Newer
Older
"fmt"
"log"
"net/http"
"net/url"
"path"
"strconv"
"strings"
"time"
"github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2"
"github.com/coreos/dex/connector"
"github.com/coreos/dex/storage"
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
)
func (s *Server) handlePublicKeys(w http.ResponseWriter, r *http.Request) {
// TODO(ericchiang): Cache this.
keys, err := s.storage.GetKeys()
if err != nil {
log.Printf("failed to get keys: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
if keys.SigningKeyPub == nil {
log.Printf("No public keys found.")
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
jwks := jose.JSONWebKeySet{
Keys: make([]jose.JSONWebKey, len(keys.VerificationKeys)+1),
}
jwks.Keys[0] = *keys.SigningKeyPub
for i, verificationKey := range keys.VerificationKeys {
jwks.Keys[i+1] = *verificationKey.PublicKey
}
data, err := json.MarshalIndent(jwks, "", " ")
if err != nil {
log.Printf("failed to marshal discovery data: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
maxAge := keys.NextRotation.Sub(s.now())
if maxAge < (time.Minute * 2) {
maxAge = time.Minute * 2
}
w.Header().Set("Cache-Control", fmt.Sprintf("max-age=%d, must-revalidate", maxAge))
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data)
}
type discovery struct {
Issuer string `json:"issuer"`
Auth string `json:"authorization_endpoint"`
Token string `json:"token_endpoint"`
Keys string `json:"jwks_uri"`
ResponseTypes []string `json:"response_types_supported"`
Subjects []string `json:"subject_types_supported"`
IDTokenAlgs []string `json:"id_token_signing_alg_values_supported"`
Scopes []string `json:"scopes_supported"`
AuthMethods []string `json:"token_endpoint_auth_methods_supported"`
Claims []string `json:"claims_supported"`
}
func (s *Server) handleDiscovery(w http.ResponseWriter, r *http.Request) {
// TODO(ericchiang): Cache this
d := discovery{
Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"),
Token: s.absURL("/token"),
Keys: s.absURL("/keys"),
ResponseTypes: []string{"code"},
Subjects: []string{"public"},
IDTokenAlgs: []string{string(jose.RS256)},
Scopes: []string{"openid", "email", "profile", "offline_access"},
AuthMethods: []string{"client_secret_basic"},
Claims: []string{
"aud", "email", "email_verified", "exp",
"iat", "iss", "locale", "name", "sub",
},
}
data, err := json.MarshalIndent(d, "", " ")
if err != nil {
log.Printf("failed to marshal discovery data: %v", err)
http.Error(w, "Internal server error", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data)
}
// handleAuthorization handles the OAuth2 auth endpoint.
func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) {
authReq, err := parseAuthorizationRequest(s.storage, s.supportedResponseTypes, r)
if err != nil {
s.renderError(w, http.StatusInternalServerError, err.Type, err.Description)
return
}
if err := s.storage.CreateAuthRequest(authReq); err != nil {
log.Printf("Failed to create authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
state := authReq.ID
if len(s.connectors) == 1 {
for id := range s.connectors {
http.Redirect(w, r, s.absPath("/auth", id)+"?state="+state, http.StatusFound)
return
}
}
connectorInfos := make([]connectorInfo, len(s.connectors))
i := 0
for id := range s.connectors {
connectorInfos[i] = connectorInfo{
DisplayName: id,
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
}
i++
}
renderLoginOptions(w, connectorInfos, state)
}
func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) {
connID := mux.Vars(r)["connector"]
conn, ok := s.connectors[connID]
if !ok {
s.notFound(w, r)
return
}
// TODO(ericchiang): cache user identity.
state := r.FormValue("state")
switch r.Method {
case "GET":
switch conn := conn.Connector.(type) {
case connector.CallbackConnector:
callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state)
if err != nil {
log.Printf("Connector %q returned error when creating callback: %v", connID, err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
http.Redirect(w, r, callbackURL, http.StatusFound)
case connector.PasswordConnector:
renderPasswordTmpl(w, state, r.URL.String(), "")
default:
s.notFound(w, r)
}
case "POST":
passwordConnector, ok := conn.Connector.(connector.PasswordConnector)
if !ok {
s.notFound(w, r)
return
}
username := r.FormValue("username")
password := r.FormValue("password")
identity, ok, err := passwordConnector.Login(username, password)
if err != nil {
log.Printf("Failed to login user: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
if !ok {
renderPasswordTmpl(w, state, r.URL.String(), "Invalid credentials")
return
}
redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector)
log.Printf("Failed to finalize login: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
default:
s.notFound(w, r)
}
}
func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) {
connID := mux.Vars(r)["connector"]
conn, ok := s.connectors[connID]
if !ok {
s.notFound(w, r)
return
}
callbackConnector, ok := conn.Connector.(connector.CallbackConnector)
if !ok {
s.notFound(w, r)
return
}
identity, state, err := callbackConnector.HandleCallback(r)
if err != nil {
log.Printf("Failed to authenticate: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector)
log.Printf("Failed to finalize login: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
http.Redirect(w, r, redirectURL, http.StatusSeeOther)
func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connectorID string, conn connector.Connector) (string, error) {
if authReqID == "" {
return "", errors.New("no auth request ID passed")
}
UserID: identity.UserID,
Username: identity.Username,
Email: identity.Email,
EmailVerified: identity.EmailVerified,
}
groupsConn, ok := conn.(connector.GroupsConnector)
if ok {
authReq, err := s.storage.GetAuthRequest(authReqID)
if err != nil {
return "", fmt.Errorf("get auth request: %v", err)
}
reqGroups := func() bool {
for _, scope := range authReq.Scopes {
if scope == scopeGroups {
return true
}
}
return false
}()
if reqGroups {
if claims.Groups, err = groupsConn.Groups(identity); err != nil {
return "", fmt.Errorf("getting groups: %v", err)
updater := func(a storage.AuthRequest) (storage.AuthRequest, error) {
a.ConnectorID = connectorID
a.ConnectorData = identity.ConnectorData
return a, nil
}
if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil {
return "", fmt.Errorf("failed to update auth request: %v", err)
}
return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReqID, nil
}
func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) {
authReq, err := s.storage.GetAuthRequest(r.FormValue("state"))
if err != nil {
log.Printf("Failed to get auth request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
log.Printf("Auth request does not have an identity for approval")
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
switch r.Method {
case "GET":
if s.skipApproval {
return
}
client, err := s.storage.GetClient(authReq.ClientID)
if err != nil {
log.Printf("Failed to get client %q: %v", authReq.ClientID, err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
renderApprovalTmpl(w, authReq.ID, *authReq.Claims, client, authReq.Scopes)
case "POST":
if r.FormValue("approval") != "approve" {
s.renderError(w, http.StatusInternalServerError, "approval rejected", "")
return
}
func (s *Server) sendCodeResponse(w http.ResponseWriter, r *http.Request, authReq storage.AuthRequest) {
if authReq.Expiry.After(s.now()) {
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request period has expired.")
return
}
if err := s.storage.DeleteAuthRequest(authReq.ID); err != nil {
if err != storage.ErrNotFound {
log.Printf("Failed to delete authorization request: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
} else {
s.renderError(w, http.StatusBadRequest, errInvalidRequest, "Authorization request has already been completed.")
}
return
}
u, err := url.Parse(authReq.RedirectURI)
if err != nil {
s.renderError(w, http.StatusInternalServerError, errServerError, "Invalid redirect URI.")
return
}
q := u.Query()
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
for _, responseType := range authReq.ResponseTypes {
switch responseType {
case responseTypeCode:
code := storage.AuthCode{
ID: storage.NewID(),
ClientID: authReq.ClientID,
ConnectorID: authReq.ConnectorID,
Nonce: authReq.Nonce,
Scopes: authReq.Scopes,
Claims: *authReq.Claims,
Expiry: s.now().Add(time.Minute * 5),
RedirectURI: authReq.RedirectURI,
}
if err := s.storage.CreateAuthCode(code); err != nil {
log.Printf("Failed to create auth code: %v", err)
s.renderError(w, http.StatusInternalServerError, errServerError, "")
return
}
if authReq.RedirectURI == redirectURIOOB {
// TODO(ericchiang): Add a proper template.
fmt.Fprintf(w, "Code: %s", code.ID)
return
}
q.Set("code", code.ID)
case responseTypeToken:
idToken, expiry, err := s.newIDToken(authReq.ClientID, *authReq.Claims, authReq.Scopes, authReq.Nonce)
if err != nil {
log.Printf("failed to create ID token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
v := url.Values{}
v.Set("access_token", storage.NewID())
v.Set("token_type", "bearer")
v.Set("id_token", idToken)
v.Set("state", authReq.State)
v.Set("expires_in", strconv.Itoa(int(expiry.Sub(s.now()))))
u.Fragment = v.Encode()
}
}
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
q.Set("state", authReq.State)
u.RawQuery = q.Encode()
http.Redirect(w, r, u.String(), http.StatusSeeOther)
}
func (s *Server) handleToken(w http.ResponseWriter, r *http.Request) {
clientID, clientSecret, ok := r.BasicAuth()
if ok {
var err error
if clientID, err = url.QueryUnescape(clientID); err != nil {
tokenErr(w, errInvalidRequest, "client_id improperly encoded", http.StatusBadRequest)
return
}
if clientSecret, err = url.QueryUnescape(clientSecret); err != nil {
tokenErr(w, errInvalidRequest, "client_secret improperly encoded", http.StatusBadRequest)
return
}
} else {
clientID = r.PostFormValue("client_id")
clientSecret = r.PostFormValue("client_secret")
}
client, err := s.storage.GetClient(clientID)
if err != nil {
if err != storage.ErrNotFound {
log.Printf("failed to get client: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
} else {
tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
}
return
}
if client.Secret != clientSecret {
tokenErr(w, errInvalidClient, "Invalid client credentials.", http.StatusUnauthorized)
return
}
grantType := r.PostFormValue("grant_type")
switch grantType {
case grantTypeAuthorizationCode:
case grantTypeRefreshToken:
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
s.handleRefreshToken(w, r, client)
default:
tokenErr(w, errInvalidGrant, "", http.StatusBadRequest)
}
}
// handle an access token request https://tools.ietf.org/html/rfc6749#section-4.1.3
func (s *Server) handleAuthCode(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("code")
redirectURI := r.PostFormValue("redirect_uri")
authCode, err := s.storage.GetAuthCode(code)
if err != nil || s.now().After(authCode.Expiry) || authCode.ClientID != client.ID {
if err != storage.ErrNotFound {
log.Printf("failed to get auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
} else {
tokenErr(w, errInvalidRequest, "Invalid or expired code parameter.", http.StatusBadRequest)
}
return
}
if authCode.RedirectURI != redirectURI {
tokenErr(w, errInvalidRequest, "redirect_uri did not match URI from initial request.", http.StatusBadRequest)
return
}
idToken, expiry, err := s.newIDToken(client.ID, authCode.Claims, authCode.Scopes, authCode.Nonce)
if err != nil {
log.Printf("failed to create ID token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
if err := s.storage.DeleteAuthCode(code); err != nil {
log.Printf("failed to delete auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
reqRefresh := func() bool {
for _, scope := range authCode.Scopes {
if scope == scopeOfflineAccess {
return true
}
}
return false
}()
var refreshToken string
if reqRefresh {
refresh := storage.RefreshToken{
RefreshToken: storage.NewID(),
ClientID: authCode.ClientID,
ConnectorID: authCode.ConnectorID,
Scopes: authCode.Scopes,
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
Nonce: authCode.Nonce,
}
if err := s.storage.CreateRefresh(refresh); err != nil {
log.Printf("failed to create refresh token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
refreshToken = refresh.RefreshToken
}
s.writeAccessToken(w, idToken, refreshToken, expiry)
}
// handle a refresh token request https://tools.ietf.org/html/rfc6749#section-6
func (s *Server) handleRefreshToken(w http.ResponseWriter, r *http.Request, client storage.Client) {
code := r.PostFormValue("refresh_token")
scope := r.PostFormValue("scope")
if code == "" {
tokenErr(w, errInvalidRequest, "No refresh token in request.", http.StatusBadRequest)
return
}
refresh, err := s.storage.GetRefresh(code)
if err != nil || refresh.ClientID != client.ID {
if err != storage.ErrNotFound {
log.Printf("failed to get auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
} else {
tokenErr(w, errInvalidRequest, "Refresh token is invalid or has already been claimed by another client.", http.StatusBadRequest)
}
return
}
scopes := refresh.Scopes
if scope != "" {
requestedScopes := strings.Split(scope, " ")
contains := func() bool {
Loop:
for _, s := range requestedScopes {
for _, scope := range refresh.Scopes {
if s == scope {
continue Loop
}
}
return false
}
return true
}()
if !contains {
tokenErr(w, errInvalidRequest, "Requested scopes did not contain authorized scopes.", http.StatusBadRequest)
return
}
scopes = requestedScopes
}
// TODO(ericchiang): re-auth with backends
idToken, expiry, err := s.newIDToken(client.ID, refresh.Claims, scopes, refresh.Nonce)
if err != nil {
log.Printf("failed to create ID token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
if err := s.storage.DeleteRefresh(code); err != nil {
log.Printf("failed to delete auth code: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
if err := s.storage.CreateRefresh(refresh); err != nil {
log.Printf("failed to create refresh token: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
s.writeAccessToken(w, idToken, refresh.RefreshToken, expiry)
}
func (s *Server) writeAccessToken(w http.ResponseWriter, idToken, refreshToken string, expiry time.Time) {
// TODO(ericchiang): figure out an access token story and support the user info
// endpoint. For now use a random value so no one depends on the access_token
// holding a specific structure.
resp := struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ExpiresIn int `json:"expires_in"`
RefreshToken string `json:"refresh_token,omitempty"`
IDToken string `json:"id_token"`
}{
"bearer",
int(expiry.Sub(s.now())),
refreshToken,
idToken,
}
data, err := json.Marshal(resp)
if err != nil {
log.Printf("failed to marshal access token response: %v", err)
tokenErr(w, errServerError, "", http.StatusInternalServerError)
return
}
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data)
}
func (s *Server) renderError(w http.ResponseWriter, status int, err, description string) {
http.Error(w, fmt.Sprintf("%s: %s", err, description), status)
}
func (s *Server) notFound(w http.ResponseWriter, r *http.Request) {
http.NotFound(w, r)
}