Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
auth.go 5.60 KiB
package server

import (
	"context"
	"encoding/base64"
	"errors"
	"time"

	apb "code.fbi.h-da.de/danet/gosdn/api/go/gosdn/rbac"
	rbacInterfaces "code.fbi.h-da.de/danet/gosdn/controller/interfaces/rbac"
	"code.fbi.h-da.de/danet/gosdn/controller/metrics"
	"code.fbi.h-da.de/danet/gosdn/controller/rbac"
	"code.fbi.h-da.de/danet/gosdn/controller/store"
	"github.com/bufbuild/protovalidate-go"
	"github.com/prometheus/client_golang/prometheus"
	"golang.org/x/crypto/argon2"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/metadata"
	"google.golang.org/grpc/status"
	"google.golang.org/protobuf/reflect/protoreflect"
)

// AuthServer holds a JWTManager and represents a AuthServiceServer.
type AuthServer struct {
	apb.UnimplementedAuthServiceServer
	jwtManager     *rbac.JWTManager
	userService    rbacInterfaces.UserService
	protoValidator *protovalidate.Validator
}

// NewAuthServer receives a JWTManager and a userService and returns a new Auth interface.
func NewAuthServer(
	jwtManager *rbac.JWTManager,
	userService rbacInterfaces.UserService,
	protoValidator *protovalidate.Validator,
) *AuthServer {
	return &AuthServer{
		jwtManager:     jwtManager,
		userService:    userService,
		protoValidator: protoValidator,
	}
}

func (s AuthServer) checkForValidationErrors(request protoreflect.ProtoMessage) error {
	err := s.protoValidator.Validate(request)
	if err != nil {
		var valErr *protovalidate.ValidationError

		if ok := errors.As(err, &valErr); ok {
			protoErr := valErr.ToProto()
			grpcError, _ := status.New(codes.Aborted, "Validation failed").WithDetails(protoErr)

			return grpcError.Err()
		}

		return status.Errorf(codes.Aborted, "%v", err)
	}

	return nil
}

// Login logs a user in.
func (s AuthServer) Login(ctx context.Context, request *apb.LoginRequest) (*apb.LoginResponse, error) {
	labels := prometheus.Labels{"service": "auth", "rpc": "post"}
	start := metrics.StartHook(labels, grpcRequestsTotal)
	defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)

	err := s.checkForValidationErrors(request)
	if err != nil {
		return nil, err
	}

	user := rbac.User{
		UserName: request.Username,
		Password: request.Pwd,
	}

	// validation of credentials
	err = s.isValidUser(user)
	if err != nil {
		return nil, err
	}

	// generate token, persist session and return to user
	token, err := s.jwtManager.GenerateToken(user)
	if err != nil {
		return nil, err
	}

	userToUpdate, err := s.userService.Get(store.Query{Name: user.UserName})
	if err != nil {
		return nil, err
	}

	userToUpdate.AddToken(token)

	err = s.userService.Update(userToUpdate)
	if err != nil {
		return nil, err
	}

	return &apb.LoginResponse{
		Timestamp: time.Now().UnixNano(),
		Token:     token,
	}, nil
}

// Logout logs a user out.
func (s AuthServer) Logout(ctx context.Context, request *apb.LogoutRequest) (*apb.LogoutResponse, error) {
	labels := prometheus.Labels{"service": "auth", "rpc": "post"}
	start := metrics.StartHook(labels, grpcRequestsTotal)
	defer metrics.FinishHook(labels, start, grpcRequestDurationSecondsTotal, grpcRequestDurationSeconds)

	err := s.checkForValidationErrors(request)
	if err != nil {
		return nil, err
	}

	err = s.handleLogout(ctx, request.Username)
	if err != nil {
		return nil, err
	}

	return &apb.LogoutResponse{
		Timestamp: time.Now().UnixNano(),
	}, nil
}

// isValidUser checks if the provided user name fits to a stored one and then checks if the provided password is correct.
func (s AuthServer) isValidUser(user rbac.User) error {
	storedUser, err := s.userService.Get(store.Query{Name: user.Name()})
	if err != nil {
		return err
	}

	if storedUser.Name() == user.Name() {
		err := s.isCorrectPassword(storedUser.GetPassword(), storedUser.GetSalt(), user.Password)
		if err != nil {
			return err
		}
	}

	return nil
}

// isCorrectPassword checks if the provided password fits with the hashed user password taken from the storage.
func (s AuthServer) isCorrectPassword(storedPassword, salt, loginPassword string) error {
	hashedPasswordFromLogin := base64.RawStdEncoding.EncodeToString(argon2.IDKey([]byte(loginPassword), []byte(salt), 1, 64*1024, 4, 32))

	if storedPassword == hashedPasswordFromLogin {
		return nil
	}

	return status.Errorf(codes.Unauthenticated, "incorrect user name or password")
}

// handleLogout checks if the provided user name matches with the one associated with token and
// removed the token from all tokens of the user
func (s AuthServer) handleLogout(ctx context.Context, userName string) error {
	md, ok := metadata.FromIncomingContext(ctx)
	if !ok {
		return status.Errorf(codes.Aborted, "metadata is not provided")
	}

	if len(md["authorize"]) > 0 {
		token := md["authorize"][0]

		claims, err := s.jwtManager.GetClaimsFromToken(token)
		if err != nil {
			return err
		}

		if claims.Username != userName {
			return status.Errorf(codes.Aborted, "missing match of user associated to token and provided user name")
		}

		storedUser, err := s.userService.Get(store.Query{Name: userName})
		if err != nil {
			return err
		}

		storedTokens := storedUser.GetTokens()
		foundToken := false
		for _, storedToken := range storedTokens {
			if storedToken == token {
				storedUser.RemoveToken(token)
				foundToken = true
				break
			}
		}

		if !foundToken {
			return status.Errorf(codes.Aborted, "missing match of token provied for user")
		}

		err = s.userService.Update(&rbac.User{UserID: storedUser.ID(),
			UserName: storedUser.Name(),
			Roles:    storedUser.GetRoles(),
			Password: storedUser.GetPassword(),
			Tokens:   storedUser.GetTokens(),
			Salt:     storedUser.GetSalt(),
			Metadata: storedUser.GetMetadata(),
		})

		if err != nil {
			return err
		}
	}

	return nil
}