Skip to content
Snippets Groups Projects
handlers.go 52.1 KiB
Newer Older
  • Learn to ignore specific revisions
  • Eric Chiang's avatar
    Eric Chiang committed
    		for _, scope := range authCode.Scopes {
    			if scope == scopeOfflineAccess {
    				return true
    			}
    		}
    		return false
    	}()
    	var refreshToken string
    	if reqRefresh {
    
    Eric Chiang's avatar
    Eric Chiang committed
    		refresh := storage.RefreshToken{
    
    			ID:            storage.NewID(),
    			Token:         storage.NewID(),
    			ClientID:      authCode.ClientID,
    			ConnectorID:   authCode.ConnectorID,
    			Scopes:        authCode.Scopes,
    			Claims:        authCode.Claims,
    			Nonce:         authCode.Nonce,
    			ConnectorData: authCode.ConnectorData,
    			CreatedAt:     s.now(),
    			LastUsed:      s.now(),
    
    Eric Chiang's avatar
    Eric Chiang committed
    		}
    
    		token := &internal.RefreshToken{
    			RefreshId: refresh.ID,
    			Token:     refresh.Token,
    		}
    		if refreshToken, err = internal.Marshal(token); err != nil {
    
    			s.logger.ErrorContext(ctx, "failed to marshal refresh token", "err", err)
    
    			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    
    			return nil, err
    
    		if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
    
    			s.logger.ErrorContext(ctx, "failed to create refresh token", "err", err)
    
    			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    
    			return nil, err
    
    Eric Chiang's avatar
    Eric Chiang committed
    		}
    
    
    		// deleteToken determines if we need to delete the newly created refresh token
    		// due to a failure in updating/creating the OfflineSession object for the
    		// corresponding user.
    		var deleteToken bool
    		defer func() {
    			if deleteToken {
    				// Delete newly created refresh token from storage.
    
    				if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil {
    
    					s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err)
    
    					s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    					return
    				}
    			}
    		}()
    
    		tokenRef := storage.RefreshTokenRef{
    			ID:        refresh.ID,
    			ClientID:  refresh.ClientID,
    			CreatedAt: refresh.CreatedAt,
    			LastUsed:  refresh.LastUsed,
    		}
    
    		// Try to retrieve an existing OfflineSession object for the corresponding user.
    
    		if session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID); err != nil {
    
    			if err != storage.ErrNotFound {
    
    				s.logger.ErrorContext(ctx, "failed to get offline session", "err", err)
    
    				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    				deleteToken = true
    
    				return nil, err
    
    			}
    			offlineSessions := storage.OfflineSessions{
    				UserID:  refresh.Claims.UserID,
    				ConnID:  refresh.ConnectorID,
    				Refresh: make(map[string]*storage.RefreshTokenRef),
    			}
    			offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
    
    			// Create a new OfflineSession object for the user and add a reference object for
    
    			// the newly received refreshtoken.
    
    			if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
    
    				s.logger.ErrorContext(ctx, "failed to create offline session", "err", err)
    
    				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    				deleteToken = true
    
    				return nil, err
    
    			}
    		} else {
    			if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
    				// Delete old refresh token from storage.
    
    				if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil && err != storage.ErrNotFound {
    
    					s.logger.ErrorContext(ctx, "failed to delete refresh token", "err", err)
    
    					s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    					deleteToken = true
    
    					return nil, err
    
    				}
    			}
    
    			// Update existing OfflineSession obj with new RefreshTokenRef.
    
    			if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
    
    				old.Refresh[tokenRef.ClientID] = &tokenRef
    				return old, nil
    			}); err != nil {
    
    				s.logger.ErrorContext(ctx, "failed to update offline session", "err", err)
    
    				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    				deleteToken = true
    
    				return nil, err
    
    Eric Chiang's avatar
    Eric Chiang committed
    	}
    
    	return s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry), nil
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    
    func (s *Server) handleUserInfo(w http.ResponseWriter, r *http.Request) {
    
    	const prefix = "Bearer "
    
    	auth := r.Header.Get("authorization")
    	if len(auth) < len(prefix) || !strings.EqualFold(prefix, auth[:len(prefix)]) {
    		w.Header().Set("WWW-Authenticate", "Bearer")
    		s.tokenErrHelper(w, errAccessDenied, "Invalid bearer token.", http.StatusUnauthorized)
    
    	rawIDToken := auth[len(prefix):]
    
    	verifier := oidc.NewVerifier(s.issuerURL.String(), &storageKeySet{s.storage}, &oidc.Config{SkipClientIDCheck: true})
    
    	idToken, err := verifier.Verify(ctx, rawIDToken)
    
    	if err != nil {
    
    		s.tokenErrHelper(w, errAccessDenied, err.Error(), http.StatusForbidden)
    
    	var claims json.RawMessage
    	if err := idToken.Claims(&claims); err != nil {
    		s.tokenErrHelper(w, errServerError, err.Error(), http.StatusInternalServerError)
    		return
    
    	w.Header().Set("Content-Type", "application/json")
    	w.Write(claims)
    
    func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, client storage.Client) {
    
    	// Parse the fields
    	if err := r.ParseForm(); err != nil {
    		s.tokenErrHelper(w, errInvalidRequest, "Couldn't parse data", http.StatusBadRequest)
    		return
    	}
    	q := r.Form
    
    	nonce := q.Get("nonce")
    	// Some clients, like the old go-oidc, provide extra whitespace. Tolerate this.
    	scopes := strings.Fields(q.Get("scope"))
    
    	// Parse the scopes if they are passed
    	var (
    		unrecognized  []string
    		invalidScopes []string
    	)
    	hasOpenIDScope := false
    	for _, scope := range scopes {
    		switch scope {
    		case scopeOpenID:
    			hasOpenIDScope = true
    		case scopeOfflineAccess, scopeEmail, scopeProfile, scopeGroups, scopeFederatedID:
    		default:
    			peerID, ok := parseCrossClientScope(scope)
    			if !ok {
    				unrecognized = append(unrecognized, scope)
    				continue
    			}
    
    
    			isTrusted, err := s.validateCrossClientTrust(ctx, client.ID, peerID)
    
    			if err != nil {
    				s.tokenErrHelper(w, errInvalidClient, fmt.Sprintf("Error validating cross client trust %v.", err), http.StatusBadRequest)
    				return
    			}
    			if !isTrusted {
    				invalidScopes = append(invalidScopes, scope)
    			}
    		}
    	}
    	if !hasOpenIDScope {
    		s.tokenErrHelper(w, errInvalidRequest, `Missing required scope(s) ["openid"].`, http.StatusBadRequest)
    		return
    	}
    	if len(unrecognized) > 0 {
    		s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Unrecognized scope(s) %q", unrecognized), http.StatusBadRequest)
    		return
    	}
    	if len(invalidScopes) > 0 {
    		s.tokenErrHelper(w, errInvalidRequest, fmt.Sprintf("Client can't request scope(s) %q", invalidScopes), http.StatusBadRequest)
    		return
    	}
    
    	// Which connector
    	connID := s.passwordConnector
    
    	conn, err := s.getConnector(ctx, connID)
    
    	if err != nil {
    		s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
    		return
    	}
    
    	passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
    	if !ok {
    		s.tokenErrHelper(w, errInvalidRequest, "Requested password connector does not correct type.", http.StatusBadRequest)
    		return
    	}
    
    	// Login
    	username := q.Get("username")
    	password := q.Get("password")
    
    	identity, ok, err := passwordConnector.Login(ctx, parseScopes(scopes), username, password)
    
    	if err != nil {
    
    		s.logger.ErrorContext(r.Context(), "failed to login user", "err", err)
    
    		s.tokenErrHelper(w, errInvalidRequest, "Could not login user", http.StatusBadRequest)
    		return
    	}
    	if !ok {
    		s.tokenErrHelper(w, errAccessDenied, "Invalid username or password", http.StatusUnauthorized)
    		return
    	}
    
    	// Build the claims to send the id token
    	claims := storage.Claims{
    		UserID:            identity.UserID,
    		Username:          identity.Username,
    		PreferredUsername: identity.PreferredUsername,
    		Email:             identity.Email,
    		EmailVerified:     identity.EmailVerified,
    		Groups:            identity.Groups,
    	}
    
    
    	accessToken, _, err := s.newAccessToken(ctx, client.ID, claims, scopes, nonce, connID)
    
    		s.logger.ErrorContext(r.Context(), "password grant failed to create new access token", "err", err)
    
    		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    		return
    	}
    
    
    	idToken, expiry, err := s.newIDToken(ctx, client.ID, claims, scopes, nonce, accessToken, "", connID)
    
    	if err != nil {
    
    		s.logger.ErrorContext(r.Context(), "password grant failed to create new ID token", "err", err)
    
    		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    
    		return
    	}
    
    	reqRefresh := func() bool {
    		// Ensure the connector supports refresh tokens.
    		//
    		// Connectors like `saml` do not implement RefreshConnector.
    		_, ok := conn.Connector.(connector.RefreshConnector)
    		if !ok {
    			return false
    		}
    
    		for _, scope := range scopes {
    			if scope == scopeOfflineAccess {
    				return true
    			}
    		}
    		return false
    	}()
    	var refreshToken string
    	if reqRefresh {
    		refresh := storage.RefreshToken{
    			ID:          storage.NewID(),
    			Token:       storage.NewID(),
    
    			ClientID:    client.ID,
    
    			ConnectorID: connID,
    			Scopes:      scopes,
    			Claims:      claims,
    			Nonce:       nonce,
    			// ConnectorData: authCode.ConnectorData,
    			CreatedAt: s.now(),
    			LastUsed:  s.now(),
    		}
    		token := &internal.RefreshToken{
    			RefreshId: refresh.ID,
    			Token:     refresh.Token,
    		}
    		if refreshToken, err = internal.Marshal(token); err != nil {
    
    			s.logger.ErrorContext(r.Context(), "failed to marshal refresh token", "err", err)
    
    			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    			return
    		}
    
    
    		if err := s.storage.CreateRefresh(ctx, refresh); err != nil {
    
    			s.logger.ErrorContext(r.Context(), "failed to create refresh token", "err", err)
    
    			s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    			return
    		}
    
    		// deleteToken determines if we need to delete the newly created refresh token
    		// due to a failure in updating/creating the OfflineSession object for the
    		// corresponding user.
    		var deleteToken bool
    		defer func() {
    			if deleteToken {
    				// Delete newly created refresh token from storage.
    
    				if err := s.storage.DeleteRefresh(ctx, refresh.ID); err != nil {
    
    					s.logger.ErrorContext(r.Context(), "failed to delete refresh token", "err", err)
    
    					s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    					return
    				}
    			}
    		}()
    
    		tokenRef := storage.RefreshTokenRef{
    			ID:        refresh.ID,
    			ClientID:  refresh.ClientID,
    			CreatedAt: refresh.CreatedAt,
    			LastUsed:  refresh.LastUsed,
    		}
    
    		// Try to retrieve an existing OfflineSession object for the corresponding user.
    
    		if session, err := s.storage.GetOfflineSessions(ctx, refresh.Claims.UserID, refresh.ConnectorID); err != nil {
    
    			if err != storage.ErrNotFound {
    
    				s.logger.ErrorContext(r.Context(), "failed to get offline session", "err", err)
    
    				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    				deleteToken = true
    				return
    			}
    			offlineSessions := storage.OfflineSessions{
    
    				UserID:        refresh.Claims.UserID,
    				ConnID:        refresh.ConnectorID,
    				Refresh:       make(map[string]*storage.RefreshTokenRef),
    				ConnectorData: identity.ConnectorData,
    
    			}
    			offlineSessions.Refresh[tokenRef.ClientID] = &tokenRef
    
    			// Create a new OfflineSession object for the user and add a reference object for
    			// the newly received refreshtoken.
    
    			if err := s.storage.CreateOfflineSessions(ctx, offlineSessions); err != nil {
    
    				s.logger.ErrorContext(r.Context(), "failed to create offline session", "err", err)
    
    				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    				deleteToken = true
    				return
    			}
    		} else {
    			if oldTokenRef, ok := session.Refresh[tokenRef.ClientID]; ok {
    				// Delete old refresh token from storage.
    
    				if err := s.storage.DeleteRefresh(ctx, oldTokenRef.ID); err != nil {
    
    						s.logger.Warn("database inconsistent, refresh token missing", "token_id", oldTokenRef.ID)
    
    						s.logger.ErrorContext(r.Context(), "failed to delete refresh token", "err", err)
    
    						s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    						deleteToken = true
    						return
    					}
    
    				}
    			}
    
    			// Update existing OfflineSession obj with new RefreshTokenRef.
    
    			if err := s.storage.UpdateOfflineSessions(ctx, session.UserID, session.ConnID, func(old storage.OfflineSessions) (storage.OfflineSessions, error) {
    
    				old.Refresh[tokenRef.ClientID] = &tokenRef
    
    				old.ConnectorData = identity.ConnectorData
    
    				return old, nil
    			}); err != nil {
    
    				s.logger.ErrorContext(r.Context(), "failed to update offline session", "err", err)
    
    				s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    				deleteToken = true
    				return
    			}
    		}
    	}
    
    
    	resp := s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry)
    	s.writeAccessToken(w, resp)
    
    func (s *Server) handleTokenExchange(w http.ResponseWriter, r *http.Request, client storage.Client) {
    	ctx := r.Context()
    
    	if err := r.ParseForm(); err != nil {
    
    		s.logger.ErrorContext(r.Context(), "could not parse request body", "err", err)
    
    		s.tokenErrHelper(w, errInvalidRequest, "", http.StatusBadRequest)
    		return
    	}
    	q := r.Form
    
    	scopes := strings.Fields(q.Get("scope"))            // OPTIONAL, map to issued token scope
    	requestedTokenType := q.Get("requested_token_type") // OPTIONAL, default to access token
    	if requestedTokenType == "" {
    		requestedTokenType = tokenTypeAccess
    	}
    	subjectToken := q.Get("subject_token")          // REQUIRED
    	subjectTokenType := q.Get("subject_token_type") // REQUIRED
    	connID := q.Get("connector_id")                 // REQUIRED, not in RFC
    
    	switch subjectTokenType {
    	case tokenTypeID, tokenTypeAccess: // ok, continue
    	default:
    		s.tokenErrHelper(w, errRequestNotSupported, "Invalid subject_token_type.", http.StatusBadRequest)
    		return
    	}
    
    	if subjectToken == "" {
    		s.tokenErrHelper(w, errInvalidRequest, "Missing subject_token", http.StatusBadRequest)
    		return
    	}
    
    
    	conn, err := s.getConnector(ctx, connID)
    
    	if err != nil {
    
    		s.logger.ErrorContext(r.Context(), "failed to get connector", "err", err)
    
    		s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
    		return
    	}
    	teConn, ok := conn.Connector.(connector.TokenIdentityConnector)
    	if !ok {
    
    		s.logger.ErrorContext(r.Context(), "connector doesn't implement token exchange", "connector_id", connID)
    
    		s.tokenErrHelper(w, errInvalidRequest, "Requested connector does not exist.", http.StatusBadRequest)
    		return
    	}
    	identity, err := teConn.TokenIdentity(ctx, subjectTokenType, subjectToken)
    	if err != nil {
    
    		s.logger.ErrorContext(r.Context(), "failed to verify subject token", "err", err)
    
    		s.tokenErrHelper(w, errAccessDenied, "", http.StatusUnauthorized)
    		return
    	}
    
    	claims := storage.Claims{
    		UserID:            identity.UserID,
    		Username:          identity.Username,
    		PreferredUsername: identity.PreferredUsername,
    		Email:             identity.Email,
    		EmailVerified:     identity.EmailVerified,
    		Groups:            identity.Groups,
    	}
    	resp := accessTokenResponse{
    		IssuedTokenType: requestedTokenType,
    		TokenType:       "bearer",
    	}
    	var expiry time.Time
    	switch requestedTokenType {
    	case tokenTypeID:
    
    		resp.AccessToken, expiry, err = s.newIDToken(r.Context(), client.ID, claims, scopes, "", "", "", connID)
    
    	case tokenTypeAccess:
    
    		resp.AccessToken, expiry, err = s.newAccessToken(r.Context(), client.ID, claims, scopes, "", connID)
    
    	default:
    		s.tokenErrHelper(w, errRequestNotSupported, "Invalid requested_token_type.", http.StatusBadRequest)
    		return
    	}
    	if err != nil {
    
    		s.logger.ErrorContext(r.Context(), "token exchange failed to create new token", "requested_token_type", requestedTokenType, "err", err)
    
    		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    		return
    	}
    	resp.ExpiresIn = int(time.Until(expiry).Seconds())
    
    	// Token response must include cache headers https://tools.ietf.org/html/rfc6749#section-5.1
    	w.Header().Set("Cache-Control", "no-store")
    	w.Header().Set("Pragma", "no-cache")
    	w.Header().Set("Content-Type", "application/json")
    	json.NewEncoder(w).Encode(resp)
    }
    
    
    Josh Soref's avatar
    Josh Soref committed
    type accessTokenResponse struct {
    
    	AccessToken     string `json:"access_token"`
    	IssuedTokenType string `json:"issued_token_type,omitempty"`
    	TokenType       string `json:"token_type"`
    	ExpiresIn       int    `json:"expires_in,omitempty"`
    	RefreshToken    string `json:"refresh_token,omitempty"`
    	IDToken         string `json:"id_token,omitempty"`
    	Scope           string `json:"scope,omitempty"`
    
    Josh Soref's avatar
    Josh Soref committed
    func (s *Server) toAccessTokenResponse(idToken, accessToken, refreshToken string, expiry time.Time) *accessTokenResponse {
    	return &accessTokenResponse{
    
    		AccessToken:  accessToken,
    		TokenType:    "bearer",
    		ExpiresIn:    int(expiry.Sub(s.now()).Seconds()),
    		RefreshToken: refreshToken,
    		IDToken:      idToken,
    
    Eric Chiang's avatar
    Eric Chiang committed
    	}
    
    Josh Soref's avatar
    Josh Soref committed
    func (s *Server) writeAccessToken(w http.ResponseWriter, resp *accessTokenResponse) {
    
    Eric Chiang's avatar
    Eric Chiang committed
    	data, err := json.Marshal(resp)
    	if err != nil {
    
    		// TODO(nabokihms): error with context
    
    		s.logger.Error("failed to marshal access token response", "err", err)
    
    		s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
    
    Eric Chiang's avatar
    Eric Chiang committed
    		return
    	}
    	w.Header().Set("Content-Type", "application/json")
    	w.Header().Set("Content-Length", strconv.Itoa(len(data)))
    
    
    	// Token response must include cache headers https://tools.ietf.org/html/rfc6749#section-5.1
    	w.Header().Set("Cache-Control", "no-store")
    	w.Header().Set("Pragma", "no-cache")
    
    Eric Chiang's avatar
    Eric Chiang committed
    	w.Write(data)
    }
    
    
    func (s *Server) renderError(r *http.Request, w http.ResponseWriter, status int, description string) {
    	if err := s.templates.err(r, w, status, description); err != nil {
    
    		s.logger.ErrorContext(r.Context(), "server template error", "err", err)
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    
    func (s *Server) tokenErrHelper(w http.ResponseWriter, typ string, description string, statusCode int) {
    	if err := tokenErr(w, typ, description, statusCode); err != nil {
    
    		// TODO(nabokihms): error with context
    
    		s.logger.Error("token error response", "err", err)
    
    
    // Check for username prompt override from connector. Defaults to "Username".
    func usernamePrompt(conn connector.PasswordConnector) string {
    	if attr := conn.Prompt(); attr != "" {
    		return attr
    	}
    	return "Username"
    }