Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 65 additions & 20 deletions internal/run/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func Start(ctx context.Context, opts Options) error {
}

sessionID := createResp.SessionId
credentialClientID := resolveCredentialClientID(createResp.AgentId, opts.ClientID)
fmt.Fprintf(os.Stderr, "✓ Session: %s (%s)\n", createResp.SessionName, truncateID(sessionID))

// 4. Resolve credentials (before sidecar starts — no background goroutines yet,
Expand All @@ -91,7 +92,7 @@ func Start(ctx context.Context, opts Options) error {
return fmt.Errorf("parse template: %w", err)
}
if len(entries) > 0 {
resolved, err = resolveCredentials(ctx, session, entries, opts.ClientID)
resolved, err = resolveCredentials(ctx, session, entries, credentialClientID)
if err != nil {
return err
}
Expand Down Expand Up @@ -175,7 +176,7 @@ func resolveCredentials(ctx context.Context, session *auth.Session, entries []cr
if err != nil {
if isNotConnectedError(err) {
fmt.Fprintln(os.Stderr, "not connected")
connectURL, connectErr := fetchConnectURL(ctx, session)
connectURL, connectErr := fetchConnectURL(ctx, session, clientID)
if connectErr != nil {
err = fmt.Errorf("create connect session: %w", connectErr)
} else {
Expand All @@ -200,15 +201,20 @@ func resolveCredentials(ctx context.Context, session *auth.Session, entries []cr
return resolved, nil
}

func fetchConnectURL(ctx context.Context, session *auth.Session) (string, error) {
func fetchConnectURL(ctx context.Context, session *auth.Session, clientID string) (string, error) {
gatewayToken, err := exchangeGatewayToken(ctx, session, clientID)
if err != nil {
return "", err
}

connectSessionURL := strings.TrimRight(session.IssuerURL, "/") + "/mcp/connect-session"

req, err := http.NewRequestWithContext(ctx, "POST", connectSessionURL, strings.NewReader("{}"))
if err != nil {
return "", fmt.Errorf("build request: %w", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("Authorization", "Bearer "+session.AccessToken)
req.Header.Set("Authorization", "Bearer "+gatewayToken)

resp, err := http.DefaultClient.Do(req)
if err != nil {
Expand Down Expand Up @@ -242,47 +248,86 @@ func fetchConnectURL(ctx context.Context, session *auth.Session) (string, error)
return result.ConnectURL, nil
}

// exchangeCredential calls POST /oauth2/token with RFC 8693 token exchange
// to resolve a provider credential. The user's access token serves as both
// the subject_token and the Bearer auth — no client secret needed.
func exchangeCredential(ctx context.Context, session *auth.Session, entry credential.Entry, clientID string) (string, error) {
func agentOAuthClientID(agentID string) string {
return "app_" + agentID
}

func resolveCredentialClientID(agentID, fallback string) string {
if strings.TrimSpace(agentID) == "" {
return fallback
}
return agentOAuthClientID(agentID)
}

type tokenExchangeResponse struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ProviderKind string `json:"provider_kind"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}

func exchangeToken(ctx context.Context, session *auth.Session, clientID, resource string, scopes ...string) (*tokenExchangeResponse, error) {
meta, err := auth.DiscoverEndpoints(ctx, session.IssuerURL)
if err != nil {
return "", fmt.Errorf("oauth discovery: %w", err)
return nil, fmt.Errorf("oauth discovery: %w", err)
}

form := url.Values{
"grant_type": {"urn:ietf:params:oauth:grant-type:token-exchange"},
"client_id": {clientID},
"subject_token": {session.AccessToken},
"subject_token_type": {"urn:ietf:params:oauth:token-type:access_token"},
"resource": {entry.Target()},
"resource": {resource},
}
if len(scopes) > 0 {
form.Set("scope", strings.Join(scopes, " "))
}

req, err := http.NewRequestWithContext(ctx, "POST", meta.TokenEndpoint, strings.NewReader(form.Encode()))
if err != nil {
return "", fmt.Errorf("build request: %w", err)
return nil, fmt.Errorf("build request: %w", err)
}
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
req.Header.Set("Authorization", "Bearer "+session.AccessToken)

resp, err := http.DefaultClient.Do(req)
if err != nil {
return "", fmt.Errorf("token exchange request: %w", err)
return nil, fmt.Errorf("token exchange request: %w", err)
}
defer resp.Body.Close()

var result struct {
AccessToken string `json:"access_token"`
TokenType string `json:"token_type"`
ProviderKind string `json:"provider_kind"`
Error string `json:"error"`
ErrorDesc string `json:"error_description"`
}
var result tokenExchangeResponse
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return "", fmt.Errorf("decode token exchange response: %w", err)
return nil, fmt.Errorf("decode token exchange response: %w", err)
}

return &result, nil
}

func exchangeGatewayToken(ctx context.Context, session *auth.Session, clientID string) (string, error) {
result, err := exchangeToken(ctx, session, clientID, "mcp-gateway", "gateway:access")
if err != nil {
return "", err
}
if result.Error != "" {
return "", fmt.Errorf("gateway token exchange failed: %s: %s", result.Error, result.ErrorDesc)
}
if result.AccessToken == "" {
return "", fmt.Errorf("gateway token exchange returned empty access_token")
}

return result.AccessToken, nil
}

// exchangeCredential calls POST /oauth2/token with RFC 8693 token exchange
// to resolve a provider credential. The user's access token serves as both
// the subject_token and the Bearer auth — no client secret needed.
func exchangeCredential(ctx context.Context, session *auth.Session, entry credential.Entry, clientID string) (string, error) {
result, err := exchangeToken(ctx, session, clientID, entry.Target())
if err != nil {
return "", err
}
if result.Error != "" {
if result.Error == "invalid_target" && strings.Contains(result.ErrorDesc, "not allowed") {
return "", fmt.Errorf("provider not connected: %s", entry.Provider)
Expand Down
Loading
Loading