Skip to content

Implement FunctionTool #1604

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
72 changes: 62 additions & 10 deletions src/smolagents/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,18 @@
# limitations under the License.
from __future__ import annotations


__all__ = [
"AUTHORIZED_TYPES",
"FunctionTool",
"Tool",
"function_tool",
"tool",
"load_tool",
"launch_gradio_demo",
"ToolCollection",
]

import ast
import inspect
import json
Expand Down Expand Up @@ -103,6 +115,56 @@ def __call__(self, *args, **kwargs) -> Any:
pass


class FunctionTool(BaseTool):
def __init__(self, func: Callable, name: str | None = None):
self.func = func
self.name = name or func.__name__

def __call__(self, *args, **kwargs):
return self.func(*args, **kwargs)

def to_code_prompt(self) -> str:
tool_doc = f'"""{inspect.getdoc(self.func)}\n"""'
return f"def {self.name}{inspect.signature(self.func)}:\n{textwrap.indent(tool_doc, ' ')}"

def to_tool_calling_prompt(self) -> str:
schema = get_json_schema(self.func)
return json.dumps(schema["function"])


def function_tool(tool_function: Callable) -> FunctionTool:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like this refactoring!
For the decorator, why not keep the simple name @tool? "function_" is not needed since users already know that what they are using as source is a function, they just want to convert it to a tool, thus @tool is enough IMO

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is because of backward compatibility: we should first deprecate the old tool decorator and eventually replace it with the new one.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But if you manage to maintain functionality (as judged by all the tests passing), switching the logic behind the decorator shouldn't break anything?

"""
Convert a function into a [`FunctionTool`] instance.

This decorator wraps a regular function to create a [`FunctionTool`] that can be used
with agents. The function should have proper type hints and a docstring with
argument descriptions.

Args:
tool_function (`Callable`): Function to convert into a FunctionTool instance.
Should have type hints for each input and a type hint for the output.
Should also have a docstring including the description of the function
and an 'Args:' part where each argument is described.

Returns:
`FunctionTool`: FunctionTool instance wrapping the provided function.

Example:
```python
@function_tool
def calculate_sum(a: int, b: int) -> int:
'''Calculate the sum of two numbers.

Args:
a: First number
b: Second number
'''
return a + b
```
"""
return FunctionTool(tool_function)


class Tool(BaseTool):
"""
A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the
Expand Down Expand Up @@ -1313,13 +1375,3 @@ def validate_tool_arguments(tool: Tool, arguments: Any) -> None:
expected_type = list(tool.inputs.values())[0]["type"]
if _get_json_schema_type(type(arguments))["type"] != expected_type and not expected_type == "any":
raise TypeError(f"Argument has type '{type(arguments).__name__}' but should be '{expected_type}'")


__all__ = [
"AUTHORIZED_TYPES",
"Tool",
"tool",
"load_tool",
"launch_gradio_demo",
"ToolCollection",
]
180 changes: 179 additions & 1 deletion tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,189 @@
import pytest

from smolagents.agent_types import _AGENT_TYPE_MAPPING
from smolagents.tools import AUTHORIZED_TYPES, Tool, ToolCollection, launch_gradio_demo, tool, validate_tool_arguments
from smolagents.tools import (
AUTHORIZED_TYPES,
FunctionTool,
Tool,
ToolCollection,
launch_gradio_demo,
tool,
validate_tool_arguments,
)

from .utils.markers import require_run_all


def simple_func(x):
return x + 1


def complex_func(a, b=1):
return a + b


def documented_func(x: int) -> int:
"""This function adds one to the input.

Args:
x: Integer input.
"""
return x + 1


def multi_param(a: str, b: int = 1, *args, **kwargs) -> str:
"""Complex function.

Args:
a: String input.
b: Integer input, defaults to 1.
args: Variable length argument list.
kwargs: Arbitrary keyword arguments.
"""
return a * b


def typed_func(x: int, y: str = "hello") -> dict[str, Any]:
"""Function with type hints.

Args:
x: Integer input.
y: String input, defaults to 'hello'.
"""
return {"result": x}


def multi_typed(numbers: list[int], flag: bool = False) -> int | None:
"""Function with complex types.

Args:
numbers: List of integers to sum.
flag: Boolean flag to control the output.
"""
return sum(numbers) if flag else None


class TestFunctionTool:
@pytest.mark.parametrize(
"func, custom_name, expected_name",
[
(lambda x: x + 1, None, "<lambda>"),
(lambda x: x * 2, "multiplier", "multiplier"),
(simple_func, None, "simple_func"),
(complex_func, "adder", "adder"),
],
)
def test_initialization(self, func, custom_name, expected_name):
"""Test initialization of FunctionTool with different functions and names."""
tool = FunctionTool(func, name=custom_name)
assert tool.func == func
assert tool.name == expected_name

@pytest.mark.parametrize(
"func, args, kwargs, expected_result",
[
(lambda x: x + 1, [2], {}, 3),
(lambda x, y: x * y, [3, 4], {}, 12),
(lambda x, y=2: x * y, [3], {}, 6),
(lambda x, y=2: x * y, [3], {"y": 4}, 12),
(lambda **kwargs: sum(kwargs.values()), [], {"a": 1, "b": 2}, 3),
],
)
def test_call_method(self, func, args, kwargs, expected_result):
"""Test the __call__ method correctly passes arguments to the wrapped function."""
tool = FunctionTool(func)
result = tool(*args, **kwargs)
assert result == expected_result

@pytest.mark.parametrize(
"func, expected",
[
(
documented_func,
dedent('''\
def documented_func(x: int) -> int:
"""This function adds one to the input.

Args:
x: Integer input.
"""'''),
),
(
multi_param,
dedent('''\
def multi_param(a: str, b: int = 1, *args, **kwargs) -> str:
"""Complex function.

Args:
a: String input.
b: Integer input, defaults to 1.
args: Variable length argument list.
kwargs: Arbitrary keyword arguments.
"""'''),
),
],
)
def test_to_code_prompt(self, func, expected):
"""Test that to_code_prompt correctly formats the function as code."""
tool = FunctionTool(func)
code_prompt = tool.to_code_prompt()
assert code_prompt == expected

@pytest.mark.parametrize(
"func, expected",
[
(
typed_func,
'{"name": "typed_func", "description": "Function with type hints.", "parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "Integer input."}, "y": {"type": "string", "nullable": true, "description": "String input, defaults to \'hello\'."}}, "required": ["x"]}, "return": {"type": "object", "additionalProperties": {"type": "any"}}}',
),
(
multi_typed,
'{"name": "multi_typed", "description": "Function with complex types.", "parameters": {"type": "object", "properties": {"numbers": {"type": "array", "items": {"type": "integer"}, "description": "List of integers to sum."}, "flag": {"type": "boolean", "nullable": true, "description": "Boolean flag to control the output."}}, "required": ["numbers"]}, "return": {"type": "integer", "nullable": true}}',
),
],
)
def test_to_tool_calling_prompt(self, func, expected):
"""Test that to_tool_calling_prompt returns valid JSON schema representation."""
tool = FunctionTool(func)
tool_calling_prompt = tool.to_tool_calling_prompt()
assert tool_calling_prompt == expected

def test_with_real_function(self):
"""Test FunctionTool with a real function to ensure complete integration."""

def calculate_area(length: float, width: float = 1.0) -> float:
"""Calculate the area of a rectangle.

Args:
length: The length of the rectangle
width: The width of the rectangle (defaults to 1.0)

Returns:
The calculated area
"""
return length * width

tool = FunctionTool(calculate_area)

# Test calling
assert tool(5, 3) == 15
assert tool(5) == 5

# Test code prompt
code_prompt = tool.to_code_prompt()
assert "def calculate_area" in code_prompt
assert "length: float" in code_prompt
assert "width: float = 1.0" in code_prompt
assert "Calculate the area of a rectangle" in code_prompt

# Test tool calling prompt
schema_str = tool.to_tool_calling_prompt()
assert "calculate_area" in schema_str
assert "length" in schema_str
assert "width" in schema_str
assert "float" in schema_str.lower() or "number" in schema_str.lower()


class ToolTesterMixin:
def test_inputs_output(self):
assert hasattr(self.tool, "inputs")
Expand Down