From cbb007663f11549beca58e924b4c55ddbc7b33d8 Mon Sep 17 00:00:00 2001
From: Ben Navetta <ben.navetta@gmail.com>
Date: Wed, 21 Jun 2017 22:56:02 -0700
Subject: [PATCH] add documentation and tests

---
 Documentation/oidc-connector.md |   7 +++
 connector/oidc/oidc.go          |  49 ++++++++++------
 connector/oidc/oidc_test.go     | 101 +++++++++++++++++++++++++++++++-
 3 files changed, 139 insertions(+), 18 deletions(-)

diff --git a/Documentation/oidc-connector.md b/Documentation/oidc-connector.md
index 57a2b66e..8171bc37 100644
--- a/Documentation/oidc-connector.md
+++ b/Documentation/oidc-connector.md
@@ -42,6 +42,13 @@ connectors:
     # following field.
     #
     # basicAuthUnsupported: true
+    
+    # Google supports whitelisting allowed domains when using G Suite
+    # (Google Apps). The following field can be set to a list of domains
+    # that can log in:
+    #
+    # hostedDomains:
+    #  - example.com
 ```
 
 [oidc-doc]: openid-connect.md
diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go
index ffe80719..cd04a374 100644
--- a/connector/oidc/oidc.go
+++ b/connector/oidc/oidc.go
@@ -33,7 +33,9 @@ type Config struct {
 
 	Scopes []string `json:"scopes"` // defaults to "profile" and "email"
 
-	HostedDomain string `json:"hostedDomain"`
+	// Optional list of whitelisted domains when using Google
+	// If this field is nonempty, only users from a listed domain will be allowed to log in
+	HostedDomains []string `json:"hostedDomain"`
 }
 
 // Domains that don't support basic auth. golang.org/x/oauth2 has an internal
@@ -111,9 +113,9 @@ func (c *Config) Open(logger logrus.FieldLogger) (conn connector.Connector, err
 		verifier: provider.Verifier(
 			&oidc.Config{ClientID: clientID},
 		),
-		logger:       logger,
-		cancel:       cancel,
-		hostedDomain: c.HostedDomain,
+		logger:        logger,
+		cancel:        cancel,
+		hostedDomains: c.HostedDomains,
 	}, nil
 }
 
@@ -123,13 +125,13 @@ var (
 )
 
 type oidcConnector struct {
-	redirectURI  string
-	oauth2Config *oauth2.Config
-	verifier     *oidc.IDTokenVerifier
-	ctx          context.Context
-	cancel       context.CancelFunc
-	logger       logrus.FieldLogger
-	hostedDomain string
+	redirectURI   string
+	oauth2Config  *oauth2.Config
+	verifier      *oidc.IDTokenVerifier
+	ctx           context.Context
+	cancel        context.CancelFunc
+	logger        logrus.FieldLogger
+	hostedDomains []string
 }
 
 func (c *oidcConnector) Close() error {
@@ -142,11 +144,14 @@ func (c *oidcConnector) LoginURL(s connector.Scopes, callbackURL, state string)
 		return "", fmt.Errorf("expected callback URL %q did not match the URL in the config %q", callbackURL, c.redirectURI)
 	}
 
-	if c.hostedDomain != "" {
-		return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", c.hostedDomain)), nil
-	} else {
-		return c.oauth2Config.AuthCodeURL(state), nil
+	if len(c.hostedDomains) > 0 {
+		preferredDomain := c.hostedDomains[0]
+		if len(c.hostedDomains) > 1 {
+			preferredDomain = "*"
+		}
+		return c.oauth2Config.AuthCodeURL(state, oauth2.SetAuthURLParam("hd", preferredDomain)), nil
 	}
+	return c.oauth2Config.AuthCodeURL(state), nil
 }
 
 type oauth2Error struct {
@@ -190,8 +195,18 @@ func (c *oidcConnector) HandleCallback(s connector.Scopes, r *http.Request) (ide
 		return identity, fmt.Errorf("oidc: failed to decode claims: %v", err)
 	}
 
-	if claims.HostedDomain != c.hostedDomain {
-		return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain)
+	if len(c.hostedDomains) > 0 {
+		found := false
+		for _, domain := range c.hostedDomains {
+			if claims.HostedDomain != domain {
+				found = true
+				break
+			}
+		}
+
+		if !found {
+			return identity, fmt.Errorf("oidc: unexpected hd claim %v", claims.HostedDomain)
+		}
 	}
 
 	identity = connector.Identity{
diff --git a/connector/oidc/oidc_test.go b/connector/oidc/oidc_test.go
index b3f609d1..a484d0a2 100644
--- a/connector/oidc/oidc_test.go
+++ b/connector/oidc/oidc_test.go
@@ -1,6 +1,13 @@
 package oidc
 
-import "testing"
+import (
+	"github.com/Sirupsen/logrus"
+	"github.com/coreos/dex/connector"
+	"net/url"
+	"os"
+	"reflect"
+	"testing"
+)
 
 func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
 	tests := []struct {
@@ -21,3 +28,95 @@ func TestKnownBrokenAuthHeaderProvider(t *testing.T) {
 		}
 	}
 }
+
+func TestOidcConnector_LoginURL(t *testing.T) {
+	logger := &logrus.Logger{
+		Out:       os.Stderr,
+		Formatter: &logrus.TextFormatter{DisableColors: true},
+		Level:     logrus.DebugLevel,
+	}
+
+	tests := []struct {
+		scopes        connector.Scopes
+		hostedDomains []string
+
+		wantScopes  string
+		wantHdParam string
+	}{
+		{
+			connector.Scopes{}, []string{"example.com"},
+			"openid profile email", "example.com",
+		},
+		{
+			connector.Scopes{}, []string{"mydomain.org", "example.com"},
+			"openid profile email", "*",
+		},
+		{
+			connector.Scopes{}, []string{},
+			"openid profile email", "",
+		},
+		{
+			connector.Scopes{OfflineAccess: true}, []string{},
+			"openid profile email", "",
+		},
+	}
+
+	callback := "https://dex.example.com/callback"
+	state := "secret"
+
+	for _, test := range tests {
+		config := &Config{
+			Issuer:        "https://accounts.google.com",
+			ClientID:      "client-id",
+			ClientSecret:  "client-secret",
+			RedirectURI:   "https://dex.example.com/callback",
+			HostedDomains: test.hostedDomains,
+		}
+
+		conn, err := config.Open(logger)
+		if err != nil {
+			t.Errorf("failed to open connector: %v", err)
+			continue
+		}
+
+		loginURL, err := conn.(connector.CallbackConnector).LoginURL(test.scopes, callback, state)
+		if err != nil {
+			t.Errorf("failed to get login URL: %v", err)
+			continue
+		}
+
+		actual, err := url.Parse(loginURL)
+		if err != nil {
+			t.Errorf("failed to parse login URL: %v", err)
+			continue
+		}
+
+		wanted, _ := url.Parse("https://accounts.google.com/o/oauth2/v2/auth")
+		wantedQuery := &url.Values{}
+		wantedQuery.Set("client_id", config.ClientID)
+		wantedQuery.Set("redirect_uri", config.RedirectURI)
+		wantedQuery.Set("response_type", "code")
+		wantedQuery.Set("state", "secret")
+		wantedQuery.Set("scope", test.wantScopes)
+		if test.wantHdParam != "" {
+			wantedQuery.Set("hd", test.wantHdParam)
+		}
+		wanted.RawQuery = wantedQuery.Encode()
+
+		if !reflect.DeepEqual(actual, wanted) {
+			t.Errorf("Wanted %v, got %v", wanted, actual)
+		}
+	}
+}
+
+//func TestOidcConnector_HandleCallback(t *testing.T) {
+//	logger := &logrus.Logger{
+//		Out:       os.Stderr,
+//		Formatter: &logrus.TextFormatter{DisableColors: true},
+//		Level:     logrus.DebugLevel,
+//	}
+//
+//	tests := []struct {
+//
+//	}
+//}
-- 
GitLab