From 9b1d31cb671c6b35093dc15a362e4a8fe25b1696 Mon Sep 17 00:00:00 2001 From: Michel Osswald Date: Thu, 16 Apr 2026 23:46:00 +0200 Subject: [PATCH] fix(start): preflight agent launch setup --- cmd/kontext/main.go | 4 +- internal/run/run.go | 90 ++++++++++++++++++++++++++++--- internal/run/run_test.go | 112 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 197 insertions(+), 9 deletions(-) diff --git a/cmd/kontext/main.go b/cmd/kontext/main.go index 80593c0..7a0bff8 100644 --- a/cmd/kontext/main.go +++ b/cmd/kontext/main.go @@ -7,7 +7,6 @@ import ( "fmt" "net" "os" - "os/exec" "time" "github.com/spf13/cobra" @@ -62,7 +61,8 @@ func startCmd() *cobra.Command { ClientID: auth.DefaultClientID, Args: args, }) - if exitErr, ok := err.(*exec.ExitError); ok { + if exitErr, ok := err.(*run.AgentExitError); ok { + fmt.Fprintf(os.Stderr, "Error: %v\n", exitErr) os.Exit(exitErr.ExitCode()) } return err diff --git a/internal/run/run.go b/internal/run/run.go index 66206d3..b512c27 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -5,6 +5,7 @@ import ( "bufio" "context" "encoding/json" + "errors" "fmt" "io" "net/http" @@ -41,8 +42,9 @@ type Options struct { // Start is the main entry point for `kontext start`. func Start(ctx context.Context, opts Options) error { - if _, ok := agent.Get(opts.Agent); !ok { - return fmt.Errorf("unsupported agent %q (supported: %s)", opts.Agent, strings.Join(supportedAgents(), ", ")) + agentPath, err := preflightAgent(opts.Agent) + if err != nil { + return err } // 1. Auth @@ -212,11 +214,25 @@ func Start(ctx context.Context, opts Options) error { // 9. Launch agent with hooks fmt.Fprintf(os.Stderr, "\nLaunching %s...\n\n", opts.Agent) - agentErr := launchAgentWithSettings(ctx, opts.Agent, env, opts.Args, settingsPath) + agentErr := launchAgentWithSettings(ctx, opts.Agent, agentPath, env, opts.Args, settingsPath) return agentErr } +// AgentExitError reports an agent process that launched but exited unsuccessfully. +type AgentExitError struct { + Agent string + Err *exec.ExitError +} + +func (e *AgentExitError) Error() string { + return fmt.Sprintf("%s exited with code %d after Kontext setup completed", e.Agent, e.Err.ExitCode()) +} + +func (e *AgentExitError) Unwrap() error { return e.Err } + +func (e *AgentExitError) ExitCode() int { return e.Err.ExitCode() } + type sessionEnder interface { EndSession(context.Context, string) error } @@ -804,12 +820,68 @@ func newSessionTokenSource(ctx context.Context, session *auth.Session) backend.T } } -func launchAgentWithSettings(_ context.Context, agentName string, env, extraArgs []string, settingsPath string) error { - binaryPath, err := exec.LookPath(agentName) +func preflightAgent(agentName string) (string, error) { + if _, ok := agent.Get(agentName); !ok { + return "", fmt.Errorf("unsupported agent %q (supported: %s)", agentName, strings.Join(supportedAgents(), ", ")) + } + return findExecutable(agentName, os.Getenv("PATH")) +} + +func findExecutable(agentName, pathEnv string) (string, error) { + if strings.ContainsRune(agentName, os.PathSeparator) || strings.Contains(agentName, string(os.PathSeparator)) { + return validateExecutable(agentName, agentName) + } + + var permissionErr error + for _, dir := range filepath.SplitList(pathEnv) { + var candidate string + if dir == "" { + candidate = "." + string(os.PathSeparator) + agentName + } else { + candidate = filepath.Join(dir, agentName) + } + path, err := validateExecutable(agentName, candidate) + if err == nil { + if !filepath.IsAbs(path) { + return "", fmt.Errorf("agent %q resolved from relative PATH entry at %s: %w", agentName, path, exec.ErrDot) + } + return path, nil + } + if errors.Is(err, exec.ErrDot) { + return "", err + } + if errors.Is(err, os.ErrPermission) && permissionErr == nil { + permissionErr = err + } + } + + if permissionErr != nil { + return "", permissionErr + } + return "", fmt.Errorf("agent %q not found in PATH", agentName) +} + +func validateExecutable(agentName, path string) (string, error) { + info, err := os.Stat(path) if err != nil { - return fmt.Errorf("agent %q not found in PATH: %w", agentName, err) + if os.IsNotExist(err) { + return "", err + } + return "", fmt.Errorf("inspect agent %q: %w", agentName, err) } + if info.IsDir() || info.Mode()&0o111 == 0 { + return "", fmt.Errorf("agent %q found at %s but is not executable: %w", agentName, path, os.ErrPermission) + } + if _, err := exec.LookPath(path); err != nil { + if errors.Is(err, exec.ErrDot) { + return "", err + } + return "", fmt.Errorf("agent %q found at %s but is not executable by the current user: %w", agentName, path, os.ErrPermission) + } + return path, nil +} +func launchAgentWithSettings(_ context.Context, agentName, binaryPath string, env, extraArgs []string, settingsPath string) error { var args []string if settingsPath != "" { args = append(args, "--settings", settingsPath) @@ -834,10 +906,14 @@ func launchAgentWithSettings(_ context.Context, agentName string, env, extraArgs } }() - err = cmd.Wait() + err := cmd.Wait() signal.Stop(sigCh) close(sigCh) + var exitErr *exec.ExitError + if errors.As(err, &exitErr) { + return &AgentExitError{Agent: agentName, Err: exitErr} + } return err } diff --git a/internal/run/run_test.go b/internal/run/run_test.go index b43b823..c81e729 100644 --- a/internal/run/run_test.go +++ b/internal/run/run_test.go @@ -4,13 +4,17 @@ import ( "bytes" "context" "encoding/json" + "errors" "fmt" "io" "net/http" "net/http/httptest" "net/url" "os" + "os/exec" + "path/filepath" "reflect" + "runtime" "strings" "testing" @@ -38,6 +42,114 @@ func TestFilterArgs(t *testing.T) { } } +func TestFindExecutableReturnsExecutablePath(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + path := filepath.Join(dir, "test-agent") + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatal(err) + } + + got, err := findExecutable("test-agent", dir) + if err != nil { + t.Fatalf("findExecutable() error = %v", err) + } + if got != path { + t.Fatalf("findExecutable() = %q, want %q", got, path) + } +} + +func TestFindExecutableDistinguishesNonExecutable(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + if err := os.WriteFile(filepath.Join(dir, "test-agent"), []byte("not executable"), 0o644); err != nil { + t.Fatal(err) + } + + _, err := findExecutable("test-agent", dir) + if err == nil { + t.Fatal("findExecutable() error = nil, want non-executable error") + } + if !strings.Contains(err.Error(), "not executable") { + t.Fatalf("findExecutable() error = %q, want not executable", err) + } +} + +func TestFindExecutableSkipsNonExecutablePathMatch(t *testing.T) { + t.Parallel() + + firstDir := t.TempDir() + secondDir := t.TempDir() + if err := os.WriteFile(filepath.Join(firstDir, "test-agent"), []byte("not executable"), 0o644); err != nil { + t.Fatal(err) + } + + want := filepath.Join(secondDir, "test-agent") + if err := os.WriteFile(want, []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatal(err) + } + + got, err := findExecutable("test-agent", firstDir+string(os.PathListSeparator)+secondDir) + if err != nil { + t.Fatalf("findExecutable() error = %v", err) + } + if got != want { + t.Fatalf("findExecutable() = %q, want %q", got, want) + } +} + +func TestFindExecutableRejectsRelativePathMatch(t *testing.T) { + dir := t.TempDir() + t.Chdir(dir) + if err := os.Mkdir("bin", 0o755); err != nil { + t.Fatal(err) + } + if err := os.WriteFile(filepath.Join("bin", "test-agent"), []byte("#!/bin/sh\nexit 0\n"), 0o755); err != nil { + t.Fatal(err) + } + + _, err := findExecutable("test-agent", "bin") + if !errors.Is(err, exec.ErrDot) { + t.Fatalf("findExecutable() error = %v, want exec.ErrDot", err) + } +} + +func TestFindExecutableDistinguishesMissing(t *testing.T) { + t.Parallel() + + _, err := findExecutable("test-agent", t.TempDir()) + if err == nil { + t.Fatal("findExecutable() error = nil, want missing error") + } + if !strings.Contains(err.Error(), "not found") { + t.Fatalf("findExecutable() error = %q, want not found", err) + } +} + +func TestLaunchAgentWithSettingsReturnsAgentExitError(t *testing.T) { + t.Parallel() + if runtime.GOOS == "windows" { + t.Skip("shell script launch test is POSIX-specific") + } + + dir := t.TempDir() + path := filepath.Join(dir, "test-agent") + if err := os.WriteFile(path, []byte("#!/bin/sh\nexit 42\n"), 0o755); err != nil { + t.Fatal(err) + } + + err := launchAgentWithSettings(context.Background(), "test-agent", path, os.Environ(), nil, "") + var exitErr *AgentExitError + if !errors.As(err, &exitErr) { + t.Fatalf("launchAgentWithSettings() error = %T, want *AgentExitError", err) + } + if exitErr.ExitCode() != 42 { + t.Fatalf("ExitCode() = %d, want 42", exitErr.ExitCode()) + } +} + func TestFetchConnectURL(t *testing.T) { t.Parallel()