diff --git a/internal/run/run.go b/internal/run/run.go index 75d33f8..a160532 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -220,7 +220,12 @@ func fetchConnectURLWithGatewayLoginFallback( return "", fmt.Errorf("authorize gateway access: %w", err) } - return fetchConnectURLWithGatewayToken(ctx, session.IssuerURL, result.Session.AccessToken) + gatewayToken, err := exchangeGatewayToken(ctx, result.Session, credentialClientID) + if err != nil { + return "", fmt.Errorf("exchange gateway token after authorize: %w", err) + } + + return fetchConnectURLWithGatewayToken(ctx, result.Session.IssuerURL, gatewayToken) } func fetchConnectURL(ctx context.Context, session *auth.Session, clientID string) (string, error) { diff --git a/internal/run/run_test.go b/internal/run/run_test.go index 9152be3..5a08f25 100644 --- a/internal/run/run_test.go +++ b/internal/run/run_test.go @@ -225,6 +225,7 @@ func TestFetchConnectURLWithGatewayLoginFallback(t *testing.T) { t.Parallel() var server *httptest.Server + var tokenExchangeCalls int server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch r.URL.Path { case "/.well-known/oauth-authorization-server": @@ -246,13 +247,26 @@ func TestFetchConnectURLWithGatewayLoginFallback(t *testing.T) { return } w.Header().Set("Content-Type", "application/json") - if got := r.Header.Get("Authorization"); got != "Bearer stale-access-token" { - http.Error(w, "wrong stale authorization", http.StatusUnauthorized) + tokenExchangeCalls++ + switch tokenExchangeCalls { + case 1: + if got := r.Header.Get("Authorization"); got != "Bearer stale-access-token" { + http.Error(w, "wrong stale authorization", http.StatusUnauthorized) + return + } + _, _ = w.Write([]byte(`{"error":"invalid_scope","error_description":"Requested scope 'gateway:access' exceeds subject token scopes"}`)) + case 2: + if got := r.Header.Get("Authorization"); got != "Bearer gateway-login-token" { + http.Error(w, "wrong gateway-login authorization", http.StatusUnauthorized) + return + } + _, _ = w.Write([]byte(`{"access_token":"gateway-exchange-token"}`)) + default: + http.Error(w, "unexpected token exchange call", http.StatusBadRequest) return } - _, _ = w.Write([]byte(`{"error":"invalid_scope","error_description":"Requested scope 'gateway:access' exceeds subject token scopes"}`)) case "/mcp/connect-session": - if got := r.Header.Get("Authorization"); got != "Bearer gateway-login-token" { + if got := r.Header.Get("Authorization"); got != "Bearer gateway-exchange-token" { http.Error(w, "wrong gateway authorization", http.StatusUnauthorized) return } @@ -301,6 +315,9 @@ func TestFetchConnectURLWithGatewayLoginFallback(t *testing.T) { if loginCalls != 1 { t.Fatalf("loginCalls = %d, want 1", loginCalls) } + if tokenExchangeCalls != 2 { + t.Fatalf("tokenExchangeCalls = %d, want 2", tokenExchangeCalls) + } if session.AccessToken != "stale-access-token" { t.Fatalf("session.AccessToken = %q, want stale token unchanged", session.AccessToken) }