diff --git a/cache/cache.go b/cache/cache.go index 4d30cf64c57b1a31cce923cb667d5b10b6b2a466..b29dd080176c920948817ac89d9a25a56c8b3941 100644 --- a/cache/cache.go +++ b/cache/cache.go @@ -1,9 +1,12 @@ package cache import ( + "fmt" "net/url" "path" + "path/filepath" "strconv" + "strings" "github.com/sirupsen/logrus" @@ -20,17 +23,33 @@ func getCacheConfig(build *common.Build) *common.CacheConfig { return build.Runner.Cache } -func generateObjectName(build *common.Build, config *common.CacheConfig, key string) string { - if key == "" { - return "" - } - +func generateBaseObjectName(build *common.Build, config *common.CacheConfig) string { runnerSegment := "" if !config.GetShared() { runnerSegment = path.Join("runner", build.Runner.ShortDescription()) } - return path.Join(config.GetPath(), runnerSegment, "project", strconv.Itoa(build.JobInfo.ProjectID), key) + return path.Join(config.GetPath(), runnerSegment, "project", strconv.Itoa(build.JobInfo.ProjectID)) +} + +func generateObjectName(build *common.Build, config *common.CacheConfig, key string) (string, error) { + if key == "" { + return "", nil + } + + basePath := generateBaseObjectName(build, config) + path := path.Join(basePath, key) + + relative, err := filepath.Rel(basePath, path) + if err != nil { + return "", fmt.Errorf("cache path correctness check failed with: %v", err) + } + + if strings.HasPrefix(relative, ".."+string(filepath.Separator)) { + return "", fmt.Errorf("computed cache path outside of project bucket. Please remove `../` from cache key") + } + + return path, nil } func onAdapter(build *common.Build, key string, handler func(adapter Adapter) *url.URL) *url.URL { @@ -40,7 +59,12 @@ func onAdapter(build *common.Build, key string, handler func(adapter Adapter) *u return nil } - objectName := generateObjectName(build, config, key) + objectName, err := generateObjectName(build, config, key) + if err != nil { + logrus.WithError(err).Error("Error while generating cache bucket.") + return nil + } + if objectName == "" { logrus.Warning("Empty cache key. Skipping adapter selection.") return nil diff --git a/cache/cache_test.go b/cache/cache_test.go index 48c8f35c1a0a83ea85c4602847398d9e82b8670d..7e03f0c65056f47e56450289d00888a46adae19f 100644 --- a/cache/cache_test.go +++ b/cache/cache_test.go @@ -173,63 +173,92 @@ func defaultBuild(cacheConfig *common.CacheConfig) *common.Build { } } -func TestGenerateObjectNameWhenKeyIsEmptyResultIsAlsoEmpty(t *testing.T) { - cache := defaultCacheConfig() - cacheBuild := defaultBuild(cache) - - url := generateObjectName(cacheBuild, cache, "") - assert.Empty(t, url) -} - -func TestGetCacheObjectName(t *testing.T) { - cache := defaultCacheConfig() - cacheBuild := defaultBuild(cache) - - url := generateObjectName(cacheBuild, cache, "key") - assert.Equal(t, "runner/longtoke/project/10/key", url) -} +type generateObjectNameTestCase struct { + cache *common.CacheConfig + build *common.Build -func TestGetCacheObjectNameWhenPathIsSetThenUrlContainsIt(t *testing.T) { - cache := defaultCacheConfig() - cache.Path = "whatever" - cacheBuild := defaultBuild(cache) + key string + path string + shared bool - url := generateObjectName(cacheBuild, cache, "key") - assert.Equal(t, "whatever/runner/longtoke/project/10/key", url) + expectedObjectName string + expectedError string } -func TestGetCacheObjectNameWhenPathHasMultipleSegmentIsSetThenUrlContainsIt(t *testing.T) { +func TestGenerateObjectName(t *testing.T) { cache := defaultCacheConfig() - cache.Path = "some/other/path/goes/here" cacheBuild := defaultBuild(cache) - url := generateObjectName(cacheBuild, cache, "key") - assert.Equal(t, "some/other/path/goes/here/runner/longtoke/project/10/key", url) -} - -func TestGetCacheObjectNameWhenPathIsNotSetThenUrlDoesNotContainIt(t *testing.T) { - cache := defaultCacheConfig() - cache.Path = "" - cacheBuild := defaultBuild(cache) - - url := generateObjectName(cacheBuild, cache, "key") - assert.Equal(t, "runner/longtoke/project/10/key", url) -} - -func TestGetCacheObjectNameWhenSharedFlagIsFalseThenRunnerSegmentExistsInTheUrl(t *testing.T) { - cache := defaultCacheConfig() - cache.Shared = false - cacheBuild := defaultBuild(cache) + tests := map[string]generateObjectNameTestCase{ + "default usage": { + cache: cache, + build: cacheBuild, + key: "key", + expectedObjectName: "runner/longtoke/project/10/key", + }, + "empty key": { + cache: cache, + build: cacheBuild, + key: "", + expectedObjectName: "", + }, + "short path is set": { + cache: cache, + build: cacheBuild, + key: "key", + path: "whatever", + expectedObjectName: "whatever/runner/longtoke/project/10/key", + }, + "multiple segment path is set": { + cache: cache, + build: cacheBuild, + key: "key", + path: "some/other/path/goes/here", + expectedObjectName: "some/other/path/goes/here/runner/longtoke/project/10/key", + }, + "path is empty": { + cache: cache, + build: cacheBuild, + key: "key", + path: "", + expectedObjectName: "runner/longtoke/project/10/key", + }, + "shared flag is set to true": { + cache: cache, + build: cacheBuild, + key: "key", + shared: true, + expectedObjectName: "project/10/key", + }, + "shared flag is set to false": { + cache: cache, + build: cacheBuild, + key: "key", + shared: false, + expectedObjectName: "runner/longtoke/project/10/key", + }, + "key escapes project namespace": { + cache: cache, + build: cacheBuild, + key: "../9/key", + expectedObjectName: "", + expectedError: "computed cache path outside of project bucket. Please remove `../` from cache key", + }, + } - url := generateObjectName(cacheBuild, cache, "key") - assert.Equal(t, "runner/longtoke/project/10/key", url) -} + for name, test := range tests { + t.Run(name, func(t *testing.T) { + cache.Path = test.path + cache.Shared = test.shared -func TestGetCacheObjectNameWhenSharedFlagIsFalseThenRunnerSegmentShouldNotBePresent(t *testing.T) { - cache := defaultCacheConfig() - cache.Shared = true - cacheBuild := defaultBuild(cache) + objectName, err := generateObjectName(test.build, test.cache, test.key) - url := generateObjectName(cacheBuild, cache, "key") - assert.Equal(t, "project/10/key", url) + assert.Equal(t, test.expectedObjectName, objectName) + if len(test.expectedError) == 0 { + assert.NoError(t, err) + } else { + assert.EqualError(t, err, test.expectedError) + } + }) + } }