package grpcserver import ( "context" "crypto/hmac" "crypto/sha256" "encoding/hex" "sort" "strconv" "strings" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) func HMACUnaryServerInterceptor(secret string) grpc.UnaryServerInterceptor { return func( ctx context.Context, req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, ) (any, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, status.Error(codes.Unauthenticated, "missing metadata") } serviceName := firstMD(md, "x-service-name") timestamp := firstMD(md, "x-timestamp") signature := firstMD(md, "x-signature") if serviceName == "" || timestamp == "" || signature == "" { return nil, status.Error(codes.Unauthenticated, "missing auth metadata") } ts, err := strconv.ParseInt(timestamp, 10, 64) if err != nil { return nil, status.Error(codes.Unauthenticated, "invalid timestamp") } if delta := time.Now().Unix() - ts; delta > 60 || delta < -60 { return nil, status.Error(codes.Unauthenticated, "timestamp expired") } payload := buildSigningPayload(serviceName, info.FullMethod, timestamp, md) expected := computeHMAC(payload, secret) if !hmac.Equal([]byte(expected), []byte(signature)) { return nil, status.Error(codes.Unauthenticated, "bad signature") } return handler(ctx, req) } } func firstMD(md metadata.MD, key string) string { values := md.Get(key) if len(values) == 0 { return "" } return values[0] } func computeHMAC(payload, secret string) string { mac := hmac.New(sha256.New, []byte(secret)) mac.Write([]byte(payload)) return hex.EncodeToString(mac.Sum(nil)) } func buildSigningPayload(serviceName, method, timestamp string, md metadata.MD) string { parts := []string{ "service=" + serviceName, "method=" + method, "timestamp=" + timestamp, } keys := make([]string, 0, len(md)) for key := range md { lowerKey := strings.ToLower(key) if lowerKey == "x-signature" || strings.HasPrefix(lowerKey, ":") { continue } keys = append(keys, key) } sort.Strings(keys) for _, key := range keys { values := md.Get(key) sort.Strings(values) parts = append(parts, key+"="+strings.Join(values, ",")) } return strings.Join(parts, "|") }