Skip to content
Snippets Groups Projects
Commit 9c699b10 authored by Justin Slowik's avatar Justin Slowik Committed by justin-slowik
Browse files

Server integration test for Device Flow (#3)


Extracted test cases from OAuth2Code flow tests to reuse in device flow

deviceHandler unit tests to test specific device endpoints

Include client secret as an optional parameter for standards compliance

Signed-off-by: default avatarjustin-slowik <justin.slowik@thermofisher.com>
parent 9bbdc721
Branches
Tags
No related merge requests found
......@@ -29,7 +29,7 @@ type deviceCodeResponse struct {
PollInterval int `json:"interval"`
}
func (s *Server) getDeviceAuthURI() string {
func (s *Server) getDeviceVerificationURI() string {
return path.Join(s.issuerURL.Path, "/device/auth/verify_code")
}
......@@ -41,8 +41,9 @@ func (s *Server) handleDeviceExchange(w http.ResponseWriter, r *http.Request) {
if err != nil {
invalidAttempt = false
}
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, invalidAttempt); err != nil {
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, invalidAttempt); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
default:
s.renderError(r, w, http.StatusBadRequest, "Requested resource does not exist.")
......@@ -63,7 +64,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
//Get the client id and scopes from the post
clientID := r.Form.Get("client_id")
scopes := r.Form["scope"]
clientSecret := r.Form.Get("client_secret")
scopes := strings.Fields(r.Form.Get("scope"))
s.logger.Infof("Received device request for client %v with scopes %v", clientID, scopes)
......@@ -82,11 +84,12 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
//Store the Device Request
deviceReq := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: deviceCode,
ClientID: clientID,
Scopes: scopes,
Expiry: expireTime,
UserCode: userCode,
DeviceCode: deviceCode,
ClientID: clientID,
ClientSecret: clientSecret,
Scopes: scopes,
Expiry: expireTime,
}
if err := s.storage.CreateDeviceRequest(deviceReq); err != nil {
......@@ -100,8 +103,8 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
DeviceCode: deviceCode,
Status: deviceTokenPending,
Expiry: expireTime,
LastRequestTime: time.Now(),
PollIntervalSeconds: 5,
LastRequestTime: s.now(),
PollIntervalSeconds: 0,
}
if err := s.storage.CreateDeviceToken(deviceToken); err != nil {
......@@ -113,7 +116,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
u, err := url.Parse(s.issuerURL.String())
if err != nil {
s.logger.Errorf("Could not parse issuer URL %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
s.tokenErrHelper(w, errInvalidRequest, "", http.StatusInternalServerError)
return
}
u.Path = path.Join(u.Path, "device")
......@@ -134,6 +137,7 @@ func (s *Server) handleDeviceCode(w http.ResponseWriter, r *http.Request) {
}
enc := json.NewEncoder(w)
enc.SetEscapeHTML(false)
enc.SetIndent("", " ")
enc.Encode(code)
......@@ -168,21 +172,25 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
now := s.now()
//Grab the device token
//Grab the device token, check validity
deviceToken, err := s.storage.GetDeviceToken(deviceCode)
if err != nil || now.After(deviceToken.Expiry) {
if err != nil {
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
}
s.tokenErrHelper(w, errInvalidRequest, "Invalid or expired device code.", http.StatusBadRequest)
s.tokenErrHelper(w, errInvalidRequest, "Invalid Device code.", http.StatusBadRequest)
return
} else if now.After(deviceToken.Expiry) {
s.tokenErrHelper(w, deviceTokenExpired, "", http.StatusBadRequest)
return
}
//Rate Limiting check
slowDown := false
pollInterval := deviceToken.PollIntervalSeconds
minRequestTime := deviceToken.LastRequestTime.Add(time.Second * time.Duration(pollInterval))
if now.Before(minRequestTime) {
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
slowDown = true
//Continually increase the poll interval until the user waits the proper time
pollInterval += 5
} else {
......@@ -202,7 +210,11 @@ func (s *Server) handleDeviceToken(w http.ResponseWriter, r *http.Request) {
s.renderError(r, w, http.StatusInternalServerError, "")
return
}
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
if slowDown {
s.tokenErrHelper(w, deviceTokenSlowDown, "", http.StatusBadRequest)
} else {
s.tokenErrHelper(w, deviceTokenPending, "", http.StatusUnauthorized)
}
case deviceTokenComplete:
w.Write([]byte(deviceToken.Token))
}
......@@ -230,44 +242,58 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
authCode, err := s.storage.GetAuthCode(code)
if err != nil || s.now().After(authCode.Expiry) {
if err != storage.ErrNotFound {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get auth code: %v", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, http.StatusBadRequest, "Invalid or expired auth code.")
s.renderError(r, w, errCode, "Invalid or expired auth code.")
return
}
//Grab the device request from storage
deviceReq, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceReq.Expiry) {
if err != storage.ErrNotFound {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device code: %v", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
s.renderError(r, w, errCode, "Invalid or expired user code.")
return
}
reqClient, err := s.storage.GetClient(deviceReq.ClientID)
client, err := s.storage.GetClient(deviceReq.ClientID)
if err != nil {
s.logger.Errorf("Failed to get reqClient %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to retrieve device client.")
if err != storage.ErrNotFound {
s.logger.Errorf("failed to get client: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
} else {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
}
return
}
if client.Secret != deviceReq.ClientSecret {
s.tokenErrHelper(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}
resp, err := s.exchangeAuthCode(w, authCode, reqClient)
resp, err := s.exchangeAuthCode(w, authCode, client)
if err != nil {
s.logger.Errorf("Could not exchange auth code for client %q: %v", deviceReq.ClientID, err)
s.renderError(r, w, http.StatusInternalServerError, "Failed to exchange auth code.")
return
}
//Grab the device request from storage
//Grab the device token from storage
old, err := s.storage.GetDeviceToken(deviceReq.DeviceCode)
if err != nil || s.now().After(old.Expiry) {
if err != storage.ErrNotFound {
errCode := http.StatusBadRequest
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device token: %v", err)
errCode = http.StatusInternalServerError
}
s.renderError(r, w, http.StatusInternalServerError, "Invalid or expired device code.")
s.renderError(r, w, errCode, "Invalid or expired device code.")
return
}
......@@ -290,12 +316,13 @@ func (s *Server) handleDeviceCallback(w http.ResponseWriter, r *http.Request) {
// Update refresh token in the storage, store the token and mark as complete
if err := s.storage.UpdateDeviceToken(deviceReq.DeviceCode, updater); err != nil {
s.logger.Errorf("failed to update device token: %v", err)
s.renderError(r, w, http.StatusInternalServerError, "")
s.renderError(r, w, http.StatusBadRequest, "")
return
}
if err := s.templates.deviceSuccess(r, w, reqClient.Name); err != nil {
if err := s.templates.deviceSuccess(r, w, client.Name); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
default:
......@@ -309,9 +336,8 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
case http.MethodPost:
err := r.ParseForm()
if err != nil {
message := "Could not parse user code verification Request body"
s.logger.Warnf("%s : %v", message, err)
s.tokenErrHelper(w, errInvalidRequest, message, http.StatusBadRequest)
s.logger.Warnf("Could not parse user code verification request body : %v", err)
s.renderError(r, w, http.StatusBadRequest, "")
return
}
......@@ -326,12 +352,12 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
//Find the user code in the available requests
deviceRequest, err := s.storage.GetDeviceRequest(userCode)
if err != nil || s.now().After(deviceRequest.Expiry) {
if err != storage.ErrNotFound {
if err != nil && err != storage.ErrNotFound {
s.logger.Errorf("failed to get device request: %v", err)
s.tokenErrHelper(w, errServerError, "", http.StatusInternalServerError)
}
if err := s.templates.device(r, w, s.getDeviceAuthURI(), userCode, true); err != nil {
if err := s.templates.device(r, w, s.getDeviceVerificationURI(), userCode, true); err != nil {
s.logger.Errorf("Server template error: %v", err)
s.renderError(r, w, http.StatusNotFound, "Page not found")
}
return
}
......@@ -345,6 +371,7 @@ func (s *Server) verifyUserCode(w http.ResponseWriter, r *http.Request) {
}
q := u.Query()
q.Set("client_id", deviceRequest.ClientID)
q.Set("client_secret", deviceRequest.ClientSecret)
q.Set("state", deviceRequest.UserCode)
q.Set("response_type", "code")
q.Set("redirect_uri", path.Join(s.issuerURL.Path, "/device/callback"))
......
This diff is collapsed.
......@@ -15,11 +15,12 @@ import (
"time"
oidc "github.com/coreos/go-oidc"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
"github.com/dexidp/dex/connector"
"github.com/dexidp/dex/server/internal"
"github.com/dexidp/dex/storage"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
)
// newHealthChecker returns the healthz handler. The handler runs until the
......@@ -153,7 +154,7 @@ type discovery struct {
Keys string `json:"jwks_uri"`
UserInfo string `json:"userinfo_endpoint"`
DeviceEndpoint string `json:"device_authorization_endpoint"`
GrantTypes []string `json:"grant_types_supported"'`
GrantTypes []string `json:"grant_types_supported"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
......@@ -1381,18 +1382,10 @@ func (s *Server) handlePasswordGrant(w http.ResponseWriter, r *http.Request, cli
}
}
s.writeAccessToken(w, idToken, accessToken, refreshToken, expiry)
resp := s.toAccessTokenResponse(idToken, accessToken, refreshToken, expiry)
s.writeAccessToken(w, resp)
}
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, accessToken, refreshToken string, expiry time.Time) {
resp := struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
}{
type accessTokenReponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
......
......@@ -81,7 +81,7 @@ type Config struct {
DeviceRequestsValidFor time.Duration // Defaults to 5 minutes
// If set, the server will use this connector to handle password grants
PasswordConnector string
GCFrequency time.Duration // Defaults to 5 minutes
// If specified, the server will use this function for determining time.
......
This diff is collapsed.
......@@ -250,6 +250,9 @@ func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name }
func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func (t *templates) device(r *http.Request, w http.ResponseWriter, postURL string, userCode string, lastWasInvalid bool) error {
if lastWasInvalid {
w.WriteHeader(http.StatusBadRequest)
}
data := struct {
PostURL string
UserCode string
......
......@@ -843,11 +843,12 @@ func testGC(t *testing.T, s storage.Storage) {
}
d := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
Expiry: expiry,
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
ClientSecret: "secret1",
Scopes: []string{"openid", "email"},
Expiry: expiry,
}
if err := s.CreateDeviceRequest(d); err != nil {
......@@ -863,9 +864,9 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected no device garbage collection results, got %#v", result)
}
}
//if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
// t.Errorf("expected to be able to get auth request after GC: %v", err)
//}
if _, err := s.GetDeviceRequest(d.UserCode); err != nil {
t.Errorf("expected to be able to get auth request after GC: %v", err)
}
}
if r, err := s.GarbageCollect(expiry.Add(time.Hour)); err != nil {
t.Errorf("garbage collection failed: %v", err)
......@@ -873,18 +874,19 @@ func testGC(t *testing.T, s storage.Storage) {
t.Errorf("expected to garbage collect 1 device request, got %d", r.DeviceRequests)
}
//TODO add this code back once Getters are written for device requests
//if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
// t.Errorf("expected device request to be GC'd")
//} else if err != storage.ErrNotFound {
// t.Errorf("expected storage.ErrNotFound, got %v", err)
//}
if _, err := s.GetDeviceRequest(d.UserCode); err == nil {
t.Errorf("expected device request to be GC'd")
} else if err != storage.ErrNotFound {
t.Errorf("expected storage.ErrNotFound, got %v", err)
}
dt := storage.DeviceToken{
DeviceCode: storage.NewID(),
Status: "pending",
Token: "foo",
Expiry: expiry,
DeviceCode: storage.NewID(),
Status: "pending",
Token: "foo",
Expiry: expiry,
LastRequestTime: time.Now(),
PollIntervalSeconds: 0,
}
if err := s.CreateDeviceToken(dt); err != nil {
......@@ -969,11 +971,12 @@ func testDeviceRequestCRUD(t *testing.T, s storage.Storage) {
panic(err)
}
d1 := storage.DeviceRequest{
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
Scopes: []string{"openid", "email"},
Expiry: neverExpire,
UserCode: userCode,
DeviceCode: storage.NewID(),
ClientID: "client1",
ClientSecret: "secret1",
Scopes: []string{"openid", "email"},
Expiry: neverExpire,
}
if err := s.CreateDeviceRequest(d1); err != nil {
......
......@@ -595,7 +595,7 @@ func (c *conn) listDeviceRequests(ctx context.Context) (requests []DeviceRequest
func (c *conn) CreateDeviceToken(t storage.DeviceToken) error {
ctx, cancel := context.WithTimeout(context.Background(), defaultStorageTimeout)
defer cancel()
return c.txnCreate(ctx, keyID(deviceRequestPrefix, t.DeviceCode), fromStorageDeviceToken(t))
return c.txnCreate(ctx, keyID(deviceTokenPrefix, t.DeviceCode), fromStorageDeviceToken(t))
}
func (c *conn) GetDeviceToken(deviceCode string) (t storage.DeviceToken, err error) {
......
......@@ -44,6 +44,8 @@ func cleanDB(c *conn) error {
passwordPrefix,
offlineSessionPrefix,
connectorPrefix,
deviceRequestPrefix,
deviceTokenPrefix,
} {
_, err := c.db.Delete(ctx, prefix, clientv3.WithPrefix())
if err != nil {
......
......@@ -219,20 +219,22 @@ func toStorageOfflineSessions(o OfflineSessions) storage.OfflineSessions {
// DeviceRequest is a mirrored struct from storage with JSON struct tags
type DeviceRequest struct {
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
Scopes []string `json:"scopes"`
Expiry time.Time `json:"expiry"`
UserCode string `json:"user_code"`
DeviceCode string `json:"device_code"`
ClientID string `json:"client_id"`
ClientSecret string `json:"client_secret"`
Scopes []string `json:"scopes"`
Expiry time.Time `json:"expiry"`
}
func fromStorageDeviceRequest(d storage.DeviceRequest) DeviceRequest {
return DeviceRequest{
UserCode: d.UserCode,
DeviceCode: d.DeviceCode,
ClientID: d.ClientID,
Scopes: d.Scopes,
Expiry: d.Expiry,
UserCode: d.UserCode,
DeviceCode: d.DeviceCode,
ClientID: d.ClientID,
ClientSecret: d.ClientSecret,
Scopes: d.Scopes,
Expiry: d.Expiry,
}
}
......
......@@ -672,10 +672,11 @@ type DeviceRequest struct {
k8sapi.TypeMeta `json:",inline"`
k8sapi.ObjectMeta `json:"metadata,omitempty"`
DeviceCode string `json:"device_code,omitempty"`
CLientID string `json:"client_id,omitempty"`
Scopes []string `json:"scopes,omitempty"`
Expiry time.Time `json:"expiry"`
DeviceCode string `json:"device_code,omitempty"`
ClientID string `json:"client_id,omitempty"`
ClientSecret string `json:"client_secret,omitempty"`
Scopes []string `json:"scopes,omitempty"`
Expiry time.Time `json:"expiry"`
}
// AuthRequestList is a list of AuthRequests.
......@@ -695,21 +696,23 @@ func (cli *client) fromStorageDeviceRequest(a storage.DeviceRequest) DeviceReque
Name: strings.ToLower(a.UserCode),
Namespace: cli.namespace,
},
DeviceCode: a.DeviceCode,
CLientID: a.ClientID,
Scopes: a.Scopes,
Expiry: a.Expiry,
DeviceCode: a.DeviceCode,
ClientID: a.ClientID,
ClientSecret: a.ClientSecret,
Scopes: a.Scopes,
Expiry: a.Expiry,
}
return req
}
func toStorageDeviceRequest(req DeviceRequest) storage.DeviceRequest {
return storage.DeviceRequest{
UserCode: strings.ToUpper(req.ObjectMeta.Name),
DeviceCode: req.DeviceCode,
ClientID: req.CLientID,
Scopes: req.Scopes,
Expiry: req.Expiry,
UserCode: strings.ToUpper(req.ObjectMeta.Name),
DeviceCode: req.DeviceCode,
ClientID: req.ClientID,
ClientSecret: req.ClientSecret,
Scopes: req.Scopes,
Expiry: req.Expiry,
}
}
......
......@@ -888,12 +888,12 @@ func (c *conn) delete(table, field, id string) error {
func (c *conn) CreateDeviceRequest(d storage.DeviceRequest) error {
_, err := c.Exec(`
insert into device_request (
user_code, device_code, client_id, scopes, expiry
user_code, device_code, client_id, client_secret, scopes, expiry
)
values (
$1, $2, $3, $4, $5
$1, $2, $3, $4, $5, $6
);`,
d.UserCode, d.DeviceCode, d.ClientID, encoder(d.Scopes), d.Expiry,
d.UserCode, d.DeviceCode, d.ClientID, d.ClientSecret, encoder(d.Scopes), d.Expiry,
)
if err != nil {
if c.alreadyExistsCheck(err) {
......@@ -930,10 +930,10 @@ func (c *conn) GetDeviceRequest(userCode string) (storage.DeviceRequest, error)
func getDeviceRequest(q querier, userCode string) (d storage.DeviceRequest, err error) {
err = q.QueryRow(`
select
device_code, client_id, scopes, expiry
device_code, client_id, client_secret, scopes, expiry
from device_request where user_code = $1;
`, userCode).Scan(
&d.DeviceCode, &d.ClientID, decoder(&d.Scopes), &d.Expiry,
&d.DeviceCode, &d.ClientID, &d.ClientSecret, decoder(&d.Scopes), &d.Expiry,
)
if err != nil {
if err == sql.ErrNoRows {
......
......@@ -235,6 +235,7 @@ var migrations = []migration{
user_code text not null primary key,
device_code text not null,
client_id text not null,
client_secret text ,
scopes bytea not null, -- JSON array of strings
expiry timestamptz not null
);`,
......
......@@ -392,6 +392,8 @@ type DeviceRequest struct {
DeviceCode string
//The client ID the code is for
ClientID string
//The Client Secret
ClientSecret string
//The scopes the device requests
Scopes []string
//The expire time
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment