Skip to content
Snippets Groups Projects
storage_test.go 6.24 KiB
Newer Older
  • Learn to ignore specific revisions
  • Eric Chiang's avatar
    Eric Chiang committed
    package kubernetes
    
    import (
    
    	"crypto/tls"
    	"errors"
    	"fmt"
    	"net/http"
    	"net/http/httptest"
    
    Eric Chiang's avatar
    Eric Chiang committed
    	"os"
    
    	"path/filepath"
    
    	"strings"
    
    Eric Chiang's avatar
    Eric Chiang committed
    	"testing"
    
    
    	"github.com/sirupsen/logrus"
    
    	"github.com/stretchr/testify/require"
    
    	"github.com/stretchr/testify/suite"
    
    
    	"github.com/dexidp/dex/storage"
    	"github.com/dexidp/dex/storage/conformance"
    
    Eric Chiang's avatar
    Eric Chiang committed
    )
    
    
    const kubeconfigPathVariableName = "DEX_KUBERNETES_CONFIG_PATH"
    
    func TestStorage(t *testing.T) {
    
    	if os.Getenv(kubeconfigPathVariableName) == "" {
    
    		t.Skipf("variable %q not set, skipping kubernetes storage tests\n", kubeconfigPathVariableName)
    
    	}
    
    	suite.Run(t, new(StorageTestSuite))
    }
    
    type StorageTestSuite struct {
    	suite.Suite
    	client *client
    }
    
    
    func (s *StorageTestSuite) expandDir(dir string) string {
    
    	dir = strings.Trim(dir, `"`)
    
    	if strings.HasPrefix(dir, "~/") {
    		homedir, err := os.UserHomeDir()
    		s.Require().NoError(err)
    
    		dir = filepath.Join(homedir, strings.TrimPrefix(dir, "~/"))
    	}
    	return dir
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    
    func (s *StorageTestSuite) SetupTest() {
    
    	kubeconfigPath := s.expandDir(os.Getenv(kubeconfigPathVariableName))
    
    		KubeConfigFile: kubeconfigPath,
    
    Eric Chiang's avatar
    Eric Chiang committed
    	}
    
    	logger := &logrus.Logger{
    		Out:       os.Stderr,
    		Formatter: &logrus.TextFormatter{DisableColors: true},
    		Level:     logrus.DebugLevel,
    	}
    
    	kubeClient, err := config.open(logger, true)
    
    	s.Require().NoError(err)
    
    
    	s.client = kubeClient
    
    }
    
    func (s *StorageTestSuite) TestStorage() {
    	newStorage := func() storage.Storage {
    		for _, resource := range []string{
    			resourceAuthCode,
    			resourceAuthRequest,
    
    			resourceDeviceRequest,
    			resourceDeviceToken,
    
    			resourceClient,
    			resourceRefreshToken,
    			resourceKeys,
    			resourcePassword,
    		} {
    			if err := s.client.deleteAll(resource); err != nil {
    				s.T().Fatalf("delete all %q failed: %v", resource, err)
    			}
    		}
    		return s.client
    
    Eric Chiang's avatar
    Eric Chiang committed
    	}
    
    
    	conformance.RunTests(s.T(), newStorage)
    	conformance.RunTransactionTests(s.T(), newStorage)
    
    Eric Chiang's avatar
    Eric Chiang committed
    }
    
    func TestURLFor(t *testing.T) {
    	tests := []struct {
    		apiVersion, namespace, resource, name string
    
    		baseURL string
    		want    string
    	}{
    		{
    			"v1", "default", "pods", "a",
    			"https://k8s.example.com",
    			"https://k8s.example.com/api/v1/namespaces/default/pods/a",
    		},
    		{
    			"foo/v1", "default", "bar", "a",
    			"https://k8s.example.com",
    			"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
    		},
    		{
    			"foo/v1", "default", "bar", "a",
    			"https://k8s.example.com/",
    			"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
    		},
    		{
    			"foo/v1", "default", "bar", "a",
    			"https://k8s.example.com/",
    			"https://k8s.example.com/apis/foo/v1/namespaces/default/bar/a",
    		},
    		{
    			// no namespace
    			"foo/v1", "", "bar", "a",
    			"https://k8s.example.com",
    			"https://k8s.example.com/apis/foo/v1/bar/a",
    		},
    	}
    
    	for _, test := range tests {
    
    		c := &client{baseURL: test.baseURL}
    
    Eric Chiang's avatar
    Eric Chiang committed
    		got := c.urlFor(test.apiVersion, test.namespace, test.resource, test.name)
    		if got != test.want {
    			t.Errorf("(&client{baseURL:%q}).urlFor(%q, %q, %q, %q): expected %q got %q",
    				test.baseURL,
    				test.apiVersion, test.namespace, test.resource, test.name,
    				test.want, got,
    			)
    		}
    	}
    }
    
    
    func TestUpdateKeys(t *testing.T) {
    	fakeUpdater := func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, nil }
    
    	tests := []struct {
    		name               string
    		updater            func(old storage.Keys) (storage.Keys, error)
    		getResponseCode    int
    		actionResponseCode int
    		wantErr            bool
    		exactErr           error
    	}{
    		{
    			"Create OK test",
    			fakeUpdater,
    			404,
    			201,
    			false,
    			nil,
    		},
    		{
    			"Update should be OK",
    			fakeUpdater,
    			200,
    			200,
    			false,
    			nil,
    		},
    		{
    			"Create conflict should be OK",
    			fakeUpdater,
    			404,
    			409,
    			true,
    			errors.New("keys already created by another server instance"),
    		},
    		{
    			"Update conflict should be OK",
    			fakeUpdater,
    			200,
    			409,
    			true,
    			errors.New("keys already rotated by another server instance"),
    		},
    		{
    			"Client error is error",
    			fakeUpdater,
    			404,
    			500,
    			true,
    			nil,
    		},
    		{
    			"Client error during update is error",
    			fakeUpdater,
    			200,
    			500,
    			true,
    			nil,
    		},
    		{
    			"Get error is error",
    			fakeUpdater,
    			500,
    			200,
    			true,
    			nil,
    		},
    		{
    			"Updater error is error",
    			func(old storage.Keys) (storage.Keys, error) { return storage.Keys{}, fmt.Errorf("test") },
    			200,
    			201,
    			true,
    			nil,
    		},
    	}
    
    	for _, test := range tests {
    		client := newStatusCodesResponseTestClient(test.getResponseCode, test.actionResponseCode)
    
    		err := client.UpdateKeys(test.updater)
    		if err != nil {
    			if !test.wantErr {
    				t.Fatalf("Test %q: %v", test.name, err)
    			}
    
    			if test.exactErr != nil && test.exactErr.Error() != err.Error() {
    				t.Fatalf("Test %q: %v, wanted: %v", test.name, err, test.exactErr)
    			}
    		}
    	}
    }
    
    func newStatusCodesResponseTestClient(getResponseCode, actionResponseCode int) *client {
    	s := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
    		if r.Method == http.MethodGet {
    			w.WriteHeader(getResponseCode)
    		} else {
    			w.WriteHeader(actionResponseCode)
    		}
    		w.Write([]byte(`{}`)) // Empty json is enough, we will test only response codes here
    	}))
    
    	tr := &http.Transport{
    		TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
    	}
    	return &client{
    		client:  &http.Client{Transport: tr},
    		baseURL: s.URL,
    		logger: &logrus.Logger{
    			Out:       os.Stderr,
    			Formatter: &logrus.TextFormatter{DisableColors: true},
    			Level:     logrus.DebugLevel,
    		},
    	}
    }
    
    
    func TestRetryOnConflict(t *testing.T) {
    	tests := []struct {
    		name     string
    		action   func() error
    		exactErr string
    	}{
    		{
    			"Timeout reached",
    			func() error { err := httpErr{status: 409}; return error(&err) },
    
    			"maximum timeout reached while retrying a conflicted request:   Conflict: response from server \"\"",
    
    		},
    		{
    			"HTTP Error",
    			func() error { err := httpErr{status: 500}; return error(&err) },
    			"  Internal Server Error: response from server \"\"",
    		},
    		{
    			"Error",
    			func() error { return errors.New("test") },
    			"test",
    		},
    		{
    			"OK",
    			func() error { return nil },
    			"",
    		},
    	}
    
    	for _, testCase := range tests {
    		t.Run(testCase.name, func(t *testing.T) {
    			err := retryOnConflict(context.TODO(), testCase.action)
    			if testCase.exactErr != "" {
    				require.EqualError(t, err, testCase.exactErr)
    			} else {
    				require.NoError(t, err)
    			}
    		})
    	}
    }