Skip to content
Snippets Groups Projects
Commit 290ccc06 authored by Manuel Kieweg's avatar Manuel Kieweg
Browse files

fix race conditions

parent ab86aeba
No related branches found
No related tags found
2 merge requests!121Resolve "Data Races in Store",!90Develop
Pipeline #67370 passed with warnings
...@@ -12,6 +12,9 @@ import ( ...@@ -12,6 +12,9 @@ import (
"time" "time"
) )
var coreLock sync.RWMutex
var coreOnce sync.Once
// Core is the representation of the controllers core // Core is the representation of the controllers core
type Core struct { type Core struct {
// deprecated // deprecated
...@@ -84,8 +87,7 @@ func createPrincipalNetworkDomain(sbi SouthboundInterface) error { ...@@ -84,8 +87,7 @@ func createPrincipalNetworkDomain(sbi SouthboundInterface) error {
// Run calls initialize to start the controller // Run calls initialize to start the controller
func Run(ctx context.Context) error { func Run(ctx context.Context) error {
var initError error var initError error
var once sync.Once coreOnce.Do(func() {
once.Do(func() {
initError = initialize() initError = initialize()
}) })
if initError != nil { if initError != nil {
......
...@@ -15,43 +15,55 @@ func TestRun(t *testing.T) { ...@@ -15,43 +15,55 @@ func TestRun(t *testing.T) {
name string name string
args args args args
want interface{} want interface{}
wantErr bool
}{ }{
{ {
name: "liveliness indicator", name: "liveliness indicator",
args: args{request: apiEndpoint + "/livez"}, args: args{request: apiEndpoint + "/livez"},
want: http.StatusOK, want: http.StatusOK,
wantErr: false,
}, },
{ {
name: "readyness indicator", name: "readyness indicator",
args: args{request: apiEndpoint + "/readyz"}, args: args{request: apiEndpoint + "/readyz"},
want: http.StatusOK, want: http.StatusOK,
wantErr: false,
}, },
{ {
name: "init", name: "init",
args: args{request: apiEndpoint + "/api?q=init"}, args: args{request: apiEndpoint + "/api?q=init"},
want: http.StatusOK, want: http.StatusOK,
wantErr: false,
}, },
} }
for _, tt := range tests { ctx, cancel := context.WithCancel(context.Background())
ctx, cancel := context.WithCancel(context.Background()) go func() {
go func() { if err := Run(ctx); err != nil{
if err := Run(ctx); (err != nil) != tt.wantErr { t.Errorf("Run() error = %v", err)
t.Errorf("Run() error = %v, wantErr %v", err, tt.wantErr) }
} }()
}() t.Run("Controller Start HTTP API", func(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { got, err := http.Get(tests[0].args.request)
got, err := http.Get(tt.args.request) if err != nil {
if err != nil { t.Error(err)
t.Error(err) return
} }
if !reflect.DeepEqual(got.StatusCode, tt.want) { if !reflect.DeepEqual(got.StatusCode, tests[0].want) {
t.Errorf("Run() got: %v, want %v", got.StatusCode, tt.want) t.Errorf("Run() got: %v, want %v", got.StatusCode, tests[0].want)
} }
}) got, err = http.Get(tests[0].args.request)
cancel() if err != nil {
} t.Error(err)
return
}
if !reflect.DeepEqual(got.StatusCode, tests[1].want) {
t.Errorf("Run() got: %v, want %v", got.StatusCode, tests[1].want)
}
got, err = http.Get(tests[0].args.request)
if err != nil {
t.Error(err)
return
}
if !reflect.DeepEqual(got.StatusCode, tests[2].want) {
t.Errorf("Run() got: %v, want %v", got.StatusCode, tests[2].want)
}
})
cancel()
} }
...@@ -9,9 +9,12 @@ import ( ...@@ -9,9 +9,12 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"net/http" "net/http"
"net/url" "net/url"
"sync"
"time" "time"
) )
var httpOnce sync.Once
func stopHttpServer() error { func stopHttpServer() error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
...@@ -32,7 +35,9 @@ func registerHttpHandler() { ...@@ -32,7 +35,9 @@ func registerHttpHandler() {
// deprecated // deprecated
func httpAPI() (err error) { func httpAPI() (err error) {
registerHttpHandler() coreLock.Lock()
defer coreLock.Unlock()
httpOnce.Do(registerHttpHandler)
c.httpServer = &http.Server{Addr: ":8080"} c.httpServer = &http.Server{Addr: ":8080"}
go func() { go func() {
err = c.httpServer.ListenAndServe() err = c.httpServer.ListenAndServe()
......
...@@ -4,8 +4,11 @@ import ( ...@@ -4,8 +4,11 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"reflect" "reflect"
"sync"
) )
var storeLock sync.RWMutex
// Storable provides an interface for the controller's storage architecture. // Storable provides an interface for the controller's storage architecture.
type Storable interface { type Storable interface {
ID() uuid.UUID ID() uuid.UUID
...@@ -14,6 +17,8 @@ type Storable interface { ...@@ -14,6 +17,8 @@ type Storable interface {
type store map[uuid.UUID]Storable type store map[uuid.UUID]Storable
func (s store) exists(id uuid.UUID) bool { func (s store) exists(id uuid.UUID) bool {
storeLock.RLock()
defer storeLock.RUnlock()
_, ok := s[id] _, ok := s[id]
return ok return ok
} }
...@@ -22,7 +27,9 @@ func (s store) add(item Storable) error { ...@@ -22,7 +27,9 @@ func (s store) add(item Storable) error {
if s.exists(item.ID()) { if s.exists(item.ID()) {
return &ErrAlreadyExists{item: item} return &ErrAlreadyExists{item: item}
} }
storeLock.Lock()
s[item.ID()] = item s[item.ID()] = item
storeLock.Unlock()
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"type": reflect.TypeOf(item), "type": reflect.TypeOf(item),
"uuid": item.ID(), "uuid": item.ID(),
...@@ -37,6 +44,8 @@ func (s store) get(id uuid.UUID) (Storable, error) { ...@@ -37,6 +44,8 @@ func (s store) get(id uuid.UUID) (Storable, error) {
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"uuid": id, "uuid": id,
}).Info("storable was accessed") }).Info("storable was accessed")
storeLock.RLock()
defer storeLock.RUnlock()
return s[id], nil return s[id], nil
} }
...@@ -44,7 +53,9 @@ func (s store) delete(id uuid.UUID) error { ...@@ -44,7 +53,9 @@ func (s store) delete(id uuid.UUID) error {
if !s.exists(id) { if !s.exists(id) {
return &ErrNotFound{id: id} return &ErrNotFound{id: id}
} }
storeLock.Lock()
delete(s, id) delete(s, id)
storeLock.Unlock()
log.WithFields(log.Fields{ log.WithFields(log.Fields{
"uuid": id, "uuid": id,
}).Info("storable has been deleted") }).Info("storable has been deleted")
...@@ -52,8 +63,9 @@ func (s store) delete(id uuid.UUID) error { ...@@ -52,8 +63,9 @@ func (s store) delete(id uuid.UUID) error {
} }
func (s store) UUIDs() []uuid.UUID { func (s store) UUIDs() []uuid.UUID {
storeLock.RLock()
defer storeLock.RUnlock()
keys := make([]uuid.UUID, len(s)) keys := make([]uuid.UUID, len(s))
i := 0 i := 0
for k := range s { for k := range s {
keys[i] = k keys[i] = k
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment