diff --git a/.github/workflows/workflow-scanner.yml b/.github/workflows/workflow-scanner.yml index b6a33fd2..9e9a27b6 100644 --- a/.github/workflows/workflow-scanner.yml +++ b/.github/workflows/workflow-scanner.yml @@ -1,7 +1,9 @@ name: Test Workflow Scanner on: - pull_request: + push: + branches: + - improve-pr jobs: test-scanner: @@ -28,5 +30,5 @@ jobs: with: api-token: ${{ secrets.FS_API_TOKEN }} github-token: ${{ secrets.GH_PAT }} - llm-api-key: ${{ secrets.GEMINI_API_KEY }} - target-branch: test-locally \ No newline at end of file + openai-api-key: ${{ secrets.OPENAI_API_KEY }} + target-branch: improve-pr \ No newline at end of file diff --git a/README.md b/README.md index 14965425..e5d094fd 100644 --- a/README.md +++ b/README.md @@ -123,4 +123,4 @@ This project is licensed under the terms included in the LICENSE file. ## Next steps - See if Docker image + entrypoint script, instead of composite, can be better. - Don't make this repo public until we remove the LLM KEY and PAT from secrets. -- See what are the possibilities of using GITHUB_TOKEN instead PAT_TOKEN. \ No newline at end of file +- See what are the possibilities of using GITHUB_TOKEN instead PAT_TOKEN. diff --git a/cmd/scanner/main.go b/cmd/scanner/main.go index 9932896c..5430dcdd 100644 --- a/cmd/scanner/main.go +++ b/cmd/scanner/main.go @@ -26,7 +26,10 @@ type batchConfig struct { provider string githubToken string gitlabToken string - llmAPIKey string + openaiKey string + anthropicKey string + geminiKey string + model string commitSHA string sourceBase64 string useGitClone bool @@ -42,7 +45,6 @@ func main() { config.repository, config.commitSHA, config.useGitClone) validateConfig(config) - setupLLMEnvironment(config.llmAPIKey) validateDaggerEnvironment() ctx := context.Background() @@ -71,11 +73,24 @@ func loadConfig() batchConfig { provider := strings.ToLower(os.Getenv("PROVIDER")) // optional override githubToken := os.Getenv("GITHUB_TOKEN") gitlabToken := os.Getenv("GITLAB_TOKEN") - llmAPIKey := os.Getenv("LLM_API_KEY") + openaiKey := os.Getenv("OPENAI_API_KEY") + anthropicKey := os.Getenv("ANTHROPIC_API_KEY") + geminiKey := os.Getenv("GEMINI_API_KEY") + model := os.Getenv("MODEL") commitSHA := os.Getenv("COMMIT_SHA") sourceBase64 := os.Getenv("SOURCE_BASE64") - useGitClone := sourceBase64 == "" && llmAPIKey != "" + // Check if any LLM key is available + hasLLMKey := openaiKey != "" || anthropicKey != "" || geminiKey != "" + useGitClone := sourceBase64 == "" && hasLLMKey + + // Validate API key formats to catch user errors early + validateAPIKeyFormats(openaiKey, anthropicKey, geminiKey) + + // Set default model based on available API key if not specified + if model == "" { + model = getDefaultModel(openaiKey, anthropicKey, geminiKey) + } if provider == "" { if strings.Contains(repository, "gitlab.com") { @@ -90,18 +105,60 @@ func loadConfig() batchConfig { provider: provider, githubToken: githubToken, gitlabToken: gitlabToken, - llmAPIKey: llmAPIKey, + openaiKey: openaiKey, + anthropicKey: anthropicKey, + geminiKey: geminiKey, + model: model, commitSHA: commitSHA, sourceBase64: sourceBase64, useGitClone: useGitClone, } } +func validateAPIKeyFormats(openaiKey, anthropicKey, geminiKey string) { + if openaiKey != "" && !strings.HasPrefix(openaiKey, "sk-") { + log.Fatal("OPENAI_API_KEY appears to be invalid format (should start with 'sk-')") + } + if anthropicKey != "" && !strings.HasPrefix(anthropicKey, "sk-ant-") { + log.Fatal("ANTHROPIC_API_KEY appears to be invalid format (should start with 'sk-ant-')") + } + if geminiKey != "" { + if strings.HasPrefix(geminiKey, "sk-") { + if strings.HasPrefix(geminiKey, "sk-ant-") { + log.Fatal("GEMINI_API_KEY appears to be an Anthropic key (starts with 'sk-ant-'), please use anthropic-api-key input instead") + } + log.Fatal("GEMINI_API_KEY appears to be an OpenAI key (starts with 'sk-'), please use openai-api-key input instead") + } + if !strings.HasPrefix(geminiKey, "AIza") { + log.Fatal("GEMINI_API_KEY appears to be invalid format (should start with 'AIza')") + } + } +} + +func getDefaultModel(openaiKey, anthropicKey, geminiKey string) string { + if openaiKey != "" { + return "gpt-4o" + } + if anthropicKey != "" { + return "claude-3-5-sonnet" + } + if geminiKey != "" { + return "gemini-2.0-flash" + } + + return "" +} + func validateConfig(config batchConfig) { if config.repository == "" { log.Fatal("Missing required environment variable: REPOSITORY") } + validateProviderTokens(config) + validateModeRequirements(config) +} + +func validateProviderTokens(config batchConfig) { if config.provider == "gitlab" { if config.gitlabToken == "" { log.Fatal("Missing GITLAB_TOKEN for gitlab provider") @@ -111,13 +168,15 @@ func validateConfig(config batchConfig) { log.Fatal("Missing GITHUB_TOKEN for github provider") } } +} +func validateModeRequirements(config batchConfig) { if !config.useGitClone && config.sourceBase64 == "" { log.Fatal("Missing SOURCE_BASE64 for legacy mode") } - if config.useGitClone && config.llmAPIKey == "" { - log.Fatal("Missing LLM_API_KEY for git clone mode") + if config.useGitClone && config.openaiKey == "" && config.anthropicKey == "" && config.geminiKey == "" { + log.Fatal("Missing API key for git clone mode (need OPENAI_API_KEY, ANTHROPIC_API_KEY, or GEMINI_API_KEY)") } } @@ -239,34 +298,6 @@ func decodeSourceData(dag *dagger.Client, sourceBase64 string) *dagger.Directory return dag.Directory().WithNewFile("workflows.tar.gz", string(sourceData)) } -func setupLLMEnvironment(llmAPIKey string) { - // Detect provider based on key format and set only the appropriate env var - var providerKey string - var providerName string - - if strings.HasPrefix(llmAPIKey, "sk-") { - providerKey = "OPENAI_API_KEY" - providerName = "OpenAI" - } else if strings.HasPrefix(llmAPIKey, "sk-ant-") { - providerKey = "ANTHROPIC_API_KEY" - providerName = "Anthropic" - } else if strings.HasPrefix(llmAPIKey, "AIza") { - providerKey = "GEMINI_API_KEY" - providerName = "Gemini" - } else { - // Default to OpenAI if format is unknown - providerKey = "OPENAI_API_KEY" - providerName = "OpenAI (default)" - log.Printf("Warning: Unknown API key format, defaulting to OpenAI") - } - - if err := os.Setenv(providerKey, llmAPIKey); err != nil { - log.Printf("Warning: Failed to set %s: %v", providerKey, err) - } else { - log.Printf("Set LLM environment for %s", providerName) - } -} - func incrementUsage(repository string, success bool) error { serviceURL := os.Getenv("SERVICE_URL") if serviceURL == "" { diff --git a/main_test.go b/main_test.go index c1ec9e38..74c9d2d1 100644 --- a/main_test.go +++ b/main_test.go @@ -36,7 +36,7 @@ func TestScanAndFixWorflowsImpl(t *testing.T) { // Step 1: Run ZIZMOR auto-fix mockZizmor.EXPECT(). RunZizmorAutoFix(gomock.Any(), mockDirectory). - Return(mockDirectory, "Fixed 2 security issues automatically", nil) + Return(mockDirectory, []zizmor.Finding{}, "Fixed 2 security issues automatically", nil) // Step 2: Check remaining issues - none found mockZizmor.EXPECT(). @@ -84,7 +84,7 @@ func TestScanAndFixWorflowsImpl(t *testing.T) { // Step 1: Run ZIZMOR auto-fix mockZizmor.EXPECT(). RunZizmorAutoFix(gomock.Any(), mockDirectory). - Return(mockDirectory, "Fixed some issues", nil) + Return(mockDirectory, []zizmor.Finding{}, "Fixed some issues", nil) // Step 2: Check remaining issues - some found remainingIssues := `[{"desc": "manual fix needed"}]` @@ -134,10 +134,7 @@ func TestScanAndFixWorflowsImpl(t *testing.T) { // Step 1: ZIZMOR auto-fix fails mockZizmor.EXPECT(). RunZizmorAutoFix(gomock.Any(), mockDirectory). - Return(nil, "", errors.New("ZIZMOR container failed")) - - // No other calls should happen after failure - + Return(nil, []zizmor.Finding{}, "", errors.New("ZIZMOR container failed")) return mockZizmor, mockAgent, mockGithub, mockDirectory }, expectedResult: "", @@ -156,7 +153,7 @@ func TestScanAndFixWorflowsImpl(t *testing.T) { // Step 1: Run ZIZMOR auto-fix mockZizmor.EXPECT(). RunZizmorAutoFix(gomock.Any(), mockDirectory). - Return(mockDirectory, "Fixed some issues", nil) + Return(mockDirectory, []zizmor.Finding{}, "Fixed some issues", nil) // Step 2: Check remaining issues - some found remainingIssues := `[{"desc": "complex issue"}]` @@ -193,7 +190,7 @@ func TestScanAndFixWorflowsImpl(t *testing.T) { // Step 1: Run ZIZMOR auto-fix mockZizmor.EXPECT(). RunZizmorAutoFix(gomock.Any(), mockDirectory). - Return(mockDirectory, "Fixed issues", nil) + Return(mockDirectory, []zizmor.Finding{}, "Fixed issues", nil) // Step 2: Check remaining issues - none mockZizmor.EXPECT(). diff --git a/mocks/client_mock.go b/mocks/client_mock.go index c8729625..8ab7fddd 100644 --- a/mocks/client_mock.go +++ b/mocks/client_mock.go @@ -109,6 +109,20 @@ func (mr *MockClientMockRecorder) LLM(opts ...any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LLM", reflect.TypeOf((*MockClient)(nil).LLM), opts...) } +// SetSecret mocks base method. +func (m *MockClient) SetSecret(name, plaintext string) *dagger.Secret { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetSecret", name, plaintext) + ret0, _ := ret[0].(*dagger.Secret) + return ret0 +} + +// SetSecret indicates an expected call of SetSecret. +func (mr *MockClientMockRecorder) SetSecret(name, plaintext any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetSecret", reflect.TypeOf((*MockClient)(nil).SetSecret), name, plaintext) +} + // Workspace mocks base method. func (m *MockClient) Workspace(source *dagger.Directory) dagger0.Workspace { m.ctrl.T.Helper() diff --git a/mocks/container_mock.go b/mocks/container_mock.go index ccb20781..34e01c32 100644 --- a/mocks/container_mock.go +++ b/mocks/container_mock.go @@ -109,6 +109,25 @@ func (mr *MockContainerMockRecorder) WithDirectory(path, source any, opts ...any return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithDirectory", reflect.TypeOf((*MockContainer)(nil).WithDirectory), varargs...) } +// WithEnvVariable mocks base method. +func (m *MockContainer) WithEnvVariable(name, value string, opts ...dagger.ContainerWithEnvVariableOpts) dagger0.Container { + m.ctrl.T.Helper() + varargs := []any{name, value} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WithEnvVariable", varargs...) + ret0, _ := ret[0].(dagger0.Container) + return ret0 +} + +// WithEnvVariable indicates an expected call of WithEnvVariable. +func (mr *MockContainerMockRecorder) WithEnvVariable(name, value any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{name, value}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithEnvVariable", reflect.TypeOf((*MockContainer)(nil).WithEnvVariable), varargs...) +} + // WithExec mocks base method. func (m *MockContainer) WithExec(args []string, opts ...dagger.ContainerWithExecOpts) dagger0.Container { m.ctrl.T.Helper() @@ -128,6 +147,25 @@ func (mr *MockContainerMockRecorder) WithExec(args any, opts ...any) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithExec", reflect.TypeOf((*MockContainer)(nil).WithExec), varargs...) } +// WithNewFile mocks base method. +func (m *MockContainer) WithNewFile(path, contents string, opts ...dagger.ContainerWithNewFileOpts) dagger0.Container { + m.ctrl.T.Helper() + varargs := []any{path, contents} + for _, a := range opts { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WithNewFile", varargs...) + ret0, _ := ret[0].(dagger0.Container) + return ret0 +} + +// WithNewFile indicates an expected call of WithNewFile. +func (mr *MockContainerMockRecorder) WithNewFile(path, contents any, opts ...any) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]any{path, contents}, opts...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithNewFile", reflect.TypeOf((*MockContainer)(nil).WithNewFile), varargs...) +} + // WithWorkdir mocks base method. func (m *MockContainer) WithWorkdir(path string, opts ...dagger.ContainerWithWorkdirOpts) dagger0.Container { m.ctrl.T.Helper() diff --git a/mocks/env_mock.go b/mocks/env_mock.go index b77e9767..15d994b2 100644 --- a/mocks/env_mock.go +++ b/mocks/env_mock.go @@ -69,6 +69,48 @@ func (mr *MockEnvMockRecorder) Output(name any) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Output", reflect.TypeOf((*MockEnv)(nil).Output), name) } +// WithDirectoryInput mocks base method. +func (m *MockEnv) WithDirectoryInput(name string, value *dagger.Directory, description string) dagger0.Env { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithDirectoryInput", name, value, description) + ret0, _ := ret[0].(dagger0.Env) + return ret0 +} + +// WithDirectoryInput indicates an expected call of WithDirectoryInput. +func (mr *MockEnvMockRecorder) WithDirectoryInput(name, value, description any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithDirectoryInput", reflect.TypeOf((*MockEnv)(nil).WithDirectoryInput), name, value, description) +} + +// WithDirectoryOutput mocks base method. +func (m *MockEnv) WithDirectoryOutput(name, description string) dagger0.Env { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithDirectoryOutput", name, description) + ret0, _ := ret[0].(dagger0.Env) + return ret0 +} + +// WithDirectoryOutput indicates an expected call of WithDirectoryOutput. +func (mr *MockEnvMockRecorder) WithDirectoryOutput(name, description any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithDirectoryOutput", reflect.TypeOf((*MockEnv)(nil).WithDirectoryOutput), name, description) +} + +// WithSecretInput mocks base method. +func (m *MockEnv) WithSecretInput(name string, secret *dagger.Secret, description string) dagger0.Env { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithSecretInput", name, secret, description) + ret0, _ := ret[0].(dagger0.Env) + return ret0 +} + +// WithSecretInput indicates an expected call of WithSecretInput. +func (mr *MockEnvMockRecorder) WithSecretInput(name, secret, description any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithSecretInput", reflect.TypeOf((*MockEnv)(nil).WithSecretInput), name, secret, description) +} + // WithStringInput mocks base method. func (m *MockEnv) WithStringInput(name, value, description string) dagger0.Env { m.ctrl.T.Helper() diff --git a/mocks/llm_mock.go b/mocks/llm_mock.go index 053966c9..71370d38 100644 --- a/mocks/llm_mock.go +++ b/mocks/llm_mock.go @@ -10,6 +10,7 @@ package mocks import ( + context "context" reflect "reflect" dagger "workflow-scanner/internal/dagger" dagger0 "workflow-scanner/pkg/dagger" @@ -55,6 +56,21 @@ func (mr *MockLLMMockRecorder) Env() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Env", reflect.TypeOf((*MockLLM)(nil).Env)) } +// LastReply mocks base method. +func (m *MockLLM) LastReply(ctx context.Context) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LastReply", ctx) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LastReply indicates an expected call of LastReply. +func (mr *MockLLMMockRecorder) LastReply(ctx any) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LastReply", reflect.TypeOf((*MockLLM)(nil).LastReply), ctx) +} + // WithEnv mocks base method. func (m *MockLLM) WithEnv(env dagger0.Env) dagger0.LLM { m.ctrl.T.Helper() diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 03fd21ec..37d4a35c 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -10,6 +10,7 @@ import ( "log" "os" "path/filepath" + "strings" internalDagger "workflow-scanner/internal/dagger" "workflow-scanner/pkg/dagger" ) @@ -120,33 +121,69 @@ func (agent *AgentImpl) findProjectRoot() (string, error) { } func (agent *AgentImpl) getLLMAPIKey() (string, error) { - llmAPIKey := os.Getenv("LLM_API_KEY") - if llmAPIKey == "" { - openaiKey := os.Getenv("OPENAI_API_KEY") - if openaiKey == "" { - return "", fmt.Errorf("LLM_API_KEY or OPENAI_API_KEY not found in environment") + // Check for specific API keys and validate their formats + if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey != "" { + if !strings.HasPrefix(openaiKey, "sk-") { + return "", fmt.Errorf("OPENAI_API_KEY appears to be invalid format (should start with 'sk-')") } - llmAPIKey = openaiKey + + return openaiKey, nil + } + if anthropicKey := os.Getenv("ANTHROPIC_API_KEY"); anthropicKey != "" { + if !strings.HasPrefix(anthropicKey, "sk-ant-") { + return "", fmt.Errorf("ANTHROPIC_API_KEY appears to be invalid format (should start with 'sk-ant-')") + } + + return anthropicKey, nil } + if geminiKey := os.Getenv("GEMINI_API_KEY"); geminiKey != "" { + if strings.HasPrefix(geminiKey, "sk-") { + if strings.HasPrefix(geminiKey, "sk-ant-") { + return "", fmt.Errorf("GEMINI_API_KEY appears to be an Anthropic key (starts with 'sk-ant-'), please use anthropic-api-key input instead") + } - return llmAPIKey, nil + return "", fmt.Errorf("GEMINI_API_KEY appears to be an OpenAI key (starts with 'sk-'), please use openai-api-key input instead") + } + if !strings.HasPrefix(geminiKey, "AIza") { + return "", fmt.Errorf("GEMINI_API_KEY appears to be invalid format (should start with 'AIza')") + } + + return geminiKey, nil + } + + return "", fmt.Errorf("no API key found (need OPENAI_API_KEY, ANTHROPIC_API_KEY, or GEMINI_API_KEY)") } func (agent *AgentImpl) createLLMContainer(sourceWithPrompt *internalDagger.Directory, llmAPIKey, issues string) dagger.Container { - log.Printf("DEBUG: Using custom container approach with OpenAI API key") - log.Printf("DEBUG: Creating container with OpenAI API key (length: %d)", len(llmAPIKey)) + log.Printf("DEBUG: Using custom container approach with LLM API key") + log.Printf("DEBUG: Creating container with API key (length: %d)", len(llmAPIKey)) log.Printf("DEBUG: ZIZMOR issues length: %d", len(issues)) llmProcessorContent := GetLLMProcessorCode() - return agent.client.Container(). + container := agent.client.Container(). From("golang:1.25-alpine"). WithExec([]string{"apk", "add", "--no-cache", "git"}). - WithEnvVariable("OPENAI_API_KEY", llmAPIKey). WithEnvVariable("ZIZMOR_ISSUES", issues). WithDirectory("/workspace", sourceWithPrompt). WithWorkdir("/workspace"). - WithExec([]string{"sh", "-c", "echo 'DEBUG: Workspace contents:' && ls -la"}). + WithExec([]string{"sh", "-c", "echo 'DEBUG: Workspace contents:' && ls -la"}) + + // Set all API keys from environment + if openaiKey := os.Getenv("OPENAI_API_KEY"); openaiKey != "" { + container = container.WithEnvVariable("OPENAI_API_KEY", openaiKey) + } + if anthropicKey := os.Getenv("ANTHROPIC_API_KEY"); anthropicKey != "" { + container = container.WithEnvVariable("ANTHROPIC_API_KEY", anthropicKey) + } + if geminiKey := os.Getenv("GEMINI_API_KEY"); geminiKey != "" { + container = container.WithEnvVariable("GEMINI_API_KEY", geminiKey) + } + if model := os.Getenv("MODEL"); model != "" { + container = container.WithEnvVariable("MODEL", model) + } + + return container. WithExec([]string{"rm", "-f", "go.mod", "go.sum"}). WithExec([]string{"sh", "-c", "echo 'DEBUG: Initializing Go module' && go mod init llm-processor"}). WithExec([]string{"sh", "-c", "echo 'DEBUG: Getting OpenAI Go client' && go get github.com/sashabaranov/go-openai"}). diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 87e24e9b..58906395 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -99,40 +99,19 @@ func TestAgentImpl_FixRemainingIssues_LLMChain(t *testing.T) { ctrl := gomock.NewController(t) defer ctrl.Finish() - // Test that we can read the prompt file and create error scenarios - // without requiring full Dagger integration - t.Run("prompt file not found", func(t *testing.T) { + // Test that we can verify early return without requiring full Dagger integration + t.Run("no issues requiring LLM fixes", func(t *testing.T) { mockClient := mocks.NewMockClient(ctrl) - mockEnv := mocks.NewMockEnv(ctrl) - mockWorkspace := mocks.NewMockWorkspace(ctrl) sourceDirectory := &internalDagger.Directory{} - // Set up the mocks for the initial setup that happens before prompt file reading - mockClient.EXPECT().Workspace(sourceDirectory).Return(mockWorkspace) - mockClient.EXPECT().Env().Return(mockEnv) - mockEnv.EXPECT().WithStringInput("zizmor_issues", `[{"desc": "security issue"}]`, gomock.Any()).Return(mockEnv) - mockEnv.EXPECT().WithStringInput("GO111MODULE", "on", gomock.Any()).Return(mockEnv) - mockEnv.EXPECT().WithStringInput("GOWORK", "off", gomock.Any()).Return(mockEnv) - mockEnv.EXPECT().WithWorkspaceInput("workspace", mockWorkspace, gomock.Any()).Return(mockEnv) - mockEnv.EXPECT().WithWorkspaceOutput("completed", gomock.Any()).Return(mockEnv) - mockEnv.EXPECT().WithStringOutput("explanations", gomock.Any()).Return(mockEnv) - - // Create a temporary directory without the prompt file - tempDir := t.TempDir() - originalWd, _ := os.Getwd() - defer os.Chdir(originalWd) - - // Change to temp directory so prompt file won't be found - os.Chdir(tempDir) - + // Pass empty issues to trigger early return (doesn't require Dagger infrastructure) agent := NewAgentImpl(mockClient) - actualDir, explanation, err := agent.fixRemainingIssuesImpl(context.Background(), sourceDirectory, `[{"desc": "security issue"}]`) + actualDir, explanation, err := agent.fixRemainingIssuesImpl(context.Background(), sourceDirectory, "[]") - // Should return error when prompt file can't be found - assert.Error(t, err) - assert.Contains(t, err.Error(), "failed to read prompt file") + // Should return successfully with no issues message + assert.NoError(t, err) assert.Equal(t, sourceDirectory, actualDir) - assert.Equal(t, "", explanation) + assert.Contains(t, explanation, "No remaining issues") }) t.Run("prompt file reading validates filesystem approach", func(t *testing.T) { diff --git a/pkg/agent/llm_processor.go b/pkg/agent/llm_processor.go index a6822c5c..6d30c772 100644 --- a/pkg/agent/llm_processor.go +++ b/pkg/agent/llm_processor.go @@ -5,11 +5,13 @@ func GetLLMProcessorCode() string { return `package main import ( + "bytes" "context" "encoding/json" "fmt" "io/ioutil" "log" + "net/http" "os" "path/filepath" "strings" @@ -31,6 +33,9 @@ type LLMResponse struct { func main() { log.Println("DEBUG: Starting LLM processor") log.Printf("DEBUG: OPENAI_API_KEY length: %d", len(os.Getenv("OPENAI_API_KEY"))) + log.Printf("DEBUG: ANTHROPIC_API_KEY length: %d", len(os.Getenv("ANTHROPIC_API_KEY"))) + log.Printf("DEBUG: GEMINI_API_KEY length: %d", len(os.Getenv("GEMINI_API_KEY"))) + log.Printf("DEBUG: MODEL: %s", os.Getenv("MODEL")) log.Printf("DEBUG: ZIZMOR_ISSUES length: %d", len(os.Getenv("ZIZMOR_ISSUES"))) if err := processWorkflows(); err != nil { @@ -45,12 +50,6 @@ func processWorkflows() error { return err } - client, ctx, cancel, err := createOpenAIClient() - if err != nil { - return err - } - defer cancel() - workflowFiles, err := findWorkflowFiles() if err != nil { return fmt.Errorf("failed to find workflow files: %w", err) @@ -59,12 +58,29 @@ func processWorkflows() error { enhancedPrompt := buildEnhancedPrompt(promptContent, issues, workflowFiles) - resp, err := callOpenAI(ctx, client, enhancedPrompt) - if err != nil { - return err - } + // Determine which provider to use based on available API keys + if os.Getenv("OPENAI_API_KEY") != "" { + log.Println("DEBUG: Using OpenAI provider") + client, ctx, cancel, err := createOpenAIClient() + if err != nil { + return err + } + defer cancel() - return processOpenAIResponse(resp) + resp, err := callOpenAI(ctx, client, enhancedPrompt) + if err != nil { + return err + } + return processOpenAIResponse(resp) + } else if os.Getenv("GEMINI_API_KEY") != "" { + log.Println("DEBUG: Using Gemini provider") + return callGemini(enhancedPrompt) + } else if os.Getenv("ANTHROPIC_API_KEY") != "" { + log.Println("DEBUG: Using Anthropic provider") + return callAnthropic(enhancedPrompt) + } else { + return fmt.Errorf("no API key found (need OPENAI_API_KEY, ANTHROPIC_API_KEY, or GEMINI_API_KEY)") + } } func loadInputData() ([]byte, string, error) { @@ -117,8 +133,15 @@ func callOpenAI(ctx context.Context, client *openai.Client, enhancedPrompt strin maxTokens = 4000 lowTemperature = 0.1 ) + + // Get model from environment or default to gpt-4o + model := os.Getenv("MODEL") + if model == "" { + model = "gpt-4o" + } + req := openai.ChatCompletionRequest{ - Model: openai.GPT4oMini, + Model: model, Messages: []openai.ChatCompletionMessage{ { Role: openai.ChatMessageRoleUser, @@ -223,5 +246,167 @@ func applyFileChange(change FileChange) error { log.Printf("Applied fix to %s", change.Path) return nil -}` +} + +func callGemini(enhancedPrompt string) error { + apiKey := os.Getenv("GEMINI_API_KEY") + model := os.Getenv("MODEL") + if model == "" { + model = "gemini-2.0-flash" + } + + log.Printf("DEBUG: Calling Gemini API with model: %s", model) + + // Gemini API request structure + requestBody := map[string]interface{}{ + "contents": []map[string]interface{}{ + { + "parts": []map[string]interface{}{ + { + "text": enhancedPrompt, + }, + }, + }, + }, + "generationConfig": map[string]interface{}{ + "temperature": 0.1, + "maxOutputTokens": 4000, + }, + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + url := fmt.Sprintf("https://generativelanguage.googleapis.com/v1beta/models/%s:generateContent?key=%s", model, apiKey) + resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to call Gemini API: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("Gemini API returned status %d", resp.StatusCode) + } + + var response map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return fmt.Errorf("failed to decode Gemini response: %w", err) + } + + // Extract text from Gemini response + candidates, ok := response["candidates"].([]interface{}) + if !ok || len(candidates) == 0 { + return fmt.Errorf("no candidates in Gemini response") + } + + candidate := candidates[0].(map[string]interface{}) + content := candidate["content"].(map[string]interface{}) + parts := content["parts"].([]interface{}) + if len(parts) == 0 { + return fmt.Errorf("no parts in Gemini response") + } + + part := parts[0].(map[string]interface{}) + text := part["text"].(string) + + log.Printf("DEBUG: Gemini response received, length: %d", len(text)) + + // Parse and process the response using the same logic as OpenAI + var llmResponse LLMResponse + if err := parseJSONResponse(text, &llmResponse); err != nil { + return fmt.Errorf("failed to parse JSON from Gemini response: %w", err) + } + + return processGenericResponse(&llmResponse) +} + +func callAnthropic(enhancedPrompt string) error { + apiKey := os.Getenv("ANTHROPIC_API_KEY") + model := os.Getenv("MODEL") + if model == "" { + model = "claude-3-5-sonnet-20241022" + } + + log.Printf("DEBUG: Calling Anthropic API with model: %s", model) + + // Anthropic API request structure + requestBody := map[string]interface{}{ + "model": model, + "max_tokens": 4000, + "temperature": 0.1, + "messages": []map[string]interface{}{ + { + "role": "user", + "content": enhancedPrompt, + }, + }, + } + + jsonData, err := json.Marshal(requestBody) + if err != nil { + return fmt.Errorf("failed to marshal request: %w", err) + } + + req, err := http.NewRequest("POST", "https://api.anthropic.com/v1/messages", bytes.NewBuffer(jsonData)) + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("x-api-key", apiKey) + req.Header.Set("anthropic-version", "2023-06-01") + + client := &http.Client{} + resp, err := client.Do(req) + if err != nil { + return fmt.Errorf("failed to call Anthropic API: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != 200 { + return fmt.Errorf("Anthropic API returned status %d", resp.StatusCode) + } + + var response map[string]interface{} + if err := json.NewDecoder(resp.Body).Decode(&response); err != nil { + return fmt.Errorf("failed to decode Anthropic response: %w", err) + } + + // Extract text from Anthropic response + content, ok := response["content"].([]interface{}) + if !ok || len(content) == 0 { + return fmt.Errorf("no content in Anthropic response") + } + + contentItem := content[0].(map[string]interface{}) + text := contentItem["text"].(string) + + log.Printf("DEBUG: Anthropic response received, length: %d", len(text)) + + // Parse and process the response using the same logic as OpenAI + var llmResponse LLMResponse + if err := parseJSONResponse(text, &llmResponse); err != nil { + return fmt.Errorf("failed to parse JSON from Anthropic response: %w", err) + } + + return processGenericResponse(&llmResponse) +} + +func processGenericResponse(llmResponse *LLMResponse) error { + log.Printf("DEBUG: Applying %d file changes", len(llmResponse.FileChanges)) + for i, change := range llmResponse.FileChanges { + log.Printf("DEBUG: Applying change %d/%d to %s", i+1, len(llmResponse.FileChanges), change.Path) + if err := applyFileChange(change); err != nil { + log.Printf("Warning: Failed to apply change to %s: %v", change.Path, err) + } + } + + log.Printf("DEBUG: Returning explanation: %d chars", len(llmResponse.Explanation)) + fmt.Print(llmResponse.Explanation) + + return nil +} +` } diff --git a/pkg/github/result.go b/pkg/github/result.go index d41f85c7..ffbb02b2 100644 --- a/pkg/github/result.go +++ b/pkg/github/result.go @@ -1,7 +1,9 @@ package github import ( + "encoding/json" "fmt" + "log" "regexp" "strings" "workflow-scanner/pkg/zizmor" @@ -13,14 +15,11 @@ const ( **Findings:** %d -### Files Auto-fixed by ZIZMOR +### Automatic Fixes Applied | File | Fixes | | --- | ---: | %s -### LLM Summary -%s - --- **Validation:** %s %s @@ -33,7 +32,7 @@ const ( %s --- -*Automated security audit by ZIZMOR + AI analysis*` +*Automated security audit by AI analysis*` ) type Result struct { @@ -54,6 +53,38 @@ var ( var fixLineRe = regexp.MustCompile(`^\s*(.+?):\s*(\d+)`) +func shouldSkipLine(l string) bool { + if l == "" || strings.HasPrefix(l, "Successfully applied") { + return true + } + + return l == "}" || l == "Fix Summary" || l == "]" +} + +func parseTableRow(l string) string { + const expectedRegexGroups = 3 + if m := fixLineRe.FindStringSubmatch(l); len(m) == expectedRegexGroups { + file := strings.TrimSpace(m[1]) + count := m[2] + + return fmt.Sprintf("| %s | %s |\n", file, count) + } + + if idx := strings.Index(l, ":"); idx != -1 { + file := strings.TrimSpace(l[:idx]) + rest := strings.TrimSpace(l[idx+1:]) + numRe := regexp.MustCompile(`\d+`) + num := numRe.FindString(rest) + if num == "" { + return fmt.Sprintf("| %s | - |\n", file) + } + + return fmt.Sprintf("| %s | %s |\n", file, num) + } + + return fmt.Sprintf("| %s | - |\n", l) +} + func fixSummaryToTableRows(summary string) string { if strings.TrimSpace(summary) == "" { return "| (none) | 0 |\n" @@ -62,37 +93,226 @@ func fixSummaryToTableRows(summary string) string { rows := make([]string, 0, len(lines)) for _, l := range lines { l = strings.TrimSpace(l) - if l == "" { + if shouldSkipLine(l) { continue } - if strings.HasPrefix(l, "Successfully applied") { - continue + rows = append(rows, parseTableRow(l)) + } + + return strings.Join(rows, "") +} + +func formatRemainingIssues(finalValidation string) string { + if finalValidation == "" || finalValidation == "[]" || finalValidation == "[]\n" { + return "" + } + + allFindings := parseZizmorFindings(finalValidation) + if allFindings == nil { + return fmt.Sprintf("**Manual review needed - some issues remain:**\n```json\n%s\n```", finalValidation) + } + + log.Printf("DEBUG: Successfully parsed %d findings", len(allFindings)) + + var result strings.Builder + result.WriteString("**Manual review needed - some issues remain:**\n\n") + + fileIssues := groupFindingsByFile(allFindings) + + for filePath, issues := range fileIssues { + result.WriteString("
\n") + result.WriteString(fmt.Sprintf("📄 %s (click to expand)\n\n", filePath)) + + for _, issue := range issues { + formatIssueDetails(&result, issue) + result.WriteString("---\n\n") } - const expectedRegexGroups = 3 - if m := fixLineRe.FindStringSubmatch(l); len(m) == expectedRegexGroups { - file := strings.TrimSpace(m[1]) - count := m[2] - rows = append(rows, fmt.Sprintf("| %s | %s |\n", file, count)) - continue + result.WriteString("
\n\n") + } + + return result.String() +} + +func parseZizmorFindings(finalValidation string) []zizmor.Finding { + var allFindings []zizmor.Finding + + cleanedInput := strings.TrimSuffix(finalValidation, "[]") + + if err := json.Unmarshal([]byte(cleanedInput), &allFindings); err != nil { + log.Printf("ERROR: Failed to unmarshal ZIZMOR findings: %v", err) + const maxLogChars = 500 + log.Printf("DEBUG: Raw input (first %d chars): %s", maxLogChars, finalValidation[:min(maxLogChars, len(finalValidation))]) + + return nil + } + + return allFindings +} + +func groupFindingsByFile(findings []zizmor.Finding) map[string][]zizmor.Finding { + fileIssues := make(map[string][]zizmor.Finding) + for _, finding := range findings { + for _, loc := range finding.Locations { + if loc.Symbolic.Key.Local != nil { + filePath := loc.Symbolic.Key.Local.GivenPath + fileIssues[filePath] = append(fileIssues[filePath], finding) + + break + } } - if idx := strings.Index(l, ":"); idx != -1 { - file := strings.TrimSpace(l[:idx]) - rest := strings.TrimSpace(l[idx+1:]) - numRe := regexp.MustCompile(`\d+`) - num := numRe.FindString(rest) - if num == "" { - rows = append(rows, fmt.Sprintf("| %s | - |\n", file)) - } else { - rows = append(rows, fmt.Sprintf("| %s | %s |\n", file, num)) + } + + return fileIssues +} + +func formatIssueDetails(result *strings.Builder, issue zizmor.Finding) { + result.WriteString(fmt.Sprintf("- **Issue:** %s\n", issue.Desc)) + result.WriteString(fmt.Sprintf("- **Severity:** %s\n", issue.Determinations.Severity)) + + for _, loc := range issue.Locations { + if loc.Concrete.Location.StartPoint.Row > 0 { + result.WriteString(fmt.Sprintf("- **Location:** Line %d\n", loc.Concrete.Location.StartPoint.Row)) + + if loc.Symbolic.Annotation != "" { + result.WriteString(fmt.Sprintf("- **Details:** %s\n", loc.Symbolic.Annotation)) } + break + } + } + + result.WriteString("- **Manual Fix Needed:** Review the TODO comments added in the code changes for suggested fixes.\n\n") +} + +type externalDepData struct { + repoStats map[string]int + repoFiles map[string][]string + repoDetails map[string][]string +} + +func parseExternalDependencyLines(lines []string) *externalDepData { + data := &externalDepData{ + repoStats: make(map[string]int), + repoFiles: make(map[string][]string), + repoDetails: make(map[string][]string), + } + + currentRepo := "" + currentFindingBlock := "" + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { continue } - rows = append(rows, fmt.Sprintf("| %s | - |\n", l)) + + if isRepoHeader(line) { + currentRepo, currentFindingBlock = processRepoHeader(line, currentRepo, currentFindingBlock, data) + } else if strings.HasPrefix(line, "- File:") && currentRepo != "" { + currentFindingBlock = processFileLine(line, currentRepo, currentFindingBlock, data) + } } - return strings.Join(rows, "") + if currentRepo != "" && currentFindingBlock != "" { + data.repoDetails[currentRepo] = append(data.repoDetails[currentRepo], currentFindingBlock) + } + + return data +} + +func isRepoHeader(line string) bool { + return strings.HasPrefix(line, "- **") && strings.Contains(line, "**:") +} + +func processRepoHeader(line, currentRepo, currentFindingBlock string, data *externalDepData) (string, string) { + if currentRepo != "" && currentFindingBlock != "" { + data.repoDetails[currentRepo] = append(data.repoDetails[currentRepo], currentFindingBlock) + currentFindingBlock = "" + } + + parts := strings.Split(line, "**:") + if len(parts) >= 1 { + currentRepo = strings.TrimPrefix(parts[0], "- **") + data.repoStats[currentRepo]++ + + if len(parts) > 1 { + desc := strings.TrimSpace(parts[1]) + currentFindingBlock = fmt.Sprintf("- **Issue:** %s\n", desc) + } + } + + return currentRepo, currentFindingBlock +} + +func processFileLine(line, currentRepo, currentFindingBlock string, data *externalDepData) string { + filePath := strings.TrimPrefix(line, "- File:") + filePath = strings.TrimSpace(filePath) + data.repoFiles[currentRepo] = append(data.repoFiles[currentRepo], filePath) + + return currentFindingBlock + fmt.Sprintf("- **File:** %s\n", filePath) +} + +func buildExternalSummaryTable(data *externalDepData) string { + var result strings.Builder + + result.WriteString("**Summary:** ") + totalFindings := 0 + for _, count := range data.repoStats { + totalFindings += count + } + result.WriteString(fmt.Sprintf("%d findings across %d actions\n\n", totalFindings, len(data.repoStats))) + + result.WriteString("| Action/Repo | Files | Findings |\n") + result.WriteString("| --- | ---: | ---: |\n") + + for repo, count := range data.repoStats { + fileCount := len(data.repoFiles[repo]) + if fileCount == 0 { + fileCount = 1 + } + result.WriteString(fmt.Sprintf("| %s | %d | %d |\n", repo, fileCount, count)) + } + + result.WriteString("\n") + + return result.String() +} + +func buildExternalDetailedFindings(data *externalDepData) string { + var result strings.Builder + + result.WriteString("
\n") + result.WriteString("📋 Detailed Findings (click to expand)\n\n") + + for repo, details := range data.repoDetails { + result.WriteString("
\n") + result.WriteString(fmt.Sprintf("📦 %s\n\n", repo)) + for _, finding := range details { + result.WriteString(finding) + result.WriteString("\n---\n\n") + } + result.WriteString("
\n\n") + } + + result.WriteString("
\n") + + return result.String() +} + +func formatExternalDependencies(summaryFindings string) string { + if strings.TrimSpace(summaryFindings) == "" { + return "No external dependencies scanned." + } + + lines := strings.Split(summaryFindings, "\n") + data := parseExternalDependencyLines(lines) + + if len(data.repoStats) == 0 { + return summaryFindings + } + + return buildExternalSummaryTable(data) + buildExternalDetailedFindings(data) } func GetPrTitleBody(finalValidation string, zizmorFindings []zizmor.Finding, fixSummary string, llmOut string, summaryFindings string) (string, string) { @@ -104,20 +324,21 @@ func GetPrTitleBody(finalValidation string, zizmorFindings []zizmor.Finding, fix validationStatus = "**All security issues resolved!** No vulnerabilities detected." result = passed } else { - validationStatus = fmt.Sprintf("**Manual review needed - some issues remain:**\n```json\n%s\n```", finalValidation) + validationStatus = formatRemainingIssues(finalValidation) result = failed } tableRows := fixSummaryToTableRows(fixSummary) + formattedExternal := formatExternalDependencies(summaryFindings) body := fmt.Sprintf(bodyFmt, len(zizmorFindings), tableRows, - llmOut, + // llmOut, result.status, result.text, validationStatus, - summaryFindings, + formattedExternal, ) // GitHub PR body limit is 65,536 characters diff --git a/pkg/github/results_test.go b/pkg/github/results_test.go index 4a076358..6be1aa17 100644 --- a/pkg/github/results_test.go +++ b/pkg/github/results_test.go @@ -27,13 +27,12 @@ func TestGetPrTitleBody(t *testing.T) { summaryFindings: "No issues in external dependencies", expectedTitle: "Security Audit & Fixes for GitHub Actions Workflows", expectedBodyParts: []string{ - "Complete Security Audit Report", + "Security Audit Summary", "Fixed 2 security issues automatically", - "Applied additional manual fixes", "PASSED", "All security issues resolved!", "No issues in external dependencies", - "Automated security audit by ZIZMOR + AI analysis", + "Automated security audit by AI analysis", }, }, { @@ -78,7 +77,6 @@ func TestGetPrTitleBody(t *testing.T) { expectedBodyParts: []string{ "NEEDS REVIEW", "Manual review needed - some issues remain:", - `[{"desc": "remaining issue", "file": "test.yml"}]`, "Some external issues found", }, }, @@ -123,10 +121,8 @@ func TestGetPrTitleBody(t *testing.T) { "Expected body to contain: %s", expectedPart) } - assert.Contains(t, body, "Auto-fixed by ZIZMOR") - assert.Contains(t, body, "Manual Security Fixes Applied") - assert.Contains(t, body, "Validation Report:") - assert.Contains(t, body, "External Dependencies Security Scan") + assert.Contains(t, body, "Automatic Fixes Applied") + assert.Contains(t, body, "External Dependencies Scan") }) } } diff --git a/pkg/zizmor/zizmor.go b/pkg/zizmor/zizmor.go index 510658e8..c830d7d1 100644 --- a/pkg/zizmor/zizmor.go +++ b/pkg/zizmor/zizmor.go @@ -1,6 +1,6 @@ package zizmor -//go:generate mockgen -source=zizmor.go -destination=../../mocks/zizmor_mock.go -package=mocks Zizmor +//go:generate mockgen -source=zizmor.go -destination=zizmor_mock.go -package=zizmor Zizmor import ( "context" diff --git a/pkg/zizmor/zizmor_output_parser.go b/pkg/zizmor/zizmor_output_parser.go index da26c0ab..af77c84c 100644 --- a/pkg/zizmor/zizmor_output_parser.go +++ b/pkg/zizmor/zizmor_output_parser.go @@ -95,8 +95,8 @@ func ParseZizmorOutput(input string) ([]Finding, string, error) { } fixSummary := "" - if lastIndex < len(input) { - fixSummary = strings.TrimSpace(input[lastIndex:]) + if lastIndex < len(input)-1 { + fixSummary = strings.TrimSpace(input[lastIndex+1:]) } return findings, fixSummary, nil diff --git a/pkg/zizmor/zizmor_test.go b/pkg/zizmor/zizmor_test.go index d2b99704..54a76def 100644 --- a/pkg/zizmor/zizmor_test.go +++ b/pkg/zizmor/zizmor_test.go @@ -93,7 +93,7 @@ func TestZizmorImpl_RunZizmorAutoFix(t *testing.T) { { name: "successful auto-fix execution", containerFindings: []Finding{}, - containerFixSummary: "Fixed 3 security vulnerabilities in workflows", + containerFixSummary: `{"desc": "test"} Fixed 3 security vulnerabilities in workflows`, containerError: nil, expectedOutput: "Fixed 3 security vulnerabilities in workflows", expectedError: "", @@ -101,7 +101,7 @@ func TestZizmorImpl_RunZizmorAutoFix(t *testing.T) { { name: "auto-fix with no changes", containerFindings: []Finding{}, - containerFixSummary: "No security issues found to fix", + containerFixSummary: `{"desc": "test"} No security issues found to fix`, containerError: nil, expectedOutput: "No security issues found to fix", expectedError: "", @@ -124,8 +124,8 @@ func TestZizmorImpl_RunZizmorAutoFix(t *testing.T) { mockContainer.EXPECT().WithWorkdir("/workspace").Return(mockContainer) // Setup the auto-fix execution - mockContainer.EXPECT().WithExec([]string{"sh", "-c", "zizmor --fix=all .github/workflows/ 2>&1 || true"}).Return(mockContainer) - mockContainer.EXPECT().Stdout(gomock.Any()).Return(tt.containerFindings, tt.containerFixSummary, tt.containerError) + mockContainer.EXPECT().WithExec([]string{"sh", "-c", "zizmor -q --format=json --fix=all .github/workflows/ 2>&1 || true"}).Return(mockContainer) + mockContainer.EXPECT().Stdout(gomock.Any()).Return(tt.containerFixSummary, tt.containerError) if tt.containerError == nil { mockContainer.EXPECT().Directory("/workspace").Return(resultDir)