Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 24 additions & 66 deletions cmd/amux/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"github.com/charmbracelet/x/term"

"github.com/andyrewlee/amux/internal/app"
"github.com/andyrewlee/amux/internal/cli"
"github.com/andyrewlee/amux/internal/logging"
"github.com/andyrewlee/amux/internal/safego"
)
Expand All @@ -32,89 +31,48 @@ var (
date = "unknown"
)

// CLI subcommands that route to the headless CLI.
var cliCommands = map[string]bool{
"status": true, "doctor": true, "logs": true,
"workspace": true, "agent": true, "session": true, "project": true,
"terminal": true,
"capabilities": true,
"version": true, "help": true,
}

func main() {
// Handle --version flag
if len(os.Args) > 1 && (os.Args[1] == "--version" || os.Args[1] == "-v") {
args := os.Args[1:]

if isVersionInvocation(args) {
fmt.Printf("amux %s (commit: %s, built: %s)\n", version, commit, date)
os.Exit(0)
}

sub, parseErr := classifyInvocation(os.Args[1:])
if parseErr != nil {
// Let the headless CLI render the canonical parse error response.
code := cli.Run(os.Args[1:], version, commit, date)
os.Exit(code)
}

// Route to CLI if a known subcommand is given (even with leading global flags).
if sub != "" {
if cliCommands[sub] {
code := cli.Run(os.Args[1:], version, commit, date)
os.Exit(code)
}
if sub == "tui" {
// Launch TUI unconditionally.
runTUI()
return
}
if len(args) > 0 {
fmt.Fprintln(os.Stderr, unsupportedInvocationMessage(args[0]))
os.Exit(2)
}

// No subcommand: TTY → TUI, non-TTY → delegate to headless CLI.
if sub == "" {
launchTUI := shouldLaunchTUI(
term.IsTerminal(os.Stdin.Fd()),
term.IsTerminal(os.Stdout.Fd()),
term.IsTerminal(os.Stderr.Fd()),
)
if handled, code := handleNoSubcommand(os.Args[1:], launchTUI); handled {
os.Exit(code)
}
runTUI()
return
if !shouldLaunchTUI(
term.IsTerminal(os.Stdin.Fd()),
term.IsTerminal(os.Stdout.Fd()),
term.IsTerminal(os.Stderr.Fd()),
) {
fmt.Fprintln(os.Stderr, nonInteractiveMessage())
os.Exit(1)
}

// Unknown argument: route through CLI for JSON-aware error handling
code := cli.Run(os.Args[1:], version, commit, date)
os.Exit(code)
}

func firstCLIArg(args []string) string {
sub, _ := classifyInvocation(args)
return sub
runTUI()
}

func classifyInvocation(args []string) (string, error) {
_, rest, err := cli.ParseGlobalFlags(args)
if err != nil {
return "", err
}
if len(rest) == 0 {
return "", nil
}
return rest[0], nil
func isVersionInvocation(args []string) bool {
return len(args) == 1 && (args[0] == "--version" || args[0] == "-v")
}

func shouldLaunchTUI(stdinIsTTY, stdoutIsTTY, stderrIsTTY bool) bool {
return stdinIsTTY && stdoutIsTTY && stderrIsTTY
}

func handleNoSubcommand(args []string, launchTUI bool) (bool, int) {
if len(args) > 0 {
return true, cli.Run(args, version, commit, date)
}
if launchTUI {
return false, 0
func unsupportedInvocationMessage(arg string) string {
if arg == "tui" {
return "run `amux` directly to start the terminal UI."
}
return true, cli.Run(args, version, commit, date)
return fmt.Sprintf("unexpected argument %q. Run `amux` to start the terminal UI or `amux --version`.", arg)
}

func nonInteractiveMessage() string {
return "amux starts an interactive terminal UI and requires stdin, stdout, and stderr to be TTYs."
}

func runTUI() {
Expand Down
222 changes: 24 additions & 198 deletions cmd/amux/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,6 @@
package main

import (
"encoding/json"
"io"
"os"
"strings"
"testing"
"time"
Expand Down Expand Up @@ -46,155 +43,50 @@ func TestMouseWheelThrottleIndependent(t *testing.T) {
}
}

func TestFirstCLIArgSkipsLeadingGlobalFlags(t *testing.T) {
func TestIsVersionInvocation(t *testing.T) {
tests := []struct {
name string
args []string
want string
want bool
}{
{
name: "json status",
args: []string{"--json", "status"},
want: "status",
},
{
name: "quiet doctor",
args: []string{"-q", "doctor"},
want: "doctor",
},
{
name: "cwd workspace list",
args: []string{"--cwd", "/tmp/repo", "workspace", "list"},
want: "workspace",
},
{
name: "timeout logs tail",
args: []string{"--timeout=5s", "logs", "tail"},
want: "logs",
},
{
name: "request-id capabilities",
args: []string{"--request-id", "req-1", "capabilities"},
want: "capabilities",
},
{
name: "only globals",
args: []string{"--json"},
want: "",
},
{name: "long flag", args: []string{"--version"}, want: true},
{name: "short flag", args: []string{"-v"}, want: true},
{name: "no args", args: nil, want: false},
{name: "unexpected command", args: []string{"status"}, want: false},
{name: "extra args after version", args: []string{"--version", "status"}, want: false},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := firstCLIArg(tt.args); got != tt.want {
t.Fatalf("firstCLIArg() = %q, want %q", got, tt.want)
if got := isVersionInvocation(tt.args); got != tt.want {
t.Fatalf("isVersionInvocation() = %v, want %v", got, tt.want)
}
})
}
}

func TestClassifyInvocation(t *testing.T) {
func TestUnsupportedInvocationMessage(t *testing.T) {
tests := []struct {
name string
args []string
wantSub string
wantErr bool
name string
arg string
want string
}{
{
name: "global-only",
args: []string{"--json"},
wantSub: "",
},
{
name: "global-prefix-with-subcommand",
args: []string{"--json", "status"},
wantSub: "status",
},
{
name: "malformed-timeout",
args: []string{"--timeout=abc"},
wantErr: true,
},
{name: "unexpected command", arg: "status", want: `unexpected argument "status"`},
{name: "tui subcommand hint", arg: "tui", want: "run `amux` directly to start the terminal UI"},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotSub, err := classifyInvocation(tt.args)
if tt.wantErr {
if err == nil {
t.Fatalf("classifyInvocation() expected error, got nil")
}
return
}
if err != nil {
t.Fatalf("classifyInvocation() unexpected error: %v", err)
}
if gotSub != tt.wantSub {
t.Fatalf("classifyInvocation() = %q, want %q", gotSub, tt.wantSub)
if got := unsupportedInvocationMessage(tt.arg); !strings.Contains(got, tt.want) {
t.Fatalf("unsupportedInvocationMessage() = %q, want substring %q", got, tt.want)
}
})
}
}

func TestHandleNoSubcommandNonTTYRoutesThroughCLIJSON(t *testing.T) {
code, stdout, stderr := runHandleNoSubcommandCaptured(t, []string{"--json"}, false)
if code != 2 {
t.Fatalf("handleNoSubcommand() code = %d, want 2", code)
}
if strings.TrimSpace(stderr) != "" {
t.Fatalf("expected empty stderr in --json mode, got %q", stderr)
}

var env struct {
OK bool `json:"ok"`
Error *struct {
Code string `json:"code"`
} `json:"error"`
}
if err := json.Unmarshal([]byte(stdout), &env); err != nil {
t.Fatalf("json.Unmarshal() error = %v\nraw: %s", err, stdout)
}
if env.OK {
t.Fatalf("expected ok=false")
}
if env.Error == nil || env.Error.Code != "usage_error" {
t.Fatalf("expected usage_error, got %#v", env.Error)
}
}

func TestHandleNoSubcommandTTYSignalsTUIFlow(t *testing.T) {
handled, code := handleNoSubcommand(nil, true)
if handled {
t.Fatalf("expected handled=false when stdin is a TTY")
}
if code != 0 {
t.Fatalf("expected code=0 for TTY path, got %d", code)
}
}

func TestHandleNoSubcommandTTYWithJSONRoutesThroughCLI(t *testing.T) {
code, stdout, stderr := runHandleNoSubcommandCaptured(t, []string{"--json"}, true)
if code != 2 {
t.Fatalf("handleNoSubcommand() code = %d, want 2", code)
}
if strings.TrimSpace(stderr) != "" {
t.Fatalf("expected empty stderr in --json mode, got %q", stderr)
}

var env struct {
OK bool `json:"ok"`
Error *struct {
Code string `json:"code"`
} `json:"error"`
}
if err := json.Unmarshal([]byte(stdout), &env); err != nil {
t.Fatalf("json.Unmarshal() error = %v\nraw: %s", err, stdout)
}
if env.OK {
t.Fatalf("expected ok=false")
}
if env.Error == nil || env.Error.Code != "usage_error" {
t.Fatalf("expected usage_error, got %#v", env.Error)
func TestNonInteractiveMessage(t *testing.T) {
if got := nonInteractiveMessage(); !strings.Contains(got, "interactive terminal") {
t.Fatalf("nonInteractiveMessage() = %q, want interactive-terminal guidance", got)
}
}

Expand All @@ -206,34 +98,10 @@ func TestShouldLaunchTUIRequiresAllTTYStreams(t *testing.T) {
stderrTTY bool
want bool
}{
{
name: "all tty",
stdinTTY: true,
stdoutTTY: true,
stderrTTY: true,
want: true,
},
{
name: "stdout redirected",
stdinTTY: true,
stdoutTTY: false,
stderrTTY: true,
want: false,
},
{
name: "stdin non tty",
stdinTTY: false,
stdoutTTY: true,
stderrTTY: true,
want: false,
},
{
name: "stderr non tty",
stdinTTY: true,
stdoutTTY: true,
stderrTTY: false,
want: false,
},
{name: "all tty", stdinTTY: true, stdoutTTY: true, stderrTTY: true, want: true},
{name: "stdout redirected", stdinTTY: true, stdoutTTY: false, stderrTTY: true, want: false},
{name: "stdin non tty", stdinTTY: false, stdoutTTY: true, stderrTTY: true, want: false},
{name: "stderr non tty", stdinTTY: true, stdoutTTY: true, stderrTTY: false, want: false},
}

for _, tt := range tests {
Expand All @@ -244,45 +112,3 @@ func TestShouldLaunchTUIRequiresAllTTYStreams(t *testing.T) {
})
}
}

func runHandleNoSubcommandCaptured(t *testing.T, args []string, stdinIsTTY bool) (int, string, string) {
t.Helper()

origStdout := os.Stdout
origStderr := os.Stderr
stdoutR, stdoutW, err := os.Pipe()
if err != nil {
t.Fatalf("os.Pipe(stdout) error = %v", err)
}
stderrR, stderrW, err := os.Pipe()
if err != nil {
t.Fatalf("os.Pipe(stderr) error = %v", err)
}
os.Stdout = stdoutW
os.Stderr = stderrW
defer func() {
os.Stdout = origStdout
os.Stderr = origStderr
}()

handled, code := handleNoSubcommand(args, stdinIsTTY)
if !handled {
t.Fatalf("expected handled=true for non-TTY path")
}

_ = stdoutW.Close()
_ = stderrW.Close()

stdoutBytes, readStdoutErr := io.ReadAll(stdoutR)
if readStdoutErr != nil {
t.Fatalf("io.ReadAll(stdout) error = %v", readStdoutErr)
}
stderrBytes, readStderrErr := io.ReadAll(stderrR)
if readStderrErr != nil {
t.Fatalf("io.ReadAll(stderr) error = %v", readStderrErr)
}
_ = stdoutR.Close()
_ = stderrR.Close()

return code, string(stdoutBytes), string(stderrBytes)
}
Loading
Loading