diff --git a/controller/northbound/server/user.go b/controller/northbound/server/user.go index 0d2cb4e7f7d196e615dccbc9237370091324a393..1ea31bee575daa4e29faa594610937452543e90b 100644 --- a/controller/northbound/server/user.go +++ b/controller/northbound/server/user.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/base64" + "errors" "fmt" "time" @@ -20,6 +21,7 @@ import ( log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/reflect/protoreflect" "golang.org/x/crypto/argon2" ) @@ -41,14 +43,33 @@ func NewUserServer(jwtManager *rbac.JWTManager, userService rbacInterfaces.UserS } } +func (r UserServer) checkForValidationErrors(request protoreflect.ProtoMessage) error { + err := r.protoValidator.Validate(request) + if err != nil { + var valErr *protovalidate.ValidationError + + if ok := errors.As(err, &valErr); ok { + protoErr := valErr.ToProto() + grpcError, _ := status.New(codes.Aborted, "Validation failed").WithDetails(protoErr) + + return grpcError.Err() + } + + return status.Errorf(codes.Aborted, "%v", err) + } + + return nil +} + // CreateUsers creates new users, can be 1 or more. func (u UserServer) CreateUsers(ctx context.Context, request *apb.CreateUsersRequest) (*apb.CreateUsersResponse, error) { labels := prometheus.Labels{"service": "auth", "rpc": "post"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := u.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := u.checkForValidationErrors(request) + if err != nil { + return nil, err } for _, user := range request.User { @@ -90,8 +111,9 @@ func (u UserServer) GetUser(ctx context.Context, request *apb.GetUserRequest) (* start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := u.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := u.checkForValidationErrors(request) + if err != nil { + return nil, err } userID, err := uuid.Parse(request.Id) @@ -125,8 +147,9 @@ func (u UserServer) GetUsers(ctx context.Context, request *apb.GetUsersRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := u.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := u.checkForValidationErrors(request) + if err != nil { + return nil, err } userList, err := u.userService.GetAll() @@ -158,8 +181,9 @@ func (u UserServer) UpdateUsers(ctx context.Context, request *apb.UpdateUsersReq start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := u.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := u.checkForValidationErrors(request) + if err != nil { + return nil, err } for _, user := range request.User { @@ -198,8 +222,9 @@ func (u UserServer) DeleteUsers(ctx context.Context, request *apb.DeleteUsersReq start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := u.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := u.checkForValidationErrors(request) + if err != nil { + return nil, err } for _, user := range request.Username { diff --git a/controller/northbound/server/user_test.go b/controller/northbound/server/user_test.go index 4c5d70eb9f29e3f6c1f157c602fc00c974a2ff13..6d2058dee8e113108b21a95cdb30b58b55f66ec8 100644 --- a/controller/northbound/server/user_test.go +++ b/controller/northbound/server/user_test.go @@ -2,10 +2,10 @@ package server import ( "context" - "fmt" "testing" "time" + "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/conflict" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" eventservice "code.fbi.h-da.de/danet/gosdn/controller/eventService" @@ -44,9 +44,10 @@ func TestUser_CreateUsers(t *testing.T) { request *apb.CreateUsersRequest } tests := []struct { - name string - args args - wantErr bool + name string + args args + wantErr bool + validationErrors []*validate.Violation }{ { name: "default create users", @@ -87,6 +88,38 @@ func TestUser_CreateUsers(t *testing.T) { }, }, wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "user[0].password", + ConstraintId: "string.min_len", + Message: "value length must be at least 5 characters", + }}, + }, + { + name: "create users with too short username should fail", + args: args{ + ctx: context.TODO(), + request: &apb.CreateUsersRequest{ + User: []*apb.User{ + { + Name: "a", + Roles: map[string]string{pndID: "userTestRole"}, + Password: "password", + Token: "", + Metadata: &conflict.Metadata{ + ResourceVersion: 0, + }, + }, + }, + }, + }, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "user[0].name", + ConstraintId: "string.min_len", + Message: "value length must be at least 3 characters", + }}, }, } for _, tt := range tests { @@ -97,7 +130,10 @@ func TestUser_CreateUsers(t *testing.T) { t.Errorf("User.CreateUsers() error = %v, wantErr %v", err, tt.wantErr) return } - fmt.Printf("err %+v,want no error\n", err) + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } } @@ -109,10 +145,11 @@ func TestUser_GetUser(t *testing.T) { request *apb.GetUserRequest } tests := []struct { - name string - args args - want *apb.GetUserResponse - wantErr bool + name string + args args + want *apb.GetUserResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default get user", @@ -140,6 +177,24 @@ func TestUser_GetUser(t *testing.T) { want: nil, wantErr: true, }, + { + name: "fail get user due to missing name", + args: args{ + ctx: context.TODO(), + request: &apb.GetUserRequest{ + Name: "", + Id: uuid.Nil.String(), + }, + }, + want: nil, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "name", + ConstraintId: "required", + Message: "value is required", + }}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -150,6 +205,10 @@ func TestUser_GetUser(t *testing.T) { return } + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } + if got != nil { if got.User.Name != tt.want.User.Name || got.User.Id != tt.want.User.Id { t.Errorf("User.GetUser() = %v, want %v", got, tt.want) @@ -230,10 +289,11 @@ func TestUser_UpdateUsers(t *testing.T) { request *apb.UpdateUsersRequest } tests := []struct { - name string - args args - want *apb.UpdateUsersResponse - wantErr bool + name string + args args + want *apb.UpdateUsersResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default update user", @@ -266,7 +326,7 @@ func TestUser_UpdateUsers(t *testing.T) { wantErr: true, }, { - name: "update user without name should fail", + name: "update user without name should fail validation", args: args{ctx: context.TODO(), request: &apb.UpdateUsersRequest{User: []*apb.UpdateUser{ { @@ -279,6 +339,13 @@ func TestUser_UpdateUsers(t *testing.T) { }, want: nil, wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "name", + ConstraintId: "required", + Message: "value is required", + }, + }, }, } @@ -290,6 +357,10 @@ func TestUser_UpdateUsers(t *testing.T) { t.Errorf("User.UpdateUsers() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } } @@ -300,10 +371,11 @@ func TestUser_DeleteUsers(t *testing.T) { request *apb.DeleteUsersRequest } tests := []struct { - name string - args args - want *apb.DeleteUsersResponse - wantErr bool + name string + args args + want *apb.DeleteUsersResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default delete users", @@ -314,13 +386,28 @@ func TestUser_DeleteUsers(t *testing.T) { wantErr: false, }, { - name: "error delete users", + name: "error delete users for non existing user", args: args{ctx: context.TODO(), request: &apb.DeleteUsersRequest{Username: []string{"no user"}}, }, want: &apb.DeleteUsersResponse{}, wantErr: true, }, + { + name: "error delete users due to missing name", + args: args{ctx: context.TODO(), + request: &apb.DeleteUsersRequest{Username: []string{""}}, + }, + want: &apb.DeleteUsersResponse{}, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "name", + ConstraintId: "required", + Message: "value is required", + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -330,6 +417,10 @@ func TestUser_DeleteUsers(t *testing.T) { t.Errorf("User.DeleteUsers() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } }