Skip to content
Snippets Groups Projects
Commit 415a68f9 authored by Simon HEGE's avatar Simon HEGE
Browse files

Allow CORS on keys and token endpoints

parent ca7d2b8f
No related branches found
No related tags found
No related merge requests found
...@@ -99,11 +99,11 @@ type OAuth2 struct { ...@@ -99,11 +99,11 @@ type OAuth2 struct {
// Web is the config format for the HTTP server. // Web is the config format for the HTTP server.
type Web struct { type Web struct {
HTTP string `json:"http"` HTTP string `json:"http"`
HTTPS string `json:"https"` HTTPS string `json:"https"`
TLSCert string `json:"tlsCert"` TLSCert string `json:"tlsCert"`
TLSKey string `json:"tlsKey"` TLSKey string `json:"tlsKey"`
DiscoveryAllowedOrigins []string `json:"discoveryAllowedOrigins"` AllowedOrigins []string `json:"allowedOrigins"`
} }
// GRPC is the config for the gRPC API. // GRPC is the config for the gRPC API.
......
...@@ -179,24 +179,24 @@ func serve(cmd *cobra.Command, args []string) error { ...@@ -179,24 +179,24 @@ func serve(cmd *cobra.Command, args []string) error {
if c.OAuth2.SkipApprovalScreen { if c.OAuth2.SkipApprovalScreen {
logger.Infof("config skipping approval screen") logger.Infof("config skipping approval screen")
} }
if len(c.Web.DiscoveryAllowedOrigins) > 0 { if len(c.Web.AllowedOrigins) > 0 {
logger.Infof("config discovery allowed origins: %s", c.Web.DiscoveryAllowedOrigins) logger.Infof("config allowed origins: %s", c.Web.AllowedOrigins)
} }
// explicitly convert to UTC. // explicitly convert to UTC.
now := func() time.Time { return time.Now().UTC() } now := func() time.Time { return time.Now().UTC() }
serverConfig := server.Config{ serverConfig := server.Config{
SupportedResponseTypes: c.OAuth2.ResponseTypes, SupportedResponseTypes: c.OAuth2.ResponseTypes,
SkipApprovalScreen: c.OAuth2.SkipApprovalScreen, SkipApprovalScreen: c.OAuth2.SkipApprovalScreen,
DiscoveryAllowedOrigins: c.Web.DiscoveryAllowedOrigins, AllowedOrigins: c.Web.AllowedOrigins,
Issuer: c.Issuer, Issuer: c.Issuer,
Connectors: connectors, Connectors: connectors,
Storage: s, Storage: s,
Web: c.Frontend, Web: c.Frontend,
EnablePasswordDB: c.EnablePasswordDB, EnablePasswordDB: c.EnablePasswordDB,
Logger: logger, Logger: logger,
Now: now, Now: now,
} }
if c.Expiry.SigningKeys != "" { if c.Expiry.SigningKeys != "" {
signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys) signingKeys, err := time.ParseDuration(c.Expiry.SigningKeys)
......
...@@ -12,7 +12,6 @@ import ( ...@@ -12,7 +12,6 @@ import (
"strings" "strings"
"time" "time"
"github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
jose "gopkg.in/square/go-jose.v2" jose "gopkg.in/square/go-jose.v2"
...@@ -104,7 +103,7 @@ type discovery struct { ...@@ -104,7 +103,7 @@ type discovery struct {
Claims []string `json:"claims_supported"` Claims []string `json:"claims_supported"`
} }
func (s *Server) discoveryHandler() (http.Handler, error) { func (s *Server) discoveryHandler() (http.HandlerFunc, error) {
d := discovery{ d := discovery{
Issuer: s.issuerURL.String(), Issuer: s.issuerURL.String(),
Auth: s.absURL("/auth"), Auth: s.absURL("/auth"),
...@@ -130,18 +129,11 @@ func (s *Server) discoveryHandler() (http.Handler, error) { ...@@ -130,18 +129,11 @@ func (s *Server) discoveryHandler() (http.Handler, error) {
return nil, fmt.Errorf("failed to marshal discovery data: %v", err) return nil, fmt.Errorf("failed to marshal discovery data: %v", err)
} }
var discoveryHandler http.Handler return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
discoveryHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
w.Header().Set("Content-Length", strconv.Itoa(len(data))) w.Header().Set("Content-Length", strconv.Itoa(len(data)))
w.Write(data) w.Write(data)
}) }), nil
if len(s.discoveryAllowedOrigins) > 0 {
corsOption := handlers.AllowedOrigins(s.discoveryAllowedOrigins)
discoveryHandler = handlers.CORS(corsOption)(discoveryHandler)
}
return discoveryHandler, nil
} }
// handleAuthorization handles the OAuth2 auth endpoint. // handleAuthorization handles the OAuth2 auth endpoint.
......
...@@ -22,61 +22,3 @@ func TestHandleHealth(t *testing.T) { ...@@ -22,61 +22,3 @@ func TestHandleHealth(t *testing.T) {
} }
} }
var discoveryHandlerCORSTests = []struct {
DiscoveryAllowedOrigins []string
Origin string
ResponseAllowOrigin string //The expected response: same as Origin in case of valid CORS flow
}{
{nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed
{[]string{}, "http://foo.example", ""},
{[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"},
{[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"},
{[]string{"*"}, "http://foo.example", "http://foo.example"},
{[]string{"http://bar.example"}, "http://foo.example", ""},
}
func TestDiscoveryHandlerCORS(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for _, testcase := range discoveryHandlerCORSTests {
httpServer, server := newTestServer(ctx, t, func(c *Config) {
c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins
})
defer httpServer.Close()
discoveryHandler, err := server.discoveryHandler()
if err != nil {
t.Fatalf("failed to get discovery handler: %v", err)
}
//Perform preflight request
rrPreflight := httptest.NewRecorder()
reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil)
reqPreflight.Header.Set("Origin", testcase.Origin)
reqPreflight.Header.Set("Access-Control-Request-Method", "GET")
discoveryHandler.ServeHTTP(rrPreflight, reqPreflight)
if rrPreflight.Code != http.StatusOK {
t.Errorf("expected 200 got %d", rrPreflight.Code)
}
headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin")
if headerAccessControlPreflight != testcase.ResponseAllowOrigin {
t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight)
}
//Perform request
rr := httptest.NewRecorder()
req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil)
req.Header.Set("Origin", testcase.Origin)
discoveryHandler.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Errorf("expected 200 got %d", rr.Code)
}
headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin")
if headerAccessControl != testcase.ResponseAllowOrigin {
t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl)
}
}
}
...@@ -13,6 +13,7 @@ import ( ...@@ -13,6 +13,7 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"github.com/Sirupsen/logrus" "github.com/Sirupsen/logrus"
"github.com/gorilla/handlers"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/coreos/dex/connector" "github.com/coreos/dex/connector"
...@@ -42,10 +43,10 @@ type Config struct { ...@@ -42,10 +43,10 @@ type Config struct {
// flow. If no response types are supplied this value defaults to "code". // flow. If no response types are supplied this value defaults to "code".
SupportedResponseTypes []string SupportedResponseTypes []string
// List of allowed origins for CORS requests on discovery endpoint. // List of allowed origins for CORS requests on discovery, token and keys endpoint.
// If none are indicated, CORS requests are disabled. Passing in "*" will allow any // If none are indicated, CORS requests are disabled. Passing in "*" will allow any
// domain. // domain.
DiscoveryAllowedOrigins []string AllowedOrigins []string
// If enabled, the server won't prompt the user to approve authorization requests. // If enabled, the server won't prompt the user to approve authorization requests.
// Logging in implies approval. // Logging in implies approval.
...@@ -116,8 +117,6 @@ type Server struct { ...@@ -116,8 +117,6 @@ type Server struct {
supportedResponseTypes map[string]bool supportedResponseTypes map[string]bool
discoveryAllowedOrigins []string
now func() time.Time now func() time.Time
idTokensValidFor time.Duration idTokensValidFor time.Duration
...@@ -185,16 +184,15 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) ...@@ -185,16 +184,15 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
} }
s := &Server{ s := &Server{
issuerURL: *issuerURL, issuerURL: *issuerURL,
connectors: make(map[string]Connector), connectors: make(map[string]Connector),
storage: newKeyCacher(c.Storage, now), storage: newKeyCacher(c.Storage, now),
supportedResponseTypes: supported, supportedResponseTypes: supported,
discoveryAllowedOrigins: c.DiscoveryAllowedOrigins, idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour),
idTokensValidFor: value(c.IDTokensValidFor, 24*time.Hour), skipApproval: c.SkipApprovalScreen,
skipApproval: c.SkipApprovalScreen, now: now,
now: now, templates: tmpls,
templates: tmpls, logger: c.Logger,
logger: c.Logger,
} }
for _, conn := range c.Connectors { for _, conn := range c.Connectors {
...@@ -205,24 +203,29 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy) ...@@ -205,24 +203,29 @@ func newServer(ctx context.Context, c Config, rotationStrategy rotationStrategy)
handleFunc := func(p string, h http.HandlerFunc) { handleFunc := func(p string, h http.HandlerFunc) {
r.HandleFunc(path.Join(issuerURL.Path, p), h) r.HandleFunc(path.Join(issuerURL.Path, p), h)
} }
handle := func(p string, h http.Handler) {
r.Handle(path.Join(issuerURL.Path, p), h)
}
handlePrefix := func(p string, h http.Handler) { handlePrefix := func(p string, h http.Handler) {
prefix := path.Join(issuerURL.Path, p) prefix := path.Join(issuerURL.Path, p)
r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h)) r.PathPrefix(prefix).Handler(http.StripPrefix(prefix, h))
} }
handleWithCORS := func(p string, h http.HandlerFunc) {
var handler http.Handler = h
if len(c.AllowedOrigins) > 0 {
corsOption := handlers.AllowedOrigins(c.AllowedOrigins)
handler = handlers.CORS(corsOption)(handler)
}
r.Handle(path.Join(issuerURL.Path, p), handler)
}
r.NotFoundHandler = http.HandlerFunc(http.NotFound) r.NotFoundHandler = http.HandlerFunc(http.NotFound)
discoveryHandler, err := s.discoveryHandler() discoveryHandler, err := s.discoveryHandler()
if err != nil { if err != nil {
return nil, err return nil, err
} }
handle("/.well-known/openid-configuration", discoveryHandler) handleWithCORS("/.well-known/openid-configuration", discoveryHandler)
// TODO(ericchiang): rate limit certain paths based on IP. // TODO(ericchiang): rate limit certain paths based on IP.
handleFunc("/token", s.handleToken) handleWithCORS("/token", s.handleToken)
handleFunc("/keys", s.handlePublicKeys) handleWithCORS("/keys", s.handlePublicKeys)
handleFunc("/auth", s.handleAuthorization) handleFunc("/auth", s.handleAuthorization)
handleFunc("/auth/{connector}", s.handleConnectorLogin) handleFunc("/auth/{connector}", s.handleConnectorLogin)
handleFunc("/callback", s.handleConnectorCallback) handleFunc("/callback", s.handleConnectorCallback)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment