diff --git a/internal/pkg/cli/command/target/target.go b/internal/pkg/cli/command/target/target.go index 6f91402..bdb0204 100644 --- a/internal/pkg/cli/command/target/target.go +++ b/internal/pkg/cli/command/target/target.go @@ -199,8 +199,11 @@ func NewTargetCmd() *cobra.Command { // If the org chosen differs from the current orgId in the token, we need to login again if currentTokenOrgId != "" && currentTokenOrgId != targetOrg.Id { + // Fetch SSO connection while the current token is still valid, + // before logout clears it. + ssoConn := login.ResolveSSOConnection(ctx, targetOrg.Id) oauth.Logout() - err = login.GetAndSetAccessToken(ctx, &targetOrg.Id, login.Options{Json: options.json, Wait: true}) + err = login.GetAndSetAccessToken(ctx, &targetOrg.Id, login.Options{Json: options.json, Wait: true, SSOConnection: ssoConn}) if err != nil { msg.FailJSON(options.json, "Failed to get access token: %s", err) exit.Error(err, "Error getting access token") @@ -245,8 +248,11 @@ func NewTargetCmd() *cobra.Command { // If the org chosen differs from the current orgId in the token, we need to login again if currentTokenOrgId != org.Id { + // Fetch SSO connection while the current token is still valid, + // before logout clears it. + ssoConn := login.ResolveSSOConnection(ctx, org.Id) oauth.Logout() - err = login.GetAndSetAccessToken(ctx, &org.Id, login.Options{Json: options.json, Wait: true}) + err = login.GetAndSetAccessToken(ctx, &org.Id, login.Options{Json: options.json, Wait: true, SSOConnection: ssoConn}) if err != nil { msg.FailJSON(options.json, "Failed to get access token: %s", err) exit.Error(err, "Error getting access token") diff --git a/internal/pkg/utils/login/login.go b/internal/pkg/utils/login/login.go index 4237428..2869d4e 100644 --- a/internal/pkg/utils/login/login.go +++ b/internal/pkg/utils/login/login.go @@ -46,6 +46,11 @@ type Options struct { // RunPostAuthSetup is not called in Wait mode; the caller is responsible // for any post-auth state setup and output. Wait bool + // SSOConnection is the Auth0 connection name to pass as `connection=` in the + // authorization URL, routing the browser directly to the org's IdP. + // Callers that hold a valid token before clearing credentials (e.g. pc target) + // should resolve this with FetchSSOConnection before logout, then pass it here. + SSOConnection *string } func Run(ctx context.Context, opts Options) { @@ -62,7 +67,7 @@ func Run(ctx context.Context, opts Options) { exit.Error(err, "Error checking for existing auth session") } if sess != nil { - if err := getAndSetAccessTokenJSON(ctx, nil, false, sess, result); err != nil { + if err := getAndSetAccessTokenJSON(ctx, nil, false, opts.SSOConnection, sess, result); err != nil { msg.FailMsg("Error acquiring access token while logging in: %s", err) exit.Error(err, "Error acquiring access token while logging in") } @@ -79,27 +84,46 @@ func Run(ctx context.Context, opts Options) { exit.Error(err, "Error retrieving oauth token") } + var ssoOrgId *string if !expired && token != nil && token.AccessToken != "" { - if opts.Json { - claims, err := oauth.ParseClaimsUnverified(token) - if err == nil { - fmt.Fprintln(os.Stdout, text.IndentJSON(struct { - Status string `json:"status"` - Email string `json:"email"` - OrgId string `json:"org_id"` - }{Status: "already_authenticated", Email: claims.Email, OrgId: claims.OrgId})) + // Check whether SSO is now enforced for the current org. + needsReauth := false + if claims, claimsErr := oauth.ParseClaimsUnverified(token); claimsErr != nil { + // Can't parse the existing token — log it, but there is no orgId to + // look up SSO against, so fall through and treat as already logged in. + log.Debug().Err(claimsErr).Msg("Run: could not parse existing token claims; skipping SSO enforcement check") + } else if conn := ResolveSSOConnection(ctx, claims.OrgId); conn != nil { + orgId := claims.OrgId + ssoOrgId = &orgId + opts.SSOConnection = conn + oauth.Logout() + needsReauth = true + } + + if !needsReauth { + // Genuinely already authenticated, no SSO enforcement — show "already logged in". + if opts.Json { + claims, err := oauth.ParseClaimsUnverified(token) + if err == nil { + fmt.Fprintln(os.Stdout, text.IndentJSON(struct { + Status string `json:"status"` + Email string `json:"email"` + OrgId string `json:"org_id"` + }{Status: "already_authenticated", Email: claims.Email, OrgId: claims.OrgId})) + } else { + fmt.Fprintln(os.Stdout, text.IndentJSON(struct { + Status string `json:"status"` + }{Status: "already_authenticated"})) + } } else { - fmt.Fprintln(os.Stdout, text.IndentJSON(struct { - Status string `json:"status"` - }{Status: "already_authenticated"})) + msg.WarnMsg("You are already logged in. Please log out first using %s.", style.Code("pc auth logout")) } - } else { - msg.WarnMsg("You are already logged in. Please log out first using %s.", style.Code("pc auth logout")) + return } - return + // Fall through to GetAndSetAccessToken. } - err = GetAndSetAccessToken(ctx, nil, opts) + err = GetAndSetAccessToken(ctx, ssoOrgId, opts) if err != nil { msg.FailMsg("Error acquiring access token while logging in: %s", err) exit.Error(err, "Error acquiring access token while logging in") @@ -188,9 +212,9 @@ func GetAndSetAccessToken(ctx context.Context, orgId *string, opts Options) erro // a terminal (agentic context), always use the JSON/daemon path. opts.Json = opts.Json || !term.IsTerminal(int(os.Stdout.Fd())) if opts.Json { - return getAndSetAccessTokenJSON(ctx, orgId, opts.Wait, nil, nil) + return getAndSetAccessTokenJSON(ctx, orgId, opts.Wait, opts.SSOConnection, nil, nil) } - return getAndSetAccessTokenInteractive(ctx, orgId) + return getAndSetAccessTokenInteractive(ctx, orgId, opts.SSOConnection) } // getAndSetAccessTokenJSON is the agentic path: daemon-backed, non-blocking on stdin. @@ -203,7 +227,7 @@ func GetAndSetAccessToken(ctx context.Context, orgId *string, opts Options) erro // When wait is true (for callers like pc target that need a token on return): spawns // daemon, blocks until auth completes, and returns with the token stored. RunPostAuthSetup // is not called; the caller owns post-auth state and output. -func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, sess *SessionState, result *SessionResult) error { +func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, ssoConnection *string, sess *SessionState, result *SessionResult) error { if sess == nil { // No pre-fetched session — look one up now. var err error @@ -214,15 +238,21 @@ func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, ses } if sess != nil { // If the caller is requesting a specific org that doesn't match the pending - // session's org, the existing session cannot be used. - if orgId != nil && sess.OrgId != nil && *orgId != *sess.OrgId { + // session's org, the existing session cannot be used. A session with no + // recorded org (sess.OrgId == nil) is also a mismatch when orgId is set, + // since it was started without an org constraint and may yield the wrong org. + if orgId != nil && (sess.OrgId == nil || *orgId != *sess.OrgId) { if result != nil { // Daemon has finished and released the port — clean up and start fresh. CleanupSession(sess.SessionId) // Fall through to start a new flow. } else { // Daemon is still running and holds the callback port. - return fmt.Errorf("an auth session for a different organization (%s) is already in progress; wait for it to expire or complete it first", *sess.OrgId) + sessOrg := "unknown" + if sess.OrgId != nil { + sessOrg = *sess.OrgId + } + return fmt.Errorf("an auth session for a different organization (%s) is already in progress; wait for it to expire or complete it first", sessOrg) } } else { return resumeSession(sess, result, wait) @@ -238,18 +268,19 @@ func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, ses return fmt.Errorf("error creating new auth verifier and challenge: %w", err) } - authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId) + authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId, ssoConnection) if err != nil { return fmt.Errorf("error getting auth URL: %w", err) } sessionId := newSessionId() newSess := &SessionState{ - SessionId: sessionId, - CSRFState: csrfState, - AuthURL: authURL, - OrgId: orgId, - CreatedAt: time.Now(), + SessionId: sessionId, + CSRFState: csrfState, + AuthURL: authURL, + OrgId: orgId, + SSOConnection: ssoConnection, + CreatedAt: time.Now(), } if err := writeSessionState(*newSess); err != nil { return fmt.Errorf("error writing session state: %w", err) @@ -266,7 +297,7 @@ func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, ses // Print the auth URL to stderr only: stdout must stay clean so the caller // can emit a single JSON document once this function returns. fmt.Fprintf(os.Stderr, "Visit the following URL to authenticate:\n\n %s\n\n", authURL) - return pollForResult(sessionId, newSess.CreatedAt, true) + return pollForResult(sessionId, newSess.CreatedAt, true, ssoConnection) } // Agentic login (first call): print pending and return immediately. The daemon @@ -276,6 +307,39 @@ func getAndSetAccessTokenJSON(ctx context.Context, orgId *string, wait bool, ses return nil } +// finishAuthWithSSO is called when a login session completes successfully. +// If ssoConnection is non-nil, the session was already an SSO round — emit +// authenticated JSON directly. Otherwise check whether SSO is enforced for +// the authenticated org; if so, log out and start a second SSO login round +// (emitting a new pending JSON for agents to follow). +// +// sessionId is the just-completed session. When starting an SSO round it is +// cleaned up eagerly before calling getAndSetAccessTokenJSON, so that +// findResumableSession does not pick up the stale session and resume it +// instead of starting the new SSO flow. The deferred CleanupSession in the +// caller is a no-op once the files are already removed. +func finishAuthWithSSO(ctx context.Context, sessionId string, ssoConnection *string) error { + if ssoConnection == nil { + // Round 1 — check whether SSO enforcement is needed. + token, _ := oauth.Token(ctx) + if token != nil && token.AccessToken != "" { + if claims, err := oauth.ParseClaimsUnverified(token); err == nil { + if conn := ResolveSSOConnection(ctx, claims.OrgId); conn != nil { + // Clean up the completed session before starting the SSO round + // so findResumableSession won't find and re-resume it. + CleanupSession(sessionId) + oauth.Logout() + return getAndSetAccessTokenJSON(ctx, &claims.OrgId, false, conn, nil, nil) + } + } + } + } + // Already in SSO round, or SSO not enforced — emit authenticated JSON. + setupCtx, cancel := context.WithTimeout(ctx, 30*time.Second) + defer cancel() + return RunPostAuthSetup(setupCtx) +} + // resumeSession handles a session that was already started (e.g. after a process restart). // If the daemon already finished, it handles the result immediately. Otherwise it polls. // When wait is true, RunPostAuthSetup is skipped; the caller owns post-auth state and output. @@ -290,14 +354,12 @@ func resumeSession(sess *SessionState, result *SessionResult, wait bool) error { if wait { return nil } - setupCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - return RunPostAuthSetup(setupCtx) + return finishAuthWithSSO(context.Background(), sess.SessionId, sess.SSOConnection) } // Still pending — poll until the daemon completes. // Don't re-emit pending here: this call will block until done, keeping stdout // to a single JSON value per invocation. - return pollForResult(sess.SessionId, sess.CreatedAt, wait) + return pollForResult(sess.SessionId, sess.CreatedAt, wait, sess.SSOConnection) } // pollForResult polls the daemon's result file until auth completes or the session expires. @@ -311,7 +373,7 @@ func resumeSession(sess *SessionState, result *SessionResult, wait bool) error { // // The polling loop runs on context.Background() so that the root command's --timeout // flag (default 60s) does not interrupt a user still authenticating in the browser. -func pollForResult(sessionId string, createdAt time.Time, wait bool) error { +func pollForResult(sessionId string, createdAt time.Time, wait bool, ssoConnection *string) error { ticker := time.NewTicker(time.Second) defer ticker.Stop() remaining := time.Until(createdAt.Add(sessionMaxAge)) @@ -345,11 +407,7 @@ func pollForResult(sessionId string, createdAt time.Time, wait bool) error { // Caller handles post-auth state and output. return nil } - // Use a fresh context for the post-auth API calls: the original ctx may - // have expired if the user took longer than --timeout to authenticate. - setupCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - return RunPostAuthSetup(setupCtx) + return finishAuthWithSSO(context.Background(), sessionId, ssoConnection) } } } @@ -370,7 +428,7 @@ func printPendingJSON(authURL, sessionId string) { // getAndSetAccessTokenInteractive is the original interactive path: inline callback server, // optional [Enter]-to-open-browser prompt when stdin is a TTY. -func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string) error { +func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string, ssoConnection *string) error { // If a daemon from a prior JSON-mode login exists, check whether it has // already finished before deciding whether to block interactive login. sess, result, err := findResumableSession() @@ -398,13 +456,16 @@ func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string) error { return fmt.Errorf("error creating new auth verifier and challenge: %w", err) } - authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId) + authURL, err := a.GetAuthURL(ctx, csrfState, challenge, orgId, ssoConnection) if err != nil { return fmt.Errorf("error getting auth URL: %w", err) } codeCh := make(chan string, 1) - serverCtx, cancel := context.WithTimeout(ctx, sessionMaxAge) + // Use context.Background() so the root command's --timeout flag (default 60s) + // does not cut off a user who is still authenticating in the browser. Each + // interactive round gets a fresh sessionMaxAge window, matching the daemon path. + serverCtx, cancel := context.WithTimeout(context.Background(), sessionMaxAge) defer cancel() go func() { @@ -451,7 +512,14 @@ func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string) error { return errors.New("error authenticating CLI and retrieving oauth2 access token") } - token, err := a.ExchangeAuthCode(ctx, verifier, code) + // Use a fresh context for post-callback network operations. The original ctx + // may have already exceeded the root command's 60s timeout if the user took + // a while to authenticate — the server accepted the callback successfully, so + // we must not let a stale deadline fail the code exchange or SSO lookup. + apiCtx, apiCancel := context.WithTimeout(context.Background(), 30*time.Second) + defer apiCancel() + + token, err := a.ExchangeAuthCode(apiCtx, verifier, code) if err != nil { return fmt.Errorf("error exchanging auth code for access token: %w", err) } @@ -477,6 +545,20 @@ func getAndSetAccessTokenInteractive(ctx context.Context, orgId *string) error { }) } + // Round 1 — check whether SSO enforcement is needed. + // ssoConnection being non-nil means we're already in the SSO round; skip. + if ssoConnection == nil { + if conn := ResolveSSOConnection(apiCtx, claims.OrgId); conn != nil { + fmt.Fprintf(os.Stderr, "\nSSO is required for your organization. Re-authenticating with your identity provider...\n") + oauth.Logout() + // Cancel the outer serverCtx before starting the SSO round so the + // stdin-watching goroutine above exits via ctx.Done() and cannot race + // with the new round's goroutine for the next Enter keypress. + cancel() + return getAndSetAccessTokenInteractive(ctx, &claims.OrgId, conn) + } + } + return nil } @@ -706,15 +788,33 @@ func EnsureAuthenticated(ctx context.Context) error { } // Daemon finished. - defer CleanupSession(sess.SessionId) if result.Status == "error" { + defer CleanupSession(sess.SessionId) return fmt.Errorf("authentication failed: %s. Run %s to try again.", result.Error, style.Code("pc login")) } - // Reload credentials written by the daemon process into this process's cache, - // then set the target org/project context so the calling command can proceed - // without a separate `pc login` or `pc target` call. + // Reload credentials so we can check SSO enforcement before finalising. _ = secrets.SecretsViper.ReadInConfig() + + // If this was a round-1 session (no SSO connection recorded), check whether + // the org requires SSO before handing off to the calling command. + // When SSO is required we do NOT clean up the session or clear the token: + // the next `pc login -j` call will find the still-alive session, resume it + // via finishAuthWithSSO, detect SSO, and emit a new pending URL for the SSO + // round — so the user only has to complete one more browser step, not two. + if sess.SSOConnection == nil { + token, _ := oauth.Token(ctx) + if token != nil && token.AccessToken != "" { + if claims, claimsErr := oauth.ParseClaimsUnverified(token); claimsErr == nil { + if ResolveSSOConnection(ctx, claims.OrgId) != nil { + return fmt.Errorf("SSO authentication is required for this organization. Run %s to complete authentication.", style.Code("pc login")) + } + } + } + } + + // No SSO required (or already in SSO round) — finalise lazy completion. + defer CleanupSession(sess.SessionId) if _, err := applyAuthContext(ctx); err != nil { // Non-fatal: credentials are valid, context setup is best-effort. log.Debug().Err(err).Msg("EnsureAuthenticated: applyAuthContext failed after lazy credential reload") diff --git a/internal/pkg/utils/login/session.go b/internal/pkg/utils/login/session.go index cf2d2f5..a048686 100644 --- a/internal/pkg/utils/login/session.go +++ b/internal/pkg/utils/login/session.go @@ -17,7 +17,11 @@ type SessionState struct { CSRFState string `json:"csrf_state"` AuthURL string `json:"auth_url"` OrgId *string `json:"org_id,omitempty"` - CreatedAt time.Time `json:"created_at"` + // SSOConnection is set on the second-round SSO session. A non-nil value + // means this session was started specifically for SSO enforcement, so the + // completion handler should skip the SSO check and emit "authenticated". + SSOConnection *string `json:"sso_connection,omitempty"` + CreatedAt time.Time `json:"created_at"` } type SessionResult struct { diff --git a/internal/pkg/utils/login/sso.go b/internal/pkg/utils/login/sso.go new file mode 100644 index 0000000..80ce16c --- /dev/null +++ b/internal/pkg/utils/login/sso.go @@ -0,0 +1,105 @@ +package login + +import ( + "context" + "encoding/json" + "io" + "net/http" + + "github.com/pinecone-io/cli/internal/pkg/utils/configuration/config" + "github.com/pinecone-io/cli/internal/pkg/utils/environment" + "github.com/pinecone-io/cli/internal/pkg/utils/log" + "github.com/pinecone-io/cli/internal/pkg/utils/oauth" +) + +// dashboardOrg is the subset of the dashboard API org response needed for SSO lookup. +type dashboardOrg struct { + Id string `json:"id"` + SSOConnectionName string `json:"sso_connection_name"` + EnforceSSO bool `json:"enforce_sso_authentication"` +} + +type dashboardOrgsResponse struct { + NewOrgs []dashboardOrg `json:"newOrgs"` +} + +// ResolveSSOConnection is a convenience wrapper around FetchSSOConnection that +// returns a pointer to the connection name when SSO is enforced for the org, or +// nil otherwise. Errors are logged at debug level and treated as "no SSO". +func ResolveSSOConnection(ctx context.Context, orgId string) *string { + conn, err := FetchSSOConnection(ctx, orgId) + if err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO connection lookup failed, proceeding without connection param") + } + if conn == "" { + return nil + } + return &conn +} + +// FetchSSOConnection calls the private dashboard API to retrieve the Auth0 +// connection name for the given orgId. It returns ("", nil) when the org has +// no SSO configured, enforce_sso_authentication is false, or any error occurs. +// Errors are non-fatal: the caller should proceed with a normal login URL. +func FetchSSOConnection(ctx context.Context, orgId string) (string, error) { + token, err := oauth.Token(ctx) + if err != nil || token == nil || token.AccessToken == "" { + log.Debug().Str("orgId", orgId).Msg("SSO lookup skipped: no valid token available") + return "", nil + } + + envConfig, err := environment.GetEnvConfig(config.Environment.Get()) + if err != nil { + return "", nil + } + + return fetchSSOConnectionFromURL(ctx, orgId, token.AccessToken, http.DefaultClient, envConfig.DashboardUrl) +} + +// fetchSSOConnectionFromURL is the testable core: it takes an explicit HTTP +// client and dashboard base URL so tests can inject a local httptest.Server. +func fetchSSOConnectionFromURL(ctx context.Context, orgId string, accessToken string, client *http.Client, dashboardURL string) (string, error) { + url := dashboardURL + "/v2/dashboard/organizations" + req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return "", nil + } + req.Header.Set("Authorization", "Bearer "+accessToken) + + resp, err := client.Do(req) + if err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO lookup: dashboard API request failed") + return "", nil + } + defer resp.Body.Close() + + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + log.Debug().Int("status", resp.StatusCode).Str("orgId", orgId).Msg("SSO lookup: dashboard API returned non-2xx") + return "", nil + } + + body, err := io.ReadAll(io.LimitReader(resp.Body, 1<<20)) + if err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO lookup: failed to read dashboard API response") + return "", nil + } + + var orgsResp dashboardOrgsResponse + if err := json.Unmarshal(body, &orgsResp); err != nil { + log.Debug().Err(err).Str("orgId", orgId).Msg("SSO lookup: failed to decode dashboard API response") + return "", nil + } + + for _, org := range orgsResp.NewOrgs { + if org.Id == orgId { + if org.EnforceSSO && org.SSOConnectionName != "" { + log.Debug().Str("orgId", orgId).Str("connection", org.SSOConnectionName).Msg("SSO lookup: found connection") + return org.SSOConnectionName, nil + } + return "", nil + } + } + + log.Debug().Str("orgId", orgId).Msg("SSO lookup: org not found in dashboard response") + return "", nil +} diff --git a/internal/pkg/utils/login/sso_test.go b/internal/pkg/utils/login/sso_test.go new file mode 100644 index 0000000..77dbe89 --- /dev/null +++ b/internal/pkg/utils/login/sso_test.go @@ -0,0 +1,115 @@ +package login + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" +) + +// newDashboardServer starts an httptest.Server that returns the given org list. +// Pass a non-zero statusCode to simulate an error response. +func newDashboardServer(t *testing.T, orgs []dashboardOrg, statusCode int) *httptest.Server { + t.Helper() + if statusCode == 0 { + statusCode = http.StatusOK + } + return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if statusCode != http.StatusOK { + http.Error(w, "error", statusCode) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(dashboardOrgsResponse{NewOrgs: orgs}) + })) +} + +func TestFetchSSOConnection_EnforcedWithConnection(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "alby-saml", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "alby-saml" { + t.Errorf("expected %q, got %q", "alby-saml", conn) + } +} + +func TestFetchSSOConnection_NotEnforced(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "alby-saml", EnforceSSO: false}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection when SSO not enforced, got %q", conn) + } +} + +func TestFetchSSOConnection_OrgNotFound(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-other", SSOConnectionName: "other-saml", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection when org not found, got %q", conn) + } +} + +func TestFetchSSOConnection_NonOKStatus(t *testing.T) { + server := newDashboardServer(t, nil, http.StatusUnauthorized) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection on non-2xx response, got %q", conn) + } +} + +func TestFetchSSOConnection_EmptyConnectionName(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-1", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "" { + t.Errorf("expected empty connection when name is empty, got %q", conn) + } +} + +func TestFetchSSOConnection_MultipleOrgs(t *testing.T) { + server := newDashboardServer(t, []dashboardOrg{ + {Id: "org-1", SSOConnectionName: "org1-saml", EnforceSSO: true}, + {Id: "org-2", SSOConnectionName: "org2-saml", EnforceSSO: true}, + }, 0) + defer server.Close() + + conn, err := fetchSSOConnectionFromURL(context.Background(), "org-2", "fake-token", server.Client(), server.URL) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if conn != "org2-saml" { + t.Errorf("expected %q, got %q", "org2-saml", conn) + } +} diff --git a/internal/pkg/utils/oauth/auth.go b/internal/pkg/utils/oauth/auth.go index 0bb0ebb..fea55a6 100644 --- a/internal/pkg/utils/oauth/auth.go +++ b/internal/pkg/utils/oauth/auth.go @@ -15,7 +15,7 @@ const ( SourceTag = "pinecone_cli" ) -func (a *Auth) GetAuthURL(ctx context.Context, csrfState string, codeChallenge string, orgId *string) (string, error) { +func (a *Auth) GetAuthURL(ctx context.Context, csrfState string, codeChallenge string, orgId *string, ssoConnection *string) (string, error) { conf, err := newOauth2Config() if err != nil { return "", err @@ -34,6 +34,9 @@ func (a *Auth) GetAuthURL(ctx context.Context, csrfState string, codeChallenge s if orgId != nil && *orgId != "" { opts = append(opts, oauth2.SetAuthURLParam("orgId", *orgId)) } + if ssoConnection != nil && *ssoConnection != "" { + opts = append(opts, oauth2.SetAuthURLParam("connection", *ssoConnection)) + } return conf.AuthCodeURL(csrfState, opts...), nil } diff --git a/internal/pkg/utils/oauth/auth_test.go b/internal/pkg/utils/oauth/auth_test.go index e275cff..d5a40a8 100644 --- a/internal/pkg/utils/oauth/auth_test.go +++ b/internal/pkg/utils/oauth/auth_test.go @@ -21,7 +21,7 @@ func TestGetAuthURL_ContainsSourceTag(t *testing.T) { t.Fatalf("failed to create verifier/challenge: %v", err) } - rawURL, err := a.GetAuthURL(ctx, "test-csrf-state", challenge, nil) + rawURL, err := a.GetAuthURL(ctx, "test-csrf-state", challenge, nil, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -48,7 +48,7 @@ func TestGetAuthURL_RequiredParams(t *testing.T) { } csrfState := "test-state-123" - rawURL, err := a.GetAuthURL(ctx, csrfState, challenge, nil) + rawURL, err := a.GetAuthURL(ctx, csrfState, challenge, nil, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -84,7 +84,7 @@ func TestGetAuthURL_WithOrgId(t *testing.T) { } orgId := "test-org-456" - rawURL, err := a.GetAuthURL(ctx, "state", challenge, &orgId) + rawURL, err := a.GetAuthURL(ctx, "state", challenge, &orgId, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -109,7 +109,7 @@ func TestGetAuthURL_WithEmptyOrgId(t *testing.T) { } emptyOrgId := "" - rawURL, err := a.GetAuthURL(ctx, "state", challenge, &emptyOrgId) + rawURL, err := a.GetAuthURL(ctx, "state", challenge, &emptyOrgId, nil) if err != nil { t.Fatalf("GetAuthURL returned error: %v", err) } @@ -123,3 +123,77 @@ func TestGetAuthURL_WithEmptyOrgId(t *testing.T) { t.Errorf("expected orgId to be absent for empty string, got %q", got) } } + +func TestGetAuthURL_WithSSOConnection(t *testing.T) { + a := &Auth{} + ctx := context.Background() + + _, challenge, err := a.CreateNewVerifierAndChallenge() + if err != nil { + t.Fatalf("failed to create verifier/challenge: %v", err) + } + + connection := "alby-saml" + rawURL, err := a.GetAuthURL(ctx, "state", challenge, nil, &connection) + if err != nil { + t.Fatalf("GetAuthURL returned error: %v", err) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if got := parsed.Query().Get("connection"); got != connection { + t.Errorf("expected connection=%q, got %q", connection, got) + } +} + +func TestGetAuthURL_WithNilSSOConnection(t *testing.T) { + a := &Auth{} + ctx := context.Background() + + _, challenge, err := a.CreateNewVerifierAndChallenge() + if err != nil { + t.Fatalf("failed to create verifier/challenge: %v", err) + } + + rawURL, err := a.GetAuthURL(ctx, "state", challenge, nil, nil) + if err != nil { + t.Fatalf("GetAuthURL returned error: %v", err) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if got := parsed.Query().Get("connection"); got != "" { + t.Errorf("expected connection param to be absent, got %q", got) + } +} + +func TestGetAuthURL_WithEmptySSOConnection(t *testing.T) { + a := &Auth{} + ctx := context.Background() + + _, challenge, err := a.CreateNewVerifierAndChallenge() + if err != nil { + t.Fatalf("failed to create verifier/challenge: %v", err) + } + + emptyConnection := "" + rawURL, err := a.GetAuthURL(ctx, "state", challenge, nil, &emptyConnection) + if err != nil { + t.Fatalf("GetAuthURL returned error: %v", err) + } + + parsed, err := url.Parse(rawURL) + if err != nil { + t.Fatalf("failed to parse auth URL: %v", err) + } + + if got := parsed.Query().Get("connection"); got != "" { + t.Errorf("expected connection param to be absent for empty string, got %q", got) + } +}