Skip to content
Snippets Groups Projects
auth_interceptor_test.go 9.14 KiB
Newer Older
  • Learn to ignore specific revisions
  • package server
    
    import (
    	"context"
    	"log"
    	"net"
    	"testing"
    
    	pipb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/plugin-internal"
    	rpb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/plugin-registry"
    
    	apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac"
    
    	eventservice "code.fbi.h-da.de/danet/gosdn/controller/eventService"
    
    	"code.fbi.h-da.de/danet/gosdn/controller/nucleus"
    	"code.fbi.h-da.de/danet/gosdn/controller/rbac"
    
    	"github.com/bufbuild/protovalidate-go"
    
    	"google.golang.org/grpc"
    
    	"google.golang.org/grpc/credentials/insecure"
    	"google.golang.org/grpc/metadata"
    
    	"google.golang.org/grpc/resolver"
    
    	"google.golang.org/grpc/test/bufconn"
    )
    
    
    func getTestAuthInterceptorServer(t *testing.T) (*AuthInterceptor, *UserServer, *RoleServer, *PluginInternalServer) {
    
    	initUUIDs(t)
    	jwtManager := rbac.NewJWTManager("test", time.Minute)
    
    	eventService := eventservice.NewMockEventService()
    
    
    	userStore := rbac.NewMemoryUserStore()
    
    	userService := rbac.NewUserService(userStore, eventService)
    
    
    	roleStore := rbac.NewMemoryRoleStore()
    
    	roleService := rbac.NewRoleService(roleStore, eventService)
    
    	registryClient := rpb.NewPluginRegistryServiceClient(&grpc.ClientConn{})
    
    	mockPlugin := getMockPlugin(t)
    	pluginService := nucleus.NewPluginServiceMock()
    	if err := pluginService.Add(mockPlugin); err != nil {
    		t.Fatal(err)
    	}
    
    
    	pndStore := nucleus.NewMemoryPndStore()
    
    	pndService := nucleus.NewPndService(pndStore)
    
    	pnd := nucleus.NewPND(pndUUID, "test", "test")
    
    	if err := pndService.Add(pnd); err != nil {
    
    	protoValidator, err := protovalidate.New()
    	if err != nil {
    		panic(err)
    	}
    
    
    	s := NewAuthInterceptor(jwtManager, userService, roleService)
    
    	u := NewUserServer(jwtManager, userService, protoValidator)
    	r := NewRoleServer(jwtManager, roleService, protoValidator)
    	p := NewPluginInternalServer(registryClient, pluginService, protoValidator)
    
    	if err := clearAndCreateAuthTestSetup(userService, roleService); err != nil {
    		t.Fatal(err)
    	}
    
    func dialer(interceptorServer *AuthInterceptor, userServer *UserServer, roleServer *RoleServer, pluginServer *PluginInternalServer) func(context.Context, string) (net.Conn, error) {
    
    	resolver.SetDefaultScheme("passthrough")
    
    	listener := bufconn.Listen(1024 * 1024)
    
    
    	interceptor := interceptorServer
    
    	server := grpc.NewServer(grpc.UnaryInterceptor(interceptor.Unary()), grpc.StreamInterceptor(interceptor.Stream()))
    
    
    	apb.RegisterUserServiceServer(server, userServer)
    
    	pipb.RegisterPluginInternalServiceServer(server, pluginServer)
    
    
    	go func() {
    		if err := server.Serve(listener); err != nil {
    			log.Fatal(err)
    		}
    	}()
    
    	return func(context.Context, string) (net.Conn, error) {
    		return listener.Dial()
    	}
    }
    
    func TestAuthInterceptor_Unary(t *testing.T) {
    
    	authServer, userServer, roleServer, pluginServer := getTestAuthInterceptorServer(t)
    
    	validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager)
    
    	if err != nil {
    
    	wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager)
    
    	if err != nil {
    
    	dialerFunc := dialer(authServer, userServer, roleServer, pluginServer)
    	conn, err := grpc.NewClient("bufnet",
    
    		grpc.WithTransportCredentials(insecure.NewCredentials()),
    
    		grpc.WithContextDialer(dialerFunc),
    
    	if err != nil {
    
    	defer func() {
    		if err := conn.Close(); err != nil {
    			log.Fatal(err)
    		}
    	}()
    
    
    	client := apb.NewUserServiceClient(conn)
    
    	type args struct {
    		ctx     context.Context
    		request *apb.GetUsersRequest
    	}
    	tests := []struct {
    		name    string
    		args    args
    
    		wantErr bool
    	}{
    		{
    			name: "default unary interceptor",
    			args: args{
    				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", validToken)),
    				request: &apb.GetUsersRequest{},
    			},
    
    			wantErr: false,
    		},
    		{
    			name: "error unary invalid user token",
    			args: args{
    				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", wrongUserToken)),
    				request: &apb.GetUsersRequest{},
    			},
    			want:    nil,
    			wantErr: true,
    		},
    		{
    			name: "error unary invalid token string",
    			args: args{
    				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", "foo")),
    				request: &apb.GetUsersRequest{},
    			},
    			want:    nil,
    			wantErr: true,
    		},
    		{
    			name: "error unary no token in metadata",
    			args: args{
    				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("foo", "foo")),
    				request: &apb.GetUsersRequest{},
    			},
    			want:    nil,
    			wantErr: true,
    		},
    	}
    
    	for _, tt := range tests {
    		t.Run(tt.name, func(t *testing.T) {
    			got, err := client.GetUsers(tt.args.ctx, tt.args.request)
    			if (err != nil) != tt.wantErr {
    				t.Errorf("AuthInterceptor.Unary() = %v, wantErr %v", err, tt.wantErr)
    				return
    			}
    
    
    			if got != nil {
    				// Todo: check why we don't hit this.
    				// fmt.Printf("Got.User %+v", got.User)
    
    				if got.User[0].Name == tt.want.Name {
    					t.Errorf("Got user = %s, want %v", got.User[0].Name, tt.want.Name)
    					return
    				}
    
    			}
    		})
    	}
    }
    
    func TestAuthInterceptor_Stream(t *testing.T) {
    
    	authServer, userServer, roleServer, pluginServer := getTestAuthInterceptorServer(t)
    
    	validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager)
    
    	if err != nil {
    
    	tokenWithMissingRights, err := createTestUserToken("testUser", true, authServer.userService, authServer.jwtManager)
    	if err != nil {
    		t.Fatal(err)
    	}
    
    	dialerFunc := dialer(authServer, userServer, roleServer, pluginServer)
    	conn, err := grpc.NewClient("bufnet",
    
    		grpc.WithTransportCredentials(insecure.NewCredentials()),
    
    		grpc.WithContextDialer(dialerFunc),
    
    	if err != nil {
    
    	defer func() {
    		if err := conn.Close(); err != nil {
    			log.Fatal(err)
    		}
    	}()
    
    	client := pipb.NewPluginInternalServiceClient(conn)
    
    
    	type args struct {
    		ctx     context.Context
    
    		request *pipb.GetPluginSchemaRequest
    
    	}
    	tests := []struct {
    
    		name       string
    		args       args
    		statusCode codes.Code
    		wantErr    bool
    
    	}{
    		{
    			name: "default stream interceptor",
    			args: args{
    				ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", validToken)),
    
    				request: &pipb.GetPluginSchemaRequest{
    					Pid: pluginID,
    
    			name: "user without sufficient rights",
    			args: args{
    				ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", tokenWithMissingRights)),
    				request: &pipb.GetPluginSchemaRequest{
    					Pid: pluginID,
    				},
    			},
    			statusCode: codes.PermissionDenied,
    			wantErr:    true,
    		},
    		{
    			name: "user not authenticated",
    
    			args: args{
    				ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", "foo")),
    
    				request: &pipb.GetPluginSchemaRequest{
    					Pid: pluginID,
    
    			statusCode: codes.Unauthenticated,
    			wantErr:    true,
    
    		},
    	}
    
    	for _, tt := range tests {
    		t.Run(tt.name, func(t *testing.T) {
    
    			got, err := client.GetPluginSchema(tt.args.ctx, tt.args.request)
    
    			if err != nil {
    				t.Errorf("AuthInterceptor.Stream() = %v", err)
    				return
    			}
    
    
    			_, err = got.Recv()
    			if (err != nil) != tt.wantErr {
    				t.Errorf("AuthInterceptor.Stream() = got error: %v, wantErr: %v", err, tt.wantErr)
    
    
    			if tt.wantErr {
    				statusCode, ok := status.FromError(err)
    				if !ok {
    					t.Errorf("AuthInterceptor.Stream() = %v", err)
    					return
    				}
    				if statusCode.Code() != tt.statusCode {
    					t.Errorf("AuthInterceptor.Stream() = got error with status code: %v, want: %v", statusCode.Code(), tt.statusCode)
    					return
    				}
    			}
    
    		})
    	}
    }
    
    func TestAuthInterceptor_authorize(t *testing.T) {
    
    	authServer, _, _, _ := getTestAuthInterceptorServer(t)
    	validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager)
    
    	if err != nil {
    
    	wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager)
    
    	if err != nil {
    
    	}
    
    	type args struct {
    		ctx    context.Context
    		method string
    	}
    	tests := []struct {
    		name    string
    		args    args
    		wantErr bool
    	}{
    		{
    			name: "default authorize",
    			args: args{
    				ctx:    metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)),
    				method: "/gosdn.rbac.UserService/GetUsers",
    			},
    			wantErr: false,
    		},
    		{
    			name: "error invalid token",
    			args: args{
    				ctx:    metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", wrongUserToken)),
    				method: "/gosdn.rbac.UserService/GetUsers",
    			},
    			wantErr: true,
    		},
    		{
    			name: "error no permission for request",
    			args: args{
    				ctx:    metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)),
    
    				method: "/gosdn.pnd.PndService/DeleteMne",
    
    			},
    			wantErr: true,
    		},
    	}
    	for _, tt := range tests {
    		t.Run(tt.name, func(t *testing.T) {
    
    			if err := authServer.authorize(tt.args.ctx, tt.args.method); (err != nil) != tt.wantErr {
    
    				t.Errorf("AuthInterceptor.authorize() error = %v, wantErr %v", err, tt.wantErr)
    			}
    		})
    	}
    }