Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion cmd/kontext/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,11 @@ func loginCmd() *cobra.Command {
return fmt.Errorf("save session: %w", err)
}

fmt.Fprintf(os.Stderr, "Logged in as %s (%s)\n", result.Session.User.Name, result.Session.User.Email)
if display := result.Session.DisplayIdentity(); display != "" {
fmt.Fprintf(os.Stderr, "Logged in as %s\n", display)
} else {
fmt.Fprintln(os.Stderr, "Logged in.")
}
return nil
},
}
Expand Down
60 changes: 52 additions & 8 deletions internal/auth/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ var defaultLoginScopes = []string{
"offline_access",
}

var identityLoginScopes = []string{
"openid",
"email",
"profile",
}

// LoginResult is the output of a successful login flow.
type LoginResult struct {
Session *Session
Expand Down Expand Up @@ -162,23 +168,57 @@ func Login(ctx context.Context, issuerURL, clientID string, scopes ...string) (*

// Try to decode ID token for user claims
if rawIDToken, ok := token.Extra("id_token").(string); ok {
session.IDToken = rawIDToken
if claims, err := decodeJWTClaims(rawIDToken); err == nil {
session.User.Name = claims.Name
session.User.Email = claims.Email
if err := applyIDTokenClaims(session, rawIDToken); err != nil {
return nil, err
}
}
if session.Subject == "" {
return nil, fmt.Errorf("id token missing subject claim")
Comment thread
michiosw marked this conversation as resolved.
}

applyTokenExtraEmailFallback(session, token)

return &LoginResult{Session: session}, nil
}

func applyIDTokenClaims(session *Session, rawIDToken string) error {
session.IDToken = rawIDToken
claims, err := decodeJWTClaims(rawIDToken)
if err != nil {
return err
}

if claims.Subject == "" {
return fmt.Errorf("id token missing subject claim")
}
session.Subject = claims.Subject
session.User.Name = claims.Name
session.User.Email = claims.Email
return nil
}

func resolveLoginScopes(scopes []string) []string {
baseScopes := defaultLoginScopes
if len(scopes) > 0 {
return append([]string(nil), scopes...)
baseScopes = identityLoginScopes
}

resolved := append([]string(nil), baseScopes...)
for _, scope := range scopes {
if !hasScope(resolved, scope) {
resolved = append(resolved, scope)
}
}
return resolved
}

return append([]string(nil), defaultLoginScopes...)
func hasScope(scopes []string, scope string) bool {
for _, existing := range scopes {
if existing == scope {
return true
}
}
return false
}

func applyTokenExtraEmailFallback(session *Session, token *oauth2.Token) {
Expand Down Expand Up @@ -234,6 +274,9 @@ func Preflight(ctx context.Context) (*Session, error) {
if err != nil {
return nil, err
}
if _, err := session.IdentityKey(); err != nil {
return nil, err
}

if !session.IsExpired() {
return session, nil
Expand All @@ -254,8 +297,9 @@ func Preflight(ctx context.Context) (*Session, error) {
// --- helpers ---

type jwtClaims struct {
Name string `json:"name"`
Email string `json:"email"`
Subject string `json:"sub"`
Name string `json:"name"`
Email string `json:"email"`
}

// decodeJWTClaims decodes the payload of a JWT without verification.
Expand Down
111 changes: 106 additions & 5 deletions internal/auth/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/base64"
"encoding/json"
"reflect"
"strings"
"testing"

"golang.org/x/oauth2"
Expand All @@ -29,13 +30,19 @@ func TestResolveLoginScopesDefaults(t *testing.T) {
}
}

func TestResolveLoginScopesCustom(t *testing.T) {
func TestResolveLoginScopesAddsCustomScopes(t *testing.T) {
t.Parallel()

input := []string{"gateway:access"}
got := resolveLoginScopes(input)
if !reflect.DeepEqual(got, input) {
t.Fatalf("resolveLoginScopes(%#v) = %#v", input, got)
want := []string{
"openid",
"email",
"profile",
"gateway:access",
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("resolveLoginScopes(%#v) = %#v, want %#v", input, got, want)
}

got[0] = "mutated"
Expand All @@ -44,10 +51,27 @@ func TestResolveLoginScopesCustom(t *testing.T) {
}
}

func TestResolveLoginScopesDeduplicatesDefaultScopes(t *testing.T) {
t.Parallel()

input := []string{"openid", "gateway:access"}
got := resolveLoginScopes(input)
want := []string{
"openid",
"email",
"profile",
"gateway:access",
}
if !reflect.DeepEqual(got, want) {
t.Fatalf("resolveLoginScopes(%#v) = %#v, want %#v", input, got, want)
}
}

func TestDecodeJWTClaims(t *testing.T) {
t.Parallel()

payload, err := json.Marshal(map[string]string{
"sub": "user-123",
"name": "Ada Lovelace",
"email": "ada@example.com",
"role": "admin",
Expand All @@ -61,8 +85,8 @@ func TestDecodeJWTClaims(t *testing.T) {
if err != nil {
t.Fatalf("decodeJWTClaims() error = %v", err)
}
if got.Name != "Ada Lovelace" || got.Email != "ada@example.com" {
t.Fatalf("decodeJWTClaims() = %#v, want name and email extracted", got)
if got.Subject != "user-123" || got.Name != "Ada Lovelace" || got.Email != "ada@example.com" {
t.Fatalf("decodeJWTClaims() = %#v, want subject, name, and email extracted", got)
}
}

Expand All @@ -89,3 +113,80 @@ func TestApplyTokenExtraEmailFallbackKeepsIDTokenEmail(t *testing.T) {
t.Fatalf("session.User.Email = %q, want ID token email", session.User.Email)
}
}

func TestApplyIDTokenClaimsRequiresSubject(t *testing.T) {
t.Parallel()

session := &Session{}
err := applyIDTokenClaims(session, unsignedJWT(map[string]any{
"email": "dev@example.com",
}))
if err == nil {
t.Fatal("applyIDTokenClaims() error = nil, want missing subject error")
}
if !strings.Contains(err.Error(), "subject") {
t.Fatalf("applyIDTokenClaims() error = %q, want subject message", err)
}
}

func TestApplyIDTokenClaimsStoresSubjectAndDisplayClaims(t *testing.T) {
t.Parallel()

session := &Session{}
err := applyIDTokenClaims(session, unsignedJWT(map[string]any{
"sub": "user-123",
"name": "Dev User",
"email": "dev@example.com",
}))
if err != nil {
t.Fatalf("applyIDTokenClaims() error = %v", err)
}
if session.Subject != "user-123" {
t.Fatalf("session.Subject = %q, want user-123", session.Subject)
}
if got := session.DisplayIdentity(); got != "dev@example.com" {
t.Fatalf("DisplayIdentity() = %q, want email", got)
}
}

func TestSessionIdentityKeyUsesIssuerAndSubject(t *testing.T) {
t.Parallel()

session := &Session{IssuerURL: "https://api.kontext.security/"}
session.Subject = "user-123"

got, err := session.IdentityKey()
if err != nil {
t.Fatalf("IdentityKey() error = %v", err)
}
want := "https://api.kontext.security#user-123"
if got != want {
t.Fatalf("IdentityKey() = %q, want %q", got, want)
}
}

func TestSessionIdentityKeyRejectsLegacySession(t *testing.T) {
t.Parallel()

session := &Session{IssuerURL: "https://api.kontext.security"}
_, err := session.IdentityKey()
if err == nil {
t.Fatal("IdentityKey() error = nil, want missing identity error")
}
if !strings.Contains(err.Error(), "kontext login") {
t.Fatalf("IdentityKey() error = %q, want login hint", err)
}
}

func unsignedJWT(claims map[string]any) string {
header := map[string]any{"alg": "none", "typ": "JWT"}
return encodeJWTPart(header) + "." + encodeJWTPart(claims) + "."
}

func encodeJWTPart(v any) string {
data, err := json.Marshal(v)
if err != nil {
panic(err)
}
return base64.RawURLEncoding.EncodeToString(data)
}
20 changes: 20 additions & 0 deletions internal/auth/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package auth
import (
"encoding/json"
"fmt"
"strings"
"time"

"github.com/zalando/go-keyring"
Expand All @@ -18,6 +19,7 @@ const (
type Session struct {
User UserInfo `json:"user"`
IssuerURL string `json:"issuer_url"`
Subject string `json:"subject"`
AccessToken string `json:"access_token"`
IDToken string `json:"id_token"`
RefreshToken string `json:"refresh_token"`
Expand All @@ -35,6 +37,24 @@ func (s *Session) IsExpired() bool {
return time.Now().After(s.ExpiresAt.Add(-refreshBuffer))
}

// IdentityKey returns the stable identity used for backend session attribution.
func (s *Session) IdentityKey() (string, error) {
issuer := strings.TrimRight(strings.TrimSpace(s.IssuerURL), "/")
subject := strings.TrimSpace(s.Subject)
if issuer == "" || subject == "" {
return "", fmt.Errorf("stored session is missing identity information (run `kontext login`)")
}
return issuer + "#" + subject, nil
}

// DisplayIdentity returns the human-readable identity for terminal output.
func (s *Session) DisplayIdentity() string {
if s.User.Email != "" {
return s.User.Email
}
return s.User.Name
}

// LoadSession reads the stored session from the system keyring.
func LoadSession() (*Session, error) {
data, err := keyring.Get(keyringService, keyringUser)
Expand Down
15 changes: 8 additions & 7 deletions internal/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,14 +50,15 @@ func Start(ctx context.Context, opts Options) error {
if err != nil {
return err
}
identity := session.User.Email
if identity == "" {
identity = session.User.Name
identityKey, err := session.IdentityKey()
if err != nil {
return err
}
if identity == "" {
identity = "authenticated"
if display := session.DisplayIdentity(); display != "" {
fmt.Fprintf(os.Stderr, "✓ Authenticated as %s\n", display)
} else {
fmt.Fprintln(os.Stderr, "✓ Authenticated")
}
fmt.Fprintf(os.Stderr, "✓ Authenticated as %s\n", identity)

// 2. Backend client — token source refreshes automatically on expiry
client := backend.NewClient(backend.BaseURL(), newSessionTokenSource(ctx, session))
Expand All @@ -66,7 +67,7 @@ func Start(ctx context.Context, opts Options) error {
hostname, _ := os.Hostname()
cwd, _ := os.Getwd()
createResp, err := client.CreateSession(ctx, &agentv1.CreateSessionRequest{
UserId: identity,
UserId: identityKey,
Agent: opts.Agent,
Hostname: hostname,
Cwd: cwd,
Expand Down
Loading