From db489f19b161feddecb1d69286adb5d5a309ebd9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Sterba?= <hda@andre-sterba.de> Date: Mon, 18 Dec 2023 13:54:46 +0000 Subject: [PATCH] Refactor mongo store handling See merge request danet/gosdn!658 --- controller/api/initialise_test.go | 5 +- controller/app/DatabaseStore.go | 154 ++++++++++ controller/app/Service.go | 12 +- controller/app/Store.go | 235 --------------- controller/app/app.go | 4 +- controller/app/store.go | 21 ++ controller/controller.go | 38 ++- .../interfaces/networkdomain/pndStore.go | 10 +- .../networkelement/networkElementStore.go | 12 +- controller/interfaces/plugin/pluginStore.go | 12 +- controller/interfaces/rbac/roleStore.go | 16 +- controller/interfaces/rbac/userStore.go | 12 +- controller/northbound/server/topology_test.go | 6 +- .../nucleus/database/mongo-connection.go | 39 ++- .../nucleus/databaseNetworkElementStore.go | 168 +++-------- controller/nucleus/databasePluginStore.go | 106 ++----- controller/nucleus/databasePndStore.go | 113 ++----- .../nucleus/memoryNetworkElementStore.go | 11 +- controller/nucleus/memoryPluginStore.go | 11 +- controller/nucleus/memoryPndStore.go | 9 +- .../nucleus/networkElementFilesystemStore.go | 11 +- .../networkElementFilesystemStore_test.go | 34 ++- controller/nucleus/networkElementService.go | 23 +- controller/nucleus/networkElementStore.go | 8 +- controller/nucleus/pluginFilesystemStore.go | 11 +- .../nucleus/pluginFilesystemStore_test.go | 26 +- controller/nucleus/pluginService.go | 15 +- controller/nucleus/pluginStore.go | 7 +- controller/nucleus/pndFilesystemStore.go | 9 +- controller/nucleus/pndFilesystemStore_test.go | 26 +- controller/nucleus/pndService.go | 14 +- controller/nucleus/pndStore.go | 11 +- controller/rbac/databaseRoleStore.go | 128 ++------ controller/rbac/databaseUserStore.go | 171 ++++------- controller/rbac/memoryRoleStore.go | 11 +- controller/rbac/memoryUserStore.go | 11 +- controller/rbac/rbacService.go | 49 ++- controller/rbac/roleFileSystemStore.go | 11 +- controller/rbac/roleFileSystemStore_test.go | 35 ++- controller/rbac/roleStore.go | 5 +- controller/rbac/userFileSystemStore.go | 11 +- controller/rbac/userFileSystemStore_test.go | 34 ++- controller/rbac/userStore.go | 5 +- controller/store/genericStore.go | 15 +- .../topology/nodes/databaseNodeStore.go | 210 +++++++++++++ controller/topology/nodes/nodeService.go | 17 +- controller/topology/nodes/nodeService_test.go | 4 +- controller/topology/nodes/nodeStore.go | 282 ------------------ controller/topology/nodes/store.go | 21 ++ controller/topology/ports/portService.go | 17 +- controller/topology/ports/portService_test.go | 4 +- controller/topology/ports/portStore.go | 136 ++------- controller/topology/ports/store.go | 21 ++ .../routing-tables/routingTableService.go | 17 +- .../routingTableService_test.go | 10 +- .../routing-tables/routingTableStore.go | 136 ++------- controller/topology/routing-tables/store.go | 21 ++ controller/topology/store.go | 22 ++ controller/topology/store/genericStore.go | 11 +- .../topology/store/genericStore_test.go | 26 +- controller/topology/topologyService.go | 17 +- controller/topology/topologyService_test.go | 10 +- controller/topology/topologyStore.go | 136 ++------- 63 files changed, 1169 insertions(+), 1624 deletions(-) create mode 100644 controller/app/DatabaseStore.go delete mode 100644 controller/app/Store.go create mode 100644 controller/app/store.go create mode 100644 controller/topology/nodes/databaseNodeStore.go delete mode 100644 controller/topology/nodes/nodeStore.go create mode 100644 controller/topology/nodes/store.go create mode 100644 controller/topology/ports/store.go create mode 100644 controller/topology/routing-tables/store.go create mode 100644 controller/topology/store.go diff --git a/controller/api/initialise_test.go b/controller/api/initialise_test.go index c05da0204..039159784 100644 --- a/controller/api/initialise_test.go +++ b/controller/api/initialise_test.go @@ -38,6 +38,7 @@ import ( "github.com/google/uuid" log "github.com/sirupsen/logrus" "github.com/stretchr/testify/mock" + "go.mongodb.org/mongo-driver/mongo" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/test/bufconn" @@ -168,7 +169,7 @@ func bootstrapUnitTest() { }) mockPnd.On("ChangeMNE", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(uuid.Nil, nil) - if err := pndStore.Add(&mockPnd); err != nil { + if err := pndStore.Add(context.TODO(), &mockPnd); err != nil { log.Fatal(err) } @@ -190,7 +191,7 @@ func bootstrapUnitTest() { pluginStore = nucleus.NewMemoryPluginStore() pluginService := nucleus.NewPluginService(pluginStore, eventService, nucleus.NewPluginThroughReattachConfig, rpb.NewPluginRegistryServiceClient(&grpc.ClientConn{})) - networkElementStore := nucleus.NewNetworkElementStore() + networkElementStore := nucleus.NewNetworkElementStore(&mongo.Database{}, pndUUID) networkElementService := nucleus.NewNetworkElementService(networkElementStore, pluginService, eventService) mne, _ := nucleus.NewNetworkElement("test", mneUUID, &tpb.TransportOption{ diff --git a/controller/app/DatabaseStore.go b/controller/app/DatabaseStore.go new file mode 100644 index 000000000..b47fa3e26 --- /dev/null +++ b/controller/app/DatabaseStore.go @@ -0,0 +1,154 @@ +package app + +import ( + "context" + "log" + + "code.fbi.h-da.de/danet/gosdn/controller/customerrs" + "code.fbi.h-da.de/danet/gosdn/controller/store" + + "github.com/google/uuid" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +// ManagementStore is a store for apps. +type ManagementStore interface { + Add(context.Context, App) error + Update(context.Context, App) error + Delete(context.Context, App) error + Get(context.Context, store.Query) (App, error) + GetAll(context.Context) ([]App, error) +} + +const storeName = "app-store.json" + +// Store stores registered apps. +type Store struct { + collection *mongo.Collection +} + +// NewDatabaseAppStore returns a AppStore. +func NewDatabaseAppStore(db *mongo.Database) ManagementStore { + collection := db.Collection(storeName) + + return &Store{ + collection: collection, + } +} + +// Get takes a app's UUID or name and returns the app. +func (s *Store) Get(ctx context.Context, query store.Query) (App, error) { + var loadedApp App + + if query.ID.String() != "" && query.ID != uuid.Nil { + loadedApp, err := s.getByID(ctx, query.ID) + if err != nil { + return loadedApp, err + } + + return loadedApp, nil + } + + loadedApp, err := s.getByName(ctx, query.Name) + if err != nil { + return loadedApp, err + } + + return loadedApp, nil +} + +func (s *Store) getByID(ctx context.Context, appID uuid.UUID) (loadedApp App, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: appID.String()}}) + if result == nil { + return loadedApp, customerrs.CouldNotFindError{ID: appID} + } + + err = result.Decode(&loadedApp) + if err != nil { + log.Printf("Failed marshalling %v", err) + return loadedApp, customerrs.CouldNotMarshallError{Identifier: appID, Type: loadedApp, Err: err} + } + + return loadedApp, nil +} + +func (s *Store) getByName(ctx context.Context, appName string) (loadedApp App, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: appName}}) + if result == nil { + return loadedApp, customerrs.CouldNotFindError{Name: appName} + } + + err = result.Decode(&loadedApp) + if err != nil { + log.Printf("Failed marshalling %v", err) + return loadedApp, customerrs.CouldNotMarshallError{Identifier: appName, Type: loadedApp, Err: err} + } + + return loadedApp, nil +} + +// GetAll returns all stored apps. +func (s *Store) GetAll(ctx context.Context) (loadedApps []App, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) + if err != nil { + return nil, err + } + + err = cursor.All(ctx, &loadedApps) + if err != nil { + log.Printf("Failed marshalling %v", err) + + return nil, customerrs.CouldNotMarshallError{Type: loadedApps, Err: err} + } + + return loadedApps, nil +} + +// Add adds a app to the app store. +func (s *Store) Add(ctx context.Context, app App) (err error) { + _, err = s.collection.InsertOne(ctx, app) + if err != nil { + log.Printf("Could not create app: %v", err) + return customerrs.CouldNotCreateError{Identifier: app.GetID(), Type: app, Err: err} + } + + return nil +} + +// Update updates a existing app. +func (s *Store) Update(ctx context.Context, app App) (err error) { + var updatedApp App + + update := bson.D{primitive.E{Key: "$set", Value: app}} + + upsert := false + after := options.After + opt := options.FindOneAndUpdateOptions{ + Upsert: &upsert, + ReturnDocument: &after, + } + + err = s.collection.FindOneAndUpdate( + ctx, bson.M{"_id": app.GetID().String()}, update, &opt). + Decode(&updatedApp) + if err != nil { + log.Printf("Could not update app: %v", err) + + return customerrs.CouldNotUpdateError{Identifier: app.GetID().String(), Type: app, Err: err} + } + + return nil +} + +// Delete deletes a app from the app store. +func (s *Store) Delete(ctx context.Context, app App) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: app.GetID().String()}}) + if err != nil { + return customerrs.CouldNotDeleteError{Identifier: app.GetID().String(), Type: app, Err: err} + } + + return nil +} diff --git a/controller/app/Service.go b/controller/app/Service.go index 1f92ddc9a..2f02f7743 100644 --- a/controller/app/Service.go +++ b/controller/app/Service.go @@ -1,6 +1,7 @@ package app import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -27,11 +28,12 @@ func NewAppService(store ManagementStore) ManagementService { // Register checks if the app already exists and if not creates a new one. func (a *Service) Register(appName, token string) (*App, error) { + ctx := context.Background() if token != "SecurePresharedToken" { return nil, fmt.Errorf("token not valid") } - exisitingApp, err := a.store.Get(store.Query{ID: uuid.Nil, Name: appName}) + exisitingApp, err := a.store.Get(ctx, store.Query{ID: uuid.Nil, Name: appName}) if err != nil { if exisitingApp.ID == uuid.Nil { return a.createNewApp(appName) @@ -45,12 +47,13 @@ func (a *Service) Register(appName, token string) (*App, error) { // Deregister deregisters an app. func (a *Service) Deregister(appName string) error { - app, err := a.store.Get(store.Query{Name: appName}) + ctx := context.Background() + app, err := a.store.Get(ctx, store.Query{Name: appName}) if err != nil { return err } - err = a.store.Delete(app) + err = a.store.Delete(ctx, app) if err != nil { return err } @@ -59,6 +62,7 @@ func (a *Service) Deregister(appName string) error { } func (a *Service) createNewApp(appName string) (*App, error) { + ctx := context.Background() app := App{ ID: uuid.New(), Name: appName, @@ -66,7 +70,7 @@ func (a *Service) createNewApp(appName string) (*App, error) { } // generate app credentials - err := a.store.Add(app) + err := a.store.Add(ctx, app) if err != nil { return nil, err } diff --git a/controller/app/Store.go b/controller/app/Store.go deleted file mode 100644 index 158fb75e8..000000000 --- a/controller/app/Store.go +++ /dev/null @@ -1,235 +0,0 @@ -package app - -import ( - "fmt" - "log" - - "code.fbi.h-da.de/danet/gosdn/controller/customerrs" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" - "code.fbi.h-da.de/danet/gosdn/controller/store" - - "github.com/google/uuid" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo/options" -) - -// ManagementStore is a store for apps. -type ManagementStore interface { - Add(App) error - Update(App) error - Delete(App) error - Get(store.Query) (App, error) - GetAll() ([]App, error) -} - -// Store stores registered apps. -type Store struct { - storeName string -} - -// NewAppStore returns a AppStore. -func NewAppStore() ManagementStore { - return &Store{ - storeName: "app-store.json", - } -} - -// Get takes a app's UUID or name and returns the app. -func (s *Store) Get(query store.Query) (App, error) { - var loadedApp App - - if query.ID.String() != "" && query.ID != uuid.Nil { - loadedApp, err := s.getByID(query.ID) - if err != nil { - return loadedApp, err - } - - return loadedApp, nil - } - - loadedApp, err := s.getByName(query.Name) - if err != nil { - return loadedApp, err - } - - return loadedApp, nil -} - -func (s *Store) getByID(appID uuid.UUID) (loadedApp App, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedApp, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: appID.String()}}) - if result == nil { - return loadedApp, customerrs.CouldNotFindError{ID: appID} - } - - err = result.Decode(&loadedApp) - if err != nil { - log.Printf("Failed marshalling %v", err) - return loadedApp, customerrs.CouldNotMarshallError{Identifier: appID, Type: loadedApp, Err: err} - } - - return loadedApp, nil -} - -func (s *Store) getByName(appName string) (loadedApp App, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedApp, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: appName}}) - if result == nil { - return loadedApp, customerrs.CouldNotFindError{Name: appName} - } - - err = result.Decode(&loadedApp) - if err != nil { - log.Printf("Failed marshalling %v", err) - return loadedApp, customerrs.CouldNotMarshallError{Identifier: appName, Type: loadedApp, Err: err} - } - - return loadedApp, nil -} - -// GetAll returns all stored apps. -func (s *Store) GetAll() (loadedApps []App, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - - cursor, err := collection.Find(ctx, bson.D{}) - if err != nil { - return nil, err - } - - err = cursor.All(ctx, &loadedApps) - if err != nil { - log.Printf("Failed marshalling %v", err) - - return nil, customerrs.CouldNotMarshallError{Type: loadedApps, Err: err} - } - - return loadedApps, nil -} - -// Add adds a app to the app store. -func (s *Store) Add(app App) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.storeName). - InsertOne(ctx, app) - if err != nil { - log.Printf("Could not create app: %v", err) - return customerrs.CouldNotCreateError{Identifier: app.GetID(), Type: app, Err: err} - } - - return nil -} - -// Update updates a existing app. -func (s *Store) Update(app App) (err error) { - var updatedApp App - - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - update := bson.D{primitive.E{Key: "$set", Value: app}} - - upsert := false - after := options.After - opt := options.FindOneAndUpdateOptions{ - Upsert: &upsert, - ReturnDocument: &after, - } - - err = client.Database(database.DatabaseName). - Collection(s.storeName). - FindOneAndUpdate( - ctx, bson.M{"_id": app.GetID().String()}, update, &opt). - Decode(&updatedApp) - if err != nil { - log.Printf("Could not update app: %v", err) - - return customerrs.CouldNotUpdateError{Identifier: app.GetID().String(), Type: app, Err: err} - } - - return nil -} - -// Delete deletes a app from the app store. -func (s *Store) Delete(app App) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - _, err = collection.DeleteOne(ctx, bson.D{primitive.E{Key: app.GetID().String()}}) - if err != nil { - return customerrs.CouldNotDeleteError{Identifier: app.GetID().String(), Type: app, Err: err} - } - - return nil -} diff --git a/controller/app/app.go b/controller/app/app.go index e433ae93f..3bc7475c0 100644 --- a/controller/app/app.go +++ b/controller/app/app.go @@ -10,11 +10,11 @@ type App struct { } // GetID returns the uuid of an app. -func (a *App) GetID() uuid.UUID { +func (a App) GetID() uuid.UUID { return a.ID } // GetCredentials returns the credentials of an app. -func (a *App) GetCredentials() string { +func (a App) GetCredentials() string { return a.EventSystemCredentials } diff --git a/controller/app/store.go b/controller/app/store.go new file mode 100644 index 000000000..c7e81f08e --- /dev/null +++ b/controller/app/store.go @@ -0,0 +1,21 @@ +package app + +import ( + "code.fbi.h-da.de/danet/gosdn/controller/store" + topoStore "code.fbi.h-da.de/danet/gosdn/controller/topology/store" + + "go.mongodb.org/mongo-driver/mongo" +) + +// NewAppStore returns a Topologytore. +func NewAppStore(db *mongo.Database) ManagementStore { + storeMode := store.GetStoreMode() + + switch storeMode { + case store.Database: + return NewDatabaseAppStore(db) + + default: + return topoStore.NewGenericStore[App]() + } +} diff --git a/controller/controller.go b/controller/controller.go index 3eb2c153f..ca31ed040 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -30,6 +30,7 @@ import ( tpb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/topology" "code.fbi.h-da.de/danet/gosdn/controller/app" + "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" "code.fbi.h-da.de/danet/gosdn/controller/config" "code.fbi.h-da.de/danet/gosdn/controller/conflict" @@ -97,31 +98,42 @@ func initialize() error { return err } - nodeService := nodes.NewNodeService(nodes.NewDatabaseNodeStore(), eventService) - portService := ports.NewPortService(ports.NewDatabasePortStore(), eventService) + db := database.GetDatabaseConnection() + + nodeService := nodes.NewNodeService(nodes.NewNodeStore(db), eventService) + portService := ports.NewPortService(ports.NewPortStore(db), eventService) routeService := routingtables.NewRoutingTableService( - routingtables.NewDatabaseRoutingTableStore(), + routingtables.NewRoutingTableStore(db), nodeService, portService, eventService, ) pluginRegistryClient := setupPluginRegistryClient() - pluginService := nucleus.NewPluginService(nucleus.NewPluginStore(), eventService, nucleus.NewPluginThroughReattachConfig, pluginRegistryClient) + pluginService := nucleus.NewPluginService( + nucleus.NewPluginStore(db), + eventService, + nucleus.NewPluginThroughReattachConfig, + pluginRegistryClient, + ) - pndStore := nucleus.NewPndStore(pluginService) + pndStore := nucleus.NewPndStore(db, pluginService) changeStore := store.NewChangeStore() c = &Core{ - pndStore: pndStore, - pndService: nucleus.NewPndService(pndStore), - mneService: nucleus.NewNetworkElementService(nucleus.NewNetworkElementStore(), pluginService, eventService), + pndStore: pndStore, + pndService: nucleus.NewPndService(pndStore), + mneService: nucleus.NewNetworkElementService( + nucleus.NewNetworkElementStore(db, config.BasePndUUID), + pluginService, + eventService, + ), changeStore: *changeStore, - userService: rbacImpl.NewUserService(rbacImpl.NewUserStore(), eventService), - roleService: rbacImpl.NewRoleService(rbacImpl.NewRoleStore(), eventService), + userService: rbacImpl.NewUserService(rbacImpl.NewUserStore(db), eventService), + roleService: rbacImpl.NewRoleService(rbacImpl.NewRoleStore(db), eventService), topologyService: topology.NewTopologyService( - topology.NewDatabaseTopologyStore(), + topology.NewTopologyStore(db), nodeService, portService, eventService, @@ -131,7 +143,7 @@ func initialize() error { routeService: routeService, eventService: eventService, pluginService: pluginService, - appService: app.NewAppService(app.NewAppStore()), + appService: app.NewAppService(app.NewAppStore(db)), stopChan: make(chan os.Signal, 1), pluginRegistryClient: pluginRegistryClient, } @@ -250,7 +262,7 @@ func createPrincipalNetworkDomain() error { "base", "gosdn base pnd", ) - err = c.pndStore.Add(pnd) + err = c.pndStore.Add(context.Background(), pnd) if err != nil { return err } diff --git a/controller/interfaces/networkdomain/pndStore.go b/controller/interfaces/networkdomain/pndStore.go index c031c1fec..d403a870a 100644 --- a/controller/interfaces/networkdomain/pndStore.go +++ b/controller/interfaces/networkdomain/pndStore.go @@ -1,6 +1,8 @@ package networkdomain import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkelement" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" @@ -8,10 +10,10 @@ import ( // PndStore describes an interface for pnd store implementations. type PndStore interface { - Add(NetworkDomain) error - Delete(NetworkDomain) error - Get(store.Query) (LoadedPnd, error) - GetAll() ([]LoadedPnd, error) + Add(context.Context, NetworkDomain) error + Delete(context.Context, NetworkDomain) error + Get(context.Context, store.Query) (LoadedPnd, error) + GetAll(context.Context) ([]LoadedPnd, error) PendingChannels(id uuid.UUID, parseErrors ...error) (chan networkelement.Details, error) AddPendingChannel(id uuid.UUID, ch chan networkelement.Details) RemovePendingChannel(id uuid.UUID) diff --git a/controller/interfaces/networkelement/networkElementStore.go b/controller/interfaces/networkelement/networkElementStore.go index e2a136105..5ea5f5244 100644 --- a/controller/interfaces/networkelement/networkElementStore.go +++ b/controller/interfaces/networkelement/networkElementStore.go @@ -1,14 +1,16 @@ package networkelement import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/store" ) // Store describes an interface for network element store implementations. type Store interface { - Add(NetworkElement) error - Update(NetworkElement) error - Delete(NetworkElement) error - Get(store.Query) (LoadedNetworkElement, error) - GetAll() ([]LoadedNetworkElement, error) + Add(context.Context, NetworkElement) error + Update(context.Context, NetworkElement) error + Delete(context.Context, NetworkElement) error + Get(context.Context, store.Query) (LoadedNetworkElement, error) + GetAll(context.Context) ([]LoadedNetworkElement, error) } diff --git a/controller/interfaces/plugin/pluginStore.go b/controller/interfaces/plugin/pluginStore.go index 0e8ef0a5e..3fd56d242 100644 --- a/controller/interfaces/plugin/pluginStore.go +++ b/controller/interfaces/plugin/pluginStore.go @@ -1,14 +1,16 @@ package plugin import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/store" ) // Store describes an interface for plugin store implementations. type Store interface { - Add(Plugin) error - Update(Plugin) error - Delete(Plugin) error - Get(store.Query) (LoadedPlugin, error) - GetAll() ([]LoadedPlugin, error) + Add(context.Context, Plugin) error + Update(context.Context, Plugin) error + Delete(context.Context, Plugin) error + Get(context.Context, store.Query) (LoadedPlugin, error) + GetAll(context.Context) ([]LoadedPlugin, error) } diff --git a/controller/interfaces/rbac/roleStore.go b/controller/interfaces/rbac/roleStore.go index 88ee8f59f..d9cbf1020 100644 --- a/controller/interfaces/rbac/roleStore.go +++ b/controller/interfaces/rbac/roleStore.go @@ -1,12 +1,16 @@ package rbac -import "code.fbi.h-da.de/danet/gosdn/controller/store" +import ( + "context" + + "code.fbi.h-da.de/danet/gosdn/controller/store" +) // RoleStore describes an interface for role store implementations. type RoleStore interface { - Add(r Role) error - Update(r Role) error - Delete(Role) error - Get(store.Query) (LoadedRole, error) - GetAll() ([]LoadedRole, error) + Add(context.Context, Role) error + Update(context.Context, Role) error + Delete(context.Context, Role) error + Get(context.Context, store.Query) (LoadedRole, error) + GetAll(context.Context) ([]LoadedRole, error) } diff --git a/controller/interfaces/rbac/userStore.go b/controller/interfaces/rbac/userStore.go index 014a20ce0..dc3386fd7 100644 --- a/controller/interfaces/rbac/userStore.go +++ b/controller/interfaces/rbac/userStore.go @@ -1,14 +1,16 @@ package rbac import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/store" ) // UserStore describes an interface for user store implementations. type UserStore interface { - Add(u User) error - Update(u User) error - Delete(User) error - Get(store.Query) (LoadedUser, error) - GetAll() ([]LoadedUser, error) + Add(context.Context, User) error + Update(context.Context, User) error + Delete(context.Context, User) error + Get(context.Context, store.Query) (LoadedUser, error) + GetAll(context.Context) ([]LoadedUser, error) } diff --git a/controller/northbound/server/topology_test.go b/controller/northbound/server/topology_test.go index 5306b61c6..cc5a6f2ae 100644 --- a/controller/northbound/server/topology_test.go +++ b/controller/northbound/server/topology_test.go @@ -131,7 +131,7 @@ func getTestStoreWithLinks(t *testing.T, nodes []links.Link) topology.Store { store := store.NewGenericStore[links.Link]() for _, node := range nodes { - err := store.Add(node) + err := store.Add(context.TODO(), node) if err != nil { t.Fatalf("failed to prepare test store while adding node: %v", err) } @@ -144,7 +144,7 @@ func getTestStoreWithNodes(t *testing.T, nodesToAdd []nodes.Node) nodes.Store { store := store.NewGenericStore[nodes.Node]() for _, node := range nodesToAdd { - err := store.Add(node) + err := store.Add(context.TODO(), node) if err != nil { t.Fatalf("failed to prepare test store while adding node: %v", err) } @@ -157,7 +157,7 @@ func getTestStoreWithPorts(t *testing.T, portsToAdd []ports.Port) ports.Store { store := store.NewGenericStore[ports.Port]() for _, port := range portsToAdd { - err := store.Add(port) + err := store.Add(context.TODO(), port) if err != nil { t.Fatalf("failed to prepare test store while adding port: %v", err) } diff --git a/controller/nucleus/database/mongo-connection.go b/controller/nucleus/database/mongo-connection.go index b26cac5fc..43b3a5132 100644 --- a/controller/nucleus/database/mongo-connection.go +++ b/controller/nucleus/database/mongo-connection.go @@ -2,38 +2,63 @@ package database import ( "context" + "fmt" "log" "time" "code.fbi.h-da.de/danet/gosdn/controller/config" + "code.fbi.h-da.de/danet/gosdn/controller/store" + "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) const ( // Timeout operations after N seconds. - connectTimeout = 5 + timeout = 5 * time.Second // DatabaseName is the name of the mongoDB database used. DatabaseName = "gosdn" ) -// GetMongoConnection Retrieves a client to the MongoDB. -func GetMongoConnection() (*mongo.Client, context.Context, context.CancelFunc, error) { +// Connect Retrieves a client to the MongoDB. +func Connect() (*mongo.Database, error) { mongoConnection := config.DatabaseConnection - ctx, cancel := context.WithTimeout(context.Background(), connectTimeout*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() client, err := mongo.Connect(ctx, options.Client().ApplyURI(mongoConnection)) if err != nil { log.Printf("Failed to create client: %v", err) - return nil, ctx, cancel, err + return nil, err } // Force a connection to verify our connection string err = client.Ping(ctx, nil) if err != nil { log.Printf("Failed to connect to database: %v\n", err) - return nil, ctx, cancel, err + return nil, err } - return client, ctx, cancel, nil + db := client.Database(DatabaseName) + if db == nil { + return nil, fmt.Errorf("can not connect to database %s", DatabaseName) + } + + return db, nil +} + +func GetDatabaseConnection() *mongo.Database { + var db *mongo.Database + + storeMode := store.GetStoreMode() + if storeMode == store.Database { + db, err := Connect() + if err != nil { + logrus.Infof("Could not connect to database") + } + + return db + } + + return db } diff --git a/controller/nucleus/databaseNetworkElementStore.go b/controller/nucleus/databaseNetworkElementStore.go index d2e63e0c4..70dd9a7a1 100644 --- a/controller/nucleus/databaseNetworkElementStore.go +++ b/controller/nucleus/databaseNetworkElementStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" @@ -19,22 +20,25 @@ import ( // DatabaseNetworkElementStore is used to store Network Elements. type DatabaseNetworkElementStore struct { - storeName string + collection *mongo.Collection } // NewDatabaseNetworkElementStore returns a NetworkElementStore. -func NewDatabaseNetworkElementStore(pndUUID uuid.UUID) networkelement.Store { +func NewDatabaseNetworkElementStore(pndUUID uuid.UUID, db *mongo.Database) networkelement.Store { + storeName := fmt.Sprintf("networkElement-store-%s.json", pndUUID.String()) + collection := db.Collection(storeName) + return &DatabaseNetworkElementStore{ - storeName: fmt.Sprintf("networkElement-store-%s.json", pndUUID.String()), + collection: collection, } } // Get takes a NetworkElement's UUID or name and returns the NetworkElement. -func (s *DatabaseNetworkElementStore) Get(query store.Query) (networkelement.LoadedNetworkElement, error) { +func (s *DatabaseNetworkElementStore) Get(ctx context.Context, query store.Query) (networkelement.LoadedNetworkElement, error) { var loadedNetworkElement networkelement.LoadedNetworkElement if query.ID.String() != "" { - loadedNetworkElement, err := s.getByID(query.ID) + loadedNetworkElement, err := s.getByID(ctx, query.ID) if err != nil { return loadedNetworkElement, err } @@ -42,7 +46,7 @@ func (s *DatabaseNetworkElementStore) Get(query store.Query) (networkelement.Loa return loadedNetworkElement, nil } - loadedNetworkElement, err := s.getByName(query.Name) + loadedNetworkElement, err := s.getByName(ctx, query.Name) if err != nil { return loadedNetworkElement, err } @@ -50,21 +54,8 @@ func (s *DatabaseNetworkElementStore) Get(query store.Query) (networkelement.Loa return loadedNetworkElement, nil } -func (s *DatabaseNetworkElementStore) getByID(idOfNetworkElement uuid.UUID) (loadedNetworkElement networkelement.LoadedNetworkElement, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedNetworkElement, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfNetworkElement.String()}}) +func (s *DatabaseNetworkElementStore) getByID(ctx context.Context, idOfNetworkElement uuid.UUID) (loadedNetworkElement networkelement.LoadedNetworkElement, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfNetworkElement.String()}}) if result == nil { return loadedNetworkElement, customerrs.CouldNotFindError{ID: idOfNetworkElement} } @@ -78,21 +69,8 @@ func (s *DatabaseNetworkElementStore) getByID(idOfNetworkElement uuid.UUID) (loa return loadedNetworkElement, nil } -func (s *DatabaseNetworkElementStore) getByName(nameOfNetworkElement string) (loadedNetworkElement networkelement.LoadedNetworkElement, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedNetworkElement, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfNetworkElement}}) +func (s *DatabaseNetworkElementStore) getByName(ctx context.Context, nameOfNetworkElement string) (loadedNetworkElement networkelement.LoadedNetworkElement, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfNetworkElement}}) if result == nil { return loadedNetworkElement, customerrs.CouldNotFindError{Name: nameOfNetworkElement} } @@ -107,23 +85,8 @@ func (s *DatabaseNetworkElementStore) getByName(nameOfNetworkElement string) (lo } // GetAll returns all stored network elements. -func (s *DatabaseNetworkElementStore) GetAll() (loadedNetworkElements []networkelement.LoadedNetworkElement, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - - cursor, err := collection.Find(ctx, bson.D{}) +func (s *DatabaseNetworkElementStore) GetAll(ctx context.Context) (loadedNetworkElements []networkelement.LoadedNetworkElement, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return nil, err } @@ -145,22 +108,8 @@ func (s *DatabaseNetworkElementStore) GetAll() (loadedNetworkElements []networke } // Add adds a network element to the network element store. -func (s *DatabaseNetworkElementStore) Add(networkElementToAdd networkelement.NetworkElement) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.storeName). - InsertOne(ctx, networkElementToAdd) +func (s *DatabaseNetworkElementStore) Add(ctx context.Context, networkElementToAdd networkelement.NetworkElement) (err error) { + _, err = s.collection.InsertOne(ctx, networkElementToAdd) if err != nil { log.Printf("Could not create NetworkElement: %v", err) return customerrs.CouldNotCreateError{Identifier: networkElementToAdd.ID(), Type: networkElementToAdd, Err: err} @@ -170,56 +119,46 @@ func (s *DatabaseNetworkElementStore) Add(networkElementToAdd networkelement.Net } // Update updates a existing network element. -func (s *DatabaseNetworkElementStore) Update(networkElementToUpdate networkelement.NetworkElement) (err error) { +func (s *DatabaseNetworkElementStore) Update(ctx context.Context, networkElementToUpdate networkelement.NetworkElement) (err error) { var updatedLoadedNetworkElement networkelement.LoadedNetworkElement - client, ctx, cancel, err := database.GetMongoConnection() + db, err := database.Connect() if err != nil { return err } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - // 1. Start Transaction - wcMajority := writeconcern.Majority() - wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) - userCollection := client.Database(database.DatabaseName).Collection(s.storeName, wcMajorityCollectionOpts) - - session, err := client.StartSession() + wc := writeconcern.Majority() + txnOptions := options.Transaction().SetWriteConcern(wc) + // Starts a session on the client + session, err := db.Client().StartSession() if err != nil { return err } + // Defers ending the session after the transaction is committed or ended defer session.EndSession(ctx) - // 2. Fetch exisiting Entity - existingNetworkElement, err := s.getByID(networkElementToUpdate.ID()) - if err != nil { - return err - } + // Transaction + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + // 1. Fetch exisiting Entity + existingNetworkElement, err := s.getByID(ctx, networkElementToUpdate.ID()) + if err != nil { + return nil, err + } - // 3. Check if Entity.Metadata.ResourceVersion == UpdatedEntity.Metadata.ResourceVersion - if networkElementToUpdate.GetMetadata().ResourceVersion != existingNetworkElement.Metadata.ResourceVersion { - // 3.1.1 End transaction - // 3.1.2 If no -> return error + // 2. Check if Entity.Metadata.ResourceVersion == UpdatedEntity.Metadata.ResourceVersion + if networkElementToUpdate.GetMetadata().ResourceVersion != existingNetworkElement.Metadata.ResourceVersion { + // 2.1 End transaction + // 2.2 If no -> return error - return fmt.Errorf( - "resource version %d of provided network element %s is older or newer than %d in the store", - networkElementToUpdate.GetMetadata().ResourceVersion, - networkElementToUpdate.ID().String(), existingNetworkElement.Metadata.ResourceVersion, - ) - } + return nil, fmt.Errorf( + "resource version %d of provided network element %s is older or newer than %d in the store", + networkElementToUpdate.GetMetadata().ResourceVersion, + networkElementToUpdate.ID().String(), existingNetworkElement.Metadata.ResourceVersion, + ) + } - // 3.2.1 If yes -> Update entity in callback - callback := func(sessCtx mongo.SessionContext) (interface{}, error) { // Important: You must pass sessCtx as the Context parameter to the operations for them to be executed in the // transaction. - u, _ := networkElementToUpdate.(*CommonNetworkElement) u.Metadata.ResourceVersion = u.Metadata.ResourceVersion + 1 @@ -232,7 +171,7 @@ func (s *DatabaseNetworkElementStore) Update(networkElementToUpdate networkeleme ReturnDocument: &after, } - err = userCollection. + err = s.collection. FindOneAndUpdate( ctx, bson.M{"_id": networkElementToUpdate.ID().String()}, update, &opt). Decode(&updatedLoadedNetworkElement) @@ -246,7 +185,7 @@ func (s *DatabaseNetworkElementStore) Update(networkElementToUpdate networkeleme return "", nil } - _, err = session.WithTransaction(ctx, callback) + _, err = session.WithTransaction(ctx, callback, txnOptions) if err != nil { return err } @@ -255,23 +194,8 @@ func (s *DatabaseNetworkElementStore) Update(networkElementToUpdate networkeleme } // Delete deletes a network element from the network element store. -func (s *DatabaseNetworkElementStore) Delete(networkElementToDelete networkelement.NetworkElement) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - - _, err = collection.DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: networkElementToDelete.ID().String()}}) +func (s *DatabaseNetworkElementStore) Delete(ctx context.Context, networkElementToDelete networkelement.NetworkElement) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: networkElementToDelete.ID().String()}}) if err != nil { return customerrs.CouldNotDeleteError{Identifier: networkElementToDelete.ID(), Type: networkElementToDelete, Err: err} } diff --git a/controller/nucleus/databasePluginStore.go b/controller/nucleus/databasePluginStore.go index af61945b0..6fcf4bdc7 100644 --- a/controller/nucleus/databasePluginStore.go +++ b/controller/nucleus/databasePluginStore.go @@ -1,11 +1,11 @@ package nucleus import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/plugin" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" "code.fbi.h-da.de/danet/gosdn/controller/store" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" @@ -15,28 +15,24 @@ import ( log "github.com/sirupsen/logrus" ) +const storeName = "plugins-store.json" + // DatabasePluginStore is used to store Plugins. type DatabasePluginStore struct { - pluginStoreName string + collection *mongo.Collection } -// Add adds a plugin. -func (s *DatabasePluginStore) Add(pluginToAdd plugin.Plugin) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err +func NewDatabasePluginStore(db *mongo.Database) *DatabasePluginStore { + collection := db.Collection(storeName) + + return &DatabasePluginStore{ + collection: collection, } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() +} - _, err = client.Database(database.DatabaseName). - Collection(s.pluginStoreName). - InsertOne(ctx, pluginToAdd) +// Add adds a plugin. +func (s *DatabasePluginStore) Add(ctx context.Context, pluginToAdd plugin.Plugin) (err error) { + _, err = s.collection.InsertOne(ctx, pluginToAdd) if err != nil { if mongo.IsDuplicateKeyError(err) { return nil @@ -49,21 +45,9 @@ func (s *DatabasePluginStore) Add(pluginToAdd plugin.Plugin) (err error) { } // Update updates an existing plugin. -func (s *DatabasePluginStore) Update(pluginToUpdate plugin.Plugin) (err error) { +func (s *DatabasePluginStore) Update(ctx context.Context, pluginToUpdate plugin.Plugin) (err error) { var updatedLoadedPlugin plugin.LoadedPlugin - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - update := bson.D{primitive.E{Key: "$set", Value: pluginToUpdate}} upsert := false @@ -73,10 +57,8 @@ func (s *DatabasePluginStore) Update(pluginToUpdate plugin.Plugin) (err error) { ReturnDocument: &after, } - err = client.Database(database.DatabaseName). - Collection(s.pluginStoreName). - FindOneAndUpdate( - ctx, bson.M{"_id": pluginToUpdate.ID().String()}, update, &opt). + err = s.collection.FindOneAndUpdate( + ctx, bson.M{"_id": pluginToUpdate.ID().String()}, update, &opt). Decode(&updatedLoadedPlugin) if err != nil { log.Printf("Could not update Plugin: %v", err) @@ -88,22 +70,8 @@ func (s *DatabasePluginStore) Update(pluginToUpdate plugin.Plugin) (err error) { } // Delete deletes an plugin. -func (s *DatabasePluginStore) Delete(pluginToDelete plugin.Plugin) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.pluginStoreName). - DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: pluginToDelete.ID().String()}}) +func (s *DatabasePluginStore) Delete(ctx context.Context, pluginToDelete plugin.Plugin) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: pluginToDelete.ID().String()}}) if err != nil { return customerrs.CouldNotDeleteError{Identifier: pluginToDelete.ID(), Type: pluginToDelete, Err: err} } @@ -113,24 +81,8 @@ func (s *DatabasePluginStore) Delete(pluginToDelete plugin.Plugin) (err error) { // Get takes a SouthboundInterface's UUID or name and returns the SouthboundInterface. If the requested // SouthboundInterface does not exist an error is returned. -func (s *DatabasePluginStore) Get(query store.Query) (loadedPlugin plugin.LoadedPlugin, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedPlugin, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - log.Debugf("Plugin-Search-ID: %+v\n", query.ID.String()) - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.pluginStoreName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: query.ID.String()}}) +func (s *DatabasePluginStore) Get(ctx context.Context, query store.Query) (loadedPlugin plugin.LoadedPlugin, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: query.ID.String()}}) if result == nil { return loadedPlugin, customerrs.CouldNotFindError{ID: query.ID} } @@ -146,22 +98,8 @@ func (s *DatabasePluginStore) Get(query store.Query) (loadedPlugin plugin.Loaded } // GetAll returns all plugin. -func (s *DatabasePluginStore) GetAll() (loadedPlugins []plugin.LoadedPlugin, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - db := client.Database(database.DatabaseName) - collection := db.Collection(s.pluginStoreName) - - cursor, err := collection.Find(ctx, bson.D{}) +func (s *DatabasePluginStore) GetAll(ctx context.Context) (loadedPlugins []plugin.LoadedPlugin, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return nil, err } diff --git a/controller/nucleus/databasePndStore.go b/controller/nucleus/databasePndStore.go index bdd68ead3..669aa2475 100644 --- a/controller/nucleus/databasePndStore.go +++ b/controller/nucleus/databasePndStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" @@ -8,28 +9,40 @@ import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/plugin" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" log "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" ) +const pndStoreName = "pnd-store.json" + // DatabasePndStore is used to store PrincipalNetworkDomains. type DatabasePndStore struct { - pndStoreName string pendingChannels map[uuid.UUID]chan networkelement.Details + collection *mongo.Collection pluginService plugin.Service } +func NewDatabasePndStore(db *mongo.Database, pluginService plugin.Service) *DatabasePndStore { + collection := db.Collection(pndStoreName) + + return &DatabasePndStore{ + pendingChannels: make(map[uuid.UUID]chan networkelement.Details), + pluginService: pluginService, + collection: collection, + } +} + // Get takes a PrincipalNetworkDomain's UUID or name and returns the PrincipalNetworkDomain. If the requested // PrincipalNetworkDomain does not exist an error is returned. -func (s *DatabasePndStore) Get(query store.Query) (newPnd networkdomain.LoadedPnd, err error) { +func (s *DatabasePndStore) Get(ctx context.Context, query store.Query) (newPnd networkdomain.LoadedPnd, err error) { var loadedPND networkdomain.LoadedPnd if query.ID != uuid.Nil { - loadedPND, err := s.getByID(query.ID) + loadedPND, err := s.getByID(ctx, query.ID) if err != nil { return loadedPND, err } @@ -37,7 +50,7 @@ func (s *DatabasePndStore) Get(query store.Query) (newPnd networkdomain.LoadedPn return loadedPND, nil } - loadedPND, err = s.getByName(query.Name) + loadedPND, err = s.getByName(ctx, query.Name) if err != nil { return loadedPND, err } @@ -57,22 +70,8 @@ func (s *DatabasePndStore) Get(query store.Query) (newPnd networkdomain.LoadedPn return loadedPND, nil } -func (s *DatabasePndStore) getByID(idOfPnd uuid.UUID) (loadedPnd networkdomain.LoadedPnd, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedPnd, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.pndStoreName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfPnd.String()}}) +func (s *DatabasePndStore) getByID(ctx context.Context, idOfPnd uuid.UUID) (loadedPnd networkdomain.LoadedPnd, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfPnd.String()}}) if result == nil { return loadedPnd, customerrs.CouldNotFindError{ID: idOfPnd} } @@ -86,22 +85,8 @@ func (s *DatabasePndStore) getByID(idOfPnd uuid.UUID) (loadedPnd networkdomain.L return loadedPnd, nil } -func (s *DatabasePndStore) getByName(nameOfPnd string) (loadedPnd networkdomain.LoadedPnd, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedPnd, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.pndStoreName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfPnd}}) +func (s *DatabasePndStore) getByName(ctx context.Context, nameOfPnd string) (loadedPnd networkdomain.LoadedPnd, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfPnd}}) if result == nil { return loadedPnd, customerrs.CouldNotFindError{ID: nameOfPnd} } @@ -116,25 +101,10 @@ func (s *DatabasePndStore) getByName(nameOfPnd string) (loadedPnd networkdomain. } // GetAll returns all stored pnds. -func (s *DatabasePndStore) GetAll() (pnds []networkdomain.LoadedPnd, err error) { +func (s *DatabasePndStore) GetAll(ctx context.Context) (pnds []networkdomain.LoadedPnd, err error) { var loadedPnds []networkdomain.LoadedPnd - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.pndStoreName) - - cursor, err := collection.Find(ctx, bson.D{}) + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return nil, err } @@ -177,22 +147,8 @@ func (s *DatabasePndStore) GetAll() (pnds []networkdomain.LoadedPnd, err error) } // Add adds a pnd to the pnd store. -func (s *DatabasePndStore) Add(pndToAdd networkdomain.NetworkDomain) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.pndStoreName). - InsertOne(ctx, pndToAdd) +func (s *DatabasePndStore) Add(ctx context.Context, pndToAdd networkdomain.NetworkDomain) (err error) { + _, err = s.collection.InsertOne(ctx, pndToAdd) if err != nil { return customerrs.CouldNotCreateError{Identifier: pndToAdd.ID(), Type: pndToAdd, Err: err} } @@ -202,21 +158,8 @@ func (s *DatabasePndStore) Add(pndToAdd networkdomain.NetworkDomain) (err error) // Delete deletes a pnd. // It also deletes all assosicated devices and sbis. -func (s *DatabasePndStore) Delete(pndToDelete networkdomain.NetworkDomain) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - db := client.Database(database.DatabaseName) - collection := db.Collection(s.pndStoreName) - _, err = collection.DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: pndToDelete.ID().String()}}) +func (s *DatabasePndStore) Delete(ctx context.Context, pndToDelete networkdomain.NetworkDomain) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: pndToDelete.ID().String()}}) if err != nil { return customerrs.CouldNotDeleteError{Identifier: pndToDelete.ID(), Type: pndToDelete, Err: err} } diff --git a/controller/nucleus/memoryNetworkElementStore.go b/controller/nucleus/memoryNetworkElementStore.go index 100410af7..29de41d26 100644 --- a/controller/nucleus/memoryNetworkElementStore.go +++ b/controller/nucleus/memoryNetworkElementStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "encoding/json" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" @@ -23,7 +24,7 @@ func NewMemoryNetworkElementStore() networkelement.Store { } // Add adds a item to the store. -func (t *MemoryNetworkElementStore) Add(item networkelement.NetworkElement) error { +func (t *MemoryNetworkElementStore) Add(ctx context.Context, item networkelement.NetworkElement) error { var mne networkelement.LoadedNetworkElement b, err := json.Marshal(item) @@ -47,7 +48,7 @@ func (t *MemoryNetworkElementStore) Add(item networkelement.NetworkElement) erro } // Update updates a existing network element. -func (t *MemoryNetworkElementStore) Update(item networkelement.NetworkElement) error { +func (t *MemoryNetworkElementStore) Update(ctx context.Context, item networkelement.NetworkElement) error { _, ok := t.Store[item.ID().String()] if !ok { return customerrs.CouldNotFindError{ID: item.ID(), Name: item.Name()} @@ -71,14 +72,14 @@ func (t *MemoryNetworkElementStore) Update(item networkelement.NetworkElement) e } // Delete deletes a network element from the network element store. -func (t *MemoryNetworkElementStore) Delete(item networkelement.NetworkElement) error { +func (t *MemoryNetworkElementStore) Delete(ctx context.Context, item networkelement.NetworkElement) error { delete(t.Store, item.ID().String()) return nil } // Get takes a network element's UUID or name and returns the network element. -func (t *MemoryNetworkElementStore) Get(query store.Query) (networkelement.LoadedNetworkElement, error) { +func (t *MemoryNetworkElementStore) Get(ctx context.Context, query store.Query) (networkelement.LoadedNetworkElement, error) { // First search for direct hit on UUID. item, ok := t.Store[query.ID.String()] if !ok { @@ -100,7 +101,7 @@ func (t *MemoryNetworkElementStore) Get(query store.Query) (networkelement.Loade } // GetAll returns all stored network elements. -func (t *MemoryNetworkElementStore) GetAll() ([]networkelement.LoadedNetworkElement, error) { +func (t *MemoryNetworkElementStore) GetAll(ctx context.Context) ([]networkelement.LoadedNetworkElement, error) { var allItems []networkelement.LoadedNetworkElement for _, item := range t.Store { diff --git a/controller/nucleus/memoryPluginStore.go b/controller/nucleus/memoryPluginStore.go index 8c32b009a..5b78b54f9 100644 --- a/controller/nucleus/memoryPluginStore.go +++ b/controller/nucleus/memoryPluginStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "encoding/json" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" @@ -23,7 +24,7 @@ func NewMemoryPluginStore() plugin.Store { } // Add adds a item to the store. -func (t *MemoryPluginStore) Add(item plugin.Plugin) error { +func (t *MemoryPluginStore) Add(ctx context.Context, item plugin.Plugin) error { loadedPlugin, err := store.TransformObjectToLoadedObject[plugin.Plugin, plugin.LoadedPlugin](item) if err != nil { return err @@ -40,7 +41,7 @@ func (t *MemoryPluginStore) Add(item plugin.Plugin) error { } // Update updates a existing plugin. -func (t *MemoryPluginStore) Update(item plugin.Plugin) error { +func (t *MemoryPluginStore) Update(ctx context.Context, item plugin.Plugin) error { _, ok := t.Store[item.ID().String()] if ok { return nil @@ -65,14 +66,14 @@ func (t *MemoryPluginStore) Update(item plugin.Plugin) error { } // Delete deletes a plugin from the store. -func (t *MemoryPluginStore) Delete(item plugin.Plugin) error { +func (t *MemoryPluginStore) Delete(ctx context.Context, item plugin.Plugin) error { delete(t.Store, item.ID().String()) return nil } // Get takes a plugins's UUID or name and returns the plugin. -func (t *MemoryPluginStore) Get(query store.Query) (plugin.LoadedPlugin, error) { +func (t *MemoryPluginStore) Get(ctx context.Context, query store.Query) (plugin.LoadedPlugin, error) { // First search for direct hit on UUID. item, ok := t.Store[query.ID.String()] if !ok { @@ -94,7 +95,7 @@ func (t *MemoryPluginStore) Get(query store.Query) (plugin.LoadedPlugin, error) } // GetAll returns all stored plugins. -func (t *MemoryPluginStore) GetAll() ([]plugin.LoadedPlugin, error) { +func (t *MemoryPluginStore) GetAll(ctx context.Context) ([]plugin.LoadedPlugin, error) { var allItems []plugin.LoadedPlugin for _, item := range t.Store { diff --git a/controller/nucleus/memoryPndStore.go b/controller/nucleus/memoryPndStore.go index 8fe4286e9..9130ffc2b 100644 --- a/controller/nucleus/memoryPndStore.go +++ b/controller/nucleus/memoryPndStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "encoding/json" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" @@ -25,7 +26,7 @@ func NewMemoryPndStore() networkdomain.PndStore { } // Add adds a pnd to the store. -func (t *MemoryPndStore) Add(item networkdomain.NetworkDomain) error { +func (t *MemoryPndStore) Add(ctx context.Context, item networkdomain.NetworkDomain) error { var pnd networkdomain.LoadedPnd b, err := json.Marshal(item) @@ -49,14 +50,14 @@ func (t *MemoryPndStore) Add(item networkdomain.NetworkDomain) error { } // Delete deletes a pnd from the store. -func (t *MemoryPndStore) Delete(item networkdomain.NetworkDomain) error { +func (t *MemoryPndStore) Delete(ctx context.Context, item networkdomain.NetworkDomain) error { delete(t.Store, item.ID()) return nil } // Get provides a the query interface to find a stored pnd. -func (t *MemoryPndStore) Get(query store.Query) (networkdomain.LoadedPnd, error) { +func (t *MemoryPndStore) Get(ctx context.Context, query store.Query) (networkdomain.LoadedPnd, error) { // First search for direct hit on UUID. item, ok := t.Store[query.ID] if !ok { @@ -67,7 +68,7 @@ func (t *MemoryPndStore) Get(query store.Query) (networkdomain.LoadedPnd, error) } // GetAll returns all pnds currently on the store. -func (t *MemoryPndStore) GetAll() ([]networkdomain.LoadedPnd, error) { +func (t *MemoryPndStore) GetAll(ctx context.Context) ([]networkdomain.LoadedPnd, error) { var allItems []networkdomain.LoadedPnd for _, item := range t.Store { diff --git a/controller/nucleus/networkElementFilesystemStore.go b/controller/nucleus/networkElementFilesystemStore.go index d4b5850cd..449fc1ebc 100644 --- a/controller/nucleus/networkElementFilesystemStore.go +++ b/controller/nucleus/networkElementFilesystemStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "encoding/json" "os" "sync" @@ -59,7 +60,7 @@ func (s *FilesystemNetworkElementStore) writeAllNetworkElementsToFile(mnes []net } // Get takes a network element's UUID or name and returns the network element. -func (s *FilesystemNetworkElementStore) Get(query store.Query) (networkelement.LoadedNetworkElement, error) { +func (s *FilesystemNetworkElementStore) Get(ctx context.Context, query store.Query) (networkelement.LoadedNetworkElement, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -80,7 +81,7 @@ func (s *FilesystemNetworkElementStore) Get(query store.Query) (networkelement.L } // GetAll returns all stored network elements. -func (s *FilesystemNetworkElementStore) GetAll() ([]networkelement.LoadedNetworkElement, error) { +func (s *FilesystemNetworkElementStore) GetAll(ctx context.Context) ([]networkelement.LoadedNetworkElement, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -90,7 +91,7 @@ func (s *FilesystemNetworkElementStore) GetAll() ([]networkelement.LoadedNetwork } // Add adds a network element to the network element store. -func (s *FilesystemNetworkElementStore) Add(networkElementToAdd networkelement.NetworkElement) error { +func (s *FilesystemNetworkElementStore) Add(ctx context.Context, networkElementToAdd networkelement.NetworkElement) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -116,7 +117,7 @@ func (s *FilesystemNetworkElementStore) Add(networkElementToAdd networkelement.N } // Update updates a existing network element. -func (s *FilesystemNetworkElementStore) Update(networkElementToUpdate networkelement.NetworkElement) error { +func (s *FilesystemNetworkElementStore) Update(ctx context.Context, networkElementToUpdate networkelement.NetworkElement) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -145,7 +146,7 @@ func (s *FilesystemNetworkElementStore) Update(networkElementToUpdate networkele } // Delete deletes a network element from the network element store. -func (s *FilesystemNetworkElementStore) Delete(networkElementToDelete networkelement.NetworkElement) error { +func (s *FilesystemNetworkElementStore) Delete(ctx context.Context, networkElementToDelete networkelement.NetworkElement) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() diff --git a/controller/nucleus/networkElementFilesystemStore_test.go b/controller/nucleus/networkElementFilesystemStore_test.go index 540234b2b..40add4c40 100644 --- a/controller/nucleus/networkElementFilesystemStore_test.go +++ b/controller/nucleus/networkElementFilesystemStore_test.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "testing" tpb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/transport" @@ -10,6 +11,7 @@ import ( "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" "github.com/stretchr/testify/mock" + "go.mongodb.org/mongo-driver/mongo" ) func returnBasicTransportOption() tpb.TransportOption { @@ -39,10 +41,10 @@ func TestAddNetworkElement(t *testing.T) { plugin1.On("ID").Return(pluginID1) plugin1.On("Model", mock.Anything).Return([]byte{}, nil) - networkElementStore := NewNetworkElementStore() + networkElementStore := NewNetworkElementStore(&mongo.Database{}, defaultPndID) mne, _ := NewNetworkElement("testNetworkElement", mneID, &trop, defaultPndID, plugin1, [][]string{}, conflict.Metadata{}) - err := networkElementStore.Add(mne) + err := networkElementStore.Add(context.TODO(), mne) if err != nil { t.Error(err) } @@ -51,7 +53,7 @@ func TestAddNetworkElement(t *testing.T) { func TestGetAllNetworkElements(t *testing.T) { defer ensureStoreFileForTestsIsRemoved(store.NetworkElementFilenameSuffix) - networkElementStore := NewNetworkElementStore() + networkElementStore := NewNetworkElementStore(&mongo.Database{}, defaultPndID) pluginID, _ := uuid.Parse("ssssssss-ssss-ssss-ssss-ssssssssssss") plugin := &mocks.Plugin{} @@ -76,13 +78,13 @@ func TestGetAllNetworkElements(t *testing.T) { inputNetworkElements := [2]networkelement.NetworkElement{mne1, mne2} for _, mne := range inputNetworkElements { - err := networkElementStore.Add(mne) + err := networkElementStore.Add(context.TODO(), mne) if err != nil { t.Error(err) } } - returnedNetworkElements, err := networkElementStore.GetAll() + returnedNetworkElements, err := networkElementStore.GetAll(context.TODO()) if err != nil { t.Error(err) } @@ -105,7 +107,7 @@ func TestGetAllNetworkElements(t *testing.T) { func TestGetNetworkElement(t *testing.T) { defer ensureStoreFileForTestsIsRemoved(store.NetworkElementFilenameSuffix) - networkElementStore := NewNetworkElementStore() + networkElementStore := NewNetworkElementStore(&mongo.Database{}, defaultPndID) pluginID, _ := uuid.Parse("ssssssss-ssss-ssss-ssss-ssssssssssss") plugin := &mocks.Plugin{} @@ -130,13 +132,13 @@ func TestGetNetworkElement(t *testing.T) { inputNetworkElements := [2]networkelement.NetworkElement{mne1, mne2} for _, mne := range inputNetworkElements { - err := networkElementStore.Add(mne) + err := networkElementStore.Add(context.TODO(), mne) if err != nil { t.Error(err) } } - returnNetworkElement, err := networkElementStore.Get(store.Query{ID: mneID2, Name: "testname2"}) + returnNetworkElement, err := networkElementStore.Get(context.TODO(), store.Query{ID: mneID2, Name: "testname2"}) if err != nil { t.Error(err) } @@ -163,22 +165,22 @@ func TestUpdateNetworkElement(t *testing.T) { plugin1.On("ID").Return(pluginID1) plugin1.On("Model", mock.Anything).Return([]byte{}, nil) - networkElementStore := NewNetworkElementStore() + networkElementStore := NewNetworkElementStore(&mongo.Database{}, defaultPndID) mne, _ := NewNetworkElement("testNetworkElement", mneID, &trop, defaultPndID, plugin1, [][]string{}, conflict.Metadata{}) - err := networkElementStore.Add(mne) + err := networkElementStore.Add(context.TODO(), mne) if err != nil { t.Error(err) } mne, _ = NewNetworkElement(updatedNetworkElementName, mneID, &trop, defaultPndID, plugin1, [][]string{}, conflict.Metadata{}) - err = networkElementStore.Update(mne) + err = networkElementStore.Update(context.TODO(), mne) if err != nil { t.Error(err) } - returnNetworkElement, err := networkElementStore.Get(store.Query{ID: mneID, Name: updatedNetworkElementName}) + returnNetworkElement, err := networkElementStore.Get(context.TODO(), store.Query{ID: mneID, Name: updatedNetworkElementName}) if err != nil { t.Error(err) } @@ -194,7 +196,7 @@ func TestUpdateNetworkElement(t *testing.T) { func TestDeleteNetworkElement(t *testing.T) { defer ensureStoreFileForTestsIsRemoved(store.NetworkElementFilenameSuffix) - networkElementStore := NewNetworkElementStore() + networkElementStore := NewNetworkElementStore(&mongo.Database{}, defaultPndID) pluginID, _ := uuid.Parse("ssssssss-ssss-ssss-ssss-ssssssssssss") plugin := &mocks.Plugin{} @@ -219,18 +221,18 @@ func TestDeleteNetworkElement(t *testing.T) { inputNetworkElements := [2]networkelement.NetworkElement{mne1, mne2} for _, mne := range inputNetworkElements { - err := networkElementStore.Add(mne) + err := networkElementStore.Add(context.TODO(), mne) if err != nil { t.Error(err) } } - err = networkElementStore.Delete(mne1) + err = networkElementStore.Delete(context.TODO(), mne1) if err != nil { t.Error(err) } - returnNetworkElements, err := networkElementStore.GetAll() + returnNetworkElements, err := networkElementStore.GetAll(context.TODO()) if err != nil { t.Error(err) } diff --git a/controller/nucleus/networkElementService.go b/controller/nucleus/networkElementService.go index 635c190be..5d79c71f8 100644 --- a/controller/nucleus/networkElementService.go +++ b/controller/nucleus/networkElementService.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "fmt" spb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/southbound" @@ -45,7 +46,8 @@ func NewNetworkElementService( // Get takes a network element's UUID or name and returns the network element. func (s *NetworkElementService) Get(query store.Query) (networkelement.NetworkElement, error) { - loadedNetworkElement, err := s.networkElementStore.Get(query) + ctx := context.Background() + loadedNetworkElement, err := s.networkElementStore.Get(ctx, query) if err != nil { return nil, err } @@ -62,7 +64,8 @@ func (s *NetworkElementService) Get(query store.Query) (networkelement.NetworkEl func (s *NetworkElementService) GetAll() ([]networkelement.NetworkElement, error) { var mnes []networkelement.NetworkElement - loadedNetworkElements, err := s.networkElementStore.GetAll() + ctx := context.Background() + loadedNetworkElements, err := s.networkElementStore.GetAll(ctx) if err != nil { return nil, err } @@ -84,7 +87,8 @@ func (s *NetworkElementService) GetAll() ([]networkelement.NetworkElement, error // requesting network element information through this method is a lot faster than the // usual `GetAll` method. func (s *NetworkElementService) GetAllAsLoaded() ([]networkelement.LoadedNetworkElement, error) { - loadedNetworkElements, err := s.networkElementStore.GetAll() + ctx := context.Background() + loadedNetworkElements, err := s.networkElementStore.GetAll(ctx) if err != nil { return nil, err } @@ -94,7 +98,8 @@ func (s *NetworkElementService) GetAllAsLoaded() ([]networkelement.LoadedNetwork // Add adds a network element to the network element store. func (s *NetworkElementService) Add(networkElementToAdd networkelement.NetworkElement) error { - err := s.networkElementStore.Add(networkElementToAdd) + ctx := context.Background() + err := s.networkElementStore.Add(ctx, networkElementToAdd) if err != nil { return err } @@ -116,6 +121,8 @@ func (s *NetworkElementService) Add(networkElementToAdd networkelement.NetworkEl // UpdateModel updates a existing network element with a new model provided as string. func (s *NetworkElementService) UpdateModel(networkElementID uuid.UUID, modelAsString string) error { + ctx := context.Background() + exisitingNetworkElement, err := s.Get(store.Query{ID: networkElementID}) if err != nil { return err @@ -140,7 +147,7 @@ func (s *NetworkElementService) UpdateModel(networkElementID uuid.UUID, modelAsS return err } - err = s.networkElementStore.Update(exisitingNetworkElement) + err = s.networkElementStore.Update(ctx, exisitingNetworkElement) if err != nil { return err } @@ -162,7 +169,8 @@ func (s *NetworkElementService) UpdateModel(networkElementID uuid.UUID, modelAsS // Update updates a existing network element. func (s *NetworkElementService) Update(networkElementToUpdate networkelement.NetworkElement) error { - err := s.networkElementStore.Update(networkElementToUpdate) + ctx := context.Background() + err := s.networkElementStore.Update(ctx, networkElementToUpdate) if err != nil { return err } @@ -184,7 +192,8 @@ func (s *NetworkElementService) Update(networkElementToUpdate networkelement.Net // Delete deletes a network element from the network element store. func (s *NetworkElementService) Delete(networkElementToDelete networkelement.NetworkElement) error { - err := s.networkElementStore.Delete(networkElementToDelete) + ctx := context.Background() + err := s.networkElementStore.Delete(ctx, networkElementToDelete) if err != nil { return err } diff --git a/controller/nucleus/networkElementStore.go b/controller/nucleus/networkElementStore.go index 3d1449c20..617a0d7d2 100644 --- a/controller/nucleus/networkElementStore.go +++ b/controller/nucleus/networkElementStore.go @@ -3,20 +3,20 @@ package nucleus import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkelement" "code.fbi.h-da.de/danet/gosdn/controller/store" + "go.mongodb.org/mongo-driver/mongo" + "github.com/google/uuid" log "github.com/sirupsen/logrus" ) // NewNetworkElementStore returns a NetworkElementStore. -func NewNetworkElementStore() networkelement.Store { +func NewNetworkElementStore(db *mongo.Database, pndUUID uuid.UUID) networkelement.Store { storeMode := store.GetStoreMode() log.Debugf("StoreMode: %s", storeMode) switch storeMode { case store.Database: - return &DatabaseNetworkElementStore{ - storeName: "networkElement-store.json", - } + return NewDatabaseNetworkElementStore(pndUUID, db) default: store := NewFilesystemNetworkElementStore() diff --git a/controller/nucleus/pluginFilesystemStore.go b/controller/nucleus/pluginFilesystemStore.go index cc48af4a2..7d26960fa 100644 --- a/controller/nucleus/pluginFilesystemStore.go +++ b/controller/nucleus/pluginFilesystemStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "encoding/json" "os" "sync" @@ -59,7 +60,7 @@ func (s *FilesystemPluginStore) writeAllPluginsToFile(plugins []plugin.LoadedPlu } // Add adds a Plugin. -func (s *FilesystemPluginStore) Add(pluginToAdd plugin.Plugin) error { +func (s *FilesystemPluginStore) Add(ctx context.Context, pluginToAdd plugin.Plugin) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -85,7 +86,7 @@ func (s *FilesystemPluginStore) Add(pluginToAdd plugin.Plugin) error { } // Update updates an existing plugin. -func (s *FilesystemPluginStore) Update(pluginToUpdate plugin.Plugin) error { +func (s *FilesystemPluginStore) Update(ctx context.Context, pluginToUpdate plugin.Plugin) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -114,7 +115,7 @@ func (s *FilesystemPluginStore) Update(pluginToUpdate plugin.Plugin) error { } // Delete deletes an Plugin. -func (s *FilesystemPluginStore) Delete(pluginToDelete plugin.Plugin) error { +func (s *FilesystemPluginStore) Delete(ctx context.Context, pluginToDelete plugin.Plugin) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -142,7 +143,7 @@ func (s *FilesystemPluginStore) Delete(pluginToDelete plugin.Plugin) error { // Get takes a Plugin's UUID or name and returns the Plugin. If the requested // Plugin does not exist an error is returned. -func (s *FilesystemPluginStore) Get(query store.Query) (plugin.LoadedPlugin, error) { +func (s *FilesystemPluginStore) Get(ctx context.Context, query store.Query) (plugin.LoadedPlugin, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -163,7 +164,7 @@ func (s *FilesystemPluginStore) Get(query store.Query) (plugin.LoadedPlugin, err } // GetAll returns all Plugins. -func (s *FilesystemPluginStore) GetAll() ([]plugin.LoadedPlugin, error) { +func (s *FilesystemPluginStore) GetAll(ctx context.Context) ([]plugin.LoadedPlugin, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() diff --git a/controller/nucleus/pluginFilesystemStore_test.go b/controller/nucleus/pluginFilesystemStore_test.go index 195feba8a..06b6a0a0c 100644 --- a/controller/nucleus/pluginFilesystemStore_test.go +++ b/controller/nucleus/pluginFilesystemStore_test.go @@ -1,12 +1,14 @@ package nucleus import ( + "context" "testing" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/plugin" "code.fbi.h-da.de/danet/gosdn/controller/mocks" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" + "go.mongodb.org/mongo-driver/mongo" ) func ensureStoreFilesForTestsAreRemoved() { @@ -17,10 +19,10 @@ func ensureStoreFilesForTestsAreRemoved() { func TestAddPlugin(t *testing.T) { defer ensureStoreFilesForTestsAreRemoved() - pluginStore := NewPluginStore() + pluginStore := NewPluginStore(&mongo.Database{}) mockPlugin := mockPlugin(t) - err := pluginStore.Add(mockPlugin) + err := pluginStore.Add(context.TODO(), mockPlugin) if err != nil { t.Error(err) } @@ -29,7 +31,7 @@ func TestAddPlugin(t *testing.T) { func TestGetAllPlugins(t *testing.T) { defer ensureStoreFilesForTestsAreRemoved() - pluginStore := NewPluginStore() + pluginStore := NewPluginStore(&mongo.Database{}) mockPlugin1ID, err := uuid.Parse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") if err != nil { @@ -48,13 +50,13 @@ func TestGetAllPlugins(t *testing.T) { inputPlugins := [2]plugin.Plugin{mockPlugin1, mockPlugin2} for _, plugin := range inputPlugins { - err := pluginStore.Add(plugin) + err := pluginStore.Add(context.TODO(), plugin) if err != nil { t.Error(err) } } - returnPlugins, err := pluginStore.GetAll() + returnPlugins, err := pluginStore.GetAll(context.TODO()) if err != nil { t.Error(err) } @@ -74,7 +76,7 @@ func TestGetAllPlugins(t *testing.T) { func TestGetPlugin(t *testing.T) { defer ensureStoreFilesForTestsAreRemoved() - pluginStore := NewPluginStore() + pluginStore := NewPluginStore(&mongo.Database{}) mockPlugin1ID, err := uuid.Parse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") if err != nil { @@ -92,13 +94,13 @@ func TestGetPlugin(t *testing.T) { inputPlugins := [2]plugin.Plugin{mockPlugin1, mockPlugin2} for _, plugins := range inputPlugins { - err := pluginStore.Add(plugins) + err := pluginStore.Add(context.TODO(), plugins) if err != nil { t.Error(err) } } - returnPlugin, err := pluginStore.Get(store.Query{ID: mockPlugin2ID, Name: ""}) + returnPlugin, err := pluginStore.Get(context.TODO(), store.Query{ID: mockPlugin2ID, Name: ""}) if err != nil { t.Error(err) } @@ -111,7 +113,7 @@ func TestGetPlugin(t *testing.T) { func TestDeleteAllPlugins(t *testing.T) { defer ensureStoreFilesForTestsAreRemoved() - pluginStore := NewPluginStore() + pluginStore := NewPluginStore(&mongo.Database{}) mockPlugin1ID, err := uuid.Parse("aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa") if err != nil { @@ -129,18 +131,18 @@ func TestDeleteAllPlugins(t *testing.T) { inputPlugins := [2]plugin.Plugin{mockPlugin1, mockPlugin2} for _, plugins := range inputPlugins { - err := pluginStore.Add(plugins) + err := pluginStore.Add(context.TODO(), plugins) if err != nil { t.Error(err) } } - err = pluginStore.Delete(mockPlugin1) + err = pluginStore.Delete(context.TODO(), mockPlugin1) if err != nil { t.Error(err) } - returnPlugins, err := pluginStore.GetAll() + returnPlugins, err := pluginStore.GetAll(context.TODO()) if err != nil { t.Error(err) } diff --git a/controller/nucleus/pluginService.go b/controller/nucleus/pluginService.go index d40ef4392..f95d47f10 100644 --- a/controller/nucleus/pluginService.go +++ b/controller/nucleus/pluginService.go @@ -50,7 +50,8 @@ func NewPluginService(pluginStore plugin.Store, eventService eventInterfaces.Ser // Get takes a Plugin's UUID or name and returns the Plugin. func (s *PluginService) Get(query store.Query) (plugin.Plugin, error) { - loadedPlugin, err := s.pluginStore.Get(query) + ctx := context.Background() + loadedPlugin, err := s.pluginStore.Get(ctx, query) if err != nil { return nil, err } @@ -67,7 +68,8 @@ func (s *PluginService) Get(query store.Query) (plugin.Plugin, error) { func (s *PluginService) GetAll() ([]plugin.Plugin, error) { var plugins []plugin.Plugin - loadedPlugins, err := s.pluginStore.GetAll() + ctx := context.Background() + loadedPlugins, err := s.pluginStore.GetAll(ctx) if err != nil { return nil, err } @@ -86,7 +88,8 @@ func (s *PluginService) GetAll() ([]plugin.Plugin, error) { // Add adds a plugin to the plugin store. func (s *PluginService) Add(pluginToAdd plugin.Plugin) error { - err := s.pluginStore.Add(pluginToAdd) + ctx := context.Background() + err := s.pluginStore.Add(ctx, pluginToAdd) if err != nil { return err } @@ -100,7 +103,8 @@ func (s *PluginService) Add(pluginToAdd plugin.Plugin) error { // Delete deletes a plugin from the plugin store. func (s *PluginService) Delete(pluginToDelete plugin.Plugin) error { - err := s.pluginStore.Delete(pluginToDelete) + ctx := context.Background() + err := s.pluginStore.Delete(ctx, pluginToDelete) if err != nil { return err } @@ -116,6 +120,7 @@ func (s *PluginService) Delete(pluginToDelete plugin.Plugin) error { } func (s *PluginService) createPluginFromStore(loadedPlugin plugin.LoadedPlugin) (plugin.Plugin, error) { + ctx := context.Background() plugin, err := s.createPluginFromStoreFn(loadedPlugin) if err != nil { if errors.Is(err, hcplugin.ErrProcessNotFound) { @@ -123,7 +128,7 @@ func (s *PluginService) createPluginFromStore(loadedPlugin plugin.LoadedPlugin) if err != nil { return nil, err } - err := s.pluginStore.Update(plugin) + err := s.pluginStore.Update(ctx, plugin) if err != nil { return nil, err } diff --git a/controller/nucleus/pluginStore.go b/controller/nucleus/pluginStore.go index cc68ab041..d61e9219c 100644 --- a/controller/nucleus/pluginStore.go +++ b/controller/nucleus/pluginStore.go @@ -3,17 +3,16 @@ package nucleus import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/plugin" "code.fbi.h-da.de/danet/gosdn/controller/store" + "go.mongodb.org/mongo-driver/mongo" ) // NewPluginStore returns a pluginStore. -func NewPluginStore() plugin.Store { +func NewPluginStore(db *mongo.Database) plugin.Store { storeMode := store.GetStoreMode() switch storeMode { case store.Database: - return &DatabasePluginStore{ - pluginStoreName: store.PluginFilenameSuffix, - } + return NewDatabasePluginStore(db) default: store := NewFilesystemPluginStore() diff --git a/controller/nucleus/pndFilesystemStore.go b/controller/nucleus/pndFilesystemStore.go index 386fc025a..863fe7205 100644 --- a/controller/nucleus/pndFilesystemStore.go +++ b/controller/nucleus/pndFilesystemStore.go @@ -1,6 +1,7 @@ package nucleus import ( + "context" "encoding/json" "os" "sync" @@ -87,7 +88,7 @@ func (t *FilesystemPndStore) writeAllPndsToFile(pnds []networkdomain.LoadedPnd) } // Add a pnd to the store. -func (t *FilesystemPndStore) Add(pndToAdd networkdomain.NetworkDomain) error { +func (t *FilesystemPndStore) Add(ctx context.Context, pndToAdd networkdomain.NetworkDomain) error { t.fileMutex.Lock() defer t.fileMutex.Unlock() @@ -113,7 +114,7 @@ func (t *FilesystemPndStore) Add(pndToAdd networkdomain.NetworkDomain) error { } // Delete deletes a pnd from the store. -func (t *FilesystemPndStore) Delete(pndToDelete networkdomain.NetworkDomain) error { +func (t *FilesystemPndStore) Delete(ctx context.Context, pndToDelete networkdomain.NetworkDomain) error { t.fileMutex.Lock() defer t.fileMutex.Unlock() @@ -141,7 +142,7 @@ func (t *FilesystemPndStore) Delete(pndToDelete networkdomain.NetworkDomain) err } // Get provides a the query interface to find a stored pnd. -func (t *FilesystemPndStore) Get(query store.Query) (networkdomain.LoadedPnd, error) { +func (t *FilesystemPndStore) Get(ctx context.Context, query store.Query) (networkdomain.LoadedPnd, error) { t.fileMutex.Lock() defer t.fileMutex.Unlock() @@ -162,7 +163,7 @@ func (t *FilesystemPndStore) Get(query store.Query) (networkdomain.LoadedPnd, er } // GetAll returns all pnds currently on the store. -func (t *FilesystemPndStore) GetAll() ([]networkdomain.LoadedPnd, error) { +func (t *FilesystemPndStore) GetAll(ctx context.Context) ([]networkdomain.LoadedPnd, error) { t.fileMutex.Lock() defer t.fileMutex.Unlock() diff --git a/controller/nucleus/pndFilesystemStore_test.go b/controller/nucleus/pndFilesystemStore_test.go index e4a1e69dc..096488713 100644 --- a/controller/nucleus/pndFilesystemStore_test.go +++ b/controller/nucleus/pndFilesystemStore_test.go @@ -1,23 +1,25 @@ package nucleus import ( + "context" "testing" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" + "go.mongodb.org/mongo-driver/mongo" ) func TestAddPnd(t *testing.T) { defer ensureStoreFileForTestsIsRemoved(store.PndFilename) pluginServiceMock := NewPluginServiceMock() - pndStore := NewPndStore(pluginServiceMock) + pndStore := NewPndStore(&mongo.Database{}, pluginServiceMock) pndID, _ := uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bad") pnd := NewPND(pndID, "testpnd", "test") - err := pndStore.Add(pnd) + err := pndStore.Add(context.TODO(), pnd) if err != nil { t.Error(err) @@ -28,7 +30,7 @@ func TestGetAllPnds(t *testing.T) { defer ensureStoreFileForTestsIsRemoved(store.PndFilename) pluginServiceMock := NewPluginServiceMock() - pndStore := NewPndStore(pluginServiceMock) + pndStore := NewPndStore(&mongo.Database{}, pluginServiceMock) pndID1, _ := uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bad") pndID2, _ := uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bab") pnd1 := NewPND(pndID1, "testpnd", "test") @@ -37,13 +39,13 @@ func TestGetAllPnds(t *testing.T) { inputPnds := [2]networkdomain.NetworkDomain{pnd1, pnd2} for _, pnd := range inputPnds { - err := pndStore.Add(pnd) + err := pndStore.Add(context.TODO(), pnd) if err != nil { t.Error(err) } } - returnPnds, err := pndStore.GetAll() + returnPnds, err := pndStore.GetAll(context.TODO()) if err != nil { t.Error(err) } @@ -65,7 +67,7 @@ func TestGetPnd(t *testing.T) { defer ensureStoreFileForTestsIsRemoved(store.PndFilename) pluginServiceMock := NewPluginServiceMock() - pndStore := NewPndStore(pluginServiceMock) + pndStore := NewPndStore(&mongo.Database{}, pluginServiceMock) pndID1, _ := uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bad") pndID2, _ := uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bab") pnd1 := NewPND(pndID1, "testpnd", "test") @@ -74,13 +76,13 @@ func TestGetPnd(t *testing.T) { inputPnds := [2]networkdomain.NetworkDomain{pnd1, pnd2} for _, pnd := range inputPnds { - err := pndStore.Add(pnd) + err := pndStore.Add(context.TODO(), pnd) if err != nil { t.Error(err) } } - returnPnd, err := pndStore.Get(store.Query{ID: pndID2, Name: ""}) + returnPnd, err := pndStore.Get(context.TODO(), store.Query{ID: pndID2, Name: ""}) if err != nil { t.Error(err) } @@ -100,7 +102,7 @@ func TestDeletePnd(t *testing.T) { defer ensureStoreFileForTestsIsRemoved(store.PndFilename) pluginServiceMock := NewPluginServiceMock() - pndStore := NewPndStore(pluginServiceMock) + pndStore := NewPndStore(&mongo.Database{}, pluginServiceMock) pndID1, _ := uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bad") pndID2, _ := uuid.Parse("b4016412-eec5-45a1-aa29-f59915357bab") pnd1 := NewPND(pndID1, "testpnd", "test") @@ -109,18 +111,18 @@ func TestDeletePnd(t *testing.T) { inputPnds := [2]networkdomain.NetworkDomain{pnd1, pnd2} for _, pnd := range inputPnds { - err := pndStore.Add(pnd) + err := pndStore.Add(context.TODO(), pnd) if err != nil { t.Error(err) } } - err := pndStore.Delete(pnd1) + err := pndStore.Delete(context.TODO(), pnd1) if err != nil { t.Error(err) } - returnPnds, err := pndStore.GetAll() + returnPnds, err := pndStore.GetAll(context.TODO()) if err != nil { t.Error(err) } diff --git a/controller/nucleus/pndService.go b/controller/nucleus/pndService.go index fc2124a87..eff57da43 100644 --- a/controller/nucleus/pndService.go +++ b/controller/nucleus/pndService.go @@ -1,6 +1,8 @@ package nucleus import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" @@ -22,7 +24,8 @@ func NewPndService(pndStore networkdomain.PndStore) networkdomain.Service { // Add adds a PND to the PND store. func (p *PndService) Add(pndToAdd networkdomain.NetworkDomain) error { - err := p.pndStore.Add(pndToAdd) + ctx := context.Background() + err := p.pndStore.Add(ctx, pndToAdd) if err != nil { return err } @@ -32,7 +35,8 @@ func (p *PndService) Add(pndToAdd networkdomain.NetworkDomain) error { // Delete deletes a PND from the PND store. func (p *PndService) Delete(pndToDelete networkdomain.NetworkDomain) error { - err := p.pndStore.Delete(pndToDelete) + ctx := context.Background() + err := p.pndStore.Delete(ctx, pndToDelete) if err != nil { return err } @@ -42,7 +46,8 @@ func (p *PndService) Delete(pndToDelete networkdomain.NetworkDomain) error { // Get takes a PND's UUID or name and returns the PND. func (p *PndService) Get(query store.Query) (networkdomain.NetworkDomain, error) { - loadedPnd, err := p.pndStore.Get(query) + ctx := context.Background() + loadedPnd, err := p.pndStore.Get(ctx, query) if err != nil { return nil, err } @@ -53,8 +58,9 @@ func (p *PndService) Get(query store.Query) (networkdomain.NetworkDomain, error) // GetAll returns all stores PNDs. func (p *PndService) GetAll() ([]networkdomain.NetworkDomain, error) { var pnds []networkdomain.NetworkDomain + ctx := context.Background() - loadedPnds, err := p.pndStore.GetAll() + loadedPnds, err := p.pndStore.GetAll(ctx) if err != nil { return nil, err } diff --git a/controller/nucleus/pndStore.go b/controller/nucleus/pndStore.go index c865a26e2..8baf210a2 100644 --- a/controller/nucleus/pndStore.go +++ b/controller/nucleus/pndStore.go @@ -2,11 +2,10 @@ package nucleus import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" - "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkelement" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/plugin" "code.fbi.h-da.de/danet/gosdn/controller/store" - "github.com/google/uuid" log "github.com/sirupsen/logrus" + "go.mongodb.org/mongo-driver/mongo" ) // LoadedPnd represents a Principal Network Domain that was loaeded by using @@ -22,17 +21,13 @@ type PndStore struct { } // NewPndStore returns a PndStore. -func NewPndStore(pluginService plugin.Service) networkdomain.PndStore { +func NewPndStore(db *mongo.Database, pluginService plugin.Service) networkdomain.PndStore { storeMode := store.GetStoreMode() log.Debugf("StoreMode: %s", storeMode) switch storeMode { case store.Database: - return &DatabasePndStore{ - pendingChannels: make(map[uuid.UUID]chan networkelement.Details), - pndStoreName: "pnd-store.json", - pluginService: pluginService, - } + return NewDatabasePndStore(db, pluginService) default: store := NewFilesystemPndStore(pluginService) diff --git a/controller/rbac/databaseRoleStore.go b/controller/rbac/databaseRoleStore.go index 4066d97b2..193ef8bf9 100644 --- a/controller/rbac/databaseRoleStore.go +++ b/controller/rbac/databaseRoleStore.go @@ -1,11 +1,11 @@ package rbac import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" log "github.com/sirupsen/logrus" @@ -15,28 +15,24 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" ) +const roleStoreName = "role-store.json" + // DatabaseRoleStore is used to store roles in database. type DatabaseRoleStore struct { - roleStoreName string + collection *mongo.Collection } -// Add adds a Role. -func (s *DatabaseRoleStore) Add(roleToAdd rbac.Role) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err +func NewDatabaseRoleStore(db *mongo.Database) *DatabaseRoleStore { + collection := db.Collection(roleStoreName) + + return &DatabaseRoleStore{ + collection: collection, } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() +} - _, err = client.Database(database.DatabaseName). - Collection(s.roleStoreName). - InsertOne(ctx, roleToAdd) +// Add adds a Role. +func (s *DatabaseRoleStore) Add(ctx context.Context, roleToAdd rbac.Role) (err error) { + _, err = s.collection.InsertOne(ctx, roleToAdd) if err != nil { if mongo.IsDuplicateKeyError(err) { return nil @@ -49,22 +45,8 @@ func (s *DatabaseRoleStore) Add(roleToAdd rbac.Role) (err error) { } // Delete deletes a Role. -func (s *DatabaseRoleStore) Delete(roleToDelete rbac.Role) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.roleStoreName). - DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: roleToDelete.ID().String()}}) +func (s *DatabaseRoleStore) Delete(ctx context.Context, roleToDelete rbac.Role) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: roleToDelete.ID().String()}}) if err != nil { return customerrs.CouldNotDeleteError{Identifier: roleToDelete.ID(), Type: roleToDelete, Err: err} } @@ -74,11 +56,11 @@ func (s *DatabaseRoleStore) Delete(roleToDelete rbac.Role) (err error) { // Get takes a Roles's UUID or name and returns the Role. If the requested // Role does not exist an error is returned. -func (s *DatabaseRoleStore) Get(query store.Query) (rbac.LoadedRole, error) { +func (s *DatabaseRoleStore) Get(ctx context.Context, query store.Query) (rbac.LoadedRole, error) { var loadedRole rbac.LoadedRole if query.ID != uuid.Nil { - loadedRole, err := s.getByID(query.ID) + loadedRole, err := s.getByID(ctx, query.ID) if err != nil { return loadedRole, err } @@ -86,7 +68,7 @@ func (s *DatabaseRoleStore) Get(query store.Query) (rbac.LoadedRole, error) { return loadedRole, nil } - loadedRole, err := s.getByName(query.Name) + loadedRole, err := s.getByName(ctx, query.Name) if err != nil { return loadedRole, err } @@ -94,22 +76,8 @@ func (s *DatabaseRoleStore) Get(query store.Query) (rbac.LoadedRole, error) { return loadedRole, nil } -func (s *DatabaseRoleStore) getByID(idOfRole uuid.UUID) (loadedRole rbac.LoadedRole, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedRole, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.roleStoreName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfRole.String()}}) +func (s *DatabaseRoleStore) getByID(ctx context.Context, idOfRole uuid.UUID) (loadedRole rbac.LoadedRole, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfRole.String()}}) if result == nil { return loadedRole, customerrs.CouldNotFindError{ID: idOfRole} } @@ -123,22 +91,8 @@ func (s *DatabaseRoleStore) getByID(idOfRole uuid.UUID) (loadedRole rbac.LoadedR return loadedRole, nil } -func (s *DatabaseRoleStore) getByName(nameOfRole string) (loadedRole rbac.LoadedRole, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedRole, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.roleStoreName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "rolename", Value: nameOfRole}}) +func (s *DatabaseRoleStore) getByName(ctx context.Context, nameOfRole string) (loadedRole rbac.LoadedRole, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "rolename", Value: nameOfRole}}) if result == nil { return loadedRole, customerrs.CouldNotFindError{Name: nameOfRole} } @@ -153,23 +107,8 @@ func (s *DatabaseRoleStore) getByName(nameOfRole string) (loadedRole rbac.Loaded } // GetAll returns all Roles. -func (s *DatabaseRoleStore) GetAll() (loadedRoles []rbac.LoadedRole, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.roleStoreName) - - cursor, err := collection.Find(ctx, bson.D{}) +func (s *DatabaseRoleStore) GetAll(ctx context.Context) (loadedRoles []rbac.LoadedRole, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return nil, err } @@ -190,20 +129,9 @@ func (s *DatabaseRoleStore) GetAll() (loadedRoles []rbac.LoadedRole, err error) } // Update updates the role. -func (s *DatabaseRoleStore) Update(roleToUpdate rbac.Role) (err error) { +func (s *DatabaseRoleStore) Update(ctx context.Context, roleToUpdate rbac.Role) (err error) { var updatedLoadedRole rbac.LoadedRole - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() update := bson.D{primitive.E{Key: "$set", Value: roleToUpdate}} upsert := false @@ -213,10 +141,8 @@ func (s *DatabaseRoleStore) Update(roleToUpdate rbac.Role) (err error) { ReturnDocument: &after, } - err = client.Database(database.DatabaseName). - Collection(s.roleStoreName). - FindOneAndUpdate( - ctx, bson.M{"_id": roleToUpdate.ID().String()}, update, &opt). + err = s.collection.FindOneAndUpdate( + ctx, bson.M{"_id": roleToUpdate.ID().String()}, update, &opt). Decode(&updatedLoadedRole) if err != nil { log.Printf("Could not update Role: %v", err) diff --git a/controller/rbac/databaseUserStore.go b/controller/rbac/databaseUserStore.go index 0c8242b06..8f4b7c93b 100644 --- a/controller/rbac/databaseUserStore.go +++ b/controller/rbac/databaseUserStore.go @@ -1,6 +1,7 @@ package rbac import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" @@ -16,28 +17,24 @@ import ( "go.mongodb.org/mongo-driver/mongo/writeconcern" ) +const storeName = "user.json" + // DatabaseUserStore is used to store users in database. type DatabaseUserStore struct { - userStoreName string + collection *mongo.Collection } -// Add adds an User. -func (s *DatabaseUserStore) Add(userToAdd rbac.User) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err +func NewDatabaseUserStore(db *mongo.Database) *DatabaseUserStore { + collection := db.Collection(storeName) + + return &DatabaseUserStore{ + collection: collection, } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() +} - _, err = client.Database(database.DatabaseName). - Collection(s.userStoreName). - InsertOne(ctx, userToAdd) +// Add adds an User. +func (s *DatabaseUserStore) Add(ctx context.Context, userToAdd rbac.User) (err error) { + _, err = s.collection.InsertOne(ctx, userToAdd) if err != nil { if mongo.IsDuplicateKeyError(err) { return nil @@ -50,22 +47,8 @@ func (s *DatabaseUserStore) Add(userToAdd rbac.User) (err error) { } // Delete deletes an User. -func (s *DatabaseUserStore) Delete(userToDelete rbac.User) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.userStoreName). - DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: userToDelete.ID().String()}}) +func (s *DatabaseUserStore) Delete(ctx context.Context, userToDelete rbac.User) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: "_id", Value: userToDelete.ID().String()}}) if err != nil { return customerrs.CouldNotDeleteError{Identifier: userToDelete.ID(), Type: userToDelete, Err: err} } @@ -75,11 +58,11 @@ func (s *DatabaseUserStore) Delete(userToDelete rbac.User) (err error) { // Get takes a User's UUID or name and returns the User. If the requested // User does not exist an error is returned. -func (s *DatabaseUserStore) Get(query store.Query) (rbac.LoadedUser, error) { +func (s *DatabaseUserStore) Get(ctx context.Context, query store.Query) (rbac.LoadedUser, error) { var loadedUser rbac.LoadedUser if query.ID != uuid.Nil { - loadedUser, err := s.getByID(query.ID) + loadedUser, err := s.getByID(ctx, query.ID) if err != nil { return loadedUser, err } @@ -87,7 +70,7 @@ func (s *DatabaseUserStore) Get(query store.Query) (rbac.LoadedUser, error) { return loadedUser, nil } - loadedUser, err := s.getByName(query.Name) + loadedUser, err := s.getByName(ctx, query.Name) if err != nil { return loadedUser, err } @@ -95,22 +78,8 @@ func (s *DatabaseUserStore) Get(query store.Query) (rbac.LoadedUser, error) { return loadedUser, nil } -func (s *DatabaseUserStore) getByID(idOfUser uuid.UUID) (loadedUser rbac.LoadedUser, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedUser, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.userStoreName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfUser.String()}}) +func (s *DatabaseUserStore) getByID(ctx context.Context, idOfUser uuid.UUID) (loadedUser rbac.LoadedUser, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfUser.String()}}) if result == nil { return loadedUser, customerrs.CouldNotFindError{ID: idOfUser} } @@ -124,22 +93,8 @@ func (s *DatabaseUserStore) getByID(idOfUser uuid.UUID) (loadedUser rbac.LoadedU return loadedUser, nil } -func (s *DatabaseUserStore) getByName(nameOfUser string) (loadedUser rbac.LoadedUser, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedUser, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.userStoreName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "username", Value: nameOfUser}}) +func (s *DatabaseUserStore) getByName(ctx context.Context, nameOfUser string) (loadedUser rbac.LoadedUser, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "username", Value: nameOfUser}}) if result == nil { return loadedUser, customerrs.CouldNotFindError{Name: nameOfUser} } @@ -154,23 +109,8 @@ func (s *DatabaseUserStore) getByName(nameOfUser string) (loadedUser rbac.Loaded } // GetAll returns all Users. -func (s *DatabaseUserStore) GetAll() (loadedUsers []rbac.LoadedUser, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.userStoreName) - - cursor, err := collection.Find(ctx, bson.D{}) +func (s *DatabaseUserStore) GetAll(ctx context.Context) (loadedUsers []rbac.LoadedUser, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return nil, err } @@ -191,52 +131,44 @@ func (s *DatabaseUserStore) GetAll() (loadedUsers []rbac.LoadedUser, err error) } // Update updates the User. -func (s *DatabaseUserStore) Update(userToUpdate rbac.User) (err error) { +func (s *DatabaseUserStore) Update(ctx context.Context, userToUpdate rbac.User) (err error) { var updatedLoadedUser rbac.LoadedUser - client, ctx, cancel, err := database.GetMongoConnection() + db, err := database.Connect() if err != nil { return err } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - // 1. Start Transaction - wcMajority := writeconcern.Majority() - wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) - userCollection := client.Database(database.DatabaseName).Collection(s.userStoreName, wcMajorityCollectionOpts) - - session, err := client.StartSession() + wc := writeconcern.Majority() + txnOptions := options.Transaction().SetWriteConcern(wc) + // Starts a session on the client + session, err := db.Client().StartSession() if err != nil { return err } + // Defers ending the session after the transaction is committed or ended defer session.EndSession(ctx) - // 2. Fetch exisiting Entity - existingUser, err := s.getByID(userToUpdate.ID()) - if err != nil { - return err - } + // 3.2.1 If yes -> Update entity in callback + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + // 2. Fetch exisiting Entity + existingUser, err := s.getByID(ctx, userToUpdate.ID()) + if err != nil { + return nil, err + } - // 3. Check if Entity.Metadata.ResourceVersion == UpdatedEntity.Metadata.ResourceVersion - if userToUpdate.GetMetadata().ResourceVersion != existingUser.Metadata.ResourceVersion { - // 3.1.1 End transaction - // 3.1.2 If no -> return error + // 3. Check if Entity.Metadata.ResourceVersion == UpdatedEntity.Metadata.ResourceVersion + if userToUpdate.GetMetadata().ResourceVersion != existingUser.Metadata.ResourceVersion { + // 3.1.1 End transaction + // 3.1.2 If no -> return error - return fmt.Errorf( - "resource version %d of provided user %s is older or newer than %d in the store", - userToUpdate.GetMetadata().ResourceVersion, - userToUpdate.ID().String(), existingUser.Metadata.ResourceVersion, - ) - } + return nil, fmt.Errorf( + "resource version %d of provided user %s is older or newer than %d in the store", + userToUpdate.GetMetadata().ResourceVersion, + userToUpdate.ID().String(), existingUser.Metadata.ResourceVersion, + ) + } - // 3.2.1 If yes -> Update entity in callback - callback := func(sessCtx mongo.SessionContext) (interface{}, error) { // Important: You must pass sessCtx as the Context parameter to the operations for them to be executed in the // transaction. @@ -252,9 +184,8 @@ func (s *DatabaseUserStore) Update(userToUpdate rbac.User) (err error) { ReturnDocument: &after, } - err = userCollection. - FindOneAndUpdate( - ctx, bson.M{"_id": userToUpdate.ID().String()}, update, &opt). + err = s.collection.FindOneAndUpdate( + ctx, bson.M{"_id": userToUpdate.ID().String()}, update, &opt). Decode(&updatedLoadedUser) if err != nil { log.Printf("Could not update User: %v", err) @@ -266,7 +197,7 @@ func (s *DatabaseUserStore) Update(userToUpdate rbac.User) (err error) { return "", nil } - _, err = session.WithTransaction(ctx, callback) + _, err = session.WithTransaction(ctx, callback, txnOptions) if err != nil { return err } diff --git a/controller/rbac/memoryRoleStore.go b/controller/rbac/memoryRoleStore.go index 86c6ee25f..68e47594c 100644 --- a/controller/rbac/memoryRoleStore.go +++ b/controller/rbac/memoryRoleStore.go @@ -1,6 +1,7 @@ package rbac import ( + "context" "encoding/json" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" @@ -23,7 +24,7 @@ func NewMemoryRoleStore() rbac.RoleStore { } // Add adds a item to the store. -func (s *MemoryRoleStore) Add(item rbac.Role) error { +func (s *MemoryRoleStore) Add(ctx context.Context, item rbac.Role) error { var role rbac.LoadedRole b, err := json.Marshal(item) @@ -47,13 +48,13 @@ func (s *MemoryRoleStore) Add(item rbac.Role) error { } // Delete deletes a role from the role store. -func (s *MemoryRoleStore) Delete(item rbac.Role) error { +func (s *MemoryRoleStore) Delete(ctx context.Context, item rbac.Role) error { delete(s.Store, item.ID().String()) return nil } // Update updates an existing role. -func (s *MemoryRoleStore) Update(item rbac.Role) error { +func (s *MemoryRoleStore) Update(ctx context.Context, item rbac.Role) error { _, ok := s.Store[item.ID().String()] if !ok { return customerrs.CouldNotFindError{ID: item.ID(), Name: item.Name()} @@ -77,7 +78,7 @@ func (s *MemoryRoleStore) Update(item rbac.Role) error { } // Get takes a role's UUID or name and returns the role. -func (s *MemoryRoleStore) Get(query store.Query) (rbac.LoadedRole, error) { +func (s *MemoryRoleStore) Get(ctx context.Context, query store.Query) (rbac.LoadedRole, error) { // First search for direct hit on UUID. item, ok := s.Store[query.ID.String()] if !ok { @@ -99,7 +100,7 @@ func (s *MemoryRoleStore) Get(query store.Query) (rbac.LoadedRole, error) { } // GetAll returns all stored roles. -func (s *MemoryRoleStore) GetAll() ([]rbac.LoadedRole, error) { +func (s *MemoryRoleStore) GetAll(ctx context.Context) ([]rbac.LoadedRole, error) { var allItems []rbac.LoadedRole for _, item := range s.Store { diff --git a/controller/rbac/memoryUserStore.go b/controller/rbac/memoryUserStore.go index b9bab5832..b05b5a280 100644 --- a/controller/rbac/memoryUserStore.go +++ b/controller/rbac/memoryUserStore.go @@ -1,6 +1,7 @@ package rbac import ( + "context" "encoding/json" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" @@ -23,7 +24,7 @@ func NewMemoryUserStore() rbac.UserStore { } // Add adds a item to the store. -func (s *MemoryUserStore) Add(item rbac.User) error { +func (s *MemoryUserStore) Add(ctx context.Context, item rbac.User) error { var user rbac.LoadedUser b, err := json.Marshal(item) @@ -47,13 +48,13 @@ func (s *MemoryUserStore) Add(item rbac.User) error { } // Delete deletes a user from the user store. -func (s *MemoryUserStore) Delete(item rbac.User) error { +func (s *MemoryUserStore) Delete(ctx context.Context, item rbac.User) error { delete(s.Store, item.ID().String()) return nil } // Update updates an existing user. -func (s *MemoryUserStore) Update(item rbac.User) error { +func (s *MemoryUserStore) Update(ctx context.Context, item rbac.User) error { _, ok := s.Store[item.ID().String()] if !ok { return customerrs.CouldNotFindError{ID: item.ID(), Name: item.Name()} @@ -77,7 +78,7 @@ func (s *MemoryUserStore) Update(item rbac.User) error { } // Get takes a user's UUID or name and returns the user. -func (s *MemoryUserStore) Get(query store.Query) (rbac.LoadedUser, error) { +func (s *MemoryUserStore) Get(ctx context.Context, query store.Query) (rbac.LoadedUser, error) { // First search for direct hit on UUID. item, ok := s.Store[query.ID.String()] if !ok { @@ -99,7 +100,7 @@ func (s *MemoryUserStore) Get(query store.Query) (rbac.LoadedUser, error) { } // GetAll returns all stored users. -func (s *MemoryUserStore) GetAll() ([]rbac.LoadedUser, error) { +func (s *MemoryUserStore) GetAll(ctx context.Context) ([]rbac.LoadedUser, error) { var allItems []rbac.LoadedUser for _, item := range s.Store { diff --git a/controller/rbac/rbacService.go b/controller/rbac/rbacService.go index 0a1d26a47..d16505d70 100644 --- a/controller/rbac/rbacService.go +++ b/controller/rbac/rbacService.go @@ -1,6 +1,8 @@ package rbac import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/event" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -35,7 +37,8 @@ func NewUserService(userStore rbac.UserStore, eventService eventInterfaces.Servi // Add adds a user to the user store. func (s *UserService) Add(userToAdd rbac.User) error { - err := s.userStore.Add(userToAdd) + ctx := context.Background() + err := s.userStore.Add(ctx, userToAdd) if err != nil { return err } @@ -57,7 +60,8 @@ func (s *UserService) Add(userToAdd rbac.User) error { // Delete deletes a user from the user store. func (s *UserService) Delete(userToDelete rbac.User) error { - err := s.userStore.Delete(userToDelete) + ctx := context.Background() + err := s.userStore.Delete(ctx, userToDelete) if err != nil { return err } @@ -79,7 +83,8 @@ func (s *UserService) Delete(userToDelete rbac.User) error { // Update updates a existing user. func (s *UserService) Update(userToUpdate rbac.User) error { - err := s.userStore.Update(userToUpdate) + ctx := context.Background() + err := s.userStore.Update(ctx, userToUpdate) if err != nil { return err } @@ -101,7 +106,8 @@ func (s *UserService) Update(userToUpdate rbac.User) error { // Get takes a user's UUID or name and returns the user. func (s *UserService) Get(query store.Query) (rbac.User, error) { - loadedUser, err := s.userStore.Get(query) + ctx := context.Background() + loadedUser, err := s.userStore.Get(ctx, query) if err != nil { return nil, err } @@ -113,7 +119,8 @@ func (s *UserService) Get(query store.Query) (rbac.User, error) { func (s *UserService) GetAll() ([]rbac.User, error) { var users []rbac.User - loadedUsers, err := s.userStore.GetAll() + ctx := context.Background() + loadedUsers, err := s.userStore.GetAll(ctx) if err != nil { return nil, err } @@ -126,7 +133,15 @@ func (s *UserService) GetAll() ([]rbac.User, error) { } func (s *UserService) createUserFromStore(loadedUser rbac.LoadedUser) rbac.User { - return NewUser(uuid.MustParse(loadedUser.ID), loadedUser.UserName, loadedUser.Roles, loadedUser.Password, loadedUser.Token, loadedUser.Salt, loadedUser.Metadata) + return NewUser( + uuid.MustParse(loadedUser.ID), + loadedUser.UserName, + loadedUser.Roles, + loadedUser.Password, + loadedUser.Token, + loadedUser.Salt, + loadedUser.Metadata, + ) } // RoleService provides a role service implementation. @@ -145,7 +160,8 @@ func NewRoleService(roleStore rbac.RoleStore, eventService eventInterfaces.Servi // Add adds a role to the role store. func (s *RoleService) Add(roleToAdd rbac.Role) error { - err := s.roleStore.Add(roleToAdd) + ctx := context.Background() + err := s.roleStore.Add(ctx, roleToAdd) if err != nil { return err } @@ -167,7 +183,8 @@ func (s *RoleService) Add(roleToAdd rbac.Role) error { // Delete deletes a role from the role store. func (s *RoleService) Delete(roleToDelete rbac.Role) error { - err := s.roleStore.Delete(roleToDelete) + ctx := context.Background() + err := s.roleStore.Delete(ctx, roleToDelete) if err != nil { return err } @@ -188,7 +205,8 @@ func (s *RoleService) Delete(roleToDelete rbac.Role) error { // Update updates a existing role. func (s *RoleService) Update(roleToUpdate rbac.Role) error { - err := s.roleStore.Update(roleToUpdate) + ctx := context.Background() + err := s.roleStore.Update(ctx, roleToUpdate) if err != nil { return err } @@ -210,7 +228,8 @@ func (s *RoleService) Update(roleToUpdate rbac.Role) error { // Get takes a roles's UUID or name and returns the role. func (s *RoleService) Get(query store.Query) (rbac.Role, error) { - loadedRole, err := s.roleStore.Get(query) + ctx := context.Background() + loadedRole, err := s.roleStore.Get(ctx, query) if err != nil { return nil, err } @@ -222,7 +241,8 @@ func (s *RoleService) Get(query store.Query) (rbac.Role, error) { func (s *RoleService) GetAll() ([]rbac.Role, error) { var roles []rbac.Role - loadedRoles, err := s.roleStore.GetAll() + ctx := context.Background() + loadedRoles, err := s.roleStore.GetAll(ctx) if err != nil { return nil, err } @@ -235,5 +255,10 @@ func (s *RoleService) GetAll() ([]rbac.Role, error) { } func (s *RoleService) createRoleFromStore(loadedRole rbac.LoadedRole) rbac.Role { - return NewRole(uuid.MustParse(loadedRole.ID), loadedRole.RoleName, loadedRole.Description, loadedRole.Permissions) + return NewRole( + uuid.MustParse(loadedRole.ID), + loadedRole.RoleName, + loadedRole.Description, + loadedRole.Permissions, + ) } diff --git a/controller/rbac/roleFileSystemStore.go b/controller/rbac/roleFileSystemStore.go index c0780d950..b7fda1602 100644 --- a/controller/rbac/roleFileSystemStore.go +++ b/controller/rbac/roleFileSystemStore.go @@ -1,6 +1,7 @@ package rbac import ( + "context" "encoding/json" "os" "sync" @@ -59,7 +60,7 @@ func (s *FileSystemRoleStore) writeAllRolesToFile(roles []rbac.LoadedRole) error } // Add adds a Role to the Role store. -func (s *FileSystemRoleStore) Add(roleToAdd rbac.Role) error { +func (s *FileSystemRoleStore) Add(ctx context.Context, roleToAdd rbac.Role) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -85,7 +86,7 @@ func (s *FileSystemRoleStore) Add(roleToAdd rbac.Role) error { } // Delete deletes a Role from the Role store. -func (s *FileSystemRoleStore) Delete(roleToDelete rbac.Role) error { +func (s *FileSystemRoleStore) Delete(ctx context.Context, roleToDelete rbac.Role) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -113,7 +114,7 @@ func (s *FileSystemRoleStore) Delete(roleToDelete rbac.Role) error { } // Get takes a Roles ID and return the Role if found. -func (s *FileSystemRoleStore) Get(query store.Query) (rbac.LoadedRole, error) { +func (s *FileSystemRoleStore) Get(ctx context.Context, query store.Query) (rbac.LoadedRole, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -133,7 +134,7 @@ func (s *FileSystemRoleStore) Get(query store.Query) (rbac.LoadedRole, error) { } // GetAll returns all the Roles. -func (s *FileSystemRoleStore) GetAll() ([]rbac.LoadedRole, error) { +func (s *FileSystemRoleStore) GetAll(ctx context.Context) ([]rbac.LoadedRole, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -142,7 +143,7 @@ func (s *FileSystemRoleStore) GetAll() ([]rbac.LoadedRole, error) { } // Update updates an exsisting Role. -func (s *FileSystemRoleStore) Update(roleToUpdate rbac.Role) error { +func (s *FileSystemRoleStore) Update(ctx context.Context, roleToUpdate rbac.Role) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() diff --git a/controller/rbac/roleFileSystemStore_test.go b/controller/rbac/roleFileSystemStore_test.go index 347b42853..b5339ea7d 100644 --- a/controller/rbac/roleFileSystemStore_test.go +++ b/controller/rbac/roleFileSystemStore_test.go @@ -1,12 +1,14 @@ package rbac import ( + "context" "reflect" "testing" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" + "go.mongodb.org/mongo-driver/mongo" ) func TestFileSystemRoleStore_Add(t *testing.T) { @@ -25,8 +27,9 @@ func TestFileSystemRoleStore_Add(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewRoleStore() - if err := s.Add(tt.args.RoleToAdd); (err != nil) != tt.wantErr { + s := NewRoleStore(&mongo.Database{}) + + if err := s.Add(context.TODO(), tt.args.RoleToAdd); (err != nil) != tt.wantErr { t.Errorf("FileSystemRoleStore.Add() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -52,11 +55,11 @@ func TestFileSystemRoleStore_Delete(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewRoleStore() - if err := s.Add(addRole); err != nil { + s := NewRoleStore(&mongo.Database{}) + if err := s.Add(context.TODO(), addRole); err != nil { t.Error(err) } - if err := s.Delete(tt.args.RoleToDelete); (err != nil) != tt.wantErr { + if err := s.Delete(context.TODO(), tt.args.RoleToDelete); (err != nil) != tt.wantErr { t.Errorf("FileSystemRoleStore.Delete() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -88,11 +91,11 @@ func TestFileSystemRoleStore_Get(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewRoleStore() - if err := s.Add(addRole); err != nil { + s := NewRoleStore(&mongo.Database{}) + if err := s.Add(context.TODO(), addRole); err != nil { t.Error(err) } - got, err := s.Get(tt.args.query) + got, err := s.Get(context.TODO(), tt.args.query) if (err != nil) != tt.wantErr { t.Errorf("FileSystemRoleStore.Get() error = %v, wantErr %v", err, tt.wantErr) return @@ -126,14 +129,14 @@ func TestFileSystemRoleStore_GetAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var errs []error - s := NewRoleStore() - if err := s.Add(addRole1); err != nil { + s := NewRoleStore(&mongo.Database{}) + if err := s.Add(context.TODO(), addRole1); err != nil { errs = append(errs, err) } - if err := s.Add(addRole2); err != nil { + if err := s.Add(context.TODO(), addRole2); err != nil { errs = append(errs, err) } - if err := s.Add(addRole3); err != nil { + if err := s.Add(context.TODO(), addRole3); err != nil { errs = append(errs, err) } @@ -141,7 +144,7 @@ func TestFileSystemRoleStore_GetAll(t *testing.T) { t.Error(errs) } - got, err := s.GetAll() + got, err := s.GetAll(context.TODO()) if (err != nil) != tt.wantErr { t.Errorf("FileSystemRoleStore.GetAll() error = %v, wantErr %v", err, tt.wantErr) return @@ -173,11 +176,11 @@ func TestFileSystemRoleStore_Update(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewRoleStore() - if err := s.Add(addRole1); err != nil { + s := NewRoleStore(&mongo.Database{}) + if err := s.Add(context.TODO(), addRole1); err != nil { t.Error(err) } - if err := s.Update(tt.args.roleToUpdate); (err != nil) != tt.wantErr { + if err := s.Update(context.TODO(), tt.args.roleToUpdate); (err != nil) != tt.wantErr { t.Errorf("FileSystemRoleStore.Update() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/controller/rbac/roleStore.go b/controller/rbac/roleStore.go index ad76beaa7..d2b05aa61 100644 --- a/controller/rbac/roleStore.go +++ b/controller/rbac/roleStore.go @@ -3,15 +3,16 @@ package rbac import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" + "go.mongodb.org/mongo-driver/mongo" ) // NewRoleStore returns a roleStore. -func NewRoleStore() rbac.RoleStore { +func NewRoleStore(db *mongo.Database) rbac.RoleStore { storeMode := store.GetStoreMode() switch storeMode { case store.Database: - return &DatabaseRoleStore{"role.json"} + return NewDatabaseRoleStore(db) default: store := NewFileSystemRoleStore() return store diff --git a/controller/rbac/userFileSystemStore.go b/controller/rbac/userFileSystemStore.go index 257fb5e11..c37b2fcfc 100644 --- a/controller/rbac/userFileSystemStore.go +++ b/controller/rbac/userFileSystemStore.go @@ -1,6 +1,7 @@ package rbac import ( + "context" "encoding/json" "os" "sync" @@ -59,7 +60,7 @@ func (s *FileSystemUserStore) writeAllUsersToFile(users []rbac.LoadedUser) error } // Add adds a User to the User store. -func (s *FileSystemUserStore) Add(UserToAdd rbac.User) error { +func (s *FileSystemUserStore) Add(ctx context.Context, UserToAdd rbac.User) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -85,7 +86,7 @@ func (s *FileSystemUserStore) Add(UserToAdd rbac.User) error { } // Delete deletes a User from the User store. -func (s *FileSystemUserStore) Delete(userToDelete rbac.User) error { +func (s *FileSystemUserStore) Delete(ctx context.Context, userToDelete rbac.User) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -113,7 +114,7 @@ func (s *FileSystemUserStore) Delete(userToDelete rbac.User) error { } // Get takes a Users ID and return the User if found. -func (s *FileSystemUserStore) Get(query store.Query) (rbac.LoadedUser, error) { +func (s *FileSystemUserStore) Get(ctx context.Context, query store.Query) (rbac.LoadedUser, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -133,7 +134,7 @@ func (s *FileSystemUserStore) Get(query store.Query) (rbac.LoadedUser, error) { } // GetAll returns all the Users. -func (s *FileSystemUserStore) GetAll() ([]rbac.LoadedUser, error) { +func (s *FileSystemUserStore) GetAll(ctx context.Context) ([]rbac.LoadedUser, error) { s.fileMutex.Lock() defer s.fileMutex.Unlock() @@ -142,7 +143,7 @@ func (s *FileSystemUserStore) GetAll() ([]rbac.LoadedUser, error) { } // Update updates an exsisting user. -func (s *FileSystemUserStore) Update(userToUpdate rbac.User) error { +func (s *FileSystemUserStore) Update(ctx context.Context, userToUpdate rbac.User) error { s.fileMutex.Lock() defer s.fileMutex.Unlock() diff --git a/controller/rbac/userFileSystemStore_test.go b/controller/rbac/userFileSystemStore_test.go index 847d2be83..7ddeb6ee3 100644 --- a/controller/rbac/userFileSystemStore_test.go +++ b/controller/rbac/userFileSystemStore_test.go @@ -1,6 +1,7 @@ package rbac import ( + "context" "reflect" "testing" @@ -8,6 +9,7 @@ import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" + "go.mongodb.org/mongo-driver/mongo" ) func TestFileSystemUserStore_Add(t *testing.T) { @@ -32,9 +34,9 @@ func TestFileSystemUserStore_Add(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewUserStore() + s := NewUserStore(&mongo.Database{}) - if err := s.Add(tt.args.UserToAdd); (err != nil) != tt.wantErr { + if err := s.Add(context.TODO(), tt.args.UserToAdd); (err != nil) != tt.wantErr { t.Errorf("FileSystemUserStore.Add() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -63,11 +65,11 @@ func TestFileSystemUserStore_Delete(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewUserStore() - if err := s.Add(testingUser); err != nil { + s := NewUserStore(&mongo.Database{}) + if err := s.Add(context.TODO(), testingUser); err != nil { t.Error(err) } - if err := s.Delete(tt.args.UserToDelete); (err != nil) != tt.wantErr { + if err := s.Delete(context.TODO(), tt.args.UserToDelete); (err != nil) != tt.wantErr { t.Errorf("FileSystemUserStore.Delete() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -100,11 +102,11 @@ func TestFileSystemUserStore_Get(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewUserStore() - if err := s.Add(testingUser); err != nil { + s := NewUserStore(&mongo.Database{}) + if err := s.Add(context.TODO(), testingUser); err != nil { t.Error(err) } - got, err := s.Get(tt.args.query) + got, err := s.Get(context.TODO(), tt.args.query) if (err != nil) != tt.wantErr { t.Errorf("FileSystemUserStore.Get() error = %v, wantErr %v", err, tt.wantErr) return @@ -138,14 +140,14 @@ func TestFileSystemUserStore_GetAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var errs []error - s := NewUserStore() - if err := s.Add(testingUser1); err != nil { + s := NewUserStore(&mongo.Database{}) + if err := s.Add(context.TODO(), testingUser1); err != nil { errs = append(errs, err) } - if err := s.Add(testingUser2); err != nil { + if err := s.Add(context.TODO(), testingUser2); err != nil { errs = append(errs, err) } - if err := s.Add(testingUser3); err != nil { + if err := s.Add(context.TODO(), testingUser3); err != nil { errs = append(errs, err) } @@ -153,7 +155,7 @@ func TestFileSystemUserStore_GetAll(t *testing.T) { t.Error(errs) } - got, err := s.GetAll() + got, err := s.GetAll(context.TODO()) if (err != nil) != tt.wantErr { t.Errorf("FileSystemUserStore.GetAll() error = %v, wantErr %v", err, tt.wantErr) return @@ -189,11 +191,11 @@ func TestFileSystemUserStore_Update(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := NewUserStore() - if err := s.Add(testingUser); err != nil { + s := NewUserStore(&mongo.Database{}) + if err := s.Add(context.TODO(), testingUser); err != nil { t.Error(err) } - if err := s.Update(tt.args.userToUpdate); (err != nil) != tt.wantErr { + if err := s.Update(context.TODO(), tt.args.userToUpdate); (err != nil) != tt.wantErr { t.Errorf("FileSystemUserStore.Update() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/controller/rbac/userStore.go b/controller/rbac/userStore.go index 48e6c8618..41f3501df 100644 --- a/controller/rbac/userStore.go +++ b/controller/rbac/userStore.go @@ -3,15 +3,16 @@ package rbac import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" + "go.mongodb.org/mongo-driver/mongo" ) // NewUserStore returns a userStore. -func NewUserStore() rbac.UserStore { +func NewUserStore(db *mongo.Database) rbac.UserStore { storeMode := store.GetStoreMode() switch storeMode { case store.Database: - return &DatabaseUserStore{"user.json"} + return NewDatabaseUserStore(db) default: store := NewFileSystemUserStore() return store diff --git a/controller/store/genericStore.go b/controller/store/genericStore.go index cb4c255cf..f3a032034 100644 --- a/controller/store/genericStore.go +++ b/controller/store/genericStore.go @@ -1,6 +1,7 @@ package store import ( + "context" "errors" "github.com/google/uuid" @@ -18,14 +19,14 @@ type GenericStore[T storableConstraint] struct { } // NewGenericStore returns a specific in-memory store for a type T. -func NewGenericStore[T storableConstraint]() GenericStore[T] { - return GenericStore[T]{ +func NewGenericStore[T storableConstraint]() *GenericStore[T] { + return &GenericStore[T]{ Store: make(map[uuid.UUID]T), nameLookupTable: make(map[string]uuid.UUID), } } -func (t *GenericStore[T]) Add(item T) error { +func (t *GenericStore[T]) Add(ctx context.Context, item T) error { _, ok := t.Store[item.ID()] if ok { return errors.New("item not found") @@ -37,7 +38,7 @@ func (t *GenericStore[T]) Add(item T) error { return nil } -func (t *GenericStore[T]) Update(item T) error { +func (t *GenericStore[T]) Update(ctx context.Context, item T) error { _, ok := t.Store[item.ID()] if ok { return nil @@ -49,13 +50,13 @@ func (t *GenericStore[T]) Update(item T) error { return nil } -func (t *GenericStore[T]) Delete(item T) error { +func (t *GenericStore[T]) Delete(ctx context.Context, item T) error { delete(t.Store, item.ID()) return nil } -func (t *GenericStore[T]) Get(query Query) (T, error) { +func (t *GenericStore[T]) Get(ctx context.Context, query Query) (T, error) { // First search for direct hit on UUID. item, ok := t.Store[query.ID] if !ok { @@ -76,7 +77,7 @@ func (t *GenericStore[T]) Get(query Query) (T, error) { return item, nil } -func (t *GenericStore[T]) GetAll() ([]T, error) { +func (t *GenericStore[T]) GetAll(ctx context.Context) ([]T, error) { var allItems []T for _, item := range t.Store { diff --git a/controller/topology/nodes/databaseNodeStore.go b/controller/topology/nodes/databaseNodeStore.go new file mode 100644 index 000000000..41271c9e0 --- /dev/null +++ b/controller/topology/nodes/databaseNodeStore.go @@ -0,0 +1,210 @@ +package nodes + +import ( + "context" + "fmt" + "time" + + "code.fbi.h-da.de/danet/gosdn/controller/customerrs" + "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" + query "code.fbi.h-da.de/danet/gosdn/controller/store" + + "github.com/google/uuid" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" + "go.mongodb.org/mongo-driver/mongo/writeconcern" +) + +const storeName = "node-store.json" + +// Store defines a NodeStore interface. +type Store interface { + Add(context.Context, Node) error + Update(context.Context, Node) error + Delete(context.Context, Node) error + Get(context.Context, query.Query) (Node, error) + GetAll(context.Context) ([]Node, error) +} + +// DatabaseNodeStore is a database store for nodes. +type DatabaseNodeStore struct { + collection *mongo.Collection +} + +// NewDatabaseNodeStore returns a NodeStore. +func NewDatabaseNodeStore(db *mongo.Database) Store { + collection := db.Collection(storeName) + + return &DatabaseNodeStore{ + collection: collection, + } +} + +// Get takes a nodes's UUID or name and returns the nodes. +func (s *DatabaseNodeStore) Get(ctx context.Context, query query.Query) (Node, error) { + var loadedNode Node + + if query.ID.String() != "" { + loadedNode, err := s.getByID(ctx, query.ID) + if err != nil { + return loadedNode, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} + } + + return loadedNode, nil + } + + loadedNode, err := s.getByName(ctx, query.Name) + if err != nil { + return loadedNode, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} + } + + return loadedNode, nil +} + +func (s *DatabaseNodeStore) getByID(ctx context.Context, idOfNode uuid.UUID) (loadedNode Node, err error) { + idAsByteArray, _ := idOfNode.MarshalBinary() + + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idAsByteArray}}) + if result == nil { + return loadedNode, customerrs.CouldNotFindError{ID: idOfNode} + } + + err = result.Decode(&loadedNode) + if err != nil { + return loadedNode, customerrs.CouldNotMarshallError{Identifier: idOfNode, Type: loadedNode, Err: err} + } + + return loadedNode, nil +} + +func (s *DatabaseNodeStore) getByName(ctx context.Context, nameOfNode string) (loadedNode Node, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfNode}}) + if result == nil { + return loadedNode, customerrs.CouldNotFindError{Name: nameOfNode} + } + + err = result.Decode(&loadedNode) + if err != nil { + return loadedNode, customerrs.CouldNotMarshallError{Identifier: nameOfNode, Type: loadedNode, Err: err} + } + + return loadedNode, nil +} + +// GetAll returns all stored nodes. +func (s *DatabaseNodeStore) GetAll(ctx context.Context) (loadedNode []Node, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) + if err != nil { + return []Node{}, err + } + defer func() { + if ferr := cursor.Close(ctx); ferr != nil { + fErrString := ferr.Error() + err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) + } + }() + + err = cursor.All(ctx, &loadedNode) + if err != nil { + return loadedNode, customerrs.CouldNotMarshallError{Type: loadedNode, Err: err} + } + + return loadedNode, nil +} + +// Add adds a node to the node store. +func (s *DatabaseNodeStore) Add(ctx context.Context, node Node) (err error) { + node.Metadata.ResourceVersion = 0 + node.Metadata.CreatedAt = time.Now() + node.Metadata.LastUpdated = time.Now() + + _, err = s.collection. + InsertOne(ctx, node) + if err != nil { + return customerrs.CouldNotCreateError{Identifier: node.ID, Type: node, Err: err} + } + + return nil +} + +// Update updates a existing node. +func (s *DatabaseNodeStore) Update(ctx context.Context, node Node) (err error) { + var updatedLoadedNodes Node + + db, err := database.Connect() + if err != nil { + return err + } + + wc := writeconcern.Majority() + txnOptions := options.Transaction().SetWriteConcern(wc) + // Starts a session on the client + session, err := db.Client().StartSession() + if err != nil { + return err + } + // Defers ending the session after the transaction is committed or ended + defer session.EndSession(ctx) + + // Transaction + callback := func(sessCtx mongo.SessionContext) (interface{}, error) { + // 1. Fetch exisiting Entity + existingNode, err := s.getByID(ctx, node.ID) + if err != nil { + return nil, err + } + + // 2. Check if Entity.Metadata.ResourceVersion == UpdatedEntity.Metadata.ResourceVersion + if node.Metadata.ResourceVersion != existingNode.Metadata.ResourceVersion { + // 2.1 End transaction + // 2.2 If no -> return error + + return nil, fmt.Errorf( + "resource version %d of provided node %s is older or newer than %d in the store", + node.Metadata.ResourceVersion, + node.ID.String(), existingNode.Metadata.ResourceVersion, + ) + } + + // Important: You must pass sessCtx as the Context parameter to the operations for them to be executed in the + // transaction. + update := bson.D{primitive.E{Key: "$set", Value: node}} + + upsert := false + after := options.After + opt := options.FindOneAndUpdateOptions{ + Upsert: &upsert, + ReturnDocument: &after, + } + + err = s.collection. + FindOneAndUpdate( + ctx, bson.M{"_id": node.ID.String()}, update, &opt). + Decode(&updatedLoadedNodes) + if err != nil { + return nil, customerrs.CouldNotUpdateError{Identifier: node.ID, Type: node, Err: err} + } + + // 3. End transaction + return "", nil + } + + _, err = session.WithTransaction(ctx, callback, txnOptions) + if err != nil { + return err + } + + return nil +} + +// Delete deletes a node from the node store. +func (s *DatabaseNodeStore) Delete(ctx context.Context, node Node) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: node.ID.String()}}) + if err != nil { + return err + } + + return nil +} diff --git a/controller/topology/nodes/nodeService.go b/controller/topology/nodes/nodeService.go index 518626c15..e047e425a 100644 --- a/controller/topology/nodes/nodeService.go +++ b/controller/topology/nodes/nodeService.go @@ -1,6 +1,8 @@ package nodes import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/event" eventInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/event" query "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -55,7 +57,8 @@ func (n *NodeService) EnsureExists(node Node) (Node, error) { } func (n *NodeService) createNode(node Node) (Node, error) { - err := n.store.Add(node) + ctx := context.Background() + err := n.store.Add(ctx, node) if err != nil { return node, err } @@ -77,7 +80,8 @@ func (n *NodeService) createNode(node Node) (Node, error) { // Update updates an existing node. func (n *NodeService) Update(node Node) error { - err := n.store.Update(node) + ctx := context.Background() + err := n.store.Update(ctx, node) if err != nil { return err } @@ -99,7 +103,8 @@ func (n *NodeService) Update(node Node) error { // Delete deletes a node. func (n *NodeService) Delete(node Node) error { - err := n.store.Delete(node) + ctx := context.Background() + err := n.store.Delete(ctx, node) if err != nil { return err } @@ -121,7 +126,8 @@ func (n *NodeService) Delete(node Node) error { // Get gets a node. func (n *NodeService) Get(query query.Query) (Node, error) { - node, err := n.store.Get(query) + ctx := context.Background() + node, err := n.store.Get(ctx, query) if err != nil { return node, err } @@ -131,7 +137,8 @@ func (n *NodeService) Get(query query.Query) (Node, error) { // GetAll gets all existing nodes. func (n *NodeService) GetAll() ([]Node, error) { - nodes, err := n.store.GetAll() + ctx := context.Background() + nodes, err := n.store.GetAll(ctx) if err != nil { return nodes, err } diff --git a/controller/topology/nodes/nodeService_test.go b/controller/topology/nodes/nodeService_test.go index 3eca6b368..fd7273d8f 100644 --- a/controller/topology/nodes/nodeService_test.go +++ b/controller/topology/nodes/nodeService_test.go @@ -1,6 +1,7 @@ package nodes import ( + "context" "reflect" "testing" @@ -32,9 +33,10 @@ func getEmptyNode() Node { func getTestStoreWithNodes(t *testing.T, nodes []Node) Store { store := store.NewGenericStore[Node]() + ctx := context.TODO() for _, node := range nodes { - err := store.Add(node) + err := store.Add(ctx, node) if err != nil { t.Fatalf("failed to prepare test store while adding node: %v", err) } diff --git a/controller/topology/nodes/nodeStore.go b/controller/topology/nodes/nodeStore.go deleted file mode 100644 index ce9af485f..000000000 --- a/controller/topology/nodes/nodeStore.go +++ /dev/null @@ -1,282 +0,0 @@ -package nodes - -import ( - "fmt" - "time" - - "code.fbi.h-da.de/danet/gosdn/controller/customerrs" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" - query "code.fbi.h-da.de/danet/gosdn/controller/store" - - "github.com/google/uuid" - "go.mongodb.org/mongo-driver/bson" - "go.mongodb.org/mongo-driver/bson/primitive" - "go.mongodb.org/mongo-driver/mongo" - "go.mongodb.org/mongo-driver/mongo/options" - "go.mongodb.org/mongo-driver/mongo/writeconcern" -) - -// Store defines a NodeStore interface. -type Store interface { - Add(Node) error - Update(Node) error - Delete(Node) error - Get(query.Query) (Node, error) - GetAll() ([]Node, error) -} - -// DatabaseNodeStore is a database store for nodes. -type DatabaseNodeStore struct { - storeName string -} - -// NewDatabaseNodeStore returns a NodeStore. -func NewDatabaseNodeStore() Store { - return &DatabaseNodeStore{ - storeName: fmt.Sprint("node-store.json"), - } -} - -// Get takes a nodes's UUID or name and returns the nodes. -func (s *DatabaseNodeStore) Get(query query.Query) (Node, error) { - var loadedNode Node - - if query.ID.String() != "" { - loadedNode, err := s.getByID(query.ID) - if err != nil { - return loadedNode, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} - } - - return loadedNode, nil - } - - loadedNode, err := s.getByName(query.Name) - if err != nil { - return loadedNode, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} - } - - return loadedNode, nil -} - -func (s *DatabaseNodeStore) getByID(idOfNode uuid.UUID) (loadedNode Node, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedNode, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - idAsByteArray, _ := idOfNode.MarshalBinary() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idAsByteArray}}) - if result == nil { - return loadedNode, customerrs.CouldNotFindError{ID: idOfNode} - } - - err = result.Decode(&loadedNode) - if err != nil { - return loadedNode, customerrs.CouldNotMarshallError{Identifier: idOfNode, Type: loadedNode, Err: err} - } - - return loadedNode, nil -} - -func (s *DatabaseNodeStore) getByName(nameOfNode string) (loadedNode Node, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedNode, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfNode}}) - if result == nil { - return loadedNode, customerrs.CouldNotFindError{Name: nameOfNode} - } - - err = result.Decode(&loadedNode) - if err != nil { - return loadedNode, customerrs.CouldNotMarshallError{Identifier: nameOfNode, Type: loadedNode, Err: err} - } - - return loadedNode, nil -} - -// GetAll returns all stored nodes. -func (s *DatabaseNodeStore) GetAll() (loadedNode []Node, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedNode, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - - cursor, err := collection.Find(ctx, bson.D{}) - if err != nil { - return []Node{}, err - } - defer func() { - if ferr := cursor.Close(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - err = cursor.All(ctx, &loadedNode) - if err != nil { - return loadedNode, customerrs.CouldNotMarshallError{Type: loadedNode, Err: err} - } - - return loadedNode, nil -} - -// Add adds a node to the node store. -func (s *DatabaseNodeStore) Add(node Node) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - node.Metadata.ResourceVersion = 0 - node.Metadata.CreatedAt = time.Now() - node.Metadata.LastUpdated = time.Now() - - _, err = client.Database(database.DatabaseName). - Collection(s.storeName). - InsertOne(ctx, node) - if err != nil { - return customerrs.CouldNotCreateError{Identifier: node.ID, Type: node, Err: err} - } - - return nil -} - -// Update updates a existing node. -func (s *DatabaseNodeStore) Update(node Node) (err error) { - var updatedLoadedNodes Node - - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - // 1. Start Transaction - wcMajority := writeconcern.Majority() - wcMajorityCollectionOpts := options.Collection().SetWriteConcern(wcMajority) - nodeCollection := client.Database(database.DatabaseName).Collection(s.storeName, wcMajorityCollectionOpts) - - session, err := client.StartSession() - if err != nil { - return err - } - defer session.EndSession(ctx) - - // 2. Fetch exisiting Entity - existingNode, err := s.getByID(node.ID) - if err != nil { - return err - } - - // 3. Check if Entity.Metadata.ResourceVersion == UpdatedEntity.Metadata.ResourceVersion - if node.Metadata.ResourceVersion != existingNode.Metadata.ResourceVersion { - // 3.1.1 End transaction - // 3.1.2 If no -> return error - - return fmt.Errorf( - "resource version %d of provided node %s is older or newer than %d in the store", - node.Metadata.ResourceVersion, - node.ID.String(), existingNode.Metadata.ResourceVersion, - ) - } - - // 3.2.1 If yes -> Update entity in callback - callback := func(sessCtx mongo.SessionContext) (interface{}, error) { - // Important: You must pass sessCtx as the Context parameter to the operations for them to be executed in the - // transaction. - update := bson.D{primitive.E{Key: "$set", Value: node}} - - upsert := false - after := options.After - opt := options.FindOneAndUpdateOptions{ - Upsert: &upsert, - ReturnDocument: &after, - } - - err = nodeCollection. - FindOneAndUpdate( - ctx, bson.M{"_id": node.ID.String()}, update, &opt). - Decode(&updatedLoadedNodes) - if err != nil { - return nil, customerrs.CouldNotUpdateError{Identifier: node.ID, Type: node, Err: err} - } - - // 3.2.2 End transaction - return "", nil - } - - _, err = session.WithTransaction(ctx, callback) - if err != nil { - return err - } - - return nil -} - -// Delete deletes a node from the node store. -func (s *DatabaseNodeStore) Delete(node Node) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - _, err = collection.DeleteOne(ctx, bson.D{primitive.E{Key: node.ID.String()}}) - if err != nil { - return err - } - - return nil -} diff --git a/controller/topology/nodes/store.go b/controller/topology/nodes/store.go new file mode 100644 index 000000000..1e4595ec9 --- /dev/null +++ b/controller/topology/nodes/store.go @@ -0,0 +1,21 @@ +package nodes + +import ( + "code.fbi.h-da.de/danet/gosdn/controller/store" + topoStore "code.fbi.h-da.de/danet/gosdn/controller/topology/store" + + "go.mongodb.org/mongo-driver/mongo" +) + +// NewNodeStore returns a NodeStore. +func NewNodeStore(db *mongo.Database) Store { + storeMode := store.GetStoreMode() + + switch storeMode { + case store.Database: + return NewDatabaseNodeStore(db) + + default: + return topoStore.NewGenericStore[Node]() + } +} diff --git a/controller/topology/ports/portService.go b/controller/topology/ports/portService.go index 25bfc09c6..6870d977b 100644 --- a/controller/topology/ports/portService.go +++ b/controller/topology/ports/portService.go @@ -1,6 +1,8 @@ package ports import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/event" eventInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/event" query "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -53,7 +55,8 @@ func (p *PortService) EnsureExists(port Port) (Port, error) { } func (p *PortService) createPort(port Port) (Port, error) { - err := p.store.Add(port) + ctx := context.Background() + err := p.store.Add(ctx, port) if err != nil { return port, err } @@ -75,7 +78,8 @@ func (p *PortService) createPort(port Port) (Port, error) { // Update updates an existing port. func (p *PortService) Update(port Port) error { - err := p.store.Update(port) + ctx := context.Background() + err := p.store.Update(ctx, port) if err != nil { return err } @@ -97,7 +101,8 @@ func (p *PortService) Update(port Port) error { // Delete deletes a port. func (p *PortService) Delete(port Port) error { - err := p.store.Delete(port) + ctx := context.Background() + err := p.store.Delete(ctx, port) if err != nil { return err } @@ -119,7 +124,8 @@ func (p *PortService) Delete(port Port) error { // Get gets a port. func (p *PortService) Get(query query.Query) (Port, error) { - port, err := p.store.Get(query) + ctx := context.Background() + port, err := p.store.Get(ctx, query) if err != nil { return port, err } @@ -129,7 +135,8 @@ func (p *PortService) Get(query query.Query) (Port, error) { // GetAll gets all existing ports. func (p *PortService) GetAll() ([]Port, error) { - nodes, err := p.store.GetAll() + ctx := context.Background() + nodes, err := p.store.GetAll(ctx) if err != nil { return nodes, err } diff --git a/controller/topology/ports/portService_test.go b/controller/topology/ports/portService_test.go index 951f2730e..537b8e690 100644 --- a/controller/topology/ports/portService_test.go +++ b/controller/topology/ports/portService_test.go @@ -1,6 +1,7 @@ package ports import ( + "context" "reflect" "testing" @@ -28,9 +29,10 @@ func getEmptyPort() Port { func getTestStoreWithPorts(t *testing.T, ports []Port) Store { store := store.NewGenericStore[Port]() + ctx := context.TODO() for _, port := range ports { - err := store.Add(port) + err := store.Add(ctx, port) if err != nil { t.Fatalf("failed to prepare test store while adding port: %v", err) } diff --git a/controller/topology/ports/portStore.go b/controller/topology/ports/portStore.go index 7045f2423..f0fd2c91e 100644 --- a/controller/topology/ports/portStore.go +++ b/controller/topology/ports/portStore.go @@ -1,46 +1,51 @@ package ports import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkelement" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" query "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) +const storeName = "port-store.json" + // Store defines a PortStore interface. type Store interface { - Add(Port) error - Update(Port) error - Delete(Port) error - Get(query.Query) (Port, error) - GetAll() ([]Port, error) + Add(context.Context, Port) error + Update(context.Context, Port) error + Delete(context.Context, Port) error + Get(context.Context, query.Query) (Port, error) + GetAll(context.Context) ([]Port, error) } // DatabasePortStore is a database store for ports. type DatabasePortStore struct { - storeName string + collection *mongo.Collection } // NewDatabasePortStore returns a PortStore. -func NewDatabasePortStore() Store { +func NewDatabasePortStore(db *mongo.Database) Store { + collection := db.Collection(storeName) + return &DatabasePortStore{ - storeName: fmt.Sprint("port-store.json"), + collection: collection, } } // Get takes a Ports's UUID or name and returns the port. -func (s *DatabasePortStore) Get(query query.Query) (Port, error) { +func (s *DatabasePortStore) Get(ctx context.Context, query query.Query) (Port, error) { var loadedPort Port if query.ID.String() != "" { - loadedPort, err := s.getByID(query.ID) + loadedPort, err := s.getByID(ctx, query.ID) if err != nil { return loadedPort, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} } @@ -48,7 +53,7 @@ func (s *DatabasePortStore) Get(query query.Query) (Port, error) { return loadedPort, nil } - loadedPort, err := s.getByName(query.Name) + loadedPort, err := s.getByName(ctx, query.Name) if err != nil { return loadedPort, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} } @@ -56,24 +61,10 @@ func (s *DatabasePortStore) Get(query query.Query) (Port, error) { return loadedPort, nil } -func (s *DatabasePortStore) getByID(idOfPort uuid.UUID) (loadedPort Port, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedPort, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - +func (s *DatabasePortStore) getByID(ctx context.Context, idOfPort uuid.UUID) (loadedPort Port, err error) { idAsByteArray, _ := idOfPort.MarshalBinary() - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idAsByteArray}}) + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idAsByteArray}}) if result == nil { return loadedPort, customerrs.CouldNotFindError{ID: idOfPort} } @@ -86,22 +77,8 @@ func (s *DatabasePortStore) getByID(idOfPort uuid.UUID) (loadedPort Port, err er return loadedPort, nil } -func (s *DatabasePortStore) getByName(nameOfPort string) (loadedPort Port, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedPort, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfPort}}) +func (s *DatabasePortStore) getByName(ctx context.Context, nameOfPort string) (loadedPort Port, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfPort}}) if result == nil { return loadedPort, customerrs.CouldNotFindError{Name: nameOfPort} } @@ -115,23 +92,8 @@ func (s *DatabasePortStore) getByName(nameOfPort string) (loadedPort Port, err e } // GetAll returns all stored ports. -func (s *DatabasePortStore) GetAll() (loadedPorts []Port, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - - cursor, err := collection.Find(ctx, bson.D{}) +func (s *DatabasePortStore) GetAll(ctx context.Context) (loadedPorts []Port, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return []Port{}, err } @@ -151,21 +113,8 @@ func (s *DatabasePortStore) GetAll() (loadedPorts []Port, err error) { } // Add adds a port to the port store. -func (s *DatabasePortStore) Add(port Port) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.storeName). +func (s *DatabasePortStore) Add(ctx context.Context, port Port) (err error) { + _, err = s.collection. InsertOne(ctx, port) if err != nil { return customerrs.CouldNotCreateError{Identifier: port.ID, Type: port, Err: err} @@ -175,21 +124,9 @@ func (s *DatabasePortStore) Add(port Port) (err error) { } // Update updates a existing port. -func (s *DatabasePortStore) Update(port Port) (err error) { +func (s *DatabasePortStore) Update(ctx context.Context, port Port) (err error) { var updatedLoadedNetworkElement networkelement.LoadedNetworkElement - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - update := bson.D{primitive.E{Key: "$set", Value: port}} upsert := false @@ -199,8 +136,7 @@ func (s *DatabasePortStore) Update(port Port) (err error) { ReturnDocument: &after, } - err = client.Database(database.DatabaseName). - Collection(s.storeName). + err = s.collection. FindOneAndUpdate( ctx, bson.M{"_id": port.ID.String()}, update, &opt). Decode(&updatedLoadedNetworkElement) @@ -212,22 +148,8 @@ func (s *DatabasePortStore) Update(port Port) (err error) { } // Delete deletes a port from the port store. -func (s *DatabasePortStore) Delete(port Port) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - _, err = collection.DeleteOne(ctx, bson.D{primitive.E{Key: port.ID.String()}}) +func (s *DatabasePortStore) Delete(ctx context.Context, port Port) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: port.ID.String()}}) if err != nil { return err } diff --git a/controller/topology/ports/store.go b/controller/topology/ports/store.go new file mode 100644 index 000000000..9a21aa516 --- /dev/null +++ b/controller/topology/ports/store.go @@ -0,0 +1,21 @@ +package ports + +import ( + "code.fbi.h-da.de/danet/gosdn/controller/store" + topoStore "code.fbi.h-da.de/danet/gosdn/controller/topology/store" + + "go.mongodb.org/mongo-driver/mongo" +) + +// NewPortStore returns a PortStore. +func NewPortStore(db *mongo.Database) Store { + storeMode := store.GetStoreMode() + + switch storeMode { + case store.Database: + return NewDatabasePortStore(db) + + default: + return topoStore.NewGenericStore[Port]() + } +} diff --git a/controller/topology/routing-tables/routingTableService.go b/controller/topology/routing-tables/routingTableService.go index 2a1fc1348..7aeac9ac3 100644 --- a/controller/topology/routing-tables/routingTableService.go +++ b/controller/topology/routing-tables/routingTableService.go @@ -1,6 +1,8 @@ package routingtables import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/event" eventInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/event" query "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -63,7 +65,8 @@ func (r *RoutingTableService) EnsureExists(routingTable RoutingTable) (RoutingTa } func (r *RoutingTableService) createRoutingTable(routingTable RoutingTable) (RoutingTable, error) { - err := r.store.Add(routingTable) + ctx := context.Background() + err := r.store.Add(ctx, routingTable) if err != nil { return routingTable, err } @@ -85,7 +88,8 @@ func (r *RoutingTableService) createRoutingTable(routingTable RoutingTable) (Rou // Update updates an existing routingTable. func (r *RoutingTableService) Update(routingTable RoutingTable) error { - err := r.store.Update(routingTable) + ctx := context.Background() + err := r.store.Update(ctx, routingTable) if err != nil { return err } @@ -107,7 +111,8 @@ func (r *RoutingTableService) Update(routingTable RoutingTable) error { // Delete deletes a routingTable. func (r *RoutingTableService) Delete(routingTable RoutingTable) error { - err := r.store.Delete(routingTable) + ctx := context.Background() + err := r.store.Delete(ctx, routingTable) if err != nil { return err } @@ -129,7 +134,8 @@ func (r *RoutingTableService) Delete(routingTable RoutingTable) error { // Get gets a routingTable. func (r *RoutingTableService) Get(query query.Query) (RoutingTable, error) { - routingTable, err := r.store.Get(query) + ctx := context.Background() + routingTable, err := r.store.Get(ctx, query) if err != nil { return routingTable, err } @@ -139,7 +145,8 @@ func (r *RoutingTableService) Get(query query.Query) (RoutingTable, error) { // GetAll gets all existing routingTables. func (r *RoutingTableService) GetAll() ([]RoutingTable, error) { - nodes, err := r.store.GetAll() + ctx := context.Background() + nodes, err := r.store.GetAll(ctx) if err != nil { return nodes, err } diff --git a/controller/topology/routing-tables/routingTableService_test.go b/controller/topology/routing-tables/routingTableService_test.go index 432f967e9..097a6b41b 100644 --- a/controller/topology/routing-tables/routingTableService_test.go +++ b/controller/topology/routing-tables/routingTableService_test.go @@ -1,6 +1,7 @@ package routingtables import ( + "context" "reflect" "testing" @@ -47,9 +48,10 @@ func getTestRoutingTable() RoutingTable { func getTestStoreWithRoutingTables(t *testing.T, routingTables []RoutingTable) Store { store := store.NewGenericStore[RoutingTable]() + ctx := context.TODO() for _, rt := range routingTables { - err := store.Add(rt) + err := store.Add(ctx, rt) if err != nil { t.Fatalf("failed to prepare test store while adding routing table: %v", err) } @@ -60,9 +62,10 @@ func getTestStoreWithRoutingTables(t *testing.T, routingTables []RoutingTable) S func getTestStoreWithNodes(t *testing.T, nodesToAdd []nodes.Node) nodes.Store { store := store.NewGenericStore[nodes.Node]() + ctx := context.TODO() for _, node := range nodesToAdd { - err := store.Add(node) + err := store.Add(ctx, node) if err != nil { t.Fatalf("failed to prepare test store while adding node: %v", err) } @@ -73,9 +76,10 @@ func getTestStoreWithNodes(t *testing.T, nodesToAdd []nodes.Node) nodes.Store { func getTestStoreWithPorts(t *testing.T, portsToAdd []ports.Port) ports.Store { store := store.NewGenericStore[ports.Port]() + ctx := context.TODO() for _, port := range portsToAdd { - err := store.Add(port) + err := store.Add(ctx, port) if err != nil { t.Fatalf("failed to prepare test store while adding port: %v", err) } diff --git a/controller/topology/routing-tables/routingTableStore.go b/controller/topology/routing-tables/routingTableStore.go index d61299916..b73b28660 100644 --- a/controller/topology/routing-tables/routingTableStore.go +++ b/controller/topology/routing-tables/routingTableStore.go @@ -1,45 +1,50 @@ package routingtables import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" query "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) +const storeName = "routing-table-store.json" + // Store defines a RoutingTable store interface. type Store interface { - Add(RoutingTable) error - Update(RoutingTable) error - Delete(RoutingTable) error - Get(query.Query) (RoutingTable, error) - GetAll() ([]RoutingTable, error) + Add(context.Context, RoutingTable) error + Update(context.Context, RoutingTable) error + Delete(context.Context, RoutingTable) error + Get(context.Context, query.Query) (RoutingTable, error) + GetAll(context.Context) ([]RoutingTable, error) } // DatabaseRoutingTableStore is a database store for routingTables. type DatabaseRoutingTableStore struct { - storeName string + collection *mongo.Collection } // NewDatabaseRoutingTableStore returns a RoutingTableStore. -func NewDatabaseRoutingTableStore() Store { +func NewDatabaseRoutingTableStore(db *mongo.Database) Store { + collection := db.Collection(storeName) + return &DatabaseRoutingTableStore{ - storeName: fmt.Sprint("routing-table-store.json"), + collection: collection, } } // Get takes a routing-tables's UUID or name and returns the entries. -func (s *DatabaseRoutingTableStore) Get(query query.Query) (RoutingTable, error) { +func (s *DatabaseRoutingTableStore) Get(ctx context.Context, query query.Query) (RoutingTable, error) { var loadedRoutingTable RoutingTable if query.ID.String() != "" { - loadedRoutingTable, err := s.getByID(query.ID) + loadedRoutingTable, err := s.getByID(ctx, query.ID) if err != nil { return loadedRoutingTable, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} } @@ -47,7 +52,7 @@ func (s *DatabaseRoutingTableStore) Get(query query.Query) (RoutingTable, error) return loadedRoutingTable, nil } - loadedRoutingTable, err := s.getByName(query.Name) + loadedRoutingTable, err := s.getByName(ctx, query.Name) if err != nil { return loadedRoutingTable, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} } @@ -55,22 +60,8 @@ func (s *DatabaseRoutingTableStore) Get(query query.Query) (RoutingTable, error) return loadedRoutingTable, nil } -func (s *DatabaseRoutingTableStore) getByID(idOfRoutingTable uuid.UUID) (routingTable RoutingTable, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return routingTable, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfRoutingTable.String()}}) +func (s *DatabaseRoutingTableStore) getByID(ctx context.Context, idOfRoutingTable uuid.UUID) (routingTable RoutingTable, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfRoutingTable.String()}}) if result == nil { return routingTable, customerrs.CouldNotFindError{ID: idOfRoutingTable} } @@ -83,22 +74,8 @@ func (s *DatabaseRoutingTableStore) getByID(idOfRoutingTable uuid.UUID) (routing return routingTable, nil } -func (s *DatabaseRoutingTableStore) getByName(nameOfRoutingTable string) (loadedRoutingTable RoutingTable, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedRoutingTable, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfRoutingTable}}) +func (s *DatabaseRoutingTableStore) getByName(ctx context.Context, nameOfRoutingTable string) (loadedRoutingTable RoutingTable, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfRoutingTable}}) if result == nil { return loadedRoutingTable, customerrs.CouldNotFindError{Name: nameOfRoutingTable} } @@ -112,23 +89,8 @@ func (s *DatabaseRoutingTableStore) getByName(nameOfRoutingTable string) (loaded } // GetAll returns all stored routingTables. -func (s *DatabaseRoutingTableStore) GetAll() (loadedRoutingTable []RoutingTable, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - - cursor, err := collection.Find(ctx, bson.D{}) +func (s *DatabaseRoutingTableStore) GetAll(ctx context.Context) (loadedRoutingTable []RoutingTable, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return []RoutingTable{}, err } @@ -148,21 +110,8 @@ func (s *DatabaseRoutingTableStore) GetAll() (loadedRoutingTable []RoutingTable, } // Add adds a RoutingTable to the store. -func (s *DatabaseRoutingTableStore) Add(routingTable RoutingTable) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.storeName). +func (s *DatabaseRoutingTableStore) Add(ctx context.Context, routingTable RoutingTable) (err error) { + _, err = s.collection. InsertOne(ctx, routingTable) if err != nil { return customerrs.CouldNotCreateError{Identifier: routingTable.ID, Type: routingTable, Err: err} @@ -172,21 +121,9 @@ func (s *DatabaseRoutingTableStore) Add(routingTable RoutingTable) (err error) { } // Update updates a existing routingTable. -func (s *DatabaseRoutingTableStore) Update(routingTable RoutingTable) (err error) { +func (s *DatabaseRoutingTableStore) Update(ctx context.Context, routingTable RoutingTable) (err error) { var updatedLoadedRoutingTable RoutingTable - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - update := bson.D{primitive.E{Key: "$set", Value: routingTable}} upsert := false @@ -196,8 +133,7 @@ func (s *DatabaseRoutingTableStore) Update(routingTable RoutingTable) (err error ReturnDocument: &after, } - err = client.Database(database.DatabaseName). - Collection(s.storeName). + err = s.collection. FindOneAndUpdate( ctx, bson.M{"_id": routingTable.ID.String()}, update, &opt). Decode(&updatedLoadedRoutingTable) @@ -209,22 +145,8 @@ func (s *DatabaseRoutingTableStore) Update(routingTable RoutingTable) (err error } // Delete deletes a node from the node store. -func (s *DatabaseRoutingTableStore) Delete(routingTable RoutingTable) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - _, err = collection.DeleteOne(ctx, bson.D{primitive.E{Key: routingTable.ID.String()}}) +func (s *DatabaseRoutingTableStore) Delete(ctx context.Context, routingTable RoutingTable) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: routingTable.ID.String()}}) if err != nil { return err } diff --git a/controller/topology/routing-tables/store.go b/controller/topology/routing-tables/store.go new file mode 100644 index 000000000..3ccf144c6 --- /dev/null +++ b/controller/topology/routing-tables/store.go @@ -0,0 +1,21 @@ +package routingtables + +import ( + "code.fbi.h-da.de/danet/gosdn/controller/store" + topoStore "code.fbi.h-da.de/danet/gosdn/controller/topology/store" + + "go.mongodb.org/mongo-driver/mongo" +) + +// NewRoutingTableStore returns a RoutingTableStore. +func NewRoutingTableStore(db *mongo.Database) Store { + storeMode := store.GetStoreMode() + + switch storeMode { + case store.Database: + return NewDatabaseRoutingTableStore(db) + + default: + return topoStore.NewGenericStore[RoutingTable]() + } +} diff --git a/controller/topology/store.go b/controller/topology/store.go new file mode 100644 index 000000000..a345ebf86 --- /dev/null +++ b/controller/topology/store.go @@ -0,0 +1,22 @@ +package topology + +import ( + "code.fbi.h-da.de/danet/gosdn/controller/store" + "code.fbi.h-da.de/danet/gosdn/controller/topology/links" + topoStore "code.fbi.h-da.de/danet/gosdn/controller/topology/store" + + "go.mongodb.org/mongo-driver/mongo" +) + +// NewTopologyStore returns a Topologytore. +func NewTopologyStore(db *mongo.Database) Store { + storeMode := store.GetStoreMode() + + switch storeMode { + case store.Database: + return NewDatabaseTopologyStore(db) + + default: + return topoStore.NewGenericStore[links.Link]() + } +} diff --git a/controller/topology/store/genericStore.go b/controller/topology/store/genericStore.go index 557426c30..8b60c3edc 100644 --- a/controller/topology/store/genericStore.go +++ b/controller/topology/store/genericStore.go @@ -1,6 +1,7 @@ package store import ( + "context" "errors" "github.com/google/uuid" @@ -24,7 +25,7 @@ func NewGenericStore[T storableConstraint]() *GenericStore[T] { } } -func (t *GenericStore[T]) Add(item T) error { +func (t *GenericStore[T]) Add(ctx context.Context, item T) error { _, ok := t.store[item.GetID()] if ok { return errors.New("item already exists") @@ -35,7 +36,7 @@ func (t *GenericStore[T]) Add(item T) error { return nil } -func (t *GenericStore[T]) Update(item T) error { +func (t *GenericStore[T]) Update(ctx context.Context, item T) error { _, ok := t.store[item.GetID()] if !ok { return errors.New("item not found") @@ -46,7 +47,7 @@ func (t *GenericStore[T]) Update(item T) error { return nil } -func (t *GenericStore[T]) Delete(item T) error { +func (t *GenericStore[T]) Delete(ctx context.Context, item T) error { _, ok := t.store[item.GetID()] if !ok { return errors.New("item not found") @@ -57,7 +58,7 @@ func (t *GenericStore[T]) Delete(item T) error { return nil } -func (t *GenericStore[T]) Get(query query.Query) (T, error) { +func (t *GenericStore[T]) Get(ctx context.Context, query query.Query) (T, error) { // First search for direct hit on UUID. item, ok := t.store[query.ID] if !ok { @@ -67,7 +68,7 @@ func (t *GenericStore[T]) Get(query query.Query) (T, error) { return item, nil } -func (t *GenericStore[T]) GetAll() ([]T, error) { +func (t *GenericStore[T]) GetAll(ctx context.Context) ([]T, error) { var allItems []T for _, item := range t.store { diff --git a/controller/topology/store/genericStore_test.go b/controller/topology/store/genericStore_test.go index b88286530..0927776ed 100644 --- a/controller/topology/store/genericStore_test.go +++ b/controller/topology/store/genericStore_test.go @@ -1,6 +1,7 @@ package store import ( + "context" "reflect" "testing" @@ -90,16 +91,17 @@ func TestGenericStore_Get(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewGenericStore[testItem]() + ctx := context.TODO() for _, itemToAdd := range tt.fields.items { - err := s.Add(itemToAdd) + err := s.Add(ctx, itemToAdd) if err != nil { t.Errorf("GenericStore.Add() error = %v, wantErr %v", err, tt.wantErr) return } } - got, err := s.Get(tt.args.query) + got, err := s.Get(ctx, tt.args.query) if (err != nil) != tt.wantErr { t.Errorf("GenericStore.Get() error = %v, wantErr %v", err, tt.wantErr) return @@ -142,16 +144,17 @@ func TestGenericStore_GetAll(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewGenericStore[testItem]() + ctx := context.TODO() for _, itemToAdd := range tt.fields.items { - err := s.Add(itemToAdd) + err := s.Add(ctx, itemToAdd) if err != nil { t.Errorf("GenericStore.Add() error = %v, wantErr %v", err, tt.wantErr) return } } - got, err := s.GetAll() + got, err := s.GetAll(ctx) if (err != nil) != tt.wantErr { t.Errorf("GenericStore.GetAll() error = %v, wantErr %v", err, tt.wantErr) return @@ -197,16 +200,17 @@ func TestGenericStore_Add(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewGenericStore[testItem]() + ctx := context.TODO() for _, itemToAdd := range tt.fields.items { - err := s.Add(itemToAdd) + err := s.Add(ctx, itemToAdd) if err != nil { t.Errorf("GenericStore.Add() error = %v, wantErr %v", err, tt.wantErr) return } } - if err := s.Add(tt.args.item); (err != nil) != tt.wantErr { + if err := s.Add(ctx, tt.args.item); (err != nil) != tt.wantErr { t.Errorf("GenericStore.Add() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -247,16 +251,17 @@ func TestGenericStore_Update(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewGenericStore[testItem]() + ctx := context.TODO() for _, itemToAdd := range tt.fields.items { - err := s.Add(itemToAdd) + err := s.Add(ctx, itemToAdd) if err != nil { t.Errorf("GenericStore.Add() error = %v, wantErr %v", err, tt.wantErr) return } } - if err := s.Update(tt.args.item); (err != nil) != tt.wantErr { + if err := s.Update(ctx, tt.args.item); (err != nil) != tt.wantErr { t.Errorf("GenericStore.Update() error = %v, wantErr %v", err, tt.wantErr) } }) @@ -297,16 +302,17 @@ func TestGenericStore_Delete(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { s := NewGenericStore[testItem]() + ctx := context.TODO() for _, itemToAdd := range tt.fields.items { - err := s.Add(itemToAdd) + err := s.Add(ctx, itemToAdd) if err != nil { t.Errorf("GenericStore.Add() error = %v, wantErr %v", err, tt.wantErr) return } } - if err := s.Delete(tt.args.item); (err != nil) != tt.wantErr { + if err := s.Delete(ctx, tt.args.item); (err != nil) != tt.wantErr { t.Errorf("GenericStore.Delete() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/controller/topology/topologyService.go b/controller/topology/topologyService.go index f7934bf6c..b562a9d7b 100644 --- a/controller/topology/topologyService.go +++ b/controller/topology/topologyService.go @@ -1,6 +1,8 @@ package topology import ( + "context" + "code.fbi.h-da.de/danet/gosdn/controller/event" eventInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/event" query "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -68,8 +70,9 @@ func (t *TopoService) AddLink(link links.Link) error { // if err != nil { // return err // } + ctx := context.Background() - err := t.store.Add(link) + err := t.store.Add(ctx, link) if err != nil { return err } @@ -91,7 +94,8 @@ func (t *TopoService) AddLink(link links.Link) error { // UpdateLink updates an existing link. func (t *TopoService) UpdateLink(link links.Link) error { - err := t.store.Update(link) + ctx := context.Background() + err := t.store.Update(ctx, link) if err != nil { return err } @@ -113,9 +117,10 @@ func (t *TopoService) UpdateLink(link links.Link) error { // DeleteLink deletes a link. func (t *TopoService) DeleteLink(link links.Link) error { + ctx := context.Background() // TODO: Delete should also check if a node or port is used somewhere else and // if not, delete the node and its ports - err := t.store.Delete(link) + err := t.store.Delete(ctx, link) if err != nil { return err } @@ -137,7 +142,8 @@ func (t *TopoService) DeleteLink(link links.Link) error { // GetAll returns the current topology. func (t *TopoService) GetAll() ([]links.Link, error) { - topo, err := t.store.GetAll() + ctx := context.Background() + topo, err := t.store.GetAll(ctx) if err != nil { return topo, err } @@ -146,7 +152,8 @@ func (t *TopoService) GetAll() ([]links.Link, error) { // Get returns the current topology. func (t *TopoService) Get(query query.Query) (links.Link, error) { - link, err := t.store.Get(query) + ctx := context.Background() + link, err := t.store.Get(ctx, query) if err != nil { return link, err } diff --git a/controller/topology/topologyService_test.go b/controller/topology/topologyService_test.go index e28df6d9e..6ef41909a 100644 --- a/controller/topology/topologyService_test.go +++ b/controller/topology/topologyService_test.go @@ -1,6 +1,7 @@ package topology import ( + "context" "reflect" "testing" @@ -60,9 +61,10 @@ func getEmptyLink() links.Link { func getTestStoreWithLinks(t *testing.T, nodes []links.Link) Store { store := store.NewGenericStore[links.Link]() + ctx := context.TODO() for _, node := range nodes { - err := store.Add(node) + err := store.Add(ctx, node) if err != nil { t.Fatalf("failed to prepare test store while adding node: %v", err) } @@ -73,9 +75,10 @@ func getTestStoreWithLinks(t *testing.T, nodes []links.Link) Store { func getTestStoreWithNodes(t *testing.T, nodesToAdd []nodes.Node) nodes.Store { store := store.NewGenericStore[nodes.Node]() + ctx := context.TODO() for _, node := range nodesToAdd { - err := store.Add(node) + err := store.Add(ctx, node) if err != nil { t.Fatalf("failed to prepare test store while adding node: %v", err) } @@ -86,9 +89,10 @@ func getTestStoreWithNodes(t *testing.T, nodesToAdd []nodes.Node) nodes.Store { func getTestStoreWithPorts(t *testing.T, portsToAdd []ports.Port) ports.Store { store := store.NewGenericStore[ports.Port]() + ctx := context.TODO() for _, port := range portsToAdd { - err := store.Add(port) + err := store.Add(ctx, port) if err != nil { t.Fatalf("failed to prepare test store while adding port: %v", err) } diff --git a/controller/topology/topologyStore.go b/controller/topology/topologyStore.go index 399d71af7..009d1eab7 100644 --- a/controller/topology/topologyStore.go +++ b/controller/topology/topologyStore.go @@ -1,46 +1,51 @@ package topology import ( + "context" "fmt" "code.fbi.h-da.de/danet/gosdn/controller/customerrs" - "code.fbi.h-da.de/danet/gosdn/controller/nucleus/database" query "code.fbi.h-da.de/danet/gosdn/controller/store" "code.fbi.h-da.de/danet/gosdn/controller/topology/links" "github.com/google/uuid" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) +const storeName = "topology-store.json" + // Store defines a Topology store interface. type Store interface { - Add(links.Link) error - Update(links.Link) error - Delete(links.Link) error - Get(query.Query) (links.Link, error) - GetAll() ([]links.Link, error) + Add(context.Context, links.Link) error + Update(context.Context, links.Link) error + Delete(context.Context, links.Link) error + Get(context.Context, query.Query) (links.Link, error) + GetAll(context.Context) ([]links.Link, error) } // DatabaseTopologyStore is a database store for the topology. type DatabaseTopologyStore struct { - storeName string + collection *mongo.Collection } // NewDatabaseTopologyStore returns a TopologyStore. -func NewDatabaseTopologyStore() Store { +func NewDatabaseTopologyStore(db *mongo.Database) Store { + collection := db.Collection(storeName) + return &DatabaseTopologyStore{ - storeName: fmt.Sprint("topology-store.json"), + collection: collection, } } // Get takes a link's UUID or name and returns the link. -func (s *DatabaseTopologyStore) Get(query query.Query) (links.Link, error) { +func (s *DatabaseTopologyStore) Get(ctx context.Context, query query.Query) (links.Link, error) { var loadedTopology links.Link if query.ID.String() != "" { - loadedTopology, err := s.getByID(query.ID) + loadedTopology, err := s.getByID(ctx, query.ID) if err != nil { return loadedTopology, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} } @@ -48,7 +53,7 @@ func (s *DatabaseTopologyStore) Get(query query.Query) (links.Link, error) { return loadedTopology, nil } - loadedTopology, err := s.getByName(query.Name) + loadedTopology, err := s.getByName(ctx, query.Name) if err != nil { return loadedTopology, customerrs.CouldNotFindError{ID: query.ID, Name: query.Name} } @@ -56,22 +61,8 @@ func (s *DatabaseTopologyStore) Get(query query.Query) (links.Link, error) { return loadedTopology, nil } -func (s *DatabaseTopologyStore) getByID(idOfTopology uuid.UUID) (loadedTopology links.Link, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedTopology, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfTopology.String()}}) +func (s *DatabaseTopologyStore) getByID(ctx context.Context, idOfTopology uuid.UUID) (loadedTopology links.Link, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "_id", Value: idOfTopology.String()}}) if result == nil { return loadedTopology, customerrs.CouldNotFindError{ID: idOfTopology} } @@ -84,22 +75,8 @@ func (s *DatabaseTopologyStore) getByID(idOfTopology uuid.UUID) (loadedTopology return loadedTopology, nil } -func (s *DatabaseTopologyStore) getByName(nameOfTopology string) (loadedTopology links.Link, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return loadedTopology, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - result := collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfTopology}}) +func (s *DatabaseTopologyStore) getByName(ctx context.Context, nameOfTopology string) (loadedTopology links.Link, err error) { + result := s.collection.FindOne(ctx, bson.D{primitive.E{Key: "name", Value: nameOfTopology}}) if result == nil { return loadedTopology, customerrs.CouldNotFindError{Name: nameOfTopology} } @@ -113,23 +90,8 @@ func (s *DatabaseTopologyStore) getByName(nameOfTopology string) (loadedTopology } // GetAll returns all stored links. -func (s *DatabaseTopologyStore) GetAll() (loadedTopology []links.Link, err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return nil, err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - - cursor, err := collection.Find(ctx, bson.D{}) +func (s *DatabaseTopologyStore) GetAll(ctx context.Context) (loadedTopology []links.Link, err error) { + cursor, err := s.collection.Find(ctx, bson.D{}) if err != nil { return loadedTopology, err } @@ -149,21 +111,8 @@ func (s *DatabaseTopologyStore) GetAll() (loadedTopology []links.Link, err error } // Add adds a link to the link store. -func (s *DatabaseTopologyStore) Add(link links.Link) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - _, err = client.Database(database.DatabaseName). - Collection(s.storeName). +func (s *DatabaseTopologyStore) Add(ctx context.Context, link links.Link) (err error) { + _, err = s.collection. InsertOne(ctx, link) if err != nil { return customerrs.CouldNotCreateError{Identifier: link.ID, Type: link, Err: err} @@ -173,21 +122,9 @@ func (s *DatabaseTopologyStore) Add(link links.Link) (err error) { } // Update updates a existing link. -func (s *DatabaseTopologyStore) Update(linkToUpdate links.Link) (err error) { +func (s *DatabaseTopologyStore) Update(ctx context.Context, linkToUpdate links.Link) (err error) { var updatedLink links.Link - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - update := bson.D{primitive.E{Key: "$set", Value: linkToUpdate}} upsert := false @@ -197,8 +134,7 @@ func (s *DatabaseTopologyStore) Update(linkToUpdate links.Link) (err error) { ReturnDocument: &after, } - err = client.Database(database.DatabaseName). - Collection(s.storeName). + err = s.collection. FindOneAndUpdate( ctx, bson.M{"_id": linkToUpdate.ID.String()}, update, &opt). Decode(&updatedLink) @@ -210,22 +146,8 @@ func (s *DatabaseTopologyStore) Update(linkToUpdate links.Link) (err error) { } // Delete deletes a link from the link store. -func (s *DatabaseTopologyStore) Delete(linkToDelete links.Link) (err error) { - client, ctx, cancel, err := database.GetMongoConnection() - if err != nil { - return err - } - defer cancel() - defer func() { - if ferr := client.Disconnect(ctx); ferr != nil { - fErrString := ferr.Error() - err = fmt.Errorf("InternalError=%w DeferError=%+s", err, fErrString) - } - }() - - db := client.Database(database.DatabaseName) - collection := db.Collection(s.storeName) - _, err = collection.DeleteOne(ctx, bson.D{primitive.E{Key: linkToDelete.ID.String()}}) +func (s *DatabaseTopologyStore) Delete(ctx context.Context, linkToDelete links.Link) (err error) { + _, err = s.collection.DeleteOne(ctx, bson.D{primitive.E{Key: linkToDelete.ID.String()}}) if err != nil { return err } -- GitLab