Skip to content
Snippets Groups Projects
server_test.go 33.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • 				t.Errorf("%s: expected valid password", tc.name)
    			}
    			continue
    		}
    
    		if tc.wantInvalid {
    			t.Errorf("%s: expected invalid password", tc.name)
    			continue
    		}
    
    		if diff := pretty.Compare(tc.wantIdentity, ident); diff != "" {
    			t.Errorf("%s: %s", tc.name, diff)
    		}
    	}
    
    }
    
    
    func TestPasswordDBUsernamePrompt(t *testing.T) {
    	s := memory.New(logger)
    	conn := newPasswordDB(s)
    
    	expected := "Email Address"
    	if actual := conn.Prompt(); actual != expected {
    		t.Errorf("expected %v, got %v", expected, actual)
    	}
    }
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    type storageWithKeysTrigger struct {
    	storage.Storage
    	f func()
    }
    
    func (s storageWithKeysTrigger) GetKeys() (storage.Keys, error) {
    	s.f()
    	return s.Storage.GetKeys()
    }
    
    func TestKeyCacher(t *testing.T) {
    	tNow := time.Now()
    	now := func() time.Time { return tNow }
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    
    	tests := []struct {
    		before            func()
    		wantCallToStorage bool
    	}{
    		{
    			before:            func() {},
    			wantCallToStorage: true,
    		},
    		{
    			before: func() {
    				s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
    					old.NextRotation = tNow.Add(time.Minute)
    					return old, nil
    				})
    			},
    			wantCallToStorage: true,
    		},
    		{
    			before:            func() {},
    			wantCallToStorage: false,
    		},
    		{
    			before: func() {
    				tNow = tNow.Add(time.Hour)
    			},
    			wantCallToStorage: true,
    		},
    		{
    			before: func() {
    				tNow = tNow.Add(time.Hour)
    				s.UpdateKeys(func(old storage.Keys) (storage.Keys, error) {
    					old.NextRotation = tNow.Add(time.Minute)
    					return old, nil
    				})
    			},
    			wantCallToStorage: true,
    		},
    		{
    			before:            func() {},
    			wantCallToStorage: false,
    		},
    	}
    
    	gotCall := false
    	s = newKeyCacher(storageWithKeysTrigger{s, func() { gotCall = true }}, now)
    	for i, tc := range tests {
    		gotCall = false
    		tc.before()
    		s.GetKeys()
    		if gotCall != tc.wantCallToStorage {
    			t.Errorf("case %d: expected call to storage=%t got call to storage=%t", i, tc.wantCallToStorage, gotCall)
    		}
    	}
    }
    
    
    type oauth2Client struct {
    	config *oauth2.Config
    	token  *oauth2.Token
    	server *httptest.Server
    }
    
    // TestRefreshTokenFlow tests the refresh token code flow for oauth2. The test verifies
    // that only valid refresh tokens can be used to refresh an expired token.
    func TestRefreshTokenFlow(t *testing.T) {
    	state := "state"
    	now := func() time.Time { return time.Now() }
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    	httpServer, s := newTestServer(ctx, t, func(c *Config) {
    		c.Now = now
    	})
    	defer httpServer.Close()
    
    	p, err := oidc.NewProvider(ctx, httpServer.URL)
    	if err != nil {
    		t.Fatalf("failed to get provider: %v", err)
    	}
    
    	var oauth2Client oauth2Client
    
    	oauth2Client.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		if r.URL.Path != "/callback" {
    			// User is visiting app first time. Redirect to dex.
    			http.Redirect(w, r, oauth2Client.config.AuthCodeURL(state), http.StatusSeeOther)
    			return
    		}
    
    		// User is at '/callback' so they were just redirected _from_ dex.
    		q := r.URL.Query()
    
    		if errType := q.Get("error"); errType != "" {
    			if desc := q.Get("error_description"); desc != "" {
    				t.Errorf("got error from server %s: %s", errType, desc)
    			} else {
    				t.Errorf("got error from server %s", errType)
    			}
    			w.WriteHeader(http.StatusInternalServerError)
    			return
    		}
    
    		// Grab code, exchange for token.
    		if code := q.Get("code"); code != "" {
    			token, err := oauth2Client.config.Exchange(ctx, code)
    			if err != nil {
    				t.Errorf("failed to exchange code for token: %v", err)
    				return
    			}
    			oauth2Client.token = token
    		}
    
    		// Ensure state matches.
    		if gotState := q.Get("state"); gotState != state {
    			t.Errorf("state did not match, want=%q got=%q", state, gotState)
    		}
    		w.WriteHeader(http.StatusOK)
    		return
    	}))
    	defer oauth2Client.server.Close()
    
    	// Register the client above with dex.
    	redirectURL := oauth2Client.server.URL + "/callback"
    	client := storage.Client{
    		ID:           "testclient",
    		Secret:       "testclientsecret",
    		RedirectURIs: []string{redirectURL},
    	}
    	if err := s.storage.CreateClient(client); err != nil {
    		t.Fatalf("failed to create client: %v", err)
    	}
    
    	oauth2Client.config = &oauth2.Config{
    		ClientID:     client.ID,
    		ClientSecret: client.Secret,
    		Endpoint:     p.Endpoint(),
    		Scopes:       []string{oidc.ScopeOpenID, "email", "offline_access"},
    		RedirectURL:  redirectURL,
    	}
    
    	if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
    		t.Fatalf("get failed: %v", err)
    	}
    
    	tok := &oauth2.Token{
    		RefreshToken: oauth2Client.token.RefreshToken,
    		Expiry:       time.Now().Add(-time.Hour),
    	}
    
    
    	// Login in again to receive a new token.
    
    	if _, err = http.Get(oauth2Client.server.URL + "/login"); err != nil {
    		t.Fatalf("get failed: %v", err)
    	}
    
    	// try to refresh expired token with old refresh token.
    	newToken, err := oauth2Client.config.TokenSource(ctx, tok).Token()
    	if newToken != nil {
    		t.Errorf("Token refreshed with invalid refresh token.")
    	}
    }