diff --git a/.gitignore b/.gitignore index 47bf342..72a17fc 100644 --- a/.gitignore +++ b/.gitignore @@ -38,6 +38,6 @@ Thumbs.db *.log # Runner workdirs (local dev) -runners/ -cache/ -state/ +/runners/ +/cache/ +/state/ diff --git a/README.md b/README.md index 4ea81ad..1984617 100644 --- a/README.md +++ b/README.md @@ -101,15 +101,22 @@ golangci-lint run # lint (if installed) ``` ghr/ -├── cmd/ghr/main.go # Entrypoint +├── cmd/ghr/main.go # Entrypoint ├── internal/ -│ ├── cli/ # Cobra commands -│ ├── auth/ # Credentials management -│ ├── config/ # YAML + env config -│ ├── runner/ # Binary download & process lifecycle -│ ├── github/ # Scale set SDK adapter -│ ├── model/ # Shared data structs -│ └── logging/ # Structured logging +│ ├── cli/ # Cobra commands (start/stop/run/status/purge/login/...) +│ ├── auth/ # Credentials, JWT signing, installation tokens, breaker +│ ├── config/ # YAML + env loading, validation, defaults +│ ├── controller/ # Scale-set orchestration and per-group scaler +│ ├── runner/ # Binary download (with SHA-256 verify) and process lifecycle +│ ├── github/ # scaleset SDK adapter +│ ├── health/ # Health monitor and check functions +│ ├── notification/ # Discord + webhook providers with filtering +│ ├── monitoring/ # Uptime Kuma push reporter +│ ├── api/ # Unix-socket JSON API exposing status/health +│ ├── launchd/ # macOS service install/uninstall via bootstrap/bootout +│ ├── state/ # Centralized daemon-state file paths (pid/sock/state) +│ ├── model/ # Shared data structs (no logic) +│ └── logging/ # Structured logging, rotation, tagged runner output ├── go.mod └── go.sum ``` diff --git a/audit.md b/audit.md new file mode 100644 index 0000000..cbfc1ef --- /dev/null +++ b/audit.md @@ -0,0 +1,798 @@ +# Audit ghr — Code Review & Plan d'amélioration + +> **Périmètre** : audit complet du dépôt `gh-runners-tool` (~5 537 LOC Go hors tests, 4 347 LOC de tests). Toutes les références code utilisent le format `path:line`. Les recommandations sont classées par sévérité ; les features proposées sont en fin de document. +> +> **Méthodologie** : exploration symbolique via Serena, lecture ciblée des chemins critiques (auth, runner, controller, api, launchd), revue croisée avec les rules du projet (`.claude/rules/security.md`, `architecture.md`, `code-cleanliness.md`, `go-style.md`) et les conventions de `CLAUDE.md`. + +--- + +## 0. Synthèse exécutive + +### Forces + +- Architecture **package-by-feature** propre, interfaces consumer-side, DI manuelle lisible dans `cmd/ghr/main.go` → `internal/cli/daemon.go:buildDaemon`. +- Bon usage de `oklog/run` pour le lifecycle daemon. +- Structure de tests honnête sur les packages bas-niveau (auth, runner, notification, logging, config, health). +- Conventions de logging structuré (`log/slog`) cohérentes, avec rotation par date et multi-handler. +- Linter strict (`gocritic`, `errorlint`, `nilerr`, `prealloc`, `unparam`, `exhaustive`, `contextcheck`…) et `govulncheck` câblé. +- Configuration YAML + env propre, avec validation explicite et messages d'erreur agrégés (`errors.Join`). +- Retry exponentiel sur le listener de groupe (`internal/controller/group.go:nextBackoff`). + +### Faiblesses majeures + +1. **Path traversal exploitable** dans l'extraction tar du runner (`internal/runner/download.go:69-71`) — une archive malicieusement fabriquée peut créer un symlink hors du dossier de cache. +2. **Aucune vérification d'intégrité** du binaire runner téléchargé (pas de SHA-256, pas de signature). Le code fait confiance aveugle au tarball GitHub. +3. **`pgrep -f workdirBase`** (`internal/runner/cleanup.go:KillOrphanRunners`) peut tuer des processus utilisateur non-ghr si `workdir_base` est court ou trop large (ex. `/tmp`). +4. **Métriques de santé jamais alimentées** : `UpdateGroupStats`, `RecordStartFailure`, `RecordStartSuccess` n'ont aucun call-site en production → `checkGroupDivergence` et `checkConsecutiveFailures` ne déclenchent jamais. Sécurité by-design désactivée silencieusement. +5. **Notifications synchrones sous mutex** dans `internal/health/checks.go:runChecks` — un Discord lent bloque toute la boucle de health. +6. **Race condition sur le cache** : 2 groupes lançant `EnsureBits` en parallèle pour la même version peuvent corrompre le cache. Détection « cached » basée sur `run.sh` qui apparaît avant la fin de l'extraction. +7. **Unix socket sans permissions explicites** (`internal/api/server.go:Run`) — autres utilisateurs locaux peuvent lire le statut + PIDs. +8. **Aucune commande de réelle observabilité** : pas de `/metrics`, pas d'audit log, pas de `ghr logs`. + +### Vue d'ensemble (heatmap) + +| Catégorie | Critique | Haute | Moyenne | Basse | +|-------------------|:--------:|:-----:|:-------:|:-----:| +| Sécurité | 3 | 6 | 7 | 5 | +| Bugs / Correctness| 2 | 5 | 8 | 6 | +| Architecture | 0 | 2 | 9 | 7 | +| Résilience | 1 | 4 | 6 | 3 | +| Tests | 0 | 3 | 5 | 2 | +| Observabilité | 1 | 3 | 4 | 2 | +| Doc | 0 | 1 | 3 | 4 | + +--- + +## I. Sécurité + +### I.11 🟠 MOYENNE — `http.Server` sans Read/Write/Idle timeouts + +**Fichier** : `internal/api/server.go:53-55`. + +```go +srv := &http.Server{ Handler: s.routes() } +``` + +Pas de protection contre slowloris (faible risque sur Unix socket local mais bonne pratique). + +```go +srv := &http.Server{ + Handler: s.routes(), + ReadHeaderTimeout: 5 * time.Second, + ReadTimeout: 10 * time.Second, + WriteTimeout: 10 * time.Second, + IdleTimeout: 30 * time.Second, +} +``` + +### I.12 🟠 MOYENNE — Pas de vérification de permissions sur le credentials file à la lecture + +**Fichier** : `internal/auth/store.go:loadFromFile:25-35`. + +`os.ReadFile` accepte n'importe quelles permissions. Pendant ce temps `LoadPrivateKey` (`jwt.go:30-46`) refuse les permissions trop larges. Asymétrie : un attaquant local qui a un `chmod 644` accidentel sur les credentials passe inaperçu. + +**Recommandation** : à `loadFromFile`, faire `os.Stat` et warning (pas error) si `mode & 0o077 != 0`. + +### I.14 🟠 MOYENNE — Erreurs PAT contiennent le body HTTP brut + +**Fichier** : `internal/auth/validate.go:46-48`. + +```go +return nil, fmt.Errorf("validate PAT: GitHub API returned %d: %s", resp.StatusCode, string(body)) +``` + +Un body GitHub d'erreur peut inclure les headers de rate-limit, l'IP, ou des en-têtes de debug. À truncate comme `installations.go:truncateBody` (déjà existant) → utiliser cette fonction partout. + +### I.15 🟠 MOYENNE — `parseScopes` retourne `nil` quand le header est absent + +**Fichier** : `internal/auth/validate.go:64-77`. + +Pour les **fine-grained PATs**, GitHub ne renvoie pas `X-OAuth-Scopes`. Le token semble valide même s'il n'a pas la permission `administration:write`. Le ghr découvrira l'erreur seulement à `CreateScaleSet`. + +**Recommandation** : +1. Détecter `X-GitHub-Token-Type: github-pat` (fine-grained) et avertir l'opérateur que les scopes ne peuvent pas être vérifiés. +2. Tenter un `GET /installation/repositories` ou un appel cible (`GET /orgs/{org}/actions/runner-groups`) pour valider l'autorisation effective. + +### I.16 🟠 MOYENNE — `daemon.pid` et `daemon.state.json` en `0o644` + +**Fichier** : `internal/cli/daemon.go:writePIDFile:165-175`, `internal/cli/state.go:writeDaemonState:30-37`. + +PID et `config_path` sont des infos d'environnement modérément sensibles. `0o600` est suffisant et cohérent avec `credentials.json`. + +### I.17 🟡 BASSE — `installation.go:http.DefaultClient` + +Cf I.5. Déjà couvert. + +### I.18 🟡 BASSE — `MaskedPAT` retourne `****` pour PAT court + +**Fichier** : `internal/auth/validate.go:96-101`. + +Le seuil `< 12` est arbitraire. Les PATs `ghp_` ont 40 char, les fine-grained `github_pat_` plus longs. Aucun PAT légitime n'a < 12 char donc OK. Bonus : ajouter le préfixe de type (`ghp_`, `github_pat_`, `ghs_`) dans le masquage pour aider au debug. + +### I.19 🟡 BASSE — `JWT exp` à 9 minutes au lieu du max 10 + +**Fichier** : `internal/auth/jwt.go:16-19`. OK marge raisonnable pour clock skew, à laisser tel quel. + +### I.20 🟡 BASSE — Goreleaser ne signe pas les binaires + +**Fichier** : `.goreleaser.yml`. + +Aucun bloc `signs:` ni `notarize:`. Sur macOS un binaire non-notarié déclenchera Gatekeeper. Et pas de signature cosign/GPG des assets. + +**Recommandation** : +```yaml +signs: + - cmd: cosign + args: ["sign-blob", "--yes", "--output-signature=${signature}", "${artifact}"] + artifacts: all +notarize: + macos: + - sign: { certificate: "{{ .Env.MACOS_SIGN_P12 }}", password: "{{ .Env.MACOS_SIGN_PASSWORD }}" } + notarize: { issuer_id: "...", key_id: "...", key: "..." } +``` + +### I.21 🟡 BASSE — Pas de `gosec` dans le pipeline CI + +`.github/workflows/ci.yml` n'exécute pas `gosec`. Aurait probablement attrapé I.3, I.7 et I.9. + +--- + +## II. Bugs / Correctness + +### II.1 ⛔ CRITIQUE — Métriques de santé jamais alimentées + +**Fichiers** : `internal/health/group_state.go:19-43` (`UpdateGroupStats`, `RecordStartFailure`, `RecordStartSuccess`). + +**Vérification** : +``` +$ grep -rn "UpdateGroupStats\|RecordStartFailure\|RecordStartSuccess" internal/ +internal/health/group_state.go: (définitions) +internal/health/group_state_test.go: (tests unitaires) +``` + +Aucun call-site en production. Conséquences : + +- `checkGroupDivergence` (`checks.go:165-195`) retourne immédiatement (`gs.lastDesiredCount == 0`). +- `checkConsecutiveFailures` (`checks.go:197-212`) ne peut jamais émettre l'event `EventHealthGroupFailing`. +- Le system reporting Discord/UptimeKuma marche, mais les events `health.group.failing` et `health.group.degraded` ne sortent jamais. + +**Recommandation** : +1. Dans `MacOSScaler.HandleDesiredRunnerCount` (`scaler.go:69-89`), appeler `m.healthMonitor.UpdateGroupStats(s.groupName, target)`. Cela demande de passer le `*health.Monitor` au scaler via une interface consumer-side `groupStatsReporter`. +2. Dans `startRunner` (`scaler_ops.go:12-55`), sur erreur de `Start`, appeler `RecordStartFailure`. Sur succès, `RecordStartSuccess`. + +```go +// scaler.go - new field +type groupStatsReporter interface { + UpdateGroupStats(group string, desired int) + RecordStartFailure(group string) + RecordStartSuccess(group string) +} +``` + +Sans ce câblage, deux features documentées du produit (divergence detection et consecutive-failure alert) sont du **vaporware**. + +`Discord.throttle()` sleep jusqu'à 2 s. `UptimeKuma.push` peut prendre 30 s en cas de timeout réseau. Pendant ce temps : +- `Monitor.Status()` (called by `/status` HTTP) est bloqué (RLock attendu). +- Le prochain tick `runChecks` accumule. + +**Recommandation** : collecter les notifications/reports localement, libérer le mutex, puis envoyer : + +```go +func (m *Monitor) runChecks(ctx context.Context) { + start := time.Now() + m.mu.Lock() + // ... compute snapshots & issues ... + pending := snapshotPendingNotifs() + m.mu.Unlock() + // dispatch async + go m.dispatchNotifications(ctx, pending) +} +``` + +### II.9 🟠 MOYENNE — Double check de `GITHUB_TOKEN` dans `auth.Load` + +**Fichier** : `internal/auth/load.go:31-36`. + +```go +if token := os.Getenv("GITHUB_TOKEN"); token != "" { // ligne 31, second check + return ..., "env (.env GITHUB_TOKEN)", nil +} +``` + +Le premier check ligne 16 attrape déjà `GITHUB_TOKEN`. La logique semble vouloir distinguer "défini avant" vs "défini par godotenv ailleurs" mais `godotenv` n'est pas appelé dans `Load`. Code mort. + +### II.10 🟠 MOYENNE — `validateGitHubApp` valide trop peu + +**Fichier** : `internal/auth/validate.go:79-94`. + +Ne vérifie que l'ouverture du fichier. Une clé corrompue passe ; l'opérateur découvre le bug à `SignAppJWT` au démarrage du daemon. + +**Recommandation** : appeler `LoadPrivateKey(app.PrivateKeyPath)` qui fait déjà parse RSA + perms. + +### II.12 🟠 MOYENNE — `labelsChanged` détecte mais n'agit pas + +**Fichier** : `internal/controller/group.go:130-147`. + +```go +if labelsChanged(ss.Labels, labels) { + c.logger.WarnContext(ctx, "scale set label mismatch detected, ...") +} +return &resolvedScaleSet{...}, nil // on continue avec l'ancien +``` + +L'opérateur peut changer `labels:` dans la config, faire `ghr restart`, et croire que c'est appliqué. Ce n'est pas. Un warning enterré dans les logs. + +**Recommandation** : +- Émettre un `model.Event{Type: EventConfigDrift, Level: LevelWarning, ...}` → notifié à Discord. +- Soit auto-recreate (DELETE + CREATE), soit fail-fast au démarrage. + +### II.18 🟡 BASSE — `interactiveApp` accepte URL host vide + +**Fichier** : `internal/cli/login_wizard.go:69-72`. + +L'utilisateur peut entrer "" → defaulted à `https://github.com` dans `prepareAppLogin`. OK. Mais le wizard print le prompt `"GitHub host URL [https://github.com]"` qui suggère un default visible — confirmons que `readLine` trim et permet vide. Oui (`login_wizard.go:107-108`). + +### II.19 🟡 BASSE — `nonInteractivePAT` ignore `--host` + +**Fichier** : `internal/cli/login.go:50-67`. + +Pour PAT, on n'a que `--url`. Cohérent. Mais pour GitHub Enterprise, l'URL d'org diffère du host API. Ajouter une note dans la doc ou un flag `--host` optionnel. + +### II.20 🟡 BASSE — `expandHome` accepte chemins absolus mais pas Windows `%USERPROFILE%` + +**Fichier** : `internal/cli/login_app.go:125-137`. + +Non-issue pour macOS. À ignorer. + +--- + +## III. Architecture / Maintenabilité + +### III.1 🔴 HAUTE — Fichiers au-delà de la limite 200 LOC + +`.claude/rules/code-cleanliness.md` : "Source files must stay under 200 LOC". + +| Fichier | LOC | Statut | +|------------------------------------|-----|--------| +| internal/health/checks.go | 213 | dépasse | +| internal/controller/group.go | 191 | proche | +| internal/cli/daemon.go | 188 | proche | +| internal/controller/scaler.go | 187 | proche | +| internal/cli/purge.go | 181 | proche | +| internal/logging/manager.go | 180 | proche | + +**Recommandations** : +- Split `checks.go` en `checks_liveness.go`, `checks_timeouts.go`, `checks_divergence.go`. +- `daemon.go` → `daemon_build.go`, `daemon_pid.go`, `daemon_url.go`. +- `group.go` → `group_run.go`, `group_resolve.go`, `group_backoff.go`. +- `purge.go` → split `purge_daemon.go`, `purge_scalesets.go`, `purge_cleanup.go`. + +### III.2 🔴 HAUTE — Interface `scaleSetClient` à 7 méthodes + +**Fichier** : `internal/controller/controller.go:16-23`. + +`.claude/rules/architecture.md` : "Consumer-side interfaces are unexported (lowercase) and minimal (1-3 methods)". + +L'interface a 7 méthodes. Acceptable car single-concern (scale set operations), mais à minima la documenter comme "façade volontaire". + +**Recommandation** : si split, `scaleSetLifecycle` (Create/Get/Delete) + `scaleSetSession` (OpenSession/NewListener/GenerateJITConfig). + +### III.3 🟠 MOYENNE — Pas de spec dans `specs/` malgré CLAUDE.md + +`CLAUDE.md` documente : +> All specs in `specs/`. Read before implementing: +> - `00-architecture.md` ... +> - `01-core-scaleset.md` ... + +Mais aucun dossier `specs/` n'existe. Soit le doc est obsolète, soit les specs ont été supprimées. Conséquence : nouvel arrivant lit `CLAUDE.md`, cherche les specs, est perdu. + +**Recommandation** : régénérer les specs ou retirer la section de `CLAUDE.md`. + +### III.4 🟠 MOYENNE — README divergent du code + +**Fichier** : `README.md`. + +```md +Repository Structure +├── internal/ +│ ├── cli/ # Cobra commands +│ ├── auth/ # Credentials management +│ ├── config/ # YAML + env config +│ ├── runner/ # Binary download & process lifecycle +│ ├── github/ # Scale set SDK adapter +│ ├── model/ # Shared data structs +│ └── logging/ # Structured logging +``` + +Manquent : `controller/`, `health/`, `notification/`, `monitoring/`, `api/`, `launchd/`. Et `internal/scaleset` n'existe pas (c'est `internal/github`). + +### III.5 🟠 MOYENNE — Linter exception `nilerr` pour `internal/cli/auth.go` + +**Fichier** : `.golangci.yml:64-65`. + +L'exception est intentionnelle (status command preserves exit 0 on validation errors pour scripting). À minima documenter la raison dans un commentaire au-dessus de la fonction (`auth.go:newAuthStatusCmd`). + +### III.7 🟠 MOYENNE — `Duration.MarshalYAML` non implémenté + +**Fichier** : `internal/config/types.go:96-98`. + +`UnmarshalYAML` existe (visible via overview), mais l'overview ne montre pas si `MarshalYAML` est correct. Symptôme : si on veut dumper la config résolue (commande `ghr config print` à venir), `Duration` deviendrait `0` au lieu de `"30s"`. + +### III.8 🟠 MOYENNE — Pas de `--dry-run` pour les commandes destructrices + +`purge`, `restart`, `stop --force` ne supportent pas `--dry-run`. Pour un outil qui touche aux processus système et au scale set GitHub, c'est précieux. + +### III.11 🟠 MOYENNE — `notification/service.go:Notify` séquentiel sur les providers + +**Fichier** : `internal/notification/service.go:40-54`. + +Si Discord prend 2 s, le webhook attend 2 s. Pas critique car il n'y a souvent qu'1 provider, mais avec 3 providers + retry, la latence cumule. + +**Recommandation** : dispatch parallèle avec `errgroup` ou `sync.WaitGroup`. + +### III.12 🟡 BASSE — `internal/model/event.go` mêle types et constants + +47 LOC mixant struct + 4 levels + 16 events. À garder pour le moment, c'est le "shared types" pattern correct. + +### III.13 🟡 BASSE — `internal/notification/discord_payload.go` couleurs hardcodées + +`colorForLevel` (non lue ici mais évoquée) : à exposer en config si on veut customiser. + +### III.15 🟡 BASSE — `internal/launchd/service.go:Status` substring match + +**Fichier** : `internal/launchd/service.go:77-102`. + +```go +if !strings.Contains(line, label) { continue } +``` + +Si un autre service a un nom contenant `com.ghr.daemon`, faux positif. Fonction suivante check exact `fields[2] != label` qui rattrape — OK mais ordre des vérifications inversé. + +### III.16 🟡 BASSE — `RunnerSnapshot.PID` exposed dans l'API JSON + +Pour un Unix socket local c'est OK, mais si on expose un jour un HTTP authentifié, exposer les PIDs facilite l'exploitation. + +### III.17 🟡 BASSE — Pas de séparation interface/impl pour les notifications + +`Service` et `Provider` cohabitent dans `service.go`. OK mais évoluera mal avec d'autres providers (Slack, Teams, Telegram). Préparer le terrain en isolant `notification/internal/discord/`, `internal/webhook/`, etc. + +### III.18 🟡 BASSE — `pgrep -f` non documenté + +L'opérateur ne sait pas que `KillOrphanRunners` peut tuer ses processus si workdir_base est mal configuré. + +--- + +## IV. Résilience / Robustesse + +Doubling sans jitter → thundering herd si N groupes se cassent en même temps (panne GitHub → tous retry au même tick). + +**Recommandation** : ajouter ±20 % de jitter (`rand.Float64()`). + +### IV.11 🟡 BASSE — Pas de `--max-age` sur le cache + +Cf. IV.5. + +### IV.12 🟡 BASSE — Aucune coordination multi-instance + +Si deux daemons ghr tournent (par accident), ils gèrent les mêmes scale sets → comportement chaotique. Aucune leader election (lockfile, advisory file lock). + +**Recommandation** : `flock` sur `daemon.pid` au démarrage. Si déjà locké → exit avec message clair. + +--- + +## V. Tests + +### V.1 🔴 HAUTE — Aucun test sur la couche CLI + +**Fichiers** : `internal/cli/{login,start,stop,run,status,purge,restart,state,daemon}.go`. + +Tous ces fichiers contiennent de la logique (validation flags, chemins, conditionals). Aucun test. Couverture estimée < 10 % sur `internal/cli/`. + +**Recommandation** : tests d'intégration via `cobra.Command.Execute()` avec args en table-driven, et FS mocké via `t.TempDir()`. + +### V.2 🔴 HAUTE — Pas de test sur l'extraction tar + +**Fichier** : `internal/runner/download.go` (extractTarGz, sanitizeTarPath, extractFile). + +Le code le plus exposé en sécurité n'a aucun test. Le bug I.1 (symlink traversal) aurait été détecté par un test couvrant le cas TypeSymlink hors path. + +**Recommandation** : table-driven tests avec tarballs forgés via `archive/tar` en mémoire : + +```go +tests := []struct { + name string + entries []tarEntry + wantErr bool +}{ + {"normal file", ..., false}, + {"path escape ../etc/passwd", ..., true}, + {"absolute path", ..., true}, + {"symlink to absolute", ..., true}, // <-- attrape I.1 + {"symlink with relative escape", ..., true}, +} +``` + +### V.3 🔴 HAUTE — Pas de test sur `controller/group.go` + +Reconnect logic, label drift, backoff — rien. + +### V.4 🟠 MOYENNE — `monitoring/uptimekuma.go` untested + +URL building, status mapping, push errors. Tests faciles avec `httptest.Server`. + +### V.5 🟠 MOYENNE — Pas de tests E2E + +`tests/complete/validate.sh` est un script bash isolé non exécuté en CI. + +**Recommandation** : un workflow CI `e2e.yml` qui sur macOS lance ghr en foreground, simule un job (curl POST /api), vérifie la création du scale set sur un repo de test. + +### V.6 🟠 MOYENNE — Pas de fuzz tests + +`config.ParseByteSize`, `auth.APIBaseURL`, `notification.matchesPattern` sont d'excellents candidats à `testing.F`. + +### V.7 🟠 MOYENNE — Pas de mock package + +Les tests utilisent des doubles à la main (ex. `controller/scaler_test.go`). Pour la durabilité, soit `gomock`, soit `testify/mock`, soit interfaces locales explicites. + +### V.8 🟡 BASSE — Pas de `t.Parallel()` + +Aucun fichier de test n'appelle `t.Parallel()`. Le run time monte vite ; sur CI macOS c'est sensible. + +### V.9 🟡 BASSE — Pas de coverage report + +Pas de `go test -coverprofile=` dans CI ni d'upload codecov. Impossible de prioriser. + +### V.10 🟡 BASSE — `internal/logging/logger_test.go` à 602 LOC + +Le test est plus long que le code testé (180 LOC). Probable redondance — à splitter par concern. + +--- + +## VI. Performance + +### VI.1 🟠 MOYENNE — `copyDir` lent au démarrage + +**Fichier** : `internal/runner/copy.go`. + +Pour chaque runner, on copie ~70 Mo (binaires `Runner.Listener`, `Runner.Worker`, `dotnet`, etc.) → ~200 ms par runner sur SSD, ~5 s sur HDD. Avec 10 runners, 50 s. + +**Recommandation** : hardlinks pour les binaires read-only : + +```go +if info.Mode().IsRegular() && !info.Mode()&0o200 == 0 { /* writable, copy */ } +else { os.Link(src, dst) } +``` + +Les action runners écrivent dans `_work/`, pas dans les binaires. Hardlink sûr pour ~99 % des fichiers. + +### VI.2 🟠 MOYENNE — Pas de parallélisation du DL multi-version + +Si on a `groups: [{version: 2.310}, {version: 2.311}]`, les DL se font séquentiellement dans chaque `runGroup`. OK pour 2-3 groupes, mauvais pour 20. + +### VI.3 🟡 BASSE — `Status` parse `launchctl list` line by line + +Acceptable pour < 100 services, OK. + +### VI.4 🟡 BASSE — `Discord.throttle()` block `mu` pendant Sleep + +Acceptable (intentionnel), mais pourrait être implémenté avec `golang.org/x/time/rate` (rate.Limiter) pour libérer le mutex et permettre des sends concurrents. + +--- + +## VII. Observabilité + +### VII.1 ⛔ CRITIQUE — Pas de `/metrics` Prometheus + +Le daemon expose `/status` et `/health` mais aucun endpoint Prometheus. Pour un outil ops, c'est limitant : impossible de tracer `runners_idle`, `runners_busy`, `jobs_completed_total`, `github_api_latency_seconds`. + +**Recommandation** : ajouter `prometheus/client_golang` + un handler `/metrics` derrière une feature flag config `monitoring.prometheus.enabled`. + +```go +runnersGauge := prometheus.NewGaugeVec( + prometheus.GaugeOpts{Name: "ghr_runners", Help: "..."}, + []string{"group", "state"}) +prometheus.MustRegister(runnersGauge) +``` + +### VII.2 🔴 HAUTE — Pas de tracing OpenTelemetry + +Pour des incidents complexes (un job qui timeout vs un runner qui ne start pas vs un network glitch), un span trace serait précieux. Le SDK `actions/scaleset` n'expose pas de hooks OTel, mais on peut wrapper. + +### VII.3 🔴 HAUTE — Pas de log d'audit pour les actions admin + +Login, logout, purge, restart, kill — aucun log dédié structuré. Le daemon log dans `daemon/*.json` mais c'est mélangé. + +**Recommandation** : un logger dédié `audit/*.json` avec format `{timestamp, action, actor, target, result}`. + +### VII.4 🔴 HAUTE — Pas de commande `ghr logs` + +L'opérateur doit `cd ~/.local/share/ghr/logs/...` et `tail -f` à la main. Pour un CLI premium : + +```bash +ghr logs daemon # tail daemon +ghr logs group ci # tail group ci +ghr logs runner ci-abc # tail runner +ghr logs --follow +ghr logs --since 1h --grep "error" +``` + +### VII.5 🟠 MOYENNE — Pas de `ghr inspect ` + +Pour debug, dump l'état d'un runner spécifique (PID, workdir, log path, started, jobs done). + +### VII.7 🟠 MOYENNE — Pas de rate-limit display + +`GET /user` retourne `X-RateLimit-Remaining` mais on ne le surfait pas. + +**Recommandation** : log `github.api.rate_limit_remaining` à chaque call (sampling 1/10 pour éviter le bruit). Notifier si < 100. + +### VII.8 🟡 BASSE — Pas de profile pprof + +Pour debug : `import _ "net/http/pprof"` derrière une feature flag dans la config. + +### VII.9 🟡 BASSE — Pas de format `text` pour les logs daemon file + +`fileHandler` n'utilise que JSON. OK car parsable, mais pour `tail -f`, c'est moins lisible que le format `text` du console handler. + +--- + +## VIII. CI / Tooling + +### VIII.1 🟠 MOYENNE — CI ne run pas `gosec` + +Cf. I.21. + +### VIII.2 🟠 MOYENNE — Pas de coverage gate + +Cf. V.9. + +### VIII.3 🟠 MOYENNE — `go.mod`: deps `// indirect` + +``` +require ( + github.com/actions/scaleset v0.4.0 // indirect + ... +) +``` + +Toutes les dépendances apparaissent comme `// indirect`. C'est anormal pour un projet final : `cobra`, `oklog/run`, `joho/godotenv` sont importés directement. Probablement résidu de `go mod tidy` après un refactor. + +**Recommandation** : `go mod tidy` + commit. + +### VIII.4 🟠 MOYENNE — Pas de release-please / automatic versioning + +Goreleaser tire la version d'un tag. Mais pas de bot pour proposer le bump à partir des commits conventional. + +### VIII.5 🟡 BASSE — Pas de dependabot / renovate + +Risque de stagnation des deps (sécurité notamment sur `scaleset`). + +### VIII.6 🟡 BASSE — `Makefile`: pas de cible `e2e` + +À ajouter quand V.5 est résolu. + +### VIII.7 🟡 BASSE — Pas de pre-commit hooks (lefthook/husky-go) + +Optional, mais utile. + +--- + +## IX. Documentation + +### IX.2 🟠 MOYENNE — `CLAUDE.md` référence des specs absentes + +Cf. III.3. + +### IX.3 🟠 MOYENNE — Pas de page `ARCHITECTURE.md` + +Diagram de séquence pour : startup, runner provisioning, job lifecycle, shutdown. Le lecteur doit reconstruire à partir du code. + +### IX.4 🟠 MOYENNE — Pas de troubleshooting guide + +"Mon runner ne se connecte pas" → où chercher ? Pas de doc. + +### IX.5 🟡 BASSE — Pas de `CONTRIBUTING.md` + +### IX.6 🟡 BASSE — Pas de `CHANGELOG.md` formel + +goreleaser génère un changelog par release, mais pas dans le repo. + +### IX.7 🟡 BASSE — Licence "Proprietary. All rights reserved." mais pas de `LICENSE` + +À clarifier (interne, MIT, BSL ?). + +--- + +## X. Propositions de features + +### X.1 Quick wins (1-2 jours) + +| Feature | Bénéfice | Effort | +|---------|----------|--------| +| `ghr config validate ` | CI lint pré-deploy | XS | +| `ghr config print` (résolu) | Debug config | XS | +| `ghr logs daemon\|group\|runner` | Ops UX | S | +| `ghr inspect ` | Debug | S | +| `--dry-run` pour `purge`, `restart`, `stop --force` | Sécurité ops | XS | +| `--max-age` sur cache binaries | Disk hygiene | XS | +| Audit log file séparé | Compliance | S | +| Notification level filter (`min_level: warn`) | UX notifs | XS | +| SHA256 verification des tarballs runner | Sécu (I.2) | S | +| Chmod 0o600 sur le socket | Sécu (I.4) | XS | + +### X.2 Medium (1-2 semaines) + +| Feature | Bénéfice | Effort | +|---------|----------|--------| +| Prometheus `/metrics` | Observabilité | M | +| Reload via SIGHUP | Ops UX (IV.4) | M | +| Keychain pour le PAT (macOS) | Sécu (I.6) | M | +| Circuit breaker GitHub API | Résilience (IV.2) | S | +| Drain mode (`ghr stop --drain`) | Ops UX | M | +| TUI dashboard (`ghr top`) en Bubble Tea | UX | M | +| Hardlinks au lieu de copy | Perf (VI.1) | S | +| Retry sur webhook/uptimekuma | Résilience (IV.3) | S | +| Linux support (testé, pas juste compile) | Adoption | M | +| Self-update (`ghr update`) | Ops UX | M | + +### X.3 Large (1+ mois) + +| Feature | Bénéfice | Effort | +|---------|----------|--------| +| Multi-runner-group ID resolution | Correctness (II.5) | M | +| Migration de `launchctl load` vers `bootstrap` | macOS forward-compat | M | +| GitHub Enterprise Server end-to-end | Adoption B2B | M | +| Dockerfile + Helm chart (Linux runners) | Adoption cloud | L | +| Web UI minimal (next.js) | Ops UX premium | L | +| Plugins externes (notif Slack, PagerDuty) via gRPC | Ecosystem | L | +| OIDC token issuance pour les runners | Sécu enterprise | L | +| Crash-only design + persistent state DB (BoltDB) | Résilience | L | +| Auto-scaling basé sur les métriques GitHub (jobs queued) | Cost optim | L | +| Mode "burst" (max éphémère sous pression, idle 0) | Cost optim | M | +| Distributed mode (plusieurs ghr coordonnés via etcd/Consul) | Scale-out | XL | + +### X.4 Polish & QoL + +- Couleurs sur `ghr status` (déjà en place via codes ANSI dans render — vérifier). +- `--no-color` flag global. +- Auto-completion shell (cobra le supporte). +- Man pages générées via `cobra/doc`. +- Homebrew tap. +- Telegram/Slack notification providers (`internal/notification/*`). + +--- + +## XI. Plan d'action priorisé + +### Phase 1 — Sécurité bloquante (1 sprint, ~5 j) + +1. ✅ Fix tar symlink traversal (I.1) — 2 h. +2. ✅ SHA-256 verification (I.2) — 4 h. +3. ✅ `pgrep` safety guards (I.3) — 4 h (+ valider workdir_base). +4. ✅ Chmod socket 0600 (I.4) — 30 min. +5. ✅ HTTP client timeouts globaux (I.5) — 2 h. +6. ✅ Plist XML escape (I.7) — 1 h. +7. ✅ `gosec` dans CI (I.21) — 1 h. + +### Phase 2 — Bugs correctness (1 sprint, ~5 j) + +8. ✅ Câbler `UpdateGroupStats`/`RecordStart{Failure,Success}` (II.1) — 1 j. +9. ✅ Lock cache versionné + marker `.complete` (II.2) — 4 h. +10. ✅ Loop variable explicite (II.4) — 30 min. +11. ✅ `RunnerGroupID` config-driven (II.5) — 1 j. +12. ✅ Notifications async sous lock (II.6) — 4 h. +13. ✅ Graceful API shutdown (II.7, I.10, I.11) — 2 h. +14. ✅ Process panic recovery (IV.1) — 2 h. +15. ✅ Tests tar extraction (V.2) — 1 j. + +### Phase 3 — Résilience & observabilité (2 sprints, ~10 j) + +16. ✅ Circuit breaker + retry generalized (IV.2, IV.3) — 2 j. +17. ✅ Reload SIGHUP (IV.4) — 2 j. +18. ✅ Cache GC (IV.5) — 4 h. +19. ✅ Liveness watchdog (IV.7) — 4 h. +20. ✅ `/metrics` Prometheus (VII.1) — 1 j. +21. ✅ Tracing OpenTelemetry initial (VII.2) — 2 j. +22. ✅ Audit log (VII.3) — 1 j. +23. ✅ `ghr logs` command (VII.4) — 1 j. + +### Phase 4 — Maintenabilité & tests (1 sprint, ~5 j) + +24. ✅ Split fichiers > 200 LOC (III.1) — 1 j. +25. ✅ Centraliser `state.Paths` (III.10) — 4 h. +26. ✅ Tests CLI (V.1) — 2 j. +27. ✅ Coverage gate dans CI (V.9) — 4 h. +28. ✅ `go mod tidy` (VIII.3) — 30 min. +29. ✅ README sync (III.4) — 2 h. +30. ✅ Regen specs (III.3) — 1 j. + +### Phase 5 — Features quick wins (1 sprint, ~5 j) + +Cf. table X.1. + +### Estimation totale + +| Phase | Durée | Risque | +|-------|-------|--------| +| 1 | 5 j | bas | +| 2 | 5 j | moyen | +| 3 | 10 j | moyen | +| 4 | 5 j | bas | +| 5 | 5 j | bas | +| **Total** | **~6 semaines** | | + +--- + +## XII. Annexes + +### A. Tableau récapitulatif des findings + +| ID | Sévérité | Catégorie | Titre | Fichier:Ligne | +|-------|----------|---------------|--------------------------------------------------------|-------------------------------------| +| I.1 | Critique | Sécurité | Symlink traversal tar | runner/download.go:68 | +| I.2 | Critique | Sécurité | Pas de checksum tarball | runner/download.go:17 | +| I.3 | Critique | Sécurité | `pgrep -f` non sanitized | runner/cleanup.go:91 | +| II.1 | Critique | Bug | Stats santé jamais alimentées | health/group_state.go | +| II.2 | Critique | Bug | Race cache binaries | runner/binary.go:28 | +| IV.1 | Critique | Résilience | Pas de panic recovery | cli/run.go:74 | +| VII.1 | Critique | Observabilité | Pas de `/metrics` | — | +| I.4 | Haute | Sécurité | Unix socket lisible par tous | api/server.go:46 | +| I.5 | Haute | Sécurité | http.DefaultClient sans timeout | auth/installations.go | +| I.6 | Haute | Sécurité | Credentials clair | auth/store.go:37 | +| I.7 | Haute | Sécurité | Plist XML non échappé | launchd/plist.go:8 | +| I.8 | Haute | Sécurité | launchctl deprecated | launchd/launchctl.go | +| I.9 | Haute | Sécurité | copyDir symlinks | runner/copy.go:26 | +| II.3 | Haute | Bug | Validate erreurs verbeuses | auth/validate.go:46 | +| II.4 | Haute | Bug | Loop variable | controller/controller.go:75 | +| II.5 | Haute | Bug | RunnerGroupID hardcoded | cli/daemon.go:73 | +| II.6 | Haute | Bug | Notifs sous mutex | health/checks.go:11 | +| II.7 | Haute | Bug | API shutdown abrupt | api/server.go:62 | +| III.1 | Haute | Archi | Fichiers > 200 LOC | health/checks.go (et 5 autres) | +| III.2 | Haute | Archi | Interface scaleSetClient = 7 méthodes | controller/controller.go:16 | +| IV.2 | Haute | Résilience | Pas de circuit breaker | auth/*.go | +| IV.3 | Haute | Résilience | Webhook/UK sans retry | notification/webhook.go | +| IV.4 | Haute | Résilience | Pas de SIGHUP reload | cli/run.go | +| IV.5 | Haute | Résilience | Pas de cache GC | runner/binary.go | +| V.1 | Haute | Tests | Aucun test CLI | cli/*.go | +| V.2 | Haute | Tests | Aucun test tar extraction | runner/download.go | +| V.3 | Haute | Tests | Aucun test controller/group | controller/group.go | +| VII.2 | Haute | Observabilité | Pas de tracing | — | +| VII.3 | Haute | Observabilité | Pas d'audit log | — | +| VII.4 | Haute | Observabilité | Pas de `ghr logs` | cli/ | +| IX.1 | Haute | Doc | README divergent | README.md | +| ... | ... | ... | (voir corps du doc) | ... | + +### B. Commandes utiles pour valider après corrections + +```bash +# Lint complet +make lint && make vet && make fmt-check + +# Tests avec race + coverage +go test -race -count=1 -coverprofile=cover.out ./... +go tool cover -func=cover.out | tail -1 + +# Vuln check +make vuln + +# Static security +gosec -severity medium -confidence medium ./... + +# Tar extraction fuzz +go test -fuzz=FuzzExtractTarGz -fuzztime=30s ./internal/runner/ + +# Build + sanity check +make build && ./ghr version && ./ghr config validate tests/simple/config.yaml +``` + +### C. Références spec interne + +- `.claude/rules/security.md` — règles secrets/permissions +- `.claude/rules/architecture.md` — package-by-feature, interfaces consumer-side +- `.claude/rules/code-cleanliness.md` — 200 LOC max, no comments, no godoc +- `.claude/rules/go-style.md` — naming, errors, concurrency +- `CLAUDE.md` — vue d'ensemble projet + +--- + +**Fin de l'audit.** Document généré le 2026-05-16. À versionner et reviewer à chaque phase pour mettre à jour le statut des items (✅ done / ⏳ in progress / ❌ blocked). diff --git a/config.example.yaml b/config.example.yaml index 3651690..992c9ff 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -1,6 +1,7 @@ github: url: "https://github.com/my-org" runner_group: "default" + runner_group_id: 1 runner: version: "latest" diff --git a/internal/api/server.go b/internal/api/server.go index 6b7f2fa..0b082ce 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -8,10 +8,11 @@ import ( "net" "net/http" "os" - "path/filepath" + "time" "github.com/RedBoardDev/gh-runners-tool/v2/internal/health" "github.com/RedBoardDev/gh-runners-tool/v2/internal/model" + "github.com/RedBoardDev/gh-runners-tool/v2/internal/state" ) type controllerState interface { @@ -32,7 +33,7 @@ type Server struct { func NewServer(stateDir string, controller controllerState, healthProvider healthState, logger *slog.Logger) *Server { return &Server{ - socketPath: filepath.Join(stateDir, "ghr.sock"), + socketPath: state.New(stateDir).Socket(), controller: controller, health: healthProvider, logger: logger, @@ -50,6 +51,12 @@ func (s *Server) Run(ctx context.Context) error { } s.listener = ln + if chmodErr := os.Chmod(s.socketPath, 0o600); chmodErr != nil { + ln.Close() + _ = os.Remove(s.socketPath) + return fmt.Errorf("chmod socket %s: %w", s.socketPath, chmodErr) + } + srv := &http.Server{ Handler: s.routes(), } @@ -59,22 +66,23 @@ func (s *Server) Run(ctx context.Context) error { errCh <- srv.Serve(ln) }() + defer func() { + if cleanupErr := os.Remove(s.socketPath); cleanupErr != nil && !os.IsNotExist(cleanupErr) { + s.logger.Warn("failed to remove socket file", "path", s.socketPath, "error", cleanupErr) + } + }() + select { case <-ctx.Done(): - shutdownErr := srv.Close() - cleanupErr := os.Remove(s.socketPath) + shutdownCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), 5*time.Second) + defer cancel() + shutdownErr := srv.Shutdown(shutdownCtx) + <-errCh if shutdownErr != nil { return fmt.Errorf("shutdown api server: %w", shutdownErr) } - if cleanupErr != nil && !os.IsNotExist(cleanupErr) { - s.logger.Warn("failed to remove socket file", "path", s.socketPath, "error", cleanupErr) - } return nil case err := <-errCh: - cleanupErr := os.Remove(s.socketPath) - if cleanupErr != nil && !os.IsNotExist(cleanupErr) { - s.logger.Warn("failed to remove socket file", "path", s.socketPath, "error", cleanupErr) - } if errors.Is(err, http.ErrServerClosed) { return nil } diff --git a/internal/api/server_test.go b/internal/api/server_test.go index b1780d6..8074d02 100644 --- a/internal/api/server_test.go +++ b/internal/api/server_test.go @@ -1,11 +1,13 @@ package api import ( + "context" "encoding/json" "log/slog" "net/http" "net/http/httptest" "os" + "path/filepath" "testing" "time" @@ -191,3 +193,46 @@ func TestHandleStatus_EmptyGroups(t *testing.T) { t.Fatalf("expected 0 groups, got %d", len(body.Groups)) } } + +func TestServer_Run_SocketPermissions(t *testing.T) { + // Unix domain sockets on macOS cap at ~104 chars, so avoid t.TempDir() which + // returns long /var/folders/... paths. Use a short directory under /tmp. + stateDir, err := os.MkdirTemp("/tmp", "ghr-sock-") + if err != nil { + t.Fatalf("mkdir temp: %v", err) + } + t.Cleanup(func() { os.RemoveAll(stateDir) }) + + srv := NewServer(stateDir, &mockController{}, &mockHealth{}, + slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{Level: slog.LevelError + 1}))) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { done <- srv.Run(ctx) }() + + socket := filepath.Join(stateDir, "ghr.sock") + deadline := time.Now().Add(2 * time.Second) + var info os.FileInfo + for time.Now().Before(deadline) { + if info, err = os.Stat(socket); err == nil { + break + } + time.Sleep(20 * time.Millisecond) + } + + if err != nil { + cancel() + <-done + t.Fatalf("stat socket: %v", err) + } + mode := info.Mode().Perm() + + cancel() + if err := <-done; err != nil { + t.Fatalf("server run: %v", err) + } + + if mode != 0o600 { + t.Fatalf("socket perm = %#o, want 0600", mode) + } +} diff --git a/internal/auth/auth_test.go b/internal/auth/auth_test.go index dbdaad8..f954963 100644 --- a/internal/auth/auth_test.go +++ b/internal/auth/auth_test.go @@ -3,6 +3,7 @@ package auth import ( "context" "encoding/json" + "io" "net/http" "net/http/httptest" "os" @@ -230,6 +231,39 @@ func TestSave_And_Load(t *testing.T) { } } +func TestLoad_WarnsOnLoosePermissions(t *testing.T) { + dir := t.TempDir() + credFile := filepath.Join(dir, "credentials.json") + t.Setenv("GHR_CREDENTIALS_FILE", credFile) + t.Setenv("GITHUB_TOKEN", "") + + creds := &Credentials{Method: "pat", PAT: "ghp_loose", GitHubURL: "https://github.com/x"} + if err := Save(creds); err != nil { + t.Fatalf("Save: %v", err) + } + if err := os.Chmod(credFile, 0o644); err != nil { + t.Fatalf("chmod: %v", err) + } + + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("pipe: %v", err) + } + origStderr := os.Stderr + os.Stderr = w + defer func() { os.Stderr = origStderr }() + + if _, _, loadErr := Load(LoadOpts{}); loadErr != nil { + t.Fatalf("Load: %v", loadErr) + } + w.Close() + + out, _ := io.ReadAll(r) + if !strings.Contains(string(out), "warning") || !strings.Contains(string(out), "chmod 600") { + t.Errorf("expected loose-perm warning, got: %q", out) + } +} + func TestSave_CreatesDirectory(t *testing.T) { dir := t.TempDir() nestedPath := filepath.Join(dir, "nested", "deep", "credentials.json") diff --git a/internal/auth/breaker.go b/internal/auth/breaker.go new file mode 100644 index 0000000..6f72eb5 --- /dev/null +++ b/internal/auth/breaker.go @@ -0,0 +1,81 @@ +package auth + +import ( + "errors" + "net/http" + "sync" + "time" +) + +// ErrCircuitOpen is returned by the circuit breaker while it is tripped. +var ErrCircuitOpen = errors.New("github API circuit open: too many consecutive failures") + +const ( + breakerFailureThreshold = 5 + breakerOpenDuration = 60 * time.Second +) + +type circuitBreaker struct { + mu sync.Mutex + consecutiveFails int + openedAt time.Time + clock func() time.Time +} + +func newCircuitBreaker() *circuitBreaker { + return &circuitBreaker{clock: time.Now} +} + +func (b *circuitBreaker) allow() bool { + b.mu.Lock() + defer b.mu.Unlock() + if b.openedAt.IsZero() { + return true + } + if b.clock().Sub(b.openedAt) >= breakerOpenDuration { + // Half-open: allow a single probe. + b.openedAt = time.Time{} + b.consecutiveFails = breakerFailureThreshold - 1 + return true + } + return false +} + +func (b *circuitBreaker) recordSuccess() { + b.mu.Lock() + defer b.mu.Unlock() + b.consecutiveFails = 0 + b.openedAt = time.Time{} +} + +func (b *circuitBreaker) recordFailure() { + b.mu.Lock() + defer b.mu.Unlock() + b.consecutiveFails++ + if b.consecutiveFails >= breakerFailureThreshold && b.openedAt.IsZero() { + b.openedAt = b.clock() + } +} + +func isBreakable(resp *http.Response, err error) bool { + if err != nil { + return true + } + return resp.StatusCode >= 500 +} + +var apiBreaker = newCircuitBreaker() + +// doGuarded routes the request through the package-level circuit breaker. +func doGuarded(req *http.Request) (*http.Response, error) { + if !apiBreaker.allow() { + return nil, ErrCircuitOpen + } + resp, err := httpClient.Do(req) + if isBreakable(resp, err) { + apiBreaker.recordFailure() + } else { + apiBreaker.recordSuccess() + } + return resp, err +} diff --git a/internal/auth/breaker_test.go b/internal/auth/breaker_test.go new file mode 100644 index 0000000..e1070cf --- /dev/null +++ b/internal/auth/breaker_test.go @@ -0,0 +1,71 @@ +package auth + +import ( + "errors" + "net/http" + "testing" + "time" +) + +func TestCircuitBreaker_OpensAfterThreshold(t *testing.T) { + b := newCircuitBreaker() + for i := 0; i < breakerFailureThreshold; i++ { + if !b.allow() { + t.Fatalf("allow() = false before threshold, iteration %d", i) + } + b.recordFailure() + } + if b.allow() { + t.Error("allow() = true after threshold, want false") + } +} + +func TestCircuitBreaker_SuccessResets(t *testing.T) { + b := newCircuitBreaker() + for i := 0; i < breakerFailureThreshold-1; i++ { + b.recordFailure() + } + b.recordSuccess() + if !b.allow() { + t.Error("allow() = false after success reset") + } + for i := 0; i < breakerFailureThreshold-1; i++ { + b.recordFailure() + } + if !b.allow() { + t.Error("allow() = false below threshold") + } +} + +func TestCircuitBreaker_HalfOpenAfterTimeout(t *testing.T) { + b := newCircuitBreaker() + now := time.Now() + b.clock = func() time.Time { return now } + + for i := 0; i < breakerFailureThreshold; i++ { + b.recordFailure() + } + if b.allow() { + t.Fatal("breaker should be open") + } + + now = now.Add(breakerOpenDuration + time.Second) + if !b.allow() { + t.Error("breaker should allow a probe after open duration") + } +} + +func TestIsBreakable(t *testing.T) { + if !isBreakable(nil, errors.New("net err")) { + t.Error("network error should be breakable") + } + if !isBreakable(&http.Response{StatusCode: 503}, nil) { + t.Error("5xx should be breakable") + } + if isBreakable(&http.Response{StatusCode: 401}, nil) { + t.Error("4xx should not trip the breaker") + } + if isBreakable(&http.Response{StatusCode: 200}, nil) { + t.Error("2xx should not trip the breaker") + } +} diff --git a/internal/auth/credentials.go b/internal/auth/credentials.go index fc5997f..f95f17c 100644 --- a/internal/auth/credentials.go +++ b/internal/auth/credentials.go @@ -1,6 +1,9 @@ package auth -import "time" +import ( + "log/slog" + "time" +) type Credentials struct { Method string `json:"method"` @@ -14,6 +17,7 @@ type GitHubAppCreds struct { ClientID string `json:"client_id"` InstallationID int64 `json:"installation_id"` PrivateKeyPath string `json:"private_key_path"` + Account string `json:"account,omitempty"` } type LoadOpts struct { @@ -30,3 +34,35 @@ type ValidationResult struct { type githubUserResponse struct { Login string `json:"login"` } + +func (c *Credentials) LogValue() slog.Value { + if c == nil { + return slog.AnyValue(nil) + } + attrs := []slog.Attr{ + slog.String("method", c.Method), + slog.String("github_url", c.GitHubURL), + } + if c.PAT != "" { + attrs = append(attrs, slog.String("pat", MaskedPAT(c.PAT))) + } + if c.GitHubApp != nil { + attrs = append(attrs, slog.Any("github_app", c.GitHubApp)) + } + if !c.CreatedAt.IsZero() { + attrs = append(attrs, slog.Time("created_at", c.CreatedAt)) + } + return slog.GroupValue(attrs...) +} + +func (g *GitHubAppCreds) LogValue() slog.Value { + if g == nil { + return slog.AnyValue(nil) + } + return slog.GroupValue( + slog.String("client_id", g.ClientID), + slog.Int64("installation_id", g.InstallationID), + slog.String("private_key_path", g.PrivateKeyPath), + slog.String("account", g.Account), + ) +} diff --git a/internal/auth/credentials_test.go b/internal/auth/credentials_test.go new file mode 100644 index 0000000..bb20080 --- /dev/null +++ b/internal/auth/credentials_test.go @@ -0,0 +1,61 @@ +package auth + +import ( + "bytes" + "log/slog" + "strings" + "testing" +) + +func TestCredentials_LogValue_MasksPAT(t *testing.T) { + creds := &Credentials{ + Method: "pat", + GitHubURL: "https://github.com/example", + PAT: "ghp_abcdefghijklmnopqrstuvwxyz1234567890", + } + + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + logger.Info("loaded", "creds", creds) + + out := buf.String() + if strings.Contains(out, "ghp_abcdefghijklmnopqrstuvwxyz1234567890") { + t.Errorf("log output leaks raw PAT: %s", out) + } + if !strings.Contains(out, "ghp_") || !strings.Contains(out, "7890") { + t.Errorf("log output should contain masked PAT excerpt, got: %s", out) + } +} + +func TestCredentials_LogValue_OmitsEmptyPAT(t *testing.T) { + creds := &Credentials{ + Method: "github_app", + GitHubURL: "https://github.com/example", + GitHubApp: &GitHubAppCreds{ + ClientID: "Iv1.xxx", + InstallationID: 42, + PrivateKeyPath: "/etc/ghr/key.pem", + Account: "octocat", + }, + } + + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + logger.Info("loaded", "creds", creds) + + out := buf.String() + if strings.Contains(out, "pat=") { + t.Errorf("empty PAT must not appear in log output: %s", out) + } + if !strings.Contains(out, "client_id=Iv1.xxx") { + t.Errorf("github app fields missing from log output: %s", out) + } +} + +func TestCredentials_LogValue_NilSafe(t *testing.T) { + var c *Credentials + v := c.LogValue() + if v.Kind() != slog.KindAny { + t.Errorf("nil receiver should resolve to KindAny, got %v", v.Kind()) + } +} diff --git a/internal/auth/http.go b/internal/auth/http.go new file mode 100644 index 0000000..81ebcf5 --- /dev/null +++ b/internal/auth/http.go @@ -0,0 +1,26 @@ +package auth + +import ( + "io" + "net/http" + "time" +) + +const ( + httpTimeout = 30 * time.Second + maxBodyExcerpt = 500 +) + +var httpClient = &http.Client{Timeout: httpTimeout} + +func drainBody(resp *http.Response) { + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() +} + +func truncateBody(s string) string { + if len(s) > maxBodyExcerpt { + return s[:maxBodyExcerpt] + "..." + } + return s +} diff --git a/internal/auth/installations.go b/internal/auth/installations.go new file mode 100644 index 0000000..c6dad19 --- /dev/null +++ b/internal/auth/installations.go @@ -0,0 +1,153 @@ +package auth + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "strings" +) + +type Installation struct { + ID int64 + Account string + AccountType string + TargetType string + HTMLURL string +} + +type InstallationToken struct { + Token string + ExpiresAt string + Permissions map[string]string +} + +const ( + permAdministration = "administration" + permOrgRunners = "organization_self_hosted_runners" +) + +func ListAppInstallations(ctx context.Context, apiBaseURL, appJWT string) ([]Installation, error) { + endpoint := strings.TrimRight(apiBaseURL, "/") + "/app/installations?per_page=100" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, endpoint, http.NoBody) + if err != nil { + return nil, fmt.Errorf("create installations request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+appJWT) + req.Header.Set("Accept", "application/vnd.github+json") + + resp, err := doGuarded(req) + if err != nil { + return nil, fmt.Errorf("list installations: %w", err) + } + defer drainBody(resp) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read installations response: %w", err) + } + if resp.StatusCode == http.StatusUnauthorized { + return nil, fmt.Errorf("list installations: GitHub rejected the JWT (check Client ID and private key belong to the same App)") + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("list installations: HTTP %d: %s", resp.StatusCode, truncateBody(string(body))) + } + + var raw []struct { + ID int64 `json:"id"` + Account struct { + Login string `json:"login"` + Type string `json:"type"` + } `json:"account"` + TargetType string `json:"target_type"` + HTMLURL string `json:"html_url"` + } + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("decode installations: %w", err) + } + + out := make([]Installation, 0, len(raw)) + for _, r := range raw { + out = append(out, Installation{ + ID: r.ID, + Account: r.Account.Login, + AccountType: r.Account.Type, + TargetType: r.TargetType, + HTMLURL: r.HTMLURL, + }) + } + return out, nil +} + +func IssueInstallationToken(ctx context.Context, apiBaseURL, appJWT string, installationID int64) (*InstallationToken, error) { + endpoint := fmt.Sprintf("%s/app/installations/%d/access_tokens", strings.TrimRight(apiBaseURL, "/"), installationID) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint, http.NoBody) + if err != nil { + return nil, fmt.Errorf("create access_tokens request: %w", err) + } + req.Header.Set("Authorization", "Bearer "+appJWT) + req.Header.Set("Accept", "application/vnd.github+json") + + resp, err := doGuarded(req) + if err != nil { + return nil, fmt.Errorf("issue installation token: %w", err) + } + defer drainBody(resp) + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read access_tokens response: %w", err) + } + if resp.StatusCode != http.StatusCreated { + return nil, fmt.Errorf("issue installation token: HTTP %d: %s", resp.StatusCode, truncateBody(string(body))) + } + + var raw struct { + Token string `json:"token"` + ExpiresAt string `json:"expires_at"` + Permissions map[string]string `json:"permissions"` + } + if err := json.Unmarshal(body, &raw); err != nil { + return nil, fmt.Errorf("decode installation token: %w", err) + } + return &InstallationToken{ + Token: raw.Token, + ExpiresAt: raw.ExpiresAt, + Permissions: raw.Permissions, + }, nil +} + +func CheckRunnerPermissions(perms map[string]string) error { + if hasWrite(perms, permAdministration) || hasWrite(perms, permOrgRunners) { + return nil + } + return fmt.Errorf( + "GitHub App lacks runner permissions: enable %q OR %q with 'write' access in the App settings", + permAdministration, permOrgRunners, + ) +} + +func APIBaseURL(githubURL string) (string, error) { + if githubURL == "" { + return "https://api.github.com", nil + } + u, err := url.Parse(githubURL) + if err != nil { + return "", fmt.Errorf("parse github URL %q: %w", githubURL, err) + } + host := strings.ToLower(u.Host) + if host == "" { + return "", fmt.Errorf("github URL %q has no host", githubURL) + } + if host == "github.com" || host == "api.github.com" { + return "https://api.github.com", nil + } + return fmt.Sprintf("%s://%s/api/v3", u.Scheme, u.Host), nil +} + +func hasWrite(perms map[string]string, key string) bool { + v, ok := perms[key] + return ok && (v == "write" || v == "admin") +} diff --git a/internal/auth/installations_test.go b/internal/auth/installations_test.go new file mode 100644 index 0000000..a521f2f --- /dev/null +++ b/internal/auth/installations_test.go @@ -0,0 +1,196 @@ +package auth + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestListAppInstallations(t *testing.T) { + t.Run("returns installations on 200", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/app/installations" { + t.Errorf("path = %q", r.URL.Path) + } + if got := r.Header.Get("Authorization"); !strings.HasPrefix(got, "Bearer ") { + t.Errorf("Authorization = %q, want Bearer prefix", got) + } + w.Header().Set("Content-Type", "application/json") + _, _ = w.Write([]byte(`[ + {"id": 100, "account": {"login": "akord-securite", "type": "Organization"}, "target_type": "Organization", "html_url": "https://github.com/akord-securite"}, + {"id": 200, "account": {"login": "personal", "type": "User"}, "target_type": "User", "html_url": "https://github.com/personal"} + ]`)) + })) + defer srv.Close() + + got, err := ListAppInstallations(context.Background(), srv.URL, "fake-jwt") + if err != nil { + t.Fatalf("ListAppInstallations: %v", err) + } + if len(got) != 2 { + t.Fatalf("len = %d, want 2", len(got)) + } + if got[0].ID != 100 || got[0].Account != "akord-securite" || got[0].AccountType != "Organization" { + t.Errorf("got[0] = %+v", got[0]) + } + if got[1].ID != 200 || got[1].Account != "personal" { + t.Errorf("got[1] = %+v", got[1]) + } + }) + + t.Run("401 returns user-friendly error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusUnauthorized) + _, _ = w.Write([]byte(`{"message":"bad credentials"}`)) + })) + defer srv.Close() + + _, err := ListAppInstallations(context.Background(), srv.URL, "wrong-jwt") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "rejected the JWT") { + t.Errorf("error = %q, want contain 'rejected the JWT'", err) + } + }) + + t.Run("500 returns wrapped HTTP error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusInternalServerError) + _, _ = w.Write([]byte("oops")) + })) + defer srv.Close() + + _, err := ListAppInstallations(context.Background(), srv.URL, "x") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "HTTP 500") { + t.Errorf("error = %q, want contain 'HTTP 500'", err) + } + }) + + t.Run("empty list returns empty slice", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _, _ = w.Write([]byte(`[]`)) + })) + defer srv.Close() + + got, err := ListAppInstallations(context.Background(), srv.URL, "x") + if err != nil { + t.Fatalf("ListAppInstallations: %v", err) + } + if len(got) != 0 { + t.Errorf("len = %d, want 0", len(got)) + } + }) +} + +func TestIssueInstallationToken(t *testing.T) { + t.Run("returns token and permissions on 201", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/app/installations/12345/access_tokens" { + t.Errorf("path = %q", r.URL.Path) + } + if r.Method != http.MethodPost { + t.Errorf("method = %q, want POST", r.Method) + } + w.WriteHeader(http.StatusCreated) + _, _ = w.Write([]byte(`{ + "token": "ghs_abc", + "expires_at": "2026-05-16T12:34:56Z", + "permissions": {"administration": "write", "metadata": "read"} + }`)) + })) + defer srv.Close() + + got, err := IssueInstallationToken(context.Background(), srv.URL, "jwt", 12345) + if err != nil { + t.Fatalf("IssueInstallationToken: %v", err) + } + if got.Token != "ghs_abc" { + t.Errorf("token = %q", got.Token) + } + if got.Permissions["administration"] != "write" { + t.Errorf("permissions = %v", got.Permissions) + } + }) + + t.Run("404 returns wrapped error", func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusNotFound) + _, _ = w.Write([]byte(`{"message":"not found"}`)) + })) + defer srv.Close() + + _, err := IssueInstallationToken(context.Background(), srv.URL, "jwt", 999) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "HTTP 404") { + t.Errorf("error = %q", err) + } + }) +} + +func TestCheckRunnerPermissions(t *testing.T) { + tests := []struct { + name string + perms map[string]string + wantErr bool + }{ + {"administration:write passes", map[string]string{"administration": "write"}, false}, + {"administration:admin passes", map[string]string{"administration": "admin"}, false}, + {"org runners write passes", map[string]string{"organization_self_hosted_runners": "write"}, false}, + {"administration:read fails", map[string]string{"administration": "read"}, true}, + {"only metadata fails", map[string]string{"metadata": "read"}, true}, + {"empty fails", map[string]string{}, true}, + {"nil fails", nil, true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := CheckRunnerPermissions(tc.perms) + if tc.wantErr && err == nil { + t.Error("expected error, got nil") + } + if !tc.wantErr && err != nil { + t.Errorf("unexpected error: %v", err) + } + }) + } +} + +func TestAPIBaseURL(t *testing.T) { + tests := []struct { + name string + githubURL string + want string + wantErr bool + }{ + {"empty defaults to github.com API", "", "https://api.github.com", false}, + {"github.com org", "https://github.com/akord-securite", "https://api.github.com", false}, + {"github.com repo", "https://github.com/akord-securite/repo", "https://api.github.com", false}, + {"GHES org", "https://ghe.corp.example/myorg", "https://ghe.corp.example/api/v3", false}, + {"invalid URL", "://broken", "", true}, + {"missing host", "noscheme", "", true}, + } + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := APIBaseURL(tc.githubURL) + if tc.wantErr { + if err == nil { + t.Errorf("expected error for %q, got nil (result=%q)", tc.githubURL, got) + } + return + } + if err != nil { + t.Fatalf("APIBaseURL(%q): %v", tc.githubURL, err) + } + if got != tc.want { + t.Errorf("APIBaseURL(%q) = %q, want %q", tc.githubURL, got, tc.want) + } + }) + } +} diff --git a/internal/auth/jwt.go b/internal/auth/jwt.go new file mode 100644 index 0000000..62552a6 --- /dev/null +++ b/internal/auth/jwt.go @@ -0,0 +1,47 @@ +package auth + +import ( + "fmt" + "os" + "time" + + "github.com/golang-jwt/jwt/v4" +) + +func SignAppJWT(clientID string, pemBytes []byte) (string, error) { + key, err := jwt.ParseRSAPrivateKeyFromPEM(pemBytes) + if err != nil { + return "", fmt.Errorf("parse RSA private key: %w", err) + } + + now := time.Now().Add(-30 * time.Second) + claims := jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(now), + ExpiresAt: jwt.NewNumericDate(now.Add(9 * time.Minute)), + Issuer: clientID, + } + token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims) + signed, err := token.SignedString(key) + if err != nil { + return "", fmt.Errorf("sign app JWT: %w", err) + } + return signed, nil +} + +func LoadPrivateKey(path string) ([]byte, error) { + info, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("stat private key %s: %w", path, err) + } + if mode := info.Mode().Perm(); mode&0o077 != 0 { + return nil, fmt.Errorf("private key %s has insecure permissions %#o (run: chmod 600 %s)", path, mode, path) + } + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read private key %s: %w", path, err) + } + if _, err := jwt.ParseRSAPrivateKeyFromPEM(data); err != nil { + return nil, fmt.Errorf("parse private key %s: %w", path, err) + } + return data, nil +} diff --git a/internal/auth/jwt_test.go b/internal/auth/jwt_test.go new file mode 100644 index 0000000..9dcfe94 --- /dev/null +++ b/internal/auth/jwt_test.go @@ -0,0 +1,118 @@ +package auth + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/golang-jwt/jwt/v4" +) + +func genTestKey(t *testing.T) (privPEM []byte) { + t.Helper() + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + return pem.EncodeToMemory(&pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(priv), + }) +} + +func writeTempKey(t *testing.T, pemBytes []byte, mode os.FileMode) string { + t.Helper() + path := filepath.Join(t.TempDir(), "test.pem") + if err := os.WriteFile(path, pemBytes, mode); err != nil { + t.Fatalf("write key: %v", err) + } + return path +} + +func TestSignAppJWT(t *testing.T) { + t.Run("valid PEM produces parseable JWT", func(t *testing.T) { + pemBytes := genTestKey(t) + + signed, err := SignAppJWT("Iv23liClient", pemBytes) + if err != nil { + t.Fatalf("SignAppJWT: %v", err) + } + + parsed, _, err := jwt.NewParser().ParseUnverified(signed, &jwt.RegisteredClaims{}) + if err != nil { + t.Fatalf("parse signed JWT: %v", err) + } + claims, ok := parsed.Claims.(*jwt.RegisteredClaims) + if !ok { + t.Fatalf("claims wrong type") + } + if claims.Issuer != "Iv23liClient" { + t.Errorf("Issuer = %q, want %q", claims.Issuer, "Iv23liClient") + } + if parsed.Method.Alg() != "RS256" { + t.Errorf("alg = %q, want RS256", parsed.Method.Alg()) + } + }) + + t.Run("garbage PEM returns parse error", func(t *testing.T) { + _, err := SignAppJWT("any", []byte("not a pem")) + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "parse RSA private key") { + t.Errorf("error = %q, want contain 'parse RSA private key'", err) + } + }) +} + +func TestLoadPrivateKey(t *testing.T) { + pemBytes := genTestKey(t) + + t.Run("valid key with 0600 returns bytes", func(t *testing.T) { + path := writeTempKey(t, pemBytes, 0o600) + got, err := LoadPrivateKey(path) + if err != nil { + t.Fatalf("LoadPrivateKey: %v", err) + } + if len(got) == 0 { + t.Error("got empty bytes") + } + }) + + t.Run("insecure permissions are rejected", func(t *testing.T) { + path := writeTempKey(t, pemBytes, 0o644) + _, err := LoadPrivateKey(path) + if err == nil { + t.Fatal("expected error for 0644 perms, got nil") + } + if !strings.Contains(err.Error(), "insecure permissions") { + t.Errorf("error = %q, want contain 'insecure permissions'", err) + } + }) + + t.Run("non-existent file returns stat error", func(t *testing.T) { + _, err := LoadPrivateKey("/nope/missing.pem") + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "stat private key") { + t.Errorf("error = %q, want contain 'stat private key'", err) + } + }) + + t.Run("garbage content is rejected", func(t *testing.T) { + path := writeTempKey(t, []byte("not a pem at all"), 0o600) + _, err := LoadPrivateKey(path) + if err == nil { + t.Fatal("expected parse error, got nil") + } + if !strings.Contains(err.Error(), "parse private key") { + t.Errorf("error = %q, want contain 'parse private key'", err) + } + }) +} diff --git a/internal/auth/store.go b/internal/auth/store.go index 639fa3d..070dfec 100644 --- a/internal/auth/store.go +++ b/internal/auth/store.go @@ -23,7 +23,16 @@ func FilePath() string { } func loadFromFile() (*Credentials, error) { - data, err := os.ReadFile(FilePath()) + path := FilePath() + if info, err := os.Stat(path); err == nil { + if mode := info.Mode().Perm(); mode&0o077 != 0 { + fmt.Fprintf(os.Stderr, + "warning: credentials file %s has permissions %#o; tighten with: chmod 600 %s\n", + path, mode, path) + } + } + + data, err := os.ReadFile(path) if err != nil { return nil, err } diff --git a/internal/auth/validate.go b/internal/auth/validate.go index 89a2f87..3b23b11 100644 --- a/internal/auth/validate.go +++ b/internal/auth/validate.go @@ -29,14 +29,11 @@ func validatePAT(ctx context.Context, pat string) (*ValidationResult, error) { req.Header.Set("Authorization", "Bearer "+pat) req.Header.Set("Accept", "application/vnd.github+json") - resp, err := http.DefaultClient.Do(req) + resp, err := doGuarded(req) if err != nil { return nil, fmt.Errorf("validate PAT: request failed: %w", err) } - defer func() { - _, _ = io.Copy(io.Discard, resp.Body) - resp.Body.Close() - }() + defer drainBody(resp) body, err := io.ReadAll(resp.Body) if err != nil { @@ -44,7 +41,7 @@ func validatePAT(ctx context.Context, pat string) (*ValidationResult, error) { } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("validate PAT: GitHub API returned %d: %s", resp.StatusCode, string(body)) + return nil, fmt.Errorf("validate PAT: GitHub API returned %d: %s", resp.StatusCode, truncateBody(string(body))) } var user githubUserResponse diff --git a/internal/cli/auth.go b/internal/cli/auth.go index a41e1ff..00ea3f5 100644 --- a/internal/cli/auth.go +++ b/internal/cli/auth.go @@ -32,13 +32,16 @@ func newAuthStatusCmd() *cobra.Command { fmt.Printf("Method: %s\n", creds.Method) fmt.Printf("Source: %s\n", source) if creds.GitHubURL != "" { - fmt.Printf("GitHub: %s\n", creds.GitHubURL) + fmt.Printf("URL: %s\n", creds.GitHubURL) } if creds.Method == "pat" && creds.PAT != "" { fmt.Printf("Token: %s\n", auth.MaskedPAT(creds.PAT)) } if creds.GitHubApp != nil { fmt.Printf("Client: %s\n", creds.GitHubApp.ClientID) + if creds.GitHubApp.Account != "" { + fmt.Printf("Account: @%s\n", creds.GitHubApp.Account) + } fmt.Printf("Install: %d\n", creds.GitHubApp.InstallationID) fmt.Printf("Key: %s\n", creds.GitHubApp.PrivateKeyPath) } diff --git a/internal/cli/daemon.go b/internal/cli/daemon.go index 3170ee5..01e3f22 100644 --- a/internal/cli/daemon.go +++ b/internal/cli/daemon.go @@ -18,6 +18,7 @@ import ( "github.com/RedBoardDev/gh-runners-tool/v2/internal/monitoring" "github.com/RedBoardDev/gh-runners-tool/v2/internal/notification" "github.com/RedBoardDev/gh-runners-tool/v2/internal/runner" + "github.com/RedBoardDev/gh-runners-tool/v2/internal/state" ) type daemon struct { @@ -72,7 +73,7 @@ func buildDaemon(cfg *config.Config, creds *auth.Credentials, githubURL string) ghClient, binaryMgr, processMgr, notifSvc, logMgr, cfg.Groups, controller.ControllerConfig{ RunnerVersion: cfg.Runner.Version, - RunnerGroupID: 1, + RunnerGroupID: cfg.GitHub.RunnerGroupID, }, logger, ) @@ -160,7 +161,7 @@ func resolveGitHubURL(creds *auth.Credentials, cfg *config.Config) (string, erro } func pidFilePath(stateDir string) string { - return filepath.Join(stateDir, "daemon.pid") + return state.New(stateDir).PIDFile() } func writePIDFile(path string) error { diff --git a/internal/cli/login.go b/internal/cli/login.go index 7de8644..403959f 100644 --- a/internal/cli/login.go +++ b/internal/cli/login.go @@ -4,7 +4,6 @@ import ( "bufio" "fmt" "os" - "strings" "github.com/RedBoardDev/gh-runners-tool/v2/internal/auth" "github.com/spf13/cobra" @@ -14,16 +13,16 @@ func newLoginCmd() *cobra.Command { cmd := &cobra.Command{ Use: "login", Short: "Authenticate with GitHub", - Long: "Interactive wizard to configure GitHub authentication. Supports PAT and GitHub App.", + Long: "Interactive wizard to configure GitHub authentication. GitHub App (recommended) or PAT.", RunE: runLogin, } - cmd.Flags().String("method", "", "auth method: pat or app") - cmd.Flags().String("url", "", "GitHub URL (org, repo, or enterprise)") + cmd.Flags().String("method", "", "auth method: app or pat (interactive if empty)") + cmd.Flags().String("url", "", "GitHub URL for PAT mode (org or repo)") + cmd.Flags().String("host", "", "GitHub host URL for App mode (default https://github.com)") cmd.Flags().String("client-id", "", "GitHub App client ID") - cmd.Flags().Int64("installation-id", 0, "GitHub App installation ID") + cmd.Flags().Int64("installation-id", 0, "GitHub App installation ID (auto-detected if only one)") cmd.Flags().String("private-key", "", "path to GitHub App private key (.pem)") - return cmd } @@ -32,68 +31,79 @@ func runLogin(cmd *cobra.Command, _ []string) error { if err != nil { return fmt.Errorf("get method flag: %w", err) } - if method == "" { - reader := bufio.NewReader(os.Stdin) - return interactiveLogin(cmd, reader) + return interactiveLogin(cmd, bufio.NewReader(os.Stdin)) } - return nonInteractiveLogin(cmd, method) } func nonInteractiveLogin(cmd *cobra.Command, method string) error { + switch method { + case "pat": + return nonInteractivePAT(cmd) + case "app", "github_app": + return nonInteractiveApp(cmd) + default: + return fmt.Errorf("unknown method %q (expected 'app' or 'pat')", method) + } +} + +func nonInteractivePAT(cmd *cobra.Command) error { + if tokenFlag == "" { + return fmt.Errorf("--token is required for PAT method") + } url, err := cmd.Flags().GetString("url") if err != nil { return fmt.Errorf("get url flag: %w", err) } + if url == "" { + return fmt.Errorf("--url is required for PAT method") + } + creds := &auth.Credentials{ + Method: "pat", + GitHubURL: url, + PAT: tokenFlag, + } + return validateAndSave(cmd, creds) +} - var creds *auth.Credentials - - switch method { - case "pat": - if tokenFlag == "" { - return fmt.Errorf("--token is required for PAT authentication") - } - if url == "" { - return fmt.Errorf("--url is required") - } - creds = &auth.Credentials{ - Method: "pat", - GitHubURL: url, - PAT: tokenFlag, - } - - case "app": - clientID, flagErr := cmd.Flags().GetString("client-id") - if flagErr != nil { - return fmt.Errorf("get client-id flag: %w", flagErr) - } - installationID, flagErr := cmd.Flags().GetInt64("installation-id") - if flagErr != nil { - return fmt.Errorf("get installation-id flag: %w", flagErr) - } - privateKey, flagErr := cmd.Flags().GetString("private-key") - if flagErr != nil { - return fmt.Errorf("get private-key flag: %w", flagErr) - } - if clientID == "" || installationID == 0 || privateKey == "" || url == "" { - return fmt.Errorf("--client-id, --installation-id, --private-key, and --url are all required for GitHub App authentication") - } - creds = &auth.Credentials{ - Method: "github_app", - GitHubURL: url, - GitHubApp: &auth.GitHubAppCreds{ - ClientID: clientID, - InstallationID: installationID, - PrivateKeyPath: privateKey, - }, - } - - default: - return fmt.Errorf("unknown method %q: must be 'pat' or 'app'", method) +func nonInteractiveApp(cmd *cobra.Command) error { + clientID, err := cmd.Flags().GetString("client-id") + if err != nil { + return fmt.Errorf("get client-id flag: %w", err) + } + privateKey, err := cmd.Flags().GetString("private-key") + if err != nil { + return fmt.Errorf("get private-key flag: %w", err) + } + host, err := cmd.Flags().GetString("host") + if err != nil { + return fmt.Errorf("get host flag: %w", err) + } + installationID, err := cmd.Flags().GetInt64("installation-id") + if err != nil { + return fmt.Errorf("get installation-id flag: %w", err) } - return validateAndSave(cmd, creds) + in := appLoginInput{ + clientID: clientID, + privateKeyPath: expandHome(privateKey), + hostURL: host, + installationID: installationID, + } + prep, err := prepareAppLogin(cmd.Context(), in) + if err != nil { + return err + } + inst, err := resolveInstallation(prep.installations, in.installationID) + if err != nil { + return err + } + creds, err := finalizeAppLogin(cmd.Context(), prep, inst, in) + if err != nil { + return err + } + return saveCreds(creds) } func validateAndSave(cmd *cobra.Command, creds *auth.Credentials) error { @@ -102,22 +112,32 @@ func validateAndSave(cmd *cobra.Command, creds *auth.Credentials) error { if err != nil { return fmt.Errorf("validation failed: %w", err) } - if !result.Valid { return fmt.Errorf("credentials are not valid") } - if err := auth.Save(creds); err != nil { return fmt.Errorf("save credentials: %w", err) } - if creds.Method == "pat" && result.Username != "" { fmt.Printf("✓ Authenticated as @%s\n", result.Username) } - if creds.Method == "pat" && len(result.Scopes) > 0 { - fmt.Printf("✓ Scopes: %s\n", strings.Join(result.Scopes, ", ")) - } fmt.Printf("✓ Credentials saved to %s\n", auth.FilePath()) + return nil +} +func saveCreds(creds *auth.Credentials) error { + if err := auth.Save(creds); err != nil { + return fmt.Errorf("save credentials: %w", err) + } + fmt.Println() + fmt.Println("✓ Authentication successful") + if creds.GitHubApp != nil { + fmt.Printf(" Method: github_app\n") + fmt.Printf(" Account: @%s\n", creds.GitHubApp.Account) + fmt.Printf(" Installation: %d\n", creds.GitHubApp.InstallationID) + fmt.Printf(" URL: %s\n", creds.GitHubURL) + fmt.Printf(" Key: %s\n", creds.GitHubApp.PrivateKeyPath) + } + fmt.Printf(" Saved to: %s\n", auth.FilePath()) return nil } diff --git a/internal/cli/login_app.go b/internal/cli/login_app.go new file mode 100644 index 0000000..569b067 --- /dev/null +++ b/internal/cli/login_app.go @@ -0,0 +1,138 @@ +package cli + +import ( + "bufio" + "context" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/RedBoardDev/gh-runners-tool/v2/internal/auth" +) + +type appLoginInput struct { + clientID string + privateKeyPath string + hostURL string + installationID int64 +} + +type appLoginPrepared struct { + apiBase string + jwt string + installations []auth.Installation +} + +func prepareAppLogin(ctx context.Context, in appLoginInput) (*appLoginPrepared, error) { + if in.clientID == "" { + return nil, fmt.Errorf("client ID is required") + } + if in.privateKeyPath == "" { + return nil, fmt.Errorf("private key path is required") + } + if in.hostURL == "" { + in.hostURL = "https://github.com" + } + + pemBytes, err := auth.LoadPrivateKey(in.privateKeyPath) + if err != nil { + return nil, err + } + jwtToken, err := auth.SignAppJWT(in.clientID, pemBytes) + if err != nil { + return nil, err + } + apiBase, err := auth.APIBaseURL(in.hostURL) + if err != nil { + return nil, err + } + installations, err := auth.ListAppInstallations(ctx, apiBase, jwtToken) + if err != nil { + return nil, err + } + if len(installations) == 0 { + return nil, fmt.Errorf("the GitHub App has no installations — install it on an org or repo first at https://github.com/settings/installations") + } + return &appLoginPrepared{apiBase: apiBase, jwt: jwtToken, installations: installations}, nil +} + +func resolveInstallation(installations []auth.Installation, requestedID int64) (*auth.Installation, error) { + if requestedID != 0 { + for i := range installations { + if installations[i].ID == requestedID { + return &installations[i], nil + } + } + return nil, fmt.Errorf("installation %d not found (available: %s)", requestedID, formatInstallationList(installations)) + } + if len(installations) == 1 { + return &installations[0], nil + } + return nil, fmt.Errorf("multiple installations found, pass --installation-id (available: %s)", formatInstallationList(installations)) +} + +func selectInstallation(reader *bufio.Reader, installations []auth.Installation) (*auth.Installation, error) { + if len(installations) == 1 { + fmt.Printf(" Using installation @%s (id %d)\n", installations[0].Account, installations[0].ID) + return &installations[0], nil + } + fmt.Println() + fmt.Println("Available installations:") + for i, inst := range installations { + fmt.Printf(" %d) @%s (%s, id %d)\n", i+1, inst.Account, strings.ToLower(inst.AccountType), inst.ID) + } + fmt.Print("? Select installation: ") + raw, err := reader.ReadString('\n') + if err != nil { + return nil, fmt.Errorf("read selection: %w", err) + } + idx := 0 + if _, err := fmt.Sscanf(strings.TrimSpace(raw), "%d", &idx); err != nil || idx < 1 || idx > len(installations) { + return nil, fmt.Errorf("invalid selection %q", strings.TrimSpace(raw)) + } + return &installations[idx-1], nil +} + +func finalizeAppLogin(ctx context.Context, prep *appLoginPrepared, inst *auth.Installation, in appLoginInput) (*auth.Credentials, error) { + token, err := auth.IssueInstallationToken(ctx, prep.apiBase, prep.jwt, inst.ID) + if err != nil { + return nil, err + } + if err := auth.CheckRunnerPermissions(token.Permissions); err != nil { + return nil, err + } + host := strings.TrimRight(in.hostURL, "/") + return &auth.Credentials{ + Method: "github_app", + GitHubURL: fmt.Sprintf("%s/%s", host, inst.Account), + GitHubApp: &auth.GitHubAppCreds{ + ClientID: in.clientID, + InstallationID: inst.ID, + PrivateKeyPath: in.privateKeyPath, + Account: inst.Account, + }, + }, nil +} + +func formatInstallationList(installations []auth.Installation) string { + parts := make([]string, len(installations)) + for i, inst := range installations { + parts[i] = fmt.Sprintf("%d (@%s)", inst.ID, inst.Account) + } + return strings.Join(parts, ", ") +} + +func expandHome(path string) string { + if !strings.HasPrefix(path, "~/") && path != "~" { + return path + } + home, err := os.UserHomeDir() + if err != nil { + return path + } + if path == "~" { + return home + } + return filepath.Join(home, path[2:]) +} diff --git a/internal/cli/login_wizard.go b/internal/cli/login_wizard.go index 27ab4ac..63d6279 100644 --- a/internal/cli/login_wizard.go +++ b/internal/cli/login_wizard.go @@ -3,7 +3,6 @@ package cli import ( "bufio" "fmt" - "strconv" "strings" "github.com/RedBoardDev/gh-runners-tool/v2/internal/auth" @@ -12,107 +11,100 @@ import ( func interactiveLogin(cmd *cobra.Command, reader *bufio.Reader) error { fmt.Println() - fmt.Println("? Authentication method") - fmt.Println(" 1) Personal Access Token (PAT)") - fmt.Println(" 2) GitHub App") - fmt.Print("> ") - - choice, err := reader.ReadString('\n') + fmt.Println("Authentication method:") + fmt.Println(" 1) GitHub App (recommended — short-lived tokens, scoped permissions)") + fmt.Println(" 2) Personal Access Token") + choice, err := readLine(reader, "Choose [1]") if err != nil { - return fmt.Errorf("read choice: %w", err) + return err } - choice = strings.TrimSpace(choice) - switch choice { - case "1": - return interactivePAT(cmd, reader) - case "2": + case "", "1": return interactiveApp(cmd, reader) + case "2": + return interactivePAT(cmd, reader) default: - return fmt.Errorf("invalid choice: %q (expected 1 or 2)", choice) + return fmt.Errorf("invalid choice %q (expected 1 or 2)", choice) } } func interactivePAT(cmd *cobra.Command, reader *bufio.Reader) error { - fmt.Print("? GitHub PAT: ") - token, err := reader.ReadString('\n') + token, err := readLine(reader, "GitHub PAT") if err != nil { - return fmt.Errorf("read token: %w", err) + return err } - token = strings.TrimSpace(token) if token == "" { return fmt.Errorf("token cannot be empty") } - - fmt.Print("? GitHub URL (org or repo): ") - url, err := reader.ReadString('\n') + url, err := readLine(reader, "GitHub URL (org or repo)") if err != nil { - return fmt.Errorf("read url: %w", err) + return err } - url = strings.TrimSpace(url) if url == "" { return fmt.Errorf("URL cannot be empty") } - creds := &auth.Credentials{ Method: "pat", GitHubURL: url, PAT: token, } - return validateAndSave(cmd, creds) } func interactiveApp(cmd *cobra.Command, reader *bufio.Reader) error { - fmt.Print("? GitHub App Client ID: ") - clientID, err := reader.ReadString('\n') - if err != nil { - return fmt.Errorf("read client ID: %w", err) - } - clientID = strings.TrimSpace(clientID) - if clientID == "" { - return fmt.Errorf("client ID cannot be empty") - } + fmt.Println() + fmt.Println("Don't have a GitHub App yet? Create one at:") + fmt.Println(" https://github.com/organizations/YOUR_ORG/settings/apps/new") + fmt.Println("Required: Organization permissions → Self-hosted runners → Read & Write") + fmt.Println("Then generate a .pem private key (chmod 600) and install the App.") + fmt.Println() - fmt.Print("? Installation ID: ") - installIDStr, err := reader.ReadString('\n') + clientID, err := readLine(reader, "GitHub App Client ID") if err != nil { - return fmt.Errorf("read installation ID: %w", err) + return err } - installID, err := strconv.ParseInt(strings.TrimSpace(installIDStr), 10, 64) + pemPath, err := readLine(reader, "Path to private key (.pem)") if err != nil { - return fmt.Errorf("parse installation ID: %w", err) + return err } - - fmt.Print("? Private key path (.pem): ") - keyPath, err := reader.ReadString('\n') + hostURL, err := readLine(reader, "GitHub host URL [https://github.com]") if err != nil { - return fmt.Errorf("read private key path: %w", err) + return err } - keyPath = strings.TrimSpace(keyPath) - if keyPath == "" { - return fmt.Errorf("private key path cannot be empty") + + in := appLoginInput{ + clientID: clientID, + privateKeyPath: expandHome(pemPath), + hostURL: hostURL, } - fmt.Print("? GitHub URL: ") - url, err := reader.ReadString('\n') + fmt.Println(" Validating credentials...") + prep, err := prepareAppLogin(cmd.Context(), in) if err != nil { - return fmt.Errorf("read url: %w", err) + return err } - url = strings.TrimSpace(url) - if url == "" { - return fmt.Errorf("URL cannot be empty") + fmt.Printf(" Found %d installation(s)\n", len(prep.installations)) + + inst, err := selectInstallation(reader, prep.installations) + if err != nil { + return err } - creds := &auth.Credentials{ - Method: "github_app", - GitHubURL: url, - GitHubApp: &auth.GitHubAppCreds{ - ClientID: clientID, - InstallationID: installID, - PrivateKeyPath: keyPath, - }, + fmt.Println(" Generating installation token...") + creds, err := finalizeAppLogin(cmd.Context(), prep, inst, in) + if err != nil { + return err } + return saveCreds(creds) +} - return validateAndSave(cmd, creds) +func readLine(reader *bufio.Reader, label string) (string, error) { + if label != "" { + fmt.Printf("? %s: ", label) + } + raw, err := reader.ReadString('\n') + if err != nil { + return "", fmt.Errorf("read input: %w", err) + } + return strings.TrimSpace(raw), nil } diff --git a/internal/cli/purge.go b/internal/cli/purge.go index 916d641..cd5d7f2 100644 --- a/internal/cli/purge.go +++ b/internal/cli/purge.go @@ -12,6 +12,7 @@ import ( "github.com/RedBoardDev/gh-runners-tool/v2/internal/config" "github.com/RedBoardDev/gh-runners-tool/v2/internal/github" "github.com/RedBoardDev/gh-runners-tool/v2/internal/launchd" + "github.com/RedBoardDev/gh-runners-tool/v2/internal/state" "github.com/spf13/cobra" ) @@ -99,7 +100,7 @@ func purgeScaleSets(ctx context.Context, ghClient *github.Client, cfg *config.Co deletedSets := 0 for _, g := range cfg.Groups { fmt.Printf("purging scale set %q...\n", g.Name) - ss, getErr := ghClient.GetScaleSet(ctx, 1, g.Name) + ss, getErr := ghClient.GetScaleSet(ctx, cfg.GitHub.RunnerGroupID, g.Name) if getErr != nil { fmt.Printf(" scale set %q not found, skipping\n", g.Name) continue @@ -171,11 +172,9 @@ func cleanupWorkdirs(workdirBase string) int { } func cleanupStateFiles(stateDir string) { - for _, name := range []string{"daemon.pid", "daemon.state.json", "ghr.sock"} { - p := filepath.Join(stateDir, name) - rmErr := os.Remove(p) - if rmErr != nil && !os.IsNotExist(rmErr) { - fmt.Printf(" failed to remove %s: %v\n", p, rmErr) + for _, p := range state.New(stateDir).All() { + if err := os.Remove(p); err != nil && !os.IsNotExist(err) { + fmt.Printf(" failed to remove %s: %v\n", p, err) } } } diff --git a/internal/cli/run.go b/internal/cli/run.go index 9f0526a..1366636 100644 --- a/internal/cli/run.go +++ b/internal/cli/run.go @@ -3,7 +3,9 @@ package cli import ( "context" "fmt" + "log/slog" "os/signal" + "runtime/debug" "syscall" "github.com/RedBoardDev/gh-runners-tool/v2/internal/auth" @@ -78,7 +80,7 @@ func runDaemonGroup(d *daemon) error { { ctx, cancel := context.WithCancel(context.Background()) g.Add( - func() error { return d.ctrl.Run(ctx) }, + safeActor(d.logger, "controller", func() error { return d.ctrl.Run(ctx) }), func(error) { cancel() }, ) } @@ -86,7 +88,7 @@ func runDaemonGroup(d *daemon) error { { ctx, cancel := context.WithCancel(context.Background()) g.Add( - func() error { return d.health.Run(ctx) }, + safeActor(d.logger, "health", func() error { return d.health.Run(ctx) }), func(error) { cancel() }, ) } @@ -94,7 +96,7 @@ func runDaemonGroup(d *daemon) error { { ctx, cancel := context.WithCancel(context.Background()) g.Add( - func() error { return d.api.Run(ctx) }, + safeActor(d.logger, "api", func() error { return d.api.Run(ctx) }), func(error) { cancel() }, ) } @@ -102,7 +104,15 @@ func runDaemonGroup(d *daemon) error { { ctx, cancel := context.WithCancel(context.Background()) g.Add( - func() error { return d.logMgr.StartCleanupScheduler(ctx) }, + safeActor(d.logger, "log-cleanup", func() error { return d.logMgr.StartCleanupScheduler(ctx) }), + func(error) { cancel() }, + ) + } + + { + ctx, cancel := context.WithCancel(context.Background()) + g.Add( + safeActor(d.logger, "watchdog", func() error { return runWatchdog(ctx, d.cfg.Daemon.StateDir, d.logger) }), func(error) { cancel() }, ) } @@ -127,3 +137,16 @@ func runDaemonGroup(d *daemon) error { d.logger.Info("ghr stopped") return groupErr } + +func safeActor(logger *slog.Logger, name string, fn func() error) func() error { + return func() (err error) { + defer func() { + if r := recover(); r != nil { + stack := debug.Stack() + logger.Error("actor panicked", "actor", name, "panic", fmt.Sprintf("%v", r), "stack", string(stack)) + err = fmt.Errorf("actor %s panicked: %v", name, r) + } + }() + return fn() + } +} diff --git a/internal/cli/run_test.go b/internal/cli/run_test.go new file mode 100644 index 0000000..7f28ebf --- /dev/null +++ b/internal/cli/run_test.go @@ -0,0 +1,43 @@ +package cli + +import ( + "bytes" + "errors" + "log/slog" + "strings" + "testing" +) + +func TestSafeActor_CapturesPanic(t *testing.T) { + var buf bytes.Buffer + logger := slog.New(slog.NewTextHandler(&buf, nil)) + + actor := safeActor(logger, "test", func() error { + panic("boom") + }) + + err := actor() + if err == nil { + t.Fatal("expected error from panicking actor") + } + if !strings.Contains(err.Error(), "actor test panicked") { + t.Errorf("error = %v, want substring 'actor test panicked'", err) + } + if !strings.Contains(buf.String(), "actor panicked") { + t.Errorf("expected log entry 'actor panicked', got: %s", buf.String()) + } +} + +func TestSafeActor_PassesThroughErrors(t *testing.T) { + logger := slog.New(slog.NewTextHandler(&bytes.Buffer{}, nil)) + want := errors.New("normal failure") + + actor := safeActor(logger, "test", func() error { + return want + }) + + got := actor() + if !errors.Is(got, want) { + t.Errorf("err = %v, want %v", got, want) + } +} diff --git a/internal/cli/state.go b/internal/cli/state.go index 3dc7d42..b7d4ac0 100644 --- a/internal/cli/state.go +++ b/internal/cli/state.go @@ -4,38 +4,34 @@ import ( "encoding/json" "fmt" "os" - "path/filepath" "time" -) -const stateFileName = "daemon.state.json" + "github.com/RedBoardDev/gh-runners-tool/v2/internal/state" +) type daemonState struct { - ConfigPath string `json:"config_path"` - StartedAt time.Time `json:"started_at"` - PID int `json:"pid"` - Groups map[string]int `json:"groups"` + ConfigPath string `json:"config_path"` + StartedAt time.Time `json:"started_at"` + PID int `json:"pid"` } func writeDaemonState(stateDir, configPath string) error { - state := daemonState{ + ds := daemonState{ ConfigPath: configPath, StartedAt: time.Now(), PID: os.Getpid(), - Groups: make(map[string]int), } - data, err := json.MarshalIndent(state, "", " ") + data, err := json.MarshalIndent(ds, "", " ") if err != nil { return fmt.Errorf("marshal daemon state: %w", err) } - dir := stateDir - if err := os.MkdirAll(dir, 0o755); err != nil { - return fmt.Errorf("create state directory %s: %w", dir, err) + if err := os.MkdirAll(stateDir, 0o755); err != nil { + return fmt.Errorf("create state directory %s: %w", stateDir, err) } - path := filepath.Join(dir, stateFileName) + path := state.New(stateDir).StateFile() if err := os.WriteFile(path, data, 0o644); err != nil { return fmt.Errorf("write daemon state %s: %w", path, err) } @@ -43,20 +39,19 @@ func writeDaemonState(stateDir, configPath string) error { } func readDaemonState(stateDir string) (*daemonState, error) { - path := filepath.Join(stateDir, stateFileName) + path := state.New(stateDir).StateFile() data, err := os.ReadFile(path) if err != nil { return nil, fmt.Errorf("read daemon state %s: %w", path, err) } - var state daemonState - if err := json.Unmarshal(data, &state); err != nil { + var ds daemonState + if err := json.Unmarshal(data, &ds); err != nil { return nil, fmt.Errorf("parse daemon state %s: %w", path, err) } - return &state, nil + return &ds, nil } func removeDaemonState(stateDir string) { - path := filepath.Join(stateDir, stateFileName) - _ = os.Remove(path) + _ = os.Remove(state.New(stateDir).StateFile()) } diff --git a/internal/cli/status.go b/internal/cli/status.go index e5af2f5..b7da27f 100644 --- a/internal/cli/status.go +++ b/internal/cli/status.go @@ -11,6 +11,7 @@ import ( "time" "github.com/RedBoardDev/gh-runners-tool/v2/internal/config" + "github.com/RedBoardDev/gh-runners-tool/v2/internal/state" "github.com/spf13/cobra" ) @@ -45,7 +46,7 @@ func runStatus(cmd *cobra.Command, _ []string) error { } stateDir := resolveStateDir() - socketPath := filepath.Join(stateDir, "ghr.sock") + socketPath := state.New(stateDir).Socket() if !watch { return renderOnce(socketPath, stateDir, jsonOutput) diff --git a/internal/cli/status_render.go b/internal/cli/status_render.go index 95da2ff..b12a9a5 100644 --- a/internal/cli/status_render.go +++ b/internal/cli/status_render.go @@ -16,7 +16,7 @@ type statusResponse struct { type statusRunner struct { Name string `json:"name"` State string `json:"state"` - PID int `json:"pid"` + PID int32 `json:"pid"` JobName string `json:"job_name"` } diff --git a/internal/cli/watchdog.go b/internal/cli/watchdog.go new file mode 100644 index 0000000..a66dd58 --- /dev/null +++ b/internal/cli/watchdog.go @@ -0,0 +1,77 @@ +package cli + +import ( + "context" + "log/slog" + "net" + "net/http" + "os" + "time" + + "github.com/RedBoardDev/gh-runners-tool/v2/internal/state" +) + +const ( + watchdogInterval = 30 * time.Second + watchdogTimeout = 5 * time.Second + watchdogFailureThreshold = 3 +) + +// runWatchdog probes the daemon's own /health endpoint over the unix socket. +// After watchdogFailureThreshold consecutive failures it logs critical and +// exits with code 2 so launchd can respawn the process. +func runWatchdog(ctx context.Context, stateDir string, logger *slog.Logger) error { + socketPath := state.New(stateDir).Socket() + client := &http.Client{ + Timeout: watchdogTimeout, + Transport: &http.Transport{ + DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { + return net.Dial("unix", socketPath) + }, + }, + } + + // Allow the API server to come up before starting to probe. + select { + case <-ctx.Done(): + return nil + case <-time.After(watchdogInterval): + } + + ticker := time.NewTicker(watchdogInterval) + defer ticker.Stop() + + failures := 0 + for { + if probeOK(ctx, client) { + failures = 0 + } else { + failures++ + logger.Warn("watchdog probe failed", "consecutive_failures", failures) + if failures >= watchdogFailureThreshold { + logger.Error("watchdog tripped; exiting for launchd to respawn", + "threshold", watchdogFailureThreshold) + os.Exit(2) + } + } + + select { + case <-ctx.Done(): + return nil + case <-ticker.C: + } + } +} + +func probeOK(ctx context.Context, client *http.Client) bool { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "http://unix/health", http.NoBody) + if err != nil { + return false + } + resp, err := client.Do(req) + if err != nil { + return false + } + defer resp.Body.Close() + return resp.StatusCode >= 200 && resp.StatusCode < 500 +} diff --git a/internal/cli/watchdog_test.go b/internal/cli/watchdog_test.go new file mode 100644 index 0000000..eafd7ca --- /dev/null +++ b/internal/cli/watchdog_test.go @@ -0,0 +1,63 @@ +package cli + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestProbeOK(t *testing.T) { + tests := []struct { + name string + statusCode int + want bool + }{ + {"200", http.StatusOK, true}, + {"204", http.StatusNoContent, true}, + {"4xx is alive (no body issue)", http.StatusNotFound, true}, + {"5xx counts as failure", http.StatusBadGateway, false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tc.statusCode) + })) + defer srv.Close() + + // rewrite the URL path-only request to point at the httptest server. + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, srv.URL+"/health", http.NoBody) + if err != nil { + t.Fatalf("new request: %v", err) + } + resp, err := srv.Client().Do(req) + if err != nil { + t.Fatalf("do: %v", err) + } + got := resp.StatusCode >= 200 && resp.StatusCode < 500 + resp.Body.Close() + if got != tc.want { + t.Errorf("probe status %d => %v, want %v", tc.statusCode, got, tc.want) + } + }) + } +} + +func TestProbeOK_ConnectionError(t *testing.T) { + client := &http.Client{Timeout: 50 * time.Millisecond} + // Hit an address that should not have anyone listening. + req, _ := http.NewRequestWithContext(context.Background(), http.MethodGet, "http://127.0.0.1:1/health", http.NoBody) + resp, err := client.Do(req) + if resp != nil { + resp.Body.Close() + } + if err == nil { + t.Skip("unexpected connectivity to port 1") + } + if !strings.Contains(err.Error(), "connect") && !strings.Contains(err.Error(), "refused") && !strings.Contains(err.Error(), "timeout") { + t.Skipf("connection error has unexpected form (env-dependent): %v", err) + } +} diff --git a/internal/config/loader.go b/internal/config/loader.go index 1d69109..ab30f07 100644 --- a/internal/config/loader.go +++ b/internal/config/loader.go @@ -57,6 +57,9 @@ func applyDefaults(cfg *Config) { if cfg.GitHub.RunnerGroup == "" { cfg.GitHub.RunnerGroup = "default" } + if cfg.GitHub.RunnerGroupID == 0 { + cfg.GitHub.RunnerGroupID = 1 + } if cfg.Runner.Version == "" { cfg.Runner.Version = "latest" diff --git a/internal/config/loader_test.go b/internal/config/loader_test.go index e6f19ee..65bb81f 100644 --- a/internal/config/loader_test.go +++ b/internal/config/loader_test.go @@ -368,6 +368,36 @@ groups: - ""`, wantInErr: "labels[0] must not be empty", }, + { + name: "label with invalid character", + yaml: ` +groups: + - name: grp + max_runners: 1 + labels: + - "bad label!"`, + wantInErr: "labels[0] \"bad label!\" must match", + }, + { + name: "label too long", + yaml: ` +groups: + - name: grp + max_runners: 1 + labels: + - "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"`, + wantInErr: "must match", + }, + { + name: "label starting with hyphen", + yaml: ` +groups: + - name: grp + max_runners: 1 + labels: + - "-leading"`, + wantInErr: "must match", + }, { name: "invalid logging level", yaml: ` @@ -408,6 +438,36 @@ groups: max_runners: 1`, wantInErr: "health.runner_timeout must be >= 1m", }, + { + name: "workdir_base relative path", + yaml: ` +runner: + workdir_base: "relative/path" +groups: + - name: grp + max_runners: 1`, + wantInErr: "runner.workdir_base must be absolute", + }, + { + name: "workdir_base on /tmp", + yaml: ` +runner: + workdir_base: "/tmp" +groups: + - name: grp + max_runners: 1`, + wantInErr: "must not be a top-level system directory", + }, + { + name: "workdir_base too short", + yaml: ` +runner: + workdir_base: "/abc" +groups: + - name: grp + max_runners: 1`, + wantInErr: "too short", + }, { name: "shutdown_timeout too small", yaml: ` diff --git a/internal/config/types.go b/internal/config/types.go index 5cb4f24..f7ebb29 100644 --- a/internal/config/types.go +++ b/internal/config/types.go @@ -17,8 +17,9 @@ type Config struct { } type GitHubConfig struct { - URL string `yaml:"url"` - RunnerGroup string `yaml:"runner_group"` + URL string `yaml:"url"` + RunnerGroup string `yaml:"runner_group"` + RunnerGroupID int `yaml:"runner_group_id"` } type RunnerConfig struct { diff --git a/internal/config/validate.go b/internal/config/validate.go index 598ca32..0dddba4 100644 --- a/internal/config/validate.go +++ b/internal/config/validate.go @@ -3,9 +3,30 @@ package config import ( "errors" "fmt" + "path/filepath" + "regexp" "time" ) +var labelPattern = regexp.MustCompile(`^[A-Za-z0-9][A-Za-z0-9._-]{0,63}$`) + +// unsafeWorkdirBase enumerates directories ghr must refuse to claim as its +// runner workdir root. Substring matches by tools like `pgrep -f` would +// otherwise sweep up arbitrary user processes. +var unsafeWorkdirBase = map[string]struct{}{ + "/": {}, + "/tmp": {}, + "/var": {}, + "/usr": {}, + "/etc": {}, + "/home": {}, + "/root": {}, + "/opt": {}, + "/bin": {}, + "/sbin": {}, + "/dev": {}, +} + func validate(cfg *Config) error { var errs []error @@ -40,8 +61,29 @@ func validate(cfg *Config) error { } for j, label := range g.Labels { - if label == "" { + switch { + case label == "": errs = append(errs, fmt.Errorf("%s (%s): labels[%d] must not be empty", prefix, g.Name, j)) + case !labelPattern.MatchString(label): + errs = append(errs, fmt.Errorf("%s (%s): labels[%d] %q must match %s", prefix, g.Name, j, label, labelPattern.String())) + } + } + } + + if cfg.GitHub.RunnerGroupID < 1 { + errs = append(errs, fmt.Errorf("github.runner_group_id must be >= 1, got %d", cfg.GitHub.RunnerGroupID)) + } + + if cfg.Runner.WorkdirBase != "" { + clean := filepath.Clean(cfg.Runner.WorkdirBase) + switch { + case !filepath.IsAbs(cfg.Runner.WorkdirBase): + errs = append(errs, fmt.Errorf("runner.workdir_base must be absolute, got %q", cfg.Runner.WorkdirBase)) + default: + if _, banned := unsafeWorkdirBase[clean]; banned { + errs = append(errs, fmt.Errorf("runner.workdir_base must not be a top-level system directory, got %q", cfg.Runner.WorkdirBase)) + } else if len(clean) < 8 { + errs = append(errs, fmt.Errorf("runner.workdir_base %q is too short (orphan-process matching would be unsafe)", cfg.Runner.WorkdirBase)) } } } diff --git a/internal/controller/controller.go b/internal/controller/controller.go index 100da0d..0624341 100644 --- a/internal/controller/controller.go +++ b/internal/controller/controller.go @@ -73,14 +73,15 @@ func (c *GroupController) Run(ctx context.Context) error { var wg sync.WaitGroup errCh := make(chan error, len(c.groups)) - for _, g := range c.groups { + for i := range c.groups { + group := &c.groups[i] wg.Add(1) - go func(group *config.GroupConfig) { + go func() { defer wg.Done() if err := c.runGroup(ctx, group); err != nil { errCh <- err } - }(&g) + }() } <-ctx.Done() diff --git a/internal/controller/group.go b/internal/controller/group.go index 4ba23ea..9934e10 100644 --- a/internal/controller/group.go +++ b/internal/controller/group.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "log/slog" + "math/rand" "os" "time" @@ -185,7 +186,9 @@ func deduplicateLabels(groupName string, extra []string) []string { func nextBackoff(current time.Duration) time.Duration { next := current * 2 if next > backoffMax { - return backoffMax + next = backoffMax } - return next + // ±20% jitter to spread retries across groups that all failed at the same tick. + jitter := time.Duration((rand.Float64()*0.4 - 0.2) * float64(next)) + return next + jitter } diff --git a/internal/controller/group_test.go b/internal/controller/group_test.go new file mode 100644 index 0000000..79876a0 --- /dev/null +++ b/internal/controller/group_test.go @@ -0,0 +1,43 @@ +package controller + +import ( + "testing" + "time" +) + +func TestNextBackoff_JitteredWithinBounds(t *testing.T) { + const samples = 200 + current := 4 * time.Second + + for i := 0; i < samples; i++ { + got := nextBackoff(current) + // next would be current * 2 = 8s. + low := time.Duration(float64(8*time.Second) * 0.8) + high := time.Duration(float64(8*time.Second) * 1.2) + if got < low || got > high { + t.Fatalf("nextBackoff(%s) = %s, want within [%s, %s]", current, got, low, high) + } + } +} + +func TestNextBackoff_RespectsCap(t *testing.T) { + current := backoffMax + for i := 0; i < 100; i++ { + got := nextBackoff(current) + // Capped at backoffMax (with ±20% jitter on the cap itself). + if got < time.Duration(float64(backoffMax)*0.8) || got > time.Duration(float64(backoffMax)*1.2) { + t.Fatalf("nextBackoff(cap) = %s, expected within ±20%% of %s", got, backoffMax) + } + } +} + +func TestNextBackoff_VariesAcrossCalls(t *testing.T) { + current := 4 * time.Second + seen := make(map[time.Duration]struct{}) + for i := 0; i < 20; i++ { + seen[nextBackoff(current)] = struct{}{} + } + if len(seen) < 2 { + t.Fatalf("expected jittered values across calls, got %d unique results", len(seen)) + } +} diff --git a/internal/controller/scaler.go b/internal/controller/scaler.go index fe3e45a..00d4acf 100644 --- a/internal/controller/scaler.go +++ b/internal/controller/scaler.go @@ -148,6 +148,12 @@ func (s *MacOSScaler) HandleJobCompleted(ctx context.Context, jobInfo *scaleset. "error", cleanupErr, ) } + if logsErr := s.logMgr.RemoveRunnerLogs(s.groupName, jobInfo.RunnerName); logsErr != nil { + s.logger.WarnContext(ctx, "failed to remove runner log dir", + "runner", jobInfo.RunnerName, + "error", logsErr, + ) + } } else { s.logger.WarnContext(ctx, "job completed for unknown runner", "runner", jobInfo.RunnerName, diff --git a/internal/controller/scaler_ops.go b/internal/controller/scaler_ops.go index cee674c..26f790a 100644 --- a/internal/controller/scaler_ops.go +++ b/internal/controller/scaler_ops.go @@ -15,8 +15,7 @@ func (s *MacOSScaler) startRunner(ctx context.Context) error { if _, err := rand.Read(randBytes); err != nil { return fmt.Errorf("generate runner ID: %w", err) } - id := hex.EncodeToString(randBytes) - name := fmt.Sprintf("%s-%s", s.groupName, id) + name := fmt.Sprintf("%s-%s", s.groupName, hex.EncodeToString(randBytes)) jitConfig, err := s.client.GenerateJITConfig(ctx, s.scaleSetID, name) if err != nil { @@ -24,7 +23,6 @@ func (s *MacOSScaler) startRunner(ctx context.Context) error { } instance := model.RunnerInstance{ - ID: id, Name: name, Group: s.groupName, } @@ -82,6 +80,13 @@ func (s *MacOSScaler) killRunner(ctx context.Context, runnerName string) error { return fmt.Errorf("cleanup runner %q: %w", runnerName, cleanupErr) } + if logsErr := s.logMgr.RemoveRunnerLogs(s.groupName, runnerName); logsErr != nil { + s.logger.WarnContext(ctx, "failed to remove runner log dir", + "runner", runnerName, + "error", logsErr, + ) + } + s.logger.InfoContext(ctx, "killed runner", "runner", runnerName, "group", s.groupName) return nil } @@ -114,6 +119,12 @@ func (s *MacOSScaler) Shutdown(ctx context.Context) { "error", cleanupErr, ) } + if logsErr := s.logMgr.RemoveRunnerLogs(s.groupName, proc.Name); logsErr != nil { + s.logger.WarnContext(ctx, "failed to remove runner log dir during shutdown", + "runner", proc.Name, + "error", logsErr, + ) + } } } diff --git a/internal/health/checks.go b/internal/health/checks.go index 3a9a94b..ecbc685 100644 --- a/internal/health/checks.go +++ b/internal/health/checks.go @@ -13,14 +13,15 @@ func (m *Monitor) runChecks(ctx context.Context) { start := time.Now() m.mu.Lock() - defer m.mu.Unlock() - m.issues = m.issues[:0] snapshots := m.runners.Snapshots() totalActual := 0 totalDesired := 0 + groupActuals := make(map[string]int, len(snapshots)) + groupDesireds := make(map[string]int, len(snapshots)) + for group, snaps := range snapshots { m.checkRunnerLiveness(ctx, group, snaps) m.checkRunnerTimeouts(ctx, group, snaps) @@ -30,24 +31,54 @@ func (m *Monitor) runChecks(ctx context.Context) { m.checkConsecutiveFailures(group, gs) totalActual += len(snaps) totalDesired += gs.lastDesiredCount + groupActuals[group] = len(snaps) + groupDesireds[group] = gs.lastDesiredCount } m.checkDiskSpace() m.lastCheck = time.Now() checkDuration := time.Since(start) - for _, r := range m.reporters { - r.ReportDaemonHealth(ctx, len(snapshots), totalActual, totalDesired, checkDuration) + issuesCopy := make([]model.HealthIssue, len(m.issues)) + copy(issuesCopy, m.issues) + reporters := m.reporters + notifier := m.notifier + groupsCount := len(snapshots) + m.mu.Unlock() + + go dispatchHealthReports(ctx, reporters, notifier, dispatchPayload{ + groupsCount: groupsCount, + totalActual: totalActual, + totalDesired: totalDesired, + checkDuration: checkDuration, + groupActuals: groupActuals, + groupDesireds: groupDesireds, + issues: issuesCopy, + }) +} + +type dispatchPayload struct { + groupsCount int + totalActual int + totalDesired int + checkDuration time.Duration + groupActuals map[string]int + groupDesireds map[string]int + issues []model.HealthIssue +} + +func dispatchHealthReports(ctx context.Context, reporters []Reporter, notifier Notifier, p dispatchPayload) { + for _, r := range reporters { + r.ReportDaemonHealth(ctx, p.groupsCount, p.totalActual, p.totalDesired, p.checkDuration) } - for group, snaps := range snapshots { - gs := m.getOrCreateGroup(group) - for _, r := range m.reporters { - r.ReportGroupHealth(ctx, group, len(snaps), gs.lastDesiredCount) + for group, actual := range p.groupActuals { + desired := p.groupDesireds[group] + for _, r := range reporters { + r.ReportGroupHealth(ctx, group, actual, desired) } } - - for _, issue := range m.issues { - m.notifier.Notify(ctx, &model.Event{ + for _, issue := range p.issues { + notifier.Notify(ctx, &model.Event{ Type: issue.Type, Level: issue.Level, Group: issue.Group, @@ -63,7 +94,7 @@ func (m *Monitor) checkRunnerLiveness(ctx context.Context, group string, snapsho if snap.PID <= 0 { continue } - if err := syscall.Kill(snap.PID, 0); err != nil { + if err := syscall.Kill(int(snap.PID), 0); err != nil { m.issues = append(m.issues, model.HealthIssue{ Level: model.LevelError, Type: model.EventHealthZombieRunner, diff --git a/internal/health/checks_test.go b/internal/health/checks_test.go index 2f99716..a48c213 100644 --- a/internal/health/checks_test.go +++ b/internal/health/checks_test.go @@ -5,6 +5,7 @@ import ( "fmt" "log/slog" "os" + "sync" "testing" "time" @@ -12,13 +13,38 @@ import ( ) type noopNotifier struct { + mu sync.Mutex events []model.Event } func (n *noopNotifier) Notify(_ context.Context, event *model.Event) { + n.mu.Lock() + defer n.mu.Unlock() n.events = append(n.events, *event) } +func (n *noopNotifier) snapshot() []model.Event { + n.mu.Lock() + defer n.mu.Unlock() + out := make([]model.Event, len(n.events)) + copy(out, n.events) + return out +} + +func (n *noopNotifier) waitFor(t *testing.T, eventType string, timeout time.Duration) bool { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + for _, e := range n.snapshot() { + if e.Type == eventType { + return true + } + } + time.Sleep(10 * time.Millisecond) + } + return false +} + type fakeRunnerState struct { snapshots map[string][]model.RunnerSnapshot } @@ -301,17 +327,55 @@ func TestRunChecks_IntegrationWithNotifier(t *testing.T) { m.runChecks(context.Background()) - foundIdle := false - for _, e := range notif.events { - if e.Type == model.EventHealthIdleTimeout { - foundIdle = true - } - } - if !foundIdle { + if !notif.waitFor(t, model.EventHealthIdleTimeout, 2*time.Second) { t.Error("expected idle timeout event to be notified") } } +func TestRunChecks_StatusObservesCoherentSnapshots(t *testing.T) { + notif := &noopNotifier{} + state := &fakeRunnerState{ + snapshots: map[string][]model.RunnerSnapshot{ + "group-a": { + {Name: "r1", State: "idle", PID: 99999999, StartedAt: time.Now().Add(-2 * time.Hour)}, + }, + }, + } + + m := NewMonitor( + MonitorConfig{ + Enabled: true, + CheckInterval: time.Second, + IdleTimeout: 30 * time.Minute, + }, + notif, + state, + nil, + nil, + noopLogger(), + ) + + // runChecks and Status() are expected to be safe to call concurrently: + // Status() may observe either the previous or the freshly computed issue + // list, but never a torn one. + done := make(chan struct{}) + go func() { + defer close(done) + for i := 0; i < 50; i++ { + m.runChecks(context.Background()) + } + }() + for i := 0; i < 50; i++ { + hs := m.Status() + for _, issue := range hs.Issues { + if issue.Type == "" { + t.Fatalf("torn issue read: %+v", issue) + } + } + } + <-done +} + func timePtr(t time.Time) *time.Time { return &t } diff --git a/internal/launchd/launchctl.go b/internal/launchd/launchctl.go index 0f1d44f..47eac94 100644 --- a/internal/launchd/launchctl.go +++ b/internal/launchd/launchctl.go @@ -2,37 +2,40 @@ package launchd import ( "fmt" + "os" "os/exec" + "strconv" ) -func launchctlLoad(plistPath string) error { - out, err := exec.Command("launchctl", "load", plistPath).CombinedOutput() - if err != nil { - return fmt.Errorf("launchctl load: %w: %s", err, string(out)) +// domainTarget returns the launchctl service target prefix +// (e.g. "gui/501" or "system") suitable for bootstrap/bootout/kickstart. +func domainTarget() string { + if os.Getuid() == 0 { + return "system" } - return nil + return "gui/" + strconv.Itoa(os.Getuid()) } -func launchctlUnload(plistPath string) error { - out, err := exec.Command("launchctl", "unload", plistPath).CombinedOutput() +func launchctlBootstrap(plistPath string) error { + out, err := exec.Command("launchctl", "bootstrap", domainTarget(), plistPath).CombinedOutput() if err != nil { - return fmt.Errorf("launchctl unload: %w: %s", err, string(out)) + return fmt.Errorf("launchctl bootstrap: %w: %s", err, string(out)) } return nil } -func launchctlStart(label string) error { - out, err := exec.Command("launchctl", "start", label).CombinedOutput() +func launchctlBootout(label string) error { + out, err := exec.Command("launchctl", "bootout", domainTarget()+"/"+label).CombinedOutput() if err != nil { - return fmt.Errorf("launchctl start: %w: %s", err, string(out)) + return fmt.Errorf("launchctl bootout: %w: %s", err, string(out)) } return nil } -func launchctlStop(label string) error { - out, err := exec.Command("launchctl", "stop", label).CombinedOutput() +func launchctlKickstart(label string) error { + out, err := exec.Command("launchctl", "kickstart", "-k", domainTarget()+"/"+label).CombinedOutput() if err != nil { - return fmt.Errorf("launchctl stop: %w: %s", err, string(out)) + return fmt.Errorf("launchctl kickstart: %w: %s", err, string(out)) } return nil } diff --git a/internal/launchd/plist.go b/internal/launchd/plist.go index 84f2b9d..12cd6fa 100644 --- a/internal/launchd/plist.go +++ b/internal/launchd/plist.go @@ -2,6 +2,7 @@ package launchd import ( "bytes" + "encoding/xml" "fmt" "text/template" ) @@ -11,13 +12,13 @@ const plistTemplate = ` Label - {{.Label}} + {{xml .Label}} ProgramArguments - {{.BinaryPath}} + {{xml .BinaryPath}} run --config - {{.ConfigPath}} + {{xml .ConfigPath}} RunAtLoad @@ -27,11 +28,11 @@ const plistTemplate = ` StandardOutPath - {{.LogDir}}/daemon.log + {{xml .LogDir}}/daemon.log StandardErrorPath - {{.LogDir}}/daemon.err + {{xml .LogDir}}/daemon.err WorkingDirectory - {{.StateDir}} + {{xml .StateDir}} EnvironmentVariables PATH @@ -41,8 +42,20 @@ const plistTemplate = ` ` +var plistFuncs = template.FuncMap{ + "xml": xmlEscape, +} + +func xmlEscape(s string) (string, error) { + var buf bytes.Buffer + if err := xml.EscapeText(&buf, []byte(s)); err != nil { + return "", err + } + return buf.String(), nil +} + func generatePlist(cfg *ServiceConfig) ([]byte, error) { - tmpl, err := template.New("plist").Parse(plistTemplate) + tmpl, err := template.New("plist").Funcs(plistFuncs).Parse(plistTemplate) if err != nil { return nil, fmt.Errorf("parse plist template: %w", err) } diff --git a/internal/launchd/plist_test.go b/internal/launchd/plist_test.go index ee8ff4b..e4cbad1 100644 --- a/internal/launchd/plist_test.go +++ b/internal/launchd/plist_test.go @@ -65,3 +65,32 @@ func TestGeneratePlist_SpecialChars(t *testing.T) { t.Error("plist should preserve paths with spaces") } } + +func TestGeneratePlist_EscapesXMLMetacharacters(t *testing.T) { + cfg := ServiceConfig{ + Label: "com.ghr.injected", + BinaryPath: `/tmp/xInjectedKeyyes`, + ConfigPath: "/config/&yaml", + LogDir: "/tmp/log\"dir", + StateDir: "/tmp/state", + } + + data, err := generatePlist(&cfg) + if err != nil { + t.Fatalf("generatePlist() error = %v", err) + } + out := string(data) + + if strings.Contains(out, "InjectedKey") { + t.Errorf("plist must escape XML payload, got: %s", out) + } + for _, needle := range []string{ + "</string>", + "<key>InjectedKey</key>", + "/config/<test>&yaml", + } { + if !strings.Contains(out, needle) { + t.Errorf("expected escaped %q in plist, got: %s", needle, out) + } + } +} diff --git a/internal/launchd/service.go b/internal/launchd/service.go index 0bf494b..e8bf34c 100644 --- a/internal/launchd/service.go +++ b/internal/launchd/service.go @@ -46,12 +46,12 @@ func Install(cfg *ServiceConfig) error { return fmt.Errorf("write plist %s: %w", plistPath, err) } - if err := launchctlLoad(plistPath); err != nil { - return fmt.Errorf("launchctl load: %w", err) + if err := launchctlBootstrap(plistPath); err != nil { + return fmt.Errorf("launchctl bootstrap: %w", err) } - if err := launchctlStart(cfg.Label); err != nil { - return fmt.Errorf("launchctl start: %w", err) + if err := launchctlKickstart(cfg.Label); err != nil { + return fmt.Errorf("launchctl kickstart: %w", err) } return nil @@ -60,8 +60,7 @@ func Install(cfg *ServiceConfig) error { func Uninstall(label string) error { plistPath := PlistPath(label) - _ = launchctlStop(label) - _ = launchctlUnload(plistPath) + _ = launchctlBootout(label) if err := os.Remove(plistPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("remove plist %s: %w", plistPath, err) diff --git a/internal/logging/logger_test.go b/internal/logging/logger_test.go index 900c355..1af5983 100644 --- a/internal/logging/logger_test.go +++ b/internal/logging/logger_test.go @@ -372,6 +372,36 @@ func TestRunnerLogger(t *testing.T) { } } +func TestRemoveRunnerLogs(t *testing.T) { + mgr := newTestManager(t) + + w, err := mgr.RunnerOutputFile("group-a", "runner-1") + if err != nil { + t.Fatalf("RunnerOutputFile() error = %v", err) + } + if _, err := w.Write([]byte("hello\n")); err != nil { + t.Fatalf("write runner output: %v", err) + } + + runnerDir := filepath.Join(mgr.rootDir, "groups", "group-a", "runners", "runner-1") + if _, statErr := os.Stat(runnerDir); statErr != nil { + t.Fatalf("runner log dir missing before cleanup: %v", statErr) + } + + if err := mgr.RemoveRunnerLogs("group-a", "runner-1"); err != nil { + t.Fatalf("RemoveRunnerLogs() error = %v", err) + } + + if _, statErr := os.Stat(runnerDir); !os.IsNotExist(statErr) { + t.Errorf("runner log dir still present after cleanup: %v", statErr) + } + + // Removing again must be a no-op (idempotency contract). + if err := mgr.RemoveRunnerLogs("group-a", "runner-1"); err != nil { + t.Errorf("RemoveRunnerLogs() second call error = %v", err) + } +} + // --------------------------------------------------------------------------- // TestDateRotation // --------------------------------------------------------------------------- diff --git a/internal/logging/manager.go b/internal/logging/manager.go index 1e1abb9..5b4ccb2 100644 --- a/internal/logging/manager.go +++ b/internal/logging/manager.go @@ -135,7 +135,31 @@ func (m *LogManager) RunnerOutputFile(group, runner string) (io.WriteCloser, err return nil, fmt.Errorf("logging: runner output file for %q/%q: %w", group, runner, err) } m.trackWriter(w) - return w, nil + return newTaggedWriter(w, group, runner), nil +} + +func (m *LogManager) RemoveRunnerLogs(group, runner string) error { + dir := filepath.Join(m.rootDir, "groups", group, "runners", runner) + + m.mu.Lock() + kept := m.writers[:0] + for _, w := range m.writers { + if w.dir == dir { + if closeErr := w.Close(); closeErr != nil { + m.mu.Unlock() + return fmt.Errorf("logging: close runner writer %q/%q: %w", group, runner, closeErr) + } + continue + } + kept = append(kept, w) + } + m.writers = kept + m.mu.Unlock() + + if err := os.RemoveAll(dir); err != nil { + return fmt.Errorf("logging: remove runner log dir %s: %w", dir, err) + } + return nil } func (m *LogManager) StartCleanupScheduler(ctx context.Context) error { diff --git a/internal/logging/tagged_writer.go b/internal/logging/tagged_writer.go new file mode 100644 index 0000000..a59d5dc --- /dev/null +++ b/internal/logging/tagged_writer.go @@ -0,0 +1,89 @@ +package logging + +import ( + "bytes" + "encoding/json" + "io" + "sync" + "time" +) + +// taggedWriter wraps an underlying io.WriteCloser and emits each line of input +// as a JSON object enriched with metadata (group, runner, source). Partial +// lines are buffered until a newline arrives, so structured tools can rely on +// one JSON object per output line. +type taggedWriter struct { + mu sync.Mutex + inner io.WriteCloser + buf bytes.Buffer + group string + runner string + now func() time.Time +} + +func newTaggedWriter(inner io.WriteCloser, group, runner string) *taggedWriter { + return &taggedWriter{ + inner: inner, + group: group, + runner: runner, + now: time.Now, + } +} + +type taggedLine struct { + Time string `json:"time"` + Source string `json:"source"` + Group string `json:"group"` + Runner string `json:"runner"` + Line string `json:"line"` +} + +func (w *taggedWriter) Write(p []byte) (int, error) { + w.mu.Lock() + defer w.mu.Unlock() + w.buf.Write(p) + + for { + idx := bytes.IndexByte(w.buf.Bytes(), '\n') + if idx < 0 { + break + } + line := string(w.buf.Bytes()[:idx]) + w.buf.Next(idx + 1) + if err := w.emit(line); err != nil { + return 0, err + } + } + + return len(p), nil +} + +func (w *taggedWriter) emit(line string) error { + rec := taggedLine{ + Time: w.now().UTC().Format(time.RFC3339Nano), + Source: "runner", + Group: w.group, + Runner: w.runner, + Line: line, + } + encoded, err := json.Marshal(rec) + if err != nil { + return err + } + encoded = append(encoded, '\n') + _, err = w.inner.Write(encoded) + return err +} + +func (w *taggedWriter) Close() error { + w.mu.Lock() + defer w.mu.Unlock() + if w.buf.Len() > 0 { + if err := w.emit(w.buf.String()); err != nil { + w.inner.Close() + return err + } + w.buf.Reset() + } + return w.inner.Close() +} diff --git a/internal/logging/tagged_writer_test.go b/internal/logging/tagged_writer_test.go new file mode 100644 index 0000000..4de7af3 --- /dev/null +++ b/internal/logging/tagged_writer_test.go @@ -0,0 +1,107 @@ +package logging + +import ( + "bytes" + "encoding/json" + "io" + "strings" + "testing" +) + +type closingBuffer struct { + bytes.Buffer + closed bool +} + +func (c *closingBuffer) Close() error { + c.closed = true + return nil +} + +func decodeLines(t *testing.T, r io.Reader) []taggedLine { + t.Helper() + var out []taggedLine + scanner := bytes.NewReader(make([]byte, 0)) + _ = scanner + raw, err := io.ReadAll(r) + if err != nil { + t.Fatalf("read: %v", err) + } + for _, line := range strings.Split(strings.TrimSpace(string(raw)), "\n") { + if line == "" { + continue + } + var tl taggedLine + if err := json.Unmarshal([]byte(line), &tl); err != nil { + t.Fatalf("decode %q: %v", line, err) + } + out = append(out, tl) + } + return out +} + +func TestTaggedWriter_EmitsOneJSONLinePerNewline(t *testing.T) { + buf := &closingBuffer{} + w := newTaggedWriter(buf, "g1", "r1") + + if _, err := w.Write([]byte("hello\nworld\n")); err != nil { + t.Fatalf("write: %v", err) + } + + lines := decodeLines(t, bytes.NewReader(buf.Bytes())) + if len(lines) != 2 { + t.Fatalf("got %d lines, want 2", len(lines)) + } + if lines[0].Line != "hello" || lines[1].Line != "world" { + t.Errorf("lines = %+v", lines) + } + for _, l := range lines { + if l.Group != "g1" || l.Runner != "r1" || l.Source != "runner" { + t.Errorf("missing tags on %+v", l) + } + if l.Time == "" { + t.Errorf("missing timestamp on %+v", l) + } + } +} + +func TestTaggedWriter_BuffersPartialLines(t *testing.T) { + buf := &closingBuffer{} + w := newTaggedWriter(buf, "g", "r") + + if _, err := w.Write([]byte("partial ")); err != nil { + t.Fatalf("write: %v", err) + } + if buf.Len() != 0 { + t.Fatalf("partial line should be buffered, wrote: %s", buf.String()) + } + + if _, err := w.Write([]byte("line\n")); err != nil { + t.Fatalf("write: %v", err) + } + + lines := decodeLines(t, bytes.NewReader(buf.Bytes())) + if len(lines) != 1 || lines[0].Line != "partial line" { + t.Errorf("got %+v, want 'partial line'", lines) + } +} + +func TestTaggedWriter_CloseFlushesTrailingPartial(t *testing.T) { + buf := &closingBuffer{} + w := newTaggedWriter(buf, "g", "r") + + if _, err := w.Write([]byte("orphan no newline")); err != nil { + t.Fatalf("write: %v", err) + } + if err := w.Close(); err != nil { + t.Fatalf("close: %v", err) + } + + lines := decodeLines(t, bytes.NewReader(buf.Bytes())) + if len(lines) != 1 || lines[0].Line != "orphan no newline" { + t.Errorf("got %+v, want flushed orphan", lines) + } + if !buf.closed { + t.Errorf("inner writer was not closed") + } +} diff --git a/internal/model/group.go b/internal/model/group.go index 8e7e366..5dfa10e 100644 --- a/internal/model/group.go +++ b/internal/model/group.go @@ -11,7 +11,6 @@ type Group struct { } type RunnerInstance struct { - ID string Name string Group string WorkDir string @@ -22,7 +21,7 @@ type RunnerSnapshot struct { Name string `json:"name"` Group string `json:"group"` State string `json:"state"` - PID int `json:"pid"` + PID int32 `json:"pid"` StartedAt time.Time `json:"started_at"` JobName string `json:"job_name"` JobID string `json:"job_id"` diff --git a/internal/notification/discord.go b/internal/notification/discord.go index 38b157e..c6243ac 100644 --- a/internal/notification/discord.go +++ b/internal/notification/discord.go @@ -16,6 +16,10 @@ import ( const discordMinInterval = 400 * time.Millisecond +// discordServerErrorBackoff is the wait between a 5xx response and the retry. +// Overridable from tests via the test-only helper in discord_test_helper.go. +var discordServerErrorBackoff = 2 * time.Second + type DiscordConfig struct { WebhookURL string Username string @@ -54,35 +58,51 @@ func (d *DiscordProvider) Send(ctx context.Context, event *model.Event) error { return fmt.Errorf("marshal discord payload: %w", err) } - resp, err := d.doPost(ctx, body) + resp, err := d.postWithRateLimitRetry(ctx, body) if err != nil { return err } defer resp.Body.Close() - if resp.StatusCode == http.StatusTooManyRequests { + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return fmt.Errorf("discord webhook returned status %d", resp.StatusCode) + } + + return nil +} + +func (d *DiscordProvider) postWithRateLimitRetry(ctx context.Context, body []byte) (*http.Response, error) { + resp, err := d.doPost(ctx, body) + if err != nil { + return nil, err + } + + switch { + case resp.StatusCode == http.StatusTooManyRequests: retryAfter := parseRetryAfter(resp.Header.Get("Retry-After")) _, _ = io.Copy(io.Discard, resp.Body) resp.Body.Close() select { case <-ctx.Done(): - return fmt.Errorf("discord rate limited, context canceled: %w", ctx.Err()) + return nil, fmt.Errorf("discord rate limited, context canceled: %w", ctx.Err()) case <-time.After(retryAfter): } + return d.doPost(ctx, body) - resp, err = d.doPost(ctx, body) - if err != nil { - return err - } - defer resp.Body.Close() - } + case resp.StatusCode >= 500: + _, _ = io.Copy(io.Discard, resp.Body) + resp.Body.Close() - if resp.StatusCode < 200 || resp.StatusCode >= 300 { - return fmt.Errorf("discord webhook returned status %d", resp.StatusCode) + select { + case <-ctx.Done(): + return nil, fmt.Errorf("discord 5xx, context canceled: %w", ctx.Err()) + case <-time.After(discordServerErrorBackoff): + } + return d.doPost(ctx, body) } - return nil + return resp, nil } func (d *DiscordProvider) throttle() { diff --git a/internal/notification/discord_test.go b/internal/notification/discord_test.go index 4572e6b..4bc747b 100644 --- a/internal/notification/discord_test.go +++ b/internal/notification/discord_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "net/http" "net/http/httptest" + "sync/atomic" "testing" "time" @@ -243,3 +244,34 @@ func TestColorForLevel(t *testing.T) { }) } } + +func TestDiscordProvider_Send_RetryOn5xx(t *testing.T) { + var calls int32 + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + n := atomic.AddInt32(&calls, 1) + if n == 1 { + w.WriteHeader(http.StatusBadGateway) + return + } + w.WriteHeader(http.StatusNoContent) + })) + defer srv.Close() + + // Shorten the 5xx backoff to keep the test snappy. + oldBackoff := discordTestBackoffOverride() + t.Cleanup(oldBackoff) + + d := NewDiscord(&DiscordConfig{WebhookURL: srv.URL}) + + err := d.Send(context.Background(), &model.Event{ + Type: "test.event", + Level: model.LevelInfo, + Timestamp: time.Now(), + }) + if err != nil { + t.Fatalf("Send() error = %v", err) + } + if got := atomic.LoadInt32(&calls); got != 2 { + t.Errorf("expected 2 webhook calls (1 5xx + 1 retry), got %d", got) + } +} diff --git a/internal/notification/discord_testhelpers_test.go b/internal/notification/discord_testhelpers_test.go new file mode 100644 index 0000000..fde6ae1 --- /dev/null +++ b/internal/notification/discord_testhelpers_test.go @@ -0,0 +1,11 @@ +package notification + +import "time" + +// discordTestBackoffOverride shortens the 5xx retry backoff for tests and +// returns a function that restores the original value. +func discordTestBackoffOverride() func() { + prev := discordServerErrorBackoff + discordServerErrorBackoff = 10 * time.Millisecond + return func() { discordServerErrorBackoff = prev } +} diff --git a/internal/runner/binary.go b/internal/runner/binary.go index 983b304..efdd7b3 100644 --- a/internal/runner/binary.go +++ b/internal/runner/binary.go @@ -9,20 +9,26 @@ import ( "os" "path/filepath" "runtime" + "sort" "strings" + "sync" + "time" ) type BinaryManager struct { cacheDir string logger *slog.Logger httpClient *http.Client + locks sync.Map } func NewBinaryManager(cacheDir string, logger *slog.Logger) *BinaryManager { return &BinaryManager{ - cacheDir: cacheDir, - logger: logger, - httpClient: &http.Client{}, + cacheDir: cacheDir, + logger: logger, + httpClient: &http.Client{ + Timeout: 10 * time.Minute, + }, } } @@ -37,14 +43,25 @@ func (m *BinaryManager) EnsureBits(ctx context.Context, version string) (string, m.logger.InfoContext(ctx, "resolved latest runner version", "version", resolved) } + mu := m.lockFor(resolved) + mu.Lock() + defer mu.Unlock() + destDir := filepath.Join(m.cacheDir, resolved) - runShPath := filepath.Join(destDir, "run.sh") + marker := filepath.Join(destDir, ".complete") - if _, err := os.Stat(runShPath); err == nil { + if _, err := os.Stat(marker); err == nil { m.logger.DebugContext(ctx, "runner binary cached", "version", resolved, "path", destDir) return destDir, nil } + if _, err := os.Stat(destDir); err == nil { + m.logger.WarnContext(ctx, "removing incomplete runner cache", "version", resolved, "path", destDir) + if rmErr := os.RemoveAll(destDir); rmErr != nil { + return "", fmt.Errorf("clean stale cache %s: %w", destDir, rmErr) + } + } + m.logger.InfoContext(ctx, "downloading runner binary", "version", resolved) if err := os.MkdirAll(destDir, 0o755); err != nil { @@ -52,17 +69,78 @@ func (m *BinaryManager) EnsureBits(ctx context.Context, version string) (string, } if err := downloadAndExtract(ctx, m.httpClient, resolved, destDir); err != nil { - rmErr := os.RemoveAll(destDir) - if rmErr != nil { + if rmErr := os.RemoveAll(destDir); rmErr != nil { m.logger.WarnContext(ctx, "failed to clean partial download", "path", destDir, "error", rmErr) } return "", fmt.Errorf("download runner %s: %w", resolved, err) } + if err := os.WriteFile(marker, nil, 0o644); err != nil { + if rmErr := os.RemoveAll(destDir); rmErr != nil { + m.logger.WarnContext(ctx, "failed to clean cache after marker write", "path", destDir, "error", rmErr) + } + return "", fmt.Errorf("write completion marker %s: %w", marker, err) + } + m.logger.InfoContext(ctx, "runner binary ready", "version", resolved, "path", destDir) + + if err := m.gcOldVersions(ctx, resolved); err != nil { + m.logger.WarnContext(ctx, "cache GC failed", "error", err) + } + return destDir, nil } +const cacheKeepVersions = 3 + +func (m *BinaryManager) gcOldVersions(ctx context.Context, keep string) error { + entries, err := os.ReadDir(m.cacheDir) + if err != nil { + return fmt.Errorf("read cache dir: %w", err) + } + + type cacheEntry struct { + name string + modTime time.Time + } + + completed := make([]cacheEntry, 0, len(entries)) + for _, e := range entries { + if !e.IsDir() || e.Name() == keep { + continue + } + marker := filepath.Join(m.cacheDir, e.Name(), ".complete") + info, statErr := os.Stat(marker) + if statErr != nil { + continue + } + completed = append(completed, cacheEntry{name: e.Name(), modTime: info.ModTime()}) + } + + if len(completed) <= cacheKeepVersions-1 { + return nil + } + + sort.Slice(completed, func(i, j int) bool { + return completed[i].modTime.After(completed[j].modTime) + }) + + for _, victim := range completed[cacheKeepVersions-1:] { + path := filepath.Join(m.cacheDir, victim.name) + if rmErr := os.RemoveAll(path); rmErr != nil { + m.logger.WarnContext(ctx, "failed to remove old runner cache", "path", path, "error", rmErr) + continue + } + m.logger.InfoContext(ctx, "removed old runner cache", "version", victim.name) + } + return nil +} + +func (m *BinaryManager) lockFor(version string) *sync.Mutex { + v, _ := m.locks.LoadOrStore(version, &sync.Mutex{}) + return v.(*sync.Mutex) +} + func (m *BinaryManager) resolveLatestVersion(ctx context.Context) (string, error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/repos/actions/runner/releases/latest", http.NoBody) if err != nil { diff --git a/internal/runner/binary_test.go b/internal/runner/binary_test.go index c651851..edca195 100644 --- a/internal/runner/binary_test.go +++ b/internal/runner/binary_test.go @@ -12,6 +12,7 @@ import ( "path/filepath" "runtime" "testing" + "time" ) func silentLogger() *slog.Logger { @@ -79,6 +80,9 @@ func TestEnsureBits_Cached(t *testing.T) { if err := os.WriteFile(runSh, []byte("#!/bin/bash\n"), 0o755); err != nil { t.Fatalf("write run.sh: %v", err) } + if err := os.WriteFile(filepath.Join(versionDir, ".complete"), nil, 0o644); err != nil { + t.Fatalf("write completion marker: %v", err) + } bm := NewBinaryManager(cacheDir, silentLogger()) @@ -92,6 +96,31 @@ func TestEnsureBits_Cached(t *testing.T) { } } +func TestEnsureBits_IncompleteCacheIsCleaned(t *testing.T) { + cacheDir := t.TempDir() + version := "2.320.0" + + versionDir := filepath.Join(cacheDir, version) + if err := os.MkdirAll(versionDir, 0o755); err != nil { + t.Fatalf("create version dir: %v", err) + } + // run.sh exists but no .complete marker → previous download was interrupted. + if err := os.WriteFile(filepath.Join(versionDir, "run.sh"), []byte("#!/bin/bash\n"), 0o755); err != nil { + t.Fatalf("write run.sh: %v", err) + } + + bm := NewBinaryManager(cacheDir, silentLogger()) + // Point at an unreachable host so the redownload fails after the stale dir is removed. + bm.httpClient = &http.Client{Timeout: 100 * time.Millisecond} + _, err := bm.EnsureBits(context.Background(), version) + if err == nil { + t.Fatal("expected EnsureBits to attempt re-download, got nil error") + } + if _, statErr := os.Stat(versionDir); !os.IsNotExist(statErr) { + t.Errorf("incomplete cache dir should have been removed, stat err = %v", statErr) + } +} + func TestEnsureBits_Download(t *testing.T) { tarGzPath := createFakeTarGz(t) tarGzData, err := os.ReadFile(tarGzPath) @@ -128,6 +157,9 @@ func TestEnsureBits_Download(t *testing.T) { if err := extractTarGz(resp.Body, versionDir); err != nil { t.Fatalf("extract tar.gz: %v", err) } + if err := os.WriteFile(filepath.Join(versionDir, ".complete"), nil, 0o644); err != nil { + t.Fatalf("write completion marker: %v", err) + } runSh := filepath.Join(versionDir, "run.sh") if _, statErr := os.Stat(runSh); statErr != nil { @@ -149,6 +181,76 @@ func TestEnsureBits_Download(t *testing.T) { } } +func TestGCOldVersions_KeepsRecent(t *testing.T) { + cacheDir := t.TempDir() + bm := NewBinaryManager(cacheDir, silentLogger()) + + versions := []string{"2.310.0", "2.311.0", "2.312.0", "2.313.0", "2.314.0"} + for i, v := range versions { + dir := filepath.Join(cacheDir, v) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir %s: %v", dir, err) + } + marker := filepath.Join(dir, ".complete") + if err := os.WriteFile(marker, nil, 0o644); err != nil { + t.Fatalf("write marker: %v", err) + } + // Stagger mtimes so the sort by modTime is deterministic. + mtime := time.Now().Add(time.Duration(i) * time.Minute) + if err := os.Chtimes(marker, mtime, mtime); err != nil { + t.Fatalf("chtimes: %v", err) + } + } + + if err := bm.gcOldVersions(context.Background(), "2.314.0"); err != nil { + t.Fatalf("gcOldVersions: %v", err) + } + + // 2.314.0 is the active version (excluded from GC); the GC keeps the 2 + // most recent of the remaining versions (cacheKeepVersions-1). + kept := map[string]bool{} + entries, _ := os.ReadDir(cacheDir) + for _, e := range entries { + kept[e.Name()] = true + } + if !kept["2.314.0"] { + t.Errorf("active version 2.314.0 was removed") + } + if !kept["2.313.0"] || !kept["2.312.0"] { + t.Errorf("recent versions were removed: %v", kept) + } + if kept["2.310.0"] || kept["2.311.0"] { + t.Errorf("old versions still present: %v", kept) + } +} + +func TestGCOldVersions_IgnoresIncompleteCaches(t *testing.T) { + cacheDir := t.TempDir() + bm := NewBinaryManager(cacheDir, silentLogger()) + + // "old" has a marker, "incomplete" doesn't. + for _, v := range []string{"2.310.0", "incomplete"} { + dir := filepath.Join(cacheDir, v) + if err := os.MkdirAll(dir, 0o755); err != nil { + t.Fatalf("mkdir: %v", err) + } + } + if err := os.WriteFile(filepath.Join(cacheDir, "2.310.0", ".complete"), nil, 0o644); err != nil { + t.Fatalf("write marker: %v", err) + } + + if err := bm.gcOldVersions(context.Background(), "2.314.0"); err != nil { + t.Fatalf("gcOldVersions: %v", err) + } + + // "incomplete" should remain because gc ignores caches without the marker. + for _, want := range []string{"2.310.0", "incomplete"} { + if _, err := os.Stat(filepath.Join(cacheDir, want)); err != nil { + t.Errorf("expected %s to remain: %v", want, err) + } + } +} + func TestRunnerArch(t *testing.T) { got := runnerArch() switch runtime.GOARCH { diff --git a/internal/runner/cleanup.go b/internal/runner/cleanup.go index 8135a5a..a4944b9 100644 --- a/internal/runner/cleanup.go +++ b/internal/runner/cleanup.go @@ -56,39 +56,37 @@ func (m *ProcessManager) cleanupStaleRunner(ctx context.Context, group, runner s pidBytes, err := os.ReadFile(pidFile) if err != nil { m.logger.DebugContext(ctx, "no PID file found, removing stale workdir", "dir", runnerDir) - removeErr := os.RemoveAll(runnerDir) - if removeErr != nil { - m.logger.WarnContext(ctx, "failed to remove stale workdir", "dir", runnerDir, "error", removeErr) - } + m.removeStaleDir(ctx, runnerDir) return } pid, err := strconv.Atoi(strings.TrimSpace(string(pidBytes))) if err != nil { m.logger.WarnContext(ctx, "invalid PID file content, removing workdir", "dir", runnerDir, "error", err) - removeErr := os.RemoveAll(runnerDir) - if removeErr != nil { - m.logger.WarnContext(ctx, "failed to remove stale workdir", "dir", runnerDir, "error", removeErr) - } + m.removeStaleDir(ctx, runnerDir) return } if isProcessAlive(pid) { m.logger.WarnContext(ctx, "killing stale runner process", "pid", pid, "runner", runner, "group", group) - killErr := syscall.Kill(pid, syscall.SIGKILL) - if killErr != nil { + if killErr := syscall.Kill(pid, syscall.SIGKILL); killErr != nil { m.logger.WarnContext(ctx, "failed to kill stale process", "pid", pid, "error", killErr) } } - removeErr := os.RemoveAll(runnerDir) - if removeErr != nil { - m.logger.WarnContext(ctx, "failed to remove stale workdir", "dir", runnerDir, "error", removeErr) - } else { + if m.removeStaleDir(ctx, runnerDir) { m.logger.InfoContext(ctx, "cleaned up stale runner", "runner", runner, "group", group, "pid", pid) } } +func (m *ProcessManager) removeStaleDir(ctx context.Context, dir string) bool { + if err := os.RemoveAll(dir); err != nil { + m.logger.WarnContext(ctx, "failed to remove stale workdir", "dir", dir, "error", err) + return false + } + return true +} + func (m *ProcessManager) KillOrphanRunners(ctx context.Context) { out, err := exec.CommandContext(ctx, "pgrep", "-f", m.workdirBase).Output() if err != nil { @@ -99,9 +97,28 @@ func (m *ProcessManager) KillOrphanRunners(ctx context.Context) { if err != nil || pid <= 0 { continue } + if !m.processBelongsToGhr(ctx, pid) { + m.logger.DebugContext(ctx, "ignoring pgrep match not owned by ghr", "pid", pid) + continue + } m.logger.WarnContext(ctx, "killing orphan runner process", "pid", pid) - _ = syscall.Kill(pid, syscall.SIGKILL) + if killErr := syscall.Kill(pid, syscall.SIGKILL); killErr != nil { + m.logger.WarnContext(ctx, "failed to kill orphan runner", "pid", pid, "error", killErr) + } + } +} + +func (m *ProcessManager) processBelongsToGhr(ctx context.Context, pid int) bool { + out, err := exec.CommandContext(ctx, "ps", "-p", strconv.Itoa(pid), "-o", "command=").Output() + if err != nil { + return false + } + cmd := strings.TrimSpace(string(out)) + if cmd == "" { + return false } + prefix := strings.TrimRight(m.workdirBase, string(os.PathSeparator)) + string(os.PathSeparator) + return strings.Contains(cmd, prefix) && strings.Contains(cmd, "/run.sh") } func isProcessAlive(pid int) bool { diff --git a/internal/runner/copy.go b/internal/runner/copy.go index cef7538..35bf11d 100644 --- a/internal/runner/copy.go +++ b/internal/runner/copy.go @@ -29,6 +29,9 @@ func copyDir(src, dst string) error { if err != nil { return fmt.Errorf("read symlink %s: %w", path, err) } + if filepath.IsAbs(link) || !filepath.IsLocal(link) { + return fmt.Errorf("refusing to copy symlink %s -> %q (non-local target)", path, link) + } return os.Symlink(link, targetPath) } diff --git a/internal/runner/copy_test.go b/internal/runner/copy_test.go new file mode 100644 index 0000000..65e31f3 --- /dev/null +++ b/internal/runner/copy_test.go @@ -0,0 +1,66 @@ +package runner + +import ( + "os" + "path/filepath" + "strings" + "testing" +) + +func TestCopyDir_RefusesAbsoluteSymlink(t *testing.T) { + src := t.TempDir() + dst := t.TempDir() + + if err := os.Symlink("/etc/passwd", filepath.Join(src, "evil-link")); err != nil { + t.Fatalf("setup symlink: %v", err) + } + + err := copyDir(src, dst) + if err == nil { + t.Fatal("expected error for absolute symlink, got nil") + } + if !strings.Contains(err.Error(), "non-local") { + t.Errorf("error = %v, want substring 'non-local'", err) + } +} + +func TestCopyDir_RefusesEscapeSymlink(t *testing.T) { + src := t.TempDir() + dst := t.TempDir() + + if err := os.Symlink("../../etc/passwd", filepath.Join(src, "escape-link")); err != nil { + t.Fatalf("setup symlink: %v", err) + } + + err := copyDir(src, dst) + if err == nil { + t.Fatal("expected error for escaping symlink, got nil") + } + if !strings.Contains(err.Error(), "non-local") { + t.Errorf("error = %v, want substring 'non-local'", err) + } +} + +func TestCopyDir_AllowsLocalSymlink(t *testing.T) { + src := t.TempDir() + dst := t.TempDir() + + if err := os.WriteFile(filepath.Join(src, "target.sh"), []byte("ok"), 0o755); err != nil { + t.Fatalf("write target: %v", err) + } + if err := os.Symlink("target.sh", filepath.Join(src, "link.sh")); err != nil { + t.Fatalf("create symlink: %v", err) + } + + if err := copyDir(src, dst); err != nil { + t.Fatalf("copyDir: %v", err) + } + + link, err := os.Readlink(filepath.Join(dst, "link.sh")) + if err != nil { + t.Fatalf("readlink dst: %v", err) + } + if link != "target.sh" { + t.Errorf("link = %q, want %q", link, "target.sh") + } +} diff --git a/internal/runner/download.go b/internal/runner/download.go index 0e66ba9..28c35bd 100644 --- a/internal/runner/download.go +++ b/internal/runner/download.go @@ -4,6 +4,8 @@ import ( "archive/tar" "compress/gzip" "context" + "crypto/sha256" + "encoding/hex" "errors" "fmt" "io" @@ -18,6 +20,11 @@ const downloadURLTemplate = "https://github.com/actions/runner/releases/download func downloadAndExtract(ctx context.Context, client *http.Client, version, destDir string) error { url := fmt.Sprintf(downloadURLTemplate, version, runnerArch(), version) + expected, err := fetchExpectedSHA256(ctx, client, url+".sha256") + if err != nil { + return fmt.Errorf("fetch checksum for %s: %w", url, err) + } + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) if err != nil { return fmt.Errorf("create download request: %w", err) @@ -33,7 +40,53 @@ func downloadAndExtract(ctx context.Context, client *http.Client, version, destD return fmt.Errorf("download returned HTTP %d for %s", resp.StatusCode, url) } - return extractTarGz(resp.Body, destDir) + hasher := sha256.New() + tee := io.TeeReader(resp.Body, hasher) + + if err := extractTarGz(tee, destDir); err != nil { + return err + } + if _, err := io.Copy(io.Discard, tee); err != nil { + return fmt.Errorf("drain tarball: %w", err) + } + + got := hex.EncodeToString(hasher.Sum(nil)) + if !strings.EqualFold(got, expected) { + return fmt.Errorf("checksum mismatch for %s: got %s, want %s", url, got, expected) + } + return nil +} + +func fetchExpectedSHA256(ctx context.Context, client *http.Client, url string) (string, error) { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, http.NoBody) + if err != nil { + return "", fmt.Errorf("create checksum request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("fetch checksum: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("checksum URL returned HTTP %d for %s", resp.StatusCode, url) + } + + raw, err := io.ReadAll(io.LimitReader(resp.Body, 1024)) + if err != nil { + return "", fmt.Errorf("read checksum response: %w", err) + } + + fields := strings.Fields(string(raw)) + if len(fields) == 0 { + return "", fmt.Errorf("checksum file %s is empty", url) + } + hash := strings.ToLower(fields[0]) + if len(hash) != 64 { + return "", fmt.Errorf("checksum %q from %s is not a sha-256 digest", hash, url) + } + return hash, nil } func extractTarGz(r io.Reader, destDir string) error { @@ -68,12 +121,8 @@ func extractTarGz(r io.Reader, destDir string) error { return err } case tar.TypeSymlink: - linkTarget, linkErr := sanitizeTarPath(destDir, header.Linkname) - if linkErr != nil { - linkTarget = header.Linkname - } - if err := os.Symlink(linkTarget, target); err != nil { - return fmt.Errorf("create symlink %s: %w", target, err) + if err := writeSymlink(target, header.Name, header.Linkname); err != nil { + return err } } } @@ -107,3 +156,13 @@ func sanitizeTarPath(destDir, name string) (string, error) { } return target, nil } + +func writeSymlink(target, name, linkname string) error { + if filepath.IsAbs(linkname) || !filepath.IsLocal(linkname) { + return fmt.Errorf("tar symlink %q -> %q is not local", name, linkname) + } + if err := os.Symlink(linkname, target); err != nil { + return fmt.Errorf("create symlink %s: %w", target, err) + } + return nil +} diff --git a/internal/runner/download_test.go b/internal/runner/download_test.go new file mode 100644 index 0000000..19ae4cb --- /dev/null +++ b/internal/runner/download_test.go @@ -0,0 +1,190 @@ +package runner + +import ( + "archive/tar" + "bytes" + "compress/gzip" + "context" + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" +) + +type tarEntry struct { + name string + typeflag byte + body string + linkname string + mode int64 +} + +func buildTarGz(t *testing.T, entries []tarEntry) []byte { + t.Helper() + var buf bytes.Buffer + gz := gzip.NewWriter(&buf) + tw := tar.NewWriter(gz) + for _, e := range entries { + hdr := &tar.Header{ + Name: e.name, + Mode: e.mode, + Typeflag: e.typeflag, + Size: int64(len(e.body)), + Linkname: e.linkname, + } + if e.typeflag == 0 { + hdr.Typeflag = tar.TypeReg + } + if err := tw.WriteHeader(hdr); err != nil { + t.Fatalf("write header: %v", err) + } + if e.typeflag == tar.TypeReg || e.typeflag == 0 { + if _, err := tw.Write([]byte(e.body)); err != nil { + t.Fatalf("write body: %v", err) + } + } + } + if err := tw.Close(); err != nil { + t.Fatalf("close tar: %v", err) + } + if err := gz.Close(); err != nil { + t.Fatalf("close gz: %v", err) + } + return buf.Bytes() +} + +func TestExtractTarGz_BlocksAbsoluteSymlink(t *testing.T) { + dest := t.TempDir() + data := buildTarGz(t, []tarEntry{ + {name: "evil-link", typeflag: tar.TypeSymlink, linkname: "/etc/passwd"}, + }) + + err := extractTarGz(bytes.NewReader(data), dest) + if err == nil { + t.Fatal("expected error for absolute symlink, got nil") + } + if !strings.Contains(err.Error(), "not local") { + t.Errorf("error = %v, want substring 'not local'", err) + } + if _, statErr := os.Lstat(filepath.Join(dest, "evil-link")); statErr == nil { + t.Error("symlink should not have been created") + } +} + +func TestExtractTarGz_BlocksRelativeEscapeSymlink(t *testing.T) { + dest := t.TempDir() + data := buildTarGz(t, []tarEntry{ + {name: "escape-link", typeflag: tar.TypeSymlink, linkname: "../../etc/passwd"}, + }) + + err := extractTarGz(bytes.NewReader(data), dest) + if err == nil { + t.Fatal("expected error for escaping symlink, got nil") + } + if !strings.Contains(err.Error(), "not local") { + t.Errorf("error = %v, want substring 'not local'", err) + } +} + +func TestExtractTarGz_BlocksPathTraversal(t *testing.T) { + dest := t.TempDir() + data := buildTarGz(t, []tarEntry{ + {name: "../escape.sh", typeflag: tar.TypeReg, body: "evil", mode: 0o644}, + }) + + err := extractTarGz(bytes.NewReader(data), dest) + if err == nil { + t.Fatal("expected error for path traversal, got nil") + } + if !strings.Contains(err.Error(), "escapes destination") { + t.Errorf("error = %v, want substring 'escapes destination'", err) + } +} + +func TestFetchExpectedSHA256(t *testing.T) { + const valid = "0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef" + + cases := []struct { + name string + statusCode int + body string + wantHash string + wantErr string + }{ + { + name: "hash with filename", + statusCode: http.StatusOK, + body: valid + " actions-runner-osx-arm64-2.331.0.tar.gz\n", + wantHash: valid, + }, + { + name: "hash only", + statusCode: http.StatusOK, + body: valid + "\n", + wantHash: valid, + }, + { + name: "empty body", + statusCode: http.StatusOK, + body: "", + wantErr: "empty", + }, + { + name: "non-sha digest", + statusCode: http.StatusOK, + body: "shortdigest\n", + wantErr: "not a sha-256", + }, + { + name: "404", + statusCode: http.StatusNotFound, + body: "", + wantErr: "HTTP 404", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(tc.statusCode) + _, _ = w.Write([]byte(tc.body)) + })) + defer srv.Close() + + got, err := fetchExpectedSHA256(context.Background(), srv.Client(), srv.URL) + if tc.wantErr != "" { + if err == nil || !strings.Contains(err.Error(), tc.wantErr) { + t.Fatalf("err = %v, want substring %q", err, tc.wantErr) + } + return + } + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got != tc.wantHash { + t.Errorf("hash = %q, want %q", got, tc.wantHash) + } + }) + } +} + +func TestExtractTarGz_AllowsLocalSymlink(t *testing.T) { + dest := t.TempDir() + data := buildTarGz(t, []tarEntry{ + {name: "target.sh", typeflag: tar.TypeReg, body: "ok", mode: 0o755}, + {name: "link.sh", typeflag: tar.TypeSymlink, linkname: "target.sh"}, + }) + + if err := extractTarGz(bytes.NewReader(data), dest); err != nil { + t.Fatalf("extractTarGz: %v", err) + } + link, err := os.Readlink(filepath.Join(dest, "link.sh")) + if err != nil { + t.Fatalf("readlink: %v", err) + } + if link != "target.sh" { + t.Errorf("link = %q, want %q", link, "target.sh") + } +} diff --git a/internal/runner/process.go b/internal/runner/process.go index 60b6451..2e6372b 100644 --- a/internal/runner/process.go +++ b/internal/runner/process.go @@ -18,13 +18,15 @@ import ( const stopGracePeriod = 10 * time.Second +const killTimeout = 5 * time.Second + type Process struct { Name string Group string WorkDir string - PID int + PID int32 StartedAt time.Time - Cmd *exec.Cmd + cmd *exec.Cmd } type ProcessManager struct { @@ -66,31 +68,32 @@ func (m *ProcessManager) Start(ctx context.Context, instance *model.RunnerInstan return nil, fmt.Errorf("start runner %s: %w", instance.Name, err) } + pid := int32(cmd.Process.Pid) pidFile := filepath.Join(workdir, ".ghr-pid") - if err := os.WriteFile(pidFile, []byte(strconv.Itoa(cmd.Process.Pid)), 0o644); err != nil { + if err := os.WriteFile(pidFile, []byte(strconv.FormatInt(int64(pid), 10)), 0o644); err != nil { m.logger.WarnContext(ctx, "failed to write PID file", "path", pidFile, "error", err) } - m.logger.InfoContext(ctx, "runner started", "runner", instance.Name, "pid", cmd.Process.Pid) + m.logger.InfoContext(ctx, "runner started", "runner", instance.Name, "pid", pid) return &Process{ Name: instance.Name, Group: instance.Group, WorkDir: workdir, - PID: cmd.Process.Pid, + PID: pid, StartedAt: time.Now(), - Cmd: cmd, + cmd: cmd, }, nil } func (m *ProcessManager) Stop(ctx context.Context, proc *Process) error { - if proc.Cmd == nil || proc.Cmd.Process == nil { + if proc.cmd == nil || proc.cmd.Process == nil { return nil } m.logger.InfoContext(ctx, "stopping runner", "runner", proc.Name, "pid", proc.PID) - if err := proc.Cmd.Process.Signal(syscall.SIGTERM); err != nil { + if err := proc.cmd.Process.Signal(syscall.SIGTERM); err != nil { if isProcessFinished(err) { return nil } @@ -99,7 +102,7 @@ func (m *ProcessManager) Stop(ctx context.Context, proc *Process) error { done := make(chan error, 1) go func() { - done <- proc.Cmd.Wait() + done <- proc.cmd.Wait() }() select { @@ -110,10 +113,18 @@ func (m *ProcessManager) Stop(ctx context.Context, proc *Process) error { return err case <-time.After(stopGracePeriod): m.logger.WarnContext(ctx, "runner did not exit after SIGTERM, sending SIGKILL", "runner", proc.Name, "pid", proc.PID) - if err := proc.Cmd.Process.Kill(); err != nil { + if err := proc.cmd.Process.Kill(); err != nil { return fmt.Errorf("kill runner %s (pid %d): %w", proc.Name, proc.PID, err) } - return <-done + select { + case err := <-done: + if isExpectedExit(err) { + return nil + } + return err + case <-time.After(killTimeout): + return fmt.Errorf("runner %s (pid %d) did not exit after SIGKILL within %s", proc.Name, proc.PID, killTimeout) + } } } diff --git a/internal/runner/process_test.go b/internal/runner/process_test.go index dc457d1..8a883c0 100644 --- a/internal/runner/process_test.go +++ b/internal/runner/process_test.go @@ -25,7 +25,6 @@ func TestPrepare(t *testing.T) { pm := NewProcessManager(workdirBase, silentLogger()) instance := model.RunnerInstance{ - ID: "abc123", Name: "test-group-abc123", Group: "test-group", } diff --git a/internal/state/paths.go b/internal/state/paths.go new file mode 100644 index 0000000..e5b3a68 --- /dev/null +++ b/internal/state/paths.go @@ -0,0 +1,33 @@ +package state + +import "path/filepath" + +const ( + pidFileName = "daemon.pid" + stateFileName = "daemon.state.json" + socketFileName = "ghr.sock" +) + +type Paths struct { + Dir string +} + +func New(dir string) Paths { + return Paths{Dir: dir} +} + +func (p Paths) PIDFile() string { + return filepath.Join(p.Dir, pidFileName) +} + +func (p Paths) StateFile() string { + return filepath.Join(p.Dir, stateFileName) +} + +func (p Paths) Socket() string { + return filepath.Join(p.Dir, socketFileName) +} + +func (p Paths) All() []string { + return []string{p.PIDFile(), p.StateFile(), p.Socket()} +} diff --git a/internal/state/paths_test.go b/internal/state/paths_test.go new file mode 100644 index 0000000..51fb5b7 --- /dev/null +++ b/internal/state/paths_test.go @@ -0,0 +1,25 @@ +package state + +import ( + "path/filepath" + "testing" +) + +func TestPaths(t *testing.T) { + p := New("/var/lib/ghr/state") + + if got, want := p.PIDFile(), filepath.Join("/var/lib/ghr/state", "daemon.pid"); got != want { + t.Errorf("PIDFile() = %q, want %q", got, want) + } + if got, want := p.StateFile(), filepath.Join("/var/lib/ghr/state", "daemon.state.json"); got != want { + t.Errorf("StateFile() = %q, want %q", got, want) + } + if got, want := p.Socket(), filepath.Join("/var/lib/ghr/state", "ghr.sock"); got != want { + t.Errorf("Socket() = %q, want %q", got, want) + } + + all := p.All() + if len(all) != 3 { + t.Fatalf("All() returned %d paths, want 3", len(all)) + } +}