diff --git a/server/handlers_test.go b/server/handlers_test.go
index 3e0b1e81f19a1ab8a31861737b652f615daffa40..395b7e72c1e90c6e9d67045233e469942b9dc771 100644
--- a/server/handlers_test.go
+++ b/server/handlers_test.go
@@ -1,7 +1,9 @@
 package server
 
 import (
+	"bytes"
 	"context"
+	"encoding/json"
 	"errors"
 	"net/http"
 	"net/http/httptest"
@@ -48,3 +50,73 @@ func TestHandleHealthFailure(t *testing.T) {
 		t.Errorf("expected 500 got %d", rr.Code)
 	}
 }
+
+type emptyStorage struct {
+	storage.Storage
+}
+
+func (*emptyStorage) GetAuthRequest(string) (storage.AuthRequest, error) {
+	return storage.AuthRequest{}, storage.ErrNotFound
+}
+
+func TestHandleInvalidOAuth2Callbacks(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	httpServer, server := newTestServer(ctx, t, func(c *Config) {
+		c.Storage = &emptyStorage{c.Storage}
+	})
+	defer httpServer.Close()
+
+	tests := []struct {
+		TargetURI    string
+		ExpectedCode int
+	}{
+		{"/callback", http.StatusBadRequest},
+		{"/callback?code=&state=", http.StatusBadRequest},
+		{"/callback?code=AAAAAAA&state=BBBBBBB", http.StatusBadRequest},
+	}
+
+	rr := httptest.NewRecorder()
+
+	for i, r := range tests {
+		server.ServeHTTP(rr, httptest.NewRequest("GET", r.TargetURI, nil))
+		if rr.Code != r.ExpectedCode {
+			t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code)
+		}
+	}
+}
+
+func TestHandleInvalidSAMLCallbacks(t *testing.T) {
+	ctx, cancel := context.WithCancel(context.Background())
+	defer cancel()
+
+	httpServer, server := newTestServer(ctx, t, func(c *Config) {
+		c.Storage = &emptyStorage{c.Storage}
+	})
+	defer httpServer.Close()
+
+	type requestForm struct {
+		RelayState string
+	}
+	tests := []struct {
+		RequestForm  requestForm
+		ExpectedCode int
+	}{
+		{requestForm{}, http.StatusBadRequest},
+		{requestForm{RelayState: "AAAAAAA"}, http.StatusBadRequest},
+	}
+
+	rr := httptest.NewRecorder()
+
+	for i, r := range tests {
+		jsonValue, err := json.Marshal(r.RequestForm)
+		if err != nil {
+			t.Fatal(err.Error())
+		}
+		server.ServeHTTP(rr, httptest.NewRequest("POST", "/callback", bytes.NewBuffer(jsonValue)))
+		if rr.Code != r.ExpectedCode {
+			t.Fatalf("test %d expected %d, got %d", i, r.ExpectedCode, rr.Code)
+		}
+	}
+}