diff --git a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py index 60f222912387..86eef951a188 100644 --- a/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py +++ b/python/packages/autogen-agentchat/src/autogen_agentchat/teams/_group_chat/_base_group_chat.py @@ -1,6 +1,7 @@ import asyncio import uuid from abc import ABC, abstractmethod +from collections.abc import Sequence as ABCSequence from typing import Any, AsyncGenerator, Callable, Dict, List, Mapping, Sequence from autogen_core import ( @@ -63,6 +64,24 @@ class BaseGroupChat(Team, ABC, ComponentBase[BaseModel]): component_type = "team" + @staticmethod + def _validate_participants(participants: object) -> List[ChatAgent | Team]: + if not isinstance(participants, ABCSequence) or isinstance(participants, (str, bytes, bytearray)): + raise TypeError("participants must be a non-empty sequence of ChatAgent or Team instances.") + + validated_participants = list(participants) + if len(validated_participants) == 0: + raise ValueError("At least one participant is required.") + + invalid_participant = next( + (participant for participant in validated_participants if not isinstance(participant, (ChatAgent, Team))), + None, + ) + if invalid_participant is not None: + raise TypeError("participants must be a non-empty sequence of ChatAgent or Team instances.") + + return validated_participants + def __init__( self, name: str, @@ -78,11 +97,10 @@ def __init__( ): self._name = name self._description = description - if len(participants) == 0: - raise ValueError("At least one participant is required.") - if len(participants) != len(set(participant.name for participant in participants)): + validated_participants = self._validate_participants(participants) + if len(validated_participants) != len(set(participant.name for participant in validated_participants)): raise ValueError("The participant names must be unique.") - self._participants = participants + self._participants = validated_participants self._base_group_chat_manager_class = group_chat_manager_class self._termination_condition = termination_condition self._max_turns = max_turns @@ -91,7 +109,7 @@ def __init__( for message_type in custom_message_types: self._message_factory.register(message_type) - for agent in participants: + for agent in validated_participants: if isinstance(agent, ChatAgent): for message_type in agent.produced_message_types: try: @@ -112,15 +130,15 @@ def __init__( # The names are used to identify the agents within the team. # The names may not be unique across different teams. self._group_chat_manager_name = group_chat_manager_name - self._participant_names: List[str] = [participant.name for participant in participants] - self._participant_descriptions: List[str] = [participant.description for participant in participants] + self._participant_names: List[str] = [participant.name for participant in validated_participants] + self._participant_descriptions: List[str] = [participant.description for participant in validated_participants] # The group chat topic type is used for broadcast communication among all participants and the group chat manager. self._group_topic_type = f"group_topic_{self._team_id}" # The group chat manager topic type is used for direct communication with the group chat manager. self._group_chat_manager_topic_type = f"{self._group_chat_manager_name}_{self._team_id}" # The participant topic types are used for direct communication with each participant. self._participant_topic_types: List[str] = [ - f"{participant.name}_{self._team_id}" for participant in participants + f"{participant.name}_{self._team_id}" for participant in validated_participants ] # The output topic type is used for emitting streaming messages from the group chat. # The group chat manager will relay the messages to the output message queue. diff --git a/python/packages/autogen-agentchat/tests/test_group_chat.py b/python/packages/autogen-agentchat/tests/test_group_chat.py index 3ded2e0c2e60..6aeba41aa577 100644 --- a/python/packages/autogen-agentchat/tests/test_group_chat.py +++ b/python/packages/autogen-agentchat/tests/test_group_chat.py @@ -1732,6 +1732,23 @@ async def test_round_robin_group_chat_with_message_list(runtime: AgentRuntime | await team.run(task=[]) +def test_round_robin_group_chat_rejects_non_sequence_participants() -> None: + with pytest.raises(TypeError, match="participants must be a non-empty sequence of ChatAgent or Team instances"): + RoundRobinGroupChat(participants=None) # type: ignore[arg-type] + + +def test_round_robin_group_chat_rejects_string_participants() -> None: + with pytest.raises(TypeError, match="participants must be a non-empty sequence of ChatAgent or Team instances"): + RoundRobinGroupChat(participants="not a list") # type: ignore[arg-type] + + +def test_round_robin_group_chat_rejects_non_agent_participants() -> None: + agent = _EchoAgent("Agent1", "First agent") + + with pytest.raises(TypeError, match="participants must be a non-empty sequence of ChatAgent or Team instances"): + RoundRobinGroupChat(participants=[agent, "bad"]) # type: ignore[list-item] + + @pytest.mark.asyncio async def test_declarative_groupchats_with_config(runtime: AgentRuntime | None) -> None: # Create basic agents and components for testing