Skip to content

Commit 9a6ba96

Browse files
authored
Refactor duplicated update-check state helpers and centralize repo/API normalization (#41985)
1 parent 04c80c2 commit 9a6ba96

14 files changed

Lines changed: 131 additions & 123 deletions

pkg/cli/codemod_dependabot_permissions.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package cli
22

33
import (
44
"fmt"
5-
"sort"
65
"strings"
76

87
"github.com/github/gh-aw/pkg/logger"
@@ -234,10 +233,9 @@ func findPermissionsInsertIndex(lines []string) int {
234233

235234
func sortedMissingPermissionKeys(missing map[workflow.PermissionScope]workflow.PermissionLevel) []string {
236235
keys := make([]string, 0, len(missing))
237-
for scope := range missing {
236+
for _, scope := range sliceutil.SortedKeys(missing) {
238237
keys = append(keys, string(scope))
239238
}
240-
sort.Strings(keys)
241239
return keys
242240
}
243241

pkg/cli/compile_update_check.go

Lines changed: 3 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -127,31 +127,7 @@ func shouldRunCompileUpdateCheck(noCheckUpdate bool) bool {
127127
}
128128

129129
lastCheckFile := getCompileUpdateCheckFilePath()
130-
if lastCheckFile == "" {
131-
compileUpdateCheckLog.Print("Could not determine compile update check file path")
132-
return false
133-
}
134-
135-
data, err := os.ReadFile(lastCheckFile)
136-
if err != nil {
137-
if !os.IsNotExist(err) {
138-
compileUpdateCheckLog.Printf("Error reading compile update check file: %v", err)
139-
}
140-
return true
141-
}
142-
143-
lastCheck, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data)))
144-
if err != nil {
145-
compileUpdateCheckLog.Printf("Error parsing compile update check time: %v", err)
146-
return true
147-
}
148-
149-
elapsed := time.Since(lastCheck)
150-
if elapsed < compileUpdateCheckInterval {
151-
compileUpdateCheckLog.Printf("Last compile update check was %v ago, skipping", elapsed)
152-
return false
153-
}
154-
return true
130+
return shouldRunUpdateCheckAtPath(lastCheckFile, compileUpdateCheckInterval, "compile update check", compileUpdateCheckLog)
155131
}
156132

157133
func waitForCompileUpdateNotification(ctx context.Context, results <-chan *compileUpdateNotification, timeout time.Duration) *compileUpdateNotification {
@@ -292,19 +268,11 @@ func getCompileUpdateCheckFilePath() string {
292268
}
293269

294270
func getCompileUpdateCheckFilePathImpl() string {
295-
return getLastCheckFilePathFor(compileUpdateCheckFileName)
271+
return getUpdateCheckFilePathFor(compileUpdateCheckFileName, compileUpdateCheckLog)
296272
}
297273

298274
func updateCompileUpdateCheckTime() {
299-
lastCheckFile := getCompileUpdateCheckFilePath()
300-
if lastCheckFile == "" {
301-
return
302-
}
303-
304-
timestamp := time.Now().Format(time.RFC3339)
305-
if err := os.WriteFile(lastCheckFile, []byte(timestamp), constants.FilePermSensitive); err != nil {
306-
compileUpdateCheckLog.Printf("Error writing compile update check time: %v", err)
307-
}
275+
writeUpdateCheckTime(getCompileUpdateCheckFilePath(), constants.FilePermSensitive, "compile update check", compileUpdateCheckLog)
308276
}
309277

310278
func isMinorVersionBehind(currentVersion string, latestVersion string) bool {

pkg/cli/logs_command.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ import (
1919
"github.com/github/gh-aw/pkg/console"
2020
"github.com/github/gh-aw/pkg/constants"
2121
"github.com/github/gh-aw/pkg/logger"
22+
"github.com/github/gh-aw/pkg/repoutil"
2223
"github.com/github/gh-aw/pkg/workflow"
2324
"github.com/spf13/cobra"
2425
)
@@ -448,7 +449,7 @@ Downloaded artifacts include (when using --artifacts all):
448449
// to the same repository that is checked out locally.
449450
func repoIsLocal(repo string) bool {
450451
// Strip optional HOST/ prefix (e.g. "github.com/owner/repo" → "owner/repo")
451-
ownerRepo, _ := normalizeRepoForAPI(repo)
452+
ownerRepo, _ := repoutil.NormalizeRepoForAPI(repo)
452453

453454
// Fast path: GITHUB_REPOSITORY is always the current repo in MCP server containers.
454455
if envRepo := os.Getenv("GITHUB_REPOSITORY"); envRepo != "" {

pkg/cli/outcome_eval.go

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/github/gh-aw/pkg/github"
1111
"github.com/github/gh-aw/pkg/intent"
1212
"github.com/github/gh-aw/pkg/logger"
13+
"github.com/github/gh-aw/pkg/repoutil"
1314
"github.com/github/gh-aw/pkg/workflow"
1415
)
1516

@@ -220,21 +221,9 @@ func ComputeOutcomeSummary(reports []OutcomeReport, mapping *github.ObjectiveMap
220221
return s
221222
}
222223

223-
// normalizeRepoForAPI splits a repo string of the form "[HOST/]owner/repo" into
224-
// the owner/repo portion and an optional host. Most callers pass plain "owner/repo",
225-
// but GHES and Proxima installs may supply "HOST/owner/repo".
226-
func normalizeRepoForAPI(repo string) (ownerRepo string, host string) {
227-
parts := strings.SplitN(repo, "/", 3)
228-
if len(parts) == 3 {
229-
// HOST/owner/repo
230-
return parts[1] + "/" + parts[2], parts[0]
231-
}
232-
return repo, ""
233-
}
234-
235224
// ghAPIGet calls the GitHub REST API via gh cli and returns the parsed JSON.
236225
func ghAPIGet(endpoint string, repo string) (map[string]any, error) {
237-
ownerRepo, host := normalizeRepoForAPI(repo)
226+
ownerRepo, host := repoutil.NormalizeRepoForAPI(repo)
238227
outcomeEvalLog.Printf("gh api GET: repo=%s, endpoint=%s, host=%q", ownerRepo, endpoint, host)
239228
args := []string{"api", fmt.Sprintf("repos/%s/%s", ownerRepo, endpoint)}
240229
var output []byte
@@ -257,7 +246,7 @@ func ghAPIGet(endpoint string, repo string) (map[string]any, error) {
257246

258247
// ghAPIGetArray calls the GitHub REST API and returns a JSON array.
259248
func ghAPIGetArray(endpoint string, repo string) ([]map[string]any, error) {
260-
ownerRepo, host := normalizeRepoForAPI(repo)
249+
ownerRepo, host := repoutil.NormalizeRepoForAPI(repo)
261250
args := []string{"api", fmt.Sprintf("repos/%s/%s", ownerRepo, endpoint)}
262251
var output []byte
263252
var err error
@@ -278,7 +267,7 @@ func ghAPIGetArray(endpoint string, repo string) ([]map[string]any, error) {
278267

279268
// ghAPIGraphQL calls the GitHub GraphQL API via gh cli and returns the parsed JSON.
280269
func ghAPIGraphQL(query string, repo string) (map[string]any, error) {
281-
ownerRepo, host := normalizeRepoForAPI(repo)
270+
ownerRepo, host := repoutil.NormalizeRepoForAPI(repo)
282271
args := []string{"api", "graphql", "-f", "query=" + query}
283272
var output []byte
284273
var err error
@@ -474,7 +463,7 @@ func resolvePullRequestIntent(report OutcomeReport, repo string, resolver intent
474463

475464
func loadPullRequestIntentData(report OutcomeReport, repo string) (intent.PullRequestData, error) {
476465
prNumber := report.ObjectNumber
477-
ownerRepo, _ := normalizeRepoForAPI(repo)
466+
ownerRepo, _ := repoutil.NormalizeRepoForAPI(repo)
478467
owner, name, found := strings.Cut(ownerRepo, "/")
479468
if !found || owner == "" || name == "" {
480469
return intent.PullRequestData{}, fmt.Errorf("invalid repo for root tracing: %s", repo)

pkg/cli/outcome_eval_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"testing"
1313

1414
"github.com/github/gh-aw/pkg/github"
15+
"github.com/github/gh-aw/pkg/repoutil"
1516
"github.com/stretchr/testify/assert"
1617
"github.com/stretchr/testify/require"
1718
)
@@ -114,7 +115,7 @@ func TestNormalizeRepoForAPI(t *testing.T) {
114115

115116
for _, tt := range tests {
116117
t.Run(tt.name, func(t *testing.T) {
117-
ownerRepo, host := normalizeRepoForAPI(tt.repo)
118+
ownerRepo, host := repoutil.NormalizeRepoForAPI(tt.repo)
118119
assert.Equal(t, tt.wantOwnerRepo, ownerRepo, "owner/repo portion")
119120
assert.Equal(t, tt.wantHost, host, "host portion")
120121
})

pkg/cli/update_check.go

Lines changed: 4 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,14 @@ import (
44
"context"
55
"fmt"
66
"os"
7-
"path/filepath"
87
"strings"
98
"time"
109

11-
"github.com/github/gh-aw/pkg/constants"
1210
"golang.org/x/mod/semver"
1311

1412
"github.com/cli/go-gh/v2/pkg/api"
1513
"github.com/github/gh-aw/pkg/console"
14+
"github.com/github/gh-aw/pkg/constants"
1615
"github.com/github/gh-aw/pkg/logger"
1716
"github.com/github/gh-aw/pkg/workflow"
1817
)
@@ -64,31 +63,7 @@ func shouldCheckForUpdate(noCheckUpdate bool) bool {
6463

6564
// Check if we've already checked recently
6665
lastCheckFile := getLastCheckFilePath()
67-
if lastCheckFile == "" {
68-
updateCheckLog.Print("Could not determine last check file path")
69-
return false
70-
}
71-
72-
// Read last check time
73-
data, err := os.ReadFile(lastCheckFile)
74-
if err != nil {
75-
if !os.IsNotExist(err) {
76-
updateCheckLog.Printf("Error reading last check file: %v", err)
77-
}
78-
// File doesn't exist or error reading - perform check
79-
return true
80-
}
81-
82-
lastCheck, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data)))
83-
if err != nil {
84-
updateCheckLog.Printf("Error parsing last check time: %v", err)
85-
// Invalid timestamp - perform check
86-
return true
87-
}
88-
89-
// Check if enough time has passed
90-
if time.Since(lastCheck) < checkInterval {
91-
updateCheckLog.Printf("Last check was %v ago, skipping", time.Since(lastCheck))
66+
if !shouldRunUpdateCheckAtPath(lastCheckFile, checkInterval, "update check", updateCheckLog) {
9267
return false
9368
}
9469

@@ -115,38 +90,12 @@ func getLastCheckFilePath() string {
11590

11691
// getLastCheckFilePathImpl is the actual implementation
11792
func getLastCheckFilePathImpl() string {
118-
return getLastCheckFilePathFor(lastCheckFileName)
119-
}
120-
121-
func getLastCheckFilePathFor(fileName string) string {
122-
// Use OS temp directory for cross-platform compatibility
123-
tmpDir := os.TempDir()
124-
if tmpDir == "" {
125-
updateCheckLog.Print("Could not determine temp directory")
126-
return ""
127-
}
128-
129-
// Create a gh-aw subdirectory in temp
130-
ghAwTmpDir := filepath.Join(tmpDir, "gh-aw")
131-
if err := os.MkdirAll(ghAwTmpDir, constants.DirPermPublic); err != nil {
132-
updateCheckLog.Printf("Error creating gh-aw temp directory: %v", err)
133-
return ""
134-
}
135-
136-
return filepath.Join(ghAwTmpDir, fileName)
93+
return getUpdateCheckFilePathFor(lastCheckFileName, updateCheckLog)
13794
}
13895

13996
// updateLastCheckTime updates the timestamp of the last update check
14097
func updateLastCheckTime() {
141-
lastCheckFile := getLastCheckFilePath()
142-
if lastCheckFile == "" {
143-
return
144-
}
145-
146-
timestamp := time.Now().Format(time.RFC3339)
147-
if err := os.WriteFile(lastCheckFile, []byte(timestamp), constants.FilePermPublic); err != nil {
148-
updateCheckLog.Printf("Error writing last check time: %v", err)
149-
}
98+
writeUpdateCheckTime(getLastCheckFilePath(), constants.FilePermPublic, "update check", updateCheckLog)
15099
}
151100

152101
// checkForUpdates checks if a newer version of gh-aw is available

pkg/cli/update_check_state.go

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
package cli
2+
3+
import (
4+
"os"
5+
"path/filepath"
6+
"strings"
7+
"time"
8+
9+
"github.com/github/gh-aw/pkg/constants"
10+
"github.com/github/gh-aw/pkg/logger"
11+
)
12+
13+
func getUpdateCheckFilePathFor(fileName string, log *logger.Logger) string {
14+
tmpDir := os.TempDir()
15+
if tmpDir == "" {
16+
log.Print("Could not determine temp directory")
17+
return ""
18+
}
19+
20+
ghAwTmpDir := filepath.Join(tmpDir, "gh-aw")
21+
if err := os.MkdirAll(ghAwTmpDir, constants.DirPermPublic); err != nil {
22+
log.Printf("Error creating gh-aw temp directory: %v", err)
23+
return ""
24+
}
25+
26+
return filepath.Join(ghAwTmpDir, fileName)
27+
}
28+
29+
func shouldRunUpdateCheckAtPath(lastCheckFile string, interval time.Duration, label string, log *logger.Logger) bool {
30+
if lastCheckFile == "" {
31+
log.Printf("Could not determine %s file path", label)
32+
return false
33+
}
34+
35+
data, err := os.ReadFile(lastCheckFile)
36+
if err != nil {
37+
if !os.IsNotExist(err) {
38+
log.Printf("Error reading %s file: %v", label, err)
39+
}
40+
return true
41+
}
42+
43+
lastCheck, err := time.Parse(time.RFC3339, strings.TrimSpace(string(data)))
44+
if err != nil {
45+
log.Printf("Error parsing %s time: %v", label, err)
46+
return true
47+
}
48+
49+
elapsed := time.Since(lastCheck)
50+
if elapsed < interval {
51+
log.Printf("Last %s was %v ago, skipping", label, elapsed)
52+
return false
53+
}
54+
55+
return true
56+
}
57+
58+
func writeUpdateCheckTime(path string, perm os.FileMode, label string, log *logger.Logger) {
59+
if path == "" {
60+
return
61+
}
62+
63+
timestamp := time.Now().Format(time.RFC3339)
64+
if err := os.WriteFile(path, []byte(timestamp), perm); err != nil {
65+
log.Printf("Error writing %s time: %v", label, err)
66+
}
67+
}

pkg/repoutil/README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ The `repoutil` package provides utility functions for working with GitHub reposi
44

55
## Overview
66

7-
This package offers a single focused helper for parsing and validating `owner/repo` repository slug strings, which are used throughout the codebase wherever GitHub repositories are referenced.
7+
This package offers focused helpers for parsing and normalizing repository identifiers, which are used throughout the codebase wherever GitHub repositories are referenced.
88

99
## Public API
1010

@@ -13,6 +13,7 @@ This package offers a single focused helper for parsing and validating `owner/re
1313
| Function | Signature | Description |
1414
|----------|-----------|-------------|
1515
| `SplitRepoSlug` | `func(slug string) (owner, repo string, err error)` | Splits a repository slug of the form `owner/repo` into its two components; returns an error when the slug does not contain exactly one `/` or when either component is empty |
16+
| `NormalizeRepoForAPI` | `func(repo string) (ownerRepo string, host string)` | Splits a repository string of the form `[HOST/]owner/repo` into the `owner/repo` portion and an optional host name for GHES/Proxima API calls |
1617

1718
## Usage Examples
1819

pkg/repoutil/repoutil.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,14 @@ func SplitRepoSlug(slug string) (owner, repo string, err error) {
2222
repoutilLog.Printf("Split result: owner=%s, repo=%s", parts[0], parts[1])
2323
return parts[0], parts[1], nil
2424
}
25+
26+
// NormalizeRepoForAPI splits a repo string of the form "[HOST/]owner/repo" into
27+
// the owner/repo portion and an optional host. Most callers pass plain
28+
// "owner/repo", but GHES and Proxima installs may supply "HOST/owner/repo".
29+
func NormalizeRepoForAPI(repo string) (ownerRepo string, host string) {
30+
parts := strings.SplitN(repo, "/", 3)
31+
if len(parts) == 3 {
32+
return parts[1] + "/" + parts[2], parts[0]
33+
}
34+
return repo, ""
35+
}

pkg/repoutil/repoutil_test.go

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,3 +213,24 @@ func BenchmarkSplitRepoSlug_Invalid(b *testing.B) {
213213
_, _, _ = SplitRepoSlug(slug)
214214
}
215215
}
216+
217+
func TestNormalizeRepoForAPI(t *testing.T) {
218+
tests := []struct {
219+
name string
220+
repo string
221+
wantOwnerRepo string
222+
wantHost string
223+
}{
224+
{"plain owner/repo", "owner/repo", "owner/repo", ""},
225+
{"GHES HOST/owner/repo", "myhost.com/owner/repo", "owner/repo", "myhost.com"},
226+
{"github.com/owner/repo treated as host prefix", "github.com/owner/repo", "owner/repo", "github.com"},
227+
}
228+
229+
for _, tt := range tests {
230+
t.Run(tt.name, func(t *testing.T) {
231+
ownerRepo, host := NormalizeRepoForAPI(tt.repo)
232+
assert.Equal(t, tt.wantOwnerRepo, ownerRepo, "owner/repo portion")
233+
assert.Equal(t, tt.wantHost, host, "host portion")
234+
})
235+
}
236+
}

0 commit comments

Comments
 (0)