Skip to content
Snippets Groups Projects
Unverified Commit 6ceb2650 authored by Márk Sági-Kazár's avatar Márk Sági-Kazár Committed by GitHub
Browse files

Merge pull request #3063 from jacksonargo/oidc-group-regex

add regex for oidc group matching
parents 36e6e081 5df16057
No related branches found
No related tags found
No related merge requests found
...@@ -9,6 +9,7 @@ import ( ...@@ -9,6 +9,7 @@ import (
"log/slog" "log/slog"
"net/http" "net/http"
"net/url" "net/url"
"regexp"
"strings" "strings"
"time" "time"
...@@ -97,6 +98,7 @@ type Config struct { ...@@ -97,6 +98,7 @@ type Config struct {
// ClaimMutations holds all claim mutations options // ClaimMutations holds all claim mutations options
ClaimMutations struct { ClaimMutations struct {
NewGroupFromClaims []NewGroupFromClaims `json:"newGroupFromClaims"` NewGroupFromClaims []NewGroupFromClaims `json:"newGroupFromClaims"`
FilterGroupClaims FilterGroupClaims `json:"filterGroupClaims"`
} `json:"claimModifications"` } `json:"claimModifications"`
} }
...@@ -176,6 +178,12 @@ type NewGroupFromClaims struct { ...@@ -176,6 +178,12 @@ type NewGroupFromClaims struct {
Prefix string `json:"prefix"` Prefix string `json:"prefix"`
} }
// FilterGroupClaims is a regex filter for to keep only the matching groups.
// This is useful when the groups list is too large to fit within an HTTP header.
type FilterGroupClaims struct {
GroupsFilter string `json:"groupsFilter"`
}
// Domains that don't support basic auth. golang.org/x/oauth2 has an internal // Domains that don't support basic auth. golang.org/x/oauth2 has an internal
// list, but it only matches specific URLs, not top level domains. // list, but it only matches specific URLs, not top level domains.
var brokenAuthHeaderDomains = []string{ var brokenAuthHeaderDomains = []string{
...@@ -252,6 +260,14 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, ...@@ -252,6 +260,14 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
promptType = *c.PromptType promptType = *c.PromptType
} }
var groupsFilter *regexp.Regexp
if c.ClaimMutations.FilterGroupClaims.GroupsFilter != "" {
groupsFilter, err = regexp.Compile(c.ClaimMutations.FilterGroupClaims.GroupsFilter)
if err != nil {
logger.Warnf("ignoring invalid regex `%s`", c.ClaimMutations.FilterGroupClaims.GroupsFilter)
}
}
clientID := c.ClientID clientID := c.ClientID
return &oidcConnector{ return &oidcConnector{
provider: provider, provider: provider,
...@@ -283,6 +299,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector, ...@@ -283,6 +299,7 @@ func (c *Config) Open(id string, logger *slog.Logger) (conn connector.Connector,
emailKey: c.ClaimMapping.EmailKey, emailKey: c.ClaimMapping.EmailKey,
groupsKey: c.ClaimMapping.GroupsKey, groupsKey: c.ClaimMapping.GroupsKey,
newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims, newGroupFromClaims: c.ClaimMutations.NewGroupFromClaims,
groupsFilter: groupsFilter,
}, nil }, nil
} }
...@@ -312,6 +329,7 @@ type oidcConnector struct { ...@@ -312,6 +329,7 @@ type oidcConnector struct {
emailKey string emailKey string
groupsKey string groupsKey string
newGroupFromClaims []NewGroupFromClaims newGroupFromClaims []NewGroupFromClaims
groupsFilter *regexp.Regexp
} }
func (c *oidcConnector) Close() error { func (c *oidcConnector) Close() error {
...@@ -518,6 +536,9 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I ...@@ -518,6 +536,9 @@ func (c *oidcConnector) createIdentity(ctx context.Context, identity connector.I
if found { if found {
for _, v := range vs { for _, v := range vs {
if s, ok := v.(string); ok { if s, ok := v.(string); ok {
if c.groupsFilter != nil && !c.groupsFilter.MatchString(s) {
continue
}
groups = append(groups, s) groups = append(groups, s)
} else { } else {
return identity, fmt.Errorf("malformed \"%v\" claim", groupsKey) return identity, fmt.Errorf("malformed \"%v\" claim", groupsKey)
......
...@@ -64,6 +64,7 @@ func TestHandleCallback(t *testing.T) { ...@@ -64,6 +64,7 @@ func TestHandleCallback(t *testing.T) {
expectPreferredUsername string expectPreferredUsername string
expectedEmailField string expectedEmailField string
token map[string]interface{} token map[string]interface{}
groupsRegex string
newGroupFromClaims []NewGroupFromClaims newGroupFromClaims []NewGroupFromClaims
}{ }{
{ {
...@@ -364,6 +365,23 @@ func TestHandleCallback(t *testing.T) { ...@@ -364,6 +365,23 @@ func TestHandleCallback(t *testing.T) {
"non-string-claim2": 666, "non-string-claim2": 666,
}, },
}, },
{
name: "filterGroupClaims",
userIDKey: "", // not configured
userNameKey: "", // not configured
groupsRegex: `^.*\d$`,
expectUserID: "subvalue",
expectUserName: "namevalue",
expectGroups: []string{"group1", "group2"},
expectedEmailField: "emailvalue",
token: map[string]interface{}{
"sub": "subvalue",
"name": "namevalue",
"groups": []string{"group1", "group2", "groupA", "groupB"},
"email": "emailvalue",
"email_verified": true,
},
},
} }
for _, tc := range tests { for _, tc := range tests {
...@@ -400,6 +418,7 @@ func TestHandleCallback(t *testing.T) { ...@@ -400,6 +418,7 @@ func TestHandleCallback(t *testing.T) {
config.ClaimMapping.EmailKey = tc.emailKey config.ClaimMapping.EmailKey = tc.emailKey
config.ClaimMapping.GroupsKey = tc.groupsKey config.ClaimMapping.GroupsKey = tc.groupsKey
config.ClaimMutations.NewGroupFromClaims = tc.newGroupFromClaims config.ClaimMutations.NewGroupFromClaims = tc.newGroupFromClaims
config.ClaimMutations.FilterGroupClaims.GroupsFilter = tc.groupsRegex
conn, err := newConnector(config) conn, err := newConnector(config)
if err != nil { if err != nil {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment