diff --git a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py index 1c0b57a440ed..a056b729030c 100644 --- a/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py +++ b/python/packages/autogen-ext/src/autogen_ext/runtimes/grpc/_worker_runtime_host_servicer.py @@ -2,6 +2,7 @@ import asyncio import logging +import uuid from abc import ABC, abstractmethod from asyncio import Future, Task from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar @@ -106,7 +107,7 @@ def __init__(self) -> None: ] = {} self._agent_type_to_client_id_lock = asyncio.Lock() self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {} - self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {} + self._pending_responses: Dict[ClientConnectionId, Dict[str, Tuple[Future[Any], str]]] = {} self._background_tasks: Set[Task[Any]] = set() self._subscription_manager = SubscriptionManager() self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {} @@ -134,7 +135,7 @@ async def handle_callback(message: agent_worker_pb2.Message) -> None: # Clean up the client connection. del self._data_connections[client_id] # Cancel pending requests sent to this client. - for future in self._pending_responses.pop(client_id, {}).values(): + for future, _ in self._pending_responses.pop(client_id, {}).values(): future.cancel() # Remove the client id from the agent type to client id mapping. await self._on_client_disconnect(client_id) @@ -240,11 +241,21 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id if target_send_queue is None: logger.error(f"Client {target_client_id} not found, failed to deliver message.") return - await target_send_queue.send(agent_worker_pb2.Message(request=request)) + # Generate a host-level unique ID to avoid collisions when different senders + # both start their per-session counters from "1" and target the same runtime. + # The forwarded request carries the UUID; the original request_id is restored + # in the response before it is returned to the sender. + host_request_id = str(uuid.uuid4()) + original_request_id = request.request_id + + forwarded_request = agent_worker_pb2.RpcRequest() + forwarded_request.CopyFrom(request) + forwarded_request.request_id = host_request_id + await target_send_queue.send(agent_worker_pb2.Message(request=forwarded_request)) # Create a future to wait for the response from the target. - future = asyncio.get_event_loop().create_future() - self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future + future: Future[agent_worker_pb2.RpcResponse] = asyncio.get_event_loop().create_future() + self._pending_responses.setdefault(target_client_id, {})[host_request_id] = (future, original_request_id) # Create a task to wait for the response and send it back to the client. send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id)) @@ -265,8 +276,17 @@ async def _wait_and_send_response( async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None: # Setting the result of the future will send the response back to the original sender. - future = self._pending_responses[client_id].pop(response.request_id) - future.set_result(response) + try: + future, original_request_id = self._pending_responses[client_id].pop(response.request_id) + except KeyError: + logger.error(f"No pending response for client {client_id} with request_id {response.request_id}") + return + # Restore the original sender's request_id so the sender's pending_requests map + # can match the response to the request it originally sent. + restored_response = agent_worker_pb2.RpcResponse() + restored_response.CopyFrom(response) + restored_response.request_id = original_request_id + future.set_result(restored_response) async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None: topic_id = TopicId(type=event.type, source=event.source) diff --git a/python/packages/autogen-ext/tests/test_worker_runtime.py b/python/packages/autogen-ext/tests/test_worker_runtime.py index ec57f187e821..b1d02c868e2e 100644 --- a/python/packages/autogen-ext/tests/test_worker_runtime.py +++ b/python/packages/autogen-ext/tests/test_worker_runtime.py @@ -1,6 +1,7 @@ import asyncio import logging import os +from dataclasses import dataclass from typing import Any, List import pytest @@ -17,6 +18,7 @@ TypeSubscription, default_subscription, event, + message_handler, try_get_known_serializers_for_type, type_subscription, ) @@ -710,6 +712,78 @@ async def test_instance_factory_messaging() -> None: # await worker.stop() # await host.stop() +@pytest.mark.grpc +@pytest.mark.asyncio +async def test_cross_runtime_rpc_no_request_id_collision() -> None: + """Regression test for https://github.com/microsoft/autogen/issues/7016. + + When two different runtimes both start their request_id counter from "1" and + send RPC requests whose target agent lives on the same third runtime, the host + used to key _pending_responses by (target_client_id, request_id). Both + requests would share the same key, causing the second pop() to raise KeyError. + + The fix replaces each forwarded request_id with a host-generated UUID so keys + are globally unique inside the host. + """ + host_address = "localhost:50062" + + @dataclass + class PingMessage: + content: str + + class InnerAgent(RoutedAgent): + """Echoes every PingMessage it receives.""" + + def __init__(self) -> None: + super().__init__("Inner echo agent.") + + @message_handler + async def on_ping(self, message: PingMessage, ctx: MessageContext) -> PingMessage: + return PingMessage(content=f"inner:{message.content}") + + class RelayAgent(RoutedAgent): + """Forwards a PingMessage to InnerAgent and returns its response.""" + + def __init__(self) -> None: + super().__init__("Relay agent.") + + @message_handler + async def on_ping(self, message: PingMessage, ctx: MessageContext) -> PingMessage: + inner_response: PingMessage = await self.send_message( + PingMessage(content=message.content), + AgentId("inner_agent", "default"), + ) + return PingMessage(content=f"relay:{inner_response.content}") + + host = GrpcWorkerAgentRuntimeHost(address=host_address) + host.start() + + # runtime1 hosts both agents — intra-runtime forwarding produces a second + # RPC at the host level with request_id == "1", colliding with the one from + # runtime2 (which also starts its counter at "1"). + runtime1 = GrpcWorkerAgentRuntime(host_address=host_address) + runtime1.add_message_serializer(try_get_known_serializers_for_type(PingMessage)) + await runtime1.start() + await RelayAgent.register(runtime1, "relay_agent", lambda: RelayAgent()) + await InnerAgent.register(runtime1, "inner_agent", lambda: InnerAgent()) + + # runtime2 is the external sender — no agents, only used to initiate the RPC. + runtime2 = GrpcWorkerAgentRuntime(host_address=host_address) + runtime2.add_message_serializer(try_get_known_serializers_for_type(PingMessage)) + await runtime2.start() + + result: PingMessage = await runtime2.send_message( + PingMessage(content="hello"), + AgentId("relay_agent", "default"), + ) + + assert result == PingMessage(content="relay:inner:hello") + + await runtime2.stop() + await runtime1.stop() + await host.stop() + + if __name__ == "__main__": os.environ["GRPC_VERBOSITY"] = "DEBUG" os.environ["GRPC_TRACE"] = "all"