Skip to content
8 changes: 8 additions & 0 deletions examples/internlm2_agent_web_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import streamlit as st

from lagent.actions import ActionExecutor, ArxivSearch, IPythonInterpreter
from lagent.actions.agentlego_wrapper import AgentLegoToolkit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from lagent.actions.agentlego_wrapper import AgentLegoToolkit
# from lagent.actions.agentlego_wrapper import AgentLegoToolkit

from lagent.agents.internlm2_agent import INTERPRETER_CN, META_CN, PLUGIN_CN, Internlm2Agent, Internlm2Protocol
from lagent.llms.lmdepoly_wrapper import LMDeployClient
from lagent.llms.meta_template import INTERNLM2_META as META
Expand All @@ -23,6 +24,13 @@ def init_state(self):

action_list = [
ArxivSearch(),
AgentLegoToolkit(
type='ImageDescription',
url='http://127.0.0.1:16180/openapi.json'),
AgentLegoToolkit(
type='Calculator', url='http://127.0.0.1:16181/openapi.json'),
AgentLegoToolkit(
type='PluginMarket', url='http://127.0.0.1:16182/openapi.json')
]
st.session_state['plugin_map'] = {
action.name: action
Expand Down
51 changes: 51 additions & 0 deletions lagent/actions/agentlego_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Optional

# from agentlego.parsers import DefaultParser
from agentlego.tools.remote import RemoteTool

from lagent import BaseAction
from lagent.actions.parser import JsonParser


class AgentLegoToolkit(BaseAction):

def __init__(self,
type: str,
url: Optional[str] = None,
text: Optional[str] = None,
spec_dict: Optional[dict] = None,
parser=JsonParser,
enable: bool = True):

if url is not None:
spec = dict(url=url)
elif text is not None:
spec = dict(text=text)
else:
assert spec_dict is not None
spec = dict(spec_dict=spec_dict)
if url is not None and not url.endswith('.json'):
api_list = [RemoteTool.from_url(url).to_lagent()]
else:
api_list = [
api.to_lagent() for api in RemoteTool.from_openapi(**spec)
]
api_desc = []
for api in api_list:
api_desc.append(api.description)
if len(api_list) > 1:
tool_description = dict(name=type, api_list=api_desc)
self.add_method(api_list)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
self.add_method(api_list)
self.add_method(api_list)
Suggested change
self.add_method(api_list)
for func in api_list:
setattr(self, func.name, func.run)

else:
tool_description = api_desc[0]
setattr(self, 'run', api_list[0].run)
super().__init__(
description=tool_description, parser=parser, enable=enable)

@property
def is_toolkit(self):
return 'api_list' in self.description

def add_method(self, funcs):
for func in funcs:
setattr(self, func.name, func.run)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove