diff --git a/controller/northbound/server/role.go b/controller/northbound/server/role.go index 09c9a0bbcd8ac031f52481bd7f19b6b78d093d99..7691ce523f47285610df8b5dd734ffa0ed2e8c80 100644 --- a/controller/northbound/server/role.go +++ b/controller/northbound/server/role.go @@ -2,6 +2,7 @@ package server import ( "context" + "errors" "fmt" "time" @@ -16,6 +17,7 @@ import ( log "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/reflect/protoreflect" ) // RoleServer holds a JWTManager and represents a RoleServiceServer. @@ -39,14 +41,33 @@ func NewRoleServer( } } +func (r RoleServer) checkForValidationErrors(request protoreflect.ProtoMessage) error { + err := r.protoValidator.Validate(request) + if err != nil { + var valErr *protovalidate.ValidationError + + if ok := errors.As(err, &valErr); ok { + protoErr := valErr.ToProto() + grpcError, _ := status.New(codes.Aborted, "Validation failed").WithDetails(protoErr) + + return grpcError.Err() + } + + return status.Errorf(codes.Aborted, "%v", err) + } + + return nil +} + // CreateRoles creates one are multiple new roles. func (r RoleServer) CreateRoles(ctx context.Context, request *apb.CreateRolesRequest) (*apb.CreateRolesResponse, error) { labels := prometheus.Labels{"service": "auth", "rpc": "post"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := r.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := r.checkForValidationErrors(request) + if err != nil { + return nil, err } for _, rrole := range request.Roles { @@ -70,8 +91,9 @@ func (r RoleServer) GetRole(ctx context.Context, request *apb.GetRoleRequest) (* start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := r.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := r.checkForValidationErrors(request) + if err != nil { + return nil, err } roleID, err := uuid.Parse(request.Id) @@ -103,8 +125,9 @@ func (r RoleServer) GetRoles(ctx context.Context, request *apb.GetRolesRequest) start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := r.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := r.checkForValidationErrors(request) + if err != nil { + return nil, err } roleList, err := r.roleService.GetAll() @@ -134,8 +157,9 @@ func (r RoleServer) UpdateRoles(ctx context.Context, request *apb.UpdateRolesReq start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := r.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := r.checkForValidationErrors(request) + if err != nil { + return nil, err } for _, role := range request.Roles { @@ -166,8 +190,9 @@ func (r RoleServer) DeletePermissionsForRole(ctx context.Context, request *apb.D start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := r.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := r.checkForValidationErrors(request) + if err != nil { + return nil, err } roleToUpdate, err := r.roleService.Get(store.Query{Name: request.RoleName}) @@ -211,8 +236,9 @@ func (r RoleServer) DeleteRoles(ctx context.Context, request *apb.DeleteRolesReq start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := r.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := r.checkForValidationErrors(request) + if err != nil { + return nil, err } for _, role := range request.RoleName { diff --git a/controller/northbound/server/role_test.go b/controller/northbound/server/role_test.go index 228e5503e4830cd70bad9781d921b2c36f746589..d2f3dc7e042ebefe196f3b2cb356daebde57ce1f 100644 --- a/controller/northbound/server/role_test.go +++ b/controller/northbound/server/role_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" "code.fbi.h-da.de/danet/gosdn/controller/rbac" "github.com/bufbuild/protovalidate-go" @@ -44,10 +45,11 @@ func TestRole_CreateRoles(t *testing.T) { request *apb.CreateRolesRequest } tests := []struct { - name string - args args - want *apb.CreateRolesResponse - wantErr bool + name string + args args + want *apb.CreateRolesResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default create roles", @@ -65,6 +67,49 @@ func TestRole_CreateRoles(t *testing.T) { want: &apb.CreateRolesResponse{}, wantErr: false, }, + { + name: "role with too short name should fail", + args: args{ctx: context.TODO(), + request: &apb.CreateRolesRequest{ + Roles: []*apb.Role{ + { + Name: "r1", + Description: "Role 1", + Permissions: []string{"permission 1", "permission 2"}, + }, + }, + }, + }, + want: &apb.CreateRolesResponse{}, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "roles[0].name", + ConstraintId: "string.min_len", + Message: "value length must be at least 3 characters", + }}, + }, + { + name: "role with too short description should fail", + args: args{ctx: context.TODO(), + request: &apb.CreateRolesRequest{ + Roles: []*apb.Role{ + { + Name: "new role 1", + Description: "r1", + Permissions: []string{"permission 1", "permission 2"}, + }, + }, + }, + }, + want: &apb.CreateRolesResponse{}, + wantErr: true, + validationErrors: []*validate.Violation{{ + FieldPath: "roles[0].description", + ConstraintId: "string.min_len", + Message: "value length must be at least 3 characters", + }}, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -74,6 +119,10 @@ func TestRole_CreateRoles(t *testing.T) { t.Errorf("Role.CreateRoles() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } } @@ -84,10 +133,11 @@ func TestRole_GetRole(t *testing.T) { request *apb.GetRoleRequest } tests := []struct { - name string - args args - want *apb.GetRoleResponse - wantErr bool + name string + args args + want *apb.GetRoleResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default get role", @@ -118,6 +168,25 @@ func TestRole_GetRole(t *testing.T) { want: nil, wantErr: true, }, + { + name: "error get role with missing name", + args: args{ + ctx: context.TODO(), + request: &apb.GetRoleRequest{ + RoleName: "", + Id: uuid.Nil.String(), + }, + }, + want: nil, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "role_name", + ConstraintId: "required", + Message: "value is required", + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -133,8 +202,8 @@ func TestRole_GetRole(t *testing.T) { t.Errorf("Role.GetRole() = %v, want %v", got, tt.want) } } else { - if got != nil { - t.Errorf("Role.GetRole() = %v, want %v", got, tt.want) + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) } } }) @@ -199,7 +268,7 @@ func TestRole_GetRoles(t *testing.T) { } if got != nil { - if len(got.Roles) != 3 { + if len(got.Roles) != tt.wantLen { t.Errorf("Role.GetRoles() = %v, want %v", got, tt.want) } for _, gotR := range got.Roles { @@ -228,10 +297,11 @@ func TestRole_UpdateRoles(t *testing.T) { request *apb.UpdateRolesRequest } tests := []struct { - name string - args args - want *apb.UpdateRolesResponse - wantErr bool + name string + args args + want *apb.UpdateRolesResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default update roles", @@ -267,6 +337,54 @@ func TestRole_UpdateRoles(t *testing.T) { want: nil, wantErr: true, }, + { + name: "error update roles with too short name", + args: args{ + ctx: context.TODO(), + request: &apb.UpdateRolesRequest{ + Roles: []*apb.Role{ + { + Id: uuid.NewString(), + Name: "a", + Description: "Test role", + }, + }, + }, + }, + want: nil, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "roles[0].name", + ConstraintId: "string.min_len", + Message: "value length must be at least 3 characters", + }, + }, + }, + { + name: "error update roles with too short description", + args: args{ + ctx: context.TODO(), + request: &apb.UpdateRolesRequest{ + Roles: []*apb.Role{ + { + Id: uuid.NewString(), + Name: "My role", + Description: "r", + }, + }, + }, + }, + want: nil, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "roles[0].description", + ConstraintId: "string.min_len", + Message: "value length must be at least 3 characters", + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -276,6 +394,10 @@ func TestRole_UpdateRoles(t *testing.T) { t.Errorf("Role.UpdateRoles() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } } @@ -286,10 +408,11 @@ func TestRole_DeletePermissionsForRole(t *testing.T) { request *apb.DeletePermissionsForRoleRequest } tests := []struct { - name string - args args - want *apb.DeletePermissionsForRoleResponse - wantErr bool + name string + args args + want *apb.DeletePermissionsForRoleResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default delete permissions for role", @@ -320,6 +443,30 @@ func TestRole_DeletePermissionsForRole(t *testing.T) { want: nil, wantErr: true, }, + { + name: "error delete permissions for role with no proper name and permissions provided", + args: args{ + ctx: context.TODO(), + request: &apb.DeletePermissionsForRoleRequest{ + RoleName: "", + PermissionsToDelete: []string{}, + }, + }, + want: nil, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "role_name", + ConstraintId: "required", + Message: "value is required", + }, + { + FieldPath: "permissions_to_delete", + ConstraintId: "required", + Message: "value is required", + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -329,6 +476,10 @@ func TestRole_DeletePermissionsForRole(t *testing.T) { t.Errorf("Role.DeletePermissionsForRole() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } } @@ -339,10 +490,11 @@ func TestRole_DeleteRoles(t *testing.T) { request *apb.DeleteRolesRequest } tests := []struct { - name string - args args - want *apb.DeleteRolesResponse - wantErr bool + name string + args args + want *apb.DeleteRolesResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default delete roles", @@ -371,6 +523,26 @@ func TestRole_DeleteRoles(t *testing.T) { want: &apb.DeleteRolesResponse{}, wantErr: true, }, + { + name: "error delete roles with missing role name", + args: args{ + ctx: context.TODO(), + request: &apb.DeleteRolesRequest{ + RoleName: []string{ + "", + }, + }, + }, + want: &apb.DeleteRolesResponse{}, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "role_name", + ConstraintId: "required", + Message: "value is required", + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -381,6 +553,10 @@ func TestRole_DeleteRoles(t *testing.T) { t.Errorf("Role.DeleteRoles() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } } diff --git a/controller/northbound/server/utils_test.go b/controller/northbound/server/utils_test.go new file mode 100644 index 0000000000000000000000000000000000000000..cfd0458979ebc26e589c4ab132425536de243ccd --- /dev/null +++ b/controller/northbound/server/utils_test.go @@ -0,0 +1,35 @@ +package server + +import ( + "testing" + + "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" + "google.golang.org/grpc/status" +) + +func contains(array []*validate.Violation, err *validate.Violation) bool { + for _, v := range array { + if v.FieldPath == err.FieldPath && v.ConstraintId == err.ConstraintId && v.Message == err.Message { + return true + } + } + + return false +} + +func assertValidationErrors(t *testing.T, err error, expectedValidationErrors []*validate.Violation) { + st := status.Convert(err) + errDetails := st.Details() + + for _, detail := range errDetails { + switch errorType := detail.(type) { + case *validate.Violations: + for _, violation := range errorType.Violations { + ok := contains(expectedValidationErrors, violation) + if !ok { + t.Errorf("Received unexptected validation error: %v, expected %v", violation, expectedValidationErrors) + } + } + } + } +}