Skip to content

Commit 493846f

Browse files
author
Judd
committed
clean up.
1 parent 4112ae6 commit 493846f

File tree

6 files changed

+213
-513
lines changed

6 files changed

+213
-513
lines changed

scripts/tool_definition.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from collections.abc import Callable
2+
import inspect
3+
import traceback
4+
from types import GenericAlias
5+
from typing import Any, get_origin, Annotated
6+
from dataclasses import dataclass
7+
import json
8+
9+
_TOOL_HOOKS = {}
10+
_TOOL_DESCRIPTIONS = []
11+
12+
def json_try_decode(s: str) -> dict | None:
13+
try:
14+
return json.loads(s)
15+
except:
16+
return None
17+
18+
def json_decode_ignore_extra(s: str) -> dict | None:
19+
for i in range(len(s)):
20+
d = json_try_decode(s[:i + 1])
21+
if d is not None: return d
22+
return None
23+
24+
def register_tool(func: Callable):
25+
tool_name = func.__name__
26+
tool_description = inspect.getdoc(func).strip()
27+
python_params = inspect.signature(func).parameters
28+
tool_params = []
29+
30+
tpye_mapping = {
31+
"str": "string",
32+
"int": "integer",
33+
}
34+
35+
for name, param in python_params.items():
36+
annotation = param.annotation
37+
if annotation is inspect.Parameter.empty:
38+
raise TypeError(f"Parameter `{name}` missing type annotation")
39+
if get_origin(annotation) != Annotated:
40+
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
41+
42+
typ, (description, required) = annotation.__origin__, annotation.__metadata__
43+
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
44+
if not isinstance(description, str):
45+
raise TypeError(f"Description for `{name}` must be a string")
46+
if not isinstance(required, bool):
47+
raise TypeError(f"Required for `{name}` must be a bool")
48+
49+
if typ in tpye_mapping:
50+
typ = tpye_mapping[typ]
51+
52+
tool_params.append({
53+
"name": name,
54+
"description": description,
55+
"type": typ,
56+
"required": required,
57+
})
58+
59+
tool_def = {
60+
"name": tool_name,
61+
"description": tool_description,
62+
"parameters": tool_params,
63+
}
64+
# print("[registered tool] " + pformat(tool_def))
65+
_TOOL_HOOKS[tool_name] = func
66+
_TOOL_DESCRIPTIONS.append(tool_def)
67+
68+
return func
69+
70+
@register_tool
71+
def get_weather(
72+
city_name: Annotated[str, "The name of the city to be queried", True],
73+
) -> str:
74+
"""
75+
Get the current weather for `city_name`
76+
"""
77+
78+
if not isinstance(city_name, str):
79+
raise TypeError("City name must be a string")
80+
81+
key_selection = {
82+
"current_condition": [
83+
"temp_C",
84+
"FeelsLikeC",
85+
"humidity",
86+
"weatherDesc",
87+
"observation_time",
88+
],
89+
}
90+
import requests
91+
92+
try:
93+
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
94+
resp.raise_for_status()
95+
resp = resp.json()
96+
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
97+
except:
98+
import traceback
99+
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
100+
101+
return str(ret)
102+
103+
@dataclass
104+
class ToolObservation:
105+
content_type: str
106+
text: str
107+
image_url: str | None = None
108+
role_metadata: str | None = None
109+
metadata: Any = None
110+
id: str | None = None
111+
112+
def dispatch_tool(tool_name: str, tool_params: dict, session_id: str | None = None) -> ToolObservation:
113+
if tool_name not in _TOOL_HOOKS:
114+
err = f"Tool `{tool_name}` not found. Please use a provided tool."
115+
return ToolObservation("system_error", err, id=session_id)
116+
117+
tool_hook = _TOOL_HOOKS[tool_name]
118+
try:
119+
ret = tool_hook(**tool_params)
120+
if isinstance(ret, dict):
121+
ret = json.dumps(ret, ensure_ascii=False)
122+
else:
123+
ret = str(ret)
124+
return ToolObservation(tool_name, ret, id=session_id)
125+
except:
126+
err = traceback.format_exc()
127+
return ToolObservation("system_error", err, id=session_id)
128+
129+
if __name__ == '__main__':
130+
print(_TOOL_DESCRIPTIONS)

scripts/tool_glm3.py

Lines changed: 11 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -1,106 +1,21 @@
1-
"""
2-
Copied from: https://github.com/THUDM/ChatGLM3/blob/main/composite_demo/tool_registry.py
3-
4-
This code is the tool registration part. By registering the tool, the model can call the tool.
5-
This code provides extended functionality to the model, enabling it to call and interact with a variety of utilities
6-
through defined interfaces.
7-
"""
8-
91
import copy
10-
import inspect
11-
from pprint import pformat
12-
import traceback
13-
from types import GenericAlias
14-
from typing import get_origin, Annotated
152
import json, sys, re
163

17-
import binding
184
from binding import PATH_BINDS
195

20-
_TOOL_HOOKS = {}
21-
_TOOL_DESCRIPTIONS = {}
22-
23-
24-
def register_tool(func: callable):
25-
tool_name = func.__name__
26-
tool_description = inspect.getdoc(func).strip()
27-
python_params = inspect.signature(func).parameters
28-
tool_params = []
29-
for name, param in python_params.items():
30-
annotation = param.annotation
31-
if annotation is inspect.Parameter.empty:
32-
raise TypeError(f"Parameter `{name}` missing type annotation")
33-
if get_origin(annotation) != Annotated:
34-
raise TypeError(f"Annotation type for `{name}` must be typing.Annotated")
35-
36-
typ, (description, required) = annotation.__origin__, annotation.__metadata__
37-
typ: str = str(typ) if isinstance(typ, GenericAlias) else typ.__name__
38-
if not isinstance(description, str):
39-
raise TypeError(f"Description for `{name}` must be a string")
40-
if not isinstance(required, bool):
41-
raise TypeError(f"Required for `{name}` must be a bool")
42-
43-
tool_params.append({
44-
"name": name,
45-
"description": description,
46-
"type": typ,
47-
"required": required
48-
})
49-
tool_def = {
50-
"name": tool_name,
51-
"description": tool_description,
52-
"params": tool_params
53-
}
54-
55-
# print("[registered tool] " + pformat(tool_def))
56-
_TOOL_HOOKS[tool_name] = func
57-
_TOOL_DESCRIPTIONS[tool_name] = tool_def
58-
59-
return func
60-
61-
62-
def dispatch_tool(tool_name: str, tool_params: dict) -> str:
63-
if tool_name not in _TOOL_HOOKS:
64-
return f"Tool `{tool_name}` not found. Please use a provided tool."
65-
tool_call = _TOOL_HOOKS[tool_name]
66-
try:
67-
ret = tool_call(**tool_params)
68-
except:
69-
ret = traceback.format_exc()
70-
return str(ret)
71-
6+
import tool_definition
7+
from tool_definition import dispatch_tool
728

739
def get_tools() -> dict:
74-
return copy.deepcopy(_TOOL_DESCRIPTIONS)
7510

11+
def convert(tool: dict):
12+
r = copy.deepcopy(tool)
7613

77-
# Tool Definitions
78-
79-
@register_tool
80-
def get_weather(
81-
city_name: Annotated[str, 'The name of the city to be queried', True],
82-
) -> str:
83-
"""
84-
Get the current weather for `city_name`
85-
"""
86-
87-
if not isinstance(city_name, str):
88-
raise TypeError("City name must be a string")
89-
90-
key_selection = {
91-
"current_condition": ["temp_C", "FeelsLikeC", "humidity", "weatherDesc", "observation_time"],
92-
}
93-
import requests
94-
try:
95-
resp = requests.get(f"https://wttr.in/{city_name}?format=j1")
96-
resp.raise_for_status()
97-
resp = resp.json()
98-
ret = {k: {_v: resp[k][0][_v] for _v in v} for k, v in key_selection.items()}
99-
except:
100-
import traceback
101-
ret = "Error encountered while fetching weather data!\n" + traceback.format_exc()
14+
r['params'] = r['parameters']
15+
del r['parameters']
16+
return r
10217

103-
return str(ret)
18+
return [convert(t) for t in tool_definition._TOOL_DESCRIPTIONS]
10419

10520
def build_sys_prompt():
10621
return "Answer the following questions as best as you can. You have access to the following tools: \n\n" + \
@@ -124,9 +39,9 @@ def tool_call(*args, **kwargs) -> dict:
12439
code = extract_code('\n'.join(call_args_text))
12540
args = eval(code, {'tool_call': tool_call}, {})
12641
observation = dispatch_tool(tool_name, args)
127-
return observation
128-
except:
129-
print("error occurs")
42+
return observation.text
43+
except Exception as e:
44+
print(f"error occurs: {e}")
13045
return "failed to call the function"
13146

13247
class ToolChatLLM(ChatLLM):

0 commit comments

Comments
 (0)