cab-backend/internal/grpcserver/auth.go
2026-03-30 21:00:35 +03:00

95 lines
2.3 KiB
Go

package grpcserver
import (
"context"
"crypto/hmac"
"crypto/sha256"
"encoding/hex"
"sort"
"strconv"
"strings"
"time"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
func HMACUnaryServerInterceptor(secret string) grpc.UnaryServerInterceptor {
return func(
ctx context.Context,
req any,
info *grpc.UnaryServerInfo,
handler grpc.UnaryHandler,
) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Error(codes.Unauthenticated, "missing metadata")
}
serviceName := firstMD(md, "x-service-name")
timestamp := firstMD(md, "x-timestamp")
signature := firstMD(md, "x-signature")
if serviceName == "" || timestamp == "" || signature == "" {
return nil, status.Error(codes.Unauthenticated, "missing auth metadata")
}
ts, err := strconv.ParseInt(timestamp, 10, 64)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "invalid timestamp")
}
if delta := time.Now().Unix() - ts; delta > 60 || delta < -60 {
return nil, status.Error(codes.Unauthenticated, "timestamp expired")
}
payload := buildSigningPayload(serviceName, info.FullMethod, timestamp, md)
expected := computeHMAC(payload, secret)
if !hmac.Equal([]byte(expected), []byte(signature)) {
return nil, status.Error(codes.Unauthenticated, "bad signature")
}
return handler(ctx, req)
}
}
func firstMD(md metadata.MD, key string) string {
values := md.Get(key)
if len(values) == 0 {
return ""
}
return values[0]
}
func computeHMAC(payload, secret string) string {
mac := hmac.New(sha256.New, []byte(secret))
mac.Write([]byte(payload))
return hex.EncodeToString(mac.Sum(nil))
}
func buildSigningPayload(serviceName, method, timestamp string, md metadata.MD) string {
parts := []string{
"service=" + serviceName,
"method=" + method,
"timestamp=" + timestamp,
}
keys := make([]string, 0, len(md))
for key := range md {
lowerKey := strings.ToLower(key)
if lowerKey == "x-signature" || strings.HasPrefix(lowerKey, ":") {
continue
}
keys = append(keys, key)
}
sort.Strings(keys)
for _, key := range keys {
values := md.Get(key)
sort.Strings(values)
parts = append(parts, key+"="+strings.Join(values, ","))
}
return strings.Join(parts, "|")
}