diff --git a/goKMS/kms/kms.go b/goKMS/kms/kms.go index f6fe6388237693f3047978c32e8ed185b6572e08..83394d6b219b5f27333b003fe694a4d31f210c21 100644 --- a/goKMS/kms/kms.go +++ b/goKMS/kms/kms.go @@ -15,8 +15,6 @@ import ( log "github.com/sirupsen/logrus" "google.golang.org/grpc" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/health" healthpb "google.golang.org/grpc/health/grpc_health_v1" "google.golang.org/grpc/metadata" @@ -182,15 +180,9 @@ func (kms *KMS) initializePeers(config *config.Config) error { return nil } - var gRPCTransportCreds credentials.TransportCredentials - if config.KmsTLS.TLS { - gRPCTransportCreds, err = kmstls.GenerateGRPCClientTransportCredsWithTLS(kms.tlsConfig) - if err != nil { - log.Error(err) - return nil - } - } else { - gRPCTransportCreds = insecure.NewCredentials() + gRPCTransportCreds, err := kmstls.GenerateGRPCClientTransportCredsBasedOnTLSFlag(kms.tlsConfig) + if err != nil { + return fmt.Errorf("unable to generate gRPC transport creds: %w", err) } newPeerConn, err := grpc.Dial(peer.PeerInterComAddr, grpc.WithTransportCredentials(gRPCTransportCreds)) @@ -232,14 +224,9 @@ func (kms *KMS) startGRPC() { log.Fatalf("failed to listen: %v", err) } - var gRPCTransportCreds credentials.TransportCredentials - if kms.tlsConfig.TLS { - gRPCTransportCreds, err = kmstls.GenerateGRPCServerTransportCredsWithTLS(kms.tlsConfig) - if err != nil { - log.Fatalf("unable to generate TLS creds: %v", err) - } - } else { - gRPCTransportCreds = insecure.NewCredentials() + gRPCTransportCreds, err := kmstls.GenerateGRPCServerTransportCredsBasedOnTLSFlag(kms.tlsConfig) + if err != nil { + log.Fatalf("unable to generate gRPC server: %v", err) } interKMSServer := grpc.NewServer(grpc.Creds(gRPCTransportCreds)) @@ -489,14 +476,9 @@ func (kms *KMS) GenerateAndSendKSAKey(remoteKMSId string, pathId uuid.UUID, requ // TODO: move this somewhere else! // send to remote - var gRPCTransportCreds credentials.TransportCredentials - if kms.tlsConfig.TLS { - gRPCTransportCreds, err = kmstls.GenerateGRPCClientTransportCredsWithTLS(kms.tlsConfig) - if err != nil { - log.Fatalf("unable to generate TLS creds: %v", err) - } - } else { - gRPCTransportCreds = insecure.NewCredentials() + gRPCTransportCreds, err := kmstls.GenerateGRPCClientTransportCredsBasedOnTLSFlag(kms.tlsConfig) + if err != nil { + return fmt.Errorf("unable to generate gRPC transport creds: %w", err) } remoteConn, err := grpc.Dial(remoteKMS.Address, grpc.WithTransportCredentials(gRPCTransportCreds)) diff --git a/goKMS/kms/kmsintercom.go b/goKMS/kms/kmsintercom.go index 73a1eb5a5964131d37de02594f7ec72cfde6b898..9868f18450fa06da36bada3db34ed5bc7954507b 100644 --- a/goKMS/kms/kmsintercom.go +++ b/goKMS/kms/kmsintercom.go @@ -3,6 +3,7 @@ package kms import ( "context" "encoding/base64" + "fmt" "io" "time" @@ -19,8 +20,6 @@ import ( kmstls "code.fbi.h-da.de/danet/quant/goKMS/kms/tls" "google.golang.org/grpc" "google.golang.org/grpc/codes" - "google.golang.org/grpc/credentials" - "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/peer" "google.golang.org/grpc/status" ) @@ -174,7 +173,7 @@ func (s *kmsTalkerServer) SyncKeyIdsForBulk(ctx context.Context, in *pb.SyncKeyI }, nil } -func (s *kmsTalkerServer) InterComTransportKeyNegotiation(ctx context.Context, in *pb.InterComTransportKeyNegotiationRequest) (capReply *pb.InterComTransportKeyNegotiationResponse, err error) { +func (s *kmsTalkerServer) InterComTransportKeyNegotiation(ctx context.Context, in *pb.InterComTransportKeyNegotiationRequest) (*pb.InterComTransportKeyNegotiationResponse, error) { // NOTE: For the current prototype it is assumed that a negotiation request // is always valid. In the future an incoming negotiation request should // also check for a suitable forwarding assignment from the controller. @@ -213,7 +212,7 @@ func (s *kmsTalkerServer) InterComTransportKeyNegotiation(ctx context.Context, i return &pb.InterComTransportKeyNegotiationResponse{Timestamp: time.Now().Unix()}, nil } -func (s *kmsTalkerServer) KeyForwarding(ctx context.Context, in *pb.KeyForwardingRequest) (capReply *pb.KeyForwardingResponse, err error) { +func (s *kmsTalkerServer) KeyForwarding(ctx context.Context, in *pb.KeyForwardingRequest) (*pb.KeyForwardingResponse, error) { pathId, err := uuid.Parse(in.GetPathId()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "") @@ -224,9 +223,9 @@ func (s *kmsTalkerServer) KeyForwarding(ctx context.Context, in *pb.KeyForwardin return nil, status.Errorf(codes.InvalidArgument, "") } - decryptKey, ok := s.keyNegotiationMap[pathId] - if !ok { - return nil, status.Errorf(codes.InvalidArgument, "") + keyID, err := uuid.Parse(in.GetKey().GetId()) + if err != nil { + return nil, status.Errorf(codes.Internal, "%s", err) } route, ok := s.KMS.routingTable[pathId] @@ -236,27 +235,19 @@ func (s *kmsTalkerServer) KeyForwarding(ctx context.Context, in *pb.KeyForwardin log.Infof("%s received a key: %s, from %s", s.KMS.kmsName, in.GetKey(), route.Previous.TcpSocketStr) - keyAsByte, err := base64.StdEncoding.DecodeString(in.GetKey().GetKey()) - if err != nil { - return nil, status.Errorf(codes.Internal, "%s", err) - } - - nonceAsByte, err := base64.StdEncoding.DecodeString(in.GetKey().GetNonce()) - if err != nil { - return nil, status.Errorf(codes.Internal, "%s", err) + decryptKey, ok := s.keyNegotiationMap[pathId] + if !ok { + return nil, status.Errorf(codes.InvalidArgument, "") } - decryptedKey, err := route.Previous.CryptoAlgo().Decrypt(nonceAsByte, keyAsByte, decryptKey.Key) + decryptedKey, err := s.getDecryptedKey(decryptKey.Key, route.Previous.CryptoAlgo(), in.GetKey()) if err != nil { return nil, status.Errorf(codes.Internal, "%s", err) } - keyID, err := uuid.Parse(in.GetKey().GetId()) if route.Next != nil { log.Infof("%s forwards payload to : %s", s.KMS.kmsName, route.Next.TcpSocketStr) - if err != nil { - return nil, status.Errorf(codes.Internal, "%s", err) - } + // TODO: Find a better way of handling this; ignore the lint error for // now. go route.Next.SendPayload(&crypto.Key{ //nolint:errcheck @@ -265,51 +256,10 @@ func (s *kmsTalkerServer) KeyForwarding(ctx context.Context, in *pb.KeyForwardin }, pathId, processId) } else { log.Infof("%s received the final payload: %s", s.KMS.kmsName, string(decryptedKey)) - s.KMS.PKStoreMutex.Lock() - keys, ok := s.KMS.PKStore[route.RemoteKMS.Id] - if !ok { - s.KMS.PKStore[route.RemoteKMS.Id] = map[uuid.UUID]*PlatformKey{ - keyID: { - Id: keyID, - Value: decryptedKey, - ProcessId: in.GetProcessId(), - }, - } - } else { - keys[keyID] = &PlatformKey{ - Id: keyID, - Value: decryptedKey, - ProcessId: in.GetProcessId(), - } - } - log.Debug("Current PKSTORE: ", s.KMS.PKStore) - s.KMS.PKStoreMutex.Unlock() + s.storeReceivedPlatformKey(route.RemoteKMS.Id, in.GetProcessId(), keyID, decryptedKey) - var gRPCTransportCreds credentials.TransportCredentials - if s.KMS.tlsConfig.TLS { - gRPCTransportCreds, err = kmstls.GenerateGRPCClientTransportCredsWithTLS(s.KMS.tlsConfig) - if err != nil { - log.Fatalf("unable to generate TLS creds: %v", err) - } - } else { - gRPCTransportCreds = insecure.NewCredentials() - } - - newPeerConn, err := grpc.Dial(route.RemoteKMS.Address, grpc.WithTransportCredentials(gRPCTransportCreds)) - if err != nil { - return nil, err - } - - // inform about successful key forwarding - client := pb.NewKmsTalkerClient(newPeerConn) - - _, err = client.AckKeyForwarding(ctx, &pb.AckKeyForwardingRequest{ - Timestamp: time.Now().Unix(), - PathId: in.PathId, - ProcessId: in.ProcessId, - KeyId: keyID.String(), - }) + err = s.sendAcknowledgeKeyForwarding(ctx, route.RemoteKMS.Address, in.PathId, in.ProcessId, in.GetKey().GetId()) if err != nil { return nil, err } @@ -318,7 +268,7 @@ func (s *kmsTalkerServer) KeyForwarding(ctx context.Context, in *pb.KeyForwardin return &pb.KeyForwardingResponse{Timestamp: time.Now().Unix()}, nil } -func (s *kmsTalkerServer) AckKeyForwarding(ctx context.Context, in *pb.AckKeyForwardingRequest) (capReply *pb.AckKeyForwardingResponse, err error) { +func (s *kmsTalkerServer) AckKeyForwarding(ctx context.Context, in *pb.AckKeyForwardingRequest) (*pb.AckKeyForwardingResponse, error) { pathId, err := uuid.Parse(in.GetPathId()) if err != nil { return nil, status.Errorf(codes.InvalidArgument, "") @@ -353,17 +303,9 @@ func (s *kmsTalkerServer) KeyDelivery(ctx context.Context, in *pb.KeyDeliveryReq // decrypt keys akmsKSAKeys := make([]client.KSAKey, len(in.Key)) for i, key := range in.Key { - encryptedKeyAsByte, err := base64.StdEncoding.DecodeString(key.GetKey()) - if err != nil { - return nil, status.Errorf(codes.Internal, "%s", err) - } - nonceAsByte, err := base64.StdEncoding.DecodeString(key.GetNonce()) - if err != nil { - return nil, status.Errorf(codes.Internal, "%s", err) - } // decrypt the key cryptoAlgo := crypto.NewAES() - decryptedKSAKey, err := cryptoAlgo.Decrypt(nonceAsByte, encryptedKeyAsByte, pk.Value) + decryptedKSAKey, err := s.getDecryptedKey(pk.Value, cryptoAlgo, key) if err != nil { return nil, status.Errorf(codes.Internal, "%s", err) } @@ -386,3 +328,73 @@ func (s *kmsTalkerServer) KeyDelivery(ctx context.Context, in *pb.KeyDeliveryReq return &pb.KeyDeliveryResponse{Timestamp: time.Now().Unix()}, nil } + +func (s *kmsTalkerServer) getDecryptedKey(keyForDecryption []byte, cryptoAlgorithm crypto.CryptoAlgorithm, encryptedKey *pb.Key) ([]byte, error) { + keyAsByte, err := base64.StdEncoding.DecodeString(encryptedKey.GetKey()) + if err != nil { + return nil, err + } + + nonceAsByte, err := base64.StdEncoding.DecodeString(encryptedKey.GetNonce()) + if err != nil { + return nil, err + } + + decryptedKey, err := cryptoAlgorithm.Decrypt(nonceAsByte, keyAsByte, keyForDecryption) + if err != nil { + return nil, err + } + + return decryptedKey, nil +} + +func (s *kmsTalkerServer) storeReceivedPlatformKey(remoteKmsID, processID string, keyID uuid.UUID, decryptedKey []byte) { + s.KMS.PKStoreMutex.Lock() + defer s.KMS.PKStoreMutex.Unlock() + + keys, ok := s.KMS.PKStore[remoteKmsID] + if !ok { + s.KMS.PKStore[remoteKmsID] = map[uuid.UUID]*PlatformKey{ + keyID: { + Id: keyID, + Value: decryptedKey, + ProcessId: processID, + }, + } + } else { + keys[keyID] = &PlatformKey{ + Id: keyID, + Value: decryptedKey, + ProcessId: processID, + } + } + + log.Debug("Current PKSTORE: ", s.KMS.PKStore) +} + +func (s *kmsTalkerServer) sendAcknowledgeKeyForwarding(ctx context.Context, remoteKmsAddr, pathID, processID, keyID string) error { + gRPCTransportCreds, err := kmstls.GenerateGRPCClientTransportCredsBasedOnTLSFlag(s.KMS.tlsConfig) + if err != nil { + return fmt.Errorf("unable to generate gRPC transport creds: %w", err) + } + + newPeerConn, err := grpc.Dial(remoteKmsAddr, grpc.WithTransportCredentials(gRPCTransportCreds)) + if err != nil { + return err + } + + // inform about successful key forwarding + client := pb.NewKmsTalkerClient(newPeerConn) + + _, err = client.AckKeyForwarding(ctx, &pb.AckKeyForwardingRequest{ + Timestamp: time.Now().Unix(), + PathId: pathID, + ProcessId: processID, + KeyId: keyID, + }) + if err != nil { + return err + } + + return nil +} diff --git a/goKMS/kms/peers/qmodule.go b/goKMS/kms/peers/qmodule.go index 3452094f716a0e5bffdb41a4ebd88f939fd4a6bf..373286fad636867e3198df8a590eae0ae4d5e31a 100644 --- a/goKMS/kms/peers/qmodule.go +++ b/goKMS/kms/peers/qmodule.go @@ -248,7 +248,7 @@ func NewETSI014HTTPQuantumModule(addr, kmsId, slaveSAEID, masterSAEID string, tl if tlsConfig.TLS { tlsConf, err := kmstls.GenerateTlsLibraryConfig(tlsConfig) if err != nil { - return nil, err + return nil, fmt.Errorf("unable to generate TLS config: %w", err) } restClientConf.HTTPClient = &http.Client{ diff --git a/goKMS/kms/tls/tls.go b/goKMS/kms/tls/tls.go index a7de18613d8f411e53b1cc8fa2ebb1baa9532be3..df180446acf23f767906adeb4093f5f03923ec9b 100644 --- a/goKMS/kms/tls/tls.go +++ b/goKMS/kms/tls/tls.go @@ -8,11 +8,28 @@ import ( "code.fbi.h-da.de/danet/quant/goKMS/config" "google.golang.org/grpc/credentials" + "google.golang.org/grpc/credentials/insecure" ) -func GenerateGRPCServerTransportCredsWithTLS(tlsData config.TLSConfig) (credentials.TransportCredentials, error) { +func GenerateGRPCServerTransportCredsBasedOnTLSFlag(tlsData config.TLSConfig) (credentials.TransportCredentials, error) { + var gRPCTransportCreds credentials.TransportCredentials + if tlsData.TLS { + creds, err := generateGRPCServerTransportCredsWithTLS(tlsData.CAFile, tlsData.CertFile, tlsData.KeyFile) + if err != nil { + return nil, err + } + + gRPCTransportCreds = creds + } else { + gRPCTransportCreds = insecure.NewCredentials() + } + + return gRPCTransportCreds, nil +} + +func generateGRPCServerTransportCredsWithTLS(caFile, certFile, keyFile string) (credentials.TransportCredentials, error) { cp := x509.NewCertPool() - b, err := os.ReadFile(tlsData.CAFile) + b, err := os.ReadFile(caFile) if err != nil { return nil, err } @@ -21,7 +38,7 @@ func GenerateGRPCServerTransportCredsWithTLS(tlsData config.TLSConfig) (credenti return nil, fmt.Errorf("credentials: failed to append certificates") } - cert, err := tls.LoadX509KeyPair(tlsData.CertFile, tlsData.KeyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } @@ -36,10 +53,26 @@ func GenerateGRPCServerTransportCredsWithTLS(tlsData config.TLSConfig) (credenti return credentials.NewTLS(tlsConfig), nil } -func GenerateGRPCClientTransportCredsWithTLS(tlsData config.TLSConfig) (credentials.TransportCredentials, error) { +func GenerateGRPCClientTransportCredsBasedOnTLSFlag(tlsConfig config.TLSConfig) (credentials.TransportCredentials, error) { + var gRPCTransportCreds credentials.TransportCredentials + if tlsConfig.TLS { + creds, err := generateGRPCClientTransportCredsWithTLS(tlsConfig.CAFile, tlsConfig.CertFile, tlsConfig.KeyFile) + if err != nil { + return nil, err + } + + gRPCTransportCreds = creds + } else { + gRPCTransportCreds = insecure.NewCredentials() + } + + return gRPCTransportCreds, nil +} + +func generateGRPCClientTransportCredsWithTLS(caFile, certFile, keyFile string) (credentials.TransportCredentials, error) { cp := x509.NewCertPool() - b, err := os.ReadFile(tlsData.CAFile) + b, err := os.ReadFile(caFile) if err != nil { return nil, err } @@ -47,7 +80,7 @@ func GenerateGRPCClientTransportCredsWithTLS(tlsData config.TLSConfig) (credenti return nil, fmt.Errorf("credentials: failed to append certificates") } - cert, err := tls.LoadX509KeyPair(tlsData.CertFile, tlsData.KeyFile) + cert, err := tls.LoadX509KeyPair(certFile, keyFile) if err != nil { return nil, err } @@ -61,8 +94,8 @@ func GenerateGRPCClientTransportCredsWithTLS(tlsData config.TLSConfig) (credenti return credentials.NewTLS(tlsConfig), nil } -func GenerateTlsLibraryConfig(tlsData config.TLSConfig) (*tls.Config, error) { - caCert, err := os.ReadFile(tlsData.CAFile) +func GenerateTlsLibraryConfig(tlsConfig config.TLSConfig) (*tls.Config, error) { + caCert, err := os.ReadFile(tlsConfig.CAFile) if err != nil { return nil, err } @@ -71,7 +104,7 @@ func GenerateTlsLibraryConfig(tlsData config.TLSConfig) (*tls.Config, error) { return nil, fmt.Errorf("credentials: failed to append certificates") } - cert, err := tls.LoadX509KeyPair(tlsData.CertFile, tlsData.KeyFile) + cert, err := tls.LoadX509KeyPair(tlsConfig.CertFile, tlsConfig.KeyFile) if err != nil { return nil, err }