diff --git a/packages/controller/src/controller.ts b/packages/controller/src/controller.ts index b8930c180..f9b1727e0 100644 --- a/packages/controller/src/controller.ts +++ b/packages/controller/src/controller.ts @@ -789,6 +789,7 @@ export default class ControllerProvider extends BaseProvider { const iframe = new KeychainIFrame({ ...this.options, rpcUrl: this.rpcUrl(), + chainId: this.selectedChain, onClose: () => { this.keychain?.reset?.(); }, diff --git a/packages/controller/src/iframe/keychain.ts b/packages/controller/src/iframe/keychain.ts index ef27d906f..1733d8a71 100644 --- a/packages/controller/src/iframe/keychain.ts +++ b/packages/controller/src/iframe/keychain.ts @@ -6,6 +6,7 @@ import { IFrame, IFrameOptions } from "./base"; type KeychainIframeOptions = IFrameOptions & KeychainOptions & { version?: string; + chainId?: string; ref?: string; refGroup?: string; needsSessionCreation?: boolean; @@ -31,6 +32,7 @@ export class KeychainIFrame extends IFrame { preset, shouldOverridePresetPolicies, rpcUrl, + chainId, ref, refGroup, needsSessionCreation, @@ -77,6 +79,10 @@ export class KeychainIFrame extends IFrame { _url.searchParams.set("rpc_url", encodeURIComponent(rpcUrl)); } + if (chainId) { + _url.searchParams.set("chain_id", chainId); + } + if (ref) { _url.searchParams.set("ref", encodeURIComponent(ref)); } diff --git a/packages/controller/src/node/provider.ts b/packages/controller/src/node/provider.ts index 8f5964312..a6a410329 100644 --- a/packages/controller/src/node/provider.ts +++ b/packages/controller/src/node/provider.ts @@ -146,7 +146,7 @@ export default class SessionProvider extends BaseProvider { redirectUri, )}&redirect_query_name=startapp&policies=${encodeURIComponent( JSON.stringify(this._policies), - )}&rpc_url=${encodeURIComponent(this._rpcUrl)}`; + )}&rpc_url=${encodeURIComponent(this._rpcUrl)}&chain_id=${encodeURIComponent(this._chainId)}`; this._backend.openLink(url); diff --git a/packages/controller/src/session/provider.ts b/packages/controller/src/session/provider.ts index 31ebba76a..8915192eb 100644 --- a/packages/controller/src/session/provider.ts +++ b/packages/controller/src/session/provider.ts @@ -295,7 +295,8 @@ export default class SessionProvider extends BaseProvider { `/session?public_key=${this._publicKey}` + `&redirect_uri=${this._redirectUrl}` + `&redirect_query_name=startapp` + - `&rpc_url=${this._rpcUrl}`; + `&rpc_url=${this._rpcUrl}` + + `&chain_id=${this._chainId}`; if (this._preset) { url += `&preset=${encodeURIComponent(this._preset)}`; diff --git a/packages/controller/src/telegram/provider.ts b/packages/controller/src/telegram/provider.ts index e24832b13..7a64290ca 100644 --- a/packages/controller/src/telegram/provider.ts +++ b/packages/controller/src/telegram/provider.ts @@ -82,7 +82,7 @@ export default class TelegramProvider extends BaseProvider { this._tmaUrl }&redirect_query_name=startapp&policies=${JSON.stringify( this._policies, - )}&rpc_url=${this._rpcUrl}`; + )}&rpc_url=${this._rpcUrl}&chain_id=${this._chainId}`; localStorage.setItem("lastUsedConnector", this.id); openLink(url); diff --git a/packages/keychain/src/hooks/connection.test.ts b/packages/keychain/src/hooks/connection.test.ts index b79d5a166..151de76b3 100644 --- a/packages/keychain/src/hooks/connection.test.ts +++ b/packages/keychain/src/hooks/connection.test.ts @@ -396,44 +396,59 @@ describe("URL rpc_url priority over stored controller rpcUrl", () => { }); describe("Controller disconnect on chain mismatch", () => { - it("should disconnect controller when URL rpc_url differs from stored controller", () => { - const urlRpcUrl = "https://api.cartridge.gg/x/starknet/mainnet"; + it("should disconnect controller when chain IDs differ", () => { + // Controller is on sepolia (0x534e5f534550), URL requests mainnet + const controllerChainId = mockController.chainId(); + const urlChainId = "0x534e5f4d41494e"; // SN_MAIN - // Simulate the effect logic: URL rpc_url differs from controller's rpcUrl - const controllerRpcUrl = mockController.rpcUrl(); - expect(controllerRpcUrl).not.toBe(urlRpcUrl); + expect(controllerChainId).not.toBe(urlChainId); - // The effect should trigger disconnect when rpcUrls don't match + // The effect should trigger disconnect when chain IDs don't match const shouldDisconnect = - mockController && urlRpcUrl && controllerRpcUrl !== urlRpcUrl; + mockController && urlChainId && controllerChainId !== urlChainId; expect(shouldDisconnect).toBeTruthy(); }); - it("should NOT disconnect controller when URL rpc_url matches stored controller", () => { - const urlRpcUrl = "https://rpc.sepolia.example.com"; - const controllerRpcUrl = mockController.rpcUrl(); + it("should NOT disconnect controller when chain IDs match", () => { + // Same chain, potentially different RPC endpoints + const controllerChainId = mockController.chainId(); + const urlChainId = "0x534e5f534550"; // SN_SEPOLIA (matches mock) - // URLs match - no disconnect should happen - expect(controllerRpcUrl).toBe(urlRpcUrl); + expect(controllerChainId).toBe(urlChainId); - const shouldDisconnect = controllerRpcUrl !== urlRpcUrl; + const shouldDisconnect = controllerChainId !== urlChainId; expect(shouldDisconnect).toBe(false); }); - it("should NOT disconnect controller when no URL rpc_url is provided", () => { - const urlRpcUrl = null; + it("should NOT disconnect when chain IDs match despite different RPC URLs", () => { + // Both on sepolia but using different RPC endpoints + const controllerChainId = mockController.chainId(); // 0x534e5f534550 + const urlChainId = "0x534e5f534550"; // SN_SEPOLIA + + // RPC URLs differ, but chain IDs match — no disconnect + const controllerRpcUrl = mockController.rpcUrl(); // sepolia.example.com + const urlRpcUrl = "https://other-provider.com/sepolia"; + expect(controllerRpcUrl).not.toBe(urlRpcUrl); + expect(controllerChainId).toBe(urlChainId); - // No urlRpcUrl - the effect guard returns early - const shouldDisconnect = urlRpcUrl !== null; + const shouldDisconnect = controllerChainId !== urlChainId; expect(shouldDisconnect).toBe(false); }); + it("should NOT disconnect controller when chainId is undefined", () => { + const chainId = undefined; + + // No chainId - the effect guard returns early + const shouldDisconnect = mockController && chainId; + expect(shouldDisconnect).toBeFalsy(); + }); + it("should NOT disconnect controller when controller is not set", () => { const controller = undefined; - const urlRpcUrl = "https://api.cartridge.gg/x/starknet/mainnet"; + const chainId = "0x534e5f4d41494e"; // No controller - the effect guard returns early - const shouldDisconnect = controller && urlRpcUrl; + const shouldDisconnect = controller && chainId; expect(shouldDisconnect).toBeFalsy(); }); }); diff --git a/packages/keychain/src/hooks/connection.ts b/packages/keychain/src/hooks/connection.ts index c5a557a18..99487c61c 100644 --- a/packages/keychain/src/hooks/connection.ts +++ b/packages/keychain/src/hooks/connection.ts @@ -244,7 +244,12 @@ export function useConnectionValue() { ...defaultTheme, }); const [controller, setController] = useState(window.controller); - const [chainId, setChainId] = useState(); + const [chainId, setChainId] = useState(() => + typeof window !== "undefined" + ? (new URLSearchParams(window.location.search).get("chain_id") ?? + undefined) + : undefined, + ); const [controllerVersion, setControllerVersion] = useState(); const connectionStateRef = useRef({ origin, @@ -277,20 +282,23 @@ export function useConnectionValue() { } }, [controller, setRpcUrl, urlRpcUrl]); - // When URL provides an rpc_url that differs from the controller's, log the user - // out so they re-authenticate on the correct chain. The account may not be deployed - // on the target chain, so we can't simply recreate the controller. + // When the requested chain differs from the controller's, log the user out so + // they re-authenticate on the correct chain. Compares chain IDs (semantic + // equality) rather than RPC URLs (string equality) to avoid false mismatches + // from URL normalization differences. useEffect(() => { - if (!controller || !urlRpcUrl) return; - if (controller.rpcUrl() === urlRpcUrl) return; + if (!controller || !chainId) return; + if (controller.chainId() === chainId) return; - setRpcUrl(urlRpcUrl); + if (urlRpcUrl) { + setRpcUrl(urlRpcUrl); + } (async () => { await controller.disconnect(); setController(undefined); })(); - }, [controller, urlRpcUrl, setController, setRpcUrl]); + }, [controller, chainId, urlRpcUrl, setController, setRpcUrl]); const urlParamsRef = useRef<{ theme: string | null;