diff --git a/controller/api/initialise_test.go b/controller/api/initialise_test.go index 8638d200c325aef24fe3d8dac48c67bc20191bc1..21fc210175207bc3ccd3a766c040aada357f9184 100644 --- a/controller/api/initialise_test.go +++ b/controller/api/initialise_test.go @@ -142,9 +142,7 @@ func bootstrapUnitTest() { jwtManager := rbacImpl.NewJWTManager("", (10000 * time.Hour)) - northbound := nbi.NewNBI(pndStore, userService, roleService) - northbound.Auth = nbi.NewAuthServer(jwtManager) - northbound.User = nbi.NewUserServer(jwtManager) + northbound := nbi.NewNBI(pndStore, userService, roleService, *jwtManager) cpb.RegisterCoreServiceServer(s, northbound.Core) ppb.RegisterPndServiceServer(s, northbound.Pnd) diff --git a/controller/controller.go b/controller/controller.go index eebc8ae8c87515631c1aea22cafc4fd095bb7a26..cb853c45d6b94fc97552340672830dfad5706de3 100644 --- a/controller/controller.go +++ b/controller/controller.go @@ -106,11 +106,9 @@ func startGrpc() error { log.Infof("listening to %v", lislisten.Addr()) jwtManager := rbacImpl.NewJWTManager(config.JWTSecret, config.JWTDuration) - setupGRPCServerWithCorrectSecurityLevel(jwtManager) + setupGRPCServerWithCorrectSecurityLevel(jwtManager, c.userService, c.roleService) - c.nbi = nbi.NewNBI(c.pndStore, c.userService, c.roleService) - c.nbi.Auth = nbi.NewAuthServer(jwtManager) - c.nbi.User = nbi.NewUserServer(jwtManager) + c.nbi = nbi.NewNBI(c.pndStore, c.userService, c.roleService, *jwtManager) pb.RegisterCoreServiceServer(c.grpcServer, c.nbi.Core) ppb.RegisterPndServiceServer(c.grpcServer, c.nbi.Pnd) @@ -288,13 +286,13 @@ func callback(id uuid.UUID, ch chan device.Details) { // This allows users to operate on the controller without any authentication/authorization, // but they could still login if they want to. // Use insecure only for testing purposes and with caution. -func setupGRPCServerWithCorrectSecurityLevel(jwt *rbacImpl.JWTManager) { +func setupGRPCServerWithCorrectSecurityLevel(jwt *rbacImpl.JWTManager, userService rbac.UserService, roleService rbac.RoleService) { securityLevel := viper.GetString("security") if securityLevel == "insecure" { c.grpcServer = grpc.NewServer() log.Info("set up grpc server in insecure mode") } else { - interceptor := server.NewAuthInterceptor(jwt) + interceptor := server.NewAuthInterceptor(jwt, userService, roleService) c.grpcServer = grpc.NewServer(grpc.UnaryInterceptor(interceptor.Unary()), grpc.StreamInterceptor(interceptor.Stream())) log.Info("set up grpc server in secure mode") } diff --git a/controller/northbound/server/auth.go b/controller/northbound/server/auth.go index bf00968736113bf76cacb96fcf958454379a0c40..9abf3a936b6ac0568d125b5ca71f920c67bfba01 100644 --- a/controller/northbound/server/auth.go +++ b/controller/northbound/server/auth.go @@ -6,6 +6,7 @@ import ( "time" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" + rbacInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" "code.fbi.h-da.de/danet/gosdn/controller/metrics" "code.fbi.h-da.de/danet/gosdn/controller/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -19,13 +20,15 @@ import ( // Auth holds a JWTManager and represents a AuthServiceServer. type Auth struct { apb.UnimplementedAuthServiceServer - jwtManager *rbac.JWTManager + jwtManager *rbac.JWTManager + userService rbacInterfaces.UserService } -// NewAuthServer receives a JWTManager and returns a new Auth interface. -func NewAuthServer(jwtManager *rbac.JWTManager) *Auth { +// NewAuthServer receives a JWTManager and a userService and returns a new Auth interface. +func NewAuthServer(jwtManager *rbac.JWTManager, userService rbacInterfaces.UserService) *Auth { return &Auth{ - jwtManager: jwtManager, + jwtManager: jwtManager, + userService: userService, } } @@ -52,14 +55,14 @@ func (s Auth) Login(ctx context.Context, request *apb.LoginRequest) (*apb.LoginR return nil, err } - userToUpdate, err := userService.Get(store.Query{Name: user.UserName}) + userToUpdate, err := s.userService.Get(store.Query{Name: user.UserName}) if err != nil { return nil, err } userToUpdate.SetToken(token) - err = userService.Update(userToUpdate) + err = s.userService.Update(userToUpdate) if err != nil { return nil, err } @@ -90,7 +93,7 @@ func (s Auth) Logout(ctx context.Context, request *apb.LogoutRequest) (*apb.Logo // isValidUser checks if the provided user name fits to a stored one and then checks if the provided password is correct. func (s Auth) isValidUser(user rbac.User) error { - storedUser, err := userService.Get(store.Query{Name: user.Name()}) + storedUser, err := s.userService.Get(store.Query{Name: user.Name()}) if err != nil { return err } @@ -136,7 +139,7 @@ func (s Auth) handleLogout(ctx context.Context, userName string) error { return status.Errorf(codes.Aborted, "missing match of user associated to token and provided user name") } - storedUser, err := userService.Get(store.Query{Name: userName}) + storedUser, err := s.userService.Get(store.Query{Name: userName}) if err != nil { return err } @@ -145,7 +148,7 @@ func (s Auth) handleLogout(ctx context.Context, userName string) error { return status.Errorf(codes.Aborted, "missing match of token provied for user") } - err = userService.Update(&rbac.User{UserID: storedUser.ID(), + err = s.userService.Update(&rbac.User{UserID: storedUser.ID(), UserName: storedUser.Name(), Roles: storedUser.GetRoles(), Password: storedUser.GetPassword(), diff --git a/controller/northbound/server/auth_interceptor.go b/controller/northbound/server/auth_interceptor.go index 9c5f7a746cfd0d0627a75188ed04eec409885947..11338596bbb311850260f1c610809ed1f5f895cc 100644 --- a/controller/northbound/server/auth_interceptor.go +++ b/controller/northbound/server/auth_interceptor.go @@ -4,6 +4,8 @@ import ( "context" "time" + rbacInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" + csbipb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/csbi" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" "code.fbi.h-da.de/danet/gosdn/controller/rbac" @@ -16,13 +18,21 @@ import ( // AuthInterceptor provides an AuthInterceptor type AuthInterceptor struct { - jwtManager *rbac.JWTManager + jwtManager *rbac.JWTManager + userService rbacInterfaces.UserService + roleService rbacInterfaces.RoleService } // NewAuthInterceptor receives a JWTManager and a rbacMand returns a new AuthInterceptor provding gRPC Interceptor functionality. -func NewAuthInterceptor(jwtManager *rbac.JWTManager) *AuthInterceptor { +func NewAuthInterceptor( + jwtManager *rbac.JWTManager, + userService rbacInterfaces.UserService, + roleService rbacInterfaces.RoleService, +) *AuthInterceptor { return &AuthInterceptor{ - jwtManager: jwtManager, + jwtManager: jwtManager, + userService: userService, + roleService: roleService, } } @@ -82,7 +92,7 @@ func (auth *AuthInterceptor) authorize(ctx context.Context, method string) error return status.Errorf(codes.PermissionDenied, "token expired at %v, please login", time.Unix(claims.ExpiresAt, 0)) } - user, err := userService.Get(store.Query{Name: claims.Username}) + user, err := auth.userService.Get(store.Query{Name: claims.Username}) if err != nil { return err } @@ -114,7 +124,7 @@ func (auth *AuthInterceptor) verifyPermisisonForRequestedCall(userRoles map[stri } func (auth *AuthInterceptor) verifyUserRoleAndRequestedCall(userRole, requestedMethod string) error { - storedRoles, err := roleService.GetAll() + storedRoles, err := auth.roleService.GetAll() if err != nil { return err } diff --git a/controller/northbound/server/auth_interceptor_test.go b/controller/northbound/server/auth_interceptor_test.go index 499a63b2c31cbecfe8c97b0512448a4572821462..0b4bb3757e4e00fbb5694ea1ccf0b90bd6a0e3d2 100644 --- a/controller/northbound/server/auth_interceptor_test.go +++ b/controller/northbound/server/auth_interceptor_test.go @@ -5,23 +5,52 @@ import ( "log" "net" "testing" + "time" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" spb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/southbound" + "code.fbi.h-da.de/danet/gosdn/controller/nucleus" + "code.fbi.h-da.de/danet/gosdn/controller/rbac" "google.golang.org/grpc" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/bufconn" ) -func dialer() func(context.Context, string) (net.Conn, error) { +func getTestAuthInterceptorServer(t *testing.T) (*AuthInterceptor, *User, *Role, *SbiServer) { + initUUIDs(t) + jwtManager := rbac.NewJWTManager("test", time.Minute) + userStore := rbac.NewMemoryUserStore() + userService := rbac.NewUserService(userStore) + + roleStore := rbac.NewMemoryRoleStore() + roleService := rbac.NewRoleService(roleStore) + + mockPnd := getMockPnd(t) + + pndStore := nucleus.NewMemoryPndStore() + if err := pndStore.Add(mockPnd); err != nil { + t.Fatal(err) + } + + s := NewAuthInterceptor(jwtManager, userService, roleService) + u := NewUserServer(jwtManager, userService) + r := NewRoleServer(jwtManager, roleService) + sbiServer := NewSbiServer(pndStore) + + clearAndCreateAuthTestSetup(userService, roleService) + + return s, u, r, sbiServer +} + +func dialer(interceptorServer *AuthInterceptor, userServer *User, roleServer *Role, sbiServer *SbiServer) func(context.Context, string) (net.Conn, error) { listener := bufconn.Listen(1024 * 1024) - interceptor := NewAuthInterceptor(jwt) + interceptor := interceptorServer server := grpc.NewServer(grpc.UnaryInterceptor(interceptor.Unary()), grpc.StreamInterceptor(interceptor.Stream())) - apb.RegisterUserServiceServer(server, &User{}) - spb.RegisterSbiServiceServer(server, &sbiServer{}) + apb.RegisterUserServiceServer(server, userServer) + spb.RegisterSbiServiceServer(server, sbiServer) go func() { if err := server.Serve(listener); err != nil { @@ -35,20 +64,26 @@ func dialer() func(context.Context, string) (net.Conn, error) { } func TestAuthInterceptor_Unary(t *testing.T) { - validToken, err := createTestUserToken("testAdmin", true) + authServer, userServer, roleServer, sbiServer := getTestAuthInterceptorServer(t) + validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager) if err != nil { - log.Fatal(err) + t.Fatal(err) } - wrongUserToken, err := createTestUserToken("foo", false) + wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager) if err != nil { - log.Fatal(err) + t.Fatal(err) } ctx := context.Background() - conn, err := grpc.DialContext(ctx, "", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(dialer())) + conn, err := grpc.DialContext( + ctx, + "", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(dialer(authServer, userServer, roleServer, sbiServer)), + ) if err != nil { - log.Fatal(err) + t.Fatal(err) } defer conn.Close() @@ -121,15 +156,21 @@ func TestAuthInterceptor_Unary(t *testing.T) { } func TestAuthInterceptor_Stream(t *testing.T) { - validToken, err := createTestUserToken("testAdmin", true) + authServer, userServer, roleServer, sbiServer := getTestAuthInterceptorServer(t) + validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager) if err != nil { - log.Fatal(err) + t.Fatal(err) } ctx := context.Background() - conn, err := grpc.DialContext(ctx, "", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(dialer())) + conn, err := grpc.DialContext( + ctx, + "", + grpc.WithTransportCredentials(insecure.NewCredentials()), + grpc.WithContextDialer(dialer(authServer, userServer, roleServer, sbiServer)), + ) if err != nil { - log.Fatal(err) + t.Fatal(err) } defer conn.Close() @@ -186,14 +227,15 @@ func TestAuthInterceptor_Stream(t *testing.T) { } func TestAuthInterceptor_authorize(t *testing.T) { - validToken, err := createTestUserToken("testAdmin", true) + authServer, _, _, _ := getTestAuthInterceptorServer(t) + validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager) if err != nil { - log.Fatal(err) + t.Fatal(err) } - wrongUserToken, err := createTestUserToken("foo", false) + wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager) if err != nil { - log.Fatal(err) + t.Fatal(err) } type args struct { @@ -232,11 +274,7 @@ func TestAuthInterceptor_authorize(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - auth := &AuthInterceptor{ - jwtManager: jwt, - } - - if err := auth.authorize(tt.args.ctx, tt.args.method); (err != nil) != tt.wantErr { + if err := authServer.authorize(tt.args.ctx, tt.args.method); (err != nil) != tt.wantErr { t.Errorf("AuthInterceptor.authorize() error = %v, wantErr %v", err, tt.wantErr) } }) diff --git a/controller/northbound/server/auth_test.go b/controller/northbound/server/auth_test.go index 9a875ec84e1a9485e6da80759f9672729073deb1..fa3e0d9ef891be7ff9d4f961b7501fb0d3c931c6 100644 --- a/controller/northbound/server/auth_test.go +++ b/controller/northbound/server/auth_test.go @@ -4,12 +4,31 @@ import ( "context" "log" "testing" + "time" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" "code.fbi.h-da.de/danet/gosdn/controller/rbac" "google.golang.org/grpc/metadata" ) +func getTestAuthServer(t *testing.T) *Auth { + jwtManager := rbac.NewJWTManager("test", time.Minute) + + userStore := rbac.NewMemoryUserStore() + userService := rbac.NewUserService(userStore) + + roleStore := rbac.NewMemoryRoleStore() + roleService := rbac.NewRoleService(roleStore) + + s := NewAuthServer(jwtManager, userService) + err := clearAndCreateAuthTestSetup(s.userService, roleService) + if err != nil { + t.Fatalf("%v", err) + } + + return s +} + func TestAuth_Login(t *testing.T) { type args struct { ctx context.Context @@ -47,9 +66,7 @@ func TestAuth_Login(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - r := Auth{ - jwtManager: jwt, - } + r := getTestAuthServer(t) resp, err := r.Login(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("Auth.Login() error = %v, wantErr %v", err, tt.wantErr) @@ -67,7 +84,8 @@ func TestAuth_Login(t *testing.T) { } func TestAuth_Logout(t *testing.T) { - validToken, err := createTestUserToken("testAdmin", true) + s := getTestAuthServer(t) + validToken, err := createTestUserToken("testAdmin", true, s.userService, s.jwtManager) if err != nil { log.Fatal(err) } @@ -99,9 +117,6 @@ func TestAuth_Logout(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Auth{ - jwtManager: jwt, - } got, err := s.Logout(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("Auth.Logout() error = %v, wantErr %v", err, tt.wantErr) @@ -158,7 +173,7 @@ func TestAuth_isValidUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Auth{} + s := getTestAuthServer(t) if err := s.isValidUser(tt.args.user); (err != nil) != tt.wantErr { t.Errorf("Auth.isValidUser() error = %v, wantErr %v", err, tt.wantErr) } @@ -167,12 +182,13 @@ func TestAuth_isValidUser(t *testing.T) { } func TestAuth_handleLogout(t *testing.T) { - validToken, err := createTestUserToken("testAdmin", true) + s := getTestAuthServer(t) + validToken, err := createTestUserToken("testAdmin", true, s.userService, s.jwtManager) if err != nil { log.Fatal(err) } - invalidToken, err := createTestUserToken("testAdmin", false) + invalidToken, err := createTestUserToken("testAdmin", false, s.userService, s.jwtManager) if err != nil { log.Fatal(err) } @@ -221,9 +237,6 @@ func TestAuth_handleLogout(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Auth{ - jwtManager: jwt, - } if err := s.handleLogout(tt.args.ctx, tt.args.userName); (err != nil) != tt.wantErr { t.Errorf("Auth.handleLogout() error = %v, wantErr %v", err, tt.wantErr) } diff --git a/controller/northbound/server/core.go b/controller/northbound/server/core.go index 13a7304c5975738d97f543a7f02f6452c25da171..5a167a4b46a10851871709d9916056f7d108cf8a 100644 --- a/controller/northbound/server/core.go +++ b/controller/northbound/server/core.go @@ -6,6 +6,7 @@ import ( pb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/core" ppb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/pnd" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" "code.fbi.h-da.de/danet/gosdn/controller/metrics" "code.fbi.h-da.de/danet/gosdn/controller/nucleus" "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -13,11 +14,21 @@ import ( "github.com/prometheus/client_golang/prometheus" ) -type core struct { +// Core represents a core server +type Core struct { pb.UnimplementedCoreServiceServer + pndStore networkdomain.PndStore } -func (s core) GetPnd(ctx context.Context, request *pb.GetPndRequest) (*pb.GetPndResponse, error) { +// NewCoreServer receives a pndStore and returns a new coreServer. +func NewCoreServer(pndStore networkdomain.PndStore) *Core { + return &Core{ + pndStore: pndStore, + } +} + +// GetPnd returns a existing pnd +func (s Core) GetPnd(ctx context.Context, request *pb.GetPndRequest) (*pb.GetPndResponse, error) { labels := prometheus.Labels{"service": "core", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -27,7 +38,7 @@ func (s core) GetPnd(ctx context.Context, request *pb.GetPndRequest) (*pb.GetPnd return nil, handleRPCError(labels, err) } - storedPnd, err := pndc.Get(store.Query{ID: pndID}) + storedPnd, err := s.pndStore.Get(store.Query{ID: pndID}) if err != nil { return nil, err } @@ -44,12 +55,13 @@ func (s core) GetPnd(ctx context.Context, request *pb.GetPndRequest) (*pb.GetPnd }, nil } -func (s core) GetPndList(ctx context.Context, request *pb.GetPndListRequest) (*pb.GetPndListResponse, error) { +// GetPndList returns all existing pnds +func (s Core) GetPndList(ctx context.Context, request *pb.GetPndListRequest) (*pb.GetPndListResponse, error) { labels := prometheus.Labels{"service": "core", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - pndList, err := pndc.GetAll() + pndList, err := s.pndStore.GetAll() if err != nil { return nil, err } @@ -68,7 +80,8 @@ func (s core) GetPndList(ctx context.Context, request *pb.GetPndListRequest) (*p }, nil } -func (s core) CreatePndList(ctx context.Context, request *pb.CreatePndListRequest) (*pb.CreatePndListResponse, error) { +// CreatePndList creates a pnd list +func (s Core) CreatePndList(ctx context.Context, request *pb.CreatePndListRequest) (*pb.CreatePndListResponse, error) { labels := prometheus.Labels{"service": "core", "rpc": "set"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -77,7 +90,7 @@ func (s core) CreatePndList(ctx context.Context, request *pb.CreatePndListReques if err != nil { return nil, handleRPCError(labels, err) } - if err := pndc.Add(pnd); err != nil { + if err := s.pndStore.Add(pnd); err != nil { return nil, handleRPCError(labels, err) } } @@ -87,7 +100,8 @@ func (s core) CreatePndList(ctx context.Context, request *pb.CreatePndListReques }, nil } -func (s core) DeletePnd(ctx context.Context, request *pb.DeletePndRequest) (*pb.DeletePndResponse, error) { +// DeletePnd deletes an existing pnd +func (s Core) DeletePnd(ctx context.Context, request *pb.DeletePndRequest) (*pb.DeletePndResponse, error) { labels := prometheus.Labels{"service": "core", "rpc": "set"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -97,11 +111,11 @@ func (s core) DeletePnd(ctx context.Context, request *pb.DeletePndRequest) (*pb. return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pndID}) + pnd, err := s.pndStore.Get(store.Query{ID: pndID}) if err != nil { return nil, handleRPCError(labels, err) } - err = pndc.Delete(pnd) + err = s.pndStore.Delete(pnd) if err != nil { return &pb.DeletePndResponse{ Timestamp: time.Now().UnixNano(), diff --git a/controller/northbound/server/core_test.go b/controller/northbound/server/core_test.go index ba01c8c4c930d12f0a80fc71e3ee89f568205aff..333ac2eb6707fae6ecf467c8f5f57b3e9ffcaf25 100644 --- a/controller/northbound/server/core_test.go +++ b/controller/northbound/server/core_test.go @@ -7,8 +7,97 @@ import ( "time" pb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/core" + ppb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/pnd" + spb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/southbound" + "code.fbi.h-da.de/danet/gosdn/controller/mocks" + "code.fbi.h-da.de/danet/gosdn/controller/nucleus" + "code.fbi.h-da.de/danet/gosdn/models/generated/openconfig" + "github.com/google/uuid" + "github.com/stretchr/testify/mock" ) +func getTestCoreServer(t *testing.T) *Core { + var err error + pndUUID, err = uuid.Parse(pndID) + if err != nil { + t.Fatal(err) + } + + sbiUUID, err = uuid.Parse(sbiID) + if err != nil { + t.Fatal(err) + } + + pendingChangeUUID, err = uuid.Parse(pendingChangeID) + if err != nil { + t.Fatal(err) + } + + committedChangeUUID, err = uuid.Parse(committedChangeID) + if err != nil { + t.Fatal(err) + } + + deviceUUID, err = uuid.Parse(ondID) + if err != nil { + t.Fatal(err) + } + + mockDevice = &nucleus.CommonDevice{ + Model: &openconfig.Device{ + System: &openconfig.OpenconfigSystem_System{ + Config: &openconfig.OpenconfigSystem_System_Config{ + Hostname: &hostname, + DomainName: &domainname, + }, + }, + }, + UUID: deviceUUID, + } + + sbi, err := nucleus.NewSBI(spb.Type_TYPE_OPENCONFIG, sbiUUID) + if err != nil { + t.Fatal(err) + } + mockDevice.(*nucleus.CommonDevice).SetSBI(sbi) + mockDevice.(*nucleus.CommonDevice).SetTransport(&mocks.Transport{}) + mockDevice.(*nucleus.CommonDevice).SetName(hostname) + sbiStore = nucleus.NewSbiStore(pndUUID) + if err := sbiStore.Add(mockDevice.SBI()); err != nil { + t.Fatal(err) + } + + mockChange := &mocks.Change{} + mockChange.On("Age").Return(time.Hour) + mockChange.On("State").Return(ppb.ChangeState_CHANGE_STATE_INCONSISTENT) + + mockPnd = &mocks.NetworkDomain{} + mockPnd.On("ID").Return(pndUUID) + mockPnd.On("GetName").Return("test") + mockPnd.On("GetDescription").Return("test") + mockPnd.On("GetSBIs").Return(sbiStore) + mockPnd.On("GetSBI", mock.Anything).Return(mockDevice.SBI(), nil) + mockPnd.On("Devices").Return([]uuid.UUID{deviceUUID}) + mockPnd.On("PendingChanges").Return([]uuid.UUID{pendingChangeUUID}) + mockPnd.On("CommittedChanges").Return([]uuid.UUID{committedChangeUUID}) + mockPnd.On("GetChange", mock.Anything).Return(mockChange, nil) + mockPnd.On("AddDevice", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockPnd.On("GetDevice", mock.Anything).Return(mockDevice, nil) + mockPnd.On("Commit", mock.Anything).Return(nil) + mockPnd.On("Confirm", mock.Anything).Return(nil) + mockPnd.On("ChangeOND", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(uuid.Nil, nil) + mockPnd.On("Request", mock.Anything, mock.Anything).Return(nil, nil) + + pndStore := nucleus.NewMemoryPndStore() + if err := pndStore.Add(mockPnd); err != nil { + t.Fatal(err) + } + + c := NewCoreServer(pndStore) + + return c +} + func Test_core_Set(t *testing.T) { type args struct { ctx context.Context @@ -41,9 +130,7 @@ func Test_core_Set(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := core{ - UnimplementedCoreServiceServer: pb.UnimplementedCoreServiceServer{}, - } + s := getTestCoreServer(t) got, err := s.CreatePndList(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("core.Set() error = %v, wantErr %v", err, tt.wantErr) @@ -86,9 +173,7 @@ func Test_core_GetPnd(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := core{ - UnimplementedCoreServiceServer: pb.UnimplementedCoreServiceServer{}, - } + s := getTestCoreServer(t) resp, err := s.GetPnd(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("core.GetPnd() error = %v, wantErr %v", err, tt.wantErr) @@ -129,14 +214,12 @@ func Test_core_GetPndList(t *testing.T) { Timestamp: time.Now().UnixNano(), }, }, - length: 2, + length: 1, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := core{ - UnimplementedCoreServiceServer: pb.UnimplementedCoreServiceServer{}, - } + s := getTestCoreServer(t) resp, err := s.GetPndList(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("core.GetPndList() error = %v, wantErr %v", err, tt.wantErr) diff --git a/controller/northbound/server/csbi.go b/controller/northbound/server/csbi.go index acfb4d67078ef410c7b6267ec2add229dcc16ee1..31bf08b981c88d15094486408464f3ece382a6a4 100644 --- a/controller/northbound/server/csbi.go +++ b/controller/northbound/server/csbi.go @@ -11,6 +11,7 @@ import ( cpb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/csbi" "code.fbi.h-da.de/danet/gosdn/controller/interfaces/device" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" "code.fbi.h-da.de/danet/gosdn/controller/metrics" "code.fbi.h-da.de/danet/gosdn/controller/store" "google.golang.org/grpc/codes" @@ -18,15 +19,25 @@ import ( "google.golang.org/grpc/status" ) -type csbi struct { +// Csbi represents a csbi server +type Csbi struct { cpb.UnimplementedCsbiServiceServer + pndStore networkdomain.PndStore } -func (s csbi) Hello(ctx context.Context, syn *cpb.Syn) (*cpb.Ack, error) { +// NewCsbiServer receives a pndStore and returns a new csbiServer. +func NewCsbiServer(pndStore networkdomain.PndStore) *Csbi { + return &Csbi{ + pndStore: pndStore, + } +} + +// Hello is used for tests +func (s Csbi) Hello(ctx context.Context, syn *cpb.Syn) (*cpb.Ack, error) { labels := prometheus.Labels{"service": "csbi", "rpc": "hello"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - ch, err := pndc.PendingChannels(store.FromString(syn.Id)) + ch, err := s.pndStore.PendingChannels(store.FromString(syn.Id)) if err != nil { return nil, handleRPCError(labels, err) } diff --git a/controller/northbound/server/nbi.go b/controller/northbound/server/nbi.go index 040407e326693301a4e8050cf004185f70b60748..cd4e983516f8ef4d498ebdf7fb956333e796a0be 100644 --- a/controller/northbound/server/nbi.go +++ b/controller/northbound/server/nbi.go @@ -2,7 +2,9 @@ package server import ( "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" - "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" + rbacInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" + "code.fbi.h-da.de/danet/gosdn/controller/rbac" + "code.fbi.h-da.de/danet/gosdn/controller/metrics" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" @@ -10,35 +12,28 @@ import ( "google.golang.org/grpc/status" ) -var pndc networkdomain.PndStore -var userService rbac.UserService -var roleService rbac.RoleService - // NorthboundInterface is the representation of the // gRPC services used provided. type NorthboundInterface struct { - Pnd *pndServer - Core *core - Csbi *csbi - Sbi *sbiServer + Pnd *PndServer + Core *Core + Csbi *Csbi + Sbi *SbiServer Auth *Auth User *User Role *Role } // NewNBI receives a PndStore and returns a new gRPC *NorthboundInterface -func NewNBI(pnds networkdomain.PndStore, users rbac.UserService, roles rbac.RoleService) *NorthboundInterface { - pndc = pnds - userService = users - roleService = roles +func NewNBI(pnds networkdomain.PndStore, users rbacInterfaces.UserService, roles rbacInterfaces.RoleService, jwt rbac.JWTManager) *NorthboundInterface { return &NorthboundInterface{ - Pnd: &pndServer{}, - Core: &core{}, - Csbi: &csbi{}, - Sbi: &sbiServer{}, - Auth: &Auth{}, - User: &User{}, - Role: &Role{}, + Pnd: NewPndServer(pnds), + Core: NewCoreServer(pnds), + Csbi: NewCsbiServer(pnds), + Sbi: NewSbiServer(pnds), + Auth: NewAuthServer(&jwt, users), + User: NewUserServer(&jwt, users), + Role: NewRoleServer(&jwt, roles), } } diff --git a/controller/northbound/server/pnd.go b/controller/northbound/server/pnd.go index 9f3f40636c89316bf2f67b5e0543cb4f413cb130..5124730e44bda354d81c0433a395783a5221b462 100644 --- a/controller/northbound/server/pnd.go +++ b/controller/northbound/server/pnd.go @@ -23,11 +23,21 @@ import ( "google.golang.org/grpc/status" ) -type pndServer struct { +// PndServer implements a pnd server +type PndServer struct { ppb.UnimplementedPndServiceServer + pndStore networkdomain.PndStore } -func (p pndServer) GetOnd(ctx context.Context, request *ppb.GetOndRequest) (*ppb.GetOndResponse, error) { +// NewPndServer receives a pndStore and returns a new pndServer. +func NewPndServer(pndStore networkdomain.PndStore) *PndServer { + return &PndServer{ + pndStore: pndStore, + } +} + +// GetOnd gets a specific ond +func (p PndServer) GetOnd(ctx context.Context, request *ppb.GetOndRequest) (*ppb.GetOndResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -36,7 +46,7 @@ func (p pndServer) GetOnd(ctx context.Context, request *ppb.GetOndRequest) (*ppb return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -59,7 +69,8 @@ func (p pndServer) GetOnd(ctx context.Context, request *ppb.GetOndRequest) (*ppb }, nil } -func (p pndServer) GetOndList(ctx context.Context, request *ppb.GetOndListRequest) (*ppb.GetOndListResponse, error) { +// GetOndList returns a list of existing onds +func (p PndServer) GetOndList(ctx context.Context, request *ppb.GetOndListRequest) (*ppb.GetOndListResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -68,7 +79,7 @@ func (p pndServer) GetOndList(ctx context.Context, request *ppb.GetOndListReques return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -163,7 +174,8 @@ func genGnmiNotification(path *gnmi.Path, val any) (*gnmi.Notification, error) { }, nil } -func (p pndServer) GetSbi(ctx context.Context, request *ppb.GetSbiRequest) (*ppb.GetSbiResponse, error) { +// GetSbi gets a specific sbi +func (p PndServer) GetSbi(ctx context.Context, request *ppb.GetSbiRequest) (*ppb.GetSbiResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -172,7 +184,7 @@ func (p pndServer) GetSbi(ctx context.Context, request *ppb.GetSbiRequest) (*ppb return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -202,7 +214,8 @@ func (p pndServer) GetSbi(ctx context.Context, request *ppb.GetSbiRequest) (*ppb }, nil } -func (p pndServer) GetSbiList(ctx context.Context, request *ppb.GetSbiListRequest) (*ppb.GetSbiListResponse, error) { +// GetSbiList gets all existing sbis +func (p PndServer) GetSbiList(ctx context.Context, request *ppb.GetSbiListRequest) (*ppb.GetSbiListResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -211,7 +224,7 @@ func (p pndServer) GetSbiList(ctx context.Context, request *ppb.GetSbiListReques return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -266,7 +279,8 @@ func stringArrayToUUIDs(sid []string) ([]uuid.UUID, error) { return UUIDs, nil } -func (p pndServer) GetPath(ctx context.Context, request *ppb.GetPathRequest) (*ppb.GetPathResponse, error) { +// GetPath gets a path on a ond +func (p PndServer) GetPath(ctx context.Context, request *ppb.GetPathRequest) (*ppb.GetPathResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -275,7 +289,7 @@ func (p pndServer) GetPath(ctx context.Context, request *ppb.GetPathRequest) (*p return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -312,7 +326,8 @@ func (p pndServer) GetPath(ctx context.Context, request *ppb.GetPathRequest) (*p }, nil } -func (p pndServer) GetChange(ctx context.Context, request *ppb.GetChangeRequest) (*ppb.GetChangeResponse, error) { +// GetChange gets a specific change of a ond +func (p PndServer) GetChange(ctx context.Context, request *ppb.GetChangeRequest) (*ppb.GetChangeResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -321,7 +336,7 @@ func (p pndServer) GetChange(ctx context.Context, request *ppb.GetChangeRequest) return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -342,7 +357,8 @@ func (p pndServer) GetChange(ctx context.Context, request *ppb.GetChangeRequest) }, nil } -func (p pndServer) GetChangeList(ctx context.Context, request *ppb.GetChangeListRequest) (*ppb.GetChangeListResponse, error) { +// GetChangeList gets all existing changes +func (p PndServer) GetChangeList(ctx context.Context, request *ppb.GetChangeListRequest) (*ppb.GetChangeListResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "get"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -351,7 +367,7 @@ func (p pndServer) GetChangeList(ctx context.Context, request *ppb.GetChangeList return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -418,7 +434,8 @@ func fillChanges(pnd networkdomain.NetworkDomain, all bool, cuid ...string) ([]* return changes, nil } -func (p pndServer) SetOndList(ctx context.Context, request *ppb.SetOndListRequest) (*ppb.SetOndListResponse, error) { +// SetOndList updates the list of onds +func (p PndServer) SetOndList(ctx context.Context, request *ppb.SetOndListRequest) (*ppb.SetOndListResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "set"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -427,7 +444,7 @@ func (p pndServer) SetOndList(ctx context.Context, request *ppb.SetOndListReques return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { return nil, handleRPCError(labels, err) } @@ -459,7 +476,8 @@ func (p pndServer) SetOndList(ctx context.Context, request *ppb.SetOndListReques }, nil } -func (p pndServer) SetChangeList(ctx context.Context, request *ppb.SetChangeListRequest) (*ppb.SetChangeListResponse, error) { +// SetChangeList sets a list of changes +func (p PndServer) SetChangeList(ctx context.Context, request *ppb.SetChangeListRequest) (*ppb.SetChangeListResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "set"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -468,7 +486,7 @@ func (p pndServer) SetChangeList(ctx context.Context, request *ppb.SetChangeList return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { return nil, handleRPCError(labels, err) } @@ -510,7 +528,8 @@ func (p pndServer) SetChangeList(ctx context.Context, request *ppb.SetChangeList }, nil } -func (p pndServer) SetPathList(ctx context.Context, request *ppb.SetPathListRequest) (*ppb.SetPathListResponse, error) { +// SetPathList sets a list of paths +func (p PndServer) SetPathList(ctx context.Context, request *ppb.SetPathListRequest) (*ppb.SetPathListResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "set"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -519,7 +538,7 @@ func (p pndServer) SetPathList(ctx context.Context, request *ppb.SetPathListRequ return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { return nil, handleRPCError(labels, err) } @@ -550,7 +569,8 @@ func (p pndServer) SetPathList(ctx context.Context, request *ppb.SetPathListRequ }, nil } -func (p pndServer) SetSbiList(ctx context.Context, request *ppb.SetSbiListRequest) (*ppb.SetSbiListResponse, error) { +// SetSbiList sets a list of sbis +func (p PndServer) SetSbiList(ctx context.Context, request *ppb.SetSbiListRequest) (*ppb.SetSbiListResponse, error) { labels := prometheus.Labels{"service": "pnd", "rpc": "set"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -559,7 +579,7 @@ func (p pndServer) SetSbiList(ctx context.Context, request *ppb.SetSbiListReques return nil, handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { return nil, handleRPCError(labels, err) } @@ -605,13 +625,14 @@ func filterSbiType(sbiType ppb.SbiType) spb.Type { return spbType } -func (p pndServer) DeleteOnd(ctx context.Context, request *ppb.DeleteOndRequest) (*ppb.DeleteOndResponse, error) { +// DeleteOnd deletes a ond +func (p PndServer) DeleteOnd(ctx context.Context, request *ppb.DeleteOndRequest) (*ppb.DeleteOndResponse, error) { pid, err := uuid.Parse(request.Pid) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := p.pndStore.Get(store.Query{ID: pid}) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) diff --git a/controller/northbound/server/pnd_test.go b/controller/northbound/server/pnd_test.go index 58771744dc7f0922f8a9a42ca89f89ef2b1aa66d..9534e6fdd0eeb58930057e3dabf8dc4e9eedac40 100644 --- a/controller/northbound/server/pnd_test.go +++ b/controller/northbound/server/pnd_test.go @@ -2,72 +2,47 @@ package server import ( "context" - "os" - "reflect" "testing" "time" - pb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/core" ppb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/pnd" spb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/southbound" - "code.fbi.h-da.de/danet/gosdn/controller/interfaces/device" - "code.fbi.h-da.de/danet/gosdn/controller/interfaces/southbound" "code.fbi.h-da.de/danet/gosdn/controller/mocks" "code.fbi.h-da.de/danet/gosdn/controller/nucleus" - "code.fbi.h-da.de/danet/gosdn/controller/rbac" "code.fbi.h-da.de/danet/gosdn/models/generated/openconfig" "github.com/golang/protobuf/proto" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/google/uuid" "github.com/openconfig/gnmi/proto/gnmi" - log "github.com/sirupsen/logrus" "github.com/stretchr/testify/mock" ) -const pndID = "2043519e-46d1-4963-9a8e-d99007e104b8" -const pendingChangeID = "0992d600-f7d4-4906-9559-409b04d59a5f" -const committedChangeID = "804787d6-e5a8-4dba-a1e6-e73f96b0119e" -const sbiID = "f6fd4b35-f039-4111-9156-5e4501bb8a5a" -const ondID = "7e0ed8cc-ebf5-46fa-9794-741494914883" - -var hostname = "manfred" -var domainname = "uwe" -var pndUUID uuid.UUID -var sbiUUID uuid.UUID -var pendingChangeUUID uuid.UUID -var committedChangeUUID uuid.UUID -var deviceUUID uuid.UUID -var mockPnd *mocks.NetworkDomain -var mockDevice device.Device -var sbiStore southbound.Store - -func TestMain(m *testing.M) { - log.SetReportCaller(true) +func getTestPndServer(t *testing.T) *PndServer { var err error pndUUID, err = uuid.Parse(pndID) if err != nil { - log.Fatal(err) + t.Fatal(err) } sbiUUID, err = uuid.Parse(sbiID) if err != nil { - log.Fatal(err) + t.Fatal(err) } pendingChangeUUID, err = uuid.Parse(pendingChangeID) if err != nil { - log.Fatal(err) + t.Fatal(err) } committedChangeUUID, err = uuid.Parse(committedChangeID) if err != nil { - log.Fatal(err) + t.Fatal(err) } deviceUUID, err = uuid.Parse(ondID) if err != nil { - log.Fatal(err) + t.Fatal(err) } mockDevice = &nucleus.CommonDevice{ @@ -84,14 +59,14 @@ func TestMain(m *testing.M) { sbi, err := nucleus.NewSBI(spb.Type_TYPE_OPENCONFIG, sbiUUID) if err != nil { - log.Fatal(err) + t.Fatal(err) } mockDevice.(*nucleus.CommonDevice).SetSBI(sbi) mockDevice.(*nucleus.CommonDevice).SetTransport(&mocks.Transport{}) mockDevice.(*nucleus.CommonDevice).SetName(hostname) sbiStore = nucleus.NewSbiStore(pndUUID) if err := sbiStore.Add(mockDevice.SBI()); err != nil { - log.Fatal(err) + t.Fatal(err) } mockChange := &mocks.Change{} @@ -115,75 +90,19 @@ func TestMain(m *testing.M) { mockPnd.On("ChangeOND", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(uuid.Nil, nil) mockPnd.On("Request", mock.Anything, mock.Anything).Return(nil, nil) - pndc = nucleus.NewMemoryPndStore() - if err := pndc.Add(mockPnd); err != nil { - log.Fatal(err) - } - - // everyting auth related - userService = rbac.NewUserService(rbac.NewMemoryUserStore()) - roleService = rbac.NewRoleService(rbac.NewMemoryRoleStore()) - err = clearAndCreateAuthTestSetup() - if err != nil { - log.Fatal(err) - } - jwt = rbac.NewJWTManager("", 1*time.Minute) - - os.Exit(m.Run()) -} - -// TODO: We should re-add all tests for changes. -// As of now this is not possible as we can't use the mock pnd, as it can't be serialized because of -// cyclic use of mock in it. -func Test_pnd_Get(t *testing.T) { - type args struct { - ctx context.Context - request *pb.GetPndRequest - } - tests := []struct { - name string - args args - want *pb.GetPndResponse - wantErr bool - }{ - { - name: "get pnd", - args: args{ - ctx: context.Background(), - request: &pb.GetPndRequest{ - Pid: pndID, - }, - }, - want: &pb.GetPndResponse{ - Pnd: &ppb.PrincipalNetworkDomain{ - Id: pndID, - Name: "test", - Description: "test"}, - }, - }, + pndStore := nucleus.NewMemoryPndStore() + if err := pndStore.Add(mockPnd); err != nil { + t.Fatal(err) } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - p := core{ - UnimplementedCoreServiceServer: pb.UnimplementedCoreServiceServer{}, - } - resp, err := p.GetPnd(tt.args.ctx, tt.args.request) - if (err != nil) != tt.wantErr { - t.Errorf("Get() error = %v, wantErr %v", err, tt.wantErr) - return - } - - got := resp.GetPnd() + c := NewPndServer(pndStore) - if !reflect.DeepEqual(got, tt.want.Pnd) { - t.Errorf("Get() got = %v, want %v", got, tt.want.Pnd) - } - }) - } + return c } func Test_pnd_GetPath(t *testing.T) { + initUUIDs(t) + opts := cmp.Options{ cmpopts.SortSlices( func(x, y *gnmi.Update) bool { @@ -209,9 +128,9 @@ func Test_pnd_GetPath(t *testing.T) { ctx: context.Background(), request: &ppb.GetPathRequest{ Timestamp: time.Now().UnixNano(), - Did: mockDevice.ID().String(), + Did: deviceUUID.String(), Path: "system/config/hostname", - Pid: mockPnd.ID().String(), + Pid: pndUUID.String(), }, }, want: []*gnmi.Notification{ @@ -247,9 +166,9 @@ func Test_pnd_GetPath(t *testing.T) { ctx: context.Background(), request: &ppb.GetPathRequest{ Timestamp: time.Now().UnixNano(), - Did: mockDevice.ID().String(), + Did: deviceUUID.String(), Path: "system", - Pid: mockPnd.ID().String(), + Pid: pndUUID.String(), }, }, want: []*gnmi.Notification{ @@ -279,9 +198,9 @@ func Test_pnd_GetPath(t *testing.T) { ctx: context.Background(), request: &ppb.GetPathRequest{ Timestamp: time.Now().UnixNano(), - Did: mockDevice.ID().String(), + Did: deviceUUID.String(), Path: "this/path/is/not/valid", - Pid: mockPnd.ID().String(), + Pid: pndUUID.String(), }, }, want: []*gnmi.Notification{}, @@ -290,9 +209,7 @@ func Test_pnd_GetPath(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := pndServer{ - UnimplementedPndServiceServer: ppb.UnimplementedPndServiceServer{}, - } + s := getTestPndServer(t) resp, err := s.GetPath(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("GetPath() error = %v, wantErr %v", err, tt.wantErr) diff --git a/controller/northbound/server/role.go b/controller/northbound/server/role.go index 4f4990d0fc47425d02da1a8257243cff49ce3b8d..fd742a96e6f21497de4985357d36447fe0090729 100644 --- a/controller/northbound/server/role.go +++ b/controller/northbound/server/role.go @@ -5,6 +5,7 @@ import ( "time" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" + rbacInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" "code.fbi.h-da.de/danet/gosdn/controller/metrics" "code.fbi.h-da.de/danet/gosdn/controller/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" @@ -18,13 +19,15 @@ import ( // Role holds a JWTManager and represents a RoleServiceServer. type Role struct { apb.UnimplementedRoleServiceServer - jwtManager *rbac.JWTManager + jwtManager *rbac.JWTManager + roleService rbacInterfaces.RoleService } -// NewRoleServer receives a JWTManager and returns a new Role. -func NewRoleServer(jwtManager *rbac.JWTManager) *Role { +// NewRoleServer receives a JWTManager and a RoleService and returns a new RoleServer. +func NewRoleServer(jwtManager *rbac.JWTManager, roleService rbacInterfaces.RoleService) *Role { return &Role{ - jwtManager: jwtManager, + jwtManager: jwtManager, + roleService: roleService, } } @@ -34,10 +37,10 @@ func (r Role) CreateRoles(ctx context.Context, request *apb.CreateRolesRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - for _, r := range request.Roles { - role := rbac.NewRole(uuid.New(), r.Name, r.Description, r.Permissions) + for _, rrole := range request.Roles { + role := rbac.NewRole(uuid.New(), rrole.Name, rrole.Description, rrole.Permissions) - err := roleService.Add(role) + err := r.roleService.Add(role) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -56,7 +59,7 @@ func (r Role) GetRole(ctx context.Context, request *apb.GetRoleRequest) (*apb.Ge start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - roleData, err := roleService.Get(store.Query{Name: request.RoleName}) + roleData, err := r.roleService.Get(store.Query{Name: request.RoleName}) if err != nil { return nil, err } @@ -81,7 +84,7 @@ func (r Role) GetRoles(ctx context.Context, request *apb.GetRolesRequest) (*apb. start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - roleList, err := roleService.GetAll() + roleList, err := r.roleService.GetAll() if err != nil { return nil, err } @@ -109,19 +112,19 @@ func (r Role) UpdateRoles(ctx context.Context, request *apb.UpdateRolesRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - for _, r := range request.Roles { - rid, err := uuid.Parse(r.Id) + for _, role := range request.Roles { + rid, err := uuid.Parse(role.Id) if err != nil { return nil, handleRPCError(labels, err) } - _, err = roleService.Get(store.Query{ID: rid}) + _, err = r.roleService.Get(store.Query{ID: rid}) if err != nil { return nil, status.Errorf(codes.Canceled, "role not found %v", err) } - roleToUpdate := rbac.NewRole(rid, r.Name, r.Description, r.Permissions) - err = roleService.Update(roleToUpdate) + roleToUpdate := rbac.NewRole(rid, role.Name, role.Description, role.Permissions) + err = r.roleService.Update(roleToUpdate) if err != nil { return nil, status.Errorf(codes.Aborted, "could not update role %v", err) } @@ -139,7 +142,7 @@ func (r Role) DeletePermissionsForRole(ctx context.Context, request *apb.DeleteP start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - roleToUpdate, err := roleService.Get(store.Query{Name: request.RoleName}) + roleToUpdate, err := r.roleService.Get(store.Query{Name: request.RoleName}) if err != nil { return nil, status.Errorf(codes.Canceled, "role not found %v", err) } @@ -164,7 +167,7 @@ func (r Role) DeletePermissionsForRole(ctx context.Context, request *apb.DeleteP // updates the existing role with the trimmed set of permissions roleToUpdate.RemovePermissionsFromRole(request.PermissionsToDelete) - err = roleService.Update(roleToUpdate) + err = r.roleService.Update(roleToUpdate) if err != nil { return nil, status.Errorf(codes.Aborted, "could not update role %v", err) } @@ -181,13 +184,13 @@ func (r Role) DeleteRoles(ctx context.Context, request *apb.DeleteRolesRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - for _, r := range request.RoleName { - roleToDelete, err := roleService.Get(store.Query{Name: r}) + for _, role := range request.RoleName { + roleToDelete, err := r.roleService.Get(store.Query{Name: role}) if err != nil { return nil, status.Errorf(codes.Canceled, "role not found") } - err = roleService.Delete(roleToDelete) + err = r.roleService.Delete(roleToDelete) if err != nil { return nil, status.Errorf(codes.Aborted, "error deleting role %v", err) } diff --git a/controller/northbound/server/role_test.go b/controller/northbound/server/role_test.go index 9fdbe2e5f86714ca952df1734c4c804f0fd18d28..1f75f2bc410eac0764b83aff2b33aaa7caf8d037 100644 --- a/controller/northbound/server/role_test.go +++ b/controller/northbound/server/role_test.go @@ -4,11 +4,31 @@ import ( "context" "reflect" "testing" + "time" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" + "code.fbi.h-da.de/danet/gosdn/controller/rbac" "github.com/google/uuid" ) +func getTestRoleServer(t *testing.T) *Role { + jwtManager := rbac.NewJWTManager("test", time.Second) + + userStore := rbac.NewMemoryUserStore() + userService := rbac.NewUserService(userStore) + + roleStore := rbac.NewMemoryRoleStore() + roleService := rbac.NewRoleService(roleStore) + + s := NewRoleServer(jwtManager, roleService) + err := clearAndCreateAuthTestSetup(userService, roleService) + if err != nil { + t.Fatalf("%v", err) + } + + return s +} + func TestRole_CreateRoles(t *testing.T) { type args struct { ctx context.Context @@ -39,7 +59,7 @@ func TestRole_CreateRoles(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Role{} + s := getTestRoleServer(t) got, err := s.CreateRoles(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("Role.CreateRoles() error = %v, wantErr %v", err, tt.wantErr) @@ -95,7 +115,7 @@ func TestRole_GetRole(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Role{} + s := getTestRoleServer(t) got, err := s.GetRole(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("Role.GetRole() error = %v, wantErr %v", err, tt.wantErr) @@ -116,11 +136,6 @@ func TestRole_GetRole(t *testing.T) { } func TestRole_GetRoles(t *testing.T) { - err := clearAndCreateAuthTestSetup() - if err != nil { - t.Fatalf("%v", err) - } - type args struct { ctx context.Context request *apb.GetRolesRequest @@ -171,7 +186,7 @@ func TestRole_GetRoles(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Role{} + s := getTestRoleServer(t) got, err := s.GetRoles(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("Role.GetRoles() error = %v, wantErr %v", err, tt.wantErr) @@ -250,7 +265,7 @@ func TestRole_UpdateRoles(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Role{} + s := getTestRoleServer(t) got, err := s.UpdateRoles(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("Role.UpdateRoles() error = %v, wantErr %v", err, tt.wantErr) @@ -265,8 +280,6 @@ func TestRole_UpdateRoles(t *testing.T) { } func TestRole_DeletePermissionsForRole(t *testing.T) { - clearAndCreateAuthTestSetup() - type args struct { ctx context.Context request *apb.DeletePermissionsForRoleRequest @@ -311,7 +324,7 @@ func TestRole_DeletePermissionsForRole(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Role{} + s := getTestRoleServer(t) got, err := s.DeletePermissionsForRole(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("Role.DeletePermissionsForRole() error = %v, wantErr %v", err, tt.wantErr) @@ -370,8 +383,7 @@ func TestRole_DeleteRoles(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := Role{} - clearAndCreateAuthTestSetup() + s := getTestRoleServer(t) got, err := s.DeleteRoles(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { diff --git a/controller/northbound/server/sbi.go b/controller/northbound/server/sbi.go index 6c6102806bc9c7e644fe3cd59c88fe6c156ce31b..a8358dd23f6521bffc7ff7afcaed810363afdd37 100644 --- a/controller/northbound/server/sbi.go +++ b/controller/northbound/server/sbi.go @@ -5,6 +5,7 @@ import ( "io" spb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/southbound" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" "code.fbi.h-da.de/danet/gosdn/controller/metrics" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" @@ -22,11 +23,21 @@ const ( MB ) -type sbiServer struct { +// SbiServer represents a sbi server +type SbiServer struct { spb.UnimplementedSbiServiceServer + pndStore networkdomain.PndStore } -func (s sbiServer) GetSchema(request *spb.GetSchemaRequest, stream spb.SbiService_GetSchemaServer) error { +// NewSbiServer receives a pndStore and returns a new sbiServer. +func NewSbiServer(pndStore networkdomain.PndStore) *SbiServer { + return &SbiServer{ + pndStore: pndStore, + } +} + +// GetSchema returns the schema of a sbi +func (s SbiServer) GetSchema(request *spb.GetSchemaRequest, stream spb.SbiService_GetSchemaServer) error { labels := prometheus.Labels{"service": "pnd", "rpc": "get schema"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) @@ -41,7 +52,7 @@ func (s sbiServer) GetSchema(request *spb.GetSchemaRequest, stream spb.SbiServic return handleRPCError(labels, err) } - pnd, err := pndc.Get(store.Query{ID: pid}) + pnd, err := s.pndStore.Get(store.Query{ID: pid}) if err != nil { return handleRPCError(labels, err) } diff --git a/controller/northbound/server/test_util_test.go b/controller/northbound/server/test_util_test.go index 6eb5c982b9ef030ec3a427ad5c780b10452f440c..1456c90e0271e427922a6dad4ef52fd9ce708937 100644 --- a/controller/northbound/server/test_util_test.go +++ b/controller/northbound/server/test_util_test.go @@ -6,13 +6,40 @@ import ( "log" "testing" + spb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/southbound" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/device" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/networkdomain" + rbacInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" + "code.fbi.h-da.de/danet/gosdn/controller/interfaces/southbound" + "code.fbi.h-da.de/danet/gosdn/controller/mocks" + "code.fbi.h-da.de/danet/gosdn/controller/nucleus" + "code.fbi.h-da.de/danet/gosdn/models/generated/openconfig" + "code.fbi.h-da.de/danet/gosdn/controller/rbac" "code.fbi.h-da.de/danet/gosdn/controller/store" "github.com/google/uuid" "github.com/sethvargo/go-password/password" + "github.com/stretchr/testify/mock" "golang.org/x/crypto/argon2" ) +const pndID = "2043519e-46d1-4963-9a8e-d99007e104b8" +const pendingChangeID = "0992d600-f7d4-4906-9559-409b04d59a5f" +const committedChangeID = "804787d6-e5a8-4dba-a1e6-e73f96b0119e" +const sbiID = "f6fd4b35-f039-4111-9156-5e4501bb8a5a" +const ondID = "7e0ed8cc-ebf5-46fa-9794-741494914883" + +var hostname = "manfred" +var domainname = "uwe" +var pndUUID uuid.UUID +var sbiUUID uuid.UUID +var pendingChangeUUID uuid.UUID +var committedChangeUUID uuid.UUID +var deviceUUID uuid.UUID +var mockPnd *mocks.NetworkDomain +var mockDevice device.Device +var sbiStore southbound.Store + // Name of this file requires _test at the end, because of how the availability of varibales is handled in test files of go packages. // Does not include actual file tests! @@ -23,9 +50,8 @@ const randomRoleName = "bertram" var adminRoleMap = map[string]string{pndID: "adminTestRole"} var userRoleMap = map[string]string{pndID: "userTestRole"} -var jwt *rbac.JWTManager -func clearAndCreateAuthTestSetup() error { +func clearAndCreateAuthTestSetup(userService rbacInterfaces.UserService, roleService rbacInterfaces.RoleService) error { //clear setup if changed storedUsers, err := userService.GetAll() if err != nil { @@ -50,12 +76,12 @@ func clearAndCreateAuthTestSetup() error { } // create dataset - err = createTestUsers() + err = createTestUsers(userService) if err != nil { return err } - err = createTestRoles() + err = createTestRoles(roleService) if err != nil { return err } @@ -63,7 +89,7 @@ func clearAndCreateAuthTestSetup() error { return nil } -func createTestUsers() error { +func createTestUsers(userService rbacInterfaces.UserService) error { randomRoleMap := map[string]string{pndID: randomRoleName} // Generate a salt that is 16 characters long with 3 digits, 0 symbols, @@ -93,7 +119,7 @@ func createTestUsers() error { return nil } -func createTestRoles() error { +func createTestRoles(roleService rbacInterfaces.RoleService) error { roles := []rbac.Role{ { RoleID: uuid.MustParse(adminRoleID), @@ -153,7 +179,7 @@ func patchLogger(t *testing.T) { // Creates a token to be used in auth interceptor tests. If validTokenRequired is set as true, the generated token will also // be attached to the provided user. Else the user won't have the token and can not be authorized. -func createTestUserToken(userName string, validTokenRequired bool) (string, error) { +func createTestUserToken(userName string, validTokenRequired bool, userService rbacInterfaces.UserService, jwt *rbac.JWTManager) (string, error) { token, err := jwt.GenerateToken(rbac.User{UserName: userName}) if err != nil { return token, err @@ -178,3 +204,75 @@ func createTestUserToken(userName string, validTokenRequired bool) (string, erro func createHashedAndSaltedPassword(plainPWD, salt string) string { return base64.RawStdEncoding.EncodeToString(argon2.IDKey([]byte(plainPWD), []byte(salt), 1, 64*1024, 4, 32)) } + +func getMockPnd(t *testing.T) networkdomain.NetworkDomain { + mockDevice = &nucleus.CommonDevice{ + Model: &openconfig.Device{ + System: &openconfig.OpenconfigSystem_System{ + Config: &openconfig.OpenconfigSystem_System_Config{ + Hostname: &hostname, + DomainName: &domainname, + }, + }, + }, + UUID: deviceUUID, + } + + sbi, err := nucleus.NewSBI(spb.Type_TYPE_OPENCONFIG, sbiUUID) + if err != nil { + t.Fatal(err) + } + mockDevice.(*nucleus.CommonDevice).SetSBI(sbi) + mockDevice.(*nucleus.CommonDevice).SetTransport(&mocks.Transport{}) + mockDevice.(*nucleus.CommonDevice).SetName(hostname) + sbiStore = nucleus.NewSbiStore(pndUUID) + if err := sbiStore.Add(mockDevice.SBI()); err != nil { + t.Fatal(err) + } + + mockPnd = &mocks.NetworkDomain{} + mockPnd.On("ID").Return(pndUUID) + mockPnd.On("GetName").Return("test") + mockPnd.On("GetDescription").Return("test") + mockPnd.On("GetSBIs").Return(sbiStore) + mockPnd.On("GetSBI", mock.Anything).Return(mockDevice.SBI(), nil) + mockPnd.On("Devices").Return([]uuid.UUID{deviceUUID}) + mockPnd.On("PendingChanges").Return([]uuid.UUID{pendingChangeUUID}) + mockPnd.On("CommittedChanges").Return([]uuid.UUID{committedChangeUUID}) + mockPnd.On("AddDevice", mock.Anything, mock.Anything, mock.Anything).Return(nil) + mockPnd.On("GetDevice", mock.Anything).Return(mockDevice, nil) + mockPnd.On("Commit", mock.Anything).Return(nil) + mockPnd.On("Confirm", mock.Anything).Return(nil) + mockPnd.On("ChangeOND", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(uuid.Nil, nil) + mockPnd.On("Request", mock.Anything, mock.Anything).Return(nil, nil) + + return mockPnd +} + +func initUUIDs(t *testing.T) { + var err error + pndUUID, err = uuid.Parse(pndID) + if err != nil { + t.Fatal(err) + } + + sbiUUID, err = uuid.Parse(sbiID) + if err != nil { + t.Fatal(err) + } + + pendingChangeUUID, err = uuid.Parse(pendingChangeID) + if err != nil { + t.Fatal(err) + } + + committedChangeUUID, err = uuid.Parse(committedChangeID) + if err != nil { + t.Fatal(err) + } + + deviceUUID, err = uuid.Parse(ondID) + if err != nil { + t.Fatal(err) + } +} diff --git a/controller/northbound/server/user.go b/controller/northbound/server/user.go index f15f3412450d3762306fef57122a92cfc63547e2..23885863725164f7643f933203c60fdf3c57438d 100644 --- a/controller/northbound/server/user.go +++ b/controller/northbound/server/user.go @@ -17,18 +17,22 @@ import ( "google.golang.org/grpc/status" "golang.org/x/crypto/argon2" + + rbacInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac" ) // User holds a JWTManager and represents a UserServiceServer. type User struct { apb.UnimplementedUserServiceServer - jwtManager *rbac.JWTManager + jwtManager *rbac.JWTManager + userService rbacInterfaces.UserService } -// NewUserServer receives a JWTManager and returns a new UserServer. -func NewUserServer(jwtManager *rbac.JWTManager) *User { +// NewUserServer receives a JWTManager and a UserService and returns a new UserServer. +func NewUserServer(jwtManager *rbac.JWTManager, userService rbacInterfaces.UserService) *User { return &User{ - jwtManager: jwtManager, + jwtManager: jwtManager, + userService: userService, } } @@ -38,9 +42,9 @@ func (u User) CreateUsers(ctx context.Context, request *apb.CreateUsersRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - for _, u := range request.User { + for _, user := range request.User { roles := map[string]string{} - for key, elem := range u.Roles { + for key, elem := range user.Roles { _, err := uuid.Parse(key) if err != nil { return nil, handleRPCError(labels, err) @@ -56,10 +60,10 @@ func (u User) CreateUsers(ctx context.Context, request *apb.CreateUsersRequest) return nil, status.Errorf(codes.Aborted, "%v", err) } - hashedPassword := base64.RawStdEncoding.EncodeToString(argon2.IDKey([]byte(u.Password), []byte(salt), 1, 64*1024, 4, 32)) + hashedPassword := base64.RawStdEncoding.EncodeToString(argon2.IDKey([]byte(user.Password), []byte(salt), 1, 64*1024, 4, 32)) - user := rbac.NewUser(uuid.New(), u.Name, roles, string(hashedPassword), u.Token, salt) - err = userService.Add(user) + user := rbac.NewUser(uuid.New(), user.Name, roles, string(hashedPassword), user.Token, salt) + err = u.userService.Add(user) if err != nil { log.Error(err) return nil, status.Errorf(codes.Aborted, "%v", err) @@ -78,7 +82,7 @@ func (u User) GetUser(ctx context.Context, request *apb.GetUserRequest) (*apb.Ge start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - userData, err := userService.Get(store.Query{Name: request.Name}) + userData, err := u.userService.Get(store.Query{Name: request.Name}) if err != nil { return nil, err } @@ -102,7 +106,7 @@ func (u User) GetUsers(ctx context.Context, request *apb.GetUsersRequest) (*apb. start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - userList, err := userService.GetAll() + userList, err := u.userService.GetAll() if err != nil { return nil, err } @@ -129,22 +133,22 @@ func (u User) UpdateUsers(ctx context.Context, request *apb.UpdateUsersRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - for _, u := range request.User { - uid, err := uuid.Parse(u.Id) + for _, user := range request.User { + uid, err := uuid.Parse(user.Id) if err != nil { return nil, handleRPCError(labels, err) } - storedUser, err := userService.Get(store.Query{ID: uid}) + storedUser, err := u.userService.Get(store.Query{ID: uid}) if err != nil { return nil, status.Errorf(codes.Canceled, "user not found %v", err) } - hashedPassword := base64.RawStdEncoding.EncodeToString(argon2.IDKey([]byte(u.Password), []byte(storedUser.GetSalt()), 1, 64*1024, 4, 32)) + hashedPassword := base64.RawStdEncoding.EncodeToString(argon2.IDKey([]byte(user.Password), []byte(storedUser.GetSalt()), 1, 64*1024, 4, 32)) - userToUpdate := rbac.NewUser(uid, u.Name, u.Roles, string(hashedPassword), u.Token, storedUser.GetSalt()) + userToUpdate := rbac.NewUser(uid, user.Name, user.Roles, string(hashedPassword), user.Token, storedUser.GetSalt()) - err = userService.Update(userToUpdate) + err = u.userService.Update(userToUpdate) if err != nil { return nil, status.Errorf(codes.Aborted, "could not update user %v", err) } @@ -162,13 +166,13 @@ func (u User) DeleteUsers(ctx context.Context, request *apb.DeleteUsersRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - for _, u := range request.Username { - userToDelete, err := userService.Get(store.Query{Name: u}) + for _, user := range request.Username { + userToDelete, err := u.userService.Get(store.Query{Name: user}) if err != nil { return nil, status.Errorf(codes.Canceled, "user not found %v", err) } - err = userService.Delete(userToDelete) + err = u.userService.Delete(userToDelete) if err != nil { return nil, status.Errorf(codes.Aborted, "error deleting user %v", err) } @@ -180,7 +184,7 @@ func (u User) DeleteUsers(ctx context.Context, request *apb.DeleteUsersRequest) } func (u User) isValidUser(user rbac.User) (bool, error) { - storedUser, err := userService.Get(store.Query{Name: user.Name()}) + storedUser, err := u.userService.Get(store.Query{Name: user.Name()}) if err != nil { return false, err } else if storedUser == nil { diff --git a/controller/northbound/server/user_test.go b/controller/northbound/server/user_test.go index a5d2d25f291387cb5c363434e641e4a42da2a2f8..c84533635306c916a2f0011510cef50b0bc110de 100644 --- a/controller/northbound/server/user_test.go +++ b/controller/northbound/server/user_test.go @@ -4,11 +4,31 @@ import ( "context" "reflect" "testing" + "time" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" + "code.fbi.h-da.de/danet/gosdn/controller/rbac" "github.com/google/uuid" ) +func getTestUserServer(t *testing.T) *User { + jwtManager := rbac.NewJWTManager("test", time.Second) + + userStore := rbac.NewMemoryUserStore() + userService := rbac.NewUserService(userStore) + + roleStore := rbac.NewMemoryRoleStore() + roleService := rbac.NewRoleService(roleStore) + + s := NewUserServer(jwtManager, userService) + err := clearAndCreateAuthTestSetup(s.userService, roleService) + if err != nil { + t.Fatalf("%v", err) + } + + return s +} + func TestUser_CreateUsers(t *testing.T) { type args struct { ctx context.Context @@ -41,9 +61,7 @@ func TestUser_CreateUsers(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := User{ - jwtManager: jwt, - } + s := getTestUserServer(t) got, err := s.CreateUsers(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("User.CreateUsers() error = %v, wantErr %v", err, tt.wantErr) @@ -96,7 +114,7 @@ func TestUser_GetUser(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := User{} + s := getTestUserServer(t) got, err := s.GetUser(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("User.GetUser() error = %v, wantErr %v", err, tt.wantErr) @@ -117,10 +135,6 @@ func TestUser_GetUser(t *testing.T) { } func TestUser_GetUsers(t *testing.T) { - err := clearAndCreateAuthTestSetup() - if err != nil { - t.Fatalf("%v", err) - } type args struct { ctx context.Context request *apb.GetUsersRequest @@ -150,7 +164,8 @@ func TestUser_GetUsers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := User{} + s := getTestUserServer(t) + got, err := s.GetUsers(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("User.GetUsers() error = %v, wantErr %v", err, tt.wantErr) @@ -220,7 +235,7 @@ func TestUser_UpdateUsers(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := User{} + s := getTestUserServer(t) got, err := s.UpdateUsers(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("User.UpdateUsers() error = %v, wantErr %v", err, tt.wantErr) @@ -264,7 +279,7 @@ func TestUser_DeleteUsers(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - s := User{} + s := getTestUserServer(t) got, err := s.DeleteUsers(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("User.DeleteUsers() error = %v, wantErr %v", err, tt.wantErr) diff --git a/controller/nucleus/memoryPndStore.go b/controller/nucleus/memoryPndStore.go index 57cb3289d3a327b333491fe8005e980b9c4daa03..471947aabed96a3413ca4232548ca017a7a1c9d0 100644 --- a/controller/nucleus/memoryPndStore.go +++ b/controller/nucleus/memoryPndStore.go @@ -45,7 +45,7 @@ func (t *MemoryPndStore) Delete(item networkdomain.NetworkDomain) error { func (t *MemoryPndStore) Get(query store.Query) (networkdomain.NetworkDomain, error) { item, ok := t.Store[query.ID] if !ok { - return nil, nil + return nil, &nerrors.ErrNotFound{ID: query.ID} } return item, nil