Ricky_fan commited on
Commit
c8aba9a
·
1 Parent(s): a21c161

modify name

Browse files
actions/__init__.py DELETED
@@ -1,28 +0,0 @@
1
- from .action_executor import ActionExecutor, AsyncActionExecutor
2
- from .arxiv_search import ArxivSearch, AsyncArxivSearch
3
- from .base_action import BaseAction, tool_api
4
- from .bing_map import AsyncBINGMap, BINGMap
5
- from .builtin_actions import FinishAction, InvalidAction, NoAction
6
- from .google_scholar_search import AsyncGoogleScholar, GoogleScholar
7
- from .google_search import AsyncGoogleSearch, GoogleSearch
8
- from .ipython_interactive import AsyncIPythonInteractive, IPythonInteractive
9
- from .ipython_interpreter import AsyncIPythonInterpreter, IPythonInterpreter
10
- from .ipython_manager import IPythonInteractiveManager
11
- from .parser import BaseParser, JsonParser, TupleParser
12
- from .ppt import PPT, AsyncPPT
13
- from .python_interpreter import AsyncPythonInterpreter, PythonInterpreter
14
- from .web_browser import AsyncWebBrowser, WebBrowser
15
- from .weather_query import WeatherQuery
16
-
17
- __all__ = [
18
- 'BaseAction', 'ActionExecutor', 'AsyncActionExecutor', 'InvalidAction',
19
- 'FinishAction', 'NoAction', 'BINGMap', 'AsyncBINGMap', 'ArxivSearch',
20
- 'AsyncArxivSearch', 'GoogleSearch', 'AsyncGoogleSearch', 'GoogleScholar',
21
- 'AsyncGoogleScholar', 'IPythonInterpreter', 'AsyncIPythonInterpreter',
22
- 'IPythonInteractive', 'AsyncIPythonInteractive',
23
- 'IPythonInteractiveManager', 'PythonInterpreter', 'AsyncPythonInterpreter',
24
- 'PPT', 'AsyncPPT', 'WebBrowser', 'AsyncWebBrowser', 'BaseParser',
25
- 'JsonParser', 'TupleParser', 'tool_api', 'WeatherQuery' #0214 update weather
26
- ]
27
-
28
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/__pycache__/__init__.cpython-310.pyc DELETED
Binary file (1.47 kB)
 
actions/__pycache__/action_executor.cpython-310.pyc DELETED
Binary file (5.84 kB)
 
actions/__pycache__/arxiv_search.cpython-310.pyc DELETED
Binary file (3.19 kB)
 
actions/__pycache__/base_action.cpython-310.pyc DELETED
Binary file (11.6 kB)
 
actions/__pycache__/bing_map.cpython-310.pyc DELETED
Binary file (7.79 kB)
 
actions/__pycache__/builtin_actions.cpython-310.pyc DELETED
Binary file (3.89 kB)
 
actions/__pycache__/google_scholar_search.cpython-310.pyc DELETED
Binary file (13 kB)
 
actions/__pycache__/google_search.cpython-310.pyc DELETED
Binary file (6.93 kB)
 
actions/__pycache__/ipython_interactive.cpython-310.pyc DELETED
Binary file (8.41 kB)
 
actions/__pycache__/ipython_interpreter.cpython-310.pyc DELETED
Binary file (16.6 kB)
 
actions/__pycache__/ipython_manager.cpython-310.pyc DELETED
Binary file (7.11 kB)
 
actions/__pycache__/parser.cpython-310.pyc DELETED
Binary file (5.48 kB)
 
actions/__pycache__/ppt.cpython-310.pyc DELETED
Binary file (6.81 kB)
 
actions/__pycache__/python_interpreter.cpython-310.pyc DELETED
Binary file (5.38 kB)
 
actions/__pycache__/weather_query.cpython-310.pyc DELETED
Binary file (2.66 kB)
 
actions/__pycache__/web_browser.cpython-310.pyc DELETED
Binary file (28.8 kB)
 
actions/action_executor.py DELETED
@@ -1,198 +0,0 @@
1
- import inspect
2
- from collections import OrderedDict
3
- from typing import Callable, Dict, List, Union
4
-
5
- from lagent.actions.base_action import BaseAction
6
- from lagent.actions.builtin_actions import FinishAction, InvalidAction, NoAction
7
- from lagent.hooks import Hook, RemovableHandle
8
- from lagent.schema import ActionReturn, ActionValidCode, AgentMessage, FunctionCall
9
- from lagent.utils import create_object
10
-
11
-
12
- class ActionExecutor:
13
- """The action executor class.
14
-
15
- Args:
16
- actions (Union[BaseAction, List[BaseAction]]): The action or actions.
17
- invalid_action (BaseAction, optional): The invalid action. Defaults to
18
- InvalidAction().
19
- no_action (BaseAction, optional): The no action.
20
- Defaults to NoAction().
21
- finish_action (BaseAction, optional): The finish action. Defaults to
22
- FinishAction().
23
- finish_in_action (bool, optional): Whether the finish action is in the
24
- action list. Defaults to False.
25
- """
26
-
27
- def __init__(
28
- self,
29
- actions: Union[BaseAction, List[BaseAction], Dict, List[Dict]],
30
- invalid_action: BaseAction = dict(type=InvalidAction),
31
- no_action: BaseAction = dict(type=NoAction),
32
- finish_action: BaseAction = dict(type=FinishAction),
33
- finish_in_action: bool = False,
34
- hooks: List[Dict] = None,
35
- ):
36
-
37
- if not isinstance(actions, list):
38
- actions = [actions]
39
- finish_action = create_object(finish_action)
40
- if finish_in_action:
41
- actions.append(finish_action)
42
- for i, action in enumerate(actions):
43
- actions[i] = create_object(action)
44
- self.actions = {action.name: action for action in actions}
45
-
46
- self.invalid_action = create_object(invalid_action)
47
- self.no_action = create_object(no_action)
48
- self.finish_action = finish_action
49
- self._hooks: Dict[int, Hook] = OrderedDict()
50
- if hooks:
51
- for hook in hooks:
52
- hook = create_object(hook)
53
- self.register_hook(hook)
54
-
55
- def description(self) -> List[Dict]:
56
- actions = []
57
- for action_name, action in self.actions.items():
58
- if action.is_toolkit:
59
- for api in action.description['api_list']:
60
- api_desc = api.copy()
61
- api_desc['name'] = f"{action_name}.{api_desc['name']}"
62
- actions.append(api_desc)
63
- else:
64
- action_desc = action.description.copy()
65
- actions.append(action_desc)
66
- return actions
67
-
68
- def __contains__(self, name: str):
69
- return name in self.actions
70
-
71
- def keys(self):
72
- return list(self.actions.keys())
73
-
74
- def __setitem__(self, name: str, action: Union[BaseAction, Dict]):
75
- action = create_object(action)
76
- self.actions[action.name] = action
77
-
78
- def __delitem__(self, name: str):
79
- del self.actions[name]
80
-
81
- def forward(self, name, parameters, **kwargs) -> ActionReturn:
82
- action_name, api_name = (
83
- name.split('.') if '.' in name else (name, 'run'))
84
- action_return: ActionReturn = ActionReturn()
85
- if action_name not in self:
86
- if name == self.no_action.name:
87
- action_return = self.no_action(parameters)
88
- elif name == self.finish_action.name:
89
- action_return = self.finish_action(parameters)
90
- else:
91
- action_return = self.invalid_action(parameters)
92
- else:
93
- action_return = self.actions[action_name](parameters, api_name)
94
- action_return.valid = ActionValidCode.OPEN
95
- return action_return
96
-
97
- def __call__(self,
98
- message: AgentMessage,
99
- session_id=0,
100
- **kwargs) -> AgentMessage:
101
- # message.receiver = self.name
102
- for hook in self._hooks.values():
103
- result = hook.before_action(self, message, session_id)
104
- if result:
105
- message = result
106
-
107
- assert isinstance(message.content, FunctionCall) or (
108
- isinstance(message.content, dict) and 'name' in message.content
109
- and 'parameters' in message.content)
110
- if isinstance(message.content, dict):
111
- name = message.content.get('name')
112
- parameters = message.content.get('parameters')
113
- else:
114
- name = message.content.name
115
- parameters = message.content.parameters
116
-
117
- response_message = self.forward(
118
- name=name, parameters=parameters, **kwargs)
119
- if not isinstance(response_message, AgentMessage):
120
- response_message = AgentMessage(
121
- sender=self.__class__.__name__,
122
- content=response_message,
123
- )
124
-
125
- for hook in self._hooks.values():
126
- result = hook.after_action(self, response_message, session_id)
127
- if result:
128
- response_message = result
129
- return response_message
130
-
131
- def register_hook(self, hook: Callable):
132
- handle = RemovableHandle(self._hooks)
133
- self._hooks[handle.id] = hook
134
- return handle
135
-
136
-
137
- class AsyncActionExecutor(ActionExecutor):
138
-
139
- async def forward(self, name, parameters, **kwargs) -> ActionReturn:
140
- action_name, api_name = (
141
- name.split('.') if '.' in name else (name, 'run'))
142
- action_return: ActionReturn = ActionReturn()
143
- if action_name not in self:
144
- if name == self.no_action.name:
145
- action_return = self.no_action(parameters)
146
- elif name == self.finish_action.name:
147
- action_return = self.finish_action(parameters)
148
- else:
149
- action_return = self.invalid_action(parameters)
150
- else:
151
- action = self.actions[action_name]
152
- if inspect.iscoroutinefunction(action.__call__):
153
- action_return = await action(parameters, api_name)
154
- else:
155
- action_return = action(parameters, api_name)
156
- action_return.valid = ActionValidCode.OPEN
157
- return action_return
158
-
159
- async def __call__(self,
160
- message: AgentMessage,
161
- session_id=0,
162
- **kwargs) -> AgentMessage:
163
- # message.receiver = self.name
164
- for hook in self._hooks.values():
165
- if inspect.iscoroutinefunction(hook.before_action):
166
- result = await hook.before_action(self, message, session_id)
167
- else:
168
- result = hook.before_action(self, message, session_id)
169
- if result:
170
- message = result
171
-
172
- assert isinstance(message.content, FunctionCall) or (
173
- isinstance(message.content, dict) and 'name' in message.content
174
- and 'parameters' in message.content)
175
- if isinstance(message.content, dict):
176
- name = message.content.get('name')
177
- parameters = message.content.get('parameters')
178
- else:
179
- name = message.content.name
180
- parameters = message.content.parameters
181
-
182
- response_message = await self.forward(
183
- name=name, parameters=parameters, **kwargs)
184
- if not isinstance(response_message, AgentMessage):
185
- response_message = AgentMessage(
186
- sender=self.__class__.__name__,
187
- content=response_message,
188
- )
189
-
190
- for hook in self._hooks.values():
191
- if inspect.iscoroutinefunction(hook.after_action):
192
- result = await hook.after_action(self, response_message,
193
- session_id)
194
- else:
195
- result = hook.after_action(self, response_message, session_id)
196
- if result:
197
- response_message = result
198
- return response_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/arxiv_search.py DELETED
@@ -1,79 +0,0 @@
1
- from typing import Optional, Type
2
-
3
- from asyncer import asyncify
4
-
5
- from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
6
- from lagent.actions.parser import BaseParser, JsonParser
7
- from lagent.schema import ActionReturn, ActionStatusCode
8
-
9
-
10
- class ArxivSearch(BaseAction):
11
- """Search information from Arxiv.org. \
12
- Useful for when you need to answer questions about Physics, Mathematics, \
13
- Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
14
- Electrical Engineering, and Economics from scientific articles on arxiv.org.
15
- """
16
-
17
- def __init__(
18
- self,
19
- top_k_results: int = 3,
20
- max_query_len: int = 300,
21
- doc_content_chars_max: int = 1500,
22
- description: Optional[dict] = None,
23
- parser: Type[BaseParser] = JsonParser,
24
- ):
25
- super().__init__(description, parser)
26
- self.top_k_results = top_k_results
27
- self.max_query_len = max_query_len
28
- self.doc_content_chars_max = doc_content_chars_max
29
-
30
- @tool_api(explode_return=True)
31
- def get_arxiv_article_information(self, query: str) -> dict:
32
- """Run Arxiv search and get the article meta information.
33
-
34
- Args:
35
- query (:class:`str`): the content of search query
36
-
37
- Returns:
38
- :class:`dict`: article information
39
- * content (str): a list of 3 arxiv search papers
40
- """
41
- import arxiv
42
-
43
- try:
44
- results = arxiv.Search( # type: ignore
45
- query[: self.max_query_len], max_results=self.top_k_results
46
- ).results()
47
- except Exception as exc:
48
- return ActionReturn(errmsg=f'Arxiv exception: {exc}', state=ActionStatusCode.HTTP_ERROR)
49
- docs = [
50
- f'Published: {result.updated.date()}\nTitle: {result.title}\n'
51
- f'Authors: {", ".join(a.name for a in result.authors)}\n'
52
- f'Summary: {result.summary[:self.doc_content_chars_max]}'
53
- for result in results
54
- ]
55
- if docs:
56
- return {'content': '\n\n'.join(docs)}
57
- return {'content': 'No good Arxiv Result was found'}
58
-
59
-
60
- class AsyncArxivSearch(AsyncActionMixin, ArxivSearch):
61
- """Search information from Arxiv.org. \
62
- Useful for when you need to answer questions about Physics, Mathematics, \
63
- Computer Science, Quantitative Biology, Quantitative Finance, Statistics, \
64
- Electrical Engineering, and Economics from scientific articles on arxiv.org.
65
- """
66
-
67
- @tool_api(explode_return=True)
68
- @asyncify
69
- def get_arxiv_article_information(self, query: str) -> dict:
70
- """Run Arxiv search and get the article meta information.
71
-
72
- Args:
73
- query (:class:`str`): the content of search query
74
-
75
- Returns:
76
- :class:`dict`: article information
77
- * content (str): a list of 3 arxiv search papers
78
- """
79
- return super().get_arxiv_article_information(query)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/base_action.py DELETED
@@ -1,434 +0,0 @@
1
- import inspect
2
- import logging
3
- import re
4
- from abc import ABCMeta
5
- from copy import deepcopy
6
- from functools import wraps
7
- from typing import Callable, Optional, Type, get_args, get_origin
8
-
9
- try:
10
- from typing import Annotated
11
- except ImportError:
12
- from typing_extensions import Annotated
13
-
14
- from griffe import Docstring
15
-
16
- try:
17
- from griffe import DocstringSectionKind
18
- except ImportError:
19
- from griffe.enumerations import DocstringSectionKind
20
-
21
- from ..schema import ActionReturn, ActionStatusCode
22
- from .parser import BaseParser, JsonParser, ParseError
23
-
24
- logging.getLogger('griffe').setLevel(logging.ERROR)
25
-
26
-
27
- def tool_api(func: Optional[Callable] = None,
28
- *,
29
- explode_return: bool = False,
30
- returns_named_value: bool = False,
31
- **kwargs):
32
- """Turn functions into tools. It will parse typehints as well as docstrings
33
- to build the tool description and attach it to functions via an attribute
34
- ``api_description``.
35
-
36
- Examples:
37
-
38
- .. code-block:: python
39
-
40
- # typehints has higher priority than docstrings
41
- from typing import Annotated
42
-
43
- @tool_api
44
- def add(a: Annotated[int, 'augend'], b: Annotated[int, 'addend'] = 1):
45
- '''Add operation
46
-
47
- Args:
48
- x (int): a
49
- y (int): b
50
- '''
51
- return a + b
52
-
53
- print(add.api_description)
54
-
55
- Args:
56
- func (Optional[Callable]): function to decorate. Defaults to ``None``.
57
- explode_return (bool): whether to flatten the dictionary or tuple return
58
- as the ``return_data`` field. When enabled, it is recommended to
59
- annotate the member in docstrings. Defaults to ``False``.
60
-
61
- .. code-block:: python
62
-
63
- @tool_api(explode_return=True)
64
- def foo(a, b):
65
- '''A simple function
66
-
67
- Args:
68
- a (int): a
69
- b (int): b
70
-
71
- Returns:
72
- dict: information of inputs
73
- * x: value of a
74
- * y: value of b
75
- '''
76
- return {'x': a, 'y': b}
77
-
78
- print(foo.api_description)
79
-
80
- returns_named_value (bool): whether to parse ``thing: Description`` in
81
- returns sections as a name and description, rather than a type and
82
- description. When true, type must be wrapped in parentheses:
83
- ``(int): Description``. When false, parentheses are optional but
84
- the items cannot be named: ``int: Description``. Defaults to ``False``.
85
-
86
- Returns:
87
- Callable: wrapped function or partial decorator
88
-
89
- Important:
90
- ``return_data`` field will be added to ``api_description`` only
91
- when ``explode_return`` or ``returns_named_value`` is enabled.
92
- """
93
-
94
- def _detect_type(string):
95
- field_type = 'STRING'
96
- if 'list' in string:
97
- field_type = 'Array'
98
- elif 'str' not in string:
99
- if 'float' in string:
100
- field_type = 'FLOAT'
101
- elif 'int' in string:
102
- field_type = 'NUMBER'
103
- elif 'bool' in string:
104
- field_type = 'BOOLEAN'
105
- return field_type
106
-
107
- def _explode(desc):
108
- kvs = []
109
- desc = '\nArgs:\n' + '\n'.join([
110
- ' ' + item.lstrip(' -+*#.')
111
- for item in desc.split('\n')[1:] if item.strip()
112
- ])
113
- docs = Docstring(desc).parse('google')
114
- if not docs:
115
- return kvs
116
- if docs[0].kind is DocstringSectionKind.parameters:
117
- for d in docs[0].value:
118
- d = d.as_dict()
119
- if not d['annotation']:
120
- d.pop('annotation')
121
- else:
122
- d['type'] = _detect_type(d.pop('annotation').lower())
123
- kvs.append(d)
124
- return kvs
125
-
126
- def _parse_tool(function):
127
- # remove rst syntax
128
- docs = Docstring(
129
- re.sub(':(.+?):`(.+?)`', '\\2', function.__doc__ or '')).parse(
130
- 'google', returns_named_value=returns_named_value, **kwargs)
131
- desc = dict(
132
- name=function.__name__,
133
- description=docs[0].value
134
- if docs[0].kind is DocstringSectionKind.text else '',
135
- parameters=[],
136
- required=[],
137
- )
138
- args_doc, returns_doc = {}, []
139
- for doc in docs:
140
- if doc.kind is DocstringSectionKind.parameters:
141
- for d in doc.value:
142
- d = d.as_dict()
143
- d['type'] = _detect_type(d.pop('annotation').lower())
144
- args_doc[d['name']] = d
145
- if doc.kind is DocstringSectionKind.returns:
146
- for d in doc.value:
147
- d = d.as_dict()
148
- if not d['name']:
149
- d.pop('name')
150
- if not d['annotation']:
151
- d.pop('annotation')
152
- else:
153
- d['type'] = _detect_type(d.pop('annotation').lower())
154
- returns_doc.append(d)
155
-
156
- sig = inspect.signature(function)
157
- for name, param in sig.parameters.items():
158
- if name == 'self':
159
- continue
160
- parameter = dict(
161
- name=param.name,
162
- type='STRING',
163
- description=args_doc.get(param.name,
164
- {}).get('description', ''))
165
- annotation = param.annotation
166
- if annotation is inspect.Signature.empty:
167
- parameter['type'] = args_doc.get(param.name,
168
- {}).get('type', 'STRING')
169
- else:
170
- if get_origin(annotation) is Annotated:
171
- annotation, info = get_args(annotation)
172
- if info:
173
- parameter['description'] = info
174
- while get_origin(annotation):
175
- annotation = get_args(annotation)
176
- parameter['type'] = _detect_type(str(annotation))
177
- desc['parameters'].append(parameter)
178
- if param.default is inspect.Signature.empty:
179
- desc['required'].append(param.name)
180
-
181
- return_data = []
182
- if explode_return:
183
- return_data = _explode(returns_doc[0]['description'])
184
- elif returns_named_value:
185
- return_data = returns_doc
186
- if return_data:
187
- desc['return_data'] = return_data
188
- return desc
189
-
190
- if callable(func):
191
-
192
- if inspect.iscoroutinefunction(func):
193
-
194
- @wraps(func)
195
- async def wrapper(self, *args, **kwargs):
196
- return await func(self, *args, **kwargs)
197
-
198
- else:
199
-
200
- @wraps(func)
201
- def wrapper(self, *args, **kwargs):
202
- return func(self, *args, **kwargs)
203
-
204
- wrapper.api_description = _parse_tool(func)
205
- return wrapper
206
-
207
- def decorate(func):
208
-
209
- if inspect.iscoroutinefunction(func):
210
-
211
- @wraps(func)
212
- async def wrapper(self, *args, **kwargs):
213
- return await func(self, *args, **kwargs)
214
-
215
- else:
216
-
217
- @wraps(func)
218
- def wrapper(self, *args, **kwargs):
219
- return func(self, *args, **kwargs)
220
-
221
- wrapper.api_description = _parse_tool(func)
222
- return wrapper
223
-
224
- return decorate
225
-
226
-
227
- class ToolMeta(ABCMeta):
228
- """Metaclass of tools."""
229
-
230
- def __new__(mcs, name, base, attrs):
231
- is_toolkit, tool_desc = True, dict(
232
- name=name,
233
- description=Docstring(attrs.get('__doc__',
234
- '')).parse('google')[0].value)
235
- for key, value in attrs.items():
236
- if callable(value) and hasattr(value, 'api_description'):
237
- api_desc = getattr(value, 'api_description')
238
- if key == 'run':
239
- tool_desc['parameters'] = api_desc['parameters']
240
- tool_desc['required'] = api_desc['required']
241
- if api_desc['description']:
242
- tool_desc['description'] = api_desc['description']
243
- if api_desc.get('return_data'):
244
- tool_desc['return_data'] = api_desc['return_data']
245
- is_toolkit = False
246
- else:
247
- tool_desc.setdefault('api_list', []).append(api_desc)
248
- if not is_toolkit and 'api_list' in tool_desc:
249
- raise KeyError('`run` and other tool APIs can not be implemented '
250
- 'at the same time')
251
- if is_toolkit and 'api_list' not in tool_desc:
252
- is_toolkit = False
253
- if callable(attrs.get('run')):
254
- run_api = tool_api(attrs['run'])
255
- api_desc = run_api.api_description
256
- tool_desc['parameters'] = api_desc['parameters']
257
- tool_desc['required'] = api_desc['required']
258
- if api_desc['description']:
259
- tool_desc['description'] = api_desc['description']
260
- if api_desc.get('return_data'):
261
- tool_desc['return_data'] = api_desc['return_data']
262
- attrs['run'] = run_api
263
- else:
264
- tool_desc['parameters'], tool_desc['required'] = [], []
265
- attrs['_is_toolkit'] = is_toolkit
266
- attrs['__tool_description__'] = tool_desc
267
- return super().__new__(mcs, name, base, attrs)
268
-
269
-
270
- class BaseAction(metaclass=ToolMeta):
271
- """Base class for all actions.
272
-
273
- Args:
274
- description (:class:`Optional[dict]`): The description of the action.
275
- Defaults to ``None``.
276
- parser (:class:`Type[BaseParser]`): The parser class to process the
277
- action's inputs and outputs. Defaults to :class:`JsonParser`.
278
-
279
- Examples:
280
-
281
- * simple tool
282
-
283
- .. code-block:: python
284
-
285
- class Bold(BaseAction):
286
- '''Make text bold'''
287
-
288
- def run(self, text: str):
289
- '''
290
- Args:
291
- text (str): input text
292
-
293
- Returns:
294
- str: bold text
295
- '''
296
- return '**' + text + '**'
297
-
298
- action = Bold()
299
-
300
- * toolkit with multiple APIs
301
-
302
- .. code-block:: python
303
-
304
- class Calculator(BaseAction):
305
- '''Calculator'''
306
-
307
- @tool_api
308
- def add(self, a, b):
309
- '''Add operation
310
-
311
- Args:
312
- a (int): augend
313
- b (int): addend
314
-
315
- Returns:
316
- int: sum
317
- '''
318
- return a + b
319
-
320
- @tool_api
321
- def sub(self, a, b):
322
- '''Subtraction operation
323
-
324
- Args:
325
- a (int): minuend
326
- b (int): subtrahend
327
-
328
- Returns:
329
- int: difference
330
- '''
331
- return a - b
332
-
333
- action = Calculator()
334
- """
335
-
336
- def __init__(
337
- self,
338
- description: Optional[dict] = None,
339
- parser: Type[BaseParser] = JsonParser,
340
- ):
341
- self._description = deepcopy(description or self.__tool_description__)
342
- self._name = self._description['name']
343
- self._parser = parser(self)
344
-
345
- def __call__(self, inputs: str, name='run') -> ActionReturn:
346
- fallback_args = {'inputs': inputs, 'name': name}
347
- if not hasattr(self, name):
348
- return ActionReturn(
349
- fallback_args,
350
- type=self.name,
351
- errmsg=f'invalid API: {name}',
352
- state=ActionStatusCode.API_ERROR)
353
- try:
354
- inputs = self._parser.parse_inputs(inputs, name)
355
- except ParseError as exc:
356
- return ActionReturn(
357
- fallback_args,
358
- type=self.name,
359
- errmsg=exc.err_msg,
360
- state=ActionStatusCode.ARGS_ERROR)
361
- try:
362
- outputs = getattr(self, name)(**inputs)
363
- except Exception as exc:
364
- return ActionReturn(
365
- inputs,
366
- type=self.name,
367
- errmsg=str(exc),
368
- state=ActionStatusCode.API_ERROR)
369
- if isinstance(outputs, ActionReturn):
370
- action_return = outputs
371
- if not action_return.args:
372
- action_return.args = inputs
373
- if not action_return.type:
374
- action_return.type = self.name
375
- else:
376
- result = self._parser.parse_outputs(outputs)
377
- action_return = ActionReturn(inputs, type=self.name, result=result)
378
- return action_return
379
-
380
- @property
381
- def name(self):
382
- return self._name
383
-
384
- @property
385
- def is_toolkit(self):
386
- return self._is_toolkit
387
-
388
- @property
389
- def description(self) -> dict:
390
- """Description of the tool."""
391
- return self._description
392
-
393
- def __repr__(self):
394
- return f'{self.description}'
395
-
396
- __str__ = __repr__
397
-
398
-
399
- class AsyncActionMixin:
400
-
401
- async def __call__(self, inputs: str, name='run') -> ActionReturn:
402
- fallback_args = {'inputs': inputs, 'name': name}
403
- if not hasattr(self, name):
404
- return ActionReturn(
405
- fallback_args,
406
- type=self.name,
407
- errmsg=f'invalid API: {name}',
408
- state=ActionStatusCode.API_ERROR)
409
- try:
410
- inputs = self._parser.parse_inputs(inputs, name)
411
- except ParseError as exc:
412
- return ActionReturn(
413
- fallback_args,
414
- type=self.name,
415
- errmsg=exc.err_msg,
416
- state=ActionStatusCode.ARGS_ERROR)
417
- try:
418
- outputs = await getattr(self, name)(**inputs)
419
- except Exception as exc:
420
- return ActionReturn(
421
- inputs,
422
- type=self.name,
423
- errmsg=str(exc),
424
- state=ActionStatusCode.API_ERROR)
425
- if isinstance(outputs, ActionReturn):
426
- action_return = outputs
427
- if not action_return.args:
428
- action_return.args = inputs
429
- if not action_return.type:
430
- action_return.type = self.name
431
- else:
432
- result = self._parser.parse_outputs(outputs)
433
- action_return = ActionReturn(inputs, type=self.name, result=result)
434
- return action_return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/bing_map.py DELETED
@@ -1,268 +0,0 @@
1
- # flake8: noqa: E501
2
- import json
3
- import os
4
- from typing import Optional, Type
5
-
6
- import aiohttp
7
- import requests
8
-
9
- from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
10
- from lagent.actions.parser import BaseParser, JsonParser
11
-
12
-
13
- class BINGMap(BaseAction):
14
- """BING Map plugin for looking up map information."""
15
-
16
- def __init__(
17
- self,
18
- key: Optional[str] = None,
19
- description: Optional[dict] = None,
20
- parser: Type[BaseParser] = JsonParser,
21
- ) -> None:
22
- super().__init__(description, parser)
23
- key = os.environ.get('BING_MAP_KEY', key)
24
- if key is None:
25
- raise ValueError(
26
- 'Please set BING Map API key either in the environment '
27
- 'as BING_MAP_KEY or pass it as `key` parameter.')
28
- self.key = key
29
- self.base_url = 'http://dev.virtualearth.net/REST/V1/'
30
-
31
- @tool_api(explode_return=True)
32
- def get_distance(self, start: str, end: str) -> dict:
33
- """Get the distance between two locations in km.
34
-
35
- Args:
36
- start (:class:`str`): The start location
37
- end (:class:`str`): The end location
38
-
39
- Returns:
40
- :class:`dict`: distance information
41
- * distance (str): the distance in km.
42
- """
43
- # Request URL
44
- url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
45
- # GET request
46
- r = requests.get(url)
47
- # TODO check request status?
48
- data = json.loads(r.text)
49
- # Extract route information
50
- route = data['resourceSets'][0]['resources'][0]
51
- # Extract distance in miles
52
- distance = route['travelDistance']
53
- return dict(distance=distance)
54
-
55
- @tool_api(explode_return=True)
56
- def get_route(self, start: str, end: str) -> dict:
57
- """Get the route between two locations in km.
58
-
59
- Args:
60
- start (:class:`str`): The start location
61
- end (:class:`str`): The end location
62
-
63
- Returns:
64
- :class:`dict`: route information
65
- * route (list): the route, a list of actions.
66
- """
67
- # Request URL
68
- url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
69
- # GET request
70
- r = requests.get(url)
71
- data = json.loads(r.text)
72
- # Extract route information
73
- route = data['resourceSets'][0]['resources'][0]
74
- itinerary = route['routeLegs'][0]['itineraryItems']
75
- # Extract route text information
76
- route_text = []
77
- for item in itinerary:
78
- if 'instruction' in item:
79
- route_text.append(item['instruction']['text'])
80
- return dict(route=route_text)
81
-
82
- @tool_api(explode_return=True)
83
- def get_coordinates(self, location: str) -> dict:
84
- """Get the coordinates of a location.
85
-
86
- Args:
87
- location (:class:`str`): the location need to get coordinates.
88
-
89
- Returns:
90
- :class:`dict`: coordinates information
91
- * latitude (float): the latitude of the location.
92
- * longitude (float): the longitude of the location.
93
- """
94
- url = self.base_url + 'Locations'
95
- params = {'query': location, 'key': self.key}
96
- response = requests.get(url, params=params)
97
- json_data = response.json()
98
- coordinates = json_data['resourceSets'][0]['resources'][0]['point'][
99
- 'coordinates']
100
- return dict(latitude=coordinates[0], longitude=coordinates[1])
101
-
102
- @tool_api(explode_return=True)
103
- def search_nearby(self,
104
- search_term: str,
105
- places: str = 'unknown',
106
- latitude: float = 0.0,
107
- longitude: float = 0.0,
108
- radius: int = 5000) -> dict:
109
- """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
110
-
111
- Args:
112
- search_term (:class:`str`): the place name.
113
- places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
114
- latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
115
- longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
116
- radius (:class:`int`): radius in meters. Defaults to ``5000``.
117
-
118
- Returns:
119
- :class:`dict`: places information
120
- * places (list): the list of places, each place is a dict with name and address, at most 5 places.
121
- """
122
- url = self.base_url + 'LocalSearch'
123
- if places != 'unknown':
124
- pos = self.get_coordinates(**{'location': places})
125
- latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
126
- # Build the request query string
127
- params = {
128
- 'query': search_term,
129
- 'userLocation': f'{latitude},{longitude}',
130
- 'radius': radius,
131
- 'key': self.key
132
- }
133
- # Make the request
134
- response = requests.get(url, params=params)
135
- # Parse the response
136
- response_data = json.loads(response.content)
137
- # Get the results
138
- results = response_data['resourceSets'][0]['resources']
139
- addresses = []
140
- for result in results:
141
- name = result['name']
142
- address = result['Address']['formattedAddress']
143
- addresses.append(dict(name=name, address=address))
144
- if len(addresses) == 5:
145
- break
146
- return dict(place=addresses)
147
-
148
-
149
- class AsyncBINGMap(AsyncActionMixin, BINGMap):
150
- """BING Map plugin for looking up map information."""
151
-
152
- @tool_api(explode_return=True)
153
- async def get_distance(self, start: str, end: str) -> dict:
154
- """Get the distance between two locations in km.
155
-
156
- Args:
157
- start (:class:`str`): The start location
158
- end (:class:`str`): The end location
159
-
160
- Returns:
161
- :class:`dict`: distance information
162
- * distance (str): the distance in km.
163
- """
164
- # Request URL
165
- url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
166
- # GET request
167
- async with aiohttp.ClientSession() as session:
168
- async with session.get(url) as resp:
169
- # TODO check request status?
170
- data = await resp.json()
171
- # Extract route information
172
- route = data['resourceSets'][0]['resources'][0]
173
- # Extract distance in miles
174
- distance = route['travelDistance']
175
- return dict(distance=distance)
176
-
177
- @tool_api(explode_return=True)
178
- async def get_route(self, start: str, end: str) -> dict:
179
- """Get the route between two locations in km.
180
-
181
- Args:
182
- start (:class:`str`): The start location
183
- end (:class:`str`): The end location
184
-
185
- Returns:
186
- :class:`dict`: route information
187
- * route (list): the route, a list of actions.
188
- """
189
- # Request URL
190
- url = self.base_url + 'Routes/Driving?o=json&wp.0=' + start + '&wp.1=' + end + '&key=' + self.key
191
- # GET request
192
- async with aiohttp.ClientSession() as session:
193
- async with session.get(url) as resp:
194
- data = await resp.json()
195
- # Extract route information
196
- route = data['resourceSets'][0]['resources'][0]
197
- itinerary = route['routeLegs'][0]['itineraryItems']
198
- # Extract route text information
199
- route_text = []
200
- for item in itinerary:
201
- if 'instruction' in item:
202
- route_text.append(item['instruction']['text'])
203
- return dict(route=route_text)
204
-
205
- @tool_api(explode_return=True)
206
- async def get_coordinates(self, location: str) -> dict:
207
- """Get the coordinates of a location.
208
-
209
- Args:
210
- location (:class:`str`): the location need to get coordinates.
211
-
212
- Returns:
213
- :class:`dict`: coordinates information
214
- * latitude (float): the latitude of the location.
215
- * longitude (float): the longitude of the location.
216
- """
217
- url = self.base_url + 'Locations'
218
- params = {'query': location, 'key': self.key}
219
- async with aiohttp.ClientSession() as session:
220
- async with session.get(url, params=params) as resp:
221
- data = await resp.json()
222
- coordinates = data['resourceSets'][0]['resources'][0]['point'][
223
- 'coordinates']
224
- return dict(latitude=coordinates[0], longitude=coordinates[1])
225
-
226
- @tool_api(explode_return=True)
227
- async def search_nearby(self,
228
- search_term: str,
229
- places: str = 'unknown',
230
- latitude: float = 0.0,
231
- longitude: float = 0.0,
232
- radius: int = 5000) -> dict:
233
- """Search for places nearby a location, within a given radius, and return the results into a list. You can use either the places name or the latitude and longitude.
234
-
235
- Args:
236
- search_term (:class:`str`): the place name.
237
- places (:class:`str`): the name of the location. Defaults to ``'unknown'``.
238
- latitude (:class:`float`): the latitude of the location. Defaults to ``0.0``.
239
- longitude (:class:`float`): the longitude of the location. Defaults to ``0.0``.
240
- radius (:class:`int`): radius in meters. Defaults to ``5000``.
241
-
242
- Returns:
243
- :class:`dict`: places information
244
- * places (list): the list of places, each place is a dict with name and address, at most 5 places.
245
- """
246
- url = self.base_url + 'LocalSearch'
247
- if places != 'unknown':
248
- pos = self.get_coordinates(**{'location': places})
249
- latitude, longitude = pos[1]['latitude'], pos[1]['longitude']
250
- # Build the request query string
251
- params = {
252
- 'query': search_term,
253
- 'userLocation': f'{latitude},{longitude}',
254
- 'radius': radius,
255
- 'key': self.key
256
- }
257
- async with aiohttp.ClientSession() as session:
258
- async with session.get(url, params=params) as resp:
259
- data = await resp.json()
260
- results = data['resourceSets'][0]['resources']
261
- addresses = []
262
- for result in results:
263
- name = result['name']
264
- address = result['Address']['formattedAddress']
265
- addresses.append(dict(name=name, address=address))
266
- if len(addresses) == 5:
267
- break
268
- return dict(place=addresses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/builtin_actions.py DELETED
@@ -1,109 +0,0 @@
1
- from typing import Optional
2
-
3
- from lagent.actions.base_action import BaseAction, tool_api
4
- from lagent.actions.parser import BaseParser
5
- from lagent.schema import ActionReturn, ActionStatusCode, ActionValidCode
6
-
7
-
8
- class InvalidAction(BaseAction):
9
- """This is a invalid action class, which is used to return error message
10
- when the action is invalid.
11
-
12
- Args:
13
- err_msg (str): The error message. Defaults to 'The action is invalid,
14
- please check the action name'.
15
-
16
- Returns:
17
- ActionReturn: The action return.
18
- """
19
-
20
- def __init__(self,
21
- err_msg:
22
- str = 'The action is invalid, please check the action name.',
23
- description: Optional[dict] = None,
24
- parser=BaseParser) -> None:
25
- super().__init__(description, parser)
26
- self._err_msg = err_msg
27
-
28
- @tool_api
29
- def run(self, err_msg: Optional[str] = None) -> ActionReturn:
30
- """Return the error message.
31
-
32
- Args:
33
- err_msg (str, optional): The error message. If err_msg is not None,
34
- it will be returned, otherwise the default error message will
35
- be returned. Defaults to None.
36
- """
37
- action_return = ActionReturn(
38
- url=None,
39
- args=dict(text=err_msg),
40
- errmsg=err_msg or self._err_msg,
41
- type=self.name,
42
- valid=ActionValidCode.INVALID,
43
- state=ActionStatusCode.API_ERROR)
44
- return action_return
45
-
46
-
47
- class NoAction(BaseAction):
48
- """This is a no action class, which is used to return error message when
49
- the response does not follow the format.
50
-
51
- Args:
52
- err_msg (str): The error message. Defaults to
53
- 'Please follow the format'.
54
- """
55
-
56
- def __init__(self,
57
- err_msg: str = 'Please follow the format',
58
- description: Optional[dict] = None,
59
- parser=BaseParser):
60
- super().__init__(description, parser)
61
- self._err_msg = err_msg
62
-
63
- @tool_api
64
- def run(self, err_msg: Optional[str] = None) -> ActionReturn:
65
- """Return the error message.
66
-
67
- Args:
68
- err_msg (str, optional): The error message. If err_msg is not None,
69
- it will be returned, otherwise the default error message will
70
- be returned. Defaults to None.
71
-
72
- Returns:
73
- ActionReturn: The action return.
74
- """
75
- action_return = ActionReturn(
76
- url=None,
77
- args=dict(text=err_msg),
78
- type=self.name,
79
- errmsg=err_msg or self._err_msg,
80
- valid=ActionValidCode.INVALID,
81
- state=ActionStatusCode.API_ERROR)
82
- return action_return
83
-
84
-
85
- class FinishAction(BaseAction):
86
- """This is a finish action class, which is used to return the final
87
- result."""
88
-
89
- def __init__(self, description: Optional[dict] = None, parser=BaseParser):
90
- super().__init__(description, parser)
91
-
92
- @tool_api
93
- def run(self, response: str) -> ActionReturn:
94
- """Return the final result.
95
-
96
- Args:
97
- response (str): The final result.
98
-
99
- Returns:
100
- ActionReturn: The action return.
101
- """
102
- action_return = ActionReturn(
103
- url=None,
104
- args=dict(text=response),
105
- result=[dict(type='text', content=response)],
106
- type=self.name,
107
- valid=ActionValidCode.FINISH,
108
- state=ActionStatusCode.SUCCESS)
109
- return action_return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/google_scholar_search.py DELETED
@@ -1,438 +0,0 @@
1
- # flake8: noqa: E501
2
- import os
3
- from typing import Optional, Type
4
-
5
- from asyncer import asyncify
6
-
7
- from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
8
- from lagent.schema import ActionReturn, ActionStatusCode
9
- from .parser import BaseParser, JsonParser
10
-
11
-
12
- class GoogleScholar(BaseAction):
13
- """Plugin for google scholar search.
14
-
15
- Args:
16
- api_key (str): API KEY to use serper google search API,
17
- You can create a free API key at https://serper.dev.
18
- description (dict): The description of the action. Defaults to ``None``.
19
- parser (Type[BaseParser]): The parser class to process the
20
- action's inputs and outputs. Defaults to :class:`JsonParser`.
21
- """
22
-
23
- def __init__(
24
- self,
25
- api_key: Optional[str] = None,
26
- description: Optional[dict] = None,
27
- parser: Type[BaseParser] = JsonParser,
28
- ):
29
- super().__init__(description, parser)
30
- api_key = os.environ.get('SERPER_API_KEY', api_key)
31
- if api_key is None:
32
- raise ValueError(
33
- 'Please set Serper API key either in the environment '
34
- 'as SERPER_API_KEY or pass it as `api_key` parameter.'
35
- )
36
- self.api_key = api_key
37
-
38
- @tool_api(explode_return=True)
39
- def search_google_scholar(
40
- self,
41
- query: str,
42
- cites: Optional[str] = None,
43
- as_ylo: Optional[int] = None,
44
- as_yhi: Optional[int] = None,
45
- scisbd: Optional[int] = None,
46
- cluster: Optional[str] = None,
47
- hl: Optional[str] = None,
48
- lr: Optional[str] = None,
49
- start: Optional[int] = None,
50
- num: Optional[int] = None,
51
- as_sdt: Optional[str] = None,
52
- safe: Optional[str] = None,
53
- filter: Optional[str] = None,
54
- as_vis: Optional[str] = None,
55
- ) -> dict:
56
- """Search for scholarly articles based on a query according to the google scholar.
57
-
58
- Args:
59
- query (str): The query to search for.
60
- cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
61
- as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
62
- as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
63
- scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
64
- cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
65
- hl (Optional[str]): The language to use for the Google Scholar search.
66
- lr (Optional[str]): One or multiple languages to limit the search to.
67
- start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
68
- num (Optional[int]): The maximum number of results to return, limited to 20.
69
- as_sdt (Optional[str]): Can be used either as a search type or a filter.
70
- safe (Optional[str]): The level of filtering for adult content.
71
- filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
72
- as_vis (Optional[str]): Defines whether to include citations or not.
73
-
74
- Returns:
75
- :class:`dict`: article information
76
- - title: a list of the titles of the three selected papers
77
- - cited_by: a list of the citation numbers of the three selected papers
78
- - organic_id: a list of the organic results' ids of the three selected papers
79
- - pub_info: publication information of selected papers
80
- """
81
- from serpapi import GoogleSearch
82
-
83
- params = {
84
- 'q': query,
85
- 'engine': 'google_scholar',
86
- 'api_key': self.api_key,
87
- 'cites': cites,
88
- 'as_ylo': as_ylo,
89
- 'as_yhi': as_yhi,
90
- 'scisbd': scisbd,
91
- 'cluster': cluster,
92
- 'hl': hl,
93
- 'lr': lr,
94
- 'start': start,
95
- 'num': num,
96
- 'as_sdt': as_sdt,
97
- 'safe': safe,
98
- 'filter': filter,
99
- 'as_vis': as_vis,
100
- }
101
- search = GoogleSearch(params)
102
- try:
103
- r = search.get_dict()
104
- results = r['organic_results']
105
- title = []
106
- snippets = []
107
- cited_by = []
108
- organic_id = []
109
- pub_info = []
110
- for item in results[:3]:
111
- title.append(item['title'])
112
- pub_info.append(item['publication_info']['summary'])
113
- citation = item['inline_links'].get('cited_by', {'total': ''})
114
- cited_by.append(citation['total'])
115
- snippets.append(item['snippet'])
116
- organic_id.append(item['result_id'])
117
- return dict(title=title, cited_by=cited_by, organic_id=organic_id, snippets=snippets)
118
- except Exception as e:
119
- return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
120
-
121
- @tool_api(explode_return=True)
122
- def get_author_information(
123
- self,
124
- author_id: str,
125
- hl: Optional[str] = None,
126
- view_op: Optional[str] = None,
127
- sort: Optional[str] = None,
128
- citation_id: Optional[str] = None,
129
- start: Optional[int] = None,
130
- num: Optional[int] = None,
131
- no_cache: Optional[bool] = None,
132
- async_req: Optional[bool] = None,
133
- output: Optional[str] = None,
134
- ) -> dict:
135
- """Search for an author's information by author's id provided by get_author_id.
136
-
137
- Args:
138
- author_id (str): Required. The ID of an author.
139
- hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
140
- view_op (Optional[str]): Used for viewing specific parts of a page.
141
- sort (Optional[str]): Used for sorting and refining articles.
142
- citation_id (Optional[str]): Used for retrieving individual article citation.
143
- start (Optional[int]): Defines the result offset. Default is 0.
144
- num (Optional[int]): Defines the number of results to return. Default is 20.
145
- no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
146
- async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
147
- output (Optional[str]): Defines the final output you want. Default is 'json'.
148
-
149
- Returns:
150
- :class:`dict`: author information
151
- * name: author's name
152
- * affliation: the affliation of the author
153
- * articles: at most 3 articles by the author
154
- * website: the author's homepage url
155
- """
156
- from serpapi import GoogleSearch
157
-
158
- params = {
159
- 'engine': 'google_scholar_author',
160
- 'author_id': author_id,
161
- 'api_key': self.api_key,
162
- 'hl': hl,
163
- 'view_op': view_op,
164
- 'sort': sort,
165
- 'citation_id': citation_id,
166
- 'start': start,
167
- 'num': num,
168
- 'no_cache': no_cache,
169
- 'async': async_req,
170
- 'output': output,
171
- }
172
- try:
173
- search = GoogleSearch(params)
174
- results = search.get_dict()
175
- author = results['author']
176
- articles = results.get('articles', [])
177
- return dict(
178
- name=author['name'],
179
- affiliations=author.get('affiliations', ''),
180
- website=author.get('website', ''),
181
- articles=[dict(title=article['title'], authors=article['authors']) for article in articles[:3]],
182
- )
183
- except Exception as e:
184
- return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
185
-
186
- @tool_api(explode_return=True)
187
- def get_citation_format(
188
- self,
189
- q: str,
190
- no_cache: Optional[bool] = None,
191
- async_: Optional[bool] = None,
192
- output: Optional[str] = 'json',
193
- ) -> dict:
194
- """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
195
-
196
- Args:
197
- q (str): ID of an individual Google Scholar organic search result.
198
- no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
199
- async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
200
- output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
201
-
202
- Returns:
203
- :class:`dict`: citation format
204
- * authors: the authors of the article
205
- * citation: the citation format of the article
206
- """
207
- from serpapi import GoogleSearch
208
-
209
- params = {
210
- 'q': q,
211
- 'engine': 'google_scholar_cite',
212
- 'api_key': self.api_key,
213
- 'no_cache': no_cache,
214
- 'async': async_,
215
- 'output': output,
216
- }
217
- try:
218
- search = GoogleSearch(params)
219
- results = search.get_dict()
220
- citation = results['citations']
221
- citation_info = citation[0]['snippet']
222
- return citation_info
223
- except Exception as e:
224
- return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
225
-
226
- @tool_api(explode_return=True)
227
- def get_author_id(
228
- self,
229
- mauthors: str,
230
- hl: Optional[str] = 'en',
231
- after_author: Optional[str] = None,
232
- before_author: Optional[str] = None,
233
- no_cache: Optional[bool] = False,
234
- _async: Optional[bool] = False,
235
- output: Optional[str] = 'json',
236
- ) -> dict:
237
- """The getAuthorId function is used to get the author's id by his or her name.
238
-
239
- Args:
240
- mauthors (str): Defines the author you want to search for.
241
- hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
242
- after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
243
- before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
244
- no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
245
- _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
246
- output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
247
-
248
- Returns:
249
- :class:`dict`: author id
250
- * author_id: the author_id of the author
251
- """
252
- from serpapi import GoogleSearch
253
-
254
- params = {
255
- 'mauthors': mauthors,
256
- 'engine': 'google_scholar_profiles',
257
- 'api_key': self.api_key,
258
- 'hl': hl,
259
- 'after_author': after_author,
260
- 'before_author': before_author,
261
- 'no_cache': no_cache,
262
- 'async': _async,
263
- 'output': output,
264
- }
265
- try:
266
- search = GoogleSearch(params)
267
- results = search.get_dict()
268
- profile = results['profiles']
269
- author_info = dict(author_id=profile[0]['author_id'])
270
- return author_info
271
- except Exception as e:
272
- return ActionReturn(errmsg=str(e), state=ActionStatusCode.HTTP_ERROR)
273
-
274
-
275
- class AsyncGoogleScholar(AsyncActionMixin, GoogleScholar):
276
- """Plugin for google scholar search.
277
-
278
- Args:
279
- api_key (str): API KEY to use serper google search API,
280
- You can create a free API key at https://serper.dev.
281
- description (dict): The description of the action. Defaults to ``None``.
282
- parser (Type[BaseParser]): The parser class to process the
283
- action's inputs and outputs. Defaults to :class:`JsonParser`.
284
- """
285
-
286
- @tool_api(explode_return=True)
287
- @asyncify
288
- def search_google_scholar(
289
- self,
290
- query: str,
291
- cites: Optional[str] = None,
292
- as_ylo: Optional[int] = None,
293
- as_yhi: Optional[int] = None,
294
- scisbd: Optional[int] = None,
295
- cluster: Optional[str] = None,
296
- hl: Optional[str] = None,
297
- lr: Optional[str] = None,
298
- start: Optional[int] = None,
299
- num: Optional[int] = None,
300
- as_sdt: Optional[str] = None,
301
- safe: Optional[str] = None,
302
- filter: Optional[str] = None,
303
- as_vis: Optional[str] = None,
304
- ) -> dict:
305
- """Search for scholarly articles based on a query according to the google scholar.
306
-
307
- Args:
308
- query (str): The query to search for.
309
- cites (Optional[str]): The unique ID of an article for triggering "Cited By" searches.
310
- as_ylo (Optional[int]): The starting year for results (e.g., if as_ylo=2018, results before this year will be omitted).
311
- as_yhi (Optional[int]): The ending year for results (e.g., if as_yhi=2018, results after this year will be omitted).
312
- scisbd (Optional[int]): Defines articles added in the last year, sorted by date. It can be set to 1 to include only abstracts, or 2 to include everything.
313
- cluster (Optional[str]): The unique ID of an article for triggering "All Versions" searches.
314
- hl (Optional[str]): The language to use for the Google Scholar search.
315
- lr (Optional[str]): One or multiple languages to limit the search to.
316
- start (Optional[int]): The result offset for pagination (0 is the first page of results, 10 is the 2nd page, etc.)
317
- num (Optional[int]): The maximum number of results to return, limited to 20.
318
- as_sdt (Optional[str]): Can be used either as a search type or a filter.
319
- safe (Optional[str]): The level of filtering for adult content.
320
- filter (Optional[str]): Defines if the filters for 'Similar Results' and 'Omitted Results' are on or off.
321
- as_vis (Optional[str]): Defines whether to include citations or not.
322
-
323
- Returns:
324
- :class:`dict`: article information
325
- - title: a list of the titles of the three selected papers
326
- - cited_by: a list of the citation numbers of the three selected papers
327
- - organic_id: a list of the organic results' ids of the three selected papers
328
- - pub_info: publication information of selected papers
329
- """
330
- return super().search_google_scholar(
331
- query,
332
- cites,
333
- as_ylo,
334
- as_yhi,
335
- scisbd,
336
- cluster,
337
- hl,
338
- lr,
339
- start,
340
- num,
341
- as_sdt,
342
- safe,
343
- filter,
344
- as_vis,
345
- )
346
-
347
- @tool_api(explode_return=True)
348
- @asyncify
349
- def get_author_information(
350
- self,
351
- author_id: str,
352
- hl: Optional[str] = None,
353
- view_op: Optional[str] = None,
354
- sort: Optional[str] = None,
355
- citation_id: Optional[str] = None,
356
- start: Optional[int] = None,
357
- num: Optional[int] = None,
358
- no_cache: Optional[bool] = None,
359
- async_req: Optional[bool] = None,
360
- output: Optional[str] = None,
361
- ) -> dict:
362
- """Search for an author's information by author's id provided by get_author_id.
363
-
364
- Args:
365
- author_id (str): Required. The ID of an author.
366
- hl (Optional[str]): The language to use for the Google Scholar Author search. Default is 'en'.
367
- view_op (Optional[str]): Used for viewing specific parts of a page.
368
- sort (Optional[str]): Used for sorting and refining articles.
369
- citation_id (Optional[str]): Used for retrieving individual article citation.
370
- start (Optional[int]): Defines the result offset. Default is 0.
371
- num (Optional[int]): Defines the number of results to return. Default is 20.
372
- no_cache (Optional[bool]): Forces SerpApi to fetch the results even if a cached version is already present. Default is False.
373
- async_req (Optional[bool]): Defines the way you want to submit your search to SerpApi. Default is False.
374
- output (Optional[str]): Defines the final output you want. Default is 'json'.
375
-
376
- Returns:
377
- :class:`dict`: author information
378
- * name: author's name
379
- * affliation: the affliation of the author
380
- * articles: at most 3 articles by the author
381
- * website: the author's homepage url
382
- """
383
- return super().get_author_information(
384
- author_id, hl, view_op, sort, citation_id, start, num, no_cache, async_req, output
385
- )
386
-
387
- @tool_api(explode_return=True)
388
- @asyncify
389
- def get_citation_format(
390
- self,
391
- q: str,
392
- no_cache: Optional[bool] = None,
393
- async_: Optional[bool] = None,
394
- output: Optional[str] = 'json',
395
- ) -> dict:
396
- """Function to get MLA citation format by an identification of organic_result's id provided by search_google_scholar.
397
-
398
- Args:
399
- q (str): ID of an individual Google Scholar organic search result.
400
- no_cache (Optional[bool]): If set to True, will force SerpApi to fetch the Google Scholar Cite results even if a cached version is already present. Defaults to None.
401
- async_ (Optional[bool]): If set to True, will submit search to SerpApi and retrieve results later. Defaults to None.
402
- output (Optional[str]): Final output format. Set to 'json' to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
403
-
404
- Returns:
405
- :class:`dict`: citation format
406
- * authors: the authors of the article
407
- * citation: the citation format of the article
408
- """
409
- return super().get_citation_format(q, no_cache, async_, output)
410
-
411
- @tool_api(explode_return=True)
412
- @asyncify
413
- def get_author_id(
414
- self,
415
- mauthors: str,
416
- hl: Optional[str] = 'en',
417
- after_author: Optional[str] = None,
418
- before_author: Optional[str] = None,
419
- no_cache: Optional[bool] = False,
420
- _async: Optional[bool] = False,
421
- output: Optional[str] = 'json',
422
- ) -> dict:
423
- """The getAuthorId function is used to get the author's id by his or her name.
424
-
425
- Args:
426
- mauthors (str): Defines the author you want to search for.
427
- hl (Optional[str]): Defines the language to use for the Google Scholar Profiles search. It's a two-letter language code. (e.g., 'en' for English, 'es' for Spanish, or 'fr' for French). Defaults to 'en'.
428
- after_author (Optional[str]): Defines the next page token. It is used for retrieving the next page results. The parameter has the precedence over before_author parameter. Defaults to None.
429
- before_author (Optional[str]): Defines the previous page token. It is used for retrieving the previous page results. Defaults to None.
430
- no_cache (Optional[bool]): Will force SerpApi to fetch the Google Scholar Profiles results even if a cached version is already present. Defaults to False.
431
- _async (Optional[bool]): Defines the way you want to submit your search to SerpApi. Defaults to False.
432
- output (Optional[str]): Defines the final output you want. It can be set to 'json' (default) to get a structured JSON of the results, or 'html' to get the raw html retrieved. Defaults to 'json'.
433
-
434
- Returns:
435
- :class:`dict`: author id
436
- * author_id: the author_id of the author
437
- """
438
- return super().get_author_id(mauthors, hl, after_author, before_author, no_cache, _async, output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/google_search.py DELETED
@@ -1,244 +0,0 @@
1
- import os
2
- from typing import List, Optional, Tuple, Type, Union
3
-
4
- import aiohttp
5
- import requests
6
-
7
- from lagent.schema import ActionReturn, ActionStatusCode
8
- from .base_action import AsyncActionMixin, BaseAction, tool_api
9
- from .parser import BaseParser, JsonParser
10
-
11
-
12
- class GoogleSearch(BaseAction):
13
- """Wrapper around the Serper.dev Google Search API.
14
-
15
- To use, you should pass your serper API key to the constructor.
16
-
17
- Code is modified from lang-chain GoogleSerperAPIWrapper
18
- (https://github.com/langchain-ai/langchain/blob/ba5f
19
- baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
20
- langchain/utilities/google_serper.py)
21
-
22
- Args:
23
- api_key (str): API KEY to use serper google search API,
24
- You can create a free API key at https://serper.dev.
25
- timeout (int): Upper bound of waiting time for a serper request.
26
- search_type (str): Serper API support ['search', 'images', 'news',
27
- 'places'] types of search, currently we only support 'search'.
28
- description (dict): The description of the action. Defaults to ``None``.
29
- parser (Type[BaseParser]): The parser class to process the
30
- action's inputs and outputs. Defaults to :class:`JsonParser`.
31
- """
32
- result_key_for_type = {
33
- 'news': 'news',
34
- 'places': 'places',
35
- 'images': 'images',
36
- 'search': 'organic',
37
- }
38
-
39
- def __init__(
40
- self,
41
- api_key: Optional[str] = None,
42
- timeout: int = 5,
43
- search_type: str = 'search',
44
- description: Optional[dict] = None,
45
- parser: Type[BaseParser] = JsonParser,
46
- ):
47
- super().__init__(description, parser)
48
- api_key = os.environ.get('SERPER_API_KEY', api_key)
49
- if api_key is None:
50
- raise ValueError(
51
- 'Please set Serper API key either in the environment '
52
- 'as SERPER_API_KEY or pass it as `api_key` parameter.')
53
- self.api_key = api_key
54
- self.timeout = timeout
55
- self.search_type = search_type
56
-
57
- @tool_api
58
- def run(self, query: str, k: int = 10) -> ActionReturn:
59
- """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
60
-
61
- Args:
62
- query (str): the search content
63
- k (int): select first k results in the search results as response
64
- """
65
- tool_return = ActionReturn(type=self.name)
66
- status_code, response = self._search(query, k=k)
67
- # convert search results to ToolReturn format
68
- if status_code == -1:
69
- tool_return.errmsg = response
70
- tool_return.state = ActionStatusCode.HTTP_ERROR
71
- elif status_code == 200:
72
- parsed_res = self._parse_results(response, k)
73
- tool_return.result = [dict(type='text', content=str(parsed_res))]
74
- tool_return.state = ActionStatusCode.SUCCESS
75
- else:
76
- tool_return.errmsg = str(status_code)
77
- tool_return.state = ActionStatusCode.API_ERROR
78
- return tool_return
79
-
80
- def _parse_results(self, results: dict, k: int) -> Union[str, List[str]]:
81
- """Parse the search results from Serper API.
82
-
83
- Args:
84
- results (dict): The search content from Serper API
85
- in json format.
86
-
87
- Returns:
88
- List[str]: The parsed search results.
89
- """
90
-
91
- snippets = []
92
-
93
- if results.get('answerBox'):
94
- answer_box = results.get('answerBox', {})
95
- if answer_box.get('answer'):
96
- return [answer_box.get('answer')]
97
- elif answer_box.get('snippet'):
98
- return [answer_box.get('snippet').replace('\n', ' ')]
99
- elif answer_box.get('snippetHighlighted'):
100
- return answer_box.get('snippetHighlighted')
101
-
102
- if results.get('knowledgeGraph'):
103
- kg = results.get('knowledgeGraph', {})
104
- title = kg.get('title')
105
- entity_type = kg.get('type')
106
- if entity_type:
107
- snippets.append(f'{title}: {entity_type}.')
108
- description = kg.get('description')
109
- if description:
110
- snippets.append(description)
111
- for attribute, value in kg.get('attributes', {}).items():
112
- snippets.append(f'{title} {attribute}: {value}.')
113
-
114
- for result in results[self.result_key_for_type[
115
- self.search_type]][:k]:
116
- if 'snippet' in result:
117
- snippets.append(result['snippet'])
118
- for attribute, value in result.get('attributes', {}).items():
119
- snippets.append(f'{attribute}: {value}.')
120
-
121
- if len(snippets) == 0:
122
- return ['No good Google Search Result was found']
123
- return snippets
124
-
125
- def _search(self,
126
- search_term: str,
127
- search_type: Optional[str] = None,
128
- **kwargs) -> Tuple[int, Union[dict, str]]:
129
- """HTTP requests to Serper API.
130
-
131
- Args:
132
- search_term (str): The search query.
133
- search_type (str): search type supported by Serper API,
134
- default to 'search'.
135
-
136
- Returns:
137
- tuple: the return value is a tuple contains:
138
- - status_code (int): HTTP status code from Serper API.
139
- - response (dict): response context with json format.
140
- """
141
- headers = {
142
- 'X-API-KEY': self.api_key or '',
143
- 'Content-Type': 'application/json',
144
- }
145
- params = {
146
- 'q': search_term,
147
- **{
148
- key: value
149
- for key, value in kwargs.items() if value is not None
150
- },
151
- }
152
- try:
153
- response = requests.post(
154
- f'https://google.serper.dev/{search_type or self.search_type}',
155
- headers=headers,
156
- params=params,
157
- timeout=self.timeout)
158
- except Exception as e:
159
- return -1, str(e)
160
- return response.status_code, response.json()
161
-
162
-
163
- class AsyncGoogleSearch(AsyncActionMixin, GoogleSearch):
164
- """Wrapper around the Serper.dev Google Search API.
165
-
166
- To use, you should pass your serper API key to the constructor.
167
-
168
- Code is modified from lang-chain GoogleSerperAPIWrapper
169
- (https://github.com/langchain-ai/langchain/blob/ba5f
170
- baba704a2d729a4b8f568ed70d7c53e799bb/libs/langchain/
171
- langchain/utilities/google_serper.py)
172
-
173
- Args:
174
- api_key (str): API KEY to use serper google search API,
175
- You can create a free API key at https://serper.dev.
176
- timeout (int): Upper bound of waiting time for a serper request.
177
- search_type (str): Serper API support ['search', 'images', 'news',
178
- 'places'] types of search, currently we only support 'search'.
179
- description (dict): The description of the action. Defaults to ``None``.
180
- parser (Type[BaseParser]): The parser class to process the
181
- action's inputs and outputs. Defaults to :class:`JsonParser`.
182
- """
183
-
184
- @tool_api
185
- async def run(self, query: str, k: int = 10) -> ActionReturn:
186
- """一个可以从谷歌搜索结果的API。当你需要对于一个特定问题找到简短明了的回答时,可以使用它。输入应该是一个搜索查询。
187
-
188
- Args:
189
- query (str): the search content
190
- k (int): select first k results in the search results as response
191
- """
192
- tool_return = ActionReturn(type=self.name)
193
- status_code, response = await self._search(query, k=k)
194
- # convert search results to ToolReturn format
195
- if status_code == -1:
196
- tool_return.errmsg = response
197
- tool_return.state = ActionStatusCode.HTTP_ERROR
198
- elif status_code == 200:
199
- parsed_res = self._parse_results(response)
200
- tool_return.result = [dict(type='text', content=str(parsed_res))]
201
- tool_return.state = ActionStatusCode.SUCCESS
202
- else:
203
- tool_return.errmsg = str(status_code)
204
- tool_return.state = ActionStatusCode.API_ERROR
205
- return tool_return
206
-
207
- async def _search(self,
208
- search_term: str,
209
- search_type: Optional[str] = None,
210
- **kwargs) -> Tuple[int, Union[dict, str]]:
211
- """HTTP requests to Serper API.
212
-
213
- Args:
214
- search_term (str): The search query.
215
- search_type (str): search type supported by Serper API,
216
- default to 'search'.
217
-
218
- Returns:
219
- tuple: the return value is a tuple contains:
220
- - status_code (int): HTTP status code from Serper API.
221
- - response (dict): response context with json format.
222
- """
223
- headers = {
224
- 'X-API-KEY': self.api_key or '',
225
- 'Content-Type': 'application/json',
226
- }
227
- params = {
228
- 'q': search_term,
229
- **{
230
- key: value
231
- for key, value in kwargs.items() if value is not None
232
- },
233
- }
234
- timeout = aiohttp.ClientTimeout(total=self.timeout)
235
- async with aiohttp.ClientSession(timeout=timeout) as session:
236
- try:
237
- async with session.post(
238
- f'https://google.serper.dev/{search_type or self.search_type}',
239
- headers=headers,
240
- params=params) as resp:
241
- code, ret = resp.status, await resp.json()
242
- except aiohttp.ClientError as e:
243
- code, ret = -1, str(e)
244
- return code, ret
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/ipython_interactive.py DELETED
@@ -1,273 +0,0 @@
1
- import re
2
- import signal
3
- from contextlib import contextmanager, redirect_stdout
4
- from dataclasses import dataclass
5
- from enum import Enum
6
- from io import StringIO
7
- from typing import Optional, Type
8
-
9
- from ..schema import ActionReturn, ActionStatusCode
10
- from .base_action import AsyncActionMixin, BaseAction, tool_api
11
- from .parser import BaseParser, JsonParser
12
-
13
-
14
- class Status(str, Enum):
15
- """Execution status."""
16
- SUCCESS = 'success'
17
- FAILURE = 'failure'
18
-
19
-
20
- @dataclass
21
- class ExecutionResult:
22
- """Execution result."""
23
- status: Status
24
- value: Optional[str] = None
25
- msg: Optional[str] = None
26
-
27
-
28
- @contextmanager
29
- def _raise_timeout(timeout):
30
-
31
- def _handler(signum, frame):
32
- raise TimeoutError()
33
-
34
- signal.signal(signal.SIGALRM, _handler)
35
- signal.alarm(timeout)
36
-
37
- try:
38
- yield
39
- finally:
40
- signal.alarm(0)
41
-
42
-
43
- class IPythonInteractive(BaseAction):
44
- """An interactive IPython shell for code execution.
45
-
46
- Args:
47
- timeout (int): Upper bound of waiting time for Python script execution.
48
- Defaults to ``20``.
49
- max_out_len (int): maximum output length. No truncation occurs if negative.
50
- Defaults to ``2048``.
51
- use_signals (bool): whether signals should be used for timing function out
52
- or the multiprocessing. Set to ``False`` when not running in the main
53
- thread, e.g. web applications. Defaults to ``True``
54
- description (dict): The description of the action. Defaults to ``None``.
55
- parser (Type[BaseParser]): The parser class to process the
56
- action's inputs and outputs. Defaults to :class:`JsonParser`.
57
- """
58
-
59
- def __init__(
60
- self,
61
- timeout: int = 30,
62
- max_out_len: int = 8192,
63
- use_signals: bool = True,
64
- description: Optional[dict] = None,
65
- parser: Type[BaseParser] = JsonParser,
66
- ):
67
- super().__init__(description, parser)
68
- self.timeout = timeout
69
- self._executor = self.create_shell()
70
- self._highlighting = re.compile(
71
- r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
72
- self._max_out_len = max_out_len if max_out_len >= 0 else None
73
- self._use_signals = use_signals
74
-
75
- def reset(self):
76
- """Clear the context."""
77
- self._executor.reset()
78
-
79
- @tool_api
80
- def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
81
- """Launch an IPython Interactive Shell to execute code.
82
-
83
- Args:
84
- command (:class:`str`): Python code snippet
85
- timeout (:class:`Optional[int]`): timeout for execution.
86
- This argument only works in the main thread. Defaults to ``None``.
87
- """
88
- from timeout_decorator import timeout as timer
89
- tool_return = ActionReturn(args={'text': command}, type=self.name)
90
- ret = (
91
- timer(timeout or self.timeout)(self.exec)(command)
92
- if self._use_signals else self.exec(command))
93
- if ret.status is Status.SUCCESS:
94
- tool_return.result = [{'type': 'text', 'content': ret.value}]
95
- tool_return.state = ActionStatusCode.SUCCESS
96
- else:
97
- tool_return.errmsg = ret.msg
98
- tool_return.state = ActionStatusCode.API_ERROR
99
- return tool_return
100
-
101
- def exec(self, code: str) -> ExecutionResult:
102
- """Run Python scripts in IPython shell.
103
-
104
- Args:
105
- code (:class:`str`): code block
106
-
107
- Returns:
108
- :py:class:`ExecutionResult`: execution result
109
- """
110
- with StringIO() as io:
111
- with redirect_stdout(io):
112
- ret = self._executor.run_cell(self.extract_code(code))
113
- result = ret.result
114
- if result is not None:
115
- return ExecutionResult(Status.SUCCESS,
116
- str(result)[:self._max_out_len])
117
- outs = io.getvalue().strip().split('\n')
118
- if not outs:
119
- return ExecutionResult(Status.SUCCESS, '')
120
- for i, out in enumerate(outs):
121
- if re.search('Error|Traceback', out, re.S):
122
- if 'TimeoutError' in out:
123
- return ExecutionResult(
124
- Status.FAILURE,
125
- msg=('The code interpreter encountered '
126
- 'a timeout error.'))
127
- err_idx = i
128
- break
129
- else:
130
- return ExecutionResult(Status.SUCCESS,
131
- outs[-1].strip()[:self._max_out_len])
132
- return ExecutionResult(
133
- Status.FAILURE,
134
- msg=self._highlighting.sub(
135
- '', '\n'.join(outs[err_idx:])[:self._max_out_len]),
136
- )
137
-
138
- @staticmethod
139
- def create_shell():
140
- from IPython import InteractiveShell
141
- from traitlets.config import Config
142
-
143
- c = Config()
144
- c.HistoryManager.enabled = False
145
- c.HistoryManager.hist_file = ':memory:'
146
- return InteractiveShell(
147
- user_ns={'_raise_timeout': _raise_timeout}, config=c)
148
-
149
- @staticmethod
150
- def extract_code(text: str) -> str:
151
- """Extract Python code from markup languages.
152
-
153
- Args:
154
- text (:class:`str`): Markdown-formatted text
155
-
156
- Returns:
157
- :class:`str`: Python code
158
- """
159
- import json5
160
-
161
- # Match triple backtick blocks first
162
- triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
163
- # Match single backtick blocks second
164
- single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
165
- if triple_match:
166
- text = triple_match.group(1)
167
- elif single_match:
168
- text = single_match.group(1)
169
- else:
170
- try:
171
- text = json5.loads(text)['code']
172
- except Exception:
173
- pass
174
- # If no code blocks found, return original text
175
- return text
176
-
177
- @staticmethod
178
- def wrap_code_with_timeout(code: str, timeout: int) -> str:
179
- if not code.strip():
180
- return code
181
- code = code.strip('\n').rstrip()
182
- indent = len(code) - len(code.lstrip())
183
- handle = ' ' * indent + f'with _raise_timeout({timeout}):\n'
184
- block = '\n'.join([' ' + line for line in code.split('\n')])
185
- wrapped_code = handle + block
186
- last_line = code.split('\n')[-1]
187
- is_expression = True
188
- try:
189
- compile(last_line.lstrip(), '<stdin>', 'eval')
190
- except SyntaxError:
191
- is_expression = False
192
- if is_expression:
193
- wrapped_code += '\n' * 5 + last_line
194
- return wrapped_code
195
-
196
-
197
- class AsyncIPythonInteractive(AsyncActionMixin, IPythonInteractive):
198
- """An interactive IPython shell for code execution.
199
-
200
- Args:
201
- timeout (int): Upper bound of waiting time for Python script execution.
202
- Defaults to ``20``.
203
- max_out_len (int): maximum output length. No truncation occurs if negative.
204
- Defaults to ``2048``.
205
- use_signals (bool): whether signals should be used for timing function out
206
- or the multiprocessing. Set to ``False`` when not running in the main
207
- thread, e.g. web applications. Defaults to ``True``
208
- description (dict): The description of the action. Defaults to ``None``.
209
- parser (Type[BaseParser]): The parser class to process the
210
- action's inputs and outputs. Defaults to :class:`JsonParser`.
211
- """
212
-
213
- @tool_api
214
- async def run(self,
215
- command: str,
216
- timeout: Optional[int] = None) -> ActionReturn:
217
- """Launch an IPython Interactive Shell to execute code.
218
-
219
- Args:
220
- command (:class:`str`): Python code snippet
221
- timeout (:class:`Optional[int]`): timeout for execution.
222
- This argument only works in the main thread. Defaults to ``None``.
223
- """
224
- tool_return = ActionReturn(args={'text': command}, type=self.name)
225
- ret = await self.exec(command, timeout)
226
- if ret.status is Status.SUCCESS:
227
- tool_return.result = [{'type': 'text', 'content': ret.value}]
228
- tool_return.state = ActionStatusCode.SUCCESS
229
- else:
230
- tool_return.errmsg = ret.msg
231
- tool_return.state = ActionStatusCode.API_ERROR
232
- return tool_return
233
-
234
- async def exec(self, code: str, timeout: int = None) -> ExecutionResult:
235
- """Asynchronously run Python scripts in IPython shell.
236
-
237
- Args:
238
- code (:class:`str`): code block
239
- timeout (:class:`int`): max waiting time for code execution
240
-
241
- Returns:
242
- :py:class:`ExecutionResult`: execution result
243
- """
244
- with StringIO() as io:
245
- with redirect_stdout(io):
246
- ret = await self._executor.run_cell_async(
247
- # ret = await self.create_shell().run_cell_async(
248
- self.wrap_code_with_timeout(
249
- self.extract_code(code), timeout or self.timeout))
250
- result = ret.result
251
- if result is not None:
252
- return ExecutionResult(Status.SUCCESS,
253
- str(result)[:self._max_out_len])
254
- outs = io.getvalue().strip().split('\n')
255
- if not outs:
256
- return ExecutionResult(Status.SUCCESS, '')
257
- for i, out in enumerate(outs):
258
- if re.search('Error|Traceback', out, re.S):
259
- if 'TimeoutError' in out:
260
- return ExecutionResult(
261
- Status.FAILURE,
262
- msg=('The code interpreter encountered a '
263
- 'timeout error.'))
264
- err_idx = i
265
- break
266
- else:
267
- return ExecutionResult(Status.SUCCESS,
268
- outs[-1].strip()[:self._max_out_len])
269
- return ExecutionResult(
270
- Status.FAILURE,
271
- msg=self._highlighting.sub(
272
- '', '\n'.join(outs[err_idx:])[:self._max_out_len]),
273
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/ipython_interpreter.py DELETED
@@ -1,584 +0,0 @@
1
- # flake8: noqa: E501
2
- import asyncio
3
- import base64
4
- import io
5
- import json
6
- import logging
7
- import os
8
- import queue
9
- import re
10
- import signal
11
- import sys
12
- import tempfile
13
- import traceback
14
- import uuid
15
- from typing import Optional, Tuple, Type
16
-
17
- from jupyter_client import AsyncKernelClient, AsyncKernelManager, AsyncMultiKernelManager
18
- from tenacity import retry, retry_if_result, stop_after_attempt, wait_fixed
19
-
20
- from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
21
- from lagent.actions.parser import BaseParser, JsonParser
22
- from lagent.schema import ActionReturn, ActionStatusCode
23
-
24
- logger = logging.getLogger(__name__)
25
-
26
- START_CODE = """
27
- def input(*args, **kwargs):
28
- raise NotImplementedError('Python input() function is disabled.')
29
-
30
- get_ipython().system = lambda *args: print('Assume we have this package, ! is disabled!')
31
- {}
32
- """ # noqa
33
-
34
-
35
- class TimeoutError(Exception):
36
- pass
37
-
38
-
39
- class KernelDeath(Exception):
40
- pass
41
-
42
-
43
- async def async_run_code(
44
- km: AsyncKernelManager,
45
- code,
46
- *,
47
- interrupt_after=30,
48
- iopub_timeout=40,
49
- wait_for_ready_timeout=60,
50
- shutdown_kernel=True,
51
- ):
52
- assert iopub_timeout > interrupt_after
53
- try:
54
-
55
- async def get_iopub_msg_with_death_detection(kc: AsyncKernelClient,
56
- *,
57
- timeout=None):
58
- loop = asyncio.get_running_loop()
59
- dead_fut = loop.create_future()
60
-
61
- def restarting():
62
- assert (
63
- False
64
- ), "Restart shouldn't happen because config.KernelRestarter.restart_limit is expected to be set to 0"
65
-
66
- def dead():
67
- logger.info("Kernel has died, will NOT restart")
68
- dead_fut.set_result(None)
69
-
70
- msg_task = asyncio.create_task(kc.get_iopub_msg(timeout=timeout))
71
- km.add_restart_callback(restarting, "restart")
72
- km.add_restart_callback(dead, "dead")
73
- try:
74
- done, _ = await asyncio.wait(
75
- [dead_fut, msg_task], return_when=asyncio.FIRST_COMPLETED)
76
- if dead_fut in done:
77
- raise KernelDeath()
78
- assert msg_task in done
79
- return await msg_task
80
- finally:
81
- msg_task.cancel()
82
- km.remove_restart_callback(restarting, "restart")
83
- km.remove_restart_callback(dead, "dead")
84
-
85
- async def send_interrupt():
86
- await asyncio.sleep(interrupt_after)
87
- logger.info("Sending interrupt to kernel")
88
- await km.interrupt_kernel()
89
-
90
- @retry(
91
- retry=retry_if_result(lambda ret: ret[-1].strip() in [
92
- 'KeyboardInterrupt',
93
- f"Kernel didn't respond in {wait_for_ready_timeout} seconds",
94
- ] if isinstance(ret, tuple) else False),
95
- stop=stop_after_attempt(3),
96
- wait=wait_fixed(1),
97
- retry_error_callback=lambda state: state.outcome.result())
98
- async def run():
99
- execute_result = None
100
- error_traceback = None
101
- stream_text_list = []
102
- kc = km.client()
103
- assert isinstance(kc, AsyncKernelClient)
104
- kc.start_channels()
105
- try:
106
- await kc.wait_for_ready(timeout=wait_for_ready_timeout)
107
- msg_id = kc.execute(code)
108
- while True:
109
- message = await get_iopub_msg_with_death_detection(
110
- kc, timeout=iopub_timeout)
111
- if logger.isEnabledFor(logging.DEBUG):
112
- logger.debug(
113
- json.dumps(message, indent=2, default=str))
114
- assert message["parent_header"]["msg_id"] == msg_id
115
- msg_type = message["msg_type"]
116
- if msg_type == "status":
117
- if message["content"]["execution_state"] == "idle":
118
- break
119
- elif msg_type == "stream":
120
- stream_name = message["content"]["name"]
121
- stream_text = message["content"]["text"]
122
- stream_text_list.append(stream_text)
123
- elif msg_type == "execute_result":
124
- execute_result = message["content"]["data"]
125
- elif msg_type == "error":
126
- error_traceback_lines = message["content"]["traceback"]
127
- error_traceback = "\n".join(error_traceback_lines)
128
- elif msg_type == "execute_input":
129
- pass
130
- else:
131
- assert False, f"Unknown message_type: {msg_type}"
132
- finally:
133
- kc.stop_channels()
134
- return execute_result, error_traceback, "".join(stream_text_list)
135
-
136
- if interrupt_after:
137
- run_task = asyncio.create_task(run())
138
- send_interrupt_task = asyncio.create_task(send_interrupt())
139
- done, _ = await asyncio.wait([run_task, send_interrupt_task],
140
- return_when=asyncio.FIRST_COMPLETED)
141
- if run_task in done:
142
- send_interrupt_task.cancel()
143
- else:
144
- assert send_interrupt_task in done
145
- result = await run_task
146
- else:
147
- result = await run()
148
- return result
149
- finally:
150
- if shutdown_kernel:
151
- await km.shutdown_kernel()
152
-
153
-
154
- class IPythonInterpreter(BaseAction):
155
- """A IPython executor that can execute Python scripts in a jupyter manner.
156
-
157
- Args:
158
- timeout (int): Upper bound of waiting time for Python script execution.
159
- Defaults to 20.
160
- user_data_dir (str, optional): Specified the user data directory for files
161
- loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
162
- Defaults to `ENV`.
163
- work_dir (str, optional): Specify which directory to save output images to.
164
- Defaults to ``'./work_dir/tmp_dir'``.
165
- description (dict): The description of the action. Defaults to ``None``.
166
- parser (Type[BaseParser]): The parser class to process the
167
- action's inputs and outputs. Defaults to :class:`JsonParser`.
168
- """
169
-
170
- _KERNEL_CLIENTS = {}
171
-
172
- def __init__(
173
- self,
174
- timeout: int = 20,
175
- user_data_dir: str = 'ENV',
176
- work_dir='./work_dir/tmp_dir',
177
- description: Optional[dict] = None,
178
- parser: Type[BaseParser] = JsonParser,
179
- ):
180
- super().__init__(description, parser)
181
-
182
- self.timeout = timeout
183
- if user_data_dir == 'ENV':
184
- user_data_dir = os.environ.get('USER_DATA_DIR', '')
185
-
186
- if user_data_dir:
187
- user_data_dir = os.path.dirname(user_data_dir)
188
- user_data_dir = f"import os\nos.chdir('{user_data_dir}')"
189
- self.user_data_dir = user_data_dir
190
- self._initialized = False
191
- self.work_dir = work_dir
192
- if not os.path.exists(self.work_dir):
193
- os.makedirs(self.work_dir, exist_ok=True)
194
-
195
- @staticmethod
196
- def start_kernel():
197
- from jupyter_client import KernelManager
198
-
199
- # start the kernel and manager
200
- km = KernelManager()
201
- km.start_kernel()
202
- kc = km.client()
203
- return km, kc
204
-
205
- def initialize(self):
206
- if self._initialized:
207
- return
208
- pid = os.getpid()
209
- if pid not in self._KERNEL_CLIENTS:
210
- self._KERNEL_CLIENTS[pid] = self.start_kernel()
211
- self.kernel_manager, self.kernel_client = self._KERNEL_CLIENTS[pid]
212
- self._initialized = True
213
- self._call(START_CODE.format(self.user_data_dir), None)
214
-
215
- def reset(self):
216
- if not self._initialized:
217
- self.initialize()
218
- else:
219
- code = "get_ipython().run_line_magic('reset', '-f')\n" + \
220
- START_CODE.format(self.user_data_dir)
221
- self._call(code, None)
222
-
223
- def _call(self,
224
- command: str,
225
- timeout: Optional[int] = None) -> Tuple[str, bool]:
226
- self.initialize()
227
- command = extract_code(command)
228
-
229
- # check previous remaining result
230
- while True:
231
- try:
232
- msg = self.kernel_client.get_iopub_msg(timeout=5)
233
- msg_type = msg['msg_type']
234
- if msg_type == 'status':
235
- if msg['content'].get('execution_state') == 'idle':
236
- break
237
- except queue.Empty:
238
- # assume no result
239
- break
240
-
241
- self.kernel_client.execute(command)
242
-
243
- def _inner_call():
244
- result = ''
245
- images = []
246
- succeed = True
247
- image_idx = 0
248
-
249
- while True:
250
- text = ''
251
- image = ''
252
- finished = False
253
- msg_type = 'error'
254
- try:
255
- msg = self.kernel_client.get_iopub_msg(timeout=20)
256
- msg_type = msg['msg_type']
257
- if msg_type == 'status':
258
- if msg['content'].get('execution_state') == 'idle':
259
- finished = True
260
- elif msg_type == 'execute_result':
261
- text = msg['content']['data'].get('text/plain', '')
262
- if 'image/png' in msg['content']['data']:
263
- image_b64 = msg['content']['data']['image/png']
264
- image_url = publish_image_to_local(
265
- image_b64, self.work_dir)
266
- image_idx += 1
267
- image = '![fig-%03d](%s)' % (image_idx, image_url)
268
-
269
- elif msg_type == 'display_data':
270
- if 'image/png' in msg['content']['data']:
271
- image_b64 = msg['content']['data']['image/png']
272
- image_url = publish_image_to_local(
273
- image_b64, self.work_dir)
274
- image_idx += 1
275
- image = '![fig-%03d](%s)' % (image_idx, image_url)
276
-
277
- else:
278
- text = msg['content']['data'].get('text/plain', '')
279
- elif msg_type == 'stream':
280
- msg_type = msg['content']['name'] # stdout, stderr
281
- text = msg['content']['text']
282
- elif msg_type == 'error':
283
- succeed = False
284
- text = escape_ansi('\n'.join(
285
- msg['content']['traceback']))
286
- if 'M6_CODE_INTERPRETER_TIMEOUT' in text:
287
- text = f'Timeout. No response after {timeout} seconds.' # noqa
288
- except queue.Empty:
289
- # stop current task in case break next input.
290
- self.kernel_manager.interrupt_kernel()
291
- succeed = False
292
- text = f'Timeout. No response after {timeout} seconds.'
293
- finished = True
294
- except Exception:
295
- succeed = False
296
- msg = ''.join(traceback.format_exception(*sys.exc_info()))
297
- # text = 'The code interpreter encountered an unexpected error.' # noqa
298
- text = msg
299
- logging.warning(msg)
300
- finished = True
301
- if text:
302
- # result += f'\n\n{msg_type}:\n\n```\n{text}\n```'
303
- result += f'{text}'
304
-
305
- if image:
306
- images.append(image_url)
307
- if finished:
308
- return succeed, dict(text=result, image=images)
309
-
310
- try:
311
- if timeout:
312
-
313
- def handler(signum, frame):
314
- raise TimeoutError()
315
-
316
- signal.signal(signal.SIGALRM, handler)
317
- signal.alarm(timeout)
318
- succeed, result = _inner_call()
319
- except TimeoutError:
320
- succeed = False
321
- text = 'The code interpreter encountered an unexpected error.'
322
- result = f'\n\nerror:\n\n```\n{text}\n```'
323
- finally:
324
- if timeout:
325
- signal.alarm(0)
326
-
327
- # result = result.strip('\n')
328
- return succeed, result
329
-
330
- @tool_api
331
- def run(self, command: str, timeout: Optional[int] = None) -> ActionReturn:
332
- r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
333
-
334
- Args:
335
- command (:class:`str`): Python code
336
- timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
337
- """
338
- tool_return = ActionReturn(url=None, args=None, type=self.name)
339
- tool_return.args = dict(text=command)
340
- succeed, result = self._call(command, timeout)
341
- if succeed:
342
- text = result['text']
343
- image = result.get('image', [])
344
- resp = [dict(type='text', content=text)]
345
- if image:
346
- resp.extend([dict(type='image', content=im) for im in image])
347
- tool_return.result = resp
348
- # tool_return.result = dict(
349
- # text=result['text'], image=result.get('image', [])[0])
350
- tool_return.state = ActionStatusCode.SUCCESS
351
- else:
352
- tool_return.errmsg = result.get('text', '') if isinstance(
353
- result, dict) else result
354
- tool_return.state = ActionStatusCode.API_ERROR
355
- return tool_return
356
-
357
-
358
- class AsyncIPythonInterpreter(AsyncActionMixin, IPythonInterpreter):
359
- """A IPython executor that can execute Python scripts in a jupyter manner.
360
-
361
- Args:
362
- timeout (int): Upper bound of waiting time for Python script execution.
363
- Defaults to 20.
364
- user_data_dir (str, optional): Specified the user data directory for files
365
- loading. If set to `ENV`, use `USER_DATA_DIR` environment variable.
366
- Defaults to `ENV`.
367
- work_dir (str, optional): Specify which directory to save output images to.
368
- Defaults to ``'./work_dir/tmp_dir'``.
369
- description (dict): The description of the action. Defaults to ``None``.
370
- parser (Type[BaseParser]): The parser class to process the
371
- action's inputs and outputs. Defaults to :class:`JsonParser`.
372
- """
373
-
374
- _UNBOUND_KERNEL_CLIENTS = asyncio.Queue()
375
-
376
- def __init__(
377
- self,
378
- timeout: int = 20,
379
- user_data_dir: str = 'ENV',
380
- work_dir=os.path.join(tempfile.gettempdir(), 'tmp_dir'),
381
- max_kernels: Optional[int] = None,
382
- reuse_kernel: bool = True,
383
- startup_rate: bool = 32,
384
- connection_dir: str = tempfile.gettempdir(),
385
- description: Optional[dict] = None,
386
- parser: Type[BaseParser] = JsonParser,
387
- ):
388
- super().__init__(timeout, user_data_dir, work_dir, description, parser)
389
- from traitlets.config import Config
390
-
391
- c = Config()
392
- c.KernelManager.transport = 'ipc'
393
- self._amkm = AsyncMultiKernelManager(
394
- config=c, connection_dir=connection_dir)
395
- self._max_kernels = max_kernels
396
- self._reuse_kernel = reuse_kernel
397
- self._sem = asyncio.Semaphore(startup_rate)
398
- self._lock = asyncio.Lock()
399
-
400
- async def initialize(self, session_id: str):
401
- session_id = str(session_id)
402
- while True:
403
- if session_id in self._KERNEL_CLIENTS:
404
- return self._KERNEL_CLIENTS[session_id]
405
- if self._reuse_kernel and not self._UNBOUND_KERNEL_CLIENTS.empty():
406
- self._KERNEL_CLIENTS[
407
- session_id] = await self._UNBOUND_KERNEL_CLIENTS.get()
408
- return self._KERNEL_CLIENTS[session_id]
409
- async with self._sem:
410
- if self._max_kernels is None or len(
411
- self._KERNEL_CLIENTS
412
- ) + self._UNBOUND_KERNEL_CLIENTS.qsize() < self._max_kernels:
413
- kernel_id = None
414
- try:
415
- kernel_id = await self._amkm.start_kernel()
416
- kernel = self._amkm.get_kernel(kernel_id)
417
- client = kernel.client()
418
- _, error_stacktrace, stream_text = await async_run_code(
419
- kernel,
420
- START_CODE.format(self.user_data_dir),
421
- shutdown_kernel=False)
422
- # check if the output of START_CODE meets expectations
423
- if not (error_stacktrace is None
424
- and stream_text == ''):
425
- raise RuntimeError
426
- except Exception as e:
427
- print(f'Starting kernel error: {e}')
428
- if kernel_id:
429
- await self._amkm.shutdown_kernel(kernel_id)
430
- self._amkm.remove_kernel(kernel_id)
431
- await asyncio.sleep(1)
432
- continue
433
- if self._max_kernels is None:
434
- self._KERNEL_CLIENTS[session_id] = (kernel_id, kernel,
435
- client)
436
- return kernel_id, kernel, client
437
- async with self._lock:
438
- if len(self._KERNEL_CLIENTS
439
- ) + self._UNBOUND_KERNEL_CLIENTS.qsize(
440
- ) < self._max_kernels:
441
- self._KERNEL_CLIENTS[session_id] = (kernel_id,
442
- kernel, client)
443
- return kernel_id, kernel, client
444
- await self._amkm.shutdown_kernel(kernel_id)
445
- self._amkm.remove_kernel(kernel_id)
446
- await asyncio.sleep(1)
447
-
448
- async def reset(self, session_id: str):
449
- session_id = str(session_id)
450
- if session_id not in self._KERNEL_CLIENTS:
451
- return
452
- _, kernel, _ = self._KERNEL_CLIENTS[session_id]
453
- code = "get_ipython().run_line_magic('reset', '-f')\n" + \
454
- START_CODE.format(self.user_data_dir)
455
- await async_run_code(kernel, code, shutdown_kernel=False)
456
-
457
- async def shutdown(self, session_id: str):
458
- session_id = str(session_id)
459
- if session_id in self._KERNEL_CLIENTS:
460
- kernel_id, _, _ = self._KERNEL_CLIENTS.get(session_id)
461
- await self._amkm.shutdown_kernel(kernel_id)
462
- self._amkm.remove_kernel(kernel_id)
463
- del self._KERNEL_CLIENTS[session_id]
464
-
465
- async def close_session(self, session_id: str):
466
- session_id = str(session_id)
467
- if self._reuse_kernel:
468
- if session_id in self._KERNEL_CLIENTS:
469
- await self.reset(session_id)
470
- await self._UNBOUND_KERNEL_CLIENTS.put(
471
- self._KERNEL_CLIENTS.pop(session_id))
472
- else:
473
- await self.shutdown(session_id)
474
-
475
- async def _call(self, command, timeout=None, session_id=None):
476
- _, kernel, _ = await self.initialize(str(session_id))
477
- result = await async_run_code(
478
- kernel,
479
- extract_code(command),
480
- interrupt_after=timeout or self.timeout,
481
- shutdown_kernel=False)
482
- execute_result, error_stacktrace, stream_text = result
483
- if error_stacktrace is not None:
484
- ret = re.sub('^-*\n', '', escape_ansi(error_stacktrace))
485
- if ret.endswith('KeyboardInterrupt: '):
486
- ret = 'The code interpreter encountered a timeout error.'
487
- status, ret = False, ret.strip()
488
- elif execute_result is not None:
489
- status, ret = True, dict(text=execute_result.get('text/plain', ''))
490
- else:
491
- status, ret = True, dict(text=stream_text.strip())
492
- return status, ret
493
-
494
- @tool_api
495
- async def run(self,
496
- command: str,
497
- timeout: Optional[int] = None,
498
- session_id: Optional[str] = None) -> ActionReturn:
499
- r"""When you send a message containing Python code to python, it will be executed in a stateful Jupyter notebook environment. python will respond with the output of the execution or time out after 60.0 seconds. The drive at '/mnt/data' can be used to save and persist user files. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
500
-
501
- Args:
502
- command (:class:`str`): Python code
503
- timeout (:class:`Optional[int]`): Upper bound of waiting time for Python script execution.
504
- """
505
- tool_return = ActionReturn(url=None, args=None, type=self.name)
506
- tool_return.args = dict(text=command)
507
- succeed, result = await self._call(command, timeout, session_id)
508
- if succeed:
509
- text = result['text']
510
- image = result.get('image', [])
511
- resp = [dict(type='text', content=text)]
512
- if image:
513
- resp.extend([dict(type='image', content=im) for im in image])
514
- tool_return.result = resp
515
- # tool_return.result = dict(
516
- # text=result['text'], image=result.get('image', [])[0])
517
- tool_return.state = ActionStatusCode.SUCCESS
518
- else:
519
- tool_return.errmsg = result.get('text', '') if isinstance(
520
- result, dict) else result
521
- tool_return.state = ActionStatusCode.API_ERROR
522
- return tool_return
523
-
524
-
525
- def extract_code(text):
526
- import json5
527
-
528
- # Match triple backtick blocks first
529
- triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
530
- # Match single backtick blocks second
531
- single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
532
- if triple_match:
533
- text = triple_match.group(1)
534
- elif single_match:
535
- text = single_match.group(1)
536
- else:
537
- try:
538
- text = json5.loads(text)['code']
539
- except Exception:
540
- pass
541
- # If no code blocks found, return original text
542
- return text
543
-
544
-
545
- def escape_ansi(line):
546
- ansi_escape = re.compile(r'(?:\x1B[@-_]|[\x80-\x9F])[0-?]*[ -/]*[@-~]')
547
- return ansi_escape.sub('', line)
548
-
549
-
550
- def publish_image_to_local(image_base64: str, work_dir='./work_dir/tmp_dir'):
551
- import PIL.Image
552
- image_file = str(uuid.uuid4()) + '.png'
553
- local_image_file = os.path.join(work_dir, image_file)
554
-
555
- png_bytes = base64.b64decode(image_base64)
556
- assert isinstance(png_bytes, bytes)
557
- bytes_io = io.BytesIO(png_bytes)
558
- PIL.Image.open(bytes_io).save(local_image_file, 'png')
559
-
560
- return local_image_file
561
-
562
-
563
- # local test for code interpreter
564
- def get_multiline_input(hint):
565
- print(hint)
566
- print('// Press ENTER to make a new line. Press CTRL-D to end input.')
567
- lines = []
568
- while True:
569
- try:
570
- line = input()
571
- except EOFError: # CTRL-D
572
- break
573
- lines.append(line)
574
- print('// Input received.')
575
- if lines:
576
- return '\n'.join(lines)
577
- else:
578
- return ''
579
-
580
-
581
- if __name__ == '__main__':
582
- code_interpreter = IPythonInterpreter()
583
- while True:
584
- print(code_interpreter(get_multiline_input('Enter python code:')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/ipython_manager.py DELETED
@@ -1,220 +0,0 @@
1
- import re
2
- import sys
3
- from collections import defaultdict
4
- from contextlib import nullcontext
5
- from io import StringIO
6
- from multiprocessing import Process, Queue
7
- from typing import List, Optional, Type, Union
8
-
9
- from filelock import FileLock
10
- from timeout_decorator import timeout as tm
11
-
12
- from ..schema import ActionReturn, ActionStatusCode
13
- from .base_action import BaseAction
14
- from .parser import BaseParser, JsonParser
15
-
16
-
17
- class IPythonProcess(Process):
18
-
19
- def __init__(self,
20
- in_q: Queue,
21
- out_q: Queue,
22
- timeout: int = 20,
23
- ci_lock: str = None,
24
- daemon: bool = True):
25
- super().__init__(daemon=daemon)
26
- self.in_q = in_q
27
- self.out_q = out_q
28
- self.timeout = timeout
29
- self.session_id2shell = defaultdict(self.create_shell)
30
- self.ci_lock = FileLock(
31
- ci_lock) if ci_lock else nullcontext() # avoid core corruption
32
- self._highlighting = re.compile(r'\x1b\[\d{,3}(;\d{,3}){,3}m')
33
-
34
- def run(self):
35
- while True:
36
- msg = self.in_q.get()
37
- if msg == 'reset':
38
- for session_id, shell in self.session_id2shell.items():
39
- with self.ci_lock:
40
- try:
41
- shell.reset(new_session=False)
42
- # shell.run_line_magic('reset', '-sf')
43
- except Exception:
44
- self.session_id2shell[
45
- session_id] = self.create_shell()
46
- self.out_q.put('ok')
47
- elif isinstance(msg, tuple) and len(msg) == 3:
48
- i, session_id, code = msg
49
- res = self.exec(session_id, code)
50
- self.out_q.put((i, session_id, res))
51
-
52
- def exec(self, session_id, code):
53
- try:
54
- shell = self.session_id2shell[session_id]
55
- with StringIO() as io:
56
- old_stdout = sys.stdout
57
- sys.stdout = io
58
- if self.timeout is False or self.timeout < 0:
59
- shell.run_cell(self.extract_code(code))
60
- else:
61
- tm(self.timeout)(shell.run_cell)(self.extract_code(code))
62
- sys.stdout = old_stdout
63
- output = self._highlighting.sub('', io.getvalue().strip())
64
- output = re.sub(r'^Out\[\d+\]: ', '', output)
65
- if 'Error' in output or 'Traceback' in output:
66
- output = output.lstrip('-').strip()
67
- if output.startswith('TimeoutError'):
68
- output = 'The code interpreter encountered a timeout error.'
69
- return {'status': 'FAILURE', 'msg': output, 'code': code}
70
- return {'status': 'SUCCESS', 'value': output, 'code': code}
71
- except Exception as e:
72
- return {'status': 'FAILURE', 'msg': str(e), 'code': code}
73
-
74
- @staticmethod
75
- def create_shell(enable_history: bool = False, in_memory: bool = True):
76
- from IPython import InteractiveShell
77
- from traitlets.config import Config
78
-
79
- c = Config()
80
- c.HistoryManager.enabled = enable_history
81
- if in_memory:
82
- c.HistoryManager.hist_file = ':memory:'
83
- shell = InteractiveShell(config=c)
84
- return shell
85
-
86
- @staticmethod
87
- def extract_code(text: str) -> str:
88
- """Extract Python code from markup languages.
89
-
90
- Args:
91
- text (:class:`str`): Markdown-formatted text
92
-
93
- Returns:
94
- :class:`str`: Python code
95
- """
96
- import json5
97
-
98
- # Match triple backtick blocks first
99
- triple_match = re.search(r'```[^\n]*\n(.+?)```', text, re.DOTALL)
100
- # Match single backtick blocks second
101
- single_match = re.search(r'`([^`]*)`', text, re.DOTALL)
102
- if triple_match:
103
- text = triple_match.group(1)
104
- elif single_match:
105
- text = single_match.group(1)
106
- else:
107
- try:
108
- text = json5.loads(text)['code']
109
- except Exception:
110
- pass
111
- # If no code blocks found, return original text
112
- return text
113
-
114
-
115
- class IPythonInteractiveManager(BaseAction):
116
- """An interactive IPython shell manager for code execution"""
117
-
118
- def __init__(
119
- self,
120
- max_workers: int = 50,
121
- timeout: int = 20,
122
- ci_lock: str = None,
123
- description: Optional[dict] = None,
124
- parser: Type[BaseParser] = JsonParser,
125
- ):
126
- super().__init__(description, parser)
127
- self.max_workers = max_workers
128
- self.timeout = timeout
129
- self.ci_lock = ci_lock
130
- self.id2queue = defaultdict(Queue)
131
- self.id2process = {}
132
- self.out_queue = Queue()
133
-
134
- def __call__(self,
135
- commands: Union[str, List[str]],
136
- session_ids: Union[int, List[int]] = None):
137
- if isinstance(commands, list):
138
- batch_size = len(commands)
139
- is_batch = True
140
- else:
141
- batch_size = 1
142
- commands = [commands]
143
- is_batch = False
144
- if session_ids is None:
145
- session_ids = range(batch_size)
146
- elif isinstance(session_ids, int):
147
- session_ids = [session_ids]
148
- if len(session_ids) != batch_size or len(session_ids) != len(
149
- set(session_ids)):
150
- raise ValueError(
151
- 'the size of `session_ids` must equal that of `commands`')
152
- try:
153
- exec_results = self.run_code_blocks([
154
- (session_id, command)
155
- for session_id, command in zip(session_ids, commands)
156
- ])
157
- except KeyboardInterrupt:
158
- self.clear()
159
- exit(1)
160
- action_returns = []
161
- for result, code in zip(exec_results, commands):
162
- action_return = ActionReturn({'command': code}, type=self.name)
163
- if result['status'] == 'SUCCESS':
164
- action_return.result = [
165
- dict(type='text', content=result['value'])
166
- ]
167
- action_return.state = ActionStatusCode.SUCCESS
168
- else:
169
- action_return.errmsg = result['msg']
170
- action_return.state = ActionStatusCode.API_ERROR
171
- action_returns.append(action_return)
172
- if not is_batch:
173
- return action_returns[0]
174
- return action_returns
175
-
176
- def process_code(self, index, session_id, code):
177
- ipy_id = session_id % self.max_workers
178
- input_queue = self.id2queue[ipy_id]
179
- proc = self.id2process.setdefault(
180
- ipy_id,
181
- IPythonProcess(
182
- input_queue,
183
- self.out_queue,
184
- self.timeout,
185
- self.ci_lock,
186
- daemon=True))
187
- if not proc.is_alive():
188
- proc.start()
189
- input_queue.put((index, session_id, code))
190
-
191
- def run_code_blocks(self, session_code_pairs):
192
- size = len(session_code_pairs)
193
- for index, (session_id, code) in enumerate(session_code_pairs):
194
- self.process_code(index, session_id, code)
195
- results = []
196
- while len(results) < size:
197
- msg = self.out_queue.get()
198
- if isinstance(msg, tuple) and len(msg) == 3:
199
- index, _, result = msg
200
- results.append((index, result))
201
- results.sort()
202
- return [item[1] for item in results]
203
-
204
- def clear(self):
205
- self.id2queue.clear()
206
- for proc in self.id2process.values():
207
- proc.terminate()
208
- self.id2process.clear()
209
- while not self.out_queue.empty():
210
- self.out_queue.get()
211
-
212
- def reset(self):
213
- cnt = 0
214
- for q in self.id2queue.values():
215
- q.put('reset')
216
- cnt += 1
217
- while cnt > 0:
218
- msg = self.out_queue.get()
219
- if msg == 'ok':
220
- cnt -= 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/parser.py DELETED
@@ -1,146 +0,0 @@
1
- import json
2
- import re
3
- from ast import literal_eval
4
- from typing import Any, List, Union
5
-
6
-
7
- class ParseError(Exception):
8
- """Parsing exception class."""
9
-
10
- def __init__(self, err_msg: str):
11
- self.err_msg = err_msg
12
-
13
-
14
- class BaseParser:
15
- """Base parser to process inputs and outputs of actions.
16
-
17
- Args:
18
- action (:class:`BaseAction`): action to validate
19
-
20
- Attributes:
21
- PARAMETER_DESCRIPTION (:class:`str`): declare the input format which
22
- LLMs should follow when generating arguments for decided tools.
23
- """
24
-
25
- PARAMETER_DESCRIPTION: str = ''
26
-
27
- def __init__(self, action):
28
- self.action = action
29
- self._api2param = {}
30
- self._api2required = {}
31
- # perform basic argument validation
32
- if action.description:
33
- for api in action.description.get('api_list',
34
- [action.description]):
35
- name = (f'{action.name}.{api["name"]}'
36
- if self.action.is_toolkit else api['name'])
37
- required_parameters = set(api['required'])
38
- all_parameters = {j['name'] for j in api['parameters']}
39
- if not required_parameters.issubset(all_parameters):
40
- raise ValueError(
41
- f'unknown parameters for function "{name}": '
42
- f'{required_parameters - all_parameters}')
43
- if self.PARAMETER_DESCRIPTION:
44
- api['parameter_description'] = self.PARAMETER_DESCRIPTION
45
- api_name = api['name'] if self.action.is_toolkit else 'run'
46
- self._api2param[api_name] = api['parameters']
47
- self._api2required[api_name] = api['required']
48
-
49
- def parse_inputs(self, inputs: str, name: str = 'run') -> dict:
50
- """Parse inputs LLMs generate for the action.
51
-
52
- Args:
53
- inputs (:class:`str`): input string extracted from responses
54
-
55
- Returns:
56
- :class:`dict`: processed input
57
- """
58
- inputs = {self._api2param[name][0]['name']: inputs}
59
- return inputs
60
-
61
- def parse_outputs(self, outputs: Any) -> List[dict]:
62
- """Parser outputs returned by the action.
63
-
64
- Args:
65
- outputs (:class:`Any`): raw output of the action
66
-
67
- Returns:
68
- :class:`List[dict]`: processed output of which each member is a
69
- dictionary with two keys - 'type' and 'content'.
70
- """
71
- if isinstance(outputs, dict):
72
- outputs = json.dumps(outputs, ensure_ascii=False)
73
- elif not isinstance(outputs, str):
74
- outputs = str(outputs)
75
- return [{
76
- 'type': 'text',
77
- 'content': outputs.encode('gbk', 'ignore').decode('gbk')
78
- }]
79
-
80
-
81
- class JsonParser(BaseParser):
82
- """Json parser to convert input string into a dictionary.
83
-
84
- Args:
85
- action (:class:`BaseAction`): action to validate
86
- """
87
-
88
- PARAMETER_DESCRIPTION = (
89
- 'If you call this tool, you must pass arguments in '
90
- 'the JSON format {key: value}, where the key is the parameter name.')
91
-
92
- def parse_inputs(self,
93
- inputs: Union[str, dict],
94
- name: str = 'run') -> dict:
95
- if not isinstance(inputs, dict):
96
- try:
97
- match = re.search(r'^\s*(```json\n)?(.*)\n```\s*$', inputs,
98
- re.S)
99
- if match:
100
- inputs = match.group(2).strip()
101
- inputs = json.loads(inputs)
102
- except json.JSONDecodeError as exc:
103
- raise ParseError(f'invalid json format: {inputs}') from exc
104
- input_keys = set(inputs)
105
- all_keys = {param['name'] for param in self._api2param[name]}
106
- if not input_keys.issubset(all_keys):
107
- raise ParseError(f'unknown arguments: {input_keys - all_keys}')
108
- required_keys = set(self._api2required[name])
109
- if not input_keys.issuperset(required_keys):
110
- raise ParseError(
111
- f'missing required arguments: {required_keys - input_keys}')
112
- return inputs
113
-
114
-
115
- class TupleParser(BaseParser):
116
- """Tuple parser to convert input string into a tuple.
117
-
118
- Args:
119
- action (:class:`BaseAction`): action to validate
120
- """
121
-
122
- PARAMETER_DESCRIPTION = (
123
- 'If you call this tool, you must pass arguments in the tuple format '
124
- 'like (arg1, arg2, arg3), and the arguments are ordered.')
125
-
126
- def parse_inputs(self,
127
- inputs: Union[str, tuple],
128
- name: str = 'run') -> dict:
129
- if not isinstance(inputs, tuple):
130
- try:
131
- inputs = literal_eval(inputs)
132
- except Exception as exc:
133
- raise ParseError(f'invalid tuple format: {inputs}') from exc
134
- if len(inputs) < len(self._api2required[name]):
135
- raise ParseError(
136
- f'API takes {len(self._api2required[name])} required positional '
137
- f'arguments but {len(inputs)} were given')
138
- if len(inputs) > len(self._api2param[name]):
139
- raise ParseError(
140
- f'API takes {len(self._api2param[name])} positional arguments '
141
- f'but {len(inputs)} were given')
142
- inputs = {
143
- self._api2param[name][i]['name']: item
144
- for i, item in enumerate(inputs)
145
- }
146
- return inputs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/ppt.py DELETED
@@ -1,233 +0,0 @@
1
- from typing import Dict, Optional, Type
2
-
3
- from asyncer import asyncify
4
-
5
- from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
6
- from lagent.actions.parser import BaseParser, JsonParser
7
-
8
- THEME_MAPPING = {
9
- 'Default': {
10
- 'template': None,
11
- 'title': 'Title Slide',
12
- 'single': 'Title and Content',
13
- 'two': 'Two Content',
14
- }
15
- }
16
-
17
-
18
- class PPT(BaseAction):
19
- """Plugin to create ppt slides with text, paragraph, images in good looking styles."""
20
-
21
- def __init__(
22
- self,
23
- theme_mapping: Optional[Dict[str, dict]] = None,
24
- description: Optional[dict] = None,
25
- parser: Type[BaseParser] = JsonParser,
26
- ):
27
- super().__init__(description, parser)
28
- self.theme_mapping = theme_mapping or THEME_MAPPING
29
- self.pointer = None
30
- self.location = None
31
-
32
- @tool_api(explode_return=True)
33
- def create_file(self, theme: str, abs_location: str) -> dict:
34
- """Create a pptx file with specific themes.
35
-
36
- Args:
37
- theme (:class:`str`): the theme used. The value should be one of ['Default'].
38
- abs_location (:class:`str`): the ppt file's absolute location
39
-
40
- Returns:
41
- :class:`dict`: operation status
42
- * status: the result of the execution
43
- """
44
- from pptx import Presentation
45
-
46
- self.location = abs_location
47
- try:
48
- self.pointer = Presentation(self.theme_mapping[theme]['template'])
49
- self.pointer.slide_master.name = theme
50
- # print('created')
51
- except Exception as e:
52
- print(e)
53
- return dict(status='created a ppt file.')
54
-
55
- @tool_api(explode_return=True)
56
- def add_first_page(self, title: str, subtitle: str) -> dict:
57
- """Add the first page of ppt.
58
-
59
- Args:
60
- title (:class:`str`): the title of ppt
61
- subtitle (:class:`str`): the subtitle of ppt
62
-
63
- Returns:
64
- :class:`dict`: operation status
65
- * status: the result of the execution
66
- """
67
- layout_name = self.theme_mapping[self.pointer.slide_master.name]['title']
68
- layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
69
- slide = self.pointer.slides.add_slide(layout)
70
- ph_title, ph_subtitle = slide.placeholders
71
- ph_title.text = title
72
- if subtitle:
73
- ph_subtitle.text = subtitle
74
- return dict(status='added page')
75
-
76
- @tool_api(explode_return=True)
77
- def add_text_page(self, title: str, bullet_items: str) -> dict:
78
- """Add text page of ppt.
79
-
80
- Args:
81
- title (:class:`str`): the title of the page
82
- bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
83
-
84
- Returns:
85
- :class:`dict`: operation status
86
- * status: the result of the execution
87
- """ # noqa: E501
88
- layout_name = self.theme_mapping[self.pointer.slide_master.name]['single']
89
- layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
90
- slide = self.pointer.slides.add_slide(layout)
91
- ph_title, ph_body = slide.placeholders
92
- ph_title.text = title
93
- ph = ph_body
94
- tf = ph.text_frame
95
- for i, item in enumerate(bullet_items.split('[SPAN]')):
96
- if i == 0:
97
- p = tf.paragraphs[0]
98
- else:
99
- p = tf.add_paragraph()
100
- p.text = item.strip()
101
- p.level = 0
102
- return dict(status='added page')
103
-
104
- @tool_api(explode_return=True)
105
- def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
106
- """Add a text page with one image. Image should be a path.
107
-
108
- Args:
109
- title (:class:`str`): the title of the page
110
- bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
111
- image (:class:`str`): the path of the image
112
-
113
- Returns:
114
- :class:`dict`: operation status
115
- * status: the result of the execution
116
- """ # noqa: E501
117
- from PIL import Image
118
-
119
- layout_name = self.theme_mapping[self.pointer.slide_master.name]['two']
120
- layout = next(i for i in self.pointer.slide_master.slide_layouts if i.name == layout_name)
121
- slide = self.pointer.slides.add_slide(layout)
122
- ph_title, ph_body1, ph_body2 = slide.placeholders
123
- ph_title.text = title
124
- ph = ph_body2
125
- image = Image.open(image)
126
- image_pil = image.to_pil()
127
- left = ph.left
128
- width = ph.width
129
- height = int(width / image_pil.width * image_pil.height)
130
- top = (ph.top + (ph.top + ph.height)) // 2 - height // 2
131
- slide.shapes.add_picture(image.to_path(), left, top, width, height)
132
-
133
- ph = ph_body1
134
- tf = ph.text_frame
135
- for i, item in enumerate(bullet_items.split('[SPAN]')):
136
- if i == 0:
137
- p = tf.paragraphs[0]
138
- else:
139
- p = tf.add_paragraph()
140
- p.text = item.strip()
141
- p.level = 0
142
-
143
- return dict(status='added page')
144
-
145
- @tool_api(explode_return=True)
146
- def submit_file(self) -> dict:
147
- """When all steps done, YOU MUST use submit_file() to submit your work.
148
-
149
- Returns:
150
- :class:`dict`: operation status
151
- * status: the result of the execution
152
- """
153
- # file_path = os.path.join(self.CACHE_DIR, f'{self._return_timestamp()}.pptx')
154
- # self.pointer.save(file_path)
155
- # retreival_url = upload_file(file_path)
156
- self.pointer.save(self.location)
157
- return dict(status=f'submitted. view ppt at {self.location}')
158
-
159
-
160
- class AsyncPPT(AsyncActionMixin, PPT):
161
- """Plugin to create ppt slides with text, paragraph, images in good looking styles."""
162
-
163
- @tool_api(explode_return=True)
164
- @asyncify
165
- def create_file(self, theme: str, abs_location: str) -> dict:
166
- """Create a pptx file with specific themes.
167
-
168
- Args:
169
- theme (:class:`str`): the theme used. The value should be one of ['Default'].
170
- abs_location (:class:`str`): the ppt file's absolute location
171
-
172
- Returns:
173
- :class:`dict`: operation status
174
- * status: the result of the execution
175
- """
176
- return super().create_file(theme, abs_location)
177
-
178
- @tool_api(explode_return=True)
179
- @asyncify
180
- def add_first_page(self, title: str, subtitle: str) -> dict:
181
- """Add the first page of ppt.
182
-
183
- Args:
184
- title (:class:`str`): the title of ppt
185
- subtitle (:class:`str`): the subtitle of ppt
186
-
187
- Returns:
188
- :class:`dict`: operation status
189
- * status: the result of the execution
190
- """
191
- return super().add_first_page(title, subtitle)
192
-
193
- @tool_api(explode_return=True)
194
- @asyncify
195
- def add_text_page(self, title: str, bullet_items: str) -> dict:
196
- """Add text page of ppt.
197
-
198
- Args:
199
- title (:class:`str`): the title of the page
200
- bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
201
-
202
- Returns:
203
- :class:`dict`: operation status
204
- * status: the result of the execution
205
- """ # noqa: E501
206
- return super().add_text_page(title, bullet_items)
207
-
208
- @tool_api(explode_return=True)
209
- @asyncify
210
- def add_text_image_page(self, title: str, bullet_items: str, image: str) -> dict:
211
- """Add a text page with one image. Image should be a path.
212
-
213
- Args:
214
- title (:class:`str`): the title of the page
215
- bullet_items (:class:`str`): bullet_items should be string, for multiple bullet items, please use [SPAN] to separate them.
216
- image (:class:`str`): the path of the image
217
-
218
- Returns:
219
- :class:`dict`: operation status
220
- * status: the result of the execution
221
- """ # noqa: E501
222
- return super().add_text_image_page(title, bullet_items, image)
223
-
224
- @tool_api(explode_return=True)
225
- @asyncify
226
- def submit_file(self) -> dict:
227
- """When all steps done, YOU MUST use submit_file() to submit your work.
228
-
229
- Returns:
230
- :class:`dict`: operation status
231
- * status: the result of the execution
232
- """
233
- return super().submit_file()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/python_interpreter.py DELETED
@@ -1,176 +0,0 @@
1
- # flake8: noqa: E501
2
- import copy
3
- import io
4
- from contextlib import redirect_stdout
5
- from typing import Any, Optional, Type
6
-
7
- from asyncer import asyncify
8
-
9
- from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
10
- from lagent.actions.parser import BaseParser, JsonParser
11
- from lagent.schema import ActionReturn, ActionStatusCode
12
-
13
-
14
- class GenericRuntime:
15
- GLOBAL_DICT = {}
16
- LOCAL_DICT = None
17
- HEADERS = []
18
-
19
- def __init__(self):
20
- self._global_vars = copy.copy(self.GLOBAL_DICT)
21
- self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None
22
-
23
- for c in self.HEADERS:
24
- self.exec_code(c)
25
-
26
- def exec_code(self, code_piece: str) -> None:
27
- exec(code_piece, self._global_vars)
28
-
29
- def eval_code(self, expr: str) -> Any:
30
- return eval(expr, self._global_vars)
31
-
32
-
33
- class PythonInterpreter(BaseAction):
34
- """A Python executor that can execute Python scripts.
35
-
36
- Args:
37
- answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
38
- answer_expr (str, Optional): the answer function name of the Python
39
- script. Defaults to ``'solution()'``.
40
- answer_from_stdout (boolean, Optional): whether the execution results is from
41
- stdout. Defaults to ``False``.
42
- timeout (int, Optional): Upper bound of waiting time for Python script execution.
43
- Defaults to ``20``.
44
- description (dict, Optional): The description of the action. Defaults to ``None``.
45
- parser (Type[BaseParser]): The parser class to process the
46
- action's inputs and outputs. Defaults to :class:`JsonParser`.
47
- """
48
-
49
- def __init__(
50
- self,
51
- answer_symbol: Optional[str] = None,
52
- answer_expr: Optional[str] = 'solution()',
53
- answer_from_stdout: bool = False,
54
- timeout: int = 20,
55
- description: Optional[dict] = None,
56
- parser: Type[BaseParser] = JsonParser,
57
- ) -> None:
58
- super().__init__(description, parser)
59
- self.answer_symbol = answer_symbol
60
- self.answer_expr = answer_expr
61
- self.answer_from_stdout = answer_from_stdout
62
- self.timeout = timeout
63
-
64
- @tool_api
65
- def run(self, command: str) -> ActionReturn:
66
- """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
67
-
68
- ```python
69
- # import 依赖包
70
- import xxx
71
- def solution():
72
- # 初始化一些变量
73
- variable_names_with_real_meaning = xxx
74
- # 步骤一
75
- mid_variable = func(variable_names_with_real_meaning)
76
- # 步骤 x
77
- mid_variable = func(mid_variable)
78
- # 最后结果
79
- final_answer = func(mid_variable)
80
- return final_answer
81
- ```
82
-
83
- Args:
84
- command (:class:`str`): Python code snippet
85
- """
86
- from func_timeout import FunctionTimedOut, func_set_timeout
87
-
88
- self.runtime = GenericRuntime()
89
- try:
90
- tool_return = func_set_timeout(self.timeout)(self._call)(command)
91
- except FunctionTimedOut as e:
92
- tool_return = ActionReturn(type=self.name)
93
- tool_return.errmsg = repr(e)
94
- tool_return.state = ActionStatusCode.API_ERROR
95
- return tool_return
96
-
97
- def _call(self, command: str) -> ActionReturn:
98
- tool_return = ActionReturn(type=self.name)
99
- try:
100
- if '```python' in command:
101
- command = command.split('```python')[1].split('```')[0]
102
- elif '```' in command:
103
- command = command.split('```')[1].split('```')[0]
104
- tool_return.args = dict(text='```python\n' + command + '\n```')
105
- command = command.split('\n')
106
-
107
- if self.answer_from_stdout:
108
- program_io = io.StringIO()
109
- with redirect_stdout(program_io):
110
- self.runtime.exec_code('\n'.join(command))
111
- program_io.seek(0)
112
- res = program_io.readlines()[-1]
113
- elif self.answer_symbol:
114
- self.runtime.exec_code('\n'.join(command))
115
- res = self.runtime._global_vars[self.answer_symbol]
116
- elif self.answer_expr:
117
- self.runtime.exec_code('\n'.join(command))
118
- res = self.runtime.eval_code(self.answer_expr)
119
- else:
120
- self.runtime.exec_code('\n'.join(command[:-1]))
121
- res = self.runtime.eval_code(command[-1])
122
- except Exception as e:
123
- tool_return.errmsg = repr(e)
124
- tool_return.type = self.name
125
- tool_return.state = ActionStatusCode.API_ERROR
126
- return tool_return
127
- try:
128
- tool_return.result = [dict(type='text', content=str(res))]
129
- tool_return.state = ActionStatusCode.SUCCESS
130
- except Exception as e:
131
- tool_return.errmsg = repr(e)
132
- tool_return.type = self.name
133
- tool_return.state = ActionStatusCode.API_ERROR
134
- return tool_return
135
-
136
-
137
- class AsyncPythonInterpreter(AsyncActionMixin, PythonInterpreter):
138
- """A Python executor that can execute Python scripts.
139
-
140
- Args:
141
- answer_symbol (str, Optional): the answer symbol from LLM. Defaults to ``None``.
142
- answer_expr (str, Optional): the answer function name of the Python
143
- script. Defaults to ``'solution()'``.
144
- answer_from_stdout (boolean, Optional): whether the execution results is from
145
- stdout. Defaults to ``False``.
146
- timeout (int, Optional): Upper bound of waiting time for Python script execution.
147
- Defaults to ``20``.
148
- description (dict, Optional): The description of the action. Defaults to ``None``.
149
- parser (Type[BaseParser]): The parser class to process the
150
- action's inputs and outputs. Defaults to :class:`JsonParser`.
151
- """
152
-
153
- @tool_api
154
- @asyncify
155
- def run(self, command: str) -> ActionReturn:
156
- """用来执行Python代码。代码必须是一个函数,函数名必须得是 'solution',代码对应你的思考过程。代码实例格式如下:
157
-
158
- ```python
159
- # import 依赖包
160
- import xxx
161
- def solution():
162
- # 初始化一些变量
163
- variable_names_with_real_meaning = xxx
164
- # 步骤一
165
- mid_variable = func(variable_names_with_real_meaning)
166
- # 步骤 x
167
- mid_variable = func(mid_variable)
168
- # 最后结果
169
- final_answer = func(mid_variable)
170
- return final_answer
171
- ```
172
-
173
- Args:
174
- command (:class:`str`): Python code snippet
175
- """
176
- return super().run(command)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/weather_query.py DELETED
@@ -1,71 +0,0 @@
1
- import os
2
- import requests
3
- from lagent.actions.base_action import BaseAction, tool_api
4
- from lagent.schema import ActionReturn, ActionStatusCode
5
-
6
- class WeatherQuery(BaseAction):
7
- def __init__(self):
8
- super().__init__()
9
- self.api_key = os.getenv("weather_token")
10
- print(self.api_key)
11
- if not self.api_key:
12
- raise EnvironmentError("未找到环境变量 'token'。请设置你的和风天气 API Key 到 'weather_token' 环境变量中,比如export weather_token='xxx' ")
13
-
14
- @tool_api
15
- def run(self, location: str) -> dict:
16
- """
17
- 查询实时天气信息。
18
-
19
- Args:
20
- location (str): 要查询的地点名称、LocationID 或经纬度坐标(如 "101010100" 或 "116.41,39.92")。
21
-
22
- Returns:
23
- dict: 包含天气信息的字典
24
- * location: 地点名称
25
- * weather: 天气状况
26
- * temperature: 当前温度
27
- * wind_direction: 风向
28
- * wind_speed: 风速(公里/小时)
29
- * humidity: 相对湿度(%)
30
- * report_time: 数据报告时间
31
- """
32
- try:
33
- # 如果 location 不是坐标格式(例如 "116.41,39.92"),则调用 GeoAPI 获取 LocationID
34
- if not ("," in location and location.replace(",", "").replace(".", "").isdigit()):
35
- # 使用 GeoAPI 获取 LocationID
36
- geo_url = f"https://geoapi.qweather.com/v2/city/lookup?location={location}&key={self.api_key}"
37
- geo_response = requests.get(geo_url)
38
- geo_data = geo_response.json()
39
-
40
- if geo_data.get("code") != "200" or not geo_data.get("location"):
41
- raise Exception(f"GeoAPI 返回错误码:{geo_data.get('code')} 或未找到位置")
42
-
43
- location = geo_data["location"][0]["id"]
44
-
45
- # 构建天气查询的 API 请求 URL
46
- weather_url = f"https://devapi.qweather.com/v7/weather/now?location={location}&key={self.api_key}"
47
- response = requests.get(weather_url)
48
- data = response.json()
49
-
50
- # 检查 API 响应码
51
- if data.get("code") != "200":
52
- raise Exception(f"Weather API 返回错误码:{data.get('code')}")
53
-
54
- # 解析和组织天气信息
55
- weather_info = {
56
- "location": location,
57
- "weather": data["now"]["text"],
58
- "temperature": data["now"]["temp"] + "°C",
59
- "wind_direction": data["now"]["windDir"],
60
- "wind_speed": data["now"]["windSpeed"] + " km/h",
61
- "humidity": data["now"]["humidity"] + "%",
62
- "report_time": data["updateTime"]
63
- }
64
-
65
- return {"result": weather_info}
66
-
67
- except Exception as exc:
68
- return ActionReturn(
69
- errmsg=f"WeatherQuery 异常:{exc}",
70
- state=ActionStatusCode.HTTP_ERROR
71
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
actions/web_browser.py DELETED
@@ -1,908 +0,0 @@
1
- import asyncio
2
- import hashlib
3
- import hmac
4
- import json
5
- import logging
6
- import random
7
- import re
8
- import time
9
- import warnings
10
- from concurrent.futures import ThreadPoolExecutor, as_completed
11
- from datetime import datetime
12
- from http.client import HTTPSConnection
13
- from typing import List, Optional, Tuple, Type, Union
14
-
15
- import aiohttp
16
- import aiohttp.client_exceptions
17
- import requests
18
- from asyncache import cached as acached
19
- from bs4 import BeautifulSoup
20
- from cachetools import TTLCache, cached
21
- from duckduckgo_search import DDGS, AsyncDDGS
22
-
23
- from lagent.actions.base_action import AsyncActionMixin, BaseAction, tool_api
24
- from lagent.actions.parser import BaseParser, JsonParser
25
- from lagent.utils import async_as_completed
26
-
27
-
28
- class BaseSearch:
29
-
30
- def __init__(self, topk: int = 3, black_list: List[str] = None):
31
- self.topk = topk
32
- self.black_list = black_list
33
-
34
- def _filter_results(self, results: List[tuple]) -> dict:
35
- filtered_results = {}
36
- count = 0
37
- for url, snippet, title in results:
38
- if all(domain not in url
39
- for domain in self.black_list) and not url.endswith('.pdf'):
40
- filtered_results[count] = {
41
- 'url': url,
42
- 'summ': json.dumps(snippet, ensure_ascii=False)[1:-1],
43
- 'title': title
44
- }
45
- count += 1
46
- if count >= self.topk:
47
- break
48
- return filtered_results
49
-
50
-
51
- class DuckDuckGoSearch(BaseSearch):
52
-
53
- def __init__(self,
54
- topk: int = 3,
55
- black_list: List[str] = [
56
- 'enoN',
57
- 'youtube.com',
58
- 'bilibili.com',
59
- 'researchgate.net',
60
- ],
61
- **kwargs):
62
- self.proxy = kwargs.get('proxy')
63
- self.timeout = kwargs.get('timeout', 30)
64
- super().__init__(topk, black_list)
65
-
66
- @cached(cache=TTLCache(maxsize=100, ttl=600))
67
- def search(self, query: str, max_retry: int = 3) -> dict:
68
- for attempt in range(max_retry):
69
- try:
70
- response = self._call_ddgs(
71
- query, timeout=self.timeout, proxy=self.proxy)
72
- return self._parse_response(response)
73
- except Exception as e:
74
- logging.exception(str(e))
75
- warnings.warn(
76
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
77
- time.sleep(random.randint(2, 5))
78
- raise Exception(
79
- 'Failed to get search results from DuckDuckGo after retries.')
80
-
81
- @acached(cache=TTLCache(maxsize=100, ttl=600))
82
- async def asearch(self, query: str, max_retry: int = 3) -> dict:
83
- for attempt in range(max_retry):
84
- try:
85
- ddgs = AsyncDDGS(timeout=self.timeout, proxy=self.proxy)
86
- response = await ddgs.atext(query.strip("'"), max_results=10)
87
- return self._parse_response(response)
88
- except Exception as e:
89
- if isinstance(e, asyncio.TimeoutError):
90
- logging.exception('Request to DDGS timed out.')
91
- logging.exception(str(e))
92
- warnings.warn(
93
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
94
- await asyncio.sleep(random.randint(2, 5))
95
- raise Exception(
96
- 'Failed to get search results from DuckDuckGo after retries.')
97
-
98
- async def _async_call_ddgs(self, query: str, **kwargs) -> dict:
99
- ddgs = DDGS(**kwargs)
100
- try:
101
- response = await asyncio.wait_for(
102
- asyncio.to_thread(ddgs.text, query.strip("'"), max_results=10),
103
- timeout=self.timeout)
104
- return response
105
- except asyncio.TimeoutError:
106
- logging.exception('Request to DDGS timed out.')
107
- raise
108
-
109
- def _call_ddgs(self, query: str, **kwargs) -> dict:
110
- loop = asyncio.new_event_loop()
111
- asyncio.set_event_loop(loop)
112
- try:
113
- response = loop.run_until_complete(
114
- self._async_call_ddgs(query, **kwargs))
115
- return response
116
- finally:
117
- loop.close()
118
-
119
- def _parse_response(self, response: dict) -> dict:
120
- raw_results = []
121
- for item in response:
122
- raw_results.append(
123
- (item['href'], item['description']
124
- if 'description' in item else item['body'], item['title']))
125
- return self._filter_results(raw_results)
126
-
127
-
128
- class BingSearch(BaseSearch):
129
-
130
- def __init__(self,
131
- api_key: str,
132
- region: str = 'zh-CN',
133
- topk: int = 3,
134
- black_list: List[str] = [
135
- 'enoN',
136
- 'youtube.com',
137
- 'bilibili.com',
138
- 'researchgate.net',
139
- ],
140
- **kwargs):
141
- self.api_key = api_key
142
- self.market = region
143
- self.proxy = kwargs.get('proxy')
144
- super().__init__(topk, black_list)
145
-
146
- @cached(cache=TTLCache(maxsize=100, ttl=600))
147
- def search(self, query: str, max_retry: int = 3) -> dict:
148
- for attempt in range(max_retry):
149
- try:
150
- response = self._call_bing_api(query)
151
- return self._parse_response(response)
152
- except Exception as e:
153
- logging.exception(str(e))
154
- warnings.warn(
155
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
156
- time.sleep(random.randint(2, 5))
157
- raise Exception(
158
- 'Failed to get search results from Bing Search after retries.')
159
-
160
- @acached(cache=TTLCache(maxsize=100, ttl=600))
161
- async def asearch(self, query: str, max_retry: int = 3) -> dict:
162
- for attempt in range(max_retry):
163
- try:
164
- response = await self._async_call_bing_api(query)
165
- return self._parse_response(response)
166
- except Exception as e:
167
- logging.exception(str(e))
168
- warnings.warn(
169
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
170
- await asyncio.sleep(random.randint(2, 5))
171
- raise Exception(
172
- 'Failed to get search results from Bing Search after retries.')
173
-
174
- def _call_bing_api(self, query: str) -> dict:
175
- endpoint = 'https://api.bing.microsoft.com/v7.0/search'
176
- params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
177
- headers = {'Ocp-Apim-Subscription-Key': self.api_key}
178
- response = requests.get(
179
- endpoint, headers=headers, params=params, proxies=self.proxy)
180
- response.raise_for_status()
181
- return response.json()
182
-
183
- async def _async_call_bing_api(self, query: str) -> dict:
184
- endpoint = 'https://api.bing.microsoft.com/v7.0/search'
185
- params = {'q': query, 'mkt': self.market, 'count': f'{self.topk * 2}'}
186
- headers = {'Ocp-Apim-Subscription-Key': self.api_key}
187
- async with aiohttp.ClientSession(raise_for_status=True) as session:
188
- async with session.get(
189
- endpoint,
190
- headers=headers,
191
- params=params,
192
- proxy=self.proxy and
193
- (self.proxy.get('http') or self.proxy.get('https'))) as resp:
194
- return await resp.json()
195
-
196
- def _parse_response(self, response: dict) -> dict:
197
- webpages = {
198
- w['id']: w
199
- for w in response.get('webPages', {}).get('value', [])
200
- }
201
- raw_results = []
202
-
203
- for item in response.get('rankingResponse',
204
- {}).get('mainline', {}).get('items', []):
205
- if item['answerType'] == 'WebPages':
206
- webpage = webpages.get(item['value']['id'])
207
- if webpage:
208
- raw_results.append(
209
- (webpage['url'], webpage['snippet'], webpage['name']))
210
- elif item['answerType'] == 'News' and item['value'][
211
- 'id'] == response.get('news', {}).get('id'):
212
- for news in response.get('news', {}).get('value', []):
213
- raw_results.append(
214
- (news['url'], news['description'], news['name']))
215
-
216
- return self._filter_results(raw_results)
217
-
218
-
219
- class BraveSearch(BaseSearch):
220
- """
221
- Wrapper around the Brave Search API.
222
-
223
- To use, you should pass your Brave Search API key to the constructor.
224
-
225
- Args:
226
- api_key (str): API KEY to use Brave Search API.
227
- You can create a free API key at https://api.search.brave.com/app/keys.
228
- search_type (str): Brave Search API supports ['web', 'news', 'images', 'videos'],
229
- currently only supports 'news' and 'web'.
230
- topk (int): The number of search results returned in response from API search results.
231
- region (str): The country code string. Specifies the country where the search results come from.
232
- language (str): The language code string. Specifies the preferred language for the search results.
233
- extra_snippets (bool): Allows retrieving up to 5 additional snippets, which are alternative excerpts from the search results.
234
- **kwargs: Any other parameters related to the Brave Search API. Find more details at
235
- https://api.search.brave.com/app/documentation/web-search/get-started.
236
- """
237
-
238
- def __init__(self,
239
- api_key: str,
240
- region: str = 'ALL',
241
- language: str = 'zh-hans',
242
- extra_snippests: bool = True,
243
- topk: int = 3,
244
- black_list: List[str] = [
245
- 'enoN',
246
- 'youtube.com',
247
- 'bilibili.com',
248
- 'researchgate.net',
249
- ],
250
- **kwargs):
251
- self.api_key = api_key
252
- self.market = region
253
- self.proxy = kwargs.get('proxy')
254
- self.language = language
255
- self.extra_snippests = extra_snippests
256
- self.search_type = kwargs.get('search_type', 'web')
257
- self.kwargs = kwargs
258
- super().__init__(topk, black_list)
259
-
260
- @cached(cache=TTLCache(maxsize=100, ttl=600))
261
- def search(self, query: str, max_retry: int = 3) -> dict:
262
- for attempt in range(max_retry):
263
- try:
264
- response = self._call_brave_api(query)
265
- return self._parse_response(response)
266
- except Exception as e:
267
- logging.exception(str(e))
268
- warnings.warn(
269
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
270
- time.sleep(random.randint(2, 5))
271
- raise Exception(
272
- 'Failed to get search results from Brave Search after retries.')
273
-
274
- @acached(cache=TTLCache(maxsize=100, ttl=600))
275
- async def asearch(self, query: str, max_retry: int = 3) -> dict:
276
- for attempt in range(max_retry):
277
- try:
278
- response = await self._async_call_brave_api(query)
279
- return self._parse_response(response)
280
- except Exception as e:
281
- logging.exception(str(e))
282
- warnings.warn(
283
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
284
- await asyncio.sleep(random.randint(2, 5))
285
- raise Exception(
286
- 'Failed to get search results from Brave Search after retries.')
287
-
288
- def _call_brave_api(self, query: str) -> dict:
289
- endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
290
- params = {
291
- 'q': query,
292
- 'country': self.market,
293
- 'search_lang': self.language,
294
- 'extra_snippets': self.extra_snippests,
295
- 'count': self.topk,
296
- **{
297
- key: value
298
- for key, value in self.kwargs.items() if value is not None
299
- },
300
- }
301
- headers = {
302
- 'X-Subscription-Token': self.api_key or '',
303
- 'Accept': 'application/json'
304
- }
305
- response = requests.get(
306
- endpoint, headers=headers, params=params, proxies=self.proxy)
307
- response.raise_for_status()
308
- return response.json()
309
-
310
- async def _async_call_brave_api(self, query: str) -> dict:
311
- endpoint = f'https://api.search.brave.com/res/v1/{self.search_type}/search'
312
- params = {
313
- 'q': query,
314
- 'country': self.market,
315
- 'search_lang': self.language,
316
- 'extra_snippets': self.extra_snippests,
317
- 'count': self.topk,
318
- **{
319
- key: value
320
- for key, value in self.kwargs.items() if value is not None
321
- },
322
- }
323
- headers = {
324
- 'X-Subscription-Token': self.api_key or '',
325
- 'Accept': 'application/json'
326
- }
327
- async with aiohttp.ClientSession(raise_for_status=True) as session:
328
- async with session.get(
329
- endpoint,
330
- headers=headers,
331
- params=params,
332
- proxy=self.proxy and
333
- (self.proxy.get('http') or self.proxy.get('https'))) as resp:
334
- return await resp.json()
335
-
336
- def _parse_response(self, response: dict) -> dict:
337
- if self.search_type == 'web':
338
- filtered_result = response.get('web', {}).get('results', [])
339
- else:
340
- filtered_result = response.get('results', {})
341
- raw_results = []
342
-
343
- for item in filtered_result:
344
- raw_results.append((
345
- item.get('url', ''),
346
- ' '.join(
347
- filter(None, [
348
- item.get('description'),
349
- *item.get('extra_snippets', [])
350
- ])),
351
- item.get('title', ''),
352
- ))
353
- return self._filter_results(raw_results)
354
-
355
-
356
- class GoogleSearch(BaseSearch):
357
- """
358
- Wrapper around the Serper.dev Google Search API.
359
-
360
- To use, you should pass your serper API key to the constructor.
361
-
362
- Args:
363
- api_key (str): API KEY to use serper google search API.
364
- You can create a free API key at https://serper.dev.
365
- search_type (str): Serper API supports ['search', 'images', 'news',
366
- 'places'] types of search, currently we only support 'search' and 'news'.
367
- topk (int): The number of search results returned in response from api search results.
368
- **kwargs: Any other parameters related to the Serper API. Find more details at
369
- https://serper.dev/playground
370
- """
371
-
372
- result_key_for_type = {
373
- 'news': 'news',
374
- 'places': 'places',
375
- 'images': 'images',
376
- 'search': 'organic',
377
- }
378
-
379
- def __init__(self,
380
- api_key: str,
381
- topk: int = 3,
382
- black_list: List[str] = [
383
- 'enoN',
384
- 'youtube.com',
385
- 'bilibili.com',
386
- 'researchgate.net',
387
- ],
388
- **kwargs):
389
- self.api_key = api_key
390
- self.proxy = kwargs.get('proxy')
391
- self.search_type = kwargs.get('search_type', 'search')
392
- self.kwargs = kwargs
393
- super().__init__(topk, black_list)
394
-
395
- @cached(cache=TTLCache(maxsize=100, ttl=600))
396
- def search(self, query: str, max_retry: int = 3) -> dict:
397
- for attempt in range(max_retry):
398
- try:
399
- response = self._call_serper_api(query)
400
- return self._parse_response(response)
401
- except Exception as e:
402
- logging.exception(str(e))
403
- warnings.warn(
404
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
405
- time.sleep(random.randint(2, 5))
406
- raise Exception(
407
- 'Failed to get search results from Google Serper Search after retries.'
408
- )
409
-
410
- @acached(cache=TTLCache(maxsize=100, ttl=600))
411
- async def asearch(self, query: str, max_retry: int = 3) -> dict:
412
- for attempt in range(max_retry):
413
- try:
414
- response = await self._async_call_serper_api(query)
415
- return self._parse_response(response)
416
- except Exception as e:
417
- logging.exception(str(e))
418
- warnings.warn(
419
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
420
- await asyncio.sleep(random.randint(2, 5))
421
- raise Exception(
422
- 'Failed to get search results from Google Serper Search after retries.'
423
- )
424
-
425
- def _call_serper_api(self, query: str) -> dict:
426
- endpoint = f'https://google.serper.dev/{self.search_type}'
427
- params = {
428
- 'q': query,
429
- 'num': self.topk,
430
- **{
431
- key: value
432
- for key, value in self.kwargs.items() if value is not None
433
- },
434
- }
435
- headers = {
436
- 'X-API-KEY': self.api_key or '',
437
- 'Content-Type': 'application/json'
438
- }
439
- response = requests.get(
440
- endpoint, headers=headers, params=params, proxies=self.proxy)
441
- response.raise_for_status()
442
- return response.json()
443
-
444
- async def _async_call_serper_api(self, query: str) -> dict:
445
- endpoint = f'https://google.serper.dev/{self.search_type}'
446
- params = {
447
- 'q': query,
448
- 'num': self.topk,
449
- **{
450
- key: value
451
- for key, value in self.kwargs.items() if value is not None
452
- },
453
- }
454
- headers = {
455
- 'X-API-KEY': self.api_key or '',
456
- 'Content-Type': 'application/json'
457
- }
458
- async with aiohttp.ClientSession(raise_for_status=True) as session:
459
- async with session.get(
460
- endpoint,
461
- headers=headers,
462
- params=params,
463
- proxy=self.proxy and
464
- (self.proxy.get('http') or self.proxy.get('https'))) as resp:
465
- return await resp.json()
466
-
467
- def _parse_response(self, response: dict) -> dict:
468
- raw_results = []
469
-
470
- if response.get('answerBox'):
471
- answer_box = response.get('answerBox', {})
472
- if answer_box.get('answer'):
473
- raw_results.append(('', answer_box.get('answer'), ''))
474
- elif answer_box.get('snippet'):
475
- raw_results.append(
476
- ('', answer_box.get('snippet').replace('\n', ' '), ''))
477
- elif answer_box.get('snippetHighlighted'):
478
- raw_results.append(
479
- ('', answer_box.get('snippetHighlighted'), ''))
480
-
481
- if response.get('knowledgeGraph'):
482
- kg = response.get('knowledgeGraph', {})
483
- description = kg.get('description', '')
484
- attributes = '. '.join(
485
- f'{attribute}: {value}'
486
- for attribute, value in kg.get('attributes', {}).items())
487
- raw_results.append(
488
- (kg.get('descriptionLink', ''),
489
- f'{description}. {attributes}' if attributes else description,
490
- f"{kg.get('title', '')}: {kg.get('type', '')}."))
491
-
492
- for result in response[self.result_key_for_type[
493
- self.search_type]][:self.topk]:
494
- description = result.get('snippet', '')
495
- attributes = '. '.join(
496
- f'{attribute}: {value}'
497
- for attribute, value in result.get('attributes', {}).items())
498
- raw_results.append(
499
- (result.get('link', ''),
500
- f'{description}. {attributes}' if attributes else description,
501
- result.get('title', '')))
502
-
503
- return self._filter_results(raw_results)
504
-
505
-
506
- class TencentSearch(BaseSearch):
507
- """Wrapper around the tencentclound Search API.
508
-
509
- To use, you should pass your secret_id and secret_key to the constructor.
510
-
511
- Args:
512
- secret_id (str): Your Tencent Cloud secret ID for accessing the API.
513
- For more details, refer to the documentation: https://cloud.tencent.com/document/product/598/40488.
514
- secret_key (str): Your Tencent Cloud secret key for accessing the API.
515
- api_key (str, optional): Additional API key, if required.
516
- action (str): The action for this interface, use `SearchCommon`.
517
- version (str): The API version, use `2020-12-29`.
518
- service (str): The service name, use `tms`.
519
- host (str): The API host, use `tms.tencentcloudapi.com`.
520
- topk (int): The maximum number of search results to return.
521
- tsn (int): Time filter for search results. Valid values:
522
- 1 (within 1 day), 2 (within 1 week), 3 (within 1 month),
523
- 4 (within 1 year), 5 (within 6 months), 6 (within 3 years).
524
- insite (str): Specify a site to search within (supports only a single site).
525
- If not specified, the entire web is searched. Example: `zhihu.com`.
526
- category (str): Vertical category for filtering results. Optional values include:
527
- `baike` (encyclopedia), `weather`, `calendar`, `medical`, `news`, `train`, `star` (horoscope).
528
- vrid (str): Result card type(s). Different `vrid` values represent different types of result cards.
529
- Supports multiple values separated by commas. Example: `30010255`.
530
- """
531
-
532
- def __init__(self,
533
- secret_id: str = 'Your SecretId',
534
- secret_key: str = 'Your SecretKey',
535
- api_key: str = '',
536
- action: str = 'SearchCommon',
537
- version: str = '2020-12-29',
538
- service: str = 'tms',
539
- host: str = 'tms.tencentcloudapi.com',
540
- topk: int = 3,
541
- tsn: int = None,
542
- insite: str = None,
543
- category: str = None,
544
- vrid: str = None,
545
- black_list: List[str] = [
546
- 'enoN',
547
- 'youtube.com',
548
- 'bilibili.com',
549
- 'researchgate.net',
550
- ]):
551
- self.secret_id = secret_id
552
- self.secret_key = secret_key
553
- self.api_key = api_key
554
- self.action = action
555
- self.version = version
556
- self.service = service
557
- self.host = host
558
- self.tsn = tsn
559
- self.insite = insite
560
- self.category = category
561
- self.vrid = vrid
562
- super().__init__(topk, black_list=black_list)
563
-
564
- @cached(cache=TTLCache(maxsize=100, ttl=600))
565
- def search(self, query: str, max_retry: int = 3) -> dict:
566
- for attempt in range(max_retry):
567
- try:
568
- response = self._call_tencent_api(query)
569
- return self._parse_response(response)
570
- except Exception as e:
571
- logging.exception(str(e))
572
- warnings.warn(
573
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
574
- time.sleep(random.randint(2, 5))
575
- raise Exception(
576
- 'Failed to get search results from Bing Search after retries.')
577
-
578
- @acached(cache=TTLCache(maxsize=100, ttl=600))
579
- async def asearch(self, query: str, max_retry: int = 3) -> dict:
580
- for attempt in range(max_retry):
581
- try:
582
- response = await self._async_call_tencent_api(query)
583
- return self._parse_response(response)
584
- except Exception as e:
585
- logging.exception(str(e))
586
- warnings.warn(
587
- f'Retry {attempt + 1}/{max_retry} due to error: {e}')
588
- await asyncio.sleep(random.randint(2, 5))
589
- raise Exception(
590
- 'Failed to get search results from Bing Search after retries.')
591
-
592
- def _get_headers_and_payload(self, query: str) -> tuple:
593
-
594
- def sign(key, msg):
595
- return hmac.new(key, msg.encode('utf-8'), hashlib.sha256).digest()
596
-
597
- params = dict(Query=query)
598
- # if self.topk:
599
- # params['Cnt'] = self.topk
600
- if self.tsn:
601
- params['Tsn'] = self.tsn
602
- if self.insite:
603
- params['Insite'] = self.insite
604
- if self.category:
605
- params['Category'] = self.category
606
- if self.vrid:
607
- params['Vrid'] = self.vrid
608
- payload = json.dumps(params)
609
- algorithm = 'TC3-HMAC-SHA256'
610
- timestamp = int(time.time())
611
- date = datetime.utcfromtimestamp(timestamp).strftime('%Y-%m-%d')
612
-
613
- # ************* 步骤 1:拼接规范请求串 *************
614
- http_request_method = 'POST'
615
- canonical_uri = '/'
616
- canonical_querystring = ''
617
- ct = 'application/json; charset=utf-8'
618
- canonical_headers = f'content-type:{ct}\nhost:{self.host}\nx-tc-action:{self.action.lower()}\n'
619
- signed_headers = 'content-type;host;x-tc-action'
620
- hashed_request_payload = hashlib.sha256(
621
- payload.encode('utf-8')).hexdigest()
622
- canonical_request = (
623
- http_request_method + '\n' + canonical_uri + '\n' +
624
- canonical_querystring + '\n' + canonical_headers + '\n' +
625
- signed_headers + '\n' + hashed_request_payload)
626
-
627
- # ************* 步骤 2:拼接待签名字符串 *************
628
- credential_scope = date + '/' + self.service + '/' + 'tc3_request'
629
- hashed_canonical_request = hashlib.sha256(
630
- canonical_request.encode('utf-8')).hexdigest()
631
- string_to_sign = (
632
- algorithm + '\n' + str(timestamp) + '\n' + credential_scope +
633
- '\n' + hashed_canonical_request)
634
-
635
- # ************* 步骤 3:计算签名 *************
636
- secret_date = sign(('TC3' + self.secret_key).encode('utf-8'), date)
637
- secret_service = sign(secret_date, self.service)
638
- secret_signing = sign(secret_service, 'tc3_request')
639
- signature = hmac.new(secret_signing, string_to_sign.encode('utf-8'),
640
- hashlib.sha256).hexdigest()
641
-
642
- # ************* 步骤 4:拼接 Authorization *************
643
- authorization = (
644
- algorithm + ' ' + 'Credential=' + self.secret_id + '/' +
645
- credential_scope + ', ' + 'SignedHeaders=' + signed_headers +
646
- ', ' + 'Signature=' + signature)
647
-
648
- # ************* 步骤 5:构造并发起请求 *************
649
- headers = {
650
- 'Authorization': authorization,
651
- 'Content-Type': 'application/json; charset=utf-8',
652
- 'Host': self.host,
653
- 'X-TC-Action': self.action,
654
- 'X-TC-Timestamp': str(timestamp),
655
- 'X-TC-Version': self.version
656
- }
657
- # if self.region:
658
- # headers["X-TC-Region"] = self.region
659
- if self.api_key:
660
- headers['X-TC-Token'] = self.api_key
661
- return headers, payload
662
-
663
- def _call_tencent_api(self, query: str) -> dict:
664
- headers, payload = self._get_headers_and_payload(query)
665
- req = HTTPSConnection(self.host)
666
- req.request('POST', '/', headers=headers, body=payload.encode('utf-8'))
667
- resp = req.getresponse()
668
- try:
669
- resp = json.loads(resp.read().decode('utf-8'))
670
- except Exception as e:
671
- logging.warning(str(e))
672
- import ast
673
- resp = ast.literal_eval(resp)
674
- return resp.get('Response', dict())
675
-
676
- async def _async_call_tencent_api(self, query: str):
677
- headers, payload = self._get_headers_and_payload(query)
678
- async with aiohttp.ClientSession(raise_for_status=True) as session:
679
- async with session.post(
680
- 'https://' + self.host.lstrip('/'),
681
- headers=headers,
682
- data=payload) as resp:
683
- return (await resp.json()).get('Response', {})
684
-
685
- def _parse_response(self, response: dict) -> dict:
686
- raw_results = []
687
- for item in response.get('Pages', []):
688
- display = json.loads(item['Display'])
689
- if not display['url']:
690
- continue
691
- raw_results.append((display['url'], display['content']
692
- or display['abstract_info'], display['title']))
693
- return self._filter_results(raw_results)
694
-
695
-
696
- class ContentFetcher:
697
-
698
- def __init__(self, timeout: int = 5):
699
- self.timeout = timeout
700
-
701
- @cached(cache=TTLCache(maxsize=100, ttl=600))
702
- def fetch(self, url: str) -> Tuple[bool, str]:
703
- try:
704
- response = requests.get(url, timeout=self.timeout)
705
- response.raise_for_status()
706
- html = response.content
707
- except requests.RequestException as e:
708
- return False, str(e)
709
-
710
- text = BeautifulSoup(html, 'html.parser').get_text()
711
- cleaned_text = re.sub(r'\n+', '\n', text)
712
- return True, cleaned_text
713
-
714
- @acached(cache=TTLCache(maxsize=100, ttl=600))
715
- async def afetch(self, url: str) -> Tuple[bool, str]:
716
- try:
717
- async with aiohttp.ClientSession(
718
- raise_for_status=True,
719
- timeout=aiohttp.ClientTimeout(self.timeout)) as session:
720
- async with session.get(url) as resp:
721
- html = await resp.text(errors='ignore')
722
- text = BeautifulSoup(html, 'html.parser').get_text()
723
- cleaned_text = re.sub(r'\n+', '\n', text)
724
- return True, cleaned_text
725
- except Exception as e:
726
- return False, str(e)
727
-
728
-
729
- class WebBrowser(BaseAction):
730
- """Wrapper around the Web Browser Tool.
731
- """
732
-
733
- def __init__(self,
734
- searcher_type: str = 'DuckDuckGoSearch',
735
- timeout: int = 5,
736
- black_list: Optional[List[str]] = [
737
- 'enoN',
738
- 'youtube.com',
739
- 'bilibili.com',
740
- 'researchgate.net',
741
- ],
742
- topk: int = 20,
743
- description: Optional[dict] = None,
744
- parser: Type[BaseParser] = JsonParser,
745
- **kwargs):
746
- self.searcher = eval(searcher_type)(
747
- black_list=black_list, topk=topk, **kwargs)
748
- self.fetcher = ContentFetcher(timeout=timeout)
749
- self.search_results = None
750
- super().__init__(description, parser)
751
-
752
- @tool_api
753
- def search(self, query: Union[str, List[str]]) -> dict:
754
- """BING search API
755
- Args:
756
- query (List[str]): list of search query strings
757
- """
758
- queries = query if isinstance(query, list) else [query]
759
- search_results = {}
760
-
761
- with ThreadPoolExecutor() as executor:
762
- future_to_query = {
763
- executor.submit(self.searcher.search, q): q
764
- for q in queries
765
- }
766
-
767
- for future in as_completed(future_to_query):
768
- query = future_to_query[future]
769
- try:
770
- results = future.result()
771
- except Exception as exc:
772
- warnings.warn(f'{query} generated an exception: {exc}')
773
- else:
774
- for result in results.values():
775
- if result['url'] not in search_results:
776
- search_results[result['url']] = result
777
- else:
778
- search_results[
779
- result['url']]['summ'] += f"\n{result['summ']}"
780
-
781
- self.search_results = {
782
- idx: result
783
- for idx, result in enumerate(search_results.values())
784
- }
785
- return self.search_results
786
-
787
- @tool_api
788
- def select(self, select_ids: List[int]) -> dict:
789
- """get the detailed content on the selected pages.
790
-
791
- Args:
792
- select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
793
- """
794
- if not self.search_results:
795
- raise ValueError('No search results to select from.')
796
-
797
- new_search_results = {}
798
- with ThreadPoolExecutor() as executor:
799
- future_to_id = {
800
- executor.submit(self.fetcher.fetch, self.search_results[select_id]['url']): select_id
801
- for select_id in select_ids if select_id in self.search_results
802
- }
803
- for future in as_completed(future_to_id):
804
- select_id = future_to_id[future]
805
- try:
806
- web_success, web_content = future.result()
807
- except Exception as exc:
808
- warnings.warn(f'{select_id} generated an exception: {exc}')
809
- else:
810
- if web_success:
811
- self.search_results[select_id][
812
- 'content'] = web_content[:8192]
813
- new_search_results[select_id] = self.search_results[
814
- select_id].copy()
815
- new_search_results[select_id].pop('summ')
816
-
817
- return new_search_results
818
-
819
- @tool_api
820
- def open_url(self, url: str) -> dict:
821
- print(f'Start Browsing: {url}')
822
- web_success, web_content = self.fetcher.fetch(url)
823
- if web_success:
824
- return {'type': 'text', 'content': web_content}
825
- else:
826
- return {'error': web_content}
827
-
828
-
829
- class AsyncWebBrowser(AsyncActionMixin, WebBrowser):
830
- """Wrapper around the Web Browser Tool.
831
- """
832
-
833
- @tool_api
834
- async def search(self, query: Union[str, List[str]]) -> dict:
835
- """BING search API
836
-
837
- Args:
838
- query (List[str]): list of search query strings
839
- """
840
- queries = query if isinstance(query, list) else [query]
841
- search_results = {}
842
-
843
- tasks = []
844
- for q in queries:
845
- task = asyncio.create_task(self.searcher.asearch(q))
846
- task.query = q
847
- tasks.append(task)
848
- async for future in async_as_completed(tasks):
849
- query = future.query
850
- try:
851
- results = await future
852
- except Exception as exc:
853
- warnings.warn(f'{query} generated an exception: {exc}')
854
- else:
855
- for result in results.values():
856
- if result['url'] not in search_results:
857
- search_results[result['url']] = result
858
- else:
859
- search_results[
860
- result['url']]['summ'] += f"\n{result['summ']}"
861
-
862
- self.search_results = {
863
- idx: result
864
- for idx, result in enumerate(search_results.values())
865
- }
866
- return self.search_results
867
-
868
- @tool_api
869
- async def select(self, select_ids: List[int]) -> dict:
870
- """get the detailed content on the selected pages.
871
-
872
- Args:
873
- select_ids (List[int]): list of index to select. Max number of index to be selected is no more than 4.
874
- """
875
- if not self.search_results:
876
- raise ValueError('No search results to select from.')
877
-
878
- new_search_results = {}
879
- tasks = []
880
- for select_id in select_ids:
881
- if select_id in self.search_results:
882
- task = asyncio.create_task(
883
- self.fetcher.afetch(self.search_results[select_id]['url']))
884
- task.select_id = select_id
885
- tasks.append(task)
886
- async for future in async_as_completed(tasks):
887
- select_id = future.select_id
888
- try:
889
- web_success, web_content = await future
890
- except Exception as exc:
891
- warnings.warn(f'{select_id} generated an exception: {exc}')
892
- else:
893
- if web_success:
894
- self.search_results[select_id][
895
- 'content'] = web_content[:8192]
896
- new_search_results[select_id] = self.search_results[
897
- select_id].copy()
898
- new_search_results[select_id].pop('summ')
899
- return new_search_results
900
-
901
- @tool_api
902
- async def open_url(self, url: str) -> dict:
903
- print(f'Start Browsing: {url}')
904
- web_success, web_content = await self.fetcher.afetch(url)
905
- if web_success:
906
- return {'type': 'text', 'content': web_content}
907
- else:
908
- return {'error': web_content}