diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 25be36f..bf89e52 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -17,7 +17,9 @@ import ( ) // TokenSource returns a valid access token, refreshing if necessary. -type TokenSource func() (string, error) +// If forceRefresh is true, the source must obtain a new token regardless of +// whether the cached one appears valid (used for retry-on-401). +type TokenSource func(forceRefresh bool) (string, error) // Client wraps the ConnectRPC AgentService client. type Client struct { @@ -39,7 +41,7 @@ func NewClient(baseURL string, ts TokenSource) *Client { // StaticToken returns a TokenSource that always returns the same token. // Useful for tests or short-lived commands. func StaticToken(token string) TokenSource { - return func() (string, error) { return token, nil } + return func(_ bool) (string, error) { return token, nil } } // BaseURL returns the API base URL from env or default. @@ -85,17 +87,38 @@ func (c *Client) IngestEvent(ctx context.Context, req *agentv1.ProcessHookEventR } // bearerTransport fetches a fresh token for every outgoing request. +// On 401, it forces a token refresh and retries once. type bearerTransport struct { tokenSource TokenSource base http.RoundTripper } func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - token, err := t.tokenSource() + token, err := t.tokenSource(false) if err != nil { return nil, fmt.Errorf("token refresh: %w", err) } + r := req.Clone(req.Context()) r.Header.Set("Authorization", "Bearer "+token) - return t.base.RoundTrip(r) + resp, err := t.base.RoundTrip(r) + if err != nil { + return nil, err + } + + // Retry once with a forced refresh on 401 — the cached token may be + // stale even though IsExpired() said it was fine (server-side revocation, + // clock skew, Hydra TTL mismatch, etc.). + if resp.StatusCode == http.StatusUnauthorized { + resp.Body.Close() + token, err = t.tokenSource(true) + if err != nil { + return nil, fmt.Errorf("token refresh (retry): %w", err) + } + r2 := req.Clone(req.Context()) + r2.Header.Set("Authorization", "Bearer "+token) + return t.base.RoundTrip(r2) + } + + return resp, nil } diff --git a/internal/run/run.go b/internal/run/run.go index 4705ff9..5b5e37a 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -18,8 +18,6 @@ import ( "syscall" "time" - "github.com/cli/browser" - agentv1 "github.com/kontext-dev/kontext-cli/gen/kontext/agent/v1" "github.com/kontext-dev/kontext-cli/internal/auth" "github.com/kontext-dev/kontext-cli/internal/backend" @@ -173,11 +171,9 @@ func resolveCredentials(ctx context.Context, session *auth.Session, entries []cr value, err := exchangeCredential(ctx, session, entry) if err != nil { if isNotConnectedError(err) { - fmt.Fprintln(os.Stderr, "not connected") - fmt.Fprintf(os.Stderr, " Opening browser to connect %s...\n", entry.Provider) - connectURL := fmt.Sprintf("%s/connect/%s", auth.DefaultIssuerURL, entry.Provider) - _ = browser.OpenURL(connectURL) - fmt.Fprint(os.Stderr, " Press Enter after connecting...") + fmt.Fprintln(os.Stderr, "needs authorization") + fmt.Fprintf(os.Stderr, " → Connect %s via an MCP client (e.g. Claude Desktop) or the hosted connect flow.\n", entry.Provider) + fmt.Fprintf(os.Stderr, " → Then press Enter to retry, or press Enter now to skip.\n") bufio.NewReader(os.Stdin).ReadString('\n') value, err = exchangeCredential(ctx, session, entry) } @@ -263,13 +259,15 @@ func buildEnv(resolved []credential.Resolved) []string { // newSessionTokenSource returns a TokenSource that transparently refreshes // the OIDC access token when it expires, so long-running sessions keep working. +// If forceRefresh is true, the token is refreshed unconditionally (used by +// the transport layer after receiving a 401 from the server). func newSessionTokenSource(ctx context.Context, session *auth.Session) backend.TokenSource { mu := &sync.Mutex{} - return func() (string, error) { + return func(forceRefresh bool) (string, error) { mu.Lock() defer mu.Unlock() - if !session.IsExpired() { + if !forceRefresh && !session.IsExpired() { return session.AccessToken, nil } @@ -278,6 +276,8 @@ func newSessionTokenSource(ctx context.Context, session *auth.Session) backend.T return "", fmt.Errorf("token expired and refresh failed: %w", err) } + fmt.Fprintf(os.Stderr, "✓ Token refreshed\n") + // Persist so other processes (and the next `kontext start`) see the new token if saveErr := auth.SaveSession(refreshed); saveErr != nil { fmt.Fprintf(os.Stderr, "⚠ Could not persist refreshed session: %v\n", saveErr)