From 5cae300b31130836047375dc38704fcdfc815108 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9=20Sterba?= <hda@andre-sterba.de> Date: Thu, 26 Oct 2023 20:17:43 +0000 Subject: [PATCH] Add input validaton for auth See merge request danet/gosdn!586 --- controller/northbound/server/auth.go | 34 ++++++++++--- controller/northbound/server/auth_test.go | 61 ++++++++++++++++++++--- 2 files changed, 81 insertions(+), 14 deletions(-) diff --git a/controller/northbound/server/auth.go b/controller/northbound/server/auth.go index 5cdc4b9a9..2569f8541 100644 --- a/controller/northbound/server/auth.go +++ b/controller/northbound/server/auth.go @@ -3,6 +3,7 @@ package server import ( "context" "encoding/base64" + "errors" "time" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" @@ -16,6 +17,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/reflect/protoreflect" ) // AuthServer holds a JWTManager and represents a AuthServiceServer. @@ -39,14 +41,33 @@ func NewAuthServer( } } +func (s AuthServer) checkForValidationErrors(request protoreflect.ProtoMessage) error { + err := s.protoValidator.Validate(request) + if err != nil { + var valErr *protovalidate.ValidationError + + if ok := errors.As(err, &valErr); ok { + protoErr := valErr.ToProto() + grpcError, _ := status.New(codes.Aborted, "Validation failed").WithDetails(protoErr) + + return grpcError.Err() + } + + return status.Errorf(codes.Aborted, "%v", err) + } + + return nil +} + // Login logs a user in. func (s AuthServer) Login(ctx context.Context, request *apb.LoginRequest) (*apb.LoginResponse, error) { labels := prometheus.Labels{"service": "auth", "rpc": "post"} start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := s.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := s.checkForValidationErrors(request) + if err != nil { + return nil, err } user := rbac.User{ @@ -55,7 +76,7 @@ func (s AuthServer) Login(ctx context.Context, request *apb.LoginRequest) (*apb. } // validation of credentials - err := s.isValidUser(user) + err = s.isValidUser(user) if err != nil { return nil, err } @@ -90,11 +111,12 @@ func (s AuthServer) Logout(ctx context.Context, request *apb.LogoutRequest) (*ap start := metrics.StartHook(labels, grpcRequestsTotal) defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds) - if err := s.protoValidator.Validate(request); err != nil { - return nil, status.Errorf(codes.Aborted, "%v", err) + err := s.checkForValidationErrors(request) + if err != nil { + return nil, err } - err := s.handleLogout(ctx, request.Username) + err = s.handleLogout(ctx, request.Username) if err != nil { return nil, err } diff --git a/controller/northbound/server/auth_test.go b/controller/northbound/server/auth_test.go index c2f8dd149..7716ec358 100644 --- a/controller/northbound/server/auth_test.go +++ b/controller/northbound/server/auth_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate" apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac" eventservice "code.fbi.h-da.de/danet/gosdn/controller/eventService" "code.fbi.h-da.de/danet/gosdn/controller/rbac" @@ -44,10 +45,11 @@ func TestAuth_Login(t *testing.T) { request *apb.LoginRequest } tests := []struct { - name string - args args - want string - wantErr bool + name string + args args + want string + wantErr bool + validationErrors []*validate.Violation }{ { name: "default login", @@ -71,6 +73,23 @@ func TestAuth_Login(t *testing.T) { }, wantErr: true, }, + { + name: "login fail due to missing username", + want: "", + args: args{ + request: &apb.LoginRequest{ + Username: "", + Pwd: "nope", + }, + }, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "username", + ConstraintId: "required", + Message: "value is required", + }}, + }, } for _, tt := range tests { @@ -82,6 +101,10 @@ func TestAuth_Login(t *testing.T) { return } + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } + if resp != nil { got := resp.Token if got == "" { @@ -104,10 +127,11 @@ func TestAuth_Logout(t *testing.T) { request *apb.LogoutRequest } tests := []struct { - name string - args args - want *apb.LogoutResponse - wantErr bool + name string + args args + want *apb.LogoutResponse + wantErr bool + validationErrors []*validate.Violation }{ { name: "default log out", @@ -120,6 +144,23 @@ func TestAuth_Logout(t *testing.T) { want: &apb.LogoutResponse{}, wantErr: false, }, + { + name: "default log out fails due to missing username", + args: args{ + ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)), + request: &apb.LogoutRequest{ + Username: "", + }, + }, + want: &apb.LogoutResponse{}, + wantErr: true, + validationErrors: []*validate.Violation{ + { + FieldPath: "username", + ConstraintId: "required", + Message: "value is required", + }}, + }, } for _, tt := range tests { @@ -129,6 +170,10 @@ func TestAuth_Logout(t *testing.T) { t.Errorf("Auth.Logout() error = %v, wantErr %v", err, tt.wantErr) return } + + if tt.wantErr { + assertValidationErrors(t, err, tt.validationErrors) + } }) } } -- GitLab