From d31f6eabd467dd2b9ddb1aae184b7c67f44f5dd2 Mon Sep 17 00:00:00 2001
From: Andrew Block <andy.block@gmail.com>
Date: Thu, 26 Dec 2019 20:32:12 -0600
Subject: [PATCH] Corrected logic in group verification

---
 connector/openshift/openshift.go      | 14 ++++++++------
 connector/openshift/openshift_test.go | 24 +++++++++++++++++++++---
 2 files changed, 29 insertions(+), 9 deletions(-)

diff --git a/connector/openshift/openshift.go b/connector/openshift/openshift.go
index e1974694..6ac5d044 100644
--- a/connector/openshift/openshift.go
+++ b/connector/openshift/openshift.go
@@ -165,10 +165,12 @@ func (c *openshiftConnector) HandleCallback(s connector.Scopes, r *http.Request)
 		return identity, fmt.Errorf("openshift: get user: %v", err)
 	}
 
-	validGroups := validateRequiredGroups(user.Groups, c.groups)
+	if len(c.groups) > 0 {
+		validGroups := validateAllowedGroups(user.Groups, c.groups)
 
-	if !validGroups {
-		return identity, fmt.Errorf("openshift: user %q is not in any of the required groups", user.Name)
+		if !validGroups {
+			return identity, fmt.Errorf("openshift: user %q is not in any of the required groups", user.Name)
+		}
 	}
 
 	identity = connector.Identity{
@@ -211,10 +213,10 @@ func (c *openshiftConnector) user(ctx context.Context, client *http.Client) (u u
 	return u, err
 }
 
-func validateRequiredGroups(userGroups, requiredGroups []string) bool {
-	matchingGroups := groups.Filter(userGroups, requiredGroups)
+func validateAllowedGroups(userGroups, allowedGroups []string) bool {
+	matchingGroups := groups.Filter(userGroups, allowedGroups)
 
-	return len(requiredGroups) == len(matchingGroups)
+	return len(matchingGroups) != 0
 }
 
 // newHTTPClient returns a new HTTP client
diff --git a/connector/openshift/openshift_test.go b/connector/openshift/openshift_test.go
index 2ed50150..316af60a 100644
--- a/connector/openshift/openshift_test.go
+++ b/connector/openshift/openshift_test.go
@@ -83,11 +83,29 @@ func TestGetUser(t *testing.T) {
 	expectEquals(t, len(u.Groups), 1)
 }
 
-func TestVerifyGroupFn(t *testing.T) {
-	requiredGroups := []string{"users"}
+func TestVerifySingleGroupFn(t *testing.T) {
+	allowedGroups := []string{"users"}
 	groupMembership := []string{"users", "org1"}
 
-	validGroupMembership := validateRequiredGroups(groupMembership, requiredGroups)
+	validGroupMembership := validateAllowedGroups(groupMembership, allowedGroups)
+
+	expectEquals(t, validGroupMembership, true)
+}
+
+func TestVerifySingleGroupFailureFn(t *testing.T) {
+	allowedGroups := []string{"admins"}
+	groupMembership := []string{"users"}
+
+	validGroupMembership := validateAllowedGroups(groupMembership, allowedGroups)
+
+	expectEquals(t, validGroupMembership, false)
+}
+
+func TestVerifyMultipleGroupFn(t *testing.T) {
+	allowedGroups := []string{"users", "admins"}
+	groupMembership := []string{"users", "org1"}
+
+	validGroupMembership := validateAllowedGroups(groupMembership, allowedGroups)
 
 	expectEquals(t, validGroupMembership, true)
 }
-- 
GitLab