diff --git a/src/google/adk/agents/base_agent.py b/src/google/adk/agents/base_agent.py index 80b58ff17..d40554cf4 100644 --- a/src/google/adk/agents/base_agent.py +++ b/src/google/adk/agents/base_agent.py @@ -132,6 +132,26 @@ class BaseAgent(BaseModel): response and appended to event history as agent response. """ + def _clone_list_fields( + self, + cloned_agent: BaseAgent, + update: Mapping[str, Any] | None, + ) -> None: + """Shallow copies fields that are lists and not provided in the update.""" + if (update is None or 'before_agent_callback' not in update) and isinstance( + cloned_agent.before_agent_callback, list + ): + cloned_agent.before_agent_callback = ( + cloned_agent.before_agent_callback.copy() + ) + + if (update is None or 'after_agent_callback' not in update) and isinstance( + cloned_agent.after_agent_callback, list + ): + cloned_agent.after_agent_callback = ( + cloned_agent.after_agent_callback.copy() + ) + def clone( self: SelfAgent, update: Mapping[str, Any] | None = None ) -> SelfAgent: @@ -165,6 +185,11 @@ def clone( cloned_agent = self.model_copy(update=update) + # If any field is stored as list and not provided in the update, need to + # shallow copy it for the cloned agent to avoid sharing the same list object + # with the original agent. + self._clone_list_fields(cloned_agent=cloned_agent, update=update) + if update is None or 'sub_agents' not in update: # If `sub_agents` is not provided in the update, need to recursively clone # the sub-agents to avoid sharing the sub-agents with the original agent. diff --git a/src/google/adk/agents/llm_agent.py b/src/google/adk/agents/llm_agent.py index c20d26963..04ae8f771 100644 --- a/src/google/adk/agents/llm_agent.py +++ b/src/google/adk/agents/llm_agent.py @@ -23,6 +23,7 @@ from typing import Callable from typing import List from typing import Literal +from typing import Mapping from typing import Optional from typing import Type from typing import Union @@ -276,6 +277,43 @@ class LlmAgent(BaseAgent): """ # Callbacks - End + @override + def _clone_list_fields( + self, + cloned_agent: LlmAgent, + update: Mapping[str, Any] | None, + ) -> None: + super()._clone_list_fields(cloned_agent=cloned_agent, update=update) + + if (update is None or 'before_model_callback' not in update) and isinstance( + cloned_agent.before_model_callback, list + ): + cloned_agent.before_model_callback = ( + cloned_agent.before_model_callback.copy() + ) + + if (update is None or 'after_model_callback' not in update) and isinstance( + cloned_agent.after_model_callback, list + ): + cloned_agent.after_model_callback = ( + cloned_agent.after_model_callback.copy() + ) + + if (update is None or 'before_tool_callback' not in update) and isinstance( + cloned_agent.before_tool_callback, list + ): + cloned_agent.before_tool_callback = ( + cloned_agent.before_tool_callback.copy() + ) + + if (update is None or 'after_tool_callback' not in update) and isinstance( + cloned_agent.after_tool_callback, list + ): + cloned_agent.after_tool_callback = cloned_agent.after_tool_callback.copy() + + if update is None or 'tools' not in update: + cloned_agent.tools = cloned_agent.tools.copy() + @override async def _run_async_impl( self, ctx: InvocationContext diff --git a/tests/unittests/agents/test_agent_clone.py b/tests/unittests/agents/test_agent_clone.py index 7bda2a69c..2e5945d1f 100644 --- a/tests/unittests/agents/test_agent_clone.py +++ b/tests/unittests/agents/test_agent_clone.py @@ -374,6 +374,39 @@ def test_clone_with_sub_agents_update(): assert original.sub_agents[1].name == "original_sub2" +def test_clone_with_callbacks(): + """Test cloning with callbacks in list format.""" + before_agent_callback = [lambda *args, **kwargs: None] + after_agent_callback = [lambda *args, **kwargs: None] + before_model_callback = [lambda *args, **kwargs: None] + after_model_callback = [lambda *args, **kwargs: None] + before_tool_callback = [lambda *args, **kwargs: None] + after_tool_callback = [lambda *args, **kwargs: None] + tools = [lambda *args, **kwargs: None] + + original = LlmAgent( + name="original_agent", + description="Original agent", + before_agent_callback=before_agent_callback, + after_agent_callback=after_agent_callback, + before_model_callback=before_model_callback, + after_model_callback=after_model_callback, + before_tool_callback=before_tool_callback, + after_tool_callback=after_tool_callback, + tools=tools, + ) + + cloned = original.clone() + + assert id(original.before_agent_callback) != id(cloned.before_agent_callback) + assert id(original.after_agent_callback) != id(cloned.after_agent_callback) + assert id(original.before_model_callback) != id(cloned.before_model_callback) + assert id(original.after_model_callback) != id(cloned.after_model_callback) + assert id(original.before_tool_callback) != id(cloned.before_tool_callback) + assert id(original.after_tool_callback) != id(cloned.after_tool_callback) + assert id(original.tools) != id(cloned.tools) + + if __name__ == "__main__": # Run a specific test for debugging test_three_level_nested_agent()