Skip to content

fix: updating BaseAgent.clone() and LlmAgent.clone() to properly clone fields that are lists #2091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions src/google/adk/agents/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,26 @@ class BaseAgent(BaseModel):
response and appended to event history as agent response.
"""

def _clone_list_fields(
self,
cloned_agent: BaseAgent,
update: Mapping[str, Any] | None,
) -> None:
"""Shallow copies fields that are lists and not provided in the update."""
if (update is None or 'before_agent_callback' not in update) and isinstance(
cloned_agent.before_agent_callback, list
):
cloned_agent.before_agent_callback = (
cloned_agent.before_agent_callback.copy()
)

if (update is None or 'after_agent_callback' not in update) and isinstance(
cloned_agent.after_agent_callback, list
):
cloned_agent.after_agent_callback = (
cloned_agent.after_agent_callback.copy()
)

def clone(
self: SelfAgent, update: Mapping[str, Any] | None = None
) -> SelfAgent:
Expand Down Expand Up @@ -165,6 +185,11 @@ def clone(

cloned_agent = self.model_copy(update=update)

# If any field is stored as list and not provided in the update, need to
# shallow copy it for the cloned agent to avoid sharing the same list object
# with the original agent.
self._clone_list_fields(cloned_agent=cloned_agent, update=update)

if update is None or 'sub_agents' not in update:
# If `sub_agents` is not provided in the update, need to recursively clone
# the sub-agents to avoid sharing the sub-agents with the original agent.
Expand Down
38 changes: 38 additions & 0 deletions src/google/adk/agents/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing import Callable
from typing import List
from typing import Literal
from typing import Mapping
from typing import Optional
from typing import Type
from typing import Union
Expand Down Expand Up @@ -276,6 +277,43 @@ class LlmAgent(BaseAgent):
"""
# Callbacks - End

@override
def _clone_list_fields(
self,
cloned_agent: LlmAgent,
update: Mapping[str, Any] | None,
) -> None:
super()._clone_list_fields(cloned_agent=cloned_agent, update=update)

if (update is None or 'before_model_callback' not in update) and isinstance(
cloned_agent.before_model_callback, list
):
cloned_agent.before_model_callback = (
cloned_agent.before_model_callback.copy()
)

if (update is None or 'after_model_callback' not in update) and isinstance(
cloned_agent.after_model_callback, list
):
cloned_agent.after_model_callback = (
cloned_agent.after_model_callback.copy()
)

if (update is None or 'before_tool_callback' not in update) and isinstance(
cloned_agent.before_tool_callback, list
):
cloned_agent.before_tool_callback = (
cloned_agent.before_tool_callback.copy()
)

if (update is None or 'after_tool_callback' not in update) and isinstance(
cloned_agent.after_tool_callback, list
):
cloned_agent.after_tool_callback = cloned_agent.after_tool_callback.copy()

if update is None or 'tools' not in update:
cloned_agent.tools = cloned_agent.tools.copy()

@override
async def _run_async_impl(
self, ctx: InvocationContext
Expand Down
33 changes: 33 additions & 0 deletions tests/unittests/agents/test_agent_clone.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,39 @@ def test_clone_with_sub_agents_update():
assert original.sub_agents[1].name == "original_sub2"


def test_clone_with_callbacks():
"""Test cloning with callbacks in list format."""
before_agent_callback = [lambda *args, **kwargs: None]
after_agent_callback = [lambda *args, **kwargs: None]
before_model_callback = [lambda *args, **kwargs: None]
after_model_callback = [lambda *args, **kwargs: None]
before_tool_callback = [lambda *args, **kwargs: None]
after_tool_callback = [lambda *args, **kwargs: None]
tools = [lambda *args, **kwargs: None]

original = LlmAgent(
name="original_agent",
description="Original agent",
before_agent_callback=before_agent_callback,
after_agent_callback=after_agent_callback,
before_model_callback=before_model_callback,
after_model_callback=after_model_callback,
before_tool_callback=before_tool_callback,
after_tool_callback=after_tool_callback,
tools=tools,
)

cloned = original.clone()

assert id(original.before_agent_callback) != id(cloned.before_agent_callback)
assert id(original.after_agent_callback) != id(cloned.after_agent_callback)
assert id(original.before_model_callback) != id(cloned.before_model_callback)
assert id(original.after_model_callback) != id(cloned.after_model_callback)
assert id(original.before_tool_callback) != id(cloned.before_tool_callback)
assert id(original.after_tool_callback) != id(cloned.after_tool_callback)
assert id(original.tools) != id(cloned.tools)


if __name__ == "__main__":
# Run a specific test for debugging
test_three_level_nested_agent()