Skip to content
Snippets Groups Projects
Commit fe9496c6 authored by André Sterba's avatar André Sterba Committed by Neil-Jocelyn Schark
Browse files

Add test for role api inputs

See merge request !567
parent 5f2b5aff
Branches
Tags
1 merge request!567Add test for role api inputs
Pipeline #164206 passed
...@@ -2,6 +2,7 @@ package server ...@@ -2,6 +2,7 @@ package server
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"time" "time"
...@@ -16,6 +17,7 @@ import ( ...@@ -16,6 +17,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/reflect/protoreflect"
) )
// RoleServer holds a JWTManager and represents a RoleServiceServer. // RoleServer holds a JWTManager and represents a RoleServiceServer.
...@@ -39,14 +41,33 @@ func NewRoleServer( ...@@ -39,14 +41,33 @@ func NewRoleServer(
} }
} }
func (r RoleServer) checkForValidationErrors(request protoreflect.ProtoMessage) error {
err := r.protoValidator.Validate(request)
if err != nil {
var valErr *protovalidate.ValidationError
if ok := errors.As(err, &valErr); ok {
protoErr := valErr.ToProto()
grpcError, _ := status.New(codes.Aborted, "Validation failed").WithDetails(protoErr)
return grpcError.Err()
}
return status.Errorf(codes.Aborted, "%v", err)
}
return nil
}
// CreateRoles creates one are multiple new roles. // CreateRoles creates one are multiple new roles.
func (r RoleServer) CreateRoles(ctx context.Context, request *apb.CreateRolesRequest) (*apb.CreateRolesResponse, error) { func (r RoleServer) CreateRoles(ctx context.Context, request *apb.CreateRolesRequest) (*apb.CreateRolesResponse, error) {
labels := prometheus.Labels{"service": "auth", "rpc": "post"} labels := prometheus.Labels{"service": "auth", "rpc": "post"}
start := metrics.StartHook(labels, grpcRequestsTotal) start := metrics.StartHook(labels, grpcRequestsTotal)
defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
if err := r.protoValidator.Validate(request); err != nil { err := r.checkForValidationErrors(request)
return nil, status.Errorf(codes.Aborted, "%v", err) if err != nil {
return nil, err
} }
for _, rrole := range request.Roles { for _, rrole := range request.Roles {
...@@ -70,8 +91,9 @@ func (r RoleServer) GetRole(ctx context.Context, request *apb.GetRoleRequest) (* ...@@ -70,8 +91,9 @@ func (r RoleServer) GetRole(ctx context.Context, request *apb.GetRoleRequest) (*
start := metrics.StartHook(labels, grpcRequestsTotal) start := metrics.StartHook(labels, grpcRequestsTotal)
defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
if err := r.protoValidator.Validate(request); err != nil { err := r.checkForValidationErrors(request)
return nil, status.Errorf(codes.Aborted, "%v", err) if err != nil {
return nil, err
} }
roleID, err := uuid.Parse(request.Id) roleID, err := uuid.Parse(request.Id)
...@@ -103,8 +125,9 @@ func (r RoleServer) GetRoles(ctx context.Context, request *apb.GetRolesRequest) ...@@ -103,8 +125,9 @@ func (r RoleServer) GetRoles(ctx context.Context, request *apb.GetRolesRequest)
start := metrics.StartHook(labels, grpcRequestsTotal) start := metrics.StartHook(labels, grpcRequestsTotal)
defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
if err := r.protoValidator.Validate(request); err != nil { err := r.checkForValidationErrors(request)
return nil, status.Errorf(codes.Aborted, "%v", err) if err != nil {
return nil, err
} }
roleList, err := r.roleService.GetAll() roleList, err := r.roleService.GetAll()
...@@ -134,8 +157,9 @@ func (r RoleServer) UpdateRoles(ctx context.Context, request *apb.UpdateRolesReq ...@@ -134,8 +157,9 @@ func (r RoleServer) UpdateRoles(ctx context.Context, request *apb.UpdateRolesReq
start := metrics.StartHook(labels, grpcRequestsTotal) start := metrics.StartHook(labels, grpcRequestsTotal)
defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
if err := r.protoValidator.Validate(request); err != nil { err := r.checkForValidationErrors(request)
return nil, status.Errorf(codes.Aborted, "%v", err) if err != nil {
return nil, err
} }
for _, role := range request.Roles { for _, role := range request.Roles {
...@@ -166,8 +190,9 @@ func (r RoleServer) DeletePermissionsForRole(ctx context.Context, request *apb.D ...@@ -166,8 +190,9 @@ func (r RoleServer) DeletePermissionsForRole(ctx context.Context, request *apb.D
start := metrics.StartHook(labels, grpcRequestsTotal) start := metrics.StartHook(labels, grpcRequestsTotal)
defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
if err := r.protoValidator.Validate(request); err != nil { err := r.checkForValidationErrors(request)
return nil, status.Errorf(codes.Aborted, "%v", err) if err != nil {
return nil, err
} }
roleToUpdate, err := r.roleService.Get(store.Query{Name: request.RoleName}) roleToUpdate, err := r.roleService.Get(store.Query{Name: request.RoleName})
...@@ -211,8 +236,9 @@ func (r RoleServer) DeleteRoles(ctx context.Context, request *apb.DeleteRolesReq ...@@ -211,8 +236,9 @@ func (r RoleServer) DeleteRoles(ctx context.Context, request *apb.DeleteRolesReq
start := metrics.StartHook(labels, grpcRequestsTotal) start := metrics.StartHook(labels, grpcRequestsTotal)
defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
if err := r.protoValidator.Validate(request); err != nil { err := r.checkForValidationErrors(request)
return nil, status.Errorf(codes.Aborted, "%v", err) if err != nil {
return nil, err
} }
for _, role := range request.RoleName { for _, role := range request.RoleName {
......
...@@ -6,6 +6,7 @@ import ( ...@@ -6,6 +6,7 @@ import (
"testing" "testing"
"time" "time"
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac"
"code.fbi.h-da.de/danet/gosdn/controller/rbac" "code.fbi.h-da.de/danet/gosdn/controller/rbac"
"github.com/bufbuild/protovalidate-go" "github.com/bufbuild/protovalidate-go"
...@@ -44,10 +45,11 @@ func TestRole_CreateRoles(t *testing.T) { ...@@ -44,10 +45,11 @@ func TestRole_CreateRoles(t *testing.T) {
request *apb.CreateRolesRequest request *apb.CreateRolesRequest
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want *apb.CreateRolesResponse want *apb.CreateRolesResponse
wantErr bool wantErr bool
validationErrors []*validate.Violation
}{ }{
{ {
name: "default create roles", name: "default create roles",
...@@ -65,6 +67,49 @@ func TestRole_CreateRoles(t *testing.T) { ...@@ -65,6 +67,49 @@ func TestRole_CreateRoles(t *testing.T) {
want: &apb.CreateRolesResponse{}, want: &apb.CreateRolesResponse{},
wantErr: false, wantErr: false,
}, },
{
name: "role with too short name should fail",
args: args{ctx: context.TODO(),
request: &apb.CreateRolesRequest{
Roles: []*apb.Role{
{
Name: "r1",
Description: "Role 1",
Permissions: []string{"permission 1", "permission 2"},
},
},
},
},
want: &apb.CreateRolesResponse{},
wantErr: true,
validationErrors: []*validate.Violation{
{
FieldPath: "roles[0].name",
ConstraintId: "string.min_len",
Message: "value length must be at least 3 characters",
}},
},
{
name: "role with too short description should fail",
args: args{ctx: context.TODO(),
request: &apb.CreateRolesRequest{
Roles: []*apb.Role{
{
Name: "new role 1",
Description: "r1",
Permissions: []string{"permission 1", "permission 2"},
},
},
},
},
want: &apb.CreateRolesResponse{},
wantErr: true,
validationErrors: []*validate.Violation{{
FieldPath: "roles[0].description",
ConstraintId: "string.min_len",
Message: "value length must be at least 3 characters",
}},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
...@@ -74,6 +119,10 @@ func TestRole_CreateRoles(t *testing.T) { ...@@ -74,6 +119,10 @@ func TestRole_CreateRoles(t *testing.T) {
t.Errorf("Role.CreateRoles() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Role.CreateRoles() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if tt.wantErr {
assertValidationErrors(t, err, tt.validationErrors)
}
}) })
} }
} }
...@@ -84,10 +133,11 @@ func TestRole_GetRole(t *testing.T) { ...@@ -84,10 +133,11 @@ func TestRole_GetRole(t *testing.T) {
request *apb.GetRoleRequest request *apb.GetRoleRequest
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want *apb.GetRoleResponse want *apb.GetRoleResponse
wantErr bool wantErr bool
validationErrors []*validate.Violation
}{ }{
{ {
name: "default get role", name: "default get role",
...@@ -118,6 +168,25 @@ func TestRole_GetRole(t *testing.T) { ...@@ -118,6 +168,25 @@ func TestRole_GetRole(t *testing.T) {
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{
name: "error get role with missing name",
args: args{
ctx: context.TODO(),
request: &apb.GetRoleRequest{
RoleName: "",
Id: uuid.Nil.String(),
},
},
want: nil,
wantErr: true,
validationErrors: []*validate.Violation{
{
FieldPath: "role_name",
ConstraintId: "required",
Message: "value is required",
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
...@@ -133,8 +202,8 @@ func TestRole_GetRole(t *testing.T) { ...@@ -133,8 +202,8 @@ func TestRole_GetRole(t *testing.T) {
t.Errorf("Role.GetRole() = %v, want %v", got, tt.want) t.Errorf("Role.GetRole() = %v, want %v", got, tt.want)
} }
} else { } else {
if got != nil { if tt.wantErr {
t.Errorf("Role.GetRole() = %v, want %v", got, tt.want) assertValidationErrors(t, err, tt.validationErrors)
} }
} }
}) })
...@@ -199,7 +268,7 @@ func TestRole_GetRoles(t *testing.T) { ...@@ -199,7 +268,7 @@ func TestRole_GetRoles(t *testing.T) {
} }
if got != nil { if got != nil {
if len(got.Roles) != 3 { if len(got.Roles) != tt.wantLen {
t.Errorf("Role.GetRoles() = %v, want %v", got, tt.want) t.Errorf("Role.GetRoles() = %v, want %v", got, tt.want)
} }
for _, gotR := range got.Roles { for _, gotR := range got.Roles {
...@@ -228,10 +297,11 @@ func TestRole_UpdateRoles(t *testing.T) { ...@@ -228,10 +297,11 @@ func TestRole_UpdateRoles(t *testing.T) {
request *apb.UpdateRolesRequest request *apb.UpdateRolesRequest
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want *apb.UpdateRolesResponse want *apb.UpdateRolesResponse
wantErr bool wantErr bool
validationErrors []*validate.Violation
}{ }{
{ {
name: "default update roles", name: "default update roles",
...@@ -267,6 +337,54 @@ func TestRole_UpdateRoles(t *testing.T) { ...@@ -267,6 +337,54 @@ func TestRole_UpdateRoles(t *testing.T) {
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{
name: "error update roles with too short name",
args: args{
ctx: context.TODO(),
request: &apb.UpdateRolesRequest{
Roles: []*apb.Role{
{
Id: uuid.NewString(),
Name: "a",
Description: "Test role",
},
},
},
},
want: nil,
wantErr: true,
validationErrors: []*validate.Violation{
{
FieldPath: "roles[0].name",
ConstraintId: "string.min_len",
Message: "value length must be at least 3 characters",
},
},
},
{
name: "error update roles with too short description",
args: args{
ctx: context.TODO(),
request: &apb.UpdateRolesRequest{
Roles: []*apb.Role{
{
Id: uuid.NewString(),
Name: "My role",
Description: "r",
},
},
},
},
want: nil,
wantErr: true,
validationErrors: []*validate.Violation{
{
FieldPath: "roles[0].description",
ConstraintId: "string.min_len",
Message: "value length must be at least 3 characters",
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
...@@ -276,6 +394,10 @@ func TestRole_UpdateRoles(t *testing.T) { ...@@ -276,6 +394,10 @@ func TestRole_UpdateRoles(t *testing.T) {
t.Errorf("Role.UpdateRoles() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Role.UpdateRoles() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if tt.wantErr {
assertValidationErrors(t, err, tt.validationErrors)
}
}) })
} }
} }
...@@ -286,10 +408,11 @@ func TestRole_DeletePermissionsForRole(t *testing.T) { ...@@ -286,10 +408,11 @@ func TestRole_DeletePermissionsForRole(t *testing.T) {
request *apb.DeletePermissionsForRoleRequest request *apb.DeletePermissionsForRoleRequest
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want *apb.DeletePermissionsForRoleResponse want *apb.DeletePermissionsForRoleResponse
wantErr bool wantErr bool
validationErrors []*validate.Violation
}{ }{
{ {
name: "default delete permissions for role", name: "default delete permissions for role",
...@@ -320,6 +443,30 @@ func TestRole_DeletePermissionsForRole(t *testing.T) { ...@@ -320,6 +443,30 @@ func TestRole_DeletePermissionsForRole(t *testing.T) {
want: nil, want: nil,
wantErr: true, wantErr: true,
}, },
{
name: "error delete permissions for role with no proper name and permissions provided",
args: args{
ctx: context.TODO(),
request: &apb.DeletePermissionsForRoleRequest{
RoleName: "",
PermissionsToDelete: []string{},
},
},
want: nil,
wantErr: true,
validationErrors: []*validate.Violation{
{
FieldPath: "role_name",
ConstraintId: "required",
Message: "value is required",
},
{
FieldPath: "permissions_to_delete",
ConstraintId: "required",
Message: "value is required",
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
...@@ -329,6 +476,10 @@ func TestRole_DeletePermissionsForRole(t *testing.T) { ...@@ -329,6 +476,10 @@ func TestRole_DeletePermissionsForRole(t *testing.T) {
t.Errorf("Role.DeletePermissionsForRole() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Role.DeletePermissionsForRole() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if tt.wantErr {
assertValidationErrors(t, err, tt.validationErrors)
}
}) })
} }
} }
...@@ -339,10 +490,11 @@ func TestRole_DeleteRoles(t *testing.T) { ...@@ -339,10 +490,11 @@ func TestRole_DeleteRoles(t *testing.T) {
request *apb.DeleteRolesRequest request *apb.DeleteRolesRequest
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want *apb.DeleteRolesResponse want *apb.DeleteRolesResponse
wantErr bool wantErr bool
validationErrors []*validate.Violation
}{ }{
{ {
name: "default delete roles", name: "default delete roles",
...@@ -371,6 +523,26 @@ func TestRole_DeleteRoles(t *testing.T) { ...@@ -371,6 +523,26 @@ func TestRole_DeleteRoles(t *testing.T) {
want: &apb.DeleteRolesResponse{}, want: &apb.DeleteRolesResponse{},
wantErr: true, wantErr: true,
}, },
{
name: "error delete roles with missing role name",
args: args{
ctx: context.TODO(),
request: &apb.DeleteRolesRequest{
RoleName: []string{
"",
},
},
},
want: &apb.DeleteRolesResponse{},
wantErr: true,
validationErrors: []*validate.Violation{
{
FieldPath: "role_name",
ConstraintId: "required",
Message: "value is required",
},
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
...@@ -381,6 +553,10 @@ func TestRole_DeleteRoles(t *testing.T) { ...@@ -381,6 +553,10 @@ func TestRole_DeleteRoles(t *testing.T) {
t.Errorf("Role.DeleteRoles() error = %v, wantErr %v", err, tt.wantErr) t.Errorf("Role.DeleteRoles() error = %v, wantErr %v", err, tt.wantErr)
return return
} }
if tt.wantErr {
assertValidationErrors(t, err, tt.validationErrors)
}
}) })
} }
} }
package server
import (
"testing"
"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
"google.golang.org/grpc/status"
)
func contains(array []*validate.Violation, err *validate.Violation) bool {
for _, v := range array {
if v.FieldPath == err.FieldPath && v.ConstraintId == err.ConstraintId && v.Message == err.Message {
return true
}
}
return false
}
func assertValidationErrors(t *testing.T, err error, expectedValidationErrors []*validate.Violation) {
st := status.Convert(err)
errDetails := st.Details()
for _, detail := range errDetails {
switch errorType := detail.(type) {
case *validate.Violations:
for _, violation := range errorType.Violations {
ok := contains(expectedValidationErrors, violation)
if !ok {
t.Errorf("Received unexptected validation error: %v, expected %v", violation, expectedValidationErrors)
}
}
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment