Spaces:
Sleeping
Sleeping
import json | |
import warnings | |
from copy import deepcopy | |
from typing import Callable, Dict, List, Union | |
from lagent.actions import ActionExecutor, AsyncActionExecutor, AsyncIPythonInterpreter, IPythonInteractive | |
from lagent.agents.agent import Agent, AsyncAgent | |
from lagent.agents.aggregator import InternLMToolAggregator | |
from lagent.hooks import InternLMActionProcessor | |
from lagent.llms import BaseLLM | |
from lagent.memory import Memory | |
from lagent.prompts.parsers import InterpreterParser, MixedToolParser, PluginParser, ToolStatusCode | |
from lagent.schema import AgentMessage | |
from lagent.utils import create_object | |
API_PREFIX = ( | |
"This is the subfunction for tool '{tool_name}', you can use this tool. " | |
'The description of this function is: \n{description}' | |
) | |
META_CN = '当开启工具以及代码时,根据需求选择合适的工具进行调用' | |
INTERPRETER_CN = ( | |
'你现在已经能够在一个有状态的 Jupyter 笔记本环境中运行 Python 代码。' | |
'当你向 python 发送含有 Python 代码的消息时,它将在该环境中执行。' | |
'这个工具适用于多种场景,如数据分析或处理(包括数据操作、统计分析、图表绘制),' | |
'复杂的计算问题(解决数学和物理难题),编程示例(理解编程概念或特性),' | |
'文本处理和分析(比如文本解析和自然语言处理),' | |
'机器学习和数据科学(用于展示模型训练和数据可视化),' | |
'以及文件操作和数据导入(处理CSV、JSON等格式的文件)。' | |
) | |
PLUGIN_CN = ( | |
'你可以使用如下工具:' | |
'\n{prompt}\n' | |
'如果你已经获得足够信息,请直接给出答案. 避免不必要的工具调用! ' | |
'同时注意你可以使用的工具,不要随意捏造!' | |
) | |
def get_plugin_prompt(actions, api_desc_template=API_PREFIX): | |
plugin_descriptions = [] | |
for action in actions if isinstance(actions, list) else [actions]: | |
action = create_object(action) | |
action_desc = deepcopy(action.description) | |
if action.is_toolkit: | |
for api in action_desc['api_list']: | |
api['name'] = f"{action.name}.{api['name']}" | |
api['description'] = api_desc_template.format(tool_name=action.name, description=api['description']) | |
api['parameters'] = [param for param in api['parameters'] if param['name'] in api['required']] | |
plugin_descriptions.append(api) | |
else: | |
action_desc['description'] = api_desc_template.format( | |
tool_name=action.name, description=action_desc['description'] | |
) | |
action_desc['parameters'] = [ | |
param for param in action_desc['parameters'] if param['name'] in action_desc['required'] | |
] | |
plugin_descriptions.append(action_desc) | |
return json.dumps(plugin_descriptions, ensure_ascii=False, indent=4) | |
class AgentForInternLM(Agent): | |
_INTERNAL_AGENT_CLS = Agent | |
def __init__( | |
self, | |
llm: Union[BaseLLM, Dict], | |
plugins: Union[dict, List[dict]] = None, | |
interpreter: dict = None, | |
template: Union[str, dict, List[dict]] = None, | |
memory: Dict = dict(type=Memory), | |
output_format: Dict = dict( | |
type=MixedToolParser, | |
template=META_CN, | |
parsers=[ | |
dict(type=PluginParser, template=PLUGIN_CN), | |
dict(type=InterpreterParser, template=INTERPRETER_CN), | |
], | |
), | |
aggregator: Dict = dict(type=InternLMToolAggregator), | |
action_hooks: List = [dict(type=InternLMActionProcessor)], | |
finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, | |
max_turn: int = 4, | |
**kwargs, | |
): | |
self.agent = self._INTERNAL_AGENT_CLS( | |
llm=llm, | |
template=template, | |
output_format=output_format, | |
memory=memory, | |
aggregator=aggregator, | |
hooks=kwargs.pop('hooks', None), | |
) | |
self.plugin_executor = plugins and ActionExecutor(plugins, hooks=action_hooks) | |
self.interpreter_executor = interpreter and ActionExecutor(interpreter, hooks=action_hooks) | |
if not (self.plugin_executor or self.interpreter_executor): | |
warnings.warn( | |
'Neither plugin nor interpreter executor is initialized. ' | |
'An exception will be thrown when the agent call a tool.' | |
) | |
self.finish_condition = finish_condition | |
self.max_turn = max_turn | |
super().__init__(**kwargs) | |
def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
for _ in range(self.max_turn): | |
message = self.agent(message, session_id=session_id, **kwargs) | |
assert isinstance(message.formatted, dict) | |
if self.finish_condition(message): | |
return message | |
if message.formatted['tool_type']: | |
tool_type = message.formatted["tool_type"] | |
executor = getattr(self, f'{tool_type}_executor', None) | |
if not executor: | |
raise RuntimeError(f'No available {tool_type} executor') | |
message = executor(message, session_id=session_id) | |
return message | |
def get_steps(self, session_id=0): | |
steps, tool_type = [], None | |
for msg in self.agent.memory.get_memory(session_id): | |
if msg.sender == self.agent.name: | |
steps.append(dict(role='thought', content=msg.formatted['thought'])) | |
if msg.formatted['tool_type']: | |
tool_type = msg.formatted['tool_type'] | |
steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) | |
elif msg.sender != 'user': | |
feedback = dict(role='environment', content=msg.content) | |
if tool_type: | |
feedback['name'] = tool_type | |
steps.append(feedback) | |
return steps | |
class MathCoder(AgentForInternLM): | |
def __init__( | |
self, | |
llm: Union[BaseLLM, Dict], | |
interpreter: dict = dict(type=IPythonInteractive, timeout=20, max_out_len=8192), | |
template: Union[str, dict, List[dict]] = None, | |
memory: Dict = dict(type=Memory), | |
output_format: Dict = dict( | |
type=InterpreterParser, | |
template=( | |
'Integrate step-by-step reasoning and Python code to solve math problems ' | |
'using the following guidelines:\n' | |
'- Analyze the question and write jupyter code to solve the problem;\n' | |
r"- Present the final result in LaTeX using a '\boxed{{}}' without any " | |
'units. \n' | |
), | |
), | |
aggregator: Dict = dict(type=InternLMToolAggregator), | |
action_hooks: List = [dict(type=InternLMActionProcessor)], | |
finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, | |
max_turn: int = 6, | |
**kwargs, | |
): | |
kwargs.pop('plugins', None) | |
super().__init__( | |
llm=llm, | |
interpreter=interpreter, | |
template=template, | |
memory=memory, | |
output_format=output_format, | |
aggregator=aggregator, | |
action_hooks=action_hooks, | |
finish_condition=finish_condition, | |
max_turn=max_turn, | |
**kwargs, | |
) | |
class AsyncAgentForInternLM(AsyncAgent): | |
_INTERNAL_AGENT_CLS = AsyncAgent | |
def __init__( | |
self, | |
llm: Union[BaseLLM, Dict], | |
plugins: Union[dict, List[dict]] = None, | |
interpreter: dict = None, | |
template: Union[str, dict, List[dict]] = None, | |
memory: Dict = dict(type=Memory), | |
output_format: Dict = dict( | |
type=MixedToolParser, | |
template=META_CN, | |
parsers=[ | |
dict(type=PluginParser, template=PLUGIN_CN), | |
dict(type=InterpreterParser, template=INTERPRETER_CN), | |
], | |
), | |
aggregator: Dict = dict(type=InternLMToolAggregator), | |
action_hooks: List = [dict(type=InternLMActionProcessor)], | |
finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, | |
max_turn: int = 4, | |
**kwargs, | |
): | |
self.agent = self._INTERNAL_AGENT_CLS( | |
llm=llm, | |
template=template, | |
output_format=output_format, | |
memory=memory, | |
aggregator=aggregator, | |
hooks=kwargs.pop('hooks', None), | |
) | |
self.plugin_executor = plugins and AsyncActionExecutor(plugins, hooks=action_hooks) | |
self.interpreter_executor = interpreter and AsyncActionExecutor(interpreter, hooks=action_hooks) | |
if not (self.plugin_executor or self.interpreter_executor): | |
warnings.warn( | |
'Neither plugin nor interpreter executor is initialized. ' | |
'An exception will be thrown when the agent call a tool.' | |
) | |
self.finish_condition = finish_condition | |
self.max_turn = max_turn | |
super().__init__(**kwargs) | |
async def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
for _ in range(self.max_turn): | |
message = await self.agent(message, session_id=session_id, **kwargs) | |
assert isinstance(message.formatted, dict) | |
if self.finish_condition(message): | |
return message | |
if message.formatted['tool_type']: | |
tool_type = message.formatted["tool_type"] | |
executor = getattr(self, f'{tool_type}_executor', None) | |
if not executor: | |
raise RuntimeError(f'No available {tool_type} executor') | |
message = await executor(message, session_id=session_id) | |
return message | |
def get_steps(self, session_id=0): | |
steps, tool_type = [], None | |
for msg in self.agent.memory.get_memory(session_id): | |
if msg.sender == self.agent.name: | |
steps.append(dict(role='thought', content=msg.formatted['thought'])) | |
if msg.formatted['tool_type']: | |
tool_type = msg.formatted['tool_type'] | |
steps.append(dict(role='tool', content=msg.formatted['action'], name=tool_type)) | |
elif msg.sender != 'user': | |
feedback = dict(role='environment', content=msg.content) | |
if tool_type: | |
feedback['name'] = tool_type | |
steps.append(feedback) | |
return steps | |
class AsyncMathCoder(AsyncAgentForInternLM): | |
def __init__( | |
self, | |
llm: Union[BaseLLM, Dict], | |
interpreter: dict = dict(type=AsyncIPythonInterpreter), | |
template: Union[str, dict, List[dict]] = None, | |
memory: Dict = dict(type=Memory), | |
output_format: Dict = dict( | |
type=InterpreterParser, | |
template=( | |
'Integrate step-by-step reasoning and Python code to solve math problems ' | |
'using the following guidelines:\n' | |
'- Analyze the question and write jupyter code to solve the problem;\n' | |
r"- Present the final result in LaTeX using a '\boxed{{}}' without any " | |
'units. \n' | |
), | |
), | |
aggregator: Dict = dict(type=InternLMToolAggregator), | |
action_hooks: List = [dict(type=InternLMActionProcessor)], | |
finish_condition: Callable[[AgentMessage], bool] = lambda m: m.formatted['status'] == ToolStatusCode.NO_TOOL, | |
max_turn: int = 6, | |
**kwargs, | |
): | |
kwargs.pop('plugins', None) | |
super().__init__( | |
llm=llm, | |
interpreter=interpreter, | |
template=template, | |
memory=memory, | |
output_format=output_format, | |
aggregator=aggregator, | |
action_hooks=action_hooks, | |
finish_condition=finish_condition, | |
max_turn=max_turn, | |
**kwargs, | |
) | |
async def forward(self, message: AgentMessage, session_id=0, **kwargs): | |
try: | |
return await super().forward(message, session_id, **kwargs) | |
finally: | |
interpreter = next(iter(self.interpreter_executor.actions.values())) | |
if interpreter.name == 'AsyncIPythonInterpreter': | |
await interpreter.close_session(session_id) | |