Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions internal/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,10 @@ func credentialFailureSummary(err error) string {
}

func connectFailureSummary(err error) string {
var mismatch *identityMismatchError
if errors.As(err, &mismatch) {
return mismatch.Error()
}
if needsGatewayAccessReauthentication(err) {
return "gateway access needs authorization"
}
Expand Down Expand Up @@ -560,6 +564,9 @@ func fetchConnectURLWithGatewayLoginFallback(
if err != nil {
return "", fmt.Errorf("authorize gateway access: %w", err)
}
if err := ensureSameIdentity(session, result.Session); err != nil {
return "", err
Comment thread
michiosw marked this conversation as resolved.
Comment thread
michiosw marked this conversation as resolved.
}

gatewayToken, err := exchangeGatewayToken(ctx, result.Session, credentialClientID)
if err != nil {
Expand All @@ -569,6 +576,43 @@ func fetchConnectURLWithGatewayLoginFallback(
return fetchConnectURLWithGatewayToken(ctx, result.Session.IssuerURL, gatewayToken)
}

func ensureSameIdentity(active, browser *auth.Session) error {
activeKey, err := active.IdentityKey()
if err != nil {
return err
}
browserKey, err := browser.IdentityKey()
if err != nil {
return fmt.Errorf("browser authorization session is missing identity information: %w", err)
}
if activeKey == browserKey {
return nil
}

activeLabel := active.DisplayIdentity()
if activeLabel == "" {
activeLabel = activeKey
}
browserLabel := browser.DisplayIdentity()
if browserLabel == "" {
browserLabel = browserKey
}
return &identityMismatchError{activeLabel: activeLabel, browserLabel: browserLabel}
}

type identityMismatchError struct {
activeLabel string
browserLabel string
}

func (e *identityMismatchError) Error() string {
return fmt.Sprintf(
"browser authorization used a different account (active CLI account: %s; browser account: %s). Run `kontext login` with the account you want to use, then retry",
e.activeLabel,
e.browserLabel,
)
}

func fetchConnectURL(ctx context.Context, session *auth.Session, clientID string) (string, error) {
gatewayToken, err := exchangeGatewayToken(ctx, session, clientID)
if err != nil {
Expand Down
74 changes: 74 additions & 0 deletions internal/run/run_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,7 @@ func TestFetchConnectURLWithGatewayLoginFallback(t *testing.T) {

session := &auth.Session{
IssuerURL: server.URL,
Subject: "user-123",
AccessToken: "stale-access-token",
}

Expand All @@ -430,6 +431,7 @@ func TestFetchConnectURLWithGatewayLoginFallback(t *testing.T) {

result := &auth.LoginResult{Session: &auth.Session{
IssuerURL: server.URL,
Subject: "user-123",
AccessToken: "gateway-login-token",
}}
return result, nil
Expand Down Expand Up @@ -460,6 +462,78 @@ func TestFetchConnectURLWithGatewayLoginFallback(t *testing.T) {
}
}

func TestFetchConnectURLWithGatewayLoginFallbackRejectsAccountMismatch(t *testing.T) {
t.Parallel()

var server *httptest.Server
var tokenExchangeCalls int
server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/.well-known/oauth-authorization-server":
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(fmt.Sprintf(`{"issuer":"%s","authorization_endpoint":"%s/oauth2/auth","token_endpoint":"%s/oauth2/token","jwks_uri":"%s/.well-known/jwks.json"}`, server.URL, server.URL, server.URL, server.URL)))
case "/oauth2/token":
tokenExchangeCalls++
w.Header().Set("Content-Type", "application/json")
_, _ = w.Write([]byte(`{"error":"invalid_scope","error_description":"Requested scope 'gateway:access' exceeds subject token scopes"}`))
default:
http.NotFound(w, r)
}
}))
defer server.Close()

session := &auth.Session{
IssuerURL: server.URL,
Subject: "active-user",
AccessToken: "stale-access-token",
}
session.User.Email = "active@example.com"

login := func(ctx context.Context, issuerURL, clientID string, scopes ...string) (*auth.LoginResult, error) {
result := &auth.LoginResult{Session: &auth.Session{
IssuerURL: server.URL,
Subject: "browser-user",
AccessToken: "gateway-login-token",
}}
result.Session.User.Email = "browser@example.com"
return result, nil
}

_, err := fetchConnectURLWithGatewayLoginFallback(
context.Background(),
session,
"app_agent-123",
login,
)
if err == nil {
t.Fatal("fetchConnectURLWithGatewayLoginFallback() error = nil, want mismatch")
}
if !strings.Contains(err.Error(), "different account") {
t.Fatalf("error = %q, want different account message", err)
}
if summary := connectFailureSummary(err); !strings.Contains(summary, "active@example.com") || !strings.Contains(summary, "browser@example.com") {
t.Fatalf("connectFailureSummary() = %q, want account labels", summary)
}
if tokenExchangeCalls != 1 {
t.Fatalf("tokenExchangeCalls = %d, want only stale-token attempt", tokenExchangeCalls)
}
}

func TestEnsureSameIdentityComparesIssuerAndSubject(t *testing.T) {
t.Parallel()

active := &auth.Session{IssuerURL: "https://issuer-a.example", Subject: "same-subject"}
browser := &auth.Session{IssuerURL: "https://issuer-b.example", Subject: "same-subject"}

err := ensureSameIdentity(active, browser)
if err == nil {
t.Fatal("ensureSameIdentity() error = nil, want issuer mismatch")
}
if !strings.Contains(err.Error(), "different account") {
t.Fatalf("ensureSameIdentity() error = %q, want different account", err)
}
}

func TestFetchConnectURLForConnectFlowSkipsLoginWhenNonInteractive(t *testing.T) {
t.Parallel()

Expand Down
Loading