package server

import (
	"context"
	"log"
	"net"
	"testing"
	"time"

	pipb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/plugin-internal"
	rpb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/plugin-registry"
	apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac"
	eventservice "code.fbi.h-da.de/danet/gosdn/controller/eventService"
	"code.fbi.h-da.de/danet/gosdn/controller/nucleus"
	"code.fbi.h-da.de/danet/gosdn/controller/rbac"
	"github.com/bufbuild/protovalidate-go"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials/insecure"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/resolver"
	"google.golang.org/grpc/status"
	"google.golang.org/grpc/test/bufconn"
)

func getTestAuthInterceptorServer(t *testing.T) (*AuthInterceptor, *UserServer, *RoleServer, *PluginInternalServer) {
	initUUIDs(t)
	jwtManager := rbac.NewJWTManager("test", time.Minute)
	eventService := eventservice.NewMockEventService()

	userStore := rbac.NewMemoryUserStore()
	userService := rbac.NewUserService(userStore, eventService)

	roleStore := rbac.NewMemoryRoleStore()
	roleService := rbac.NewRoleService(roleStore, eventService)

	registryClient := rpb.NewPluginRegistryServiceClient(&grpc.ClientConn{})

	mockPlugin := getMockPlugin(t)
	pluginService := nucleus.NewPluginServiceMock()
	if err := pluginService.Add(mockPlugin); err != nil {
		t.Fatal(err)
	}

	pndStore := nucleus.NewMemoryPndStore()
	pndService := nucleus.NewPndService(pndStore)

	pnd := nucleus.NewPND(pndUUID, "test", "test")

	if err := pndService.Add(pnd); err != nil {
		t.Fatal(err)
	}

	protoValidator, err := protovalidate.New()
	if err != nil {
		panic(err)
	}

	s := NewAuthInterceptor(jwtManager, userService, roleService)
	u := NewUserServer(jwtManager, userService, protoValidator)
	r := NewRoleServer(jwtManager, roleService, protoValidator)
	p := NewPluginInternalServer(registryClient, pluginService, protoValidator)

	if err := clearAndCreateAuthTestSetup(userService, roleService); err != nil {
		t.Fatal(err)
	}

	return s, u, r, p
}

func dialer(interceptorServer *AuthInterceptor, userServer *UserServer, roleServer *RoleServer, pluginServer *PluginInternalServer) func(context.Context, string) (net.Conn, error) {
	resolver.SetDefaultScheme("passthrough")
	listener := bufconn.Listen(1024 * 1024)

	interceptor := interceptorServer
	server := grpc.NewServer(grpc.UnaryInterceptor(interceptor.Unary()), grpc.StreamInterceptor(interceptor.Stream()))

	apb.RegisterUserServiceServer(server, userServer)
	pipb.RegisterPluginInternalServiceServer(server, pluginServer)

	go func() {
		if err := server.Serve(listener); err != nil {
			log.Fatal(err)
		}
	}()

	return func(context.Context, string) (net.Conn, error) {
		return listener.Dial()
	}
}

func TestAuthInterceptor_Unary(t *testing.T) {
	authServer, userServer, roleServer, pluginServer := getTestAuthInterceptorServer(t)
	validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager)
	if err != nil {
		t.Fatal(err)
	}

	wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager)
	if err != nil {
		t.Fatal(err)
	}

	dialerFunc := dialer(authServer, userServer, roleServer, pluginServer)
	conn, err := grpc.NewClient("bufnet",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithContextDialer(dialerFunc),
	)

	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if err := conn.Close(); err != nil {
			log.Fatal(err)
		}
	}()

	client := apb.NewUserServiceClient(conn)

	type args struct {
		ctx     context.Context
		request *apb.GetUsersRequest
	}
	tests := []struct {
		name    string
		args    args
		want    *apb.User
		wantErr bool
	}{
		{
			name: "default unary interceptor",
			args: args{
				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", validToken)),
				request: &apb.GetUsersRequest{},
			},
			want:    &apb.User{},
			wantErr: false,
		},
		{
			name: "error unary invalid user token",
			args: args{
				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", wrongUserToken)),
				request: &apb.GetUsersRequest{},
			},
			want:    nil,
			wantErr: true,
		},
		{
			name: "error unary invalid token string",
			args: args{
				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", "foo")),
				request: &apb.GetUsersRequest{},
			},
			want:    nil,
			wantErr: true,
		},
		{
			name: "error unary no token in metadata",
			args: args{
				ctx:     metadata.NewOutgoingContext(context.Background(), metadata.Pairs("foo", "foo")),
				request: &apb.GetUsersRequest{},
			},
			want:    nil,
			wantErr: true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := client.GetUsers(tt.args.ctx, tt.args.request)
			if (err != nil) != tt.wantErr {
				t.Errorf("AuthInterceptor.Unary() = %v, wantErr %v", err, tt.wantErr)
				return
			}

			if got != nil {
				// Todo: check why we don't hit this.
				// fmt.Printf("Got.User %+v", got.User)

				if got.User[0].Name == tt.want.Name {
					t.Errorf("Got user = %s, want %v", got.User[0].Name, tt.want.Name)
					return
				}
			}
		})
	}
}

func TestAuthInterceptor_Stream(t *testing.T) {
	authServer, userServer, roleServer, pluginServer := getTestAuthInterceptorServer(t)
	validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager)
	if err != nil {
		t.Fatal(err)
	}
	tokenWithMissingRights, err := createTestUserToken("testUser", true, authServer.userService, authServer.jwtManager)
	if err != nil {
		t.Fatal(err)
	}

	dialerFunc := dialer(authServer, userServer, roleServer, pluginServer)
	conn, err := grpc.NewClient("bufnet",
		grpc.WithTransportCredentials(insecure.NewCredentials()),
		grpc.WithContextDialer(dialerFunc),
	)

	if err != nil {
		t.Fatal(err)
	}

	defer func() {
		if err := conn.Close(); err != nil {
			log.Fatal(err)
		}
	}()

	client := pipb.NewPluginInternalServiceClient(conn)

	type args struct {
		ctx     context.Context
		request *pipb.GetPluginSchemaRequest
	}
	tests := []struct {
		name       string
		args       args
		statusCode codes.Code
		wantErr    bool
	}{
		{
			name: "default stream interceptor",
			args: args{
				ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", validToken)),
				request: &pipb.GetPluginSchemaRequest{
					Pid: pluginID,
				},
			},
			statusCode: codes.OK,
			wantErr:    false,
		},
		{
			name: "user without sufficient rights",
			args: args{
				ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", tokenWithMissingRights)),
				request: &pipb.GetPluginSchemaRequest{
					Pid: pluginID,
				},
			},
			statusCode: codes.PermissionDenied,
			wantErr:    true,
		},
		{
			name: "user not authenticated",
			args: args{
				ctx: metadata.NewOutgoingContext(context.Background(), metadata.Pairs("authorize", "foo")),
				request: &pipb.GetPluginSchemaRequest{
					Pid: pluginID,
				},
			},
			statusCode: codes.Unauthenticated,
			wantErr:    true,
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			got, err := client.GetPluginSchema(tt.args.ctx, tt.args.request)
			if err != nil {
				t.Errorf("AuthInterceptor.Stream() = %v", err)
				return
			}

			_, err = got.Recv()
			if (err != nil) != tt.wantErr {
				t.Errorf("AuthInterceptor.Stream() = got error: %v, wantErr: %v", err, tt.wantErr)
				return
			}

			if tt.wantErr {
				statusCode, ok := status.FromError(err)
				if !ok {
					t.Errorf("AuthInterceptor.Stream() = %v", err)
					return
				}
				if statusCode.Code() != tt.statusCode {
					t.Errorf("AuthInterceptor.Stream() = got error with status code: %v, want: %v", statusCode.Code(), tt.statusCode)
					return
				}
			}
		})
	}
}

func TestAuthInterceptor_authorize(t *testing.T) {
	authServer, _, _, _ := getTestAuthInterceptorServer(t)
	validToken, err := createTestUserToken("testAdmin", true, authServer.userService, authServer.jwtManager)
	if err != nil {
		t.Fatal(err)
	}

	wrongUserToken, err := createTestUserToken("foo", false, authServer.userService, authServer.jwtManager)
	if err != nil {
		t.Fatal(err)
	}

	type args struct {
		ctx    context.Context
		method string
	}
	tests := []struct {
		name    string
		args    args
		wantErr bool
	}{
		{
			name: "default authorize",
			args: args{
				ctx:    metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)),
				method: "/gosdn.rbac.UserService/GetUsers",
			},
			wantErr: false,
		},
		{
			name: "error invalid token",
			args: args{
				ctx:    metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", wrongUserToken)),
				method: "/gosdn.rbac.UserService/GetUsers",
			},
			wantErr: true,
		},
		{
			name: "error no permission for request",
			args: args{
				ctx:    metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)),
				method: "/gosdn.pnd.PndService/DeleteMne",
			},
			wantErr: true,
		},
	}
	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			if err := authServer.authorize(tt.args.ctx, tt.args.method); (err != nil) != tt.wantErr {
				t.Errorf("AuthInterceptor.authorize() error = %v, wantErr %v", err, tt.wantErr)
			}
		})
	}
}