diff --git a/Makefile b/Makefile index 5e0aa31..0a04bcc 100644 --- a/Makefile +++ b/Makefile @@ -25,12 +25,12 @@ setup: test: $(MAKE) -C agent test - $(MAKE) -C sdks/python test @echo "TODO: sdk go test" $(MAKE) -C sdks/go test $(MAKE) -C scaffold test $(MAKE) -C agent relay-test $(MAKE) -C examples test + $(MAKE) -C sdks/python test publish: $(MAKE) -C sdks publish-go-sdk diff --git a/agent/common/integration.go b/agent/common/integration.go index cf1ed8e..0a4ab38 100644 --- a/agent/common/integration.go +++ b/agent/common/integration.go @@ -1,18 +1,15 @@ package common import ( - "crypto/sha256" _ "embed" - "encoding/hex" "encoding/json" "fmt" - "net/url" "os" "path" - "regexp" - "sort" "strings" - "time" + + "github.com/cortexapps/axon/config" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" ) type Integration string @@ -60,7 +57,7 @@ func ParseIntegration(s string) (Integration, error) { } func ValidIntegrations() []Integration { - return []Integration{IntegrationGithub, IntegrationJira, IntegrationGitlab, IntegrationBitbucket, IntegrationSonarqube, IntegrationPrometheus} + return []Integration{IntegrationCustom, IntegrationGithub, IntegrationJira, IntegrationGitlab, IntegrationBitbucket, IntegrationSonarqube, IntegrationPrometheus} } type IntegrationInfo struct { @@ -70,284 +67,27 @@ type IntegrationInfo struct { AcceptFilePath string } -func (ii IntegrationInfo) AcceptFile(localhostBase string) (string, error) { - - // load/locate the accept file then fix it up to have - // an entry to talk to the axon HTTP server itself - // which we use for status, etc - alias := ii.Alias - if len(alias) == 0 { - alias = "default" - } - - content, err := ii.getAcceptFileContents() - if err != nil { - return "", err +func (ii IntegrationInfo) Validate() error { + if err := ii.Integration.Validate(); err != nil { + return err } - - rawFile := []byte(content) - h := sha256.New() - h.Write(rawFile) - expectedPath := path.Join( - os.TempDir(), - "accept-files", - ii.Integration.String(), - alias, - hex.EncodeToString(h.Sum(nil)), - "accept.json", - ) - - dict := map[string]interface{}{} - err = json.Unmarshal(rawFile, &dict) - if err != nil { - return "", err - } - - // we add a section like this to allow the server side - // to hit the axon agent via an HTTP bridge - // "private": [ - // { - // "method": "any", - // "path": "/__axon/*", - // "origin": "http://localhost" - // } - // ] - - entries, ok := dict["private"].([]interface{}) - if !ok { - entries = []interface{}{} - dict["private"] = entries - } - - entry := map[string]string{ - "method": "any", - "path": "/__axon/*", - "origin": localhostBase, - } - dict["private"] = append([]interface{}{entry}, entries...) - - if _, ok := dict["public"]; !ok { - dict["public"] = []interface{}{} - } - - json, err := json.Marshal(dict) - if err != nil { - return "", err - } - err = os.MkdirAll(path.Dir(expectedPath), os.ModeDir|os.ModePerm) - if err != nil { - return "", err - } - err = os.WriteFile(expectedPath, json, os.ModePerm) - - return expectedPath, err -} - -type ValueResolver func() string - -func StringValueResolver(value string) ValueResolver { - return func() string { - return value + if _, err := ii.ValidateSubtype(); err != nil { + return err } + return nil } -type ResolverMap map[string]ValueResolver - -func NewResolverMapFromMap(m map[string]string) ResolverMap { - rm := make(ResolverMap, len(m)) - for key, value := range m { - rm[key] = StringValueResolver(value) - } - return rm -} - -func (rm ResolverMap) ToStringMap() map[string]string { - resolved := make(map[string]string, len(rm)) - for key, resolver := range rm { - resolved[key] = resolver() - } - return resolved -} - -func (rm ResolverMap) Resolve(key string) string { - resolver, ok := rm[key] - if !ok { - return "" - } - return resolver() -} - -func EnvValueResolver(envVar string, defaultValue string, capture bool) ValueResolver { - - captured := os.ExpandEnv(envVar) - return func() string { - if capture { - return captured - } - if val := os.ExpandEnv(envVar); val != "" { - return val - } - - return defaultValue - } -} - -func (ii IntegrationInfo) RewriteOrigins(acceptFilePath string, writer func(string, ResolverMap) string) (*AcceptFileInfo, error) { - - info, err := newAcceptFileInfo(acceptFilePath, writer) - if err != nil { - return nil, err - } - _, err = info.Rewrite() - if err != nil { - return nil, err - } - return info, nil -} - -type AcceptFileInfo struct { - OriginalPath string - RewrittenPath string - Content string - rawContent []byte - Routes []AcceptFileRoute - originRewriter func(uri string, headers ResolverMap) string -} - -var IgnoreHosts = []string{ - "localhost", - "127.0.0.1", -} +func (ii IntegrationInfo) ToAcceptFile(cfg config.AgentConfig) (*acceptfile.AcceptFile, error) { -func newAcceptFileInfo(acceptFilePath string, originRewriter func(string, ResolverMap) string) (*AcceptFileInfo, error) { - stat, err := os.Stat(acceptFilePath) - if err != nil { + if err := ii.Validate(); err != nil { return nil, err } - if stat.IsDir() { - return nil, fmt.Errorf("accept file path %q is a directory, expected a file", acceptFilePath) - } - - info := &AcceptFileInfo{ - OriginalPath: acceptFilePath, - originRewriter: originRewriter, - } - - info.rawContent, err = os.ReadFile(acceptFilePath) + content, err := ii.getAcceptFileContents() if err != nil { return nil, err } - - return info, nil -} - -func (afi *AcceptFileInfo) isIgnoredHost(host string) bool { - for _, ignoreHost := range IgnoreHosts { - if strings.HasPrefix(host, ignoreHost) { - return true - } - } - return false -} - -func (afi *AcceptFileInfo) Rewrite() (string, error) { - - if afi.Content == "" { - dict := map[string]interface{}{} - err := json.Unmarshal(afi.rawContent, &dict) - if err != nil { - return "", err - } - - entries, ok := dict["private"].([]interface{}) - if !ok { - return "", nil - } - - for _, entry := range entries { - values := entry.(map[string]any) - rawOrigin, ok := values["origin"].(string) - if !ok { - continue - } - - origin := os.ExpandEnv(rawOrigin) - - parsed, err := url.Parse(origin) - if err != nil { - return "", fmt.Errorf("failed to parse origin %q: %w", origin, err) - } - - if afi.isIgnoredHost(parsed.Host) { - continue - } - - if parsed.Scheme == "" { - parsed.Scheme = "https" - origin = parsed.String() - } - - // Extract headers if present - var headers ResolverMap - if headersInterface, hasHeaders := values["headers"]; hasHeaders { - if headersMap, ok := headersInterface.(map[string]interface{}); ok { - headers = make(ResolverMap, len(headersMap)) - for k, v := range headersMap { - if strVal, ok := v.(string); ok { - // Resolve environment variables in header values - headers[k] = EnvValueResolver(strVal, "", true) - } - } - } - } - - // rewrite the origin to use the writer function - newOrigin := afi.originRewriter(origin, headers) - if newOrigin != "" { - values["origin"] = newOrigin - } - afi.Routes = append(afi.Routes, AcceptFileRoute{ - ResolvedOrigin: origin, - ProxyOrigin: newOrigin, - Headers: headers, - }) - } - - json, err := json.Marshal(dict) - if err != nil { - return "", err - } - afi.Content = string(json) - - stat, err := os.Stat(afi.OriginalPath) - - if err != nil { - return "", err - } - newFilePath := path.Join( - os.TempDir(), - "accept-files-written", - fmt.Sprintf("rewrite.%v.%v", time.Now().UnixMilli(), stat.Name()), - ) - err = os.MkdirAll(path.Dir(newFilePath), os.ModeDir|os.ModePerm) - if err != nil { - return "", err - } - - err = os.WriteFile(newFilePath, json, os.ModePerm) - if err != nil { - return "", err - } - afi.RewrittenPath = newFilePath - } - return afi.Content, nil -} - -type AcceptFileRoute struct { - ResolvedOrigin string - ProxyOrigin string - Headers ResolverMap + return acceptfile.NewAcceptFile([]byte(content), cfg) } func (ii IntegrationInfo) getAcceptFileContents() (string, error) { @@ -478,31 +218,5 @@ func (ii IntegrationInfo) getIntegrationAcceptFile() (string, error) { if err != nil { return "", err } - if err := ii.ensureAcceptFileVars(strContent); err != nil { - return "", err - } return strContent, nil } - -var reContentVars = regexp.MustCompile(`\$\{(.*?)\}`) - -func (ii IntegrationInfo) ensureAcceptFileVars(content string) error { - varMatch := reContentVars.FindAllStringSubmatch(content, -1) - - envVars := []string{} - - // sort these so they have a stable order - - for _, match := range varMatch { - envVars = append(envVars, match[1]) - } - - sort.Strings(envVars) - - for _, envVar := range envVars { - if os.Getenv(envVar) == "" && os.Getenv(envVar+"_POOL") == "" { - return fmt.Errorf("missing required environment variable %q for integration %s", envVar, ii.Integration.String()) - } - } - return nil -} diff --git a/agent/common/integration_headers_test.go b/agent/common/integration_headers_test.go deleted file mode 100644 index e42b805..0000000 --- a/agent/common/integration_headers_test.go +++ /dev/null @@ -1,347 +0,0 @@ -package common - -import ( - "encoding/json" - "os" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestRewriteOriginsWithHeaderExtraction(t *testing.T) { - // Set up environment variables for testing - os.Setenv("TEST_API_KEY", "secret-key-123") - os.Setenv("TEST_TOKEN", "bearer-token-456") - defer func() { - os.Unsetenv("TEST_API_KEY") - os.Unsetenv("TEST_TOKEN") - }() - - tests := []struct { - name string - acceptContent map[string]interface{} - expectedHeaderCalls []headerCall - }{ - { - name: "single route with headers", - acceptContent: map[string]interface{}{ - "private": []interface{}{ - map[string]interface{}{ - "method": "GET", - "path": "/api/*", - "origin": "https://api.example.com", - "headers": map[string]any{ - "x-api-key": "${TEST_API_KEY}", - "x-static": "static-value", - }, - }, - }, - }, - expectedHeaderCalls: []headerCall{ - { - origin: "https://api.example.com", - headers: map[string]string{ - "x-api-key": "secret-key-123", - "x-static": "static-value", - }, - }, - }, - }, - { - name: "multiple routes with different headers", - acceptContent: map[string]interface{}{ - "private": []interface{}{ - map[string]interface{}{ - "method": "GET", - "path": "/api1/*", - "origin": "https://api1.example.com", - "headers": map[string]interface{}{ - "authorization": "Bearer ${TEST_TOKEN}", - }, - }, - map[string]interface{}{ - "method": "POST", - "path": "/api2/*", - "origin": "https://api2.example.com", - "headers": map[string]interface{}{ - "x-api-key": "${TEST_API_KEY}", - "x-service": "test-service", - }, - }, - }, - }, - expectedHeaderCalls: []headerCall{ - { - origin: "https://api1.example.com", - headers: map[string]string{ - "authorization": "Bearer bearer-token-456", - }, - }, - { - origin: "https://api2.example.com", - headers: map[string]string{ - "x-api-key": "secret-key-123", - "x-service": "test-service", - }, - }, - }, - }, - { - name: "route without headers", - acceptContent: map[string]interface{}{ - "private": []interface{}{ - map[string]interface{}{ - "method": "GET", - "path": "/api/*", - "origin": "https://api.example.com", - }, - }, - }, - expectedHeaderCalls: []headerCall{}, // No headers, so no calls expected - }, - { - name: "mixed routes with and without headers", - acceptContent: map[string]interface{}{ - "private": []interface{}{ - map[string]interface{}{ - "method": "GET", - "path": "/api1/*", - "origin": "https://api1.example.com", - "headers": map[string]interface{}{ - "x-api-key": "${TEST_API_KEY}", - }, - }, - map[string]interface{}{ - "method": "GET", - "path": "/api2/*", - "origin": "https://api2.example.com", - // No headers section - }, - }, - }, - expectedHeaderCalls: []headerCall{ - { - origin: "https://api1.example.com", - headers: map[string]string{ - "x-api-key": "secret-key-123", - }, - }, - }, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Create temporary accept file - acceptFile := createTempAcceptFile(t, tt.acceptContent) - defer os.Remove(acceptFile) - - // Create integration info - integrationInfo := IntegrationInfo{ - Integration: IntegrationCustom, - AcceptFilePath: acceptFile, - } - - // Capture header extraction calls - var headerCalls []headerCall - headerExtractor := func(origin string, headers ResolverMap) string { - if len(headers) > 0 { - headerCalls = append(headerCalls, headerCall{ - origin: origin, - headers: headers.ToStringMap(), - }) - } - return "proxy-" + origin // Mocking the proxy URI generation - } - - // Call RewriteOriginsWithHeaderExtraction - newFile, err := integrationInfo.RewriteOrigins( - acceptFile, - headerExtractor, - ) - require.NoError(t, err) - require.NotEmpty(t, newFile) - - // Verify header extraction calls - assert.Equal(t, len(tt.expectedHeaderCalls), len(headerCalls)) - for i, expected := range tt.expectedHeaderCalls { - if i < len(headerCalls) { - assert.Equal(t, expected.origin, headerCalls[i].origin) - assert.Equal(t, expected.headers, headerCalls[i].headers) - } - } - }) - } -} - -func TestHeaderEnvironmentVariableResolution(t *testing.T) { - // Test missing environment variable - os.Unsetenv("MISSING_VAR") - os.Setenv("EXISTING_VAR", "existing-value") - defer os.Unsetenv("EXISTING_VAR") - - acceptContent := map[string]interface{}{ - "private": []interface{}{ - map[string]interface{}{ - "method": "GET", - "path": "/api/*", - "origin": "https://api.example.com", - "headers": map[string]interface{}{ - "x-existing": "${EXISTING_VAR}", - "x-missing": "${MISSING_VAR}", - "x-static": "no-vars-here", - }, - }, - }, - } - - acceptFile := createTempAcceptFile(t, acceptContent) - defer os.Remove(acceptFile) - - integrationInfo := IntegrationInfo{ - Integration: IntegrationCustom, - AcceptFilePath: acceptFile, - } - - var capturedHeaders map[string]string - headerExtractor := func(_ string, headers ResolverMap) { - capturedHeaders = headers.ToStringMap() - } - - _, err := integrationInfo.RewriteOrigins( - acceptFile, - func(originalURI string, headers ResolverMap) string { - headerExtractor(originalURI, headers) - return originalURI - }, - ) - require.NoError(t, err) - - // Verify environment variable resolution - assert.Equal(t, "existing-value", capturedHeaders["x-existing"]) - assert.Equal(t, "", capturedHeaders["x-missing"]) // os.ExpandEnv returns empty string for missing vars - assert.Equal(t, "no-vars-here", capturedHeaders["x-static"]) -} - -func TestComplexEnvironmentVariablePatterns(t *testing.T) { - os.Setenv("PREFIX", "api") - os.Setenv("VERSION", "v1") - os.Setenv("SECRET", "my-secret-123") - defer func() { - os.Unsetenv("PREFIX") - os.Unsetenv("VERSION") - os.Unsetenv("SECRET") - }() - - acceptContent := map[string]interface{}{ - "private": []interface{}{ - map[string]interface{}{ - "method": "GET", - "path": "/api/*", - "origin": "https://api.example.com", - "headers": map[string]interface{}{ - "x-service-name": "${PREFIX}-service-${VERSION}", - "authorization": "Bearer ${SECRET}", - "x-mixed": "prefix-${VERSION}-suffix", - }, - }, - }, - } - - acceptFile := createTempAcceptFile(t, acceptContent) - defer os.Remove(acceptFile) - - integrationInfo := IntegrationInfo{ - Integration: IntegrationCustom, - AcceptFilePath: acceptFile, - } - - var capturedHeaders map[string]string - headerExtractor := func(origin string, headers ResolverMap) string { - capturedHeaders = headers.ToStringMap() - return origin - } - - _, err := integrationInfo.RewriteOrigins( - acceptFile, - headerExtractor, - ) - require.NoError(t, err) - - // Verify complex environment variable patterns - assert.Equal(t, "api-service-v1", capturedHeaders["x-service-name"]) - assert.Equal(t, "Bearer my-secret-123", capturedHeaders["authorization"]) - assert.Equal(t, "prefix-v1-suffix", capturedHeaders["x-mixed"]) -} - -func TestEmptyAndInvalidHeaderValues(t *testing.T) { - acceptContent := map[string]interface{}{ - "private": []interface{}{ - map[string]interface{}{ - "method": "GET", - "path": "/api/*", - "origin": "https://api.example.com", - "headers": map[string]interface{}{ - "x-empty": "", - "x-number": 123, // Non-string value - "x-boolean": true, // Non-string value - "x-string": "valid", // Valid string value - }, - }, - }, - } - - acceptFile := createTempAcceptFile(t, acceptContent) - defer os.Remove(acceptFile) - - integrationInfo := IntegrationInfo{ - Integration: IntegrationCustom, - AcceptFilePath: acceptFile, - } - - var capturedHeaders map[string]string - headerExtractor := func(origin string, headers ResolverMap) string { - capturedHeaders = headers.ToStringMap() - return origin - } - - _, err := integrationInfo.RewriteOrigins( - acceptFile, - headerExtractor, - ) - require.NoError(t, err) - - // Verify that only string values are processed - assert.Equal(t, "", capturedHeaders["x-empty"]) - assert.Equal(t, "valid", capturedHeaders["x-string"]) - - // Non-string values should not be included - _, hasNumber := capturedHeaders["x-number"] - _, hasBoolean := capturedHeaders["x-boolean"] - assert.False(t, hasNumber) - assert.False(t, hasBoolean) -} - -// Helper types and functions - -type headerCall struct { - origin string - headers map[string]string -} - -func createTempAcceptFile(t *testing.T, content map[string]interface{}) string { - jsonContent, err := json.MarshalIndent(content, "", " ") - require.NoError(t, err) - - tmpFile, err := os.CreateTemp("", "accept-*.json") - require.NoError(t, err) - - _, err = tmpFile.Write(jsonContent) - require.NoError(t, err) - - err = tmpFile.Close() - require.NoError(t, err) - - return tmpFile.Name() -} diff --git a/agent/common/integration_test.go b/agent/common/integration_test.go index d4ac9c1..bedcdf0 100644 --- a/agent/common/integration_test.go +++ b/agent/common/integration_test.go @@ -4,37 +4,16 @@ import ( "fmt" "os" "path" - "strings" "testing" "time" + "github.com/cortexapps/axon/config" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" + "github.com/stretchr/testify/require" + "go.uber.org/zap" ) -func TestEmptyAcceptFile(t *testing.T) { - - acceptFileContents := "{}" - acceptFilePath := writeTempFile(t, acceptFileContents) - - info := IntegrationInfo{ - Integration: IntegrationGithub, - Alias: fmt.Sprintf("%v", time.Now().UnixMilli()), - AcceptFilePath: acceptFilePath, - } - - defer func() { - os.Remove(info.AcceptFilePath) - }() - - resultPath, err := info.AcceptFile("http://localhost:9999") - require.NoError(t, err) - - contents, err := os.ReadFile(resultPath) - require.NoError(t, err) - require.Equal(t, "{\"private\":[{\"method\":\"any\",\"origin\":\"http://localhost:9999\",\"path\":\"/__axon/*\"}],\"public\":[]}", string(contents)) - -} - func TestGithubDefaultAcceptFile(t *testing.T) { os.Setenv("GITHUB_TOKEN", "the-github-token") @@ -48,10 +27,9 @@ func TestGithubDefaultAcceptFile(t *testing.T) { Alias: fmt.Sprintf("%v", time.Now().UnixMilli()), } - resultPath, err := info.AcceptFile("http://localhost:9999") - require.NoError(t, err) - _, err = os.Stat(resultPath) + file, err := info.ToAcceptFile(config.NewAgentEnvConfig()) require.NoError(t, err) + require.NotNil(t, file) } func TestGithubDefaultAcceptFileSubtypeInvalid(t *testing.T) { @@ -68,7 +46,7 @@ func TestGithubDefaultAcceptFileSubtypeInvalid(t *testing.T) { Alias: fmt.Sprintf("%v", time.Now().UnixMilli()), } - _, err := info.AcceptFile("http://localhost:9999") + _, err := info.ToAcceptFile(config.NewAgentEnvConfig()) require.Error(t, err) } @@ -98,16 +76,16 @@ func TestExistingAcceptFile(t *testing.T) { AcceptFilePath: acceptFilePath, } - defer func() { + t.Cleanup(func() { os.Remove(info.AcceptFilePath) - }() + }) - resultPath, err := info.AcceptFile("http://localhost:9999") + af, err := info.ToAcceptFile(config.NewAgentEnvConfig()) require.NoError(t, err) - contents, err := os.ReadFile(resultPath) + contents, err := af.Render(zap.NewNop()) require.NoError(t, err) - require.Equal(t, `{"private":[{"method":"any","origin":"http://localhost:9999","path":"/__axon/*"},{"method":"any","origin":"http://python-server","path":"/*"}],"public":[{"method":"any","path":"/*"}]}`, string(contents)) + require.Equal(t, `{"private":[{"method":"any","origin":"http://localhost:80","path":"/__axon/*"},{"method":"any","origin":"http://python-server","path":"/*"}],"public":[{"method":"any","path":"/*"}]}`, string(contents)) } @@ -118,12 +96,12 @@ func setAcceptFileDir(t *testing.T) { os.Setenv("ACCEPTFILE_DIR", acceptFileDir) } -func loadAcceptFile(t *testing.T, integration Integration) (string, error) { +func loadAcceptFile(t *testing.T, integration Integration) (*acceptfile.AcceptFile, error) { setAcceptFileDir(t) ii := IntegrationInfo{ Integration: integration, } - return ii.AcceptFile("http://localhost:9999") + return ii.ToAcceptFile(config.NewAgentEnvConfig()) } func init() { setAcceptFileDir(&testing.T{}) @@ -135,9 +113,8 @@ func TestLoadIntegrationAcceptFileSuccess(t *testing.T) { os.Setenv("GITHUB_API", "foo.github.com") os.Setenv("GITHUB_GRAPHQL", "foo.github.com/graphql") - acceptFile, err := loadAcceptFile(t, IntegrationGithub) + _, err := loadAcceptFile(t, IntegrationGithub) require.NoError(t, err) - require.NotEmpty(t, acceptFile) } func TestLoadIntegrationAcceptFileMissingVars(t *testing.T) { @@ -157,67 +134,12 @@ func TestLoadIntegrationAcceptFilePoolVars(t *testing.T) { acceptFile, err := loadAcceptFile(t, IntegrationGithub) require.NoError(t, err) - contents, err := os.ReadFile(acceptFile) + contents, err := acceptFile.Render(zap.NewNop()) require.NoError(t, err) require.Contains(t, string(contents), "GITHUB_TOKEN") require.NotContains(t, string(contents), "GITHUB_TOKEN_POOL") } -func TestAcceptRewrite(t *testing.T) { - acceptFileContents := ` - { - "public": [ - { - "method": "any", - "path": "/*" - } - ], - "private": [ - { - "method": "any", - "path": "/*", - "origin": "http://python-server" - }, - { - "method": "any", - "path": "/*", - "origin": "http://localhost" - }, - { - "method": "any", - "path": "/stuff/*", - "origin": "http://localhost:9999" - }, - { - "method": "any", - "path": "/*", - "origin": "api.foo.com" - } - ] - } - ` - acceptFilePath := writeTempFile(t, acceptFileContents) - info := IntegrationInfo{ - Integration: IntegrationGithub, - AcceptFilePath: acceptFilePath, - } - rewritten, err := info.RewriteOrigins(acceptFilePath, func(origin string, headers ResolverMap) string { - - if strings.Contains(origin, "http://localhost") { - require.Fail(t, "should not rewrite localhost origins") - } - - if origin == "http://python-server" { - return "http://new-python-server" - } - return origin - }) - require.NoError(t, err) - contents, err := os.ReadFile(rewritten.RewrittenPath) - require.NoError(t, err) - require.Equal(t, `{"private":[{"method":"any","origin":"http://new-python-server","path":"/*"},{"method":"any","origin":"http://localhost","path":"/*"},{"method":"any","origin":"http://localhost:9999","path":"/stuff/*"},{"method":"any","origin":"https://api.foo.com","path":"/*"}],"public":[{"method":"any","path":"/*"}]}`, string(contents)) -} - func TestGetOrigin(t *testing.T) { os.Setenv("USER", "testuser") @@ -310,6 +232,9 @@ func writeTempFile(t *testing.T, contents string) string { f, err := os.CreateTemp(t.TempDir(), "accept.*.json") require.NoError(t, err) defer f.Close() + t.Cleanup(func() { + os.Remove(f.Name()) + }) _, err = f.WriteString(contents) require.NoError(t, err) diff --git a/agent/common/version.go b/agent/common/version.go index 30e2ecb..0b99bb5 100644 --- a/agent/common/version.go +++ b/agent/common/version.go @@ -1,10 +1,13 @@ package common +import _ "embed" + // This is the version we send to Cortex to identify our // client protocol Backcompat should always be possible with this // sent along const ClientVersion = "0.0.1" // This is the version the GRPC client -// go:embed grpcversion.txt +// +//go:embed grpcversion.txt var GrpcVersion string diff --git a/agent/config/config.go b/agent/config/config.go index 05d7746..6109689 100644 --- a/agent/config/config.go +++ b/agent/config/config.go @@ -51,6 +51,7 @@ type AgentConfig struct { FailWaitTime time.Duration AutoRegisterFrequency time.Duration VerboseOutput bool + PluginDirs []string HandlerHistoryPath string HandlerHistoryMaxAge time.Duration @@ -205,25 +206,32 @@ func NewAgentEnvConfig() AgentConfig { } cfg := AgentConfig{ - GrpcPort: port, - CortexApiBaseUrl: baseUrl, - CortexApiToken: token, - DryRun: dryRun, - DequeueWaitTime: dequeueWaitTime, - InstanceId: getInstanceId(), - IntegrationAlias: identifier, - HttpServerPort: httpPort, - WebhookServerPort: WebhookServerPort, - SnykBrokerPort: snykBrokerPort, - EnableApiProxy: true, - FailWaitTime: time.Second * 2, - + GrpcPort: port, + CortexApiBaseUrl: baseUrl, + CortexApiToken: token, + DryRun: dryRun, + DequeueWaitTime: dequeueWaitTime, + InstanceId: getInstanceId(), + IntegrationAlias: identifier, + HttpServerPort: httpPort, + WebhookServerPort: WebhookServerPort, + SnykBrokerPort: snykBrokerPort, + EnableApiProxy: true, + FailWaitTime: time.Second * 2, + PluginDirs: []string{"./plugins"}, AutoRegisterFrequency: reregisterFrequency, HandlerHistoryPath: historyPath, HandlerHistoryMaxAge: handlerHistoryMaxAge, HandlerHistoryMaxSizeBytes: handlerHistoryMaxSizeBytes, } + if pluginDirsEnv := os.Getenv("PLUGIN_DIRS"); pluginDirsEnv != "" { + pluginDirs := filepath.SplitList(pluginDirsEnv) + for _, dir := range pluginDirs { + cfg.PluginDirs = append(cfg.PluginDirs, filepath.Clean(dir)) + } + } + if DisableTLS := os.Getenv("DISABLE_TLS"); DisableTLS == "true" { cfg.HttpDisableTLS = true } diff --git a/agent/go.mod b/agent/go.mod index 59dc34e..3efd6bd 100644 --- a/agent/go.mod +++ b/agent/go.mod @@ -3,7 +3,6 @@ module github.com/cortexapps/axon go 1.22.9 require ( - github.com/PaesslerAG/jsonpath v0.1.1 github.com/google/uuid v1.6.0 github.com/gorilla/mux v1.8.1 github.com/prometheus/client_golang v1.22.0 @@ -18,7 +17,6 @@ require ( ) require ( - github.com/PaesslerAG/gval v1.0.0 // indirect github.com/beorn7/perks v1.0.1 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect github.com/davecgh/go-spew v1.1.1 // indirect diff --git a/agent/go.sum b/agent/go.sum index a2f6ef3..cbe838f 100644 --- a/agent/go.sum +++ b/agent/go.sum @@ -1,8 +1,3 @@ -github.com/PaesslerAG/gval v1.0.0 h1:GEKnRwkWDdf9dOmKcNrar9EA1bz1z9DqPIO1+iLzhd8= -github.com/PaesslerAG/gval v1.0.0/go.mod h1:y/nm5yEyTeX6av0OfKJNp9rBNj2XrGhAf5+v24IBN1I= -github.com/PaesslerAG/jsonpath v0.1.0/go.mod h1:4BzmtoM/PI8fPO4aQGIusjGxGir2BzcV0grWtFzq1Y8= -github.com/PaesslerAG/jsonpath v0.1.1 h1:c1/AToHQMVsduPAa4Vh6xp2U0evy4t8SWp8imEsylIk= -github.com/PaesslerAG/jsonpath v0.1.1/go.mod h1:lVboNxFGal/VwW6d9JzIy56bUsYAP6tH/x80vjnCseY= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= diff --git a/agent/server/http/axon_handler_test.go b/agent/server/http/axon_handler_test.go index 77a54eb..3906005 100644 --- a/agent/server/http/axon_handler_test.go +++ b/agent/server/http/axon_handler_test.go @@ -70,7 +70,7 @@ func TestInvokeEndpoint(t *testing.T) { req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) assert.NoError(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode) assert.NotNil(t, manager.GetByTag("test-handler").LastInvoked()) body, err := io.ReadAll(resp.Body) @@ -127,7 +127,7 @@ func TestInvokeEndpointErr(t *testing.T) { req.Header.Set("Content-Type", "application/json") resp, err := http.DefaultClient.Do(req) assert.NoError(t, err) - assert.Equal(t, http.StatusInternalServerError, resp.StatusCode) + require.Equal(t, http.StatusInternalServerError, resp.StatusCode) assert.NotNil(t, manager.GetByTag("test-handler").LastInvoked()) body, err := io.ReadAll(resp.Body) diff --git a/agent/server/http/http_server.go b/agent/server/http/http_server.go index a7b6765..def9171 100644 --- a/agent/server/http/http_server.go +++ b/agent/server/http/http_server.go @@ -175,7 +175,9 @@ type responseRecorder struct { func (rr *responseRecorder) WriteHeader(code int) { rr.statusCode = code - rr.ResponseWriter.WriteHeader(code) + if code != http.StatusOK { + rr.ResponseWriter.WriteHeader(code) + } } // needed for websockets/HTTP2 diff --git a/agent/server/http/http_server_webhook_test.go b/agent/server/http/http_server_webhook_test.go index 997bbfc..829c92f 100644 --- a/agent/server/http/http_server_webhook_test.go +++ b/agent/server/http/http_server_webhook_test.go @@ -55,7 +55,7 @@ func TestHandleWebhook(t *testing.T) { resp, err := http.DefaultClient.Do(req) assert.NoError(t, err) defer resp.Body.Close() - assert.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode) /** Ensure that the handler was invoked diff --git a/agent/server/snykbroker/acceptfile/accept_file.go b/agent/server/snykbroker/acceptfile/accept_file.go new file mode 100644 index 0000000..c16fe18 --- /dev/null +++ b/agent/server/snykbroker/acceptfile/accept_file.go @@ -0,0 +1,256 @@ +package acceptfile + +import ( + "encoding/json" + "fmt" + + "github.com/cortexapps/axon/config" + axonHttp "github.com/cortexapps/axon/server/http" + "go.uber.org/zap" +) + +// AcceptFile is an abstraction over the Snyk Broker "accept.json" format. +// It owns manipulating a raw file into one that is customized for Axon including +// adding the Axon rules, and doing replacements for the reflector (eg all traffic is sent through the agent +// instead of directly to the target), as well as adding support for adding outbound headers. +type AcceptFile struct { + wrapper acceptFileWrapper + content []byte + config config.AgentConfig +} + +// NewAcceptFile creates a new AcceptFile instance, taking the raw content of the accept file +// and the agent configuration. It preprocesses the content to handle plugin invocations. +func NewAcceptFile(content []byte, cfg config.AgentConfig) (*AcceptFile, error) { + + // Fixup ${} references to support plugins without confusing with env vars + processedContent, err := preProcessContent(content) + if err != nil { + return nil, fmt.Errorf("failed to preprocess accept file content: %w", err) + } + + af := &AcceptFile{ + content: processedContent, + config: cfg, + } + + if err := ensureAcceptFileVars(string(af.content)); err != nil { + return nil, err + } + + af.wrapper = newAcceptFileWrapper(processedContent, af) + return af, nil +} + +type RenderContext struct { + AcceptFile acceptFileWrapper + Logger *zap.Logger +} + +type RenderStep func(renderContext RenderContext) error + +var IgnoreHosts = []string{ + "localhost", + "127.0.0.1", +} + +const RULES_PRIVATE = "private" +const RULES_PUBLIC = "public" + +// Render renders the accept file by applying Axon updates plus any additional render steps provided. +// It returns the rendered JSON content of the accept file. +func (a *AcceptFile) Render(logger *zap.Logger, extraRenderSteps ...RenderStep) ([]byte, error) { + + renderContext := RenderContext{ + Logger: logger, + AcceptFile: newAcceptFileWrapper(a.content, a), + } + + renderSteps := append([]RenderStep{ + a.ensurePublicAndPrivate, + a.addAxonRoute, + }, extraRenderSteps...) + + for _, step := range renderSteps { + if err := step(renderContext); err != nil { + logger.Error("failed to render accept file", zap.Error(err)) + return nil, err + } + } + + json, err := renderContext.AcceptFile.toJSON() + if err != nil { + logger.Error("failed to marshal accept file content", zap.Error(err)) + return nil, err + } + + return json, nil +} + +func (a *AcceptFile) ensurePublicAndPrivate(renderContext RenderContext) error { + renderContext.AcceptFile.PrivateRules() + renderContext.AcceptFile.PublicRules() + return nil +} + +func (a *AcceptFile) addAxonRoute(renderContext RenderContext) error { + // we add a section like this to allow the server side + // to hit the axon agent via an HTTP bridge + // "private": [ + // { + // "method": "any", + // "path": "/__axon/*", + // "origin": "http://localhost" + // } + // ] + + entry := acceptFileRule{ + Method: "any", + Path: fmt.Sprintf("%s/*", axonHttp.AxonPathRoot), + Origin: a.config.HttpBaseUrl(), + } + + renderContext.AcceptFile.AddRule(RULES_PRIVATE, entry) + return nil +} + +// +// The wrapper classes below provide strongly typed access to the accept file +// which is parsed as a map[string]any. We do this to ensure full compatibility +// with the Snyk Broker accept file format while also being able to add extra functionality. +// In other words, if we simply mapped the file to a struct here, its possible there would be content +// in the accept file that we don't know about, and we would lose that content on a round trip. +// + +// acceptFileWrapper provides a strongly typed wrapper around the accept file content. +type acceptFileWrapper struct { + dict map[string]any + acceptFile *AcceptFile +} + +func newAcceptFileWrapper(content []byte, af *AcceptFile) acceptFileWrapper { + dict := make(map[string]any) + err := json.Unmarshal(content, &dict) + if err != nil { + panic(fmt.Errorf("failed to unmarshal accept file content: %w, content was:\n%s", err, string(content))) + } + + return acceptFileWrapper{dict: dict, acceptFile: af} +} + +func (w acceptFileWrapper) PrivateRules() []acceptFileRuleWrapper { + return w.rules(RULES_PRIVATE) +} + +func (w acceptFileWrapper) PublicRules() []acceptFileRuleWrapper { + return w.rules(RULES_PUBLIC) +} + +func (w acceptFileWrapper) rules(routeType string) []acceptFileRuleWrapper { + routesEntry, ok := w.dict[routeType].([]interface{}) + if !ok { + routesEntry = []any{} + w.dict[routeType] = routesEntry + } + + routes := make([]acceptFileRuleWrapper, len(routesEntry)) + for i, route := range routesEntry { + routeDict, ok := route.(map[string]any) + if !ok { + return nil + } + routes[i] = acceptFileRuleWrapper{ + dict: routeDict, + acceptFile: w.acceptFile, + } + } + return routes +} + +// AddRule adds a new route to the accept file for the specified route type. +func (w acceptFileWrapper) AddRule(routeType string, entry acceptFileRule) acceptFileRuleWrapper { + + // with a little extra work here we could probably just directly use + // the entry structure above, but the acceptFileRuleWrapper takes a dict so we need + // to convert it to a map[string]any first, so we round trip it through JSON. + + routeAsJson, err := json.Marshal(entry) + if err != nil { + panic(fmt.Errorf("failed to marshal accept file route: %w", err)) + } + var routeDict map[string]any + err = json.Unmarshal(routeAsJson, &routeDict) + if err != nil { + panic(fmt.Errorf("failed to unmarshal accept file route: %w", err)) + } + existingRoutes := w.dict[routeType].([]any) + w.dict[routeType] = append([]any{routeDict}, existingRoutes...) + return acceptFileRuleWrapper{dict: routeDict} +} + +func (w acceptFileWrapper) toJSON() ([]byte, error) { + jsonData, err := json.Marshal(w.dict) + if err != nil { + return nil, fmt.Errorf("failed to marshal accept file content: %w", err) + } + return jsonData, nil +} + +type acceptFileRuleWrapper struct { + dict map[string]any + acceptFile *AcceptFile +} + +func (r acceptFileRuleWrapper) Origin() string { + origin, ok := r.dict["origin"].(string) + if !ok { + return "" + } + return origin +} + +func (r acceptFileRuleWrapper) Path() string { + path, ok := r.dict["path"].(string) + if !ok { + return "" + } + return path +} + +func (r acceptFileRuleWrapper) SetOrigin(origin string) { + r.dict["origin"] = origin +} + +func (r acceptFileRuleWrapper) Headers() ResolverMap { + headers, ok := r.dict["headers"].(map[string]any) + if !ok { + return nil + } + + result := make(ResolverMap) + for k, v := range headers { + if str, ok := v.(string); ok { + result[k] = CreateResolver(str, zap.NewNop(), r.acceptFile.config.PluginDirs) + } + } + return result +} + +// Here are our JSON structed types that represent the accept file rules. +// that we can use for things that we are generating such that we don't need to worry +// about additional fields that might be in the accept file that we don't know about. + +type acceptFileRule struct { + Method string `json:"method"` + Path string `json:"path"` + Origin string `json:"origin"` + Auth *acceptFileRuleAuth `json:"auth,omitempty"` + Headers map[string]string `json:"headers,omitempty"` +} + +type acceptFileRuleAuth struct { + Scheme string `json:"scheme"` + Username string `json:"username,omitempty"` + Password string `json:"password,omitempty"` + Token string `json:"token,omitempty"` +} diff --git a/agent/server/snykbroker/acceptfile/accept_file_test.go b/agent/server/snykbroker/acceptfile/accept_file_test.go new file mode 100644 index 0000000..d71134d --- /dev/null +++ b/agent/server/snykbroker/acceptfile/accept_file_test.go @@ -0,0 +1,191 @@ +package acceptfile + +import ( + "fmt" + "os" + "strings" + "testing" + + axonConfig "github.com/cortexapps/axon/config" + "go.uber.org/zap" + + "github.com/stretchr/testify/require" +) + +func TestEmptyAcceptFile(t *testing.T) { + + acceptFiles := []string{ + "{}", + `{"private": [], "public": []}`, + `{"private": []}`, + `{"public": []}`, + } + + for _, acceptFileContents := range acceptFiles { + t.Run(acceptFileContents, func(t *testing.T) { + cfg := axonConfig.NewAgentEnvConfig() + cfg.HttpServerPort = 9999 + acceptFile, err := NewAcceptFile([]byte(acceptFileContents), cfg) + require.NoError(t, err) + contents, err := acceptFile.Render(zap.NewNop()) + require.NoError(t, err) + require.Equal(t, fmt.Sprintf("{\"private\":[{\"method\":\"any\",\"origin\":\"%s\",\"path\":\"/__axon/*\"}],\"public\":[]}", cfg.HttpBaseUrl()), string(contents)) + }) + } +} + +func TestAcceptFileValidate(t *testing.T) { + + cfg := axonConfig.NewAgentEnvConfig() + + files := []struct { + content string + valid bool + envVars map[string]string + }{ + { + content: `{"private": [], "public": []}`, + valid: true, + envVars: nil, + }, + { + content: `{"private": [ + {"method": "GET", "origin": "${API}", "path": "/*"} + ]}`, + valid: true, + envVars: map[string]string{"API": "value"}, + }, + { + content: `{"private": [ + {"method": "GET", "origin": "${API}", "path": "/*"} + ]}`, + valid: false, + envVars: nil, + }, + { + content: `{"private": [ + {"method": "GET", "origin": "${plugin:API}", "path": "/*"} + ]}`, + valid: true, + envVars: nil, + }, + { + content: `{"private": [ + {"method": "GET", "origin": "${env:API}", "path": "/*"} + ]}`, + valid: false, + envVars: nil, + }, + { + content: `{"vars": ["${env:API}", "${OTHER}"], "private": []}`, + valid: true, + envVars: map[string]string{"API": "value", "OTHER": "othervalue"}, + }, + } + + for _, file := range files { + t.Run(file.content, func(t *testing.T) { + if file.envVars != nil { + for k, v := range file.envVars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range file.envVars { + os.Unsetenv(k) + } + }) + } + _, err := NewAcceptFile([]byte(file.content), cfg) + if file.valid { + require.NoError(t, err) + } else { + require.Error(t, err) + return + } + + }) + } + +} + +func TestRenderEnvVars(t *testing.T) { + + vars := map[string]string{ + "API": "value", + "OTHER": "othervalue", + "plugin": "nope", + } + + for k, v := range vars { + os.Setenv(k, v) + } + t.Cleanup(func() { + for k := range vars { + os.Unsetenv(k) + } + }) + + cfg := axonConfig.NewAgentEnvConfig() + + content := `{ + "$vars":["${env:API}", "${OTHER}", "${plugin:foo}", "${OTHER}"], "private": []}` + + af, err := NewAcceptFile([]byte(content), cfg) + require.NoError(t, err) + rendered, err := af.Render(zap.NewNop()) + require.NoError(t, err) + expected := `{"$vars":["${API}","${OTHER}","{{plugin:foo}}","${OTHER}"],"private":[{"method":"any","origin":"http://localhost:80","path":"/__axon/*"}],"public":[]}` + require.Equal(t, expected, string(rendered), "Rendered accept file does not match expected output") +} + +func TestExtraRenderSteps(t *testing.T) { + acceptFileContents := `{ + + "private": [ + {"method": "GET", "origin": "http://localhost:9999", "path": "/private/*"} + ] + }` + cfg := axonConfig.NewAgentEnvConfig() + cfg.HttpServerPort = 9999 + logger := zap.NewNop() + acceptFile, err := NewAcceptFile([]byte(acceptFileContents), cfg) + require.NoError(t, err) + + rendered, err := acceptFile.Render(logger, func(renderContext RenderContext) error { + + for _, entry := range renderContext.AcceptFile.PrivateRules() { + if !strings.Contains(entry.Path(), "axon") { + entry.SetOrigin("http://localhost:8888") + } + } + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, rendered) + expected := `{"private":[{"method":"any","origin":"http://localhost:9999","path":"/__axon/*"},{"method":"GET","origin":"http://localhost:8888","path":"/private/*"}],"public":[]}` + require.Equal(t, expected, string(rendered), "Rendered accept file does not match expected output") +} + +func TestPreProcessContent(t *testing.T) { + content := `{ + "$vars":[ + "${env:API}", + "${OTHER}", + "${plugin:foo}", + "${OTHER}" + ] + }` + + expected := `{ + "$vars":[ + "${API}", + "${OTHER}", + "{{plugin:foo}}", + "${OTHER}" + ] + }` + + processed, err := preProcessContent([]byte(content)) + require.NoError(t, err) + require.Equal(t, expected, string(processed), "Processed content does not match expected output") +} diff --git a/agent/server/snykbroker/acceptfile/plugin.go b/agent/server/snykbroker/acceptfile/plugin.go new file mode 100644 index 0000000..43a8d5c --- /dev/null +++ b/agent/server/snykbroker/acceptfile/plugin.go @@ -0,0 +1,80 @@ +package acceptfile + +import ( + "bytes" + "fmt" + "os" + "os/exec" + "time" + + "go.uber.org/zap" +) + +// Plugin is an extension to the accept file system that allows some values in the +// accept file to be dynamically generated by executing an executable on the local system. +// +// It is currently only supported for headers. +// +// The format of plugin invocations in the accept file is: `${plugin:myplugin}`, where +// `myplugin` is the name of the executable in the provided directories specified by +// config.AgentConfig.PluginDirs. The default is the `./plugins` directory in the agent's working directory, +// but this can be overridden by setting the `PLUGIN_DIRS` environment variable. +type Plugin struct { + Name string + FullPath string + Logger *zap.Logger +} + +func NewPlugin(name, fullPath string, logger *zap.Logger) Plugin { + _, err := os.Stat(fullPath) + if err != nil { + logger.Panic("Failed to stat plugin file", zap.String("path", fullPath), + zap.Error(err)) + } + return Plugin{ + Name: name, + FullPath: fullPath, + Logger: logger.Named("plugin-" + name), + } +} + +func FindPlugin(name string, dirs []string, logger *zap.Logger) (Plugin, error) { + // searches for the plugin in the provided directories + for _, dir := range dirs { + fullPath := dir + "/" + name + if _, err := exec.LookPath(fullPath); err == nil { + + return NewPlugin(name, fullPath, logger), nil + } + } + return Plugin{}, fmt.Errorf("plugin %q not found in directories: %v", name, dirs) +} + +func (p Plugin) Execute() (string, error) { + + // executes the plugin and returns the output from stdout + + cmd := exec.Command(p.FullPath) + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + now := time.Now() + err := cmd.Run() + + if err != nil { + p.Logger.Error("Plugin execution failed", + zap.String("path", p.FullPath), + zap.Error(err), + zap.String("stderr", stderr.String()), + zap.String("stdout", stdout.String()), + zap.Int("exit-code", cmd.ProcessState.ExitCode()), + ) + return "", fmt.Errorf("failed to execute %q (%v), output was:\nstderr: %s, stdout:%s", p.FullPath, err, stderr.String(), stdout.String()) + } else { + duration := time.Since(now) + p.Logger.Info("Executed plugin", zap.String("path", p.FullPath), zap.Duration("duration", duration)) + } + return stdout.String(), nil + +} diff --git a/agent/server/snykbroker/acceptfile/plugin.sh b/agent/server/snykbroker/acceptfile/plugin.sh new file mode 100755 index 0000000..d20bdcb --- /dev/null +++ b/agent/server/snykbroker/acceptfile/plugin.sh @@ -0,0 +1,2 @@ +#! /bin/sh +echo "HOME=$HOME" \ No newline at end of file diff --git a/agent/server/snykbroker/acceptfile/plugin_fail.sh b/agent/server/snykbroker/acceptfile/plugin_fail.sh new file mode 100755 index 0000000..f84ba43 --- /dev/null +++ b/agent/server/snykbroker/acceptfile/plugin_fail.sh @@ -0,0 +1,2 @@ +#! /bin/sh +exit 1 \ No newline at end of file diff --git a/agent/server/snykbroker/acceptfile/plugin_test.go b/agent/server/snykbroker/acceptfile/plugin_test.go new file mode 100644 index 0000000..19abfc8 --- /dev/null +++ b/agent/server/snykbroker/acceptfile/plugin_test.go @@ -0,0 +1,118 @@ +package acceptfile + +import ( + "os" + "testing" + + "github.com/stretchr/testify/require" + "go.uber.org/zap" +) + +func TestEnvVarTypes(t *testing.T) { + content := ` + ${NORMAL_ENV_VAR} + ${plugin:PLUGIN_NAME} + ${env:ENV_VAR} + ` + + os.Setenv("NORMAL_ENV_VAR", "normal_value") + os.Setenv("PLUGIN_NAME", "INVALID") + os.Setenv("ENV_VAR", "env_value") + + t.Cleanup(func() { + os.Unsetenv("NORMAL_ENV_VAR") + os.Unsetenv("PLUGIN_NAME") + os.Unsetenv("ENV_VAR") + }) + + fileVars := findFileVars(content) + + expectedVars := []fileVar{ + {Name: "NORMAL_ENV_VAR", Type: VarTypeEnv, Original: "${NORMAL_ENV_VAR}"}, + {Name: "PLUGIN_NAME", Type: VarTypePlugin, Original: "${plugin:PLUGIN_NAME}"}, + {Name: "ENV_VAR", Type: VarTypeEnv, Original: "${env:ENV_VAR}"}, + } + + require.ElementsMatch(t, expectedVars, fileVars, "File vars do not match expected values") + +} + +func TestParsePluginVar(t *testing.T) { + content := `${plugin:my-plugin}` + item := parseEnvType(content) + require.Equal(t, "my-plugin", item.Name, "Plugin name should match") + require.Equal(t, VarTypePlugin, item.Type, "Variable type should be VarTypePlugin") +} + +func TestParseReplacedFormat(t *testing.T) { + content := `{{plugin:my-plugin}}` + item := parseInterpolation(content) + require.Equal(t, "my-plugin", item.Name, "Plugin name should match") + require.Equal(t, VarTypePlugin, item.Type, "Variable type should be VarTypePlugin") +} + +func TestCreatePluginResolver(t *testing.T) { + // Assuming the plugin.sh is in the same directory as the test file + + logger := zap.NewNop() + + // Create a plugin resolver + resolver := CreateResolver("{{plugin:plugin.sh}}", logger, []string{"."}) + + // Execute the resolver and get the output + output := resolver.Resolve() + require.NotEmpty(t, output, "Output should not be empty") + require.Contains(t, output, "HOME="+os.Getenv("HOME"), "Output should contain $HOME, but was: "+output) +} + +func TestPluginNotPresent(t *testing.T) { + + require.Panics(t, func() { + NewPlugin("test-plugin", "bad-path", zap.NewNop()) + }) + +} + +func TestPluginExecution(t *testing.T) { + // Assuming the plugin.sh is in the same directory as the test file + pluginPath := "./plugin.sh" + plugin := NewPlugin("test-plugin", pluginPath, zap.NewNop()) + output, err := plugin.Execute() + require.NoError(t, err, "Plugin execution should not return an error") + require.NotEmpty(t, output, "Output should not be empty") + require.Contains(t, output, "HOME="+os.Getenv("HOME"), "Output should contain $HOME, but was: "+output) +} + +func TestPluginExecutionFail(t *testing.T) { + // Assuming the plugin.sh is in the same directory as the test file + pluginPath := "./plugin_fail.sh" + plugin := NewPlugin("test-plugin-fail", pluginPath, zap.NewNop()) + output, err := plugin.Execute() + require.Error(t, err, "Plugin execution should not return an error") + require.Contains(t, err.Error(), "exit status 1", "Error message should indicate failure to run") + require.Empty(t, output, "Output should not be empty") +} + +func TestFindPlugin(t *testing.T) { + // Assuming the plugin.sh is in the same directory as the test file + pluginFile := "plugin.sh" + logger := zap.NewNop() + + // Test finding an existing plugin + plugin, err := FindPlugin(pluginFile, []string{"."}, logger) + require.NoError(t, err, "Should find the plugin") + require.Equal(t, pluginFile, plugin.Name) + require.Equal(t, "./plugin.sh", plugin.FullPath) + + // Test finding a non-existing plugin + _, err = FindPlugin("nonexistent.sh", []string{"."}, logger) + require.Error(t, err, "Should not find a non-existing plugin") + + // Test finding non-executable plugin + pluginFile = "plugin_test.go" + _, err = os.Stat(pluginFile) + require.NoError(t, err, "Should be able to stat the plugin file") + _, err = FindPlugin(pluginFile, []string{"."}, logger) + require.Error(t, err, "Should not find a non-executable plugin") + +} diff --git a/agent/server/snykbroker/acceptfile/resolver.go b/agent/server/snykbroker/acceptfile/resolver.go new file mode 100644 index 0000000..4f64a4b --- /dev/null +++ b/agent/server/snykbroker/acceptfile/resolver.go @@ -0,0 +1,255 @@ +package acceptfile + +import ( + "fmt" + "os" + "regexp" + "sort" + "strings" + + "go.uber.org/zap" +) + +// preProcessContent processes the content to replace plugin placeholders +// for simplicity the vars in the raw accept file are of the form +// "${plugin:foo}" and we replace them with "{{plugin:foo}}" to avoid +// conflicts with env variables os we can use the system env expansion +func preProcessContent(content []byte) ([]byte, error) { + // Change all "${plugin:foo}" to "{{ plugin:foo }}" + rePlugin := regexp.MustCompile(`\$\{plugin:([^}]+)\}`) + content = rePlugin.ReplaceAll(content, []byte("{{plugin:$1}}")) + + // Change all "${env:FOO}" to "${FOO}" (unchanged) + reEnv := regexp.MustCompile(`\$\{env:([^}]+)\}`) + content = reEnv.ReplaceAll(content, []byte("${$1}")) + + return content, nil +} + +// CreateResolver takes content and resolves it to a function that will do any +// interpolation and plugin execution when called. +// It expands environment variables and finds plugins in the provided directories. +// Then for each execution we loop the plugins and replace the plugin placeholder content +func CreateResolver(value string, logger *zap.Logger, pluginDirs []string) ValueResolver { + + // here we expand all the env vars and then + // look for plugins and execute those. + content := os.ExpandEnv(value) + plugins := findPlugins(content, pluginDirs, logger) + + return ValueResolver{ + Key: value, + Resolve: func() string { + + execContent := content + for _, plugin := range plugins { + // replace the plugin variable with the output of the plugin + pluginOutput, err := plugin.Plugin.Execute() + if err != nil { + logger.Error("failed to execute plugin", zap.String("plugin", plugin.Plugin.FullPath), zap.Error(err)) + continue + } + execContent = strings.ReplaceAll(execContent, plugin.Content, strings.Trim(pluginOutput, "\n")) + } + return execContent + }, + } +} + +type PluginResult struct { + Plugin + Content string +} + +// findPlugins finds all plugin invocations in the content and returns a list of PluginResults +// each of which represents a plugin that was found and its content in the accept file. +func findPlugins(content string, pluginDirs []string, logger *zap.Logger) []PluginResult { + + pluginStrings := reInterpolation.FindAllString(content, -1) + + seen := map[string]bool{} + + plugins := make([]PluginResult, 0, len(pluginStrings)) + for _, pluginString := range pluginStrings { + + if seen[pluginString] { + continue + } + + result := parseInterpolation(pluginString) + if result.Type != VarTypePlugin { + logger.Panic("expected plugin type", zap.String("plugin", result.Name), zap.String("content", content)) + } + plugin, err := FindPlugin(result.Name, pluginDirs, logger) + if err != nil { + logger.Panic( + fmt.Sprintf("failed to find plugin from %q", result.Name), + zap.String("plugin", result.Name), + zap.String("workingDir", os.Getenv("PWD")), + zap.Error(err)) + } + plugins = append(plugins, PluginResult{ + Plugin: plugin, + Content: pluginString, + }) + seen[result.Name] = true + } + + return plugins +} + +type ValueResolver struct { + Resolve func() string + Key string +} + +func StringValueResolver(value string) ValueResolver { + return ValueResolver{ + Resolve: func() string { + return value + }, + Key: value, + } +} + +type ResolverMap map[string]ValueResolver + +func NewResolverMapFromMap(m map[string]string) ResolverMap { + rm := make(ResolverMap, len(m)) + for key, value := range m { + rm[key] = StringValueResolver(value) + } + return rm +} + +func (rm ResolverMap) ToStringMap() map[string]string { + resolved := make(map[string]string, len(rm)) + for key, resolver := range rm { + resolved[key] = resolver.Resolve() + } + return resolved +} + +func (rm ResolverMap) Resolve(key string) string { + resolver, ok := rm[key] + if !ok { + return "" + } + return resolver.Resolve() +} + +func (rm ResolverMap) ResolverKey(key string) string { + resolver, ok := rm[key] + if !ok { + return "" + + } + return resolver.Key +} + +// Parsing code for extracting environment variables and plugin invocations from content +// Examples: `${env:API}`, `${plugin:my-plugin}`, `${API}` +var reContentVars = regexp.MustCompile(`(?m)\$\{([^:}]+):([^}]+)\}|\$\{([^}]+)\}`) + +const VAR_TYPE_INDEX = 1 +const VAR_NAME_INDEX = 2 +const VAR_NAME_ONLY_INDEX = 3 + +// define an enum of variable types +type varType int + +const ( + VarTypeEnv varType = iota + VarTypePlugin +) + +type fileVar struct { + Name string + Type varType + Original string +} + +func (vt varType) String() string { + switch vt { + case VarTypeEnv: + return "Env" + case VarTypePlugin: + return "Plugin" + default: + return "Unknown" + } +} + +// Parser for interpolation in the accept file +// This is used to parse the `{{plugin:my-plugin}}` format in the accept file. +var reInterpolation = regexp.MustCompile(`\{\{([^}]+)\}\}`) + +func parseInterpolation(content string) fileVar { + match := reInterpolation.FindStringSubmatch(content) + if match == nil { + panic(fmt.Sprintf("invalid interpolation format: %q", content)) + } + found := strings.Trim(match[1], " ") + parts := strings.Split(found, ":") + return fileVar{ + Name: strings.Trim(parts[1], " "), + Type: VarTypePlugin, + Original: match[0], + } +} + +func parseEnvType(content string) fileVar { + + match := reContentVars.FindStringSubmatch(content) + if len(match) < 4 { + panic(fmt.Sprintf("invalid env var format %q", content)) + } + varTypeName := match[VAR_TYPE_INDEX] + value := match[VAR_NAME_INDEX] + if value == "" { + value = match[VAR_NAME_ONLY_INDEX] + } + + switch varTypeName { + case "env", "": + return fileVar{value, VarTypeEnv, content} + case "plugin": + return fileVar{value, VarTypePlugin, content} + default: + panic(fmt.Sprintf("unknown env var type %q", content)) + } +} + +func findFileVars(content string) []fileVar { + varMatch := reContentVars.FindAllStringSubmatch(content, -1) + + envVars := []fileVar{} + + // sort these so they have a stable order + for _, match := range varMatch { + fv := parseEnvType(match[0]) + envVars = append(envVars, fv) + } + + sort.Slice(envVars, func(i, j int) bool { + return envVars[i].Name < envVars[j].Name + }) + return envVars +} + +func ensureAcceptFileVars(content string) error { + + fileVars := findFileVars(content) + + for _, envVar := range fileVars { + if envVar.Type != VarTypeEnv { + continue + } + + envVar := envVar.Name + if os.Getenv(envVar) == "" && os.Getenv(envVar+"_POOL") == "" { + return fmt.Errorf("missing required environment variable %q", envVar) + } + } + return nil +} diff --git a/agent/server/snykbroker/reflector.go b/agent/server/snykbroker/reflector.go index 5c3d461..ed7dc35 100644 --- a/agent/server/snykbroker/reflector.go +++ b/agent/server/snykbroker/reflector.go @@ -10,9 +10,9 @@ import ( "strings" "sync/atomic" - "github.com/cortexapps/axon/common" "github.com/cortexapps/axon/config" cortexHttp "github.com/cortexapps/axon/server/http" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus" "go.uber.org/fx" @@ -94,7 +94,7 @@ func (rr *RegistrationReflector) Stop() error { return nil } -func (rr *RegistrationReflector) getProxy(targetURI string, isDefault bool, headers common.ResolverMap) (*proxyEntry, error) { +func (rr *RegistrationReflector) getProxy(targetURI string, isDefault bool, headers acceptfile.ResolverMap) (*proxyEntry, error) { if targetURI == "" { return nil, fmt.Errorf("target URI cannot be empty") @@ -121,6 +121,8 @@ func (rr *RegistrationReflector) getProxy(targetURI string, isDefault bool, head rr.logger.Info("Registered redirector", zap.String("targetURI", entry.TargetURI), zap.String("proxyURI", entry.proxyURI), + zap.Bool("isDefault", entry.isDefault), + zap.String("key", key), zap.Any("headers", headers), ) return &entry, nil @@ -171,7 +173,7 @@ type ProxyOption func(*proxyOption) type proxyOption struct { isDefault bool - headerResolvers common.ResolverMap + headerResolvers acceptfile.ResolverMap } func WithDefault(value bool) ProxyOption { @@ -183,20 +185,34 @@ func WithDefault(value bool) ProxyOption { func WithHeaders(headers map[string]string) ProxyOption { return func(option *proxyOption) { if option.headerResolvers == nil { - option.headerResolvers = make(common.ResolverMap, len(headers)) + option.headerResolvers = make(acceptfile.ResolverMap, len(headers)) } for k, v := range headers { - option.headerResolvers[k] = common.StringValueResolver(v) + option.headerResolvers[k] = acceptfile.StringValueResolver(v) } } } -func WithHeadersResolver(headers common.ResolverMap) ProxyOption { +func WithHeadersResolver(headers acceptfile.ResolverMap) ProxyOption { return func(option *proxyOption) { option.headerResolvers = headers } } +func (rr *RegistrationReflector) getUriForTarget(target string) (string, error) { + + if target == "" { + return "", fmt.Errorf("target URI cannot be empty") + } + + for _, entry := range rr.targets { + if entry.TargetURI == target { + return entry.proxyURI, nil + } + } + return "", fmt.Errorf("no proxy entry found for target URI: %s", target) +} + func (rr *RegistrationReflector) ProxyURI(target string, options ...ProxyOption) string { opts := &proxyOption{} @@ -220,9 +236,13 @@ func (rr *RegistrationReflector) RegisterRoutes(mux *mux.Router) error { func (rr *RegistrationReflector) ServeHTTP(w http.ResponseWriter, r *http.Request) { + rr.logger.Debug("Received request for proxy", + zap.String("method", r.Method), + zap.String("path", r.URL.Path), + ) entry, newPath, err := rr.parseTargetUri(r.URL.Path) if err != nil { - rr.logger.Error("Failed to parse target URI", zap.Error(err)) + rr.logger.Error("Failed to find Entry for target URI", zap.Error(err)) http.Error(w, "Invalid target URI", http.StatusBadGateway) w.WriteHeader(http.StatusBadGateway) return @@ -232,6 +252,12 @@ func (rr *RegistrationReflector) ServeHTTP(w http.ResponseWriter, r *http.Reques newPath = "/" + newPath } r.URL.Path = newPath + rr.logger.Debug("Proxying request", + zap.String("targetURI", entry.TargetURI), + zap.String("proxyURI", entry.proxyURI), + zap.String("key", entry.key()), + zap.String("newPath", newPath), + ) entry.handler.ServeHTTP(w, r) } @@ -246,12 +272,12 @@ type proxyEntry struct { TargetURI string // Exported for clean access proxyURI string handler http.Handler - headers common.ResolverMap + headers acceptfile.ResolverMap responseHeaders map[string]string hashCode string } -func newProxyEntry(targetURI string, isDefault bool, port int, headers common.ResolverMap, transport *http.Transport) (*proxyEntry, error) { +func newProxyEntry(targetURI string, isDefault bool, port int, headers acceptfile.ResolverMap, transport *http.Transport) (*proxyEntry, error) { if targetURI == "" { return nil, fmt.Errorf("target URI cannot be empty") } @@ -315,7 +341,7 @@ func (pe *proxyEntry) key() string { // Create a unique key that includes headers to allow different header sets for the same URI headerKey := "" for k := range pe.headers { - headerKey += fmt.Sprintf("|%s=%s", k, pe.headers.Resolve(k)) + headerKey += fmt.Sprintf("|%s=%s", k, pe.headers.ResolverKey(k)) } key = key + headerKey } diff --git a/agent/server/snykbroker/reflector_headers_simple_test.go b/agent/server/snykbroker/reflector_headers_simple_test.go index 5d1b9ec..36b8387 100644 --- a/agent/server/snykbroker/reflector_headers_simple_test.go +++ b/agent/server/snykbroker/reflector_headers_simple_test.go @@ -6,7 +6,7 @@ import ( "os" "testing" - "github.com/cortexapps/axon/common" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -40,7 +40,7 @@ func TestHeaderApplicationInProxy(t *testing.T) { } // Create proxy with headers - proxyEntry, err := newProxyEntry(backendServer.URL, false, 8080, common.NewResolverMapFromMap(headers), nil) + proxyEntry, err := newProxyEntry(backendServer.URL, false, 8080, acceptfile.NewResolverMapFromMap(headers), nil) proxyEntry.addResponseHeader("x-response", "response-value") require.NoError(t, err) require.NotNil(t, proxyEntry) @@ -56,18 +56,18 @@ func TestHeaderApplicationInProxy(t *testing.T) { proxyEntry.handler.ServeHTTP(rr, req) // Verify the request was successful - assert.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, http.StatusOK, rr.Code) // Verify that headers were applied to the backend request - assert.Equal(t, "secret-key-123", receivedHeaders.Get("x-api-key")) - assert.Equal(t, "Bearer bearer-token-456", receivedHeaders.Get("authorization")) - assert.Equal(t, "static-value", receivedHeaders.Get("x-static")) - assert.Equal(t, "", receivedHeaders.Get("x-response")) + require.Equal(t, "secret-key-123", receivedHeaders.Get("x-api-key")) + require.Equal(t, "Bearer bearer-token-456", receivedHeaders.Get("authorization")) + require.Equal(t, "static-value", receivedHeaders.Get("x-static")) + require.Equal(t, "", receivedHeaders.Get("x-response")) // Verify original headers are preserved - assert.Equal(t, "test-client", receivedHeaders.Get("user-agent")) + require.Equal(t, "test-client", receivedHeaders.Get("user-agent")) - assert.Equal(t, "response-value", rr.Header().Get("x-response")) + require.Equal(t, "response-value", rr.Header().Get("x-response")) // Verify the response body assert.JSONEq(t, `{"message": "success"}`, rr.Body.String()) @@ -98,7 +98,7 @@ func TestMultipleProxiesWithDifferentHeaders(t *testing.T) { "x-api-key": "key-for-server-1", "x-service": "service-1", } - proxy1, err := newProxyEntry(server1.URL, false, 8080, common.NewResolverMapFromMap(headers1), nil) + proxy1, err := newProxyEntry(server1.URL, false, 8080, acceptfile.NewResolverMapFromMap(headers1), nil) require.NoError(t, err) // Create second proxy with different headers @@ -106,7 +106,7 @@ func TestMultipleProxiesWithDifferentHeaders(t *testing.T) { "x-api-key": "key-for-server-2", "x-service": "service-2", } - proxy2, err := newProxyEntry(server2.URL, false, 8080, common.NewResolverMapFromMap(headers2), nil) + proxy2, err := newProxyEntry(server2.URL, false, 8080, acceptfile.NewResolverMapFromMap(headers2), nil) require.NoError(t, err) // Send requests through both proxies @@ -119,15 +119,15 @@ func TestMultipleProxiesWithDifferentHeaders(t *testing.T) { proxy2.handler.ServeHTTP(rr2, req2) // Verify both requests were successful - assert.Equal(t, http.StatusOK, rr1.Code) - assert.Equal(t, http.StatusOK, rr2.Code) + require.Equal(t, http.StatusOK, rr1.Code) + require.Equal(t, http.StatusOK, rr2.Code) // Verify correct headers were sent to each server - assert.Equal(t, "key-for-server-1", server1Headers.Get("x-api-key")) - assert.Equal(t, "service-1", server1Headers.Get("x-service")) + require.Equal(t, "key-for-server-1", server1Headers.Get("x-api-key")) + require.Equal(t, "service-1", server1Headers.Get("x-service")) - assert.Equal(t, "key-for-server-2", server2Headers.Get("x-api-key")) - assert.Equal(t, "service-2", server2Headers.Get("x-service")) + require.Equal(t, "key-for-server-2", server2Headers.Get("x-api-key")) + require.Equal(t, "service-2", server2Headers.Get("x-service")) } // TestProxyWithNoHeaders tests that proxies work correctly without headers @@ -166,11 +166,11 @@ func TestProxyWithNoHeaders(t *testing.T) { proxyEntry.handler.ServeHTTP(rr, req) // Verify the request was successful - assert.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, http.StatusOK, rr.Code) // Verify original headers are preserved (no custom headers added) - assert.Equal(t, "test-client", receivedHeaders.Get("user-agent")) - assert.Equal(t, "application/json", receivedHeaders.Get("content-type")) + require.Equal(t, "test-client", receivedHeaders.Get("user-agent")) + require.Equal(t, "application/json", receivedHeaders.Get("content-type")) // Verify no custom headers were added assert.Empty(t, receivedHeaders.Get("x-api-key")) @@ -206,7 +206,7 @@ func TestHeaderOverwriting(t *testing.T) { } // Create proxy with headers - proxyEntry, err := reflector.getProxy(backendServer.URL, false, common.NewResolverMapFromMap(headers)) + proxyEntry, err := reflector.getProxy(backendServer.URL, false, acceptfile.NewResolverMapFromMap(headers)) require.NoError(t, err) // Create request with original headers @@ -221,12 +221,12 @@ func TestHeaderOverwriting(t *testing.T) { proxyEntry.handler.ServeHTTP(rr, req) // Verify the request was successful - assert.Equal(t, http.StatusOK, rr.Code) + require.Equal(t, http.StatusOK, rr.Code) // Verify that custom headers overwrote original headers - assert.Equal(t, "proxy-agent", receivedHeaders.Get("user-agent")) // Should be overwritten - assert.Equal(t, "Bearer custom-token", receivedHeaders.Get("authorization")) + require.Equal(t, "proxy-agent", receivedHeaders.Get("user-agent")) // Should be overwritten + require.Equal(t, "Bearer custom-token", receivedHeaders.Get("authorization")) // Verify that non-conflicting headers are preserved - assert.Equal(t, "application/json", receivedHeaders.Get("content-type")) + require.Equal(t, "application/json", receivedHeaders.Get("content-type")) } diff --git a/agent/server/snykbroker/reflector_headers_test.go b/agent/server/snykbroker/reflector_headers_test.go index 3be922f..c9e6b47 100644 --- a/agent/server/snykbroker/reflector_headers_test.go +++ b/agent/server/snykbroker/reflector_headers_test.go @@ -10,8 +10,8 @@ import ( "testing" "time" - "github.com/cortexapps/axon/common" "github.com/cortexapps/axon/config" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/zap" @@ -21,7 +21,7 @@ import ( func TestAcceptFileHeadersAppliedToLiveRequests(t *testing.T) { - common.IgnoreHosts = []string{} + acceptfile.IgnoreHosts = []string{} // Set up environment variables for testing os.Setenv("TEST_API_KEY", "secret-api-key-123") @@ -127,6 +127,27 @@ func TestAcceptFileHeadersAppliedToLiveRequests(t *testing.T) { index: 1, testPath: "/api-v2/test", }, + { + name: "header with env var and plugin", + acceptContent: map[string]any{ + "private": []any{ + map[string]any{ + "method": "GET", + "path": "/api/*", + "origin": "http://example.com", + "headers": map[string]any{ + "x-api-key": "${TEST_API_KEY}", + "x-plugin-header": "${plugin:plugin.sh}", + }, + }, + }, + }, + expectedHeaders: map[string]string{ + "x-api-key": "secret-api-key-123", + "x-plugin-header": "HOME=" + os.Getenv("HOME"), + }, + testPath: "/api/test", + }, } for _, tt := range tests { @@ -143,23 +164,15 @@ func TestAcceptFileHeadersAppliedToLiveRequests(t *testing.T) { // Update the accept content to use the real backend server URL updateOriginInAcceptContent(tt.acceptContent, backendServer.URL) - // Create temporary accept file - acceptFile := createTempAcceptFile(t, tt.acceptContent) - defer os.Remove(acceptFile) - - // Create integration info and process the accept file - integrationInfo := common.IntegrationInfo{ - Integration: common.IntegrationCustom, - AcceptFilePath: acceptFile, - } - // Create reflector logger := zap.NewNop() + cfg := config.AgentConfig{ + HttpRelayReflectorMode: config.RelayReflectorAllTraffic, + PluginDirs: []string{".", "./acceptfile"}, + } reflector := NewRegistrationReflector(RegistrationReflectorParams{ Logger: logger, - Config: config.AgentConfig{ - HttpRelayReflectorMode: config.RelayReflectorAllTraffic, - }, + Config: cfg, }) // Start the reflector server @@ -168,37 +181,53 @@ func TestAcceptFileHeadersAppliedToLiveRequests(t *testing.T) { defer reflector.Stop() // Process the accept file with header extraction - var capturedOrigin string - var capturedHeaders map[string]string - - info, err := integrationInfo.RewriteOrigins( - acceptFile, - func(originalURI string, headers common.ResolverMap) string { - capturedOrigin = originalURI - capturedHeaders = headers.ToStringMap() - return reflector.ProxyURI(originalURI, WithHeaders(headers.ToStringMap())) + capturedOrigins := make([]string, 0) + capturedHeaders := make([]map[string]string, 0) + + jsonContent, err := json.MarshalIndent(tt.acceptContent, "", " ") + require.NoError(t, err) + + af, err := acceptfile.NewAcceptFile(jsonContent, cfg) + require.NoError(t, err) + proxyUris := []string{} + + _, err = af.Render( + logger, + func(renderContext acceptfile.RenderContext) error { + for _, entry := range renderContext.AcceptFile.PrivateRules() { + originalURI := entry.Origin() + if originalURI == cfg.HttpBaseUrl() { + continue + } + capturedOrigins = append(capturedOrigins, originalURI) + headers := entry.Headers() + capturedHeaders = append(capturedHeaders, headers.ToStringMap()) + newURI := reflector.ProxyURI(originalURI, WithHeadersResolver(headers)) + entry.SetOrigin(newURI) + proxyUris = append(proxyUris, newURI) + } + return nil }, ) require.NoError(t, err) - require.NotNil(t, info) // Verify headers were captured correctly - assert.Equal(t, backendServer.URL, capturedOrigin) - assert.Equal(t, tt.expectedHeaders, capturedHeaders) + require.Equal(t, backendServer.URL, capturedOrigins[tt.index]) + require.Equal(t, tt.expectedHeaders, capturedHeaders[tt.index]) // Make a live HTTP request through the proxy - proxyURL := fmt.Sprintf("%s%s", info.Routes[tt.index].ProxyOrigin, tt.testPath) + proxyURL := fmt.Sprintf("%s%s", proxyUris[tt.index], tt.testPath) resp, err := http.Get(proxyURL) require.NoError(t, err) defer resp.Body.Close() // Verify the request was successful - assert.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, http.StatusOK, resp.StatusCode) // Verify that all expected headers were received by the backend for expectedKey, expectedValue := range tt.expectedHeaders { actualValue := receivedHeaders.Get(expectedKey) - assert.Equal(t, expectedValue, actualValue, + require.Equal(t, expectedValue, actualValue, "Header %s: expected %s, got %s", expectedKey, expectedValue, actualValue) } @@ -210,54 +239,6 @@ func TestAcceptFileHeadersAppliedToLiveRequests(t *testing.T) { } } -func TestAcceptFileHeadersWithMissingEnvVars(t *testing.T) { - // Ensure env var is not set - os.Unsetenv("MISSING_ENV_VAR") - - acceptContent := map[string]any{ - "private": []any{ - map[string]any{ - "method": "GET", - "path": "/api/*", - "origin": "http://example.com", - "headers": map[string]any{ - "x-api-key": "${MISSING_ENV_VAR}", - }, - }, - }, - } - - acceptFile := createTempAcceptFile(t, acceptContent) - defer os.Remove(acceptFile) - - integrationInfo := common.IntegrationInfo{ - Integration: common.IntegrationCustom, - AcceptFilePath: acceptFile, - } - - logger := zap.NewNop() - reflector := NewRegistrationReflector(RegistrationReflectorParams{ - Logger: logger, - Config: config.AgentConfig{ - HttpRelayReflectorMode: config.RelayReflectorAllTraffic, - }, - }) - - var capturedHeaders map[string]string - - _, err := integrationInfo.RewriteOrigins( - acceptFile, - func(originalURI string, headers common.ResolverMap) string { - capturedHeaders = headers.ToStringMap() - return reflector.ProxyURI(originalURI, WithHeadersResolver(headers)) - }, - ) - require.NoError(t, err) - - // Verify that missing env vars result in empty string (os.ExpandEnv behavior) - assert.Equal(t, "", capturedHeaders["x-api-key"]) -} - func TestMultipleRoutesWithDifferentHeaders(t *testing.T) { os.Setenv("API_KEY_1", "key-one") os.Setenv("API_KEY_2", "key-two") @@ -309,17 +290,13 @@ func TestMultipleRoutesWithDifferentHeaders(t *testing.T) { acceptFile := createTempAcceptFile(t, acceptContent) defer os.Remove(acceptFile) - integrationInfo := common.IntegrationInfo{ - Integration: common.IntegrationCustom, - AcceptFilePath: acceptFile, - } - logger := zap.NewNop() + cfg := config.AgentConfig{ + HttpRelayReflectorMode: config.RelayReflectorAllTraffic, + } reflector := NewRegistrationReflector(RegistrationReflectorParams{ Logger: logger, - Config: config.AgentConfig{ - HttpRelayReflectorMode: config.RelayReflectorAllTraffic, - }, + Config: cfg, }) _, err := reflector.Start() @@ -328,41 +305,64 @@ func TestMultipleRoutesWithDifferentHeaders(t *testing.T) { // Process the accept file headerExtractionCount := 0 - info, err := integrationInfo.RewriteOrigins( - acceptFile, - func(originalURI string, headers common.ResolverMap) string { - headerExtractionCount++ - return reflector.ProxyURI(originalURI, WithHeadersResolver(headers)) + + capturedOrigins := make([]string, 0) + capturedHeaders := make([]map[string]string, 0) + + jsonContent, err := json.MarshalIndent(acceptContent, "", " ") + require.NoError(t, err) + + af, err := acceptfile.NewAcceptFile(jsonContent, cfg) + require.NoError(t, err) + proxyUris := []string{} + + _, err = af.Render( + logger, + func(renderContext acceptfile.RenderContext) error { + for _, entry := range renderContext.AcceptFile.PrivateRules() { + originalURI := entry.Origin() + if originalURI == cfg.HttpBaseUrl() { + continue + } + headerExtractionCount++ + capturedOrigins = append(capturedOrigins, originalURI) + headers := entry.Headers().ToStringMap() + capturedHeaders = append(capturedHeaders, headers) + newURI := reflector.ProxyURI(originalURI, WithHeaders(headers)) + entry.SetOrigin(newURI) + proxyUris = append(proxyUris, newURI) + } + return nil }, ) - require.NotNil(t, info) + require.NoError(t, err) // Verify that headers were extracted for both routes - assert.Equal(t, 2, headerExtractionCount) + require.Equal(t, 2, headerExtractionCount) // Make requests to both routes - resp1, err := http.Get(fmt.Sprintf("%s/api1/test", info.Routes[0].ProxyOrigin)) + resp1, err := http.Get(fmt.Sprintf("%s/api1/test", proxyUris[0])) require.NoError(t, err) defer resp1.Body.Close() - resp2, err := http.Get(fmt.Sprintf("%s/api2/test", info.Routes[1].ProxyOrigin)) + resp2, err := http.Get(fmt.Sprintf("%s/api2/test", proxyUris[1])) require.NoError(t, err) defer resp2.Body.Close() // Verify both requests were successful - assert.Equal(t, http.StatusOK, resp1.StatusCode) - assert.Equal(t, http.StatusOK, resp2.StatusCode) + require.Equal(t, http.StatusOK, resp1.StatusCode) + require.Equal(t, http.StatusOK, resp2.StatusCode) // Give servers time to process requests time.Sleep(100 * time.Millisecond) // Verify correct headers were sent to each server - assert.Equal(t, "key-one", server1Headers.Get("x-api-key")) - assert.Equal(t, "service-1", server1Headers.Get("x-service")) + require.Equal(t, "key-one", server1Headers.Get("x-api-key")) + require.Equal(t, "service-1", server1Headers.Get("x-service")) - assert.Equal(t, "key-two", server2Headers.Get("x-api-key")) - assert.Equal(t, "service-2", server2Headers.Get("x-service")) + require.Equal(t, "key-two", server2Headers.Get("x-api-key")) + require.Equal(t, "service-2", server2Headers.Get("x-service")) } // Helper functions diff --git a/agent/server/snykbroker/relay_instance_manager.go b/agent/server/snykbroker/relay_instance_manager.go index 8873487..74c4f99 100644 --- a/agent/server/snykbroker/relay_instance_manager.go +++ b/agent/server/snykbroker/relay_instance_manager.go @@ -15,6 +15,7 @@ import ( "github.com/cortexapps/axon/common" "github.com/cortexapps/axon/config" cortexHttp "github.com/cortexapps/axon/server/http" + "github.com/cortexapps/axon/server/snykbroker/acceptfile" "github.com/cortexapps/axon/util" "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus" @@ -299,10 +300,52 @@ func (r *relayInstanceManager) Start() error { executable = directPath } - acceptFile, err := r.integrationInfo.AcceptFile(r.config.HttpBaseUrl()) + af, err := r.integrationInfo.ToAcceptFile(r.config) + if err != nil { + r.logger.Error("Error creating accept file", zap.Error(err)) + return fmt.Errorf("error creating accept file: %w", err) + } + + rendered, err := af.Render(r.logger, func(renderContext acceptfile.RenderContext) error { + if r.reflector != nil { + + // Here we loop all the private (incoming) routes and do two things + // 1. We rewrite the origin to point back to the reflector. This captures the original URI so + // that the reflector can proxy the request to the correct origin. + // 2. We add any custom headers that are defined in the route, which is a functional addition not available + // in the original accept file / snyk-broker. + // + // The returned proxyURI is an encoded URI path that has an additional path section which is used + // to identify the original route and headers. + + for _, route := range renderContext.AcceptFile.PrivateRules() { + headers := route.Headers() + if len(headers) > 0 && r.config.HttpRelayReflectorMode != config.RelayReflectorAllTraffic { + panic("HttpRelayReflectorMode must be set to 'all' to add custom headers") + } + + routeUri := r.reflector.ProxyURI(route.Origin(), WithHeadersResolver(headers)) + route.SetOrigin(routeUri) + } + } + return nil + }) + + if err != nil { + r.logger.Error("Error rendering accept file", zap.Error(err)) + return fmt.Errorf("error rendering accept file: %w", err) + } + + tmpDir, err := os.MkdirTemp(os.TempDir(), "axon-accept-files-*") + if err != nil { + r.logger.Error("Error creating temp directory for accept file", zap.Error(err)) + return fmt.Errorf("error creating temp directory for accept file: %w", err) + } + tmpAcceptFile := path.Join(tmpDir, "accept.json") + err = os.WriteFile(tmpAcceptFile, rendered, 0644) if err != nil { - fmt.Println("Error getting accept file", err) + fmt.Println("Error writing accept file", err) panic(err) } @@ -388,13 +431,11 @@ func (r *relayInstanceManager) Start() error { zap.Strings("args", args), zap.String("token", info.Token), zap.String("uri", info.ServerUri), - zap.String("acceptFile", acceptFile), + zap.String("acceptFile", tmpAcceptFile), ) - acceptFile = r.applyAcceptFileTransforms(acceptFile) - brokerEnv := map[string]string{ - "ACCEPT": acceptFile, + "ACCEPT": tmpAcceptFile, "BROKER_SERVER_URL": info.ServerUri, "BROKER_TOKEN": info.Token, "PORT": fmt.Sprintf("%d", r.getSnykBrokerPort()), @@ -455,31 +496,6 @@ func (r *relayInstanceManager) Start() error { return err } -func (r *relayInstanceManager) applyAcceptFileTransforms(acceptFile string) string { - if r.config.HttpRelayReflectorMode == config.RelayReflectorAllTraffic && r.reflector != nil { - - info, err := r.integrationInfo.RewriteOrigins(acceptFile, func(uri string, headers common.ResolverMap) string { - r.logger.Info("Rewriting accept file URI", zap.String("uri", uri), zap.Any("headers", headers)) - return r.reflector.ProxyURI(uri, WithHeadersResolver(headers)) - }) - - if err != nil { - r.logger.Error("Error creating accept file", zap.String("acceptFile", acceptFile), zap.Error(err)) - return acceptFile // return original if error occurs - } - _, err = info.Rewrite() - - if err != nil { - r.logger.Error("Error rewriting accept file", zap.String("acceptFile", acceptFile), - - zap.Error(err)) - return acceptFile // return original if error occurs - } - return info.RewrittenPath - } - return acceptFile -} - func (r *relayInstanceManager) setHttpProxyEnvVars(brokerEnv map[string]string) { // This is mostly for testing so we can validate no traffic goes out from the broker diff --git a/agent/server/snykbroker/relay_instance_manager_test.go b/agent/server/snykbroker/relay_instance_manager_test.go index 55740a7..05a4f51 100644 --- a/agent/server/snykbroker/relay_instance_manager_test.go +++ b/agent/server/snykbroker/relay_instance_manager_test.go @@ -2,34 +2,37 @@ package snykbroker import ( "context" - "encoding/json" + "fmt" "net" "net/http" "net/http/httptest" + "net/http/httputil" "net/url" "os" "testing" "time" - "github.com/PaesslerAG/jsonpath" "github.com/cortexapps/axon/common" "github.com/cortexapps/axon/config" cortex_http "github.com/cortexapps/axon/server/http" "github.com/cortexapps/axon/util" "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.uber.org/fx/fxtest" "go.uber.org/mock/gomock" "go.uber.org/zap" ) +var defaultIntegrationInfo = common.IntegrationInfo{ + Integration: common.IntegrationGithub, +} + func TestManagerSuccess(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mgr := createTestRelayInstanceManager(t, controller, nil, false) + mgr := createTestRelayInstanceManager(t, controller, nil, false, defaultIntegrationInfo) err := mgr.Close() require.NoError(t, err) @@ -40,7 +43,7 @@ func TestManagerSuccessWithReflector(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mgr := createTestRelayInstanceManager(t, controller, nil, true) + mgr := createTestRelayInstanceManager(t, controller, nil, true, defaultIntegrationInfo) // call the reflector uri uri := mgr.reflector.ProxyURI(mgr.serverUri) @@ -56,7 +59,82 @@ func TestManagerSuccessWithReflector(t *testing.T) { require.NoError(t, err) require.Equal(t, 1, len(mgr.requestUrls), "Expected one request to the reflector URI") - assert.Equal(t, "/foo/bar", mgr.requestUrls[0].Path, "Expected request to the reflector URI to have the correct path") + require.Equal(t, "/foo/bar", mgr.requestUrls[0].Path, "Expected request to the reflector URI to have the correct path") + +} + +func TestManagerSuccessWithReflectorHeadersAndProxy(t *testing.T) { + + var mgr *wrappedRelayInstanceManager + controller := gomock.NewController(t) + defer controller.Finish() + + os.Setenv("PLUGIN_DIRS", "./acceptfile") + + // This is janky but we have a circular dep where we need to know the + // target server URI before we create the file, but we need to create the file + // before calling createTestRelayInstanceManager. So as a hack we create a second + // server that just forwards requests to the test server URI, once we know it. + knownTestServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // forward to the the test server + url, err := url.Parse(mgr.serverUri) + require.NoError(t, err, "Failed to parse test server URL") + proxy := httputil.NewSingleHostReverseProxy(url) + proxy.ServeHTTP(w, r) + })) + + t.Cleanup(knownTestServer.Close) + + content := fmt.Sprintf(` + { + "private": [ + { + "method": "GET", + "origin": "%s", + "path": "/foo/bar", + "headers": { + "x-plugin-value": "${plugin:plugin.sh}", + "x-other-header": "other-value" + } + } + ] + } + `, knownTestServer.URL) + + tmpFile := t.TempDir() + "/github.json" + err := os.WriteFile(tmpFile, []byte(content), 0644) + require.NoError(t, err, "Failed to write test accept file") + + ii := common.IntegrationInfo{ + Integration: common.IntegrationGithub, + AcceptFilePath: tmpFile, + } + + mgr = createTestRelayInstanceManager(t, controller, nil, true, ii) + + t.Cleanup(func() { + os.Remove(tmpFile) + os.Unsetenv("PLUGIN_DIRS") + }) + + // call the reflector uri + uri, err := mgr.reflector.getUriForTarget(knownTestServer.URL) + require.NoError(t, err, "Failed to get URI for target") + + req, err := http.NewRequest(http.MethodGet, uri+"/foo/bar", nil) + require.NoError(t, err) + resp, err := http.DefaultClient.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + + err = mgr.Close() + require.NoError(t, err) + + require.Equal(t, 1, len(mgr.requests), "Expected one request to the reflector URI") + require.Equal(t, "/foo/bar", mgr.requests[0].URL.Path, "Expected request to the reflector URI to have the correct path") + require.Equal(t, "HOME="+os.Getenv("HOME"), mgr.requests[0].Header.Get("x-plugin-value"), "Expected request to the reflector URI to have the correct plugin header value") + require.Equal(t, "other-value", mgr.requests[0].Header.Get("x-other-header"), "Expected request to the reflector URI to have the correct header value") } @@ -64,7 +142,7 @@ func TestManagerUnauthorized(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() - mgr := createTestRelayInstanceManager(t, controller, ErrUnauthorized, false) + mgr := createTestRelayInstanceManager(t, controller, ErrUnauthorized, false, defaultIntegrationInfo) err := mgr.Start() require.Error(t, err, ErrUnauthorized) @@ -87,9 +165,9 @@ func TestApplyValidationConfig(t *testing.T) { mgr := &relayInstanceManager{} mgr.applyClientValidationConfig(validationConfig, envVars) - assert.Equal(t, "https://api.github.com/user", envVars["BROKER_CLIENT_VALIDATION_URL"]) - assert.Equal(t, "POST", envVars["BROKER_CLIENT_VALIDATION_METHOD"]) - assert.Equal(t, "bearer the-token", envVars["BROKER_CLIENT_VALIDATION_AUTHORIZATION_HEADER"]) + require.Equal(t, "https://api.github.com/user", envVars["BROKER_CLIENT_VALIDATION_URL"]) + require.Equal(t, "POST", envVars["BROKER_CLIENT_VALIDATION_METHOD"]) + require.Equal(t, "bearer the-token", envVars["BROKER_CLIENT_VALIDATION_AUTHORIZATION_HEADER"]) } @@ -97,14 +175,14 @@ func TestLoadCertsDir(t *testing.T) { mgr := &relayInstanceManager{} path := mgr.getCertFilePath("../../test/certs") - assert.Equal(t, "../../test/certs/selfsigned-1.pem", path) + require.Equal(t, "../../test/certs/selfsigned-1.pem", path) } func TestLoadCertsFile(t *testing.T) { mgr := &relayInstanceManager{} path := mgr.getCertFilePath("../../test/certs/selfsigned-2.pem") - assert.Equal(t, "../../test/certs/selfsigned-2.pem", path) + require.Equal(t, "../../test/certs/selfsigned-2.pem", path) } func TestHttpProxy(t *testing.T) { @@ -124,10 +202,10 @@ func TestHttpProxy(t *testing.T) { env := map[string]string{} mgr.setHttpProxyEnvVars(env) - assert.Equal(t, "http://proxy.example.com:8080", env["HTTP_PROXY"]) - assert.Equal(t, "http://proxy.example.com:8080", env["HTTPS_PROXY"]) - assert.Equal(t, "localhost,127.0.0.1", env["NO_PROXY"]) - assert.Equal(t, cfg.HttpCaCertFilePath, env["NODE_EXTRA_CA_CERTS"]) + require.Equal(t, "http://proxy.example.com:8080", env["HTTP_PROXY"]) + require.Equal(t, "http://proxy.example.com:8080", env["HTTPS_PROXY"]) + require.Equal(t, "localhost,127.0.0.1", env["NO_PROXY"]) + require.Equal(t, cfg.HttpCaCertFilePath, env["NODE_EXTRA_CA_CERTS"]) } @@ -135,7 +213,7 @@ func TestHttpProxy(t *testing.T) { func TestRelayRestartServer(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mgr := createTestRelayInstanceManager(t, ctrl, nil, false) + mgr := createTestRelayInstanceManager(t, ctrl, nil, false, defaultIntegrationInfo) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/__axon/broker/restart", nil) @@ -145,15 +223,15 @@ func TestRelayRestartServer(t *testing.T) { httpHandler.RegisterRoutes(mux) mux.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, 2, int(mgr.Instance().startCount.Load())) + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, 2, int(mgr.Instance().startCount.Load())) } func TestRelayReRegisterServer(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - mgr := createTestRelayInstanceManager(t, ctrl, nil, false) + mgr := createTestRelayInstanceManager(t, ctrl, nil, false, defaultIntegrationInfo) w := httptest.NewRecorder() req := httptest.NewRequest(http.MethodPost, "/__axon/broker/reregister", nil) @@ -163,8 +241,8 @@ func TestRelayReRegisterServer(t *testing.T) { httpHandler.RegisterRoutes(mux) mux.ServeHTTP(w, req) - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, 1, int(mgr.Instance().startCount.Load())) + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, 1, int(mgr.Instance().startCount.Load())) } @@ -217,99 +295,8 @@ func TestSystemCheck(t *testing.T) { mux.ServeHTTP(w, req) // Verify the response - assert.Equal(t, http.StatusOK, w.Code) - assert.Equal(t, jsonPayload, w.Body.String()) -} - -func TestApplyAcceptTransforms(t *testing.T) { - - lifecycle := fxtest.NewLifecycle(t) - cfg := config.NewAgentEnvConfig() - logger := zap.NewNop() - cfg.HttpRelayReflectorMode = config.RelayReflectorAllTraffic - - params := RegistrationReflectorParams{ - Lifecycle: lifecycle, - Logger: logger.Named("reflector"), - Config: cfg, - } - reflector := NewRegistrationReflector( - params, - ) - reflector.Start() - defer reflector.Stop() - - mgr := &relayInstanceManager{ - config: cfg, - reflector: reflector, - logger: zap.NewNop(), - } - - cases := []struct { - acceptFile string - env map[string]string - }{ - { - acceptFile: "accept_files/accept.github.json", - env: map[string]string{ - "GITHUB_API": "api.github.com", - "GITHUB_GRAPHQL": "api.github.com/graphql", - }, - }, - { - acceptFile: "accept_files/accept.bitbucket.basic.json", - env: map[string]string{ - "BITBUCKET_API": "api.bitbucket.com", - }, - }, - } - - for _, c := range cases { - t.Run(c.acceptFile, func(t *testing.T) { - // Set environment variables - for k, v := range c.env { - t.Setenv(k, v) - } - - // validate it doesn't do it when mode is - cfgCopy := cfg - cfgCopy.HttpRelayReflectorMode = config.RelayReflectorRegistrationOnly - mgr.config = cfgCopy - - newFile := mgr.applyAcceptFileTransforms(c.acceptFile) - // Check that the new file is the same as the original - assert.Equal(t, c.acceptFile, newFile, "Expected the accept file to not be transformed when reflector mode is disabled") - - mgr.config = cfg // reset the config to the original - - // Apply the accept file transforms - newFile = mgr.applyAcceptFileTransforms(c.acceptFile) - // Check that the new file is not the same as the original - assert.NotEqual(t, c.acceptFile, newFile, "Expected the accept file to be transformed") - - // gather all of the "origin" values - newFileContent, err := os.ReadFile(newFile) - require.NoError(t, err, "Failed to read transformed accept file") - - v := any(nil) - - err = json.Unmarshal(newFileContent, &v) - require.NoError(t, err, "Failed to unmarshal transformed accept file") - origins, err := jsonpath.Get("$.private[*].origin", v) - require.NoError(t, err, "Failed to get origins from transformed accept file") - for _, origin := range origins.([]any) { - originStr, ok := origin.(string) - require.True(t, ok, "Expected origin to be a string") - // Check that the origin is not empty - assert.NotEmpty(t, originStr, "Expected origin to be non-empty") - // Check that the origin is a valid URL - url, err := url.ParseRequestURI(originStr) - assert.NoError(t, err, "Expected origin to be a valid URL") - require.Contains(t, url.Host, "localhost:", "Expected origin to contain localhost") - } - }) - - } + require.Equal(t, http.StatusOK, w.Code) + require.Equal(t, jsonPayload, w.Body.String()) } type wrappedRelayInstanceManager struct { @@ -317,6 +304,7 @@ type wrappedRelayInstanceManager struct { mockRegistration *MockRegistration reflector *RegistrationReflector requestUrls []url.URL + requests []*http.Request serverUri string } @@ -324,7 +312,7 @@ func (w *wrappedRelayInstanceManager) Instance() *relayInstanceManager { return w.RelayInstanceManager.(*relayInstanceManager) } -func createTestRelayInstanceManager(t *testing.T, controller *gomock.Controller, expectedError error, useReflector bool) *wrappedRelayInstanceManager { +func createTestRelayInstanceManager(t *testing.T, controller *gomock.Controller, expectedError error, useReflector bool, ii common.IntegrationInfo) *wrappedRelayInstanceManager { envVars := map[string]string{ "ACCEPTFILE_DIR": "./accept_files", "GITHUB_TOKEN": "the-token", @@ -343,10 +331,9 @@ func createTestRelayInstanceManager(t *testing.T, controller *gomock.Controller, if useReflector { cfg.HttpRelayReflectorMode = config.RelayReflectorAllTraffic } - logger := zap.NewNop() - ii := common.IntegrationInfo{ - Integration: common.IntegrationGithub, - } + loggerConfig := zap.NewDevelopmentConfig() + loggerConfig.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + logger := zap.Must(loggerConfig.Build()) mockServer := cortex_http.NewMockServer(controller) mockServer.EXPECT().RegisterHandler(gomock.Any()).MinTimes(1) @@ -387,8 +374,11 @@ func createTestRelayInstanceManager(t *testing.T, controller *gomock.Controller, testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { mgr.requestUrls = append(mgr.requestUrls, *r.URL) + mgr.requests = append(mgr.requests, r) })) + t.Cleanup(testServer.Close) + response := &RegistrationInfoResponse{ ServerUri: testServer.URL, Token: "abcd1234", @@ -401,6 +391,8 @@ func createTestRelayInstanceManager(t *testing.T, controller *gomock.Controller, mockRegistration.EXPECT().Register(gomock.Eq(common.IntegrationGithub), gomock.Eq("")).MinTimes(1).Return(response, nil) } + os.Setenv("TEST_SERVER_URL", testServer.URL) + lifecycle.Start(context.Background()) return mgr diff --git a/agent/test/relay/accept-client.json b/agent/test/relay/accept-client.json index 04e7fe6..00643c8 100644 --- a/agent/test/relay/accept-client.json +++ b/agent/test/relay/accept-client.json @@ -19,7 +19,8 @@ "path": "/echo/*", "origin": "http://cortex-fake:8081", "headers": { - "x-test-header": "added-fake-server" + "x-test-header": "added-fake-server", + "x-test-header-plugin": "${plugin:plugin.sh}" } }, diff --git a/agent/test/relay/docker-compose.yml b/agent/test/relay/docker-compose.yml index 95cbeee..2b2809e 100644 --- a/agent/test/relay/docker-compose.yml +++ b/agent/test/relay/docker-compose.yml @@ -28,11 +28,13 @@ services: volumes: - .:/src - ./.mitmproxy:/certs + - ../../server/snykbroker/acceptfile:/agent/plugins environment: CORTEX_API_TOKEN: fake-token CORTEX_API_BASE_URL: http://cortex-fake:8081 PORT: 7433 - env_file: ${ENVFILE:-noproxy.env} + PLUGIN_DIRS: /agent/plugins + env_file: ${ENVFILE:-noproxy.env} command: relay -f /src/accept-client.json -i github -a axon-test depends_on: mitmproxy: diff --git a/agent/test/relay/noproxy.env b/agent/test/relay/noproxy.env index e69de29..ad3e98a 100644 --- a/agent/test/relay/noproxy.env +++ b/agent/test/relay/noproxy.env @@ -0,0 +1 @@ +HEADER_PROXY_VALUE=nope \ No newline at end of file diff --git a/agent/test/relay/relay_test.sh b/agent/test/relay/relay_test.sh index cf4d3f1..654b616 100755 --- a/agent/test/relay/relay_test.sh +++ b/agent/test/relay/relay_test.sh @@ -175,6 +175,15 @@ if [ "$PROXY" == "1" ]; then else echo "Success: Found expected injected header value in result" fi + + # Make sure the plugin header is also injected + if ! echo "$proxy_result" | grep -q "HOME=/root"; then + echo "FAIL: Expected injected plugin header value but not found" + echo "$proxy_result" + exit 1 + else + echo "Success: Found expected injected plugin header value in result" + fi else echo "Checking relay non proxy config..." if echo "$result" | grep -i "x-proxy-mitmproxy"