Skip to content
Snippets Groups Projects
handlers_test.go 2.81 KiB
Newer Older
  • Learn to ignore specific revisions
  • Eric Chiang's avatar
    Eric Chiang committed
    package server
    
    	"net/http"
    	"net/http/httptest"
    	"testing"
    
    )
    
    func TestHandleHealth(t *testing.T) {
    
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    
    Eric Chiang's avatar
    Eric Chiang committed
    	httpServer, server := newTestServer(ctx, t, nil)
    
    	defer httpServer.Close()
    
    	rr := httptest.NewRecorder()
    
    	server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil))
    
    	if rr.Code != http.StatusOK {
    		t.Errorf("expected 200 got %d", rr.Code)
    	}
    }
    
    
    type badStorage struct {
    	storage.Storage
    }
    
    func (b *badStorage) CreateAuthRequest(r storage.AuthRequest) error {
    	return errors.New("storage unavailable")
    }
    
    func TestHandleHealthFailure(t *testing.T) {
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    	httpServer, server := newTestServer(ctx, t, func(c *Config) {
    		c.Storage = &badStorage{c.Storage}
    	})
    	defer httpServer.Close()
    
    	rr := httptest.NewRecorder()
    	server.ServeHTTP(rr, httptest.NewRequest("GET", "/healthz", nil))
    	if rr.Code != http.StatusInternalServerError {
    		t.Errorf("expected 500 got %d", rr.Code)
    	}
    }
    
    
    type emptyStorage struct {
    	storage.Storage
    }
    
    func (*emptyStorage) GetAuthRequest(string) (storage.AuthRequest, error) {
    	return storage.AuthRequest{}, storage.ErrNotFound
    }
    
    func TestHandleInvalidOAuth2Callbacks(t *testing.T) {
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    	httpServer, server := newTestServer(ctx, t, func(c *Config) {
    		c.Storage = &emptyStorage{c.Storage}
    	})
    	defer httpServer.Close()
    
    	tests := []struct {
    		TargetURI    string
    		ExpectedCode int
    	}{
    		{"/callback", http.StatusBadRequest},
    		{"/callback?code=&state=", http.StatusBadRequest},
    		{"/callback?code=AAAAAAA&state=BBBBBBB", http.StatusBadRequest},
    	}
    
    	rr := httptest.NewRecorder()
    
    	for i, r := range tests {
    		server.ServeHTTP(rr, httptest.NewRequest("GET", r.TargetURI, nil))
    		if rr.Code != r.ExpectedCode {
    			t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code)
    		}
    	}
    }
    
    func TestHandleInvalidSAMLCallbacks(t *testing.T) {
    	ctx, cancel := context.WithCancel(context.Background())
    	defer cancel()
    
    	httpServer, server := newTestServer(ctx, t, func(c *Config) {
    		c.Storage = &emptyStorage{c.Storage}
    	})
    	defer httpServer.Close()
    
    	type requestForm struct {
    		RelayState string
    	}
    	tests := []struct {
    		RequestForm  requestForm
    		ExpectedCode int
    	}{
    		{requestForm{}, http.StatusBadRequest},
    		{requestForm{RelayState: "AAAAAAA"}, http.StatusBadRequest},
    	}
    
    	rr := httptest.NewRecorder()
    
    	for i, r := range tests {
    		jsonValue, err := json.Marshal(r.RequestForm)
    		if err != nil {
    			t.Fatal(err.Error())
    		}
    		server.ServeHTTP(rr, httptest.NewRequest("POST", "/callback", bytes.NewBuffer(jsonValue)))
    		if rr.Code != r.ExpectedCode {
    			t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code)
    		}
    	}
    }