From 29b3523e07b61d006914ac9e24f317610c7eda16 Mon Sep 17 00:00:00 2001
From: Marcelo Clavel <mclavel00@gmail.com>
Date: Thu, 1 Sep 2022 07:46:24 -0400
Subject: [PATCH] feat(connector/authproxy): support multiple groups (#2643)

Signed-off-by: Marcelo Clavel <mclavel00@gmail.com>
---
 connector/authproxy/authproxy.go      |   7 +-
 connector/authproxy/authproxy_test.go | 134 ++++++++++++++++++++++++++
 2 files changed, 140 insertions(+), 1 deletion(-)
 create mode 100644 connector/authproxy/authproxy_test.go

diff --git a/connector/authproxy/authproxy.go b/connector/authproxy/authproxy.go
index 487a3f60..87154121 100644
--- a/connector/authproxy/authproxy.go
+++ b/connector/authproxy/authproxy.go
@@ -7,6 +7,7 @@ import (
 	"fmt"
 	"net/http"
 	"net/url"
+	"strings"
 
 	"github.com/dexidp/dex/connector"
 	"github.com/dexidp/dex/pkg/log"
@@ -69,7 +70,11 @@ func (m *callback) HandleCallback(s connector.Scopes, r *http.Request) (connecto
 	groups := m.groups
 	headerGroup := r.Header.Get(m.groupHeader)
 	if headerGroup != "" {
-		groups = append(groups, headerGroup)
+		splitheaderGroup := strings.Split(headerGroup, ",")
+		for i, v := range splitheaderGroup {
+			splitheaderGroup[i] = strings.TrimSpace(v)
+		}
+		groups = append(splitheaderGroup, groups...)
 	}
 	return connector.Identity{
 		UserID:        remoteUser, // TODO: figure out if this is a bad ID value.
diff --git a/connector/authproxy/authproxy_test.go b/connector/authproxy/authproxy_test.go
new file mode 100644
index 00000000..5d42530e
--- /dev/null
+++ b/connector/authproxy/authproxy_test.go
@@ -0,0 +1,134 @@
+package authproxy
+
+import (
+	"io"
+	"net/http"
+	"reflect"
+	"testing"
+
+	"github.com/sirupsen/logrus"
+
+	"github.com/dexidp/dex/connector"
+)
+
+const (
+	testEmail        = "testuser@example.com"
+	testGroup1       = "group1"
+	testGroup2       = "group2"
+	testGroup3       = "group 3"
+	testGroup4       = "group 4"
+	testStaticGroup1 = "static1"
+	testStaticGroup2 = "static 2"
+)
+
+var logger = &logrus.Logger{Out: io.Discard, Formatter: &logrus.TextFormatter{}}
+
+func TestUser(t *testing.T) {
+	config := Config{
+		UserHeader: "X-Remote-User",
+	}
+	conn := callback{userHeader: config.UserHeader, logger: logger, pathSuffix: "/test"}
+
+	req, err := http.NewRequest("GET", "/", nil)
+	expectNil(t, err)
+	req.Header = map[string][]string{
+		"X-Remote-User": {testEmail},
+	}
+
+	ident, err := conn.HandleCallback(connector.Scopes{OfflineAccess: true, Groups: true}, req)
+	expectNil(t, err)
+
+	expectEquals(t, ident.UserID, testEmail)
+	expectEquals(t, ident.Email, testEmail)
+	expectEquals(t, len(ident.Groups), 0)
+}
+
+func TestSingleGroup(t *testing.T) {
+	config := Config{
+		UserHeader:  "X-Remote-User",
+		GroupHeader: "X-Remote-Group",
+	}
+
+	conn := callback{userHeader: config.UserHeader, groupHeader: config.GroupHeader, logger: logger, pathSuffix: "/test"}
+
+	req, err := http.NewRequest("GET", "/", nil)
+	expectNil(t, err)
+	req.Header = map[string][]string{
+		"X-Remote-User":  {testEmail},
+		"X-Remote-Group": {testGroup1},
+	}
+
+	ident, err := conn.HandleCallback(connector.Scopes{OfflineAccess: true, Groups: true}, req)
+	expectNil(t, err)
+
+	expectEquals(t, ident.UserID, testEmail)
+	expectEquals(t, len(ident.Groups), 1)
+	expectEquals(t, ident.Groups[0], testGroup1)
+}
+
+func TestMultipleGroup(t *testing.T) {
+	config := Config{
+		UserHeader:  "X-Remote-User",
+		GroupHeader: "X-Remote-Group",
+	}
+
+	conn := callback{userHeader: config.UserHeader, groupHeader: config.GroupHeader, logger: logger, pathSuffix: "/test"}
+
+	req, err := http.NewRequest("GET", "/", nil)
+	expectNil(t, err)
+	req.Header = map[string][]string{
+		"X-Remote-User":  {testEmail},
+		"X-Remote-Group": {testGroup1 + ", " + testGroup2 + ", " + testGroup3 + ", " + testGroup4},
+	}
+
+	ident, err := conn.HandleCallback(connector.Scopes{OfflineAccess: true, Groups: true}, req)
+	expectNil(t, err)
+
+	expectEquals(t, ident.UserID, testEmail)
+	expectEquals(t, len(ident.Groups), 4)
+	expectEquals(t, ident.Groups[0], testGroup1)
+	expectEquals(t, ident.Groups[1], testGroup2)
+	expectEquals(t, ident.Groups[2], testGroup3)
+	expectEquals(t, ident.Groups[3], testGroup4)
+}
+
+func TestStaticGroup(t *testing.T) {
+	config := Config{
+		UserHeader:  "X-Remote-User",
+		GroupHeader: "X-Remote-Group",
+		Groups:      []string{"static1", "static 2"},
+	}
+
+	conn := callback{userHeader: config.UserHeader, groupHeader: config.GroupHeader, groups: config.Groups, logger: logger, pathSuffix: "/test"}
+
+	req, err := http.NewRequest("GET", "/", nil)
+	expectNil(t, err)
+	req.Header = map[string][]string{
+		"X-Remote-User":  {testEmail},
+		"X-Remote-Group": {testGroup1 + ", " + testGroup2 + ", " + testGroup3 + ", " + testGroup4},
+	}
+
+	ident, err := conn.HandleCallback(connector.Scopes{OfflineAccess: true, Groups: true}, req)
+	expectNil(t, err)
+
+	expectEquals(t, ident.UserID, testEmail)
+	expectEquals(t, len(ident.Groups), 6)
+	expectEquals(t, ident.Groups[0], testGroup1)
+	expectEquals(t, ident.Groups[1], testGroup2)
+	expectEquals(t, ident.Groups[2], testGroup3)
+	expectEquals(t, ident.Groups[3], testGroup4)
+	expectEquals(t, ident.Groups[4], testStaticGroup1)
+	expectEquals(t, ident.Groups[5], testStaticGroup2)
+}
+
+func expectNil(t *testing.T, a interface{}) {
+	if a != nil {
+		t.Errorf("Expected %+v to equal nil", a)
+	}
+}
+
+func expectEquals(t *testing.T, a interface{}, b interface{}) {
+	if !reflect.DeepEqual(a, b) {
+		t.Errorf("Expected %+v to equal %+v", a, b)
+	}
+}
-- 
GitLab