From f3ef7d46dfd38b0be706306a567b0f79df5bd431 Mon Sep 17 00:00:00 2001
From: Doug Goldstein <cardoe@cardoe.com>
Date: Sun, 2 Jun 2024 19:56:53 -0500
Subject: [PATCH] feat: allow domain names or IDs in keystone connector (#3506)

OpenStack Keystone allows a user to authenticate against a domain. That
domain can be specified either as the domain ID or the domain name when
authenticating. The domain ID is a UUID or the special "default" domain
ID so key off of that when deciding what to submit to the keystone API.
Collapsed the code to share the domainKeystone struct by utilizing
omitempty to skip unset fields.

Signed-off-by: Doug Goldstein <cardoe@cardoe.com>
---
 connector/keystone/keystone.go      |  36 +++--
 connector/keystone/keystone_test.go | 224 +++++++++++++++++++++++-----
 go.mod                              |   2 +-
 3 files changed, 214 insertions(+), 48 deletions(-)

diff --git a/connector/keystone/keystone.go b/connector/keystone/keystone.go
index f8dff9e3..cdfdb558 100644
--- a/connector/keystone/keystone.go
+++ b/connector/keystone/keystone.go
@@ -10,11 +10,13 @@ import (
 	"log/slog"
 	"net/http"
 
+	"github.com/google/uuid"
+
 	"github.com/dexidp/dex/connector"
 )
 
 type conn struct {
-	Domain        string
+	Domain        domainKeystone
 	Host          string
 	AdminUsername string
 	AdminPassword string
@@ -29,8 +31,8 @@ type userKeystone struct {
 }
 
 type domainKeystone struct {
-	ID   string `json:"id"`
-	Name string `json:"name"`
+	ID   string `json:"id,omitempty"`
+	Name string `json:"name,omitempty"`
 }
 
 // Config holds the configuration parameters for Keystone connector.
@@ -71,13 +73,9 @@ type password struct {
 }
 
 type user struct {
-	Name     string `json:"name"`
-	Domain   domain `json:"domain"`
-	Password string `json:"password"`
-}
-
-type domain struct {
-	ID string `json:"id"`
+	Name     string         `json:"name"`
+	Domain   domainKeystone `json:"domain"`
+	Password string         `json:"password"`
 }
 
 type token struct {
@@ -112,8 +110,22 @@ var (
 
 // Open returns an authentication strategy using Keystone.
 func (c *Config) Open(id string, logger *slog.Logger) (connector.Connector, error) {
+	_, err := uuid.Parse(c.Domain)
+	var domain domainKeystone
+	// check if the supplied domain is a UUID or the special "default" value
+	// which is treated as an ID and not a name
+	if err == nil || c.Domain == "default" {
+		domain = domainKeystone{
+			ID: c.Domain,
+		}
+	} else {
+		domain = domainKeystone{
+			Name: c.Domain,
+		}
+	}
+
 	return &conn{
-		Domain:        c.Domain,
+		Domain:        domain,
 		Host:          c.Host,
 		AdminUsername: c.AdminUsername,
 		AdminPassword: c.AdminPassword,
@@ -202,7 +214,7 @@ func (p *conn) getTokenResponse(ctx context.Context, username, pass string) (res
 				Password: password{
 					User: user{
 						Name:     username,
-						Domain:   domain{ID: p.Domain},
+						Domain:   p.Domain,
 						Password: pass,
 					},
 				},
diff --git a/connector/keystone/keystone_test.go b/connector/keystone/keystone_test.go
index 8f1ea1bb..9b0590df 100644
--- a/connector/keystone/keystone_test.go
+++ b/connector/keystone/keystone_test.go
@@ -17,11 +17,13 @@ import (
 const (
 	invalidPass = "WRONG_PASS"
 
-	testUser   = "test_user"
-	testPass   = "test_pass"
-	testEmail  = "test@example.com"
-	testGroup  = "test_group"
-	testDomain = "default"
+	testUser          = "test_user"
+	testPass          = "test_pass"
+	testEmail         = "test@example.com"
+	testGroup         = "test_group"
+	testDomainAltName = "altdomain"
+	testDomainID      = "default"
+	testDomainName    = "Default"
 )
 
 var (
@@ -32,8 +34,26 @@ var (
 	authTokenURL     = ""
 	usersURL         = ""
 	groupsURL        = ""
+	domainsURL       = ""
 )
 
+type userReq struct {
+	Name     string   `json:"name"`
+	Email    string   `json:"email"`
+	Enabled  bool     `json:"enabled"`
+	Password string   `json:"password"`
+	Roles    []string `json:"roles"`
+	DomainID string   `json:"domain_id,omitempty"`
+}
+
+type domainResponse struct {
+	Domain domainKeystone `json:"domain"`
+}
+
+type domainsResponse struct {
+	Domains []domainKeystone `json:"domains"`
+}
+
 type groupResponse struct {
 	Group struct {
 		ID string `json:"id"`
@@ -49,7 +69,7 @@ func getAdminToken(t *testing.T, adminName, adminPass string) (token, id string)
 				Password: password{
 					User: user{
 						Name:     adminName,
-						Domain:   domain{ID: testDomain},
+						Domain:   domainKeystone{ID: testDomainID},
 						Password: adminPass,
 					},
 				},
@@ -89,16 +109,91 @@ func getAdminToken(t *testing.T, adminName, adminPass string) (token, id string)
 	return token, tokenResp.Token.User.ID
 }
 
-func createUser(t *testing.T, token, userName, userEmail, userPass string) string {
+func getOrCreateDomain(t *testing.T, token, domainName string) string {
+	t.Helper()
+
+	domainSearchURL := domainsURL + "?name=" + domainName
+	reqGet, err := http.NewRequest("GET", domainSearchURL, nil)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	reqGet.Header.Set("X-Auth-Token", token)
+	reqGet.Header.Add("Content-Type", "application/json")
+	respGet, err := http.DefaultClient.Do(reqGet)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	dataGet, err := io.ReadAll(respGet.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer respGet.Body.Close()
+
+	domainsResp := new(domainsResponse)
+	err = json.Unmarshal(dataGet, &domainsResp)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if len(domainsResp.Domains) >= 1 {
+		return domainsResp.Domains[0].ID
+	}
+
+	createDomainData := map[string]interface{}{
+		"domain": map[string]interface{}{
+			"name":    domainName,
+			"enabled": true,
+		},
+	}
+
+	body, err := json.Marshal(createDomainData)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	req, err := http.NewRequest("POST", domainsURL, bytes.NewBuffer(body))
+	if err != nil {
+		t.Fatal(err)
+	}
+	req.Header.Set("X-Auth-Token", token)
+	req.Header.Add("Content-Type", "application/json")
+	resp, err := http.DefaultClient.Do(req)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	if resp.StatusCode != 201 {
+		t.Fatalf("failed to create domain %s", domainName)
+	}
+
+	data, err := io.ReadAll(resp.Body)
+	if err != nil {
+		t.Fatal(err)
+	}
+	defer resp.Body.Close()
+
+	domainResp := new(domainResponse)
+	err = json.Unmarshal(data, &domainResp)
+	if err != nil {
+		t.Fatal(err)
+	}
+
+	return domainResp.Domain.ID
+}
+
+func createUser(t *testing.T, token, domainID, userName, userEmail, userPass string) string {
 	t.Helper()
 
 	createUserData := map[string]interface{}{
-		"user": map[string]interface{}{
-			"name":     userName,
-			"email":    userEmail,
-			"enabled":  true,
-			"password": userPass,
-			"roles":    []string{"admin"},
+		"user": userReq{
+			DomainID: domainID,
+			Name:     userName,
+			Email:    userEmail,
+			Enabled:  true,
+			Password: userPass,
+			Roles:    []string{"admin"},
 		},
 	}
 
@@ -214,7 +309,7 @@ func TestIncorrectCredentialsLogin(t *testing.T) {
 	setupVariables(t)
 	c := conn{
 		client: http.DefaultClient,
-		Host:   keystoneURL, Domain: testDomain,
+		Host:   keystoneURL, Domain: domainKeystone{ID: testDomainID},
 		AdminUsername: adminUser, AdminPassword: adminPass,
 	}
 	s := connector.Scopes{OfflineAccess: true, Groups: true}
@@ -238,10 +333,11 @@ func TestValidUserLogin(t *testing.T) {
 	token, _ := getAdminToken(t, adminUser, adminPass)
 
 	type tUser struct {
-		username string
-		domain   string
-		email    string
-		password string
+		createDomain bool
+		domain       domainKeystone
+		username     string
+		email        string
+		password     string
 	}
 
 	type expect struct {
@@ -258,10 +354,11 @@ func TestValidUserLogin(t *testing.T) {
 		{
 			name: "test with email address",
 			input: tUser{
-				username: testUser,
-				domain:   testDomain,
-				email:    testEmail,
-				password: testPass,
+				createDomain: false,
+				domain:       domainKeystone{ID: testDomainID},
+				username:     testUser,
+				email:        testEmail,
+				password:     testPass,
 			},
 			expected: expect{
 				username:      testUser,
@@ -272,10 +369,11 @@ func TestValidUserLogin(t *testing.T) {
 		{
 			name: "test without email address",
 			input: tUser{
-				username: testUser,
-				domain:   testDomain,
-				email:    "",
-				password: testPass,
+				createDomain: false,
+				domain:       domainKeystone{ID: testDomainID},
+				username:     testUser,
+				email:        "",
+				password:     testPass,
 			},
 			expected: expect{
 				username:      testUser,
@@ -283,11 +381,66 @@ func TestValidUserLogin(t *testing.T) {
 				verifiedEmail: false,
 			},
 		},
+		{
+			name: "test with default domain Name",
+			input: tUser{
+				createDomain: false,
+				domain:       domainKeystone{Name: testDomainName},
+				username:     testUser,
+				email:        testEmail,
+				password:     testPass,
+			},
+			expected: expect{
+				username:      testUser,
+				email:         testEmail,
+				verifiedEmail: true,
+			},
+		},
+		{
+			name: "test with custom domain Name",
+			input: tUser{
+				createDomain: true,
+				domain:       domainKeystone{Name: testDomainAltName},
+				username:     testUser,
+				email:        testEmail,
+				password:     testPass,
+			},
+			expected: expect{
+				username:      testUser,
+				email:         testEmail,
+				verifiedEmail: true,
+			},
+		},
+		{
+			name: "test with custom domain ID",
+			input: tUser{
+				createDomain: true,
+				domain:       domainKeystone{},
+				username:     testUser,
+				email:        testEmail,
+				password:     testPass,
+			},
+			expected: expect{
+				username:      testUser,
+				email:         testEmail,
+				verifiedEmail: true,
+			},
+		},
 	}
 
 	for _, tt := range tests {
 		t.Run(tt.name, func(t *testing.T) {
-			userID := createUser(t, token, tt.input.username, tt.input.email, tt.input.password)
+			domainID := ""
+			if tt.input.createDomain == true {
+				domainID = getOrCreateDomain(t, token, testDomainAltName)
+				t.Logf("getOrCreateDomain ID: %s\n", domainID)
+
+				// if there was nothing set then use the dynamically generated domain ID
+				if tt.input.domain.ID == "" && tt.input.domain.Name == "" {
+					tt.input.domain.ID = domainID
+				}
+			}
+			userID := createUser(t, token, domainID, tt.input.username, tt.input.email, tt.input.password)
 			defer deleteResource(t, token, userID, usersURL)
 
 			c := conn{
@@ -298,7 +451,7 @@ func TestValidUserLogin(t *testing.T) {
 			s := connector.Scopes{OfflineAccess: true, Groups: true}
 			identity, validPW, err := c.Login(context.Background(), s, tt.input.username, tt.input.password)
 			if err != nil {
-				t.Fatal(err.Error())
+				t.Fatalf("Login failed for user %s: %v", tt.input.username, err.Error())
 			}
 			t.Log(identity)
 			if identity.Username != tt.expected.username {
@@ -330,7 +483,7 @@ func TestUseRefreshToken(t *testing.T) {
 
 	c := conn{
 		client: http.DefaultClient,
-		Host:   keystoneURL, Domain: testDomain,
+		Host:   keystoneURL, Domain: domainKeystone{ID: testDomainID},
 		AdminUsername: adminUser, AdminPassword: adminPass,
 	}
 	s := connector.Scopes{OfflineAccess: true, Groups: true}
@@ -352,11 +505,11 @@ func TestUseRefreshToken(t *testing.T) {
 func TestUseRefreshTokenUserDeleted(t *testing.T) {
 	setupVariables(t)
 	token, _ := getAdminToken(t, adminUser, adminPass)
-	userID := createUser(t, token, testUser, testEmail, testPass)
+	userID := createUser(t, token, "", testUser, testEmail, testPass)
 
 	c := conn{
 		client: http.DefaultClient,
-		Host:   keystoneURL, Domain: testDomain,
+		Host:   keystoneURL, Domain: domainKeystone{ID: testDomainID},
 		AdminUsername: adminUser, AdminPassword: adminPass,
 	}
 	s := connector.Scopes{OfflineAccess: true, Groups: true}
@@ -382,12 +535,12 @@ func TestUseRefreshTokenUserDeleted(t *testing.T) {
 func TestUseRefreshTokenGroupsChanged(t *testing.T) {
 	setupVariables(t)
 	token, _ := getAdminToken(t, adminUser, adminPass)
-	userID := createUser(t, token, testUser, testEmail, testPass)
+	userID := createUser(t, token, "", testUser, testEmail, testPass)
 	defer deleteResource(t, token, userID, usersURL)
 
 	c := conn{
 		client: http.DefaultClient,
-		Host:   keystoneURL, Domain: testDomain,
+		Host:   keystoneURL, Domain: domainKeystone{ID: testDomainID},
 		AdminUsername: adminUser, AdminPassword: adminPass,
 	}
 	s := connector.Scopes{OfflineAccess: true, Groups: true}
@@ -419,12 +572,12 @@ func TestUseRefreshTokenGroupsChanged(t *testing.T) {
 func TestNoGroupsInScope(t *testing.T) {
 	setupVariables(t)
 	token, _ := getAdminToken(t, adminUser, adminPass)
-	userID := createUser(t, token, testUser, testEmail, testPass)
+	userID := createUser(t, token, "", testUser, testEmail, testPass)
 	defer deleteResource(t, token, userID, usersURL)
 
 	c := conn{
 		client: http.DefaultClient,
-		Host:   keystoneURL, Domain: testDomain,
+		Host:   keystoneURL, Domain: domainKeystone{ID: testDomainID},
 		AdminUsername: adminUser, AdminPassword: adminPass,
 	}
 	s := connector.Scopes{OfflineAccess: true, Groups: false}
@@ -474,6 +627,7 @@ func setupVariables(t *testing.T) {
 	authTokenURL = keystoneURL + "/v3/auth/tokens/"
 	usersURL = keystoneAdminURL + "/v3/users/"
 	groupsURL = keystoneAdminURL + "/v3/groups/"
+	domainsURL = keystoneAdminURL + "/v3/domains/"
 }
 
 func expectEquals(t *testing.T, a interface{}, b interface{}) {
diff --git a/go.mod b/go.mod
index 2667cf8a..4a1fd126 100644
--- a/go.mod
+++ b/go.mod
@@ -17,6 +17,7 @@ require (
 	github.com/go-jose/go-jose/v4 v4.0.2
 	github.com/go-ldap/ldap/v3 v3.4.6
 	github.com/go-sql-driver/mysql v1.8.1
+	github.com/google/uuid v1.6.0
 	github.com/gorilla/handlers v1.5.2
 	github.com/gorilla/mux v1.8.1
 	github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0
@@ -65,7 +66,6 @@ require (
 	github.com/golang/protobuf v1.5.4 // indirect
 	github.com/google/go-cmp v0.6.0 // indirect
 	github.com/google/s2a-go v0.1.7 // indirect
-	github.com/google/uuid v1.6.0 // indirect
 	github.com/googleapis/enterprise-certificate-proxy v0.3.2 // indirect
 	github.com/googleapis/gax-go/v2 v2.12.4 // indirect
 	github.com/hashicorp/hcl/v2 v2.13.0 // indirect
-- 
GitLab