diff --git a/goKMS/kms/peers/etsi14Quantummodule.go b/goKMS/kms/peers/etsi14Quantummodule.go index 11b21ee0cde4426e0e1878a0919092ac5d0fc06c..def13a49f184923dae1dfdaa5092958e1aeaefce 100644 --- a/goKMS/kms/peers/etsi14Quantummodule.go +++ b/goKMS/kms/peers/etsi14Quantummodule.go @@ -31,7 +31,7 @@ type ETSI014HTTPQuantumModule struct { keyFetchInterval int keyFetchAmount int64 maxKeyFillLevel uint64 - stopFetch chan struct{} + stopFetch context.CancelFunc } func NewETSI014HTTPQuantumModule(addr, kmsId, localSAEID, targetSAEID string, tlsConfig config.TLSConfig, master bool, keyFetchInterval int, keyFetchAmount int64, maxKeyFillLevel uint64) (*ETSI014HTTPQuantumModule, error) { @@ -107,7 +107,8 @@ func (qm *ETSI014HTTPQuantumModule) Client() *etsi14ClientImpl.ClientImpl { } func (qm *ETSI014HTTPQuantumModule) Initialize() error { - qm.stopFetch = make(chan struct{}, 0) + var ctx context.Context + ctx, qm.stopFetch = context.WithCancel(context.Background()) // start polling keys if qm.master { @@ -117,8 +118,14 @@ func (qm *ETSI014HTTPQuantumModule) Initialize() error { defer ticker.Stop() // immediately start with the ticker instead of waiting the defined amount - for ; true; <-ticker.C { - qm.doKeyFetching() + RestartFetchLoop: + for { + select { + case <-ticker.C: + qm.doKeyFetching(ctx) + case <-ctx.Done(): + break RestartFetchLoop + } } }() } @@ -127,7 +134,7 @@ func (qm *ETSI014HTTPQuantumModule) Initialize() error { func (qm *ETSI014HTTPQuantumModule) StopKeyFetching() { if qm.master { - close(qm.stopFetch) + qm.stopFetch() } } @@ -185,18 +192,19 @@ func (qm *ETSI014HTTPQuantumModule) GetKeyWithIds(keyIds []etsi14ClientGenerated return container, nil } -func (qm *ETSI014HTTPQuantumModule) doKeyFetching() { +func (qm *ETSI014HTTPQuantumModule) doKeyFetching(ctx context.Context) { ticker := time.NewTicker(time.Duration(qm.keyFetchInterval) * time.Second) defer ticker.Stop() failedAttemps := 0 +FetchLoop: for { select { case <-ticker.C: if failedAttemps == maxFailedKeyRequestAttempts { log.Errorf("stopped trying to fetch keys from qkd module after %d tries", failedAttemps) - break + break FetchLoop } if qm.keyStore.Length() < int(qm.maxKeyFillLevel) { @@ -233,9 +241,8 @@ func (qm *ETSI014HTTPQuantumModule) doKeyFetching() { failedAttemps = 0 } - case <-qm.stopFetch: - break + case <-ctx.Done(): + break FetchLoop } } - } diff --git a/goKMS/qkdnManager/server.go b/goKMS/qkdnManager/server.go index fe3582c4a1951ca73ce64b87f2a9a2819676dea5..3e365b4782406bf5210e3ccfe689c5192d3fb874 100644 --- a/goKMS/qkdnManager/server.go +++ b/goKMS/qkdnManager/server.go @@ -214,14 +214,10 @@ func (qs *QkdnManagerServer) handleSetKeyStore(w http.ResponseWriter, r *http.Re return } if fetch == "true" { - eqm.Initialize() - - w.WriteHeader(http.StatusOK) - _, err = w.Write([]byte("OK\n")) - if err != nil { - logrus.Error(err) + if err := eqm.Initialize(); err != nil { + http.Error(w, fmt.Sprintf("Failed to restart fetching for quantum module of peer: %s", peerID), http.StatusBadRequest) + return } - return } else if fetch == "false" { eqm.StopKeyFetching() eqm.KeyStore().Reset()