diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index 4bf423abc564cdd3c8766f041a20b51b2b340dc5..2c59d85f775f944ca1ac4ea1cab735c1efbadc8f 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -9,6 +9,7 @@ import ( "log/slog" "net/http" "net/url" + "regexp" "strings" "time" @@ -97,6 +98,7 @@ type Config struct { // ClaimMutations holds all claim mutations options ClaimMutations struct { NewGroupFromClaims []NewGroupFromClaims `json:"newGroupFromClaims"` + FilterGroupClaims FilterGroupClaims `json:"filterGroupClaims"` } `json:"claimModifications"` } @@ -176,6 +178,12 @@ type NewGroupFromClaims struct { Prefix string `json:"prefix"` } +// FilterGroupClaims is a regex filter for to keep only the matching groups. +// This is useful when the groups list is too large to fit within an HTTP header. +type FilterGroupClaims struct { + GroupsFilter string `json:"groupsFilter"` +} + // Domains that don't support basic auth. golang.org/x/oauth2 has an internal // list, but it only matches specific URLs, not top level domains. var brokenAuthHeaderDomains = []string{ @@ -252,6 +260,14 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, promptType = *c.PromptType } + var groupsFilter *regexp.Regexp + if c.ClaimMutations.FilterGroupClaims.GroupsFilter != "" { + groupsFilter, err = regexp.Compile(c.ClaimMutations.FilterGroupClaims.GroupsFilter) + if err != nil { + logger.Warnf("ignoring invalid regex `%s`", c.ClaimMutations.FilterGroupClaims.GroupsFilter) + } + } + clientID := c.ClientID return &oidcConnector{ provider: provider, @@ -283,6 +299,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, emailKey: c.ClaimMapping.EmailKey, groupsKey: c.ClaimMapping.GroupsKey, newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, + groupsFilter: groupsFilter, }, nil } @@ -312,6 +329,7 @@ type oidcConnector struct { emailKey string groupsKey string newGroupFromClaims []NewGroupFromClaims + groupsFilter *regexp.Regexp } func (c *oidcConnector) Close() error { @@ -518,6 +536,9 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I if found { for _, v := range vs { if s, ok := v.(string); ok { + if c.groupsFilter != nil && !c.groupsFilter.MatchString(s) { + continue + } groups = append(groups, s) } else { return identity, fmt.Errorf("malformed \"%v\" claim", groupsKey) diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go index 07291f7e5b931b0ea457feb2eb5c5861912f0471..66b35c3feff7aad47a964972bc9e32d6b35017d1 100644 --- a/connector/oidc/oidc_test.go +++ b/connector/oidc/oidc_test.go @@ -64,6 +64,7 @@ func TestHandleCallback(t *testing.T) { expectPreferredUsername string expectedEmailField string token map[string]interface{} + groupsRegex string newGroupFromClaims []NewGroupFromClaims }{ { @@ -364,6 +365,23 @@ func TestHandleCallback(t *testing.T) { "non-string-claim2": 666, }, }, + { + name: "filterGroupClaims", + userIDKey: "", // not configured + userNameKey: "", // not configured + groupsRegex: `^.*\d$`, + expectUserID: "subvalue", + expectUserName: "namevalue", + expectGroups: []string{"group1", "group2"}, + expectedEmailField: "emailvalue", + token: map[string]interface{}{ + "sub": "subvalue", + "name": "namevalue", + "groups": []string{"group1", "group2", "groupA", "groupB"}, + "email": "emailvalue", + "email_verified": true, + }, + }, } for _, tc := range tests { @@ -400,6 +418,7 @@ func TestHandleCallback(t *testing.T) { config.ClaimMapping.EmailKey = tc.emailKey config.ClaimMapping.GroupsKey = tc.groupsKey config.ClaimMutations.NewGroupFromClaims = tc.newGroupFromClaims + config.ClaimMutations.FilterGroupClaims.GroupsFilter = tc.groupsRegex conn, err := newConnector(config) if err != nil {