From affd4d4e495326ba5aeee81bb5ea88d22c39b419 Mon Sep 17 00:00:00 2001
From: Sean Liao <sean+git@liao.dev>
Date: Tue, 1 Aug 2023 09:32:16 +0100
Subject: [PATCH] verify access tokens by checking getuserinfo during a token
 exchange (#3031)

The provider.Verifier.Verify endpoint we were using only works with ID
tokens. This isn't an issue with systems which use ID tokens as access
tokens (e.g. dex), but for systems with opaque access tokens (e.g.
Google / GCP), those access tokens could not be verified.
Instead, check the access token against the getUserInfo endpoint.

Signed-off-by: Sean Liao <sean+git@liao.dev>
Co-authored-by: Maksim Nabokikh <max.nabokih@gmail.com>
---
 connector/oidc/oidc.go      | 27 +++++++++++++++++++--------
 connector/oidc/oidc_test.go |  5 +++++
 2 files changed, 24 insertions(+), 8 deletions(-)

diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go
index 14329c00..ff4713c2 100644
--- a/connector/oidc/oidc.go
+++ b/connector/oidc/oidc.go
@@ -301,6 +301,7 @@ func (c *oidcConnector) TokenIdentity(ctx context.Context, subjectTokenType, sub
 	var identity connector.Identity
 	token := &oauth2.Token{
 		AccessToken: subjectToken,
+		TokenType:   subjectTokenType,
 	}
 	return c.createIdentity(ctx, identity, token, exchangeCaller)
 }
@@ -318,20 +319,30 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
 			return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
 		}
 	} else if caller == exchangeCaller {
-		// AccessToken here could be either an id token or an access token
-		idToken, err := c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}).Verify(ctx, token.AccessToken)
-		if err != nil {
-			return identity, fmt.Errorf("oidc: failed to verify token: %v", err)
-		}
-		if err := idToken.Claims(&claims); err != nil {
-			return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
+		switch token.TokenType {
+		case "urn:ietf:params:oauth:token-type:id_token":
+			// Verify only works on ID tokens
+			idToken, err := c.provider.Verifier(&oidc.Config{SkipClientIDCheck: true}).Verify(ctx, token.AccessToken)
+			if err != nil {
+				return identity, fmt.Errorf("oidc: failed to verify token: %v", err)
+			}
+			if err := idToken.Claims(&claims); err != nil {
+				return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
+			}
+		case "urn:ietf:params:oauth:token-type:access_token":
+			if !c.getUserInfo {
+				return identity, fmt.Errorf("oidc: getUserInfo is required for access token exchange")
+			}
+		default:
+			return identity, fmt.Errorf("unknown token type for token exchange: %s", token.TokenType)
 		}
 	} else if caller != refreshCaller {
 		// ID tokens aren't mandatory in the reply when using a refresh_token grant
 		return identity, errors.New("oidc: no id_token in token response")
 	}
 
-	// We immediately want to run getUserInfo if configured before we validate the claims
+	// We immediately want to run getUserInfo if configured before we validate the claims.
+	// For token exchanges with access tokens, this is how we verify the token.
 	if c.getUserInfo {
 		userInfo, err := c.provider.UserInfo(ctx, oauth2.StaticTokenSource(token))
 		if err != nil {
diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go
index 5c5208a6..29e8875e 100644
--- a/connector/oidc/oidc_test.go
+++ b/connector/oidc/oidc_test.go
@@ -441,6 +441,7 @@ func TestTokenIdentity(t *testing.T) {
 		name        string
 		subjectType string
 		userInfo    bool
+		expectError bool
 	}{
 		{
 			name:        "id_token",
@@ -448,6 +449,7 @@ func TestTokenIdentity(t *testing.T) {
 		}, {
 			name:        "access_token",
 			subjectType: tokenTypeAccess,
+			expectError: true,
 		}, {
 			name:        "id_token with user info",
 			subjectType: tokenTypeID,
@@ -494,6 +496,9 @@ func TestTokenIdentity(t *testing.T) {
 			origToken := tokenResponse[long2short[tc.subjectType]].(string)
 			identity, err := conn.TokenIdentity(ctx, tc.subjectType, origToken)
 			if err != nil {
+				if tc.expectError {
+					return
+				}
 				t.Fatal("failed to get token identity", err)
 			}
 
-- 
GitLab