Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import uuid
from abc import ABC, abstractmethod
from collections.abc import Sequence as ABCSequence
from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence

from autogen_core import (
Expand Down Expand Up @@ -63,6 +64,24 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]):

component_type = "team"

@staticmethod
def _validate_participants(participants: object) -> List[ChatAgent | Team]:
if not isinstance(participants, ABCSequence) or isinstance(participants, (str, bytes, bytearray)):
raise TypeError("participants must be a non-empty sequence of ChatAgent or Team instances.")

validated_participants = list(participants)
if len(validated_participants) == 0:
raise ValueError("At least one participant is required.")

invalid_participant = next(
(participant for participant in validated_participants if not isinstance(participant, (ChatAgent, Team))),
None,
)
if invalid_participant is not None:
raise TypeError("participants must be a non-empty sequence of ChatAgent or Team instances.")

return validated_participants

def __init__(
self,
name: str,
Expand All @@ -78,11 +97,10 @@ def __init__(
):
self._name = name
self._description = description
if len(participants) == 0:
raise ValueError("At least one participant is required.")
if len(participants) != len(set(participant.name for participant in participants)):
validated_participants = self._validate_participants(participants)
if len(validated_participants) != len(set(participant.name for participant in validated_participants)):
raise ValueError("The participant names must be unique.")
self._participants = participants
self._participants = validated_participants
self._base_group_chat_manager_class = group_chat_manager_class
self._termination_condition = termination_condition
self._max_turns = max_turns
Expand All @@ -91,7 +109,7 @@ def __init__(
for message_type in custom_message_types:
self._message_factory.register(message_type)

for agent in participants:
for agent in validated_participants:
if isinstance(agent, ChatAgent):
for message_type in agent.produced_message_types:
try:
Expand All @@ -112,15 +130,15 @@ def __init__(
# The names are used to identify the agents within the team.
# The names may not be unique across different teams.
self._group_chat_manager_name = group_chat_manager_name
self._participant_names: List[str] = [participant.name for participant in participants]
self._participant_descriptions: List[str] = [participant.description for participant in participants]
self._participant_names: List[str] = [participant.name for participant in validated_participants]
self._participant_descriptions: List[str] = [participant.description for participant in validated_participants]
# The group chat topic type is used for broadcast communication among all participants and the group chat manager.
self._group_topic_type = f"group_topic_{self._team_id}"
# The group chat manager topic type is used for direct communication with the group chat manager.
self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}"
# The participant topic types are used for direct communication with each participant.
self._participant_topic_types: List[str] = [
f"{participant.name}_{self._team_id}" for participant in participants
f"{participant.name}_{self._team_id}" for participant in validated_participants
]
# The output topic type is used for emitting streaming messages from the group chat.
# The group chat manager will relay the messages to the output message queue.
Expand Down
17 changes: 17 additions & 0 deletions python/packages/autogen-agentchat/tests/test_group_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -1732,6 +1732,23 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime |
await team.run(task=[])


def test_round_robin_group_chat_rejects_non_sequence_participants() -> None:
with pytest.raises(TypeError, match="participants must be a non-empty sequence of ChatAgent or Team instances"):
RoundRobinGroupChat(participants=None) # type: ignore[arg-type]


def test_round_robin_group_chat_rejects_string_participants() -> None:
with pytest.raises(TypeError, match="participants must be a non-empty sequence of ChatAgent or Team instances"):
RoundRobinGroupChat(participants="not a list") # type: ignore[arg-type]


def test_round_robin_group_chat_rejects_non_agent_participants() -> None:
agent = _EchoAgent("Agent1", "First agent")

with pytest.raises(TypeError, match="participants must be a non-empty sequence of ChatAgent or Team instances"):
RoundRobinGroupChat(participants=[agent, "bad"]) # type: ignore[list-item]


@pytest.mark.asyncio
async def test_declarative_groupchats_with_config(runtime: AgentRuntime | None) -> None:
# Create basic agents and components for testing
Expand Down