diff --git a/connector/connector.go b/connector/connector.go index 8235caae4120ed2d27ed40a3cc2aa7f539571ba1..9f84d3e66f93f6c6806336179a1b9e495794b0a6 100644 --- a/connector/connector.go +++ b/connector/connector.go @@ -33,7 +33,7 @@ type PasswordConnector interface { // CallbackConnector is an optional interface for callback based connectors. type CallbackConnector interface { LoginURL(callbackURL, state string) (string, error) - HandleCallback(r *http.Request) (identity Identity, state string, err error) + HandleCallback(r *http.Request) (identity Identity, err error) } // GroupsConnector is an optional interface for connectors which can map a user to groups. diff --git a/connector/github/github.go b/connector/github/github.go index 988c2c536ab73ba451186ad51f1af3da97080740..ab22217b4853d131984e843ea7f9fd110f6b9d53 100644 --- a/connector/github/github.go +++ b/connector/github/github.go @@ -83,28 +83,28 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) { +func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { - return identity, "", &oauth2Error{errType, q.Get("error_description")} + return identity, &oauth2Error{errType, q.Get("error_description")} } token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) if err != nil { - return identity, "", fmt.Errorf("github: failed to get token: %v", err) + return identity, fmt.Errorf("github: failed to get token: %v", err) } resp, err := c.oauth2Config.Client(c.ctx, token).Get(baseURL + "/user") if err != nil { - return identity, "", fmt.Errorf("github: get URL %v", err) + return identity, fmt.Errorf("github: get URL %v", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { body, err := ioutil.ReadAll(resp.Body) if err != nil { - return identity, "", fmt.Errorf("github: read body: %v", err) + return identity, fmt.Errorf("github: read body: %v", err) } - return identity, "", fmt.Errorf("%s: %s", resp.Status, body) + return identity, fmt.Errorf("%s: %s", resp.Status, body) } var user struct { Name string `json:"name"` @@ -113,13 +113,13 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id Email string `json:"email"` } if err := json.NewDecoder(resp.Body).Decode(&user); err != nil { - return identity, "", fmt.Errorf("failed to decode response: %v", err) + return identity, fmt.Errorf("failed to decode response: %v", err) } data := connectorData{AccessToken: token.AccessToken} connData, err := json.Marshal(data) if err != nil { - return identity, "", fmt.Errorf("marshal connector data: %v", err) + return identity, fmt.Errorf("marshal connector data: %v", err) } username := user.Name @@ -133,7 +133,7 @@ func (c *githubConnector) HandleCallback(r *http.Request) (identity connector.Id EmailVerified: true, ConnectorData: connData, } - return identity, q.Get("state"), nil + return identity, nil } func (c *githubConnector) Groups(identity connector.Identity) ([]string, error) { diff --git a/connector/mock/connectortest.go b/connector/mock/connectortest.go index 0d4b87baf333ad4ae2eb00f84bbcac3312caf26f..1ceeb36119a0d3a61293c0f4e2f4ad404353c981 100644 --- a/connector/mock/connectortest.go +++ b/connector/mock/connectortest.go @@ -41,14 +41,14 @@ func (m callbackConnector) LoginURL(callbackURL, state string) (string, error) { var connectorData = []byte("foobar") -func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, string, error) { +func (m callbackConnector) HandleCallback(r *http.Request) (connector.Identity, error) { return connector.Identity{ UserID: "0-385-28089-0", Username: "Kilgore Trout", Email: "kilgore@kilgore.trout", EmailVerified: true, ConnectorData: connectorData, - }, r.URL.Query().Get("state"), nil + }, nil } func (m callbackConnector) Groups(identity connector.Identity) ([]string, error) { diff --git a/connector/oidc/oidc.go b/connector/oidc/oidc.go index e2d3361b75a137d054745cc16bbdb6356ec0a38f..9b5aad20da587305d396d60b8a84f152c9a74a15 100644 --- a/connector/oidc/oidc.go +++ b/connector/oidc/oidc.go @@ -94,23 +94,23 @@ func (e *oauth2Error) Error() string { return e.error + ": " + e.errorDescription } -func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, state string, err error) { +func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Identity, err error) { q := r.URL.Query() if errType := q.Get("error"); errType != "" { - return identity, "", &oauth2Error{errType, q.Get("error_description")} + return identity, &oauth2Error{errType, q.Get("error_description")} } token, err := c.oauth2Config.Exchange(c.ctx, q.Get("code")) if err != nil { - return identity, "", fmt.Errorf("oidc: failed to get token: %v", err) + return identity, fmt.Errorf("oidc: failed to get token: %v", err) } rawIDToken, ok := token.Extra("id_token").(string) if !ok { - return identity, "", errors.New("oidc: no id_token in token response") + return identity, errors.New("oidc: no id_token in token response") } idToken, err := c.verifier.Verify(rawIDToken) if err != nil { - return identity, "", fmt.Errorf("oidc: failed to verify ID Token: %v", err) + return identity, fmt.Errorf("oidc: failed to verify ID Token: %v", err) } var claims struct { @@ -119,7 +119,7 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden EmailVerified bool `json:"email_verified"` } if err := idToken.Claims(&claims); err != nil { - return identity, "", fmt.Errorf("oidc: failed to decode claims: %v", err) + return identity, fmt.Errorf("oidc: failed to decode claims: %v", err) } identity = connector.Identity{ @@ -128,5 +128,5 @@ func (c *oidcConnector) HandleCallback(r *http.Request) (identity connector.Iden Email: claims.Email, EmailVerified: claims.EmailVerified, } - return identity, q.Get("state"), nil + return identity, nil } diff --git a/server/handlers.go b/server/handlers.go index 00dbe82e68833b749899092d7aee2b7e7a6abf96..3bc49e7c97d7458ad34723c5650ed97cf2b0a018 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -2,7 +2,6 @@ package server import ( "encoding/json" - "errors" "fmt" "log" "net/http" @@ -149,11 +148,9 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { s.renderError(w, http.StatusInternalServerError, errServerError, "") return } - state := authReq.ID - if len(s.connectors) == 1 { for id := range s.connectors { - http.Redirect(w, r, s.absPath("/auth", id)+"?state="+state, http.StatusFound) + http.Redirect(w, r, s.absPath("/auth", id)+"?req="+authReq.ID, http.StatusFound) return } } @@ -169,7 +166,7 @@ func (s *Server) handleAuthorization(w http.ResponseWriter, r *http.Request) { i++ } - s.templates.login(w, connectorInfos, state) + s.templates.login(w, connectorInfos, authReq.ID) } func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { @@ -180,14 +177,29 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } + authReqID := r.FormValue("req") + // TODO(ericchiang): cache user identity. - state := r.FormValue("state") switch r.Method { case "GET": + // Set the connector being used for the login. + updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { + a.ConnectorID = connID + return a, nil + } + if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { + log.Printf("Failed to set connector ID on auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + switch conn := conn.Connector.(type) { case connector.CallbackConnector: - callbackURL, err := conn.LoginURL(s.absURL("/callback", connID), state) + // Use the auth request ID as the "state" token. + // + // TODO(ericchiang): Is this appropriate or should we also be using a nonce? + callbackURL, err := conn.LoginURL(s.absURL("/callback"), authReqID) if err != nil { log.Printf("Connector %q returned error when creating callback: %v", connID, err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -195,7 +207,7 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } http.Redirect(w, r, callbackURL, http.StatusFound) case connector.PasswordConnector: - s.templates.password(w, state, r.URL.String(), "", false) + s.templates.password(w, authReqID, r.URL.String(), "", false) default: s.notFound(w, r) } @@ -216,10 +228,16 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { return } if !ok { - s.templates.password(w, state, r.URL.String(), username, true) + s.templates.password(w, authReqID, r.URL.String(), username, true) + return + } + authReq, err := s.storage.GetAuthRequest(authReqID) + if err != nil { + log.Printf("Failed to get auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") return } - redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector) + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { log.Printf("Failed to finalize login: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -233,8 +251,31 @@ func (s *Server) handleConnectorLogin(w http.ResponseWriter, r *http.Request) { } func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) { - connID := mux.Vars(r)["connector"] - conn, ok := s.connectors[connID] + // SAML redirect bindings use the "RelayState" URL query field. When we support + // SAML, we'll have to check that field too and possibly let callback connectors + // indicate which field is used to determine the state. + // + // See: + // https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf + // Section: "3.4.3 RelayState" + state := r.URL.Query().Get("state") + if state == "" { + s.renderError(w, http.StatusBadRequest, errInvalidRequest, "no 'state' parameter provided") + return + } + + authReq, err := s.storage.GetAuthRequest(state) + if err != nil { + if err == storage.ErrNotFound { + s.renderError(w, http.StatusBadRequest, errInvalidRequest, "invalid 'state' parameter provided") + return + } + log.Printf("Failed to get auth request: %v", err) + s.renderError(w, http.StatusInternalServerError, errServerError, "") + return + } + + conn, ok := s.connectors[authReq.ConnectorID] if !ok { s.notFound(w, r) return @@ -245,14 +286,14 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) return } - identity, state, err := callbackConnector.HandleCallback(r) + identity, err := callbackConnector.HandleCallback(r) if err != nil { log.Printf("Failed to authenticate: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") return } - redirectURL, err := s.finalizeLogin(identity, state, connID, conn.Connector) + redirectURL, err := s.finalizeLogin(identity, authReq, conn.Connector) if err != nil { log.Printf("Failed to finalize login: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") @@ -262,10 +303,11 @@ func (s *Server) handleConnectorCallback(w http.ResponseWriter, r *http.Request) http.Redirect(w, r, redirectURL, http.StatusSeeOther) } -func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connectorID string, conn connector.Connector) (string, error) { - if authReqID == "" { - return "", errors.New("no auth request ID passed") +func (s *Server) finalizeLogin(identity connector.Identity, authReq storage.AuthRequest, conn connector.Connector) (string, error) { + if authReq.ConnectorID == "" { + } + claims := storage.Claims{ UserID: identity.UserID, Username: identity.Username, @@ -275,10 +317,6 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector groupsConn, ok := conn.(connector.GroupsConnector) if ok { - authReq, err := s.storage.GetAuthRequest(authReqID) - if err != nil { - return "", fmt.Errorf("get auth request: %v", err) - } reqGroups := func() bool { for _, scope := range authReq.Scopes { if scope == scopeGroups { @@ -288,27 +326,28 @@ func (s *Server) finalizeLogin(identity connector.Identity, authReqID, connector return false }() if reqGroups { - if claims.Groups, err = groupsConn.Groups(identity); err != nil { + groups, err := groupsConn.Groups(identity) + if err != nil { return "", fmt.Errorf("getting groups: %v", err) } + claims.Groups = groups } } updater := func(a storage.AuthRequest) (storage.AuthRequest, error) { a.LoggedIn = true a.Claims = claims - a.ConnectorID = connectorID a.ConnectorData = identity.ConnectorData return a, nil } - if err := s.storage.UpdateAuthRequest(authReqID, updater); err != nil { + if err := s.storage.UpdateAuthRequest(authReq.ID, updater); err != nil { return "", fmt.Errorf("failed to update auth request: %v", err) } - return path.Join(s.issuerURL.Path, "/approval") + "?state=" + authReqID, nil + return path.Join(s.issuerURL.Path, "/approval") + "?req=" + authReq.ID, nil } func (s *Server) handleApproval(w http.ResponseWriter, r *http.Request) { - authReq, err := s.storage.GetAuthRequest(r.FormValue("state")) + authReq, err := s.storage.GetAuthRequest(r.FormValue("req")) if err != nil { log.Printf("Failed to get auth request: %v", err) s.renderError(w, http.StatusInternalServerError, errServerError, "") diff --git a/server/server.go b/server/server.go index 628dfb4642f4f252159b017e493049ad1fdf8587..603a23cb701c06ccac820e22e4cc8814f511f48c 100644 --- a/server/server.go +++ b/server/server.go @@ -172,7 +172,7 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) handleFunc("/keys", s.handlePublicKeys) handleFunc("/auth", s.handleAuthorization) handleFunc("/auth/{connector}", s.handleConnectorLogin) - handleFunc("/callback/{connector}", s.handleConnectorCallback) + handleFunc("/callback", s.handleConnectorCallback) handleFunc("/approval", s.handleApproval) handleFunc("/healthz", s.handleHealth) s.mux = r diff --git a/server/templates.go b/server/templates.go index 117d12c55e315073b48a25b2f85c11993709b1e5..e8285fe31e26b579b076cc430289ec8fe1f2b10b 100644 --- a/server/templates.go +++ b/server/templates.go @@ -138,29 +138,29 @@ func (n byName) Len() int { return len(n) } func (n byName) Less(i, j int) bool { return n[i].Name < n[j].Name } func (n byName) Swap(i, j int) { n[i], n[j] = n[j], n[i] } -func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo, state string) { +func (t *templates) login(w http.ResponseWriter, connectors []connectorInfo, authReqID string) { sort.Sort(byName(connectors)) data := struct { TemplateConfig Connectors []connectorInfo - State string - }{t.globalData, connectors, state} + AuthReqID string + }{t.globalData, connectors, authReqID} renderTemplate(w, t.loginTmpl, data) } -func (t *templates) password(w http.ResponseWriter, state, callback, lastUsername string, lastWasInvalid bool) { +func (t *templates) password(w http.ResponseWriter, authReqID, callback, lastUsername string, lastWasInvalid bool) { data := struct { TemplateConfig - State string - PostURL string - Username string - Invalid bool - }{t.globalData, state, callback, lastUsername, lastWasInvalid} + AuthReqID string + PostURL string + Username string + Invalid bool + }{t.globalData, authReqID, callback, lastUsername, lastWasInvalid} renderTemplate(w, t.passwordTmpl, data) } -func (t *templates) approval(w http.ResponseWriter, state, username, clientName string, scopes []string) { +func (t *templates) approval(w http.ResponseWriter, authReqID, username, clientName string, scopes []string) { accesses := []string{} for _, scope := range scopes { access, ok := scopeDescriptions[scope] @@ -171,11 +171,11 @@ func (t *templates) approval(w http.ResponseWriter, state, username, clientName sort.Strings(accesses) data := struct { TemplateConfig - User string - Client string - State string - Scopes []string - }{t.globalData, username, clientName, state, accesses} + User string + Client string + AuthReqID string + Scopes []string + }{t.globalData, username, clientName, authReqID, accesses} renderTemplate(w, t.approvalTmpl, data) } diff --git a/server/templates_default.go b/server/templates_default.go index 3a3d031a2bdec7ab779d66527ece476ba18a3791..651c7411b8c9eaa9a7f26fae82aa59533ad31427 100644 --- a/server/templates_default.go +++ b/server/templates_default.go @@ -25,7 +25,7 @@ var defaultTemplates = map[string]string{ <div> <div class="form-row"> <form method="post"> - <input type="hidden" name="state" value="{{ .State }}"/> + <input type="hidden" name="req" value="{{ .AuthReqID }}"/> <input type="hidden" name="approval" value="approve"> <button type="submit" class="btn btn-success"> <span class="btn-text">Grant Access</span> @@ -34,7 +34,7 @@ var defaultTemplates = map[string]string{ </div> <div class="form-row"> <form method="post"> - <input type="hidden" name="state" value="{{ .State }}"/> + <input type="hidden" name="req" value="{{ .AuthReqID }}"/> <input type="hidden" name="approval" value="rejected"> <button type="submit" class="btn btn-provider"> <span class="btn-text">Cancel</span> @@ -300,7 +300,7 @@ var defaultTemplates = map[string]string{ <div> {{ range $c := .Connectors }} <div class="form-row"> - <a href="{{ $c.URL }}?state={{ $.State }}" target="_self"> + <a href="{{ $c.URL }}?req={{ $.AuthReqID }}" target="_self"> <button class="btn btn-provider"> <span class="btn-icon btn-icon-{{ $c.ID }}"></span> <span class="btn-text">Log in with {{ $c.Name }}</span> @@ -344,7 +344,7 @@ var defaultTemplates = map[string]string{ </div> <input tabindex="2" required id="password" name="password" type="password" class="input-box" placeholder="password" {{ if .Invalid }} autofocus {{ end }}/> </div> - <input type="hidden" name="state" value="{{ .State }}"/> + <input type="hidden" name="req" value="{{ .AuthReqID }}"/> {{ if .Invalid }} <div class="error-box"> diff --git a/web/templates/approval.html b/web/templates/approval.html index c73a522e0d6cfde30e9a4aa7fbde26e7065cf122..076049c0b2a57d94d8961ee813f8a21c76788d61 100644 --- a/web/templates/approval.html +++ b/web/templates/approval.html @@ -19,7 +19,7 @@ <div> <div class="form-row"> <form method="post"> - <input type="hidden" name="state" value="{{ .State }}"/> + <input type="hidden" name="req" value="{{ .AuthReqID }}"/> <input type="hidden" name="approval" value="approve"> <button type="submit" class="btn btn-success"> <span class="btn-text">Grant Access</span> @@ -28,7 +28,7 @@ </div> <div class="form-row"> <form method="post"> - <input type="hidden" name="state" value="{{ .State }}"/> + <input type="hidden" name="req" value="{{ .AuthReqID }}"/> <input type="hidden" name="approval" value="rejected"> <button type="submit" class="btn btn-provider"> <span class="btn-text">Cancel</span> diff --git a/web/templates/login.html b/web/templates/login.html index d43b5142773425057774ead2a8b09ecdbbc1ddff..ea43903a538e55b652babf21dc4f0ec722b6dd58 100644 --- a/web/templates/login.html +++ b/web/templates/login.html @@ -6,7 +6,7 @@ <div> {{ range $c := .Connectors }} <div class="form-row"> - <a href="{{ $c.URL }}?state={{ $.State }}" target="_self"> + <a href="{{ $c.URL }}?req={{ $.AuthReqID }}" target="_self"> <button class="btn btn-provider"> <span class="btn-icon btn-icon-{{ $c.ID }}"></span> <span class="btn-text">Log in with {{ $c.Name }}</span> diff --git a/web/templates/password.html b/web/templates/password.html index 89f833fcdcd2f5a840ab8584877061669324aeb8..7a9ffb14ca9e1e0d332a2eba1d1eda1213c627f3 100644 --- a/web/templates/password.html +++ b/web/templates/password.html @@ -15,7 +15,7 @@ </div> <input tabindex="2" required id="password" name="password" type="password" class="input-box" placeholder="password" {{ if .Invalid }} autofocus {{ end }}/> </div> - <input type="hidden" name="state" value="{{ .State }}"/> + <input type="hidden" name="req" value="{{ .AuthReqID }}"/> {{ if .Invalid }} <div class="error-box">