Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 102 additions & 16 deletions web/backend/api/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,14 @@ import (

// gateway holds the state for the managed gateway process.
var gateway = struct {
mu sync.Mutex
cmd *exec.Cmd
owned bool // true if we started the process, false if we attached to an existing one
bootDefaultModel string
runtimeStatus string
startupDeadline time.Time
logs *LogBuffer
mu sync.Mutex
cmd *exec.Cmd
owned bool // true if we started the process, false if we attached to an existing one
bootDefaultModel string
bootConfigSignature string
runtimeStatus string
startupDeadline time.Time
logs *LogBuffer
}{
runtimeStatus: "stopped",
logs: NewLogBuffer(200),
Expand Down Expand Up @@ -177,14 +178,93 @@ func lookupModelConfig(cfg *config.Config, modelName string) *config.ModelConfig
return modelCfg
}

func gatewayRestartRequired(configDefaultModel, bootDefaultModel, gatewayStatus string) bool {
func computeConfigSignature(cfg *config.Config) string {
if cfg == nil {
return ""
}
var parts []string
defaultModel := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
if defaultModel != "" {
parts = append(parts, "model:"+defaultModel)
}
toolSignatures := []string{}
if cfg.Tools.ReadFile.Enabled {
toolSignatures = append(toolSignatures, "read_file")
}
if cfg.Tools.WriteFile.Enabled {
toolSignatures = append(toolSignatures, "write_file")
}
if cfg.Tools.ListDir.Enabled {
toolSignatures = append(toolSignatures, "list_dir")
}
if cfg.Tools.EditFile.Enabled {
toolSignatures = append(toolSignatures, "edit_file")
}
if cfg.Tools.AppendFile.Enabled {
toolSignatures = append(toolSignatures, "append_file")
}
if cfg.Tools.Exec.Enabled {
toolSignatures = append(toolSignatures, "exec")
}
if cfg.Tools.Cron.Enabled {
toolSignatures = append(toolSignatures, "cron")
}
if cfg.Tools.Web.Enabled {
toolSignatures = append(toolSignatures, "web")
}
if cfg.Tools.WebFetch.Enabled {
toolSignatures = append(toolSignatures, "web_fetch")
}
if cfg.Tools.Message.Enabled {
toolSignatures = append(toolSignatures, "message")
}
if cfg.Tools.SendFile.Enabled {
toolSignatures = append(toolSignatures, "send_file")
}
if cfg.Tools.FindSkills.Enabled {
toolSignatures = append(toolSignatures, "find_skills")
}
if cfg.Tools.InstallSkill.Enabled {
toolSignatures = append(toolSignatures, "install_skill")
}
if cfg.Tools.Spawn.Enabled {
toolSignatures = append(toolSignatures, "spawn")
}
if cfg.Tools.SpawnStatus.Enabled {
toolSignatures = append(toolSignatures, "spawn_status")
}
if cfg.Tools.I2C.Enabled {
toolSignatures = append(toolSignatures, "i2c")
}
if cfg.Tools.SPI.Enabled {
toolSignatures = append(toolSignatures, "spi")
}
if cfg.Tools.MCP.Enabled {
toolSignatures = append(toolSignatures, "mcp")
}
if cfg.Tools.MCP.Discovery.Enabled {
toolSignatures = append(toolSignatures, "mcp_discovery")
}
if cfg.Tools.MCP.Discovery.UseRegex {
toolSignatures = append(toolSignatures, "mcp_discovery_regex")
}
if cfg.Tools.MCP.Discovery.UseBM25 {
toolSignatures = append(toolSignatures, "mcp_discovery_bm25")
}
if len(toolSignatures) > 0 {
parts = append(parts, "tools:"+strings.Join(toolSignatures, ","))
}
return strings.Join(parts, ";")
}

func gatewayRestartRequiredBySignature(bootSignature, currentSignature, gatewayStatus string) bool {
if gatewayStatus != "running" {
return false
}
if strings.TrimSpace(configDefaultModel) == "" || strings.TrimSpace(bootDefaultModel) == "" {
if bootSignature == "" || currentSignature == "" {
return false
}
return configDefaultModel != bootDefaultModel
return bootSignature != currentSignature
}

func isCmdProcessAliveLocked(cmd *exec.Cmd) bool {
Expand Down Expand Up @@ -228,10 +308,11 @@ func attachToGatewayProcessLocked(pid int, cfg *config.Config) error {
gateway.owned = false // We didn't start this process
setGatewayRuntimeStatusLocked("running")

// Update bootDefaultModel from config
// Update bootDefaultModel and bootConfigSignature from config
if cfg != nil {
defaultModelName := strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
gateway.bootDefaultModel = defaultModelName
gateway.bootConfigSignature = computeConfigSignature(cfg)
}

logger.InfoC("gateway", fmt.Sprintf("Attached to gateway process (PID: %d)", pid))
Expand Down Expand Up @@ -419,6 +500,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
gateway.cmd = cmd
gateway.owned = true // We started this process
gateway.bootDefaultModel = defaultModelName
gateway.bootConfigSignature = computeConfigSignature(cfg)
setGatewayRuntimeStatusLocked(initialStatus)
pid = cmd.Process.Pid
logger.InfoC("gateway", fmt.Sprintf("Started picoclaw gateway (PID: %d) from %s", pid, execPath))
Expand All @@ -439,6 +521,7 @@ func (h *Handler) startGatewayLocked(initialStatus string, existingPid int) (int
if gateway.cmd == cmd {
gateway.cmd = nil
gateway.bootDefaultModel = ""
gateway.bootConfigSignature = ""
if gateway.runtimeStatus != "restarting" {
setGatewayRuntimeStatusLocked("stopped")
}
Expand Down Expand Up @@ -713,7 +796,7 @@ func (h *Handler) handleGatewayStatus(w http.ResponseWriter, r *http.Request) {

func (h *Handler) gatewayStatusData() map[string]any {
data := map[string]any{}
configDefaultModel := ""
var configDefaultModel string
cfg, cfgErr := config.LoadConfig(h.configPath)
if cfgErr == nil && cfg != nil {
configDefaultModel = strings.TrimSpace(cfg.Agents.Defaults.GetModelName())
Expand Down Expand Up @@ -784,11 +867,14 @@ func (h *Handler) gatewayStatusData() map[string]any {
}
}

bootDefaultModel, _ := data["boot_default_model"].(string)
gatewayStatus, _ := data["gateway_status"].(string)
data["gateway_restart_required"] = gatewayRestartRequired(
configDefaultModel,
bootDefaultModel,
currentConfigSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
bootConfigSignature := gateway.bootConfigSignature
gateway.mu.Unlock()
data["gateway_restart_required"] = gatewayRestartRequiredBySignature(
bootConfigSignature,
currentConfigSignature,
gatewayStatus,
)

Expand Down
185 changes: 185 additions & 0 deletions web/backend/api/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ func resetGatewayTestState(t *testing.T) {
gateway.mu.Lock()
gateway.cmd = nil
gateway.bootDefaultModel = ""
gateway.bootConfigSignature = ""
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()
})
Expand Down Expand Up @@ -502,9 +503,11 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
t.Fatalf("FindProcess() error = %v", err)
}

bootSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
gateway.cmd = &exec.Cmd{Process: process}
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
gateway.bootConfigSignature = bootSignature
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()

Expand Down Expand Up @@ -548,6 +551,188 @@ func TestGatewayStatusRequiresRestartAfterDefaultModelChange(t *testing.T) {
}
}

func TestGatewayStatusRequiresRestartAfterToolChange(t *testing.T) {
resetGatewayTestState(t)

configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].SetAPIKey("test-key")
cfg.Tools.WriteFile.Enabled = true
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}

h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)

process, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}

bootSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
gateway.cmd = &exec.Cmd{Process: process}
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
gateway.bootConfigSignature = bootSignature
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()

updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
updatedCfg.Tools.WriteFile.Enabled = false
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}

gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
}

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}

var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}

if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["gateway_restart_required"]; got != true {
t.Fatalf("gateway_restart_required = %#v, want true", got)
}
}

func TestGatewayStatusNoRestartRequiredForNonSensitiveChanges(t *testing.T) {
resetGatewayTestState(t)

configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].SetAPIKey("test-key")
cfg.Agents.Defaults.MaxTokens = 1000
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}

h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)

process, err := os.FindProcess(os.Getpid())
if err != nil {
t.Fatalf("FindProcess() error = %v", err)
}

bootSignature := computeConfigSignature(cfg)
gateway.mu.Lock()
gateway.cmd = &exec.Cmd{Process: process}
gateway.bootDefaultModel = cfg.ModelList[0].ModelName
gateway.bootConfigSignature = bootSignature
setGatewayRuntimeStatusLocked("running")
gateway.mu.Unlock()

updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
updatedCfg.Agents.Defaults.MaxTokens = 2000
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}

gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return mockGatewayHealthResponse(http.StatusOK, os.Getpid()), nil
}

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}

var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}

if got := body["gateway_status"]; got != "running" {
t.Fatalf("gateway_status = %#v, want %q", got, "running")
}
if got := body["gateway_restart_required"]; got != false {
t.Fatalf("gateway_restart_required = %#v, want false", got)
}
}

func TestGatewayStatusNoRestartRequiredWhenNotRunning(t *testing.T) {
resetGatewayTestState(t)

configPath := filepath.Join(t.TempDir(), "config.json")
cfg := config.DefaultConfig()
cfg.Agents.Defaults.ModelName = cfg.ModelList[0].ModelName
cfg.ModelList[0].SetAPIKey("test-key")
if err := config.SaveConfig(configPath, cfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}

h := NewHandler(configPath)
mux := http.NewServeMux()
h.RegisterRoutes(mux)

gateway.mu.Lock()
gateway.cmd = nil
gateway.bootDefaultModel = ""
gateway.bootConfigSignature = ""
setGatewayRuntimeStatusLocked("stopped")
gateway.mu.Unlock()

updatedCfg, err := config.LoadConfig(configPath)
if err != nil {
t.Fatalf("LoadConfig() error = %v", err)
}
updatedCfg.Agents.Defaults.ModelName = "different-model"
if err := config.SaveConfig(configPath, updatedCfg); err != nil {
t.Fatalf("SaveConfig() error = %v", err)
}

gatewayHealthGet = func(string, time.Duration) (*http.Response, error) {
return nil, errors.New("no gateway running")
}

rec := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/api/gateway/status", nil)
mux.ServeHTTP(rec, req)

if rec.Code != http.StatusOK {
t.Fatalf("status = %d, want %d", rec.Code, http.StatusOK)
}

var body map[string]any
if err := json.Unmarshal(rec.Body.Bytes(), &body); err != nil {
t.Fatalf("unmarshal response: %v", err)
}

if got := body["gateway_status"]; got != "stopped" {
t.Fatalf("gateway_status = %#v, want %q", got, "stopped")
}
if got := body["gateway_restart_required"]; got != false {
t.Fatalf("gateway_restart_required = %#v, want false", got)
}
}

func TestGatewayStatusReturnsErrorAfterStartupWindowExpires(t *testing.T) {
resetGatewayTestState(t)

Expand Down
Loading
Loading