diff --git a/internal/backend/backend.go b/internal/backend/backend.go index 879430f..bacbcd9 100644 --- a/internal/backend/backend.go +++ b/internal/backend/backend.go @@ -17,7 +17,9 @@ import ( ) // TokenSource returns a valid access token, refreshing if necessary. -type TokenSource func() (string, error) +// If forceRefresh is true, the source must obtain a new token regardless of +// whether the cached one appears valid (used for retry-on-401). +type TokenSource func(forceRefresh bool) (string, error) // Client wraps the ConnectRPC AgentService client. type Client struct { @@ -39,7 +41,7 @@ func NewClient(baseURL string, ts TokenSource) *Client { // StaticToken returns a TokenSource that always returns the same token. // Useful for tests or short-lived commands. func StaticToken(token string) TokenSource { - return func() (string, error) { return token, nil } + return func(_ bool) (string, error) { return token, nil } } // BaseURL returns the API base URL from env or default. @@ -94,17 +96,38 @@ func (c *Client) IngestEvent(ctx context.Context, req *agentv1.ProcessHookEventR } // bearerTransport fetches a fresh token for every outgoing request. +// On 401, it forces a token refresh and retries once. type bearerTransport struct { tokenSource TokenSource base http.RoundTripper } func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) { - token, err := t.tokenSource() + token, err := t.tokenSource(false) if err != nil { return nil, fmt.Errorf("token refresh: %w", err) } + r := req.Clone(req.Context()) r.Header.Set("Authorization", "Bearer "+token) - return t.base.RoundTrip(r) + resp, err := t.base.RoundTrip(r) + if err != nil { + return nil, err + } + + // Retry once with a forced refresh on 401 — the cached token may be + // stale even though IsExpired() said it was fine (server-side revocation, + // clock skew, Hydra TTL mismatch, etc.). + if resp.StatusCode == http.StatusUnauthorized { + resp.Body.Close() + token, err = t.tokenSource(true) + if err != nil { + return nil, fmt.Errorf("token refresh (retry): %w", err) + } + r2 := req.Clone(req.Context()) + r2.Header.Set("Authorization", "Bearer "+token) + return t.base.RoundTrip(r2) + } + + return resp, nil } diff --git a/internal/run/run.go b/internal/run/run.go index 2e3edc1..1b7ed1f 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -902,13 +902,15 @@ func supportedAgents() []string { // newSessionTokenSource returns a TokenSource that transparently refreshes // the OIDC access token when it expires, so long-running sessions keep working. +// If forceRefresh is true, the token is refreshed unconditionally (used by +// the transport layer after receiving a 401 from the server). func newSessionTokenSource(ctx context.Context, session *auth.Session, diagnostics diagnostic.Logger) backend.TokenSource { mu := &sync.Mutex{} - return func() (string, error) { + return func(forceRefresh bool) (string, error) { mu.Lock() defer mu.Unlock() - if !session.IsExpired() { + if !forceRefresh && !session.IsExpired() { return session.AccessToken, nil } @@ -917,6 +919,8 @@ func newSessionTokenSource(ctx context.Context, session *auth.Session, diagnosti return "", fmt.Errorf("token expired and refresh failed: %w", err) } + fmt.Fprintf(os.Stderr, "✓ Token refreshed\n") + // Persist so other processes (and the next `kontext start`) see the new token if saveErr := auth.SaveSession(refreshed); saveErr != nil { diagnostics.Printf("persist refreshed session failed: %v\n", saveErr)