Skip to content
Snippets Groups Projects
jwtManager.go 1.85 KiB
Newer Older
  • Learn to ignore specific revisions
  • package rbac
    
    import (
    	"time"
    
    	"github.com/golang-jwt/jwt"
    	"google.golang.org/grpc/codes"
    	"google.golang.org/grpc/status"
    )
    
    var signingMethod = jwt.SigningMethodHS256 //jwt.SigningMethodPS256.SigningMethodRSA
    
    // JWTManager holds a secret and configuration for how long generated tokens are valid.
    type JWTManager struct {
    	secretKey     string
    	tokenDuration time.Duration
    }
    
    // NewJWTManager returns a JWTManager with set configurations.
    func NewJWTManager(secretKey string, tokenDuration time.Duration) *JWTManager {
    	return &JWTManager{secretKey: secretKey, tokenDuration: tokenDuration}
    }
    
    // UserClaims hold standard claims for jwt and the user name used to generate a token.
    type UserClaims struct {
    	jwt.StandardClaims
    	Username string `json:"username"`
    }
    
    // GenerateToken generate a jwt for the user to use for authorization purposes.
    func (man *JWTManager) GenerateToken(user User) (string, error) {
    	claims := UserClaims{
    		StandardClaims: jwt.StandardClaims{ExpiresAt: time.Now().Add(man.tokenDuration).Unix()},
    		Username:       user.GetName(),
    	}
    
    	token := jwt.NewWithClaims(signingMethod, claims)
    	return token.SignedString([]byte(man.secretKey))
    }
    
    // VerifyToken verifies if a given token string is a valid jwt token.
    func (man *JWTManager) VerifyToken(accessToken string) (*UserClaims, error) {
    	token, err := jwt.ParseWithClaims(
    		accessToken,
    		&UserClaims{},
    		func(token *jwt.Token) (interface{}, error) {
    			_, ok := token.Method.(*jwt.SigningMethodHMAC)
    			if !ok {
    				return nil, status.Errorf(codes.Unauthenticated, "unexpected token signing method")
    			}
    
    			return []byte(man.secretKey), nil
    		},
    	)
    
    	if err != nil {
    		return nil, status.Errorf(codes.Unauthenticated, "invalid token: %v", err)
    	}
    
    	claims, ok := token.Claims.(*UserClaims)
    	if !ok {
    		return nil, status.Errorf(codes.Unauthenticated, "invalid token claims %v", ok)
    	}
    
    	return claims, nil
    }