Skip to content
Snippets Groups Projects
storage_test.go 8.95 KiB
Newer Older
  • Learn to ignore specific revisions
  • Eric Chiang's avatar
    Eric Chiang committed
    package kubernetes
    
    import (
    
    	"crypto/tls"
    	"errors"
    	"fmt"
    	"net/http"
    	"net/http/httptest"
    
    Eric Chiang's avatar
    Eric Chiang committed
    	"os"
    
    	"path/filepath"
    
    	"strings"
    
    Eric Chiang's avatar
    Eric Chiang committed
    	"testing"
    
    Eric Chiang's avatar
    Eric Chiang committed
    
    
    	"github.com/sirupsen/logrus"
    
    	"github.com/stretchr/testify/require"
    
    	"github.com/stretchr/testify/suite"
    
    
    	"github.com/dexidp/dex/storage"
    	"github.com/dexidp/dex/storage/conformance"
    
    Eric Chiang's avatar
    Eric Chiang committed
    )
    
    
    const kubeconfigPathVariableName = "DEX_KUBERNETES_CONFIG_PATH"
    
    func TestStorage(t *testing.T) {
    
    	if os.Getenv(kubeconfigPathVariableName) == "" {
    
    		t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName)
    
    	}
    
    	suite.Run(t, new(StorageTestSuite))
    }
    
    type StorageTestSuite struct {
    	suite.Suite
    	client *client
    }
    
    
    func expandDir(dir string) (string, error) {
    
    	dir = strings.Trim(dir, `"`)
    
    	if strings.HasPrefix(dir, "~/") {
    		homedir, err := os.UserHomeDir()
    
    		if err != nil {
    			return "", err
    		}
    
    		dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/"))
    	}
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    
    func (s *StorageTestSuite) SetupTest() {
    
    	kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
    	s.Require().NoError(err)
    
    		KubeConfigFile: kubeconfigPath,
    
    Eric Chiang's avatar
    Eric Chiang committed
    	}
    
    	logger := &logrus.Logger{
    		Out:       os.Stderr,
    		Formatter: &logrus.TextFormatter{DisableColors: true},
    		Level:     logrus.DebugLevel,
    	}
    
    	kubeClient, err := config.open(logger, true)
    
    	s.Require().NoError(err)
    
    
    	s.client = kubeClient
    
    }
    
    func (s *StorageTestSuite) TestStorage() {
    	newStorage := func() storage.Storage {
    		for _, resource := range []string{
    			resourceAuthCode,
    			resourceAuthRequest,
    
    			resourceDeviceRequest,
    			resourceDeviceToken,
    
    			resourceClient,
    			resourceRefreshToken,
    			resourceKeys,
    			resourcePassword,
    		} {
    			if err := s.client.deleteAll(resource); err != nil {
    				s.T().Fatalf("delete all %q failed: %v", resource, err)
    			}
    		}
    		return s.client
    
    Eric Chiang's avatar
    Eric Chiang committed
    	}
    
    
    	conformance.RunTests(s.T(), newStorage)
    	conformance.RunTransactionTests(s.T(), newStorage)
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    func TestURLFor(t *testing.T) {
    	tests := []struct {
    		apiVersion, namespace, resource, name string
    
    		baseURL string
    		want    string
    	}{
    		{
    			"v1", "default", "pods", "a",
    			"https://k8s.example.com",
    			"https://k8s.example.com/api/v1/namespaces/default/pods/a",
    		},
    		{
    			"foo/v1", "default", "bar", "a",
    			"https://k8s.example.com",
    			"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
    		},
    		{
    			"foo/v1", "default", "bar", "a",
    			"https://k8s.example.com/",
    			"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
    		},
    		{
    			"foo/v1", "default", "bar", "a",
    			"https://k8s.example.com/",
    			"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
    		},
    		{
    			// no namespace
    			"foo/v1", "", "bar", "a",
    			"https://k8s.example.com",
    			"https://k8s.example.com/apis/foo/v1/bar/a",
    		},
    	}
    
    	for _, test := range tests {
    
    		c := &client{baseURL: test.baseURL}
    
    Eric Chiang's avatar
    Eric Chiang committed
    		got := c.urlFor(test.apiVersion, test.namespace, test.resource, test.name)
    		if got != test.want {
    			t.Errorf("(&client{baseURL:%q}).urlFor(%q, %q, %q, %q): expected %q got %q",
    				test.baseURL,
    				test.apiVersion, test.namespace, test.resource, test.name,
    				test.want, got,
    			)
    		}
    	}
    }
    
    
    func TestUpdateKeys(t *testing.T) {
    	fakeUpdater := func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, nil }
    
    	tests := []struct {
    		name               string
    		updater            func(old storage.Keys) (storage.Keys, error)
    		getResponseCode    int
    		actionResponseCode int
    		wantErr            bool
    		exactErr           error
    	}{
    		{
    			"Create OK test",
    			fakeUpdater,
    			404,
    			201,
    			false,
    			nil,
    		},
    		{
    			"Update should be OK",
    			fakeUpdater,
    			200,
    			200,
    			false,
    			nil,
    		},
    		{
    			"Create conflict should be OK",
    			fakeUpdater,
    			404,
    			409,
    			true,
    			errors.New("keys already created by another server instance"),
    		},
    		{
    			"Update conflict should be OK",
    			fakeUpdater,
    			200,
    			409,
    			true,
    			errors.New("keys already rotated by another server instance"),
    		},
    		{
    			"Client error is error",
    			fakeUpdater,
    			404,
    			500,
    			true,
    			nil,
    		},
    		{
    			"Client error during update is error",
    			fakeUpdater,
    			200,
    			500,
    			true,
    			nil,
    		},
    		{
    			"Get error is error",
    			fakeUpdater,
    			500,
    			200,
    			true,
    			nil,
    		},
    		{
    			"Updater error is error",
    			func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, fmt.Errorf("test") },
    			200,
    			201,
    			true,
    			nil,
    		},
    	}
    
    	for _, test := range tests {
    		client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode)
    
    		err := client.UpdateKeys(test.updater)
    		if err != nil {
    			if !test.wantErr {
    				t.Fatalf("Test %q: %v", test.name, err)
    			}
    
    			if test.exactErr != nil && test.exactErr.Error() != err.Error() {
    				t.Fatalf("Test %q: %v, wanted: %v", test.name, err, test.exactErr)
    			}
    		}
    	}
    }
    
    func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) *client {
    	s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		if r.Method == http.MethodGet {
    			w.WriteHeader(getResponseCode)
    		} else {
    			w.WriteHeader(actionResponseCode)
    		}
    		w.Write([]byte(`{}`)) // Empty json is enough, we will test only response codes here
    	}))
    
    	tr := &http.Transport{
    		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
    	}
    	return &client{
    		client:  &http.Client{Transport: tr},
    		baseURL: s.URL,
    		logger: &logrus.Logger{
    			Out:       os.Stderr,
    			Formatter: &logrus.TextFormatter{DisableColors: true},
    			Level:     logrus.DebugLevel,
    		},
    	}
    }
    
    
    func TestRetryOnConflict(t *testing.T) {
    	tests := []struct {
    		name     string
    		action   func() error
    		exactErr string
    	}{
    		{
    			"Timeout reached",
    			func() error { err := httpErr{status: 409}; return error(&err) },
    
    			"maximum timeout reached while retrying a conflicted request:   Conflict: response from server \"\"",
    
    		},
    		{
    			"HTTP Error",
    			func() error { err := httpErr{status: 500}; return error(&err) },
    			"  Internal Server Error: response from server \"\"",
    		},
    		{
    			"Error",
    			func() error { return errors.New("test") },
    			"test",
    		},
    		{
    			"OK",
    			func() error { return nil },
    			"",
    		},
    	}
    
    	for _, testCase := range tests {
    		t.Run(testCase.name, func(t *testing.T) {
    			err := retryOnConflict(context.TODO(), testCase.action)
    			if testCase.exactErr != "" {
    				require.EqualError(t, err, testCase.exactErr)
    			} else {
    				require.NoError(t, err)
    			}
    		})
    	}
    }
    
    
    func TestRefreshTokenLock(t *testing.T) {
    	if os.Getenv(kubeconfigPathVariableName) == "" {
    		t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName)
    	}
    
    	kubeconfigPath, err := expandDir(os.Getenv(kubeconfigPathVariableName))
    	require.NoError(t, err)
    
    	config := Config{
    		KubeConfigFile: kubeconfigPath,
    	}
    
    	logger := &logrus.Logger{
    		Out:       os.Stderr,
    		Formatter: &logrus.TextFormatter{DisableColors: true},
    		Level:     logrus.DebugLevel,
    	}
    
    	kubeClient, err := config.open(logger, true)
    	require.NoError(t, err)
    
    	lockCheckPeriod = time.Nanosecond
    
    	// Creating a storage with an existing refresh token and offline session for the user.
    	id := storage.NewID()
    	r := storage.RefreshToken{
    		ID:          id,
    		Token:       "bar",
    		Nonce:       "foo",
    		ClientID:    "client_id",
    		ConnectorID: "client_secret",
    		Scopes:      []string{"openid", "email", "profile"},
    		CreatedAt:   time.Now().UTC().Round(time.Millisecond),
    		LastUsed:    time.Now().UTC().Round(time.Millisecond),
    		Claims: storage.Claims{
    			UserID:        "1",
    			Username:      "jane",
    			Email:         "jane.doe@example.com",
    			EmailVerified: true,
    			Groups:        []string{"a", "b"},
    		},
    		ConnectorData: []byte(`{"some":"data"}`),
    	}
    
    	err = kubeClient.CreateRefresh(r)
    	require.NoError(t, err)
    
    	t.Run("Timeout lock error", func(t *testing.T) {
    		err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
    			r.Token = "update-result-1"
    			err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
    				r.Token = "timeout-err"
    				return r, nil
    			})
    			require.Equal(t, fmt.Errorf("timeout waiting for refresh token %s lock", r.ID), err)
    			return r, nil
    		})
    		require.NoError(t, err)
    
    		token, err := kubeClient.GetRefresh(r.ID)
    		require.NoError(t, err)
    		require.Equal(t, "update-result-1", token.Token)
    	})
    
    	t.Run("Break the lock", func(t *testing.T) {
    		var lockBroken bool
    		lockTimeout = -time.Hour
    
    		err = kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
    			r.Token = "update-result-2"
    			if lockBroken {
    				return r, nil
    			}
    
    			err := kubeClient.UpdateRefreshToken(r.ID, func(r storage.RefreshToken) (storage.RefreshToken, error) {
    				r.Token = "should-break-the-lock-and-finish-updating"
    				return r, nil
    			})
    			require.NoError(t, err)
    
    			lockBroken = true
    			return r, nil
    		})
    		require.NoError(t, err)
    
    		token, err := kubeClient.GetRefresh(r.ID)
    		require.NoError(t, err)
    		// Because concurrent update breaks the lock, the final result will be the value of the first update
    		require.Equal(t, "update-result-2", token.Token)
    	})
    }