Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions internal/backend/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ import (
)

// TokenSource returns a valid access token, refreshing if necessary.
type TokenSource func() (string, error)
// If forceRefresh is true, the source must obtain a new token regardless of
// whether the cached one appears valid (used for retry-on-401).
type TokenSource func(forceRefresh bool) (string, error)

// Client wraps the ConnectRPC AgentService client.
type Client struct {
Expand All @@ -39,7 +41,7 @@ func NewClient(baseURL string, ts TokenSource) *Client {
// StaticToken returns a TokenSource that always returns the same token.
// Useful for tests or short-lived commands.
func StaticToken(token string) TokenSource {
return func() (string, error) { return token, nil }
return func(_ bool) (string, error) { return token, nil }
}

// BaseURL returns the API base URL from env or default.
Expand Down Expand Up @@ -85,17 +87,38 @@ func (c *Client) IngestEvent(ctx context.Context, req *agentv1.ProcessHookEventR
}

// bearerTransport fetches a fresh token for every outgoing request.
// On 401, it forces a token refresh and retries once.
type bearerTransport struct {
tokenSource TokenSource
base http.RoundTripper
}

func (t *bearerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
token, err := t.tokenSource()
token, err := t.tokenSource(false)
if err != nil {
return nil, fmt.Errorf("token refresh: %w", err)
}

r := req.Clone(req.Context())
r.Header.Set("Authorization", "Bearer "+token)
return t.base.RoundTrip(r)
resp, err := t.base.RoundTrip(r)
if err != nil {
return nil, err
}

// Retry once with a forced refresh on 401 — the cached token may be
// stale even though IsExpired() said it was fine (server-side revocation,
// clock skew, Hydra TTL mismatch, etc.).
if resp.StatusCode == http.StatusUnauthorized {
resp.Body.Close()
token, err = t.tokenSource(true)
if err != nil {
return nil, fmt.Errorf("token refresh (retry): %w", err)
}
r2 := req.Clone(req.Context())
r2.Header.Set("Authorization", "Bearer "+token)
return t.base.RoundTrip(r2)
}

return resp, nil
}
18 changes: 9 additions & 9 deletions internal/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ import (
"syscall"
"time"

"github.com/cli/browser"

agentv1 "github.com/kontext-dev/kontext-cli/gen/kontext/agent/v1"
"github.com/kontext-dev/kontext-cli/internal/auth"
"github.com/kontext-dev/kontext-cli/internal/backend"
Expand Down Expand Up @@ -173,11 +171,9 @@ func resolveCredentials(ctx context.Context, session *auth.Session, entries []cr
value, err := exchangeCredential(ctx, session, entry)
if err != nil {
if isNotConnectedError(err) {
fmt.Fprintln(os.Stderr, "not connected")
fmt.Fprintf(os.Stderr, " Opening browser to connect %s...\n", entry.Provider)
connectURL := fmt.Sprintf("%s/connect/%s", auth.DefaultIssuerURL, entry.Provider)
_ = browser.OpenURL(connectURL)
fmt.Fprint(os.Stderr, " Press Enter after connecting...")
fmt.Fprintln(os.Stderr, "needs authorization")
fmt.Fprintf(os.Stderr, " → Connect %s via an MCP client (e.g. Claude Desktop) or the hosted connect flow.\n", entry.Provider)
fmt.Fprintf(os.Stderr, " → Then press Enter to retry, or press Enter now to skip.\n")
bufio.NewReader(os.Stdin).ReadString('\n')
value, err = exchangeCredential(ctx, session, entry)
}
Expand Down Expand Up @@ -263,13 +259,15 @@ func buildEnv(resolved []credential.Resolved) []string {

// newSessionTokenSource returns a TokenSource that transparently refreshes
// the OIDC access token when it expires, so long-running sessions keep working.
// If forceRefresh is true, the token is refreshed unconditionally (used by
// the transport layer after receiving a 401 from the server).
func newSessionTokenSource(ctx context.Context, session *auth.Session) backend.TokenSource {
mu := &sync.Mutex{}
return func() (string, error) {
return func(forceRefresh bool) (string, error) {
mu.Lock()
defer mu.Unlock()

if !session.IsExpired() {
if !forceRefresh && !session.IsExpired() {
return session.AccessToken, nil
}

Expand All @@ -278,6 +276,8 @@ func newSessionTokenSource(ctx context.Context, session *auth.Session) backend.T
return "", fmt.Errorf("token expired and refresh failed: %w", err)
}

fmt.Fprintf(os.Stderr, "✓ Token refreshed\n")

// Persist so other processes (and the next `kontext start`) see the new token
if saveErr := auth.SaveSession(refreshed); saveErr != nil {
fmt.Fprintf(os.Stderr, "⚠ Could not persist refreshed session: %v\n", saveErr)
Expand Down
Loading