Skip to content

Add on_tool_result()/on_tool_request() callbacks #48

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 99 additions & 7 deletions chatlas/_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Optional,
Sequence,
TypeVar,
cast,
overload,
)

Expand All @@ -42,7 +43,7 @@
from ._tools import Tool
from ._turn import Turn, user_turn
from ._typing_extensions import TypedDict
from ._utils import html_escape, wrap_async
from ._utils import html_escape, is_async_callable, wrap_async


class AnyTypeDict(TypedDict, total=False):
Expand Down Expand Up @@ -91,6 +92,14 @@ def __init__(
self.provider = provider
self._turns: list[Turn] = list(turns or [])
self._tools: dict[str, Tool] = {}
self._on_tool_request: Optional[
Callable[[ContentToolRequest], None]
| Callable[[ContentToolRequest], Awaitable[None]]
] = None
self._on_tool_result: Optional[
Callable[[ContentToolResult], None]
| Callable[[ContentToolResult], Awaitable[None]]
] = None
self._echo_options: EchoOptions = {
"rich_markdown": {},
"rich_console": {},
Expand Down Expand Up @@ -908,6 +917,34 @@ def add(a: int, b: int) -> int:
tool = Tool(func, model=model)
self._tools[tool.name] = tool

def on_tool_request(
self,
func: Callable[[ContentToolRequest], None]
| Callable[[ContentToolRequest], Awaitable[None]],
):
"""
Register a function to be called when a tool is requested.

This function will be called with a single argument, a `ContentToolRequest`
object, which contains the tool name and the input parameters for the tool.
"""
self._on_tool_request = func

def on_tool_result(
self,
func: Callable[[ContentToolResult], None]
| Callable[[ContentToolResult], Awaitable[None]],
):
"""
Register a function to be called when a tool result is received.

This function will be called with a single argument, a `ContentToolResult`
object, which contains the tool name and the output of the tool.

TODO: explain how to check for errors in the tool result
"""
self._on_tool_result = func

def export(
self,
filename: str | Path,
Expand Down Expand Up @@ -1205,12 +1242,30 @@ def _invoke_tools(self) -> Turn | None:
if turn is None:
return None

on_request = self._on_tool_request
if on_request is not None and is_async_callable(on_request):
raise ValueError(
"Cannot use async on_tool_request callback in a synchronous chat"
)

on_result = self._on_tool_result
if on_result is not None and is_async_callable(on_result):
raise ValueError(
"Cannot use async on_tool_result callback in a synchronous chat"
)

on_result = cast(Callable[[ContentToolResult], None], on_result)

results: list[ContentToolResult] = []
for x in turn.contents:
if isinstance(x, ContentToolRequest):
if on_request is not None:
on_request(x)
tool_def = self._tools.get(x.name, None)
func = tool_def.func if tool_def is not None else None
results.append(self._invoke_tool(func, x.arguments, x.id))
results.append(
self._invoke_tool(func, x.arguments, x.id, x.name, on_result)
)

if not results:
return None
Expand All @@ -1222,17 +1277,33 @@ async def _invoke_tools_async(self) -> Turn | None:
if turn is None:
return None

on_request = self._on_tool_request
if on_request is not None:
on_request = wrap_async(on_request)

on_result = self._on_tool_result
if on_result is not None:
on_result = wrap_async(on_result)

on_result = cast(Callable[[ContentToolResult], Awaitable[None]], on_result)

results: list[ContentToolResult] = []
for x in turn.contents:
if isinstance(x, ContentToolRequest):
if on_request is not None:
await on_request(x)
tool_def = self._tools.get(x.name, None)
func = None
if tool_def:
if tool_def._is_async:
func = tool_def.func
else:
func = wrap_async(tool_def.func)
results.append(await self._invoke_tool_async(func, x.arguments, x.id))
results.append(
await self._invoke_tool_async(
func, x.arguments, x.id, x.name, on_result
)
)

if not results:
return None
Expand All @@ -1244,12 +1315,18 @@ def _invoke_tool(
func: Callable[..., Any] | None,
arguments: object,
id_: str,
name: str,
on_result: Optional[Callable[[ContentToolResult], None]] = None,
) -> ContentToolResult:
if func is None:
return ContentToolResult(id_, value=None, error="Unknown tool")
res = ContentToolResult(id_, value=None, error="Unknown tool", name=name)
if on_result is not None:
on_result(res)
return res

name = func.__name__

res = None
try:
if isinstance(arguments, dict):
result = func(**arguments)
Expand All @@ -1259,19 +1336,30 @@ def _invoke_tool(
return ContentToolResult(id_, value=result, error=None, name=name)
except Exception as e:
log_tool_error(name, str(arguments), e)
return ContentToolResult(id_, value=None, error=str(e), name=name)
res = ContentToolResult(id_, value=None, error=str(e), name=name)

if on_result is not None:
on_result(res)

return res

@staticmethod
async def _invoke_tool_async(
func: Callable[..., Awaitable[Any]] | None,
arguments: object,
id_: str,
name: str,
on_result: Optional[Callable[[ContentToolResult], Awaitable[None]]] = None,
) -> ContentToolResult:
if func is None:
return ContentToolResult(id_, value=None, error="Unknown tool")
res = ContentToolResult(id_, value=None, error="Unknown tool", name=name)
if on_result is not None:
await on_result(res)
return res

name = func.__name__

res = None
try:
if isinstance(arguments, dict):
result = await func(**arguments)
Expand All @@ -1281,7 +1369,11 @@ async def _invoke_tool_async(
return ContentToolResult(id_, value=result, error=None, name=name)
except Exception as e:
log_tool_error(func.__name__, str(arguments), e)
return ContentToolResult(id_, value=None, error=str(e), name=name)
res = ContentToolResult(id_, value=None, error=str(e), name=name)
if on_result is not None:
await on_result(res)

return res

def _markdown_display(
self, echo: Literal["text", "all", "none"]
Expand Down
20 changes: 17 additions & 3 deletions chatlas/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,16 +185,28 @@ class ContentToolResult(Content):
Parameters
----------
id
The unique identifier of the tool request.
The unique identifier for the tool result.
name
The name of the tool/function that was called.
value
The value returned by the tool/function.
name
The name of the tool/function that was called.
error
An error message if the tool/function call failed.

Note
----
If the tool/function call failed, the `value` field will be `None` and the
`error` field will contain the error message.
If the tool/function call succeeded, the `value` field will contain the
return value and the `error` field will be `None`.
To get the actual result sent to the model assistant, use the `get_final_value()`
method.
"""

id: str
name: str
value: Any = None
name: Optional[str] = None
error: Optional[str] = None
Expand All @@ -209,7 +221,7 @@ def _get_value_and_language(self) -> tuple[str, str]:
return str(self.value), ""

def __str__(self):
comment = f"# tool result ({self.id})"
comment = f"# tool ({self.name}) result ({self.id})"
value, language = self._get_value_and_language()

return f"""```{language}\n{comment}\n{value}\n```"""
Expand All @@ -219,7 +231,9 @@ def _repr_markdown_(self):

def __repr__(self, indent: int = 0):
res = " " * indent
res += f"<ContentToolResult value='{self.value}' id='{self.id}'"
res += (
f"<ContentToolResult value='{self.value}' name='{self.name}' id='{self.id}'"
)
if self.error:
res += f" error='{self.error}'"
return res + ">"
Expand Down
Loading