Skip to content
Snippets Groups Projects
server_test.go 49.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • 					t.Errorf("failed to exchange code for token: %v", err)
    					return
    				}
    				rawIDToken, ok := token.Extra("id_token").(string)
    				if !ok {
    					t.Errorf("no id token found: %v", err)
    					return
    				}
    
    				idToken, err := p.Verifier(&oidc.Config{ClientID: testClientID}).Verify(ctx, rawIDToken)
    
    				if err != nil {
    					t.Errorf("failed to parse ID Token: %v", err)
    					return
    				}
    
    				sort.Strings(idToken.Audience)
    				expAudience := []string{peerID, testClientID}
    				if !reflect.DeepEqual(idToken.Audience, expAudience) {
    					t.Errorf("expected audience %q, got %q", expAudience, idToken.Audience)
    				}
    			}
    			if gotState := q.Get("state"); gotState != state {
    				t.Errorf("state did not match, want=%q got=%q", state, gotState)
    			}
    			w.WriteHeader(http.StatusOK)
    			return
    		}
    		http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
    	}))
    
    	defer oauth2Server.Close()
    
    	redirectURL := oauth2Server.URL + "/callback"
    	client := storage.Client{
    		ID:           testClientID,
    		Secret:       "testclientsecret",
    		RedirectURIs: []string{redirectURL},
    	}
    	if err := s.storage.CreateClient(client); err != nil {
    		t.Fatalf("failed to create client: %v", err)
    	}
    
    	peer := storage.Client{
    		ID:           peerID,
    		Secret:       "foobar",
    		TrustedPeers: []string{"testclient"},
    	}
    
    	if err := s.storage.CreateClient(peer); err != nil {
    		t.Fatalf("failed to create client: %v", err)
    	}
    
    	oauth2Config = &oauth2.Config{
    		ClientID:     client.ID,
    		ClientSecret: client.Secret,
    		Endpoint:     p.Endpoint(),
    		Scopes: []string{
    			oidc.ScopeOpenID, "profile", "email",
    			"audience:server:client_id:" + client.ID,
    			"audience:server:client_id:" + peer.ID,
    		},
    		RedirectURL: redirectURL,
    	}
    
    	resp, err := http.Get(oauth2Server.URL + "/login")
    	if err != nil {
    		t.Fatalf("get failed: %v", err)
    	}
    
    Mark Sagi-Kazar's avatar
    Mark Sagi-Kazar committed
    	defer resp.Body.Close()
    
    
    	if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil {
    		t.Fatal(err)
    	}
    	if respDump, err = httputil.DumpResponse(resp, true); err != nil {
    
    		t.Fatal(err)
    	}
    }
    
    func TestCrossClientScopesWithAzpInAudienceByDefault(t *testing.T) {
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    	httpServer, s := newTestServer(ctx, t, func(c *Config) {
    
    m.nabokikh's avatar
    m.nabokikh committed
    		c.Issuer += "/non-root-path"
    
    	})
    	defer httpServer.Close()
    
    	p, err := oidc.NewProvider(ctx, httpServer.URL)
    	if err != nil {
    		t.Fatalf("failed to get provider: %v", err)
    	}
    
    	var (
    		reqDump, respDump []byte
    		gotCode           bool
    		state             = "a_state"
    	)
    	defer func() {
    		if !gotCode {
    			t.Errorf("never got a code in callback\n%s\n%s", reqDump, respDump)
    		}
    	}()
    
    	testClientID := "testclient"
    	peerID := "peer"
    
    	var oauth2Config *oauth2.Config
    	oauth2Server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		if r.URL.Path == "/callback" {
    			q := r.URL.Query()
    			if errType := q.Get("error"); errType != "" {
    				if desc := q.Get("error_description"); desc != "" {
    					t.Errorf("got error from server %s: %s", errType, desc)
    				} else {
    					t.Errorf("got error from server %s", errType)
    				}
    				w.WriteHeader(http.StatusInternalServerError)
    				return
    			}
    
    			if code := q.Get("code"); code != "" {
    				gotCode = true
    				token, err := oauth2Config.Exchange(ctx, code)
    				if err != nil {
    					t.Errorf("failed to exchange code for token: %v", err)
    					return
    				}
    				rawIDToken, ok := token.Extra("id_token").(string)
    				if !ok {
    					t.Errorf("no id token found: %v", err)
    					return
    				}
    				idToken, err := p.Verifier(&oidc.Config{ClientID: testClientID}).Verify(ctx, rawIDToken)
    				if err != nil {
    					t.Errorf("failed to parse ID Token: %v", err)
    					return
    				}
    
    				sort.Strings(idToken.Audience)
    				expAudience := []string{peerID, testClientID}
    				if !reflect.DeepEqual(idToken.Audience, expAudience) {
    					t.Errorf("expected audience %q, got %q", expAudience, idToken.Audience)
    				}
    			}
    			if gotState := q.Get("state"); gotState != state {
    				t.Errorf("state did not match, want=%q got=%q", state, gotState)
    			}
    			w.WriteHeader(http.StatusOK)
    			return
    		}
    		http.Redirect(w, r, oauth2Config.AuthCodeURL(state), http.StatusSeeOther)
    	}))
    
    	defer oauth2Server.Close()
    
    	redirectURL := oauth2Server.URL + "/callback"
    	client := storage.Client{
    		ID:           testClientID,
    		Secret:       "testclientsecret",
    		RedirectURIs: []string{redirectURL},
    	}
    	if err := s.storage.CreateClient(client); err != nil {
    		t.Fatalf("failed to create client: %v", err)
    	}
    
    	peer := storage.Client{
    		ID:           peerID,
    		Secret:       "foobar",
    		TrustedPeers: []string{"testclient"},
    	}
    
    	if err := s.storage.CreateClient(peer); err != nil {
    		t.Fatalf("failed to create client: %v", err)
    	}
    
    	oauth2Config = &oauth2.Config{
    		ClientID:     client.ID,
    		ClientSecret: client.Secret,
    		Endpoint:     p.Endpoint(),
    		Scopes: []string{
    			oidc.ScopeOpenID, "profile", "email",
    			"audience:server:client_id:" + peer.ID,
    		},
    		RedirectURL: redirectURL,
    	}
    
    	resp, err := http.Get(oauth2Server.URL + "/login")
    	if err != nil {
    		t.Fatalf("get failed: %v", err)
    	}
    
    Mark Sagi-Kazar's avatar
    Mark Sagi-Kazar committed
    	defer resp.Body.Close()
    
    
    	if reqDump, err = httputil.DumpRequest(resp.Request, false); err != nil {
    		t.Fatal(err)
    	}
    	if respDump, err = httputil.DumpResponse(resp, true); err != nil {
    
    func TestPasswordDB(t *testing.T) {
    
    	h, err := bcrypt.GenerateFromPassword([]byte(pw), bcrypt.DefaultCost)
    
    	if err != nil {
    		t.Fatal(err)
    	}
    
    	s.CreatePassword(storage.Password{
    		Email:    "jane@example.com",
    		Username: "jane",
    		UserID:   "foobar",
    		Hash:     h,
    	})
    
    	tests := []struct {
    		name         string
    		username     string
    		password     string
    		wantIdentity connector.Identity
    		wantInvalid  bool
    		wantErr      bool
    	}{
    		{
    			name:     "valid password",
    			username: "jane@example.com",
    			password: pw,
    			wantIdentity: connector.Identity{
    				Email:         "jane@example.com",
    				Username:      "jane",
    				UserID:        "foobar",
    				EmailVerified: true,
    			},
    		},
    		{
    
    			name:        "unknown user",
    			username:    "john@example.com",
    			password:    pw,
    			wantInvalid: true,
    
    		},
    		{
    			name:        "invalid password",
    			username:    "jane@example.com",
    			password:    "not the correct password",
    			wantInvalid: true,
    		},
    	}
    
    	for _, tc := range tests {
    
    		ident, valid, err := conn.Login(context.Background(), connector.Scopes{}, tc.username, tc.password)
    
    		if err != nil {
    			if !tc.wantErr {
    				t.Errorf("%s: %v", tc.name, err)
    			}
    			continue
    		}
    
    		if tc.wantErr {
    			t.Errorf("%s: expected error", tc.name)
    			continue
    		}
    
    		if !valid {
    			if !tc.wantInvalid {
    				t.Errorf("%s: expected valid password", tc.name)
    			}
    			continue
    		}
    
    		if tc.wantInvalid {
    			t.Errorf("%s: expected invalid password", tc.name)
    			continue
    		}
    
    		if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" {
    			t.Errorf("%s: %s", tc.name, diff)
    		}
    	}
    }
    
    
    func TestPasswordDBUsernamePrompt(t *testing.T) {
    	s := memory.New(logger)
    	conn := newPasswordDB(s)
    
    	expected := "Email Address"
    	if actual := conn.Prompt(); actual != expected {
    		t.Errorf("expected %v, got %v", expected, actual)
    	}
    }
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    type storageWithKeysTrigger struct {
    	storage.Storage
    	f func()
    }
    
    func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) {
    	s.f()
    	return s.Storage.GetKeys()
    }
    
    func TestKeyCacher(t *testing.T) {
    	tNow := time.Now()
    	now := func() time.Time { return tNow }
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    
    	tests := []struct {
    		before            func()
    		wantCallToStorage bool
    	}{
    		{
    			before:            func() {},
    			wantCallToStorage: true,
    		},
    		{
    			before: func() {
    				s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
    					old.NextRotation = tNow.Add(time.Minute)
    					return old, nil
    				})
    			},
    			wantCallToStorage: true,
    		},
    		{
    			before:            func() {},
    			wantCallToStorage: false,
    		},
    		{
    			before: func() {
    				tNow = tNow.Add(time.Hour)
    			},
    			wantCallToStorage: true,
    		},
    		{
    			before: func() {
    				tNow = tNow.Add(time.Hour)
    				s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
    					old.NextRotation = tNow.Add(time.Minute)
    					return old, nil
    				})
    			},
    			wantCallToStorage: true,
    		},
    		{
    			before:            func() {},
    			wantCallToStorage: false,
    		},
    	}
    
    	gotCall := false
    	s = newKeyCacher(storageWithKeysTrigger{s, func() { gotCall = true }}, now)
    	for i, tc := range tests {
    		gotCall = false
    		tc.before()
    		s.GetKeys()
    		if gotCall != tc.wantCallToStorage {
    			t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
    		}
    	}
    }
    
    func checkErrorResponse(err error, t *testing.T, tc test) {
    	if err == nil {
    		t.Errorf("%s: DANGEROUS! got a token when we should not get one!", tc.name)
    		return
    	}
    	if rErr, ok := err.(*oauth2.RetrieveError); ok {
    		if rErr.Response.StatusCode != tc.tokenError.StatusCode {
    			t.Errorf("%s: got wrong StatusCode from server %d. expected %d",
    				tc.name, rErr.Response.StatusCode, tc.tokenError.StatusCode)
    		}
    		details := new(OAuth2ErrorResponse)
    		if err := json.Unmarshal(rErr.Body, details); err != nil {
    			t.Errorf("%s: could not parse return json: %s", tc.name, err)
    			return
    		}
    		if tc.tokenError.Error != "" && details.Error != tc.tokenError.Error {
    			t.Errorf("%s: got wrong Error in response: %s (%s). expected %s",
    				tc.name, details.Error, details.ErrorDescription, tc.tokenError.Error)
    		}
    	} else {
    		t.Errorf("%s: unexpected error type: %s. expected *oauth2.RetrieveError", tc.name, reflect.TypeOf(err))
    	}
    }
    
    
    type oauth2Client struct {
    	config *oauth2.Config
    	token  *oauth2.Token
    	server *httptest.Server
    }
    
    // TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies
    // that only valid refresh tokens can be used to refresh an expired token.
    func TestRefreshTokenFlow(t *testing.T) {
    	state := "state"
    
    m.nabokikh's avatar
    m.nabokikh committed
    	now := time.Now
    
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    	httpServer, s := newTestServer(ctx, t, func(c *Config) {
    		c.Now = now
    	})
    	defer httpServer.Close()
    
    	p, err := oidc.NewProvider(ctx, httpServer.URL)
    	if err != nil {
    		t.Fatalf("failed to get provider: %v", err)
    	}
    
    	var oauth2Client oauth2Client
    
    	oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		if r.URL.Path != "/callback" {
    			// User is visiting app first time. Redirect to dex.
    			http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther)
    			return
    		}
    
    		// User is at '/callback' so they were just redirected _from_ dex.
    		q := r.URL.Query()
    
    		if errType := q.Get("error"); errType != "" {
    			if desc := q.Get("error_description"); desc != "" {
    				t.Errorf("got error from server %s: %s", errType, desc)
    			} else {
    				t.Errorf("got error from server %s", errType)
    			}
    			w.WriteHeader(http.StatusInternalServerError)
    			return
    		}
    
    		// Grab code, exchange for token.
    		if code := q.Get("code"); code != "" {
    			token, err := oauth2Client.config.Exchange(ctx, code)
    			if err != nil {
    				t.Errorf("failed to exchange code for token: %v", err)
    				return
    			}
    			oauth2Client.token = token
    		}
    
    		// Ensure state matches.
    		if gotState := q.Get("state"); gotState != state {
    			t.Errorf("state did not match, want=%q got=%q", state, gotState)
    		}
    		w.WriteHeader(http.StatusOK)
    	}))
    	defer oauth2Client.server.Close()
    
    	// Register the client above with dex.
    	redirectURL := oauth2Client.server.URL + "/callback"
    	client := storage.Client{
    		ID:           "testclient",
    		Secret:       "testclientsecret",
    		RedirectURIs: []string{redirectURL},
    	}
    	if err := s.storage.CreateClient(client); err != nil {
    		t.Fatalf("failed to create client: %v", err)
    	}
    
    	oauth2Client.config = &oauth2.Config{
    		ClientID:     client.ID,
    		ClientSecret: client.Secret,
    		Endpoint:     p.Endpoint(),
    		Scopes:       []string{oidc.ScopeOpenID, "email", "offline_access"},
    		RedirectURL:  redirectURL,
    	}
    
    
    Mark Sagi-Kazar's avatar
    Mark Sagi-Kazar committed
    	resp, err := http.Get(oauth2Client.server.URL + "/login")
    	if err != nil {
    
    		t.Fatalf("get failed: %v", err)
    	}
    
    Mark Sagi-Kazar's avatar
    Mark Sagi-Kazar committed
    	defer resp.Body.Close()
    
    
    	tok := &oauth2.Token{
    		RefreshToken: oauth2Client.token.RefreshToken,
    		Expiry:       time.Now().Add(-time.Hour),
    	}
    
    
    	// Login in again to receive a new token.
    
    Mark Sagi-Kazar's avatar
    Mark Sagi-Kazar committed
    	resp, err = http.Get(oauth2Client.server.URL + "/login")
    	if err != nil {
    
    		t.Fatalf("get failed: %v", err)
    	}
    
    Mark Sagi-Kazar's avatar
    Mark Sagi-Kazar committed
    	defer resp.Body.Close()
    
    
    	// try to refresh expired token with old refresh token.
    
    	if _, err := oauth2Client.config.TokenSource(ctx, tok).Token(); err == nil {
    		t.Errorf("Token refreshed with invalid refresh token, error expected.")
    
    
    // TestOAuth2DeviceFlow runs device flow integration tests against a test server
    func TestOAuth2DeviceFlow(t *testing.T) {
    	clientID := "testclient"
    	clientSecret := ""
    	requestedScopes := []string{oidc.ScopeOpenID, "email", "profile", "groups", "offline_access"}
    
    	t0 := time.Now()
    
    	// Always have the time function used by the server return the same time so
    	// we can predict expected values of "expires_in" fields exactly.
    	now := func() time.Time { return t0 }
    
    	// Connector used by the tests.
    	var conn *mock.Callback
    	idTokensValidFor := time.Second * 30
    
    
    	tests := makeOAuth2Tests(clientID, clientSecret, now)
    	testCases := []struct {
    		name          string
    		tokenEndpoint string
    		oauth2Tests   oauth2Tests
    	}{
    		{
    			name:          "Actual token endpoint for devices",
    			tokenEndpoint: "/token",
    			oauth2Tests:   tests,
    		},
    		// TODO(nabokihms): delete temporary tests after removing the deprecated token endpoint support
    		{
    			name:          "Deprecated token endpoint for devices",
    			tokenEndpoint: "/device/token",
    			oauth2Tests:   tests,
    		},
    	}
    
    	for _, testCase := range testCases {
    		for _, tc := range testCase.oauth2Tests.tests {
    			t.Run(tc.name, func(t *testing.T) {
    				ctx, cancel := context.WithCancel(context.Background())
    				defer cancel()
    
    				// Setup a dex server.
    				httpServer, s := newTestServer(ctx, t, func(c *Config) {
    					c.Issuer += "/non-root-path"
    					c.Now = now
    					c.IDTokensValidFor = idTokensValidFor
    				})
    				defer httpServer.Close()
    
    				mockConn := s.connectors["mock"]
    				conn = mockConn.Connector.(*mock.Callback)
    
    				p, err := oidc.NewProvider(ctx, httpServer.URL)
    				if err != nil {
    					t.Fatalf("failed to get provider: %v", err)
    				}
    
    				// Add the Clients to the test server
    				client := storage.Client{
    					ID:           clientID,
    					RedirectURIs: []string{deviceCallbackURI},
    					Public:       true,
    				}
    				if err := s.storage.CreateClient(client); err != nil {
    					t.Fatalf("failed to create client: %v", err)
    				}
    
    				// Grab the issuer that we'll reuse for the different endpoints to hit
    				issuer, err := url.Parse(s.issuerURL.String())
    				if err != nil {
    					t.Errorf("Could not parse issuer URL %v", err)
    				}
    
    				// Send a new Device Request
    				codeURL, _ := url.Parse(issuer.String())
    				codeURL.Path = path.Join(codeURL.Path, "device/code")
    
    				data := url.Values{}
    				data.Set("client_id", clientID)
    				data.Add("scope", strings.Join(requestedScopes, " "))
    				resp, err := http.PostForm(codeURL.String(), data)
    				if err != nil {
    					t.Errorf("Could not request device code: %v", err)
    				}
    				defer resp.Body.Close()
    				responseBody, err := ioutil.ReadAll(resp.Body)
    				if err != nil {
    					t.Errorf("Could read device code response %v", err)
    				}
    				if resp.StatusCode != http.StatusOK {
    					t.Errorf("%v - Unexpected Response Type.  Expected 200 got  %v.  Response: %v", tc.name, resp.StatusCode, string(responseBody))
    				}
    				if resp.Header.Get("Cache-Control") != "no-store" {
    					t.Errorf("Cache-Control header doesn't exist in Device Code Response")
    				}
    
    				// Parse the code response
    				var deviceCode deviceCodeResponse
    				if err := json.Unmarshal(responseBody, &deviceCode); err != nil {
    					t.Errorf("Unexpected Device Code Response Format %v", string(responseBody))
    				}
    
    				// Mock the user hitting the verification URI and posting the form
    				verifyURL, _ := url.Parse(issuer.String())
    				verifyURL.Path = path.Join(verifyURL.Path, "/device/auth/verify_code")
    				urlData := url.Values{}
    				urlData.Set("user_code", deviceCode.UserCode)
    				resp, err = http.PostForm(verifyURL.String(), urlData)
    				if err != nil {
    					t.Errorf("Error Posting Form: %v", err)
    				}
    				defer resp.Body.Close()
    				responseBody, err = ioutil.ReadAll(resp.Body)
    				if err != nil {
    					t.Errorf("Could read verification response %v", err)
    				}
    				if resp.StatusCode != http.StatusOK {
    					t.Errorf("%v - Unexpected Response Type.  Expected 200 got  %v.  Response: %v", tc.name, resp.StatusCode, string(responseBody))
    				}
    
    				// Hit the Token Endpoint, and try and get an access token
    				tokenURL, _ := url.Parse(issuer.String())
    				tokenURL.Path = path.Join(tokenURL.Path, testCase.tokenEndpoint)
    				v := url.Values{}
    				v.Add("grant_type", grantTypeDeviceCode)
    				v.Add("device_code", deviceCode.DeviceCode)
    				resp, err = http.PostForm(tokenURL.String(), v)
    				if err != nil {
    					t.Errorf("Could not request device token: %v", err)
    				}
    				defer resp.Body.Close()
    				responseBody, err = ioutil.ReadAll(resp.Body)
    				if err != nil {
    					t.Errorf("Could read device token response %v", err)
    				}
    				if resp.StatusCode != http.StatusOK {
    					t.Errorf("%v - Unexpected Token Response Type.  Expected 200 got  %v.  Response: %v", tc.name, resp.StatusCode, string(responseBody))
    				}
    
    				// Parse the response
    				var tokenRes accessTokenResponse
    				if err := json.Unmarshal(responseBody, &tokenRes); err != nil {
    					t.Errorf("Unexpected Device Access Token Response Format %v", string(responseBody))
    				}
    
    				token := &oauth2.Token{
    					AccessToken:  tokenRes.AccessToken,
    					TokenType:    tokenRes.TokenType,
    					RefreshToken: tokenRes.RefreshToken,
    				}
    				raw := make(map[string]interface{})
    				json.Unmarshal(responseBody, &raw) // no error checks for optional fields
    				token = token.WithExtra(raw)
    				if secs := tokenRes.ExpiresIn; secs > 0 {
    					token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
    				}
    
    				// Run token tests to validate info is correct
    				// Create the OAuth2 config.
    				oauth2Config := &oauth2.Config{
    					ClientID:     client.ID,
    					ClientSecret: client.Secret,
    					Endpoint:     p.Endpoint(),
    					Scopes:       requestedScopes,
    					RedirectURL:  deviceCallbackURI,
    				}
    				if len(tc.scopes) != 0 {
    					oauth2Config.Scopes = tc.scopes
    				}
    				err = tc.handleToken(ctx, p, oauth2Config, token, conn)
    				if err != nil {
    					t.Errorf("%s: %v", tc.name, err)
    				}
    			})
    		}