From 5cae300b31130836047375dc38704fcdfc815108 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Andr=C3=A9=20Sterba?= <hda@andre-sterba.de>
Date: Thu, 26 Oct 2023 20:17:43 +0000
Subject: [PATCH] Add input validaton for auth

See merge request danet/gosdn!586
---
 controller/northbound/server/auth.go      | 34 ++++++++++---
 controller/northbound/server/auth_test.go | 61 ++++++++++++++++++++---
 2 files changed, 81 insertions(+), 14 deletions(-)

diff --git a/controller/northbound/server/auth.go b/controller/northbound/server/auth.go
index 5cdc4b9a9..2569f8541 100644
--- a/controller/northbound/server/auth.go
+++ b/controller/northbound/server/auth.go
@@ -3,6 +3,7 @@ package server
 import (
 	"context"
 	"encoding/base64"
+	"errors"
 	"time"
 
 	apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac"
@@ -16,6 +17,7 @@ import (
 	"google.golang.org/grpc/codes"
 	"google.golang.org/grpc/metadata"
 	"google.golang.org/grpc/status"
+	"google.golang.org/protobuf/reflect/protoreflect"
 )
 
 // AuthServer holds a JWTManager and represents a AuthServiceServer.
@@ -39,14 +41,33 @@ func NewAuthServer(
 	}
 }
 
+func (s AuthServer) checkForValidationErrors(request protoreflect.ProtoMessage) error {
+	err := s.protoValidator.Validate(request)
+	if err != nil {
+		var valErr *protovalidate.ValidationError
+
+		if ok := errors.As(err, &valErr); ok {
+			protoErr := valErr.ToProto()
+			grpcError, _ := status.New(codes.Aborted, "Validation failed").WithDetails(protoErr)
+
+			return grpcError.Err()
+		}
+
+		return status.Errorf(codes.Aborted, "%v", err)
+	}
+
+	return nil
+}
+
 // Login logs a user in.
 func (s AuthServer) Login(ctx context.Context, request *apb.LoginRequest) (*apb.LoginResponse, error) {
 	labels := prometheus.Labels{"service": "auth", "rpc": "post"}
 	start := metrics.StartHook(labels, grpcRequestsTotal)
 	defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
 
-	if err := s.protoValidator.Validate(request); err != nil {
-		return nil, status.Errorf(codes.Aborted, "%v", err)
+	err := s.checkForValidationErrors(request)
+	if err != nil {
+		return nil, err
 	}
 
 	user := rbac.User{
@@ -55,7 +76,7 @@ func (s AuthServer) Login(ctx context.Context, request *apb.LoginRequest) (*apb.
 	}
 
 	// validation of credentials
-	err := s.isValidUser(user)
+	err = s.isValidUser(user)
 	if err != nil {
 		return nil, err
 	}
@@ -90,11 +111,12 @@ func (s AuthServer) Logout(ctx context.Context, request *apb.LogoutRequest) (*ap
 	start := metrics.StartHook(labels, grpcRequestsTotal)
 	defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)
 
-	if err := s.protoValidator.Validate(request); err != nil {
-		return nil, status.Errorf(codes.Aborted, "%v", err)
+	err := s.checkForValidationErrors(request)
+	if err != nil {
+		return nil, err
 	}
 
-	err := s.handleLogout(ctx, request.Username)
+	err = s.handleLogout(ctx, request.Username)
 	if err != nil {
 		return nil, err
 	}
diff --git a/controller/northbound/server/auth_test.go b/controller/northbound/server/auth_test.go
index c2f8dd149..7716ec358 100644
--- a/controller/northbound/server/auth_test.go
+++ b/controller/northbound/server/auth_test.go
@@ -6,6 +6,7 @@ import (
 	"testing"
 	"time"
 
+	"buf.build/gen/go/bufbuild/protovalidate/protocolbuffers/go/buf/validate"
 	apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac"
 	eventservice "code.fbi.h-da.de/danet/gosdn/controller/eventService"
 	"code.fbi.h-da.de/danet/gosdn/controller/rbac"
@@ -44,10 +45,11 @@ func TestAuth_Login(t *testing.T) {
 		request *apb.LoginRequest
 	}
 	tests := []struct {
-		name    string
-		args    args
-		want    string
-		wantErr bool
+		name             string
+		args             args
+		want             string
+		wantErr          bool
+		validationErrors []*validate.Violation
 	}{
 		{
 			name: "default login",
@@ -71,6 +73,23 @@ func TestAuth_Login(t *testing.T) {
 			},
 			wantErr: true,
 		},
+		{
+			name: "login fail due to missing username",
+			want: "",
+			args: args{
+				request: &apb.LoginRequest{
+					Username: "",
+					Pwd:      "nope",
+				},
+			},
+			wantErr: true,
+			validationErrors: []*validate.Violation{
+				{
+					FieldPath:    "username",
+					ConstraintId: "required",
+					Message:      "value is required",
+				}},
+		},
 	}
 
 	for _, tt := range tests {
@@ -82,6 +101,10 @@ func TestAuth_Login(t *testing.T) {
 				return
 			}
 
+			if tt.wantErr {
+				assertValidationErrors(t, err, tt.validationErrors)
+			}
+
 			if resp != nil {
 				got := resp.Token
 				if got == "" {
@@ -104,10 +127,11 @@ func TestAuth_Logout(t *testing.T) {
 		request *apb.LogoutRequest
 	}
 	tests := []struct {
-		name    string
-		args    args
-		want    *apb.LogoutResponse
-		wantErr bool
+		name             string
+		args             args
+		want             *apb.LogoutResponse
+		wantErr          bool
+		validationErrors []*validate.Violation
 	}{
 		{
 			name: "default log out",
@@ -120,6 +144,23 @@ func TestAuth_Logout(t *testing.T) {
 			want:    &apb.LogoutResponse{},
 			wantErr: false,
 		},
+		{
+			name: "default log out fails due to missing username",
+			args: args{
+				ctx: metadata.NewIncomingContext(context.Background(), metadata.Pairs("authorize", validToken)),
+				request: &apb.LogoutRequest{
+					Username: "",
+				},
+			},
+			want:    &apb.LogoutResponse{},
+			wantErr: true,
+			validationErrors: []*validate.Violation{
+				{
+					FieldPath:    "username",
+					ConstraintId: "required",
+					Message:      "value is required",
+				}},
+		},
 	}
 
 	for _, tt := range tests {
@@ -129,6 +170,10 @@ func TestAuth_Logout(t *testing.T) {
 				t.Errorf("Auth.Logout() error = %v, wantErr %v", err, tt.wantErr)
 				return
 			}
+
+			if tt.wantErr {
+				assertValidationErrors(t, err, tt.validationErrors)
+			}
 		})
 	}
 }
-- 
GitLab