Skip to content
Snippets Groups Projects
Unverified Commit 665a5b62 authored by Maksim Nabokikh's avatar Maksim Nabokikh Committed by GitHub
Browse files

Override OIDC provider discovered claims (#3267)

parent 231a97d0
No related branches found
No related tags found
No related merge requests found
......@@ -27,6 +27,10 @@ type Config struct {
ClientSecret string `json:"clientSecret"`
RedirectURI string `json:"redirectURI"`
// The section to override options discovered automatically from
// the providers' discovery URL (.well-known/openid-configuration).
ProviderDiscoveryOverrides ProviderDiscoveryOverrides `json:"providerDiscoveryOverrides"`
// Causes client_secret to be passed as POST parameters instead of basic
// auth. This is specifically "NOT RECOMMENDED" by the OAuth2 RFC, but some
// providers require it.
......@@ -96,6 +100,61 @@ type Config struct {
} `json:"claimModifications"`
}
type ProviderDiscoveryOverrides struct {
// TokenURL provides a way to user overwrite the Token URL
// from the .well-known/openid-configuration token_endpoint
TokenURL string `json:"tokenURL"`
// AuthURL provides a way to user overwrite the Auth URL
// from the .well-known/openid-configuration authorization_endpoint
AuthURL string `json:"authURL"`
}
func (o *ProviderDiscoveryOverrides) Empty() bool {
return o.TokenURL == "" && o.AuthURL == ""
}
func getProvider(ctx context.Context, issuer string, overrides ProviderDiscoveryOverrides) (*oidc.Provider, error) {
provider, err := oidc.NewProvider(ctx, issuer)
if err != nil {
return nil, fmt.Errorf("failed to get provider: %v", err)
}
if overrides.Empty() {
return provider, nil
}
v := &struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
DeviceAuthURL string `json:"device_authorization_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
Algorithms []string `json:"id_token_signing_alg_values_supported"`
}{}
if err := provider.Claims(v); err != nil {
return nil, fmt.Errorf("failed to extract provider discovery claims: %v", err)
}
config := oidc.ProviderConfig{
IssuerURL: v.Issuer,
AuthURL: v.AuthURL,
TokenURL: v.TokenURL,
DeviceAuthURL: v.DeviceAuthURL,
JWKSURL: v.JWKSURL,
UserInfoURL: v.UserInfoURL,
Algorithms: v.Algorithms,
}
if overrides.TokenURL != "" {
config.TokenURL = overrides.TokenURL
}
if overrides.AuthURL != "" {
config.AuthURL = overrides.AuthURL
}
return config.NewProvider(context.Background()), nil
}
// NewGroupFromClaims creates a new group from a list of claims and appends it to the list of existing groups.
type NewGroupFromClaims struct {
// List of claim to join together
......@@ -152,13 +211,16 @@ func (c *Config) Open(id string, logger log.Logger) (conn connector.Connector, e
return nil, err
}
ctx, cancel := context.WithCancel(context.Background())
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
bgctx, cancel := context.WithCancel(context.Background())
ctx := context.WithValue(bgctx, oauth2.HTTPClient, httpClient)
provider, err := oidc.NewProvider(ctx, c.Issuer)
provider, err := getProvider(ctx, c.Issuer, c.ProviderDiscoveryOverrides)
if err != nil {
cancel()
return nil, fmt.Errorf("failed to get provider: %v", err)
return nil, err
}
if !c.ProviderDiscoveryOverrides.Empty() {
logger.Warnf("overrides for connector %q are set, this can be a vulnerability when not properly configured", id)
}
endpoint := provider.Endpoint()
......
......@@ -584,6 +584,57 @@ func TestTokenIdentity(t *testing.T) {
}
}
func TestProviderOverride(t *testing.T) {
testServer, err := setupServer(map[string]any{
"sub": "subvalue",
"name": "namevalue",
}, true)
if err != nil {
t.Fatal("failed to setup test server", err)
}
t.Run("No override", func(t *testing.T) {
conn, err := newConnector(Config{
Issuer: testServer.URL,
Scopes: []string{"openid", "groups"},
})
if err != nil {
t.Fatal("failed to create new connector", err)
}
expAuth := fmt.Sprintf("%s/authorize", testServer.URL)
if conn.provider.Endpoint().AuthURL != expAuth {
t.Fatalf("unexpected auth URL: %s, expected: %s\n", conn.provider.Endpoint().AuthURL, expAuth)
}
expToken := fmt.Sprintf("%s/token", testServer.URL)
if conn.provider.Endpoint().TokenURL != expToken {
t.Fatalf("unexpected token URL: %s, expected: %s\n", conn.provider.Endpoint().TokenURL, expToken)
}
})
t.Run("Override", func(t *testing.T) {
conn, err := newConnector(Config{
Issuer: testServer.URL,
Scopes: []string{"openid", "groups"},
ProviderDiscoveryOverrides: ProviderDiscoveryOverrides{TokenURL: "/test1", AuthURL: "/test2"},
})
if err != nil {
t.Fatal("failed to create new connector", err)
}
expAuth := "/test2"
if conn.provider.Endpoint().AuthURL != expAuth {
t.Fatalf("unexpected auth URL: %s, expected: %s\n", conn.provider.Endpoint().AuthURL, expAuth)
}
expToken := "/test1"
if conn.provider.Endpoint().TokenURL != expToken {
t.Fatalf("unexpected token URL: %s, expected: %s\n", conn.provider.Endpoint().TokenURL, expToken)
}
})
}
func setupServer(tok map[string]interface{}, idTokenDesired bool) (*httptest.Server, error) {
key, err := rsa.GenerateKey(rand.Reader, 1024)
if err != nil {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment