Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import uuid
from abc import ABC, abstractmethod
from asyncio import Future, Task
from typing import Any, AsyncIterator, Awaitable, Callable, Dict, Generic, Sequence, Set, Tuple, TypeVar
Expand Down Expand Up @@ -106,7 +107,7 @@ def __init__(self) -> None:
] = {}
self._agent_type_to_client_id_lock = asyncio.Lock()
self._agent_type_to_client_id: Dict[str, ClientConnectionId] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Future[Any]]] = {}
self._pending_responses: Dict[ClientConnectionId, Dict[str, Tuple[Future[Any], str]]] = {}
self._background_tasks: Set[Task[Any]] = set()
self._subscription_manager = SubscriptionManager()
self._client_id_to_subscription_id_mapping: Dict[ClientConnectionId, set[str]] = {}
Expand Down Expand Up @@ -134,7 +135,7 @@ async def handle_callback(message: agent_worker_pb2.Message) -> None:
# Clean up the client connection.
del self._data_connections[client_id]
# Cancel pending requests sent to this client.
for future in self._pending_responses.pop(client_id, {}).values():
for future, _ in self._pending_responses.pop(client_id, {}).values():
future.cancel()
# Remove the client id from the agent type to client id mapping.
await self._on_client_disconnect(client_id)
Expand Down Expand Up @@ -240,11 +241,21 @@ async def _process_request(self, request: agent_worker_pb2.RpcRequest, client_id
if target_send_queue is None:
logger.error(f"Client {target_client_id} not found, failed to deliver message.")
return
await target_send_queue.send(agent_worker_pb2.Message(request=request))
# Generate a host-level unique ID to avoid collisions when different senders
# both start their per-session counters from "1" and target the same runtime.
# The forwarded request carries the UUID; the original request_id is restored
# in the response before it is returned to the sender.
host_request_id = str(uuid.uuid4())
original_request_id = request.request_id

forwarded_request = agent_worker_pb2.RpcRequest()
forwarded_request.CopyFrom(request)
forwarded_request.request_id = host_request_id
await target_send_queue.send(agent_worker_pb2.Message(request=forwarded_request))

# Create a future to wait for the response from the target.
future = asyncio.get_event_loop().create_future()
self._pending_responses.setdefault(target_client_id, {})[request.request_id] = future
future: Future[agent_worker_pb2.RpcResponse] = asyncio.get_event_loop().create_future()
self._pending_responses.setdefault(target_client_id, {})[host_request_id] = (future, original_request_id)

# Create a task to wait for the response and send it back to the client.
send_response_task = asyncio.create_task(self._wait_and_send_response(future, client_id))
Expand All @@ -265,8 +276,17 @@ async def _wait_and_send_response(

async def _process_response(self, response: agent_worker_pb2.RpcResponse, client_id: ClientConnectionId) -> None:
# Setting the result of the future will send the response back to the original sender.
future = self._pending_responses[client_id].pop(response.request_id)
future.set_result(response)
try:
future, original_request_id = self._pending_responses[client_id].pop(response.request_id)
except KeyError:
logger.error(f"No pending response for client {client_id} with request_id {response.request_id}")
return
# Restore the original sender's request_id so the sender's pending_requests map
# can match the response to the request it originally sent.
restored_response = agent_worker_pb2.RpcResponse()
restored_response.CopyFrom(response)
restored_response.request_id = original_request_id
future.set_result(restored_response)

async def _process_event(self, event: cloudevent_pb2.CloudEvent) -> None:
topic_id = TopicId(type=event.type, source=event.source)
Expand Down
74 changes: 74 additions & 0 deletions python/packages/autogen-ext/tests/test_worker_runtime.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import logging
import os
from dataclasses import dataclass
from typing import Any, List

import pytest
Expand All @@ -17,6 +18,7 @@
TypeSubscription,
default_subscription,
event,
message_handler,
try_get_known_serializers_for_type,
type_subscription,
)
Expand Down Expand Up @@ -710,6 +712,78 @@ async def test_instance_factory_messaging() -> None:
# await worker.stop()
# await host.stop()

@pytest.mark.grpc
@pytest.mark.asyncio
async def test_cross_runtime_rpc_no_request_id_collision() -> None:
"""Regression test for https://github.com/microsoft/autogen/issues/7016.

When two different runtimes both start their request_id counter from "1" and
send RPC requests whose target agent lives on the same third runtime, the host
used to key _pending_responses by (target_client_id, request_id). Both
requests would share the same key, causing the second pop() to raise KeyError.

The fix replaces each forwarded request_id with a host-generated UUID so keys
are globally unique inside the host.
"""
host_address = "localhost:50062"

@dataclass
class PingMessage:
content: str

class InnerAgent(RoutedAgent):
"""Echoes every PingMessage it receives."""

def __init__(self) -> None:
super().__init__("Inner echo agent.")

@message_handler
async def on_ping(self, message: PingMessage, ctx: MessageContext) -> PingMessage:
return PingMessage(content=f"inner:{message.content}")

class RelayAgent(RoutedAgent):
"""Forwards a PingMessage to InnerAgent and returns its response."""

def __init__(self) -> None:
super().__init__("Relay agent.")

@message_handler
async def on_ping(self, message: PingMessage, ctx: MessageContext) -> PingMessage:
inner_response: PingMessage = await self.send_message(
PingMessage(content=message.content),
AgentId("inner_agent", "default"),
)
return PingMessage(content=f"relay:{inner_response.content}")

host = GrpcWorkerAgentRuntimeHost(address=host_address)
host.start()

# runtime1 hosts both agents — intra-runtime forwarding produces a second
# RPC at the host level with request_id == "1", colliding with the one from
# runtime2 (which also starts its counter at "1").
runtime1 = GrpcWorkerAgentRuntime(host_address=host_address)
runtime1.add_message_serializer(try_get_known_serializers_for_type(PingMessage))
await runtime1.start()
await RelayAgent.register(runtime1, "relay_agent", lambda: RelayAgent())
await InnerAgent.register(runtime1, "inner_agent", lambda: InnerAgent())

# runtime2 is the external sender — no agents, only used to initiate the RPC.
runtime2 = GrpcWorkerAgentRuntime(host_address=host_address)
runtime2.add_message_serializer(try_get_known_serializers_for_type(PingMessage))
await runtime2.start()

result: PingMessage = await runtime2.send_message(
PingMessage(content="hello"),
AgentId("relay_agent", "default"),
)

assert result == PingMessage(content="relay:inner:hello")

await runtime2.stop()
await runtime1.stop()
await host.stop()


if __name__ == "__main__":
os.environ["GRPC_VERBOSITY"] = "DEBUG"
os.environ["GRPC_TRACE"] = "all"
Expand Down