Skip to content
Snippets Groups Projects
handlers_test.go 2.73 KiB
Newer Older
  • Learn to ignore specific revisions
  • Eric Chiang's avatar
    Eric Chiang committed
    package server
    
    
    import (
    	"net/http"
    	"net/http/httptest"
    	"testing"
    
    
    	"golang.org/x/net/context"
    
    )
    
    func TestHandleHealth(t *testing.T) {
    
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    	httpServer, server := newTestServer(ctx, t, nil)
    
    	defer httpServer.Close()
    
    	rr := httptest.NewRecorder()
    	server.handleHealth(rr, httptest.NewRequest("GET", "/healthz", nil))
    	if rr.Code != http.StatusOK {
    		t.Errorf("expected 200 got %d", rr.Code)
    	}
    
    }
    
    
    var discoveryHandlerCORSTests = []struct {
    	DiscoveryAllowedOrigins []string
    	Origin                  string
    	ResponseAllowOrigin     string //The expected response: same as Origin in case of valid CORS flow
    }{
    	{nil, "http://foo.example", ""}, //Default behavior: cross origin requests not allowed
    	{[]string{}, "http://foo.example", ""},
    	{[]string{"http://foo.example"}, "http://foo.example", "http://foo.example"},
    	{[]string{"http://bar.example", "http://foo.example"}, "http://foo.example", "http://foo.example"},
    	{[]string{"*"}, "http://foo.example", "http://foo.example"},
    	{[]string{"http://bar.example"}, "http://foo.example", ""},
    }
    
    func TestDiscoveryHandlerCORS(t *testing.T) {
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    	for _, testcase := range discoveryHandlerCORSTests {
    
    		httpServer, server := newTestServer(ctx, t, func(c *Config) {
    			c.DiscoveryAllowedOrigins = testcase.DiscoveryAllowedOrigins
    		})
    		defer httpServer.Close()
    
    		discoveryHandler, err := server.discoveryHandler()
    		if err != nil {
    			t.Fatalf("failed to get discovery handler: %v", err)
    		}
    
    		//Perform preflight request
    		rrPreflight := httptest.NewRecorder()
    		reqPreflight := httptest.NewRequest("OPTIONS", "/.well-kown/openid-configuration", nil)
    		reqPreflight.Header.Set("Origin", testcase.Origin)
    		reqPreflight.Header.Set("Access-Control-Request-Method", "GET")
    		discoveryHandler.ServeHTTP(rrPreflight, reqPreflight)
    		if rrPreflight.Code != http.StatusOK {
    			t.Errorf("expected 200 got %d", rrPreflight.Code)
    		}
    		headerAccessControlPreflight := rrPreflight.HeaderMap.Get("Access-Control-Allow-Origin")
    		if headerAccessControlPreflight != testcase.ResponseAllowOrigin {
    			t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControlPreflight)
    		}
    
    		//Perform request
    		rr := httptest.NewRecorder()
    		req := httptest.NewRequest("GET", "/.well-kown/openid-configuration", nil)
    		req.Header.Set("Origin", testcase.Origin)
    		discoveryHandler.ServeHTTP(rr, req)
    		if rr.Code != http.StatusOK {
    			t.Errorf("expected 200 got %d", rr.Code)
    		}
    		headerAccessControl := rr.HeaderMap.Get("Access-Control-Allow-Origin")
    		if headerAccessControl != testcase.ResponseAllowOrigin {
    			t.Errorf("expected '%s' got '%s'", testcase.ResponseAllowOrigin, headerAccessControl)
    		}
    	}
    }