From 6010307d55960232d6d3b3d470ae5c53f1ee2c2d Mon Sep 17 00:00:00 2001 From: Michel Osswald Date: Thu, 16 Apr 2026 23:43:47 +0200 Subject: [PATCH] fix(auth): use stable session identity --- cmd/kontext/main.go | 6 +- internal/auth/oidc.go | 60 +++++++++++++++++--- internal/auth/oidc_test.go | 111 +++++++++++++++++++++++++++++++++++-- internal/auth/session.go | 20 +++++++ internal/run/run.go | 15 ++--- 5 files changed, 191 insertions(+), 21 deletions(-) diff --git a/cmd/kontext/main.go b/cmd/kontext/main.go index e71e641..80593c0 100644 --- a/cmd/kontext/main.go +++ b/cmd/kontext/main.go @@ -93,7 +93,11 @@ func loginCmd() *cobra.Command { return fmt.Errorf("save session: %w", err) } - fmt.Fprintf(os.Stderr, "Logged in as %s (%s)\n", result.Session.User.Name, result.Session.User.Email) + if display := result.Session.DisplayIdentity(); display != "" { + fmt.Fprintf(os.Stderr, "Logged in as %s\n", display) + } else { + fmt.Fprintln(os.Stderr, "Logged in.") + } return nil }, } diff --git a/internal/auth/oidc.go b/internal/auth/oidc.go index 999c607..2c6124b 100644 --- a/internal/auth/oidc.go +++ b/internal/auth/oidc.go @@ -32,6 +32,12 @@ var defaultLoginScopes = []string{ "offline_access", } +var identityLoginScopes = []string{ + "openid", + "email", + "profile", +} + // LoginResult is the output of a successful login flow. type LoginResult struct { Session *Session @@ -162,23 +168,57 @@ func Login(ctx context.Context, issuerURL, clientID string, scopes ...string) (* // Try to decode ID token for user claims if rawIDToken, ok := token.Extra("id_token").(string); ok { - session.IDToken = rawIDToken - if claims, err := decodeJWTClaims(rawIDToken); err == nil { - session.User.Name = claims.Name - session.User.Email = claims.Email + if err := applyIDTokenClaims(session, rawIDToken); err != nil { + return nil, err } } + if session.Subject == "" { + return nil, fmt.Errorf("id token missing subject claim") + } + applyTokenExtraEmailFallback(session, token) return &LoginResult{Session: session}, nil } +func applyIDTokenClaims(session *Session, rawIDToken string) error { + session.IDToken = rawIDToken + claims, err := decodeJWTClaims(rawIDToken) + if err != nil { + return err + } + + if claims.Subject == "" { + return fmt.Errorf("id token missing subject claim") + } + session.Subject = claims.Subject + session.User.Name = claims.Name + session.User.Email = claims.Email + return nil +} + func resolveLoginScopes(scopes []string) []string { + baseScopes := defaultLoginScopes if len(scopes) > 0 { - return append([]string(nil), scopes...) + baseScopes = identityLoginScopes + } + + resolved := append([]string(nil), baseScopes...) + for _, scope := range scopes { + if !hasScope(resolved, scope) { + resolved = append(resolved, scope) + } } + return resolved +} - return append([]string(nil), defaultLoginScopes...) +func hasScope(scopes []string, scope string) bool { + for _, existing := range scopes { + if existing == scope { + return true + } + } + return false } func applyTokenExtraEmailFallback(session *Session, token *oauth2.Token) { @@ -234,6 +274,9 @@ func Preflight(ctx context.Context) (*Session, error) { if err != nil { return nil, err } + if _, err := session.IdentityKey(); err != nil { + return nil, err + } if !session.IsExpired() { return session, nil @@ -254,8 +297,9 @@ func Preflight(ctx context.Context) (*Session, error) { // --- helpers --- type jwtClaims struct { - Name string `json:"name"` - Email string `json:"email"` + Subject string `json:"sub"` + Name string `json:"name"` + Email string `json:"email"` } // decodeJWTClaims decodes the payload of a JWT without verification. diff --git a/internal/auth/oidc_test.go b/internal/auth/oidc_test.go index 46c33ee..2e1eee7 100644 --- a/internal/auth/oidc_test.go +++ b/internal/auth/oidc_test.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "reflect" + "strings" "testing" "golang.org/x/oauth2" @@ -29,13 +30,19 @@ func TestResolveLoginScopesDefaults(t *testing.T) { } } -func TestResolveLoginScopesCustom(t *testing.T) { +func TestResolveLoginScopesAddsCustomScopes(t *testing.T) { t.Parallel() input := []string{"gateway:access"} got := resolveLoginScopes(input) - if !reflect.DeepEqual(got, input) { - t.Fatalf("resolveLoginScopes(%#v) = %#v", input, got) + want := []string{ + "openid", + "email", + "profile", + "gateway:access", + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("resolveLoginScopes(%#v) = %#v, want %#v", input, got, want) } got[0] = "mutated" @@ -44,10 +51,27 @@ func TestResolveLoginScopesCustom(t *testing.T) { } } +func TestResolveLoginScopesDeduplicatesDefaultScopes(t *testing.T) { + t.Parallel() + + input := []string{"openid", "gateway:access"} + got := resolveLoginScopes(input) + want := []string{ + "openid", + "email", + "profile", + "gateway:access", + } + if !reflect.DeepEqual(got, want) { + t.Fatalf("resolveLoginScopes(%#v) = %#v, want %#v", input, got, want) + } +} + func TestDecodeJWTClaims(t *testing.T) { t.Parallel() payload, err := json.Marshal(map[string]string{ + "sub": "user-123", "name": "Ada Lovelace", "email": "ada@example.com", "role": "admin", @@ -61,8 +85,8 @@ func TestDecodeJWTClaims(t *testing.T) { if err != nil { t.Fatalf("decodeJWTClaims() error = %v", err) } - if got.Name != "Ada Lovelace" || got.Email != "ada@example.com" { - t.Fatalf("decodeJWTClaims() = %#v, want name and email extracted", got) + if got.Subject != "user-123" || got.Name != "Ada Lovelace" || got.Email != "ada@example.com" { + t.Fatalf("decodeJWTClaims() = %#v, want subject, name, and email extracted", got) } } @@ -89,3 +113,80 @@ func TestApplyTokenExtraEmailFallbackKeepsIDTokenEmail(t *testing.T) { t.Fatalf("session.User.Email = %q, want ID token email", session.User.Email) } } + +func TestApplyIDTokenClaimsRequiresSubject(t *testing.T) { + t.Parallel() + + session := &Session{} + err := applyIDTokenClaims(session, unsignedJWT(map[string]any{ + "email": "dev@example.com", + })) + if err == nil { + t.Fatal("applyIDTokenClaims() error = nil, want missing subject error") + } + if !strings.Contains(err.Error(), "subject") { + t.Fatalf("applyIDTokenClaims() error = %q, want subject message", err) + } +} + +func TestApplyIDTokenClaimsStoresSubjectAndDisplayClaims(t *testing.T) { + t.Parallel() + + session := &Session{} + err := applyIDTokenClaims(session, unsignedJWT(map[string]any{ + "sub": "user-123", + "name": "Dev User", + "email": "dev@example.com", + })) + if err != nil { + t.Fatalf("applyIDTokenClaims() error = %v", err) + } + if session.Subject != "user-123" { + t.Fatalf("session.Subject = %q, want user-123", session.Subject) + } + if got := session.DisplayIdentity(); got != "dev@example.com" { + t.Fatalf("DisplayIdentity() = %q, want email", got) + } +} + +func TestSessionIdentityKeyUsesIssuerAndSubject(t *testing.T) { + t.Parallel() + + session := &Session{IssuerURL: "https://api.kontext.security/"} + session.Subject = "user-123" + + got, err := session.IdentityKey() + if err != nil { + t.Fatalf("IdentityKey() error = %v", err) + } + want := "https://api.kontext.security#user-123" + if got != want { + t.Fatalf("IdentityKey() = %q, want %q", got, want) + } +} + +func TestSessionIdentityKeyRejectsLegacySession(t *testing.T) { + t.Parallel() + + session := &Session{IssuerURL: "https://api.kontext.security"} + _, err := session.IdentityKey() + if err == nil { + t.Fatal("IdentityKey() error = nil, want missing identity error") + } + if !strings.Contains(err.Error(), "kontext login") { + t.Fatalf("IdentityKey() error = %q, want login hint", err) + } +} + +func unsignedJWT(claims map[string]any) string { + header := map[string]any{"alg": "none", "typ": "JWT"} + return encodeJWTPart(header) + "." + encodeJWTPart(claims) + "." +} + +func encodeJWTPart(v any) string { + data, err := json.Marshal(v) + if err != nil { + panic(err) + } + return base64.RawURLEncoding.EncodeToString(data) +} diff --git a/internal/auth/session.go b/internal/auth/session.go index a9a6731..8cd04b9 100644 --- a/internal/auth/session.go +++ b/internal/auth/session.go @@ -3,6 +3,7 @@ package auth import ( "encoding/json" "fmt" + "strings" "time" "github.com/zalando/go-keyring" @@ -18,6 +19,7 @@ const ( type Session struct { User UserInfo `json:"user"` IssuerURL string `json:"issuer_url"` + Subject string `json:"subject"` AccessToken string `json:"access_token"` IDToken string `json:"id_token"` RefreshToken string `json:"refresh_token"` @@ -35,6 +37,24 @@ func (s *Session) IsExpired() bool { return time.Now().After(s.ExpiresAt.Add(-refreshBuffer)) } +// IdentityKey returns the stable identity used for backend session attribution. +func (s *Session) IdentityKey() (string, error) { + issuer := strings.TrimRight(strings.TrimSpace(s.IssuerURL), "/") + subject := strings.TrimSpace(s.Subject) + if issuer == "" || subject == "" { + return "", fmt.Errorf("stored session is missing identity information (run `kontext login`)") + } + return issuer + "#" + subject, nil +} + +// DisplayIdentity returns the human-readable identity for terminal output. +func (s *Session) DisplayIdentity() string { + if s.User.Email != "" { + return s.User.Email + } + return s.User.Name +} + // LoadSession reads the stored session from the system keyring. func LoadSession() (*Session, error) { data, err := keyring.Get(keyringService, keyringUser) diff --git a/internal/run/run.go b/internal/run/run.go index 9dd0794..66206d3 100644 --- a/internal/run/run.go +++ b/internal/run/run.go @@ -50,14 +50,15 @@ func Start(ctx context.Context, opts Options) error { if err != nil { return err } - identity := session.User.Email - if identity == "" { - identity = session.User.Name + identityKey, err := session.IdentityKey() + if err != nil { + return err } - if identity == "" { - identity = "authenticated" + if display := session.DisplayIdentity(); display != "" { + fmt.Fprintf(os.Stderr, "✓ Authenticated as %s\n", display) + } else { + fmt.Fprintln(os.Stderr, "✓ Authenticated") } - fmt.Fprintf(os.Stderr, "✓ Authenticated as %s\n", identity) // 2. Backend client — token source refreshes automatically on expiry client := backend.NewClient(backend.BaseURL(), newSessionTokenSource(ctx, session)) @@ -66,7 +67,7 @@ func Start(ctx context.Context, opts Options) error { hostname, _ := os.Hostname() cwd, _ := os.Getwd() createResp, err := client.CreateSession(ctx, &agentv1.CreateSessionRequest{ - UserId: identity, + UserId: identityKey, Agent: opts.Agent, Hostname: hostname, Cwd: cwd,