From c95035b1c6e34d6a33a98a78fe143d12086d1c71 Mon Sep 17 00:00:00 2001
From: Malte Bauch <malte.bauch@h-da.de>
Date: Fri, 16 Aug 2024 12:57:07 +0200
Subject: [PATCH] Stop fetching of keys and reset keystore through qkdn manager

---
 goKMS/kms/kms.go                       |  3 +
 goKMS/kms/peers/etsi14Quantummodule.go | 80 +++++++++++++++-----------
 goKMS/kms/store/kms-keystore.go        |  6 ++
 goKMS/qkdnManager/server.go            | 32 +++++++++++
 4 files changed, 88 insertions(+), 33 deletions(-)

diff --git a/goKMS/kms/kms.go b/goKMS/kms/kms.go
index c806d8e1..7ba5551d 100644
--- a/goKMS/kms/kms.go
+++ b/goKMS/kms/kms.go
@@ -498,6 +498,9 @@ func (kms *KMS) RemovePeer(kmsPeerSocket string) {
 }
 
 func (kms *KMS) FindPeerUuid(lookup uuid.UUID) (peer *peers.KmsPeer) {
+	kms.kmsPeersMutex.Lock()
+	defer kms.kmsPeersMutex.Unlock()
+
 	if kms.KmsPeers != nil {
 		for _, peer = range kms.KmsPeers {
 			if peer.GetKmsPeerId() == lookup {
diff --git a/goKMS/kms/peers/etsi14Quantummodule.go b/goKMS/kms/peers/etsi14Quantummodule.go
index 902272ed..6a7a037d 100644
--- a/goKMS/kms/peers/etsi14Quantummodule.go
+++ b/goKMS/kms/peers/etsi14Quantummodule.go
@@ -31,6 +31,7 @@ type ETSI014HTTPQuantumModule struct {
 	keyFetchInterval int
 	keyFetchAmount   int64
 	maxKeyFillLevel  uint64
+	stopFetch        chan struct{}
 }
 
 func NewETSI014HTTPQuantumModule(addr, kmsId, localSAEID, targetSAEID string, tlsConfig config.TLSConfig, master bool, keyFetchInterval int, keyFetchAmount int64, maxKeyFillLevel uint64) (*ETSI014HTTPQuantumModule, error) {
@@ -106,6 +107,8 @@ func (qm *ETSI014HTTPQuantumModule) Client() *etsi14ClientImpl.ClientImpl {
 }
 
 func (qm *ETSI014HTTPQuantumModule) Initialize() error {
+	qm.stopFetch = make(chan struct{}, 0)
+
 	// start polling keys
 	if qm.master {
 		go func() {
@@ -115,42 +118,47 @@ func (qm *ETSI014HTTPQuantumModule) Initialize() error {
 			failedAttemps := 0
 
 			// TODO: add context/channel to stop
-			for range ticker.C {
-				if failedAttemps == maxFailedKeyRequestAttempts {
-					log.Errorf("stopped trying to fetch keys from qkd module after %d tries", failedAttemps)
-					break
-				}
-
-				if qm.keyStore.Length() < int(qm.maxKeyFillLevel) {
-					container, err := qm.GetKeys(qm.keyFetchAmount, 256, nil, nil, nil)
-					if err != nil {
-						log.Error(err)
-						failedAttemps++
-						continue
-					}
-
-					keyIds := make([]string, len(container.GetKeys()))
-					for i, keyItem := range container.GetKeys() {
-						keyIds[i] = keyItem.GetKeyID()
-					}
-
-					_, err = qm.kmsClient.KeyIdNotification(context.Background(),
-						&pbIC.KeyIdNotificationRequest{
-							Timestamp: time.Now().Unix(),
-							KmsId:     qm.kmsId,
-							KeyIds:    keyIds,
-						})
-					if err != nil {
-						log.Error(err)
-						failedAttemps++
-						continue
+			for {
+				select {
+				case <-ticker.C:
+					if failedAttemps == maxFailedKeyRequestAttempts {
+						log.Errorf("stopped trying to fetch keys from qkd module after %d tries", failedAttemps)
+						break
 					}
 
-					if err := store.AddETSIKeysToKeystore(qm.keyStore, container.GetKeys()); err != nil {
-						log.Error(err)
+					if qm.keyStore.Length() < int(qm.maxKeyFillLevel) {
+						container, err := qm.GetKeys(qm.keyFetchAmount, 256, nil, nil, nil)
+						if err != nil {
+							log.Error(err)
+							failedAttemps++
+							continue
+						}
+
+						keyIds := make([]string, len(container.GetKeys()))
+						for i, keyItem := range container.GetKeys() {
+							keyIds[i] = keyItem.GetKeyID()
+						}
+
+						_, err = qm.kmsClient.KeyIdNotification(context.Background(),
+							&pbIC.KeyIdNotificationRequest{
+								Timestamp: time.Now().Unix(),
+								KmsId:     qm.kmsId,
+								KeyIds:    keyIds,
+							})
+						if err != nil {
+							log.Error(err)
+							failedAttemps++
+							continue
+						}
+
+						if err := store.AddETSIKeysToKeystore(qm.keyStore, container.GetKeys()); err != nil {
+							log.Error(err)
+						}
+
+						failedAttemps = 0
 					}
-
-					failedAttemps = 0
+				case <-qm.stopFetch:
+					break
 				}
 			}
 		}()
@@ -158,6 +166,12 @@ func (qm *ETSI014HTTPQuantumModule) Initialize() error {
 	return nil
 }
 
+func (qm *ETSI014HTTPQuantumModule) StopKeyFetching() {
+	if qm.master {
+		close(qm.stopFetch)
+	}
+}
+
 func (qm *ETSI014HTTPQuantumModule) SetKmsPeerInformation(kmsClient *GRPCClient, kmsEventBus *event.EventBus, kmsTcpSocketStr string) error {
 	qm.kmsClient = kmsClient
 	return nil
diff --git a/goKMS/kms/store/kms-keystore.go b/goKMS/kms/store/kms-keystore.go
index 9a4334b1..289a17dd 100644
--- a/goKMS/kms/store/kms-keystore.go
+++ b/goKMS/kms/store/kms-keystore.go
@@ -102,6 +102,12 @@ func (ks *KmsKeyStore) DeleteKey(keyId uuid.UUID) {
 	delete(ks.keyStore, keyId)
 }
 
+func (ks *KmsKeyStore) Reset() {
+	ks.keyStoreMutex.Lock()
+	defer ks.keyStoreMutex.Unlock()
+	ks.keyStore = make(map[uuid.UUID]*KmsKSElement)
+}
+
 func AddETSIKeysToKeystore(keyStore *KmsKeyStore, keyContainer []etsi14.KeyContainerKeysInner) error {
 	for _, keyItem := range keyContainer {
 		// decode base64 encoded key string
diff --git a/goKMS/qkdnManager/server.go b/goKMS/qkdnManager/server.go
index 1cbf5c98..fe3582c4 100644
--- a/goKMS/qkdnManager/server.go
+++ b/goKMS/qkdnManager/server.go
@@ -9,6 +9,7 @@ import (
 
 	"code.fbi.h-da.de/danet/quant/goKMS/config"
 	"code.fbi.h-da.de/danet/quant/goKMS/kms"
+	"code.fbi.h-da.de/danet/quant/goKMS/kms/peers"
 	"github.com/google/uuid"
 	"github.com/sirupsen/logrus"
 )
@@ -196,6 +197,37 @@ func (qs *QkdnManagerServer) handleSetKeyStore(w http.ResponseWriter, r *http.Re
 
 	logrus.Debugf("KeyFillLevel: %s, PeerIDs: %v, Fetch: %s", keyFillLevel, peerIDs, fetch)
 
+	for _, peerID := range peerIDs {
+		peerUUID, err := uuid.Parse(peerID)
+		if err != nil {
+			http.Error(w, err.Error(), http.StatusBadRequest)
+			return
+		}
+		peer := qs.kms.FindPeerUuid(peerUUID)
+		if peer == nil {
+			http.Error(w, fmt.Sprintf("No peer for ID: %s found", peerID), http.StatusBadRequest)
+			return
+		}
+		eqm, ok := peer.QuantumModule().(*peers.ETSI014HTTPQuantumModule)
+		if !ok {
+			http.Error(w, fmt.Sprintf("QuantumModule is not of Type ETSI014"), http.StatusBadRequest)
+			return
+		}
+		if fetch == "true" {
+			eqm.Initialize()
+
+			w.WriteHeader(http.StatusOK)
+			_, err = w.Write([]byte("OK\n"))
+			if err != nil {
+				logrus.Error(err)
+			}
+			return
+		} else if fetch == "false" {
+			eqm.StopKeyFetching()
+			eqm.KeyStore().Reset()
+		}
+	}
+
 	w.WriteHeader(http.StatusOK)
 	_, err = w.Write([]byte("OK\n"))
 	if err != nil {
-- 
GitLab