95 lines
2.3 KiB
Go
95 lines
2.3 KiB
Go
package grpcserver
|
|
|
|
import (
|
|
"context"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/hex"
|
|
"sort"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"google.golang.org/grpc"
|
|
"google.golang.org/grpc/codes"
|
|
"google.golang.org/grpc/metadata"
|
|
"google.golang.org/grpc/status"
|
|
)
|
|
|
|
func HMACUnaryServerInterceptor(secret string) grpc.UnaryServerInterceptor {
|
|
return func(
|
|
ctx context.Context,
|
|
req any,
|
|
info *grpc.UnaryServerInfo,
|
|
handler grpc.UnaryHandler,
|
|
) (any, error) {
|
|
md, ok := metadata.FromIncomingContext(ctx)
|
|
if !ok {
|
|
return nil, status.Error(codes.Unauthenticated, "missing metadata")
|
|
}
|
|
|
|
serviceName := firstMD(md, "x-service-name")
|
|
timestamp := firstMD(md, "x-timestamp")
|
|
signature := firstMD(md, "x-signature")
|
|
if serviceName == "" || timestamp == "" || signature == "" {
|
|
return nil, status.Error(codes.Unauthenticated, "missing auth metadata")
|
|
}
|
|
|
|
ts, err := strconv.ParseInt(timestamp, 10, 64)
|
|
if err != nil {
|
|
return nil, status.Error(codes.Unauthenticated, "invalid timestamp")
|
|
}
|
|
if delta := time.Now().Unix() - ts; delta > 60 || delta < -60 {
|
|
return nil, status.Error(codes.Unauthenticated, "timestamp expired")
|
|
}
|
|
|
|
payload := buildSigningPayload(serviceName, info.FullMethod, timestamp, md)
|
|
expected := computeHMAC(payload, secret)
|
|
if !hmac.Equal([]byte(expected), []byte(signature)) {
|
|
return nil, status.Error(codes.Unauthenticated, "bad signature")
|
|
}
|
|
|
|
return handler(ctx, req)
|
|
}
|
|
}
|
|
|
|
func firstMD(md metadata.MD, key string) string {
|
|
values := md.Get(key)
|
|
if len(values) == 0 {
|
|
return ""
|
|
}
|
|
return values[0]
|
|
}
|
|
|
|
func computeHMAC(payload, secret string) string {
|
|
mac := hmac.New(sha256.New, []byte(secret))
|
|
mac.Write([]byte(payload))
|
|
return hex.EncodeToString(mac.Sum(nil))
|
|
}
|
|
|
|
func buildSigningPayload(serviceName, method, timestamp string, md metadata.MD) string {
|
|
parts := []string{
|
|
"service=" + serviceName,
|
|
"method=" + method,
|
|
"timestamp=" + timestamp,
|
|
}
|
|
|
|
keys := make([]string, 0, len(md))
|
|
for key := range md {
|
|
lowerKey := strings.ToLower(key)
|
|
if lowerKey == "x-signature" || strings.HasPrefix(lowerKey, ":") {
|
|
continue
|
|
}
|
|
keys = append(keys, key)
|
|
}
|
|
sort.Strings(keys)
|
|
|
|
for _, key := range keys {
|
|
values := md.Get(key)
|
|
sort.Strings(values)
|
|
parts = append(parts, key+"="+strings.Join(values, ","))
|
|
}
|
|
|
|
return strings.Join(parts, "|")
|
|
}
|