diff --git a/nucleus/controller.go b/nucleus/controller.go index fff5b4cc8c3c4b272e79853f61c68e57903ce22e..bb928f4ae44ba399da1b7e3262e7700380eb546d 100644 --- a/nucleus/controller.go +++ b/nucleus/controller.go @@ -12,6 +12,9 @@ import ( "time" ) +var coreLock sync.RWMutex +var coreOnce sync.Once + // Core is the representation of the controllers core type Core struct { // deprecated @@ -84,8 +87,7 @@ func createPrincipalNetworkDomain(sbi SouthboundInterface) error { // Run calls initialize to start the controller func Run(ctx context.Context) error { var initError error - var once sync.Once - once.Do(func() { + coreOnce.Do(func() { initError = initialize() }) if initError != nil { diff --git a/nucleus/controller_test.go b/nucleus/controller_test.go index 4608dc56800cfeca27a9e76c5c6e25dd9f0d8da1..c414d6c61dc1e538b494dd53df65d7b6f10f288d 100644 --- a/nucleus/controller_test.go +++ b/nucleus/controller_test.go @@ -15,43 +15,55 @@ func TestRun(t *testing.T) { name string args args want interface{} - wantErr bool }{ { name: "liveliness indicator", args: args{request: apiEndpoint + "/livez"}, want: http.StatusOK, - wantErr: false, }, { name: "readyness indicator", args: args{request: apiEndpoint + "/readyz"}, want: http.StatusOK, - wantErr: false, }, { name: "init", args: args{request: apiEndpoint + "/api?q=init"}, want: http.StatusOK, - wantErr: false, }, } - for _, tt := range tests { - ctx, cancel := context.WithCancel(context.Background()) - go func() { - if err := Run(ctx); (err != nil) != tt.wantErr { - t.Errorf("Run() error = %v, wantErr %v", err, tt.wantErr) - } - }() - t.Run(tt.name, func(t *testing.T) { - got, err := http.Get(tt.args.request) - if err != nil { - t.Error(err) - } - if !reflect.DeepEqual(got.StatusCode, tt.want) { - t.Errorf("Run() got: %v, want %v", got.StatusCode, tt.want) - } - }) - cancel() - } + ctx, cancel := context.WithCancel(context.Background()) + go func() { + if err := Run(ctx); err != nil{ + t.Errorf("Run() error = %v", err) + } + }() + t.Run("Controller Start HTTP API", func(t *testing.T) { + got, err := http.Get(tests[0].args.request) + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(got.StatusCode, tests[0].want) { + t.Errorf("Run() got: %v, want %v", got.StatusCode, tests[0].want) + } + got, err = http.Get(tests[0].args.request) + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(got.StatusCode, tests[1].want) { + t.Errorf("Run() got: %v, want %v", got.StatusCode, tests[1].want) + } + got, err = http.Get(tests[0].args.request) + if err != nil { + t.Error(err) + return + } + if !reflect.DeepEqual(got.StatusCode, tests[2].want) { + t.Errorf("Run() got: %v, want %v", got.StatusCode, tests[2].want) + } + }) + + cancel() } diff --git a/nucleus/http.go b/nucleus/http.go index dc2df371c62f1392cbab26516b4d47662152c087..41fea9aaa78245df4bf991451beb65743dc5b127 100644 --- a/nucleus/http.go +++ b/nucleus/http.go @@ -9,9 +9,12 @@ import ( log "github.com/sirupsen/logrus" "net/http" "net/url" + "sync" "time" ) +var httpOnce sync.Once + func stopHttpServer() error { ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -32,7 +35,9 @@ func registerHttpHandler() { // deprecated func httpAPI() (err error) { - registerHttpHandler() + coreLock.Lock() + defer coreLock.Unlock() + httpOnce.Do(registerHttpHandler) c.httpServer = &http.Server{Addr: ":8080"} go func() { err = c.httpServer.ListenAndServe() diff --git a/nucleus/store.go b/nucleus/store.go index 14d09d335f9b7bc1938e7c68dfd0f5e1dd22a0ea..a1fbdab764bd7fcb750b0e60ad309211f9caff2c 100644 --- a/nucleus/store.go +++ b/nucleus/store.go @@ -4,8 +4,11 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" "reflect" + "sync" ) +var storeLock sync.RWMutex + // Storable provides an interface for the controller's storage architecture. type Storable interface { ID() uuid.UUID @@ -14,6 +17,8 @@ type Storable interface { type store map[uuid.UUID]Storable func (s store) exists(id uuid.UUID) bool { + storeLock.RLock() + defer storeLock.RUnlock() _, ok := s[id] return ok } @@ -22,7 +27,9 @@ func (s store) add(item Storable) error { if s.exists(item.ID()) { return &ErrAlreadyExists{item: item} } + storeLock.Lock() s[item.ID()] = item + storeLock.Unlock() log.WithFields(log.Fields{ "type": reflect.TypeOf(item), "uuid": item.ID(), @@ -37,6 +44,8 @@ func (s store) get(id uuid.UUID) (Storable, error) { log.WithFields(log.Fields{ "uuid": id, }).Info("storable was accessed") + storeLock.RLock() + defer storeLock.RUnlock() return s[id], nil } @@ -44,7 +53,9 @@ func (s store) delete(id uuid.UUID) error { if !s.exists(id) { return &ErrNotFound{id: id} } + storeLock.Lock() delete(s, id) + storeLock.Unlock() log.WithFields(log.Fields{ "uuid": id, }).Info("storable has been deleted") @@ -52,8 +63,9 @@ func (s store) delete(id uuid.UUID) error { } func (s store) UUIDs() []uuid.UUID { + storeLock.RLock() + defer storeLock.RUnlock() keys := make([]uuid.UUID, len(s)) - i := 0 for k := range s { keys[i] = k