package server import ( "context" "log" "net" "testing" "time" pipb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/plugin-internal" rpb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/plugin-registry" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" eventservice "code.fbi.h-da.de/danet/gosdn/controller/eventService" "code.fbi.h-da.de/danet/gosdn/controller/nucleus" "code.fbi.h-da.de/danet/gosdn/controller/rbac" "github.com/bufbuild/protovalidate-go" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/metadata" "google.golang.org/grpc/resolver" "google.golang.org/grpc/status" "google.golang.org/grpc/test/bufconn" ) func getTestAuthInterceptorServer(t *testing.T) (*AuthInterceptor, *UserServer, *RoleServer, *PluginInternalServer) { initUUIDs(t) jwtManager := rbac.NewJWTManager("test", time.Minute) eventService := eventservice.NewMockEventService() userStore := rbac.NewMemoryUserStore() userService := rbac.NewUserService(userStore, eventService) roleStore := rbac.NewMemoryRoleStore() roleService := rbac.NewRoleService(roleStore, eventService) registryClient := rpb.NewPluginRegistryServiceClient(&grpc.ClientConn{}) mockPlugin := getMockPlugin(t) pluginService := nucleus.NewPluginServiceMock() if err := pluginService.Add(mockPlugin); err != nil { t.Fatal(err) } pndStore := nucleus.NewMemoryPndStore() pndService := nucleus.NewPndService(pndStore) pnd := nucleus.NewPND(pndUUID, "test", "test") if err := pndService.Add(pnd); err != nil { t.Fatal(err) } protoValidator, err := protovalidate.New() if err != nil { panic(err) } s := NewAuthInterceptor(jwtManager, userService, roleService) u := NewUserServer(jwtManager, userService, protoValidator) r := NewRoleServer(jwtManager, roleService, protoValidator) p := NewPluginInternalServer(registryClient, pluginService, protoValidator) if err := clearAndCreateAuthTestSetup(userService, roleService); err != nil { t.Fatal(err) } return s, u, r, p } func dialer(interceptorServer *AuthInterceptor, userServer *UserServer, roleServer *RoleServer, pluginServer *PluginInternalServer) func(context.Context, string) (net.Conn, error) { resolver.SetDefaultScheme("passthrough") listener := bufconn.Listen(1024 * 1024) interceptor := interceptorServer server := grpc.NewServer(grpc.UnaryInterceptor(interceptor.Unary()), grpc.StreamInterceptor(interceptor.Stream())) apb.RegisterUserServiceServer(server, userServer) pipb.RegisterPluginInternalServiceServer(server, pluginServer) go func() { if err := server.Serve(listener); err != nil { log.Fatal(err) } }() return func(context.Context, string) (net.Conn, error) { return listener.Dial() } } func TestAuthInterceptor_Unary(t *testing.T) { authServer, userServer, roleServer, pluginServer := getTestAuthInterceptorServer(t) validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager) if err != nil { t.Fatal(err) } wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager) if err != nil { t.Fatal(err) } dialerFunc := dialer(authServer, userServer, roleServer, pluginServer) conn, err := grpc.NewClient("bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(dialerFunc), ) if err != nil { t.Fatal(err) } defer func() { if err := conn.Close(); err != nil { log.Fatal(err) } }() client := apb.NewUserServiceClient(conn) type args struct { ctx context.Context request *apb.GetUsersRequest } tests := []struct { name string args args want *apb.User wantErr bool }{ { name: "default unary interceptor", args: args{ ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", validToken)), request: &apb.GetUsersRequest{}, }, want: &apb.User{}, wantErr: false, }, { name: "error unary invalid user token", args: args{ ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", wrongUserToken)), request: &apb.GetUsersRequest{}, }, want: nil, wantErr: true, }, { name: "error unary invalid token string", args: args{ ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", "foo")), request: &apb.GetUsersRequest{}, }, want: nil, wantErr: true, }, { name: "error unary no token in metadata", args: args{ ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("foo", "foo")), request: &apb.GetUsersRequest{}, }, want: nil, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := client.GetUsers(tt.args.ctx, tt.args.request) if (err != nil) != tt.wantErr { t.Errorf("AuthInterceptor.Unary() = %v, wantErr %v", err, tt.wantErr) return } if got != nil { // Todo: check why we don't hit this. // fmt.Printf("Got.User %+v", got.User) if got.User[0].Name == tt.want.Name { t.Errorf("Got user = %s, want %v", got.User[0].Name, tt.want.Name) return } } }) } } func TestAuthInterceptor_Stream(t *testing.T) { authServer, userServer, roleServer, pluginServer := getTestAuthInterceptorServer(t) validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager) if err != nil { t.Fatal(err) } tokenWithMissingRights, err := createTestUserToken("testUser", true, authServer.userService, authServer.jwtManager) if err != nil { t.Fatal(err) } dialerFunc := dialer(authServer, userServer, roleServer, pluginServer) conn, err := grpc.NewClient("bufnet", grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithContextDialer(dialerFunc), ) if err != nil { t.Fatal(err) } defer func() { if err := conn.Close(); err != nil { log.Fatal(err) } }() client := pipb.NewPluginInternalServiceClient(conn) type args struct { ctx context.Context request *pipb.GetPluginSchemaRequest } tests := []struct { name string args args statusCode codes.Code wantErr bool }{ { name: "default stream interceptor", args: args{ ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", validToken)), request: &pipb.GetPluginSchemaRequest{ Pid: pluginID, }, }, statusCode: codes.OK, wantErr: false, }, { name: "user without sufficient rights", args: args{ ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", tokenWithMissingRights)), request: &pipb.GetPluginSchemaRequest{ Pid: pluginID, }, }, statusCode: codes.PermissionDenied, wantErr: true, }, { name: "user not authenticated", args: args{ ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", "foo")), request: &pipb.GetPluginSchemaRequest{ Pid: pluginID, }, }, statusCode: codes.Unauthenticated, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := client.GetPluginSchema(tt.args.ctx, tt.args.request) if err != nil { t.Errorf("AuthInterceptor.Stream() = %v", err) return } _, err = got.Recv() if (err != nil) != tt.wantErr { t.Errorf("AuthInterceptor.Stream() = got error: %v, wantErr: %v", err, tt.wantErr) return } if tt.wantErr { statusCode, ok := status.FromError(err) if !ok { t.Errorf("AuthInterceptor.Stream() = %v", err) return } if statusCode.Code() != tt.statusCode { t.Errorf("AuthInterceptor.Stream() = got error with status code: %v, want: %v", statusCode.Code(), tt.statusCode) return } } }) } } func TestAuthInterceptor_authorize(t *testing.T) { authServer, _, _, _ := getTestAuthInterceptorServer(t) validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager) if err != nil { t.Fatal(err) } wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager) if err != nil { t.Fatal(err) } type args struct { ctx context.Context method string } tests := []struct { name string args args wantErr bool }{ { name: "default authorize", args: args{ ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)), method: "/gosdn.rbac.UserService/GetUsers", }, wantErr: false, }, { name: "error invalid token", args: args{ ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", wrongUserToken)), method: "/gosdn.rbac.UserService/GetUsers", }, wantErr: true, }, { name: "error no permission for request", args: args{ ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)), method: "/gosdn.pnd.PndService/DeleteMne", }, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if err := authServer.authorize(tt.args.ctx, tt.args.method); (err != nil) != tt.wantErr { t.Errorf("AuthInterceptor.authorize() error = %v, wantErr %v", err, tt.wantErr) } }) } }