diff --git a/pkg/httpclient/httpclient.go b/pkg/httpclient/httpclient.go index 04837a7da0b70c5840c64d9f8b3926d582527f71..671e0e7754e71988b461e2ca8bf47b75bc49429e 100644 --- a/pkg/httpclient/httpclient.go +++ b/pkg/httpclient/httpclient.go @@ -3,6 +3,7 @@ package httpclient import ( "crypto/tls" "crypto/x509" + "encoding/base64" "fmt" "net" "net/http" @@ -10,6 +11,26 @@ import ( "time" ) +func extractCAs(input []string) [][]byte { + result := make([][]byte, 0, len(input)) + for _, ca := range input { + if ca == "" { + continue + } + + pemData, err := os.ReadFile(ca) + if err != nil { + pemData, err = base64.StdEncoding.DecodeString(ca) + if err != nil { + pemData = []byte(ca) + } + } + + result = append(result, pemData) + } + return result +} + func NewHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, error) { pool, err := x509.SystemCertPool() if err != nil { @@ -17,13 +38,11 @@ func NewHTTPClient(rootCAs []string, insecureSkipVerify bool) (*http.Client, err } tlsConfig := tls.Config{RootCAs: pool, InsecureSkipVerify: insecureSkipVerify} - for _, rootCA := range rootCAs { - rootCABytes, err := os.ReadFile(rootCA) - if err != nil { - return nil, fmt.Errorf("failed to read root-ca: %v", err) - } + for index, rootCABytes := range extractCAs(rootCAs) { if !tlsConfig.RootCAs.AppendCertsFromPEM(rootCABytes) { - return nil, fmt.Errorf("no certs found in root CA file %q", rootCA) + return nil, fmt.Errorf("rootCAs.%d is not in PEM format, certificate must be "+ + "a PEM encoded string, a base64 encoded bytes that contain PEM encoded string, "+ + "or a path to a PEM encoded certificate", index) } } diff --git a/pkg/httpclient/httpclient_test.go b/pkg/httpclient/httpclient_test.go index 07baea04eef566d771503d7bdfa437eabfd8343e..6f561c10303ab779179211ea4087c9636e4e7898 100644 --- a/pkg/httpclient/httpclient_test.go +++ b/pkg/httpclient/httpclient_test.go @@ -2,10 +2,12 @@ package httpclient_test import ( "crypto/tls" + "encoding/base64" "fmt" "io" "net/http" "net/http/httptest" + "os" "testing" "github.com/stretchr/testify/assert" @@ -20,18 +22,31 @@ func TestRootCAs(t *testing.T) { assert.Nil(t, err) defer ts.Close() - rootCAs := []string{"testdata/rootCA.pem"} - testClient, err := httpclient.NewHTTPClient(rootCAs, false) - assert.Nil(t, err) + runTest := func(name string, certs []string) { + t.Run(name, func(t *testing.T) { + rootCAs := certs + testClient, err := httpclient.NewHTTPClient(rootCAs, false) + assert.Nil(t, err) - res, err := testClient.Get(ts.URL) - assert.Nil(t, err) + res, err := testClient.Get(ts.URL) + assert.Nil(t, err) - greeting, err := io.ReadAll(res.Body) - res.Body.Close() - assert.Nil(t, err) + greeting, err := io.ReadAll(res.Body) + res.Body.Close() + assert.Nil(t, err) - assert.Equal(t, "Hello, client", string(greeting)) + assert.Equal(t, "Hello, client", string(greeting)) + }) + } + + runTest("From file", []string{"testdata/rootCA.pem"}) + + content, err := os.ReadFile("testdata/rootCA.pem") + assert.NoError(t, err) + runTest("From string", []string{string(content)}) + + contentStr := base64.StdEncoding.EncodeToString(content) + runTest("From bytes", []string{contentStr}) } func TestInsecureSkipVerify(t *testing.T) {