m-ric HF Staff commited on
Commit
9bb5689
·
verified ·
1 Parent(s): 68bd5bd

Upload tool

Browse files
Files changed (5) hide show
  1. app.py +4 -0
  2. requirements.txt +25 -0
  3. tool_config.json +7 -0
  4. tools.py +1098 -0
  5. types.py +270 -0
app.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from transformers import launch_gradio_demo
2
+ from tools import Get_current_timeTool
3
+
4
+ launch_gradio_demo(Get_current_timeTool)
requirements.txt ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ my_tool_module
2
+ transformers
3
+ importlib
4
+ uuid
5
+ io
6
+ huggingface_hub
7
+ inspect
8
+ builtins
9
+ pathlib
10
+ os
11
+ PIL
12
+ functools
13
+ torch
14
+ typing
15
+ {module_name}
16
+ packaging
17
+ gradio_client
18
+ ast
19
+ IPython
20
+ json
21
+ logging
22
+ base64
23
+ accelerate
24
+ textwrap
25
+ tempfile
tool_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "description": "Gets the current time.",
3
+ "inputs": {},
4
+ "name": "get_current_time",
5
+ "output_type": "string",
6
+ "tool_class": "tools.Get_current_timeTool"
7
+ }
tools.py ADDED
@@ -0,0 +1,1098 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+
4
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+ import ast
18
+ import base64
19
+ import importlib
20
+ import inspect
21
+ import io
22
+ import json
23
+ import os
24
+ import tempfile
25
+ import textwrap
26
+ from functools import lru_cache, wraps
27
+ from pathlib import Path
28
+ from typing import Any, Callable, Dict, List, Optional, Union
29
+
30
+ from huggingface_hub import (
31
+ create_repo,
32
+ get_collection,
33
+ hf_hub_download,
34
+ metadata_update,
35
+ upload_folder,
36
+ )
37
+ from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
38
+ from packaging import version
39
+
40
+ from transformers.dynamic_module_utils import (
41
+ custom_object_save,
42
+ get_class_from_dynamic_module,
43
+ get_imports,
44
+ )
45
+ from transformers import AutoProcessor
46
+ from transformers.utils import (
47
+ CONFIG_NAME,
48
+ TypeHintParsingException,
49
+ cached_file,
50
+ get_json_schema,
51
+ is_accelerate_available,
52
+ is_torch_available,
53
+ is_vision_available,
54
+ )
55
+ from .types import ImageType, handle_agent_inputs, handle_agent_outputs
56
+ import logging
57
+
58
+ logger = logging.getLogger(__name__)
59
+
60
+
61
+ if is_torch_available():
62
+ import torch
63
+
64
+ if is_accelerate_available():
65
+ from accelerate import PartialState
66
+ from accelerate.utils import send_to_device
67
+
68
+
69
+ TOOL_CONFIG_FILE = "tool_config.json"
70
+
71
+
72
+ def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
73
+ if repo_type is not None:
74
+ return repo_type
75
+ try:
76
+ hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
77
+ return "space"
78
+ except RepositoryNotFoundError:
79
+ try:
80
+ hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
81
+ return "model"
82
+ except RepositoryNotFoundError:
83
+ raise EnvironmentError(
84
+ f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
85
+ )
86
+ except Exception:
87
+ return "model"
88
+ except Exception:
89
+ return "space"
90
+
91
+
92
+ def setup_default_tools():
93
+ default_tools = {}
94
+ main_module = importlib.import_module("transformers")
95
+ tools_module = main_module.agents
96
+
97
+ for task_name, tool_class_name in TOOL_MAPPING.items():
98
+ tool_class = getattr(tools_module, tool_class_name)
99
+ tool_instance = tool_class()
100
+ default_tools[tool_class.name] = tool_instance
101
+
102
+ return default_tools
103
+
104
+
105
+ # docstyle-ignore
106
+ APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
107
+ from {module_name} import {class_name}
108
+
109
+ launch_gradio_demo({class_name})
110
+ """
111
+
112
+
113
+ def validate_after_init(cls, do_validate_forward: bool = True):
114
+ original_init = cls.__init__
115
+
116
+ @wraps(original_init)
117
+ def new_init(self, *args, **kwargs):
118
+ original_init(self, *args, **kwargs)
119
+ self.validate_arguments(do_validate_forward=do_validate_forward)
120
+
121
+ cls.__init__ = new_init
122
+ return cls
123
+
124
+ def validate_forward_method_args(cls):
125
+ """Validates that all names in forward method are properly defined.
126
+ In particular it will check that all imports are done within the function."""
127
+ if 'forward' not in cls.__dict__:
128
+ return
129
+
130
+ forward = cls.__dict__['forward']
131
+ source_code = textwrap.dedent(inspect.getsource(forward))
132
+ tree = ast.parse(source_code)
133
+
134
+ # Get function arguments
135
+ func_node = tree.body[0]
136
+ arg_names = {arg.arg for arg in func_node.args.args}
137
+
138
+
139
+ import builtins
140
+ builtin_names = set(vars(builtins))
141
+
142
+
143
+ # Find all used names that aren't arguments or self attributes
144
+ class NameChecker(ast.NodeVisitor):
145
+ def __init__(self):
146
+ self.undefined_names = set()
147
+ self.imports = {}
148
+ self.from_imports = {}
149
+
150
+ def visit_Import(self, node):
151
+ """Handle simple imports like 'import datetime'."""
152
+ for name in node.names:
153
+ actual_name = name.asname or name.name
154
+ self.imports[actual_name] = (name.name, actual_name)
155
+
156
+ def visit_ImportFrom(self, node):
157
+ """Handle from imports like 'from datetime import datetime'."""
158
+ module = node.module or ''
159
+ for name in node.names:
160
+ actual_name = name.asname or name.name
161
+ self.from_imports[actual_name] = (module, name.name, actual_name)
162
+
163
+ def visit_Name(self, node):
164
+ if (isinstance(node.ctx, ast.Load) and not (
165
+ node.id == "tool" or
166
+ node.id in builtin_names or
167
+ node.id in arg_names or
168
+ node.id == 'self'
169
+ )):
170
+ if node.id not in self.from_imports and node.id not in self.imports:
171
+ self.undefined_names.add(node.id)
172
+
173
+ def visit_Attribute(self, node):
174
+ # Skip self.something
175
+ if not (isinstance(node.value, ast.Name) and node.value.id == 'self'):
176
+ self.generic_visit(node)
177
+
178
+ checker = NameChecker()
179
+ checker.visit(tree)
180
+
181
+ if checker.undefined_names:
182
+ raise ValueError(
183
+ f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
184
+ Make sure all imports and variables are defined within the method.
185
+ For instance:
186
+
187
+ """
188
+ )
189
+
190
+ AUTHORIZED_TYPES = [
191
+ "string",
192
+ "boolean",
193
+ "integer",
194
+ "number",
195
+ "image",
196
+ "audio",
197
+ "any",
198
+ ]
199
+
200
+ CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
201
+
202
+
203
+ class Tool:
204
+ """
205
+ A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the
206
+ following class attributes:
207
+
208
+ - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
209
+ will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
210
+ returns the text contained in the file'.
211
+ - **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
212
+ `"text-classifier"` or `"image_generator"`.
213
+ - **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
214
+ It has one `type`key and a `description`key.
215
+ This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
216
+ description for your tool.
217
+ - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
218
+ or to make a nice space from your tool, and also can be used in the generated description for your tool.
219
+
220
+ You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being
221
+ usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
222
+ instantiation.
223
+ """
224
+
225
+ name: str
226
+ description: str
227
+ inputs: Dict[str, Dict[str, Union[str, type]]]
228
+ output_type: str
229
+
230
+ def __init__(self, *args, **kwargs):
231
+ self.is_initialized = False
232
+
233
+ def __init_subclass__(cls, **kwargs):
234
+ super().__init_subclass__(**kwargs)
235
+ validate_forward_method_args(cls)
236
+ validate_after_init(cls, do_validate_forward=False)
237
+
238
+
239
+ def validate_arguments(self, do_validate_forward: bool = True):
240
+ required_attributes = {
241
+ "description": str,
242
+ "name": str,
243
+ "inputs": dict,
244
+ "output_type": str,
245
+ }
246
+
247
+ for attr, expected_type in required_attributes.items():
248
+ attr_value = getattr(self, attr, None)
249
+ if attr_value is None:
250
+ raise TypeError(f"You must set an attribute {attr}.")
251
+ if not isinstance(attr_value, expected_type):
252
+ raise TypeError(
253
+ f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
254
+ )
255
+ for input_name, input_content in self.inputs.items():
256
+ assert isinstance(
257
+ input_content, dict
258
+ ), f"Input '{input_name}' should be a dictionary."
259
+ assert (
260
+ "type" in input_content and "description" in input_content
261
+ ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
262
+ if input_content["type"] not in AUTHORIZED_TYPES:
263
+ raise Exception(
264
+ f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
265
+ )
266
+
267
+ assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
268
+ if do_validate_forward:
269
+ signature = inspect.signature(self.forward)
270
+ if not set(signature.parameters.keys()) == set(self.inputs.keys()):
271
+ raise Exception(
272
+ "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
273
+ )
274
+
275
+ def forward(self, *args, **kwargs):
276
+ return NotImplementedError("Write this method in your subclass of `Tool`.")
277
+
278
+ def __call__(self, *args, **kwargs):
279
+ if not self.is_initialized:
280
+ self.setup()
281
+ args, kwargs = handle_agent_inputs(*args, **kwargs)
282
+ outputs = self.forward(*args, **kwargs)
283
+ return handle_agent_outputs(outputs, self.output_type)
284
+
285
+ def setup(self):
286
+ """
287
+ Overwrite this method here for any operation that is expensive and needs to be executed before you start using
288
+ your tool. Such as loading a big model.
289
+ """
290
+ self.is_initialized = True
291
+
292
+ def save(self, output_dir):
293
+ """
294
+ Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
295
+ tool in `output_dir` as well as autogenerate:
296
+
297
+ - an `app.py` file so that your tool can be converted to a space
298
+ - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
299
+ code)
300
+
301
+ You should only use this method to save tools that are defined in a separate module (not `__main__`).
302
+
303
+ Args:
304
+ output_dir (`str`): The folder in which you want to save your tool.
305
+ """
306
+ os.makedirs(output_dir, exist_ok=True)
307
+ # Save module file
308
+ if self.__module__ == "__main__":
309
+ raise ValueError(
310
+ f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
311
+ "have to put this code in a separate module so we can include it in the saved folder."
312
+ )
313
+ module_files = custom_object_save(self, output_dir)
314
+
315
+ module_name = self.__class__.__module__
316
+ last_module = module_name.split(".")[-1]
317
+ full_name = f"{last_module}.{self.__class__.__name__}"
318
+
319
+ # Save config file
320
+ config_file = os.path.join(output_dir, "tool_config.json")
321
+ if os.path.isfile(config_file):
322
+ with open(config_file, "r", encoding="utf-8") as f:
323
+ tool_config = json.load(f)
324
+ else:
325
+ tool_config = {}
326
+
327
+ tool_config = {
328
+ "tool_class": full_name,
329
+ "description": self.description,
330
+ "name": self.name,
331
+ "inputs": self.inputs,
332
+ "output_type": str(self.output_type),
333
+ }
334
+ with open(config_file, "w", encoding="utf-8") as f:
335
+ f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
336
+
337
+ # Save app file
338
+ app_file = os.path.join(output_dir, "app.py")
339
+ with open(app_file, "w", encoding="utf-8") as f:
340
+ f.write(
341
+ APP_FILE_TEMPLATE.format(
342
+ module_name=last_module, class_name=self.__class__.__name__
343
+ )
344
+ )
345
+
346
+ # Save requirements file
347
+ requirements_file = os.path.join(output_dir, "requirements.txt")
348
+ imports = []
349
+ for module in module_files:
350
+ imports.extend(get_imports(module))
351
+ imports = list(set(imports))
352
+ with open(requirements_file, "w", encoding="utf-8") as f:
353
+ f.write("\n".join(imports) + "\n")
354
+
355
+ @classmethod
356
+ def from_hub(
357
+ cls,
358
+ repo_id: str,
359
+ token: Optional[str] = None,
360
+ **kwargs,
361
+ ):
362
+ """
363
+ Loads a tool defined on the Hub.
364
+
365
+ <Tip warning={true}>
366
+
367
+ Loading a tool from the Hub means that you'll download the tool and execute it locally.
368
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
369
+ installing a package using pip/npm/apt.
370
+
371
+ </Tip>
372
+
373
+ Args:
374
+ repo_id (`str`):
375
+ The name of the repo on the Hub where your tool is defined.
376
+ token (`str`, *optional*):
377
+ The token to identify you on hf.co. If unset, will use the token generated when running
378
+ `huggingface-cli login` (stored in `~/.huggingface`).
379
+ kwargs (additional keyword arguments, *optional*):
380
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
381
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
382
+ others will be passed along to its init.
383
+ """
384
+ hub_kwargs_names = [
385
+ "cache_dir",
386
+ "force_download",
387
+ "resume_download",
388
+ "proxies",
389
+ "revision",
390
+ "repo_type",
391
+ "subfolder",
392
+ "local_files_only",
393
+ ]
394
+ hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
395
+
396
+ # Try to get the tool config first.
397
+ hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
398
+ resolved_config_file = cached_file(
399
+ repo_id,
400
+ TOOL_CONFIG_FILE,
401
+ token=token,
402
+ **hub_kwargs,
403
+ _raise_exceptions_for_gated_repo=False,
404
+ _raise_exceptions_for_missing_entries=False,
405
+ _raise_exceptions_for_connection_errors=False,
406
+ )
407
+ is_tool_config = resolved_config_file is not None
408
+ if resolved_config_file is None:
409
+ resolved_config_file = cached_file(
410
+ repo_id,
411
+ CONFIG_NAME,
412
+ token=token,
413
+ **hub_kwargs,
414
+ _raise_exceptions_for_gated_repo=False,
415
+ _raise_exceptions_for_missing_entries=False,
416
+ _raise_exceptions_for_connection_errors=False,
417
+ )
418
+ if resolved_config_file is None:
419
+ raise EnvironmentError(
420
+ f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
421
+ )
422
+
423
+ with open(resolved_config_file, encoding="utf-8") as reader:
424
+ config = json.load(reader)
425
+
426
+ if not is_tool_config:
427
+ if "custom_tool" not in config:
428
+ raise EnvironmentError(
429
+ f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
430
+ )
431
+ custom_tool = config["custom_tool"]
432
+ else:
433
+ custom_tool = config
434
+
435
+ tool_class = custom_tool["tool_class"]
436
+ tool_class = get_class_from_dynamic_module(
437
+ tool_class, repo_id, token=token, **hub_kwargs
438
+ )
439
+
440
+ if len(tool_class.name) == 0:
441
+ tool_class.name = custom_tool["name"]
442
+ if tool_class.name != custom_tool["name"]:
443
+ logger.warning(
444
+ f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
445
+ "configuration name."
446
+ )
447
+ tool_class.name = custom_tool["name"]
448
+
449
+ if len(tool_class.description) == 0:
450
+ tool_class.description = custom_tool["description"]
451
+ if tool_class.description != custom_tool["description"]:
452
+ logger.warning(
453
+ f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
454
+ "tool configuration description."
455
+ )
456
+ tool_class.description = custom_tool["description"]
457
+
458
+ if tool_class.inputs != custom_tool["inputs"]:
459
+ tool_class.inputs = custom_tool["inputs"]
460
+ if tool_class.output_type != custom_tool["output_type"]:
461
+ tool_class.output_type = custom_tool["output_type"]
462
+
463
+ if not isinstance(tool_class.inputs, dict):
464
+ tool_class.inputs = ast.literal_eval(tool_class.inputs)
465
+
466
+ return tool_class(**kwargs)
467
+
468
+ def push_to_hub(
469
+ self,
470
+ repo_id: str,
471
+ commit_message: str = "Upload tool",
472
+ private: Optional[bool] = None,
473
+ token: Optional[Union[bool, str]] = None,
474
+ create_pr: bool = False,
475
+ ) -> str:
476
+ """
477
+ Upload the tool to the Hub.
478
+
479
+ For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
480
+ For instance:
481
+ ```
482
+ from my_tool_module import MyTool
483
+ my_tool = MyTool()
484
+ my_tool.push_to_hub("my-username/my-space")
485
+ ```
486
+
487
+ Parameters:
488
+ repo_id (`str`):
489
+ The name of the repository you want to push your tool to. It should contain your organization name when
490
+ pushing to a given organization.
491
+ commit_message (`str`, *optional*, defaults to `"Upload tool"`):
492
+ Message to commit while pushing.
493
+ private (`bool`, *optional*):
494
+ Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
495
+ token (`bool` or `str`, *optional*):
496
+ The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
497
+ when running `huggingface-cli login` (stored in `~/.huggingface`).
498
+ create_pr (`bool`, *optional*, defaults to `False`):
499
+ Whether or not to create a PR with the uploaded files or directly commit.
500
+ """
501
+ repo_url = create_repo(
502
+ repo_id=repo_id,
503
+ token=token,
504
+ private=private,
505
+ exist_ok=True,
506
+ repo_type="space",
507
+ space_sdk="gradio",
508
+ )
509
+ repo_id = repo_url.repo_id
510
+ metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
511
+
512
+ with tempfile.TemporaryDirectory() as work_dir:
513
+ # Save all files.
514
+ self.save(work_dir)
515
+ logger.info(
516
+ f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
517
+ )
518
+ return upload_folder(
519
+ repo_id=repo_id,
520
+ commit_message=commit_message,
521
+ folder_path=work_dir,
522
+ token=token,
523
+ create_pr=create_pr,
524
+ repo_type="space",
525
+ )
526
+
527
+ @staticmethod
528
+ def from_space(
529
+ space_id: str,
530
+ name: str,
531
+ description: str,
532
+ api_name: Optional[str] = None,
533
+ token: Optional[str] = None,
534
+ ):
535
+ """
536
+ Creates a [`Tool`] from a Space given its id on the Hub.
537
+
538
+ Args:
539
+ space_id (`str`):
540
+ The id of the Space on the Hub.
541
+ name (`str`):
542
+ The name of the tool.
543
+ description (`str`):
544
+ The description of the tool.
545
+ api_name (`str`, *optional*):
546
+ The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
547
+ token (`str`, *optional*):
548
+ Add your token to access private spaces or increase your GPU quotas.
549
+ Returns:
550
+ [`Tool`]:
551
+ The Space, as a tool.
552
+
553
+ Examples:
554
+ ```
555
+ image_generator = Tool.from_space(
556
+ space_id="black-forest-labs/FLUX.1-schnell",
557
+ name="image-generator",
558
+ description="Generate an image from a prompt"
559
+ )
560
+ image = image_generator("Generate an image of a cool surfer in Tahiti")
561
+ ```
562
+ ```
563
+ face_swapper = Tool.from_space(
564
+ "tuan2308/face-swap",
565
+ "face_swapper",
566
+ "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
567
+ )
568
+ image = face_swapper('./aymeric.jpeg', './ruth.jpg')
569
+ ```
570
+ """
571
+ from gradio_client import Client, handle_file
572
+ from gradio_client.utils import is_http_url_like
573
+
574
+ class SpaceToolWrapper(Tool):
575
+ def __init__(
576
+ self,
577
+ space_id: str,
578
+ name: str,
579
+ description: str,
580
+ api_name: Optional[str] = None,
581
+ token: Optional[str] = None,
582
+ ):
583
+ self.client = Client(space_id, hf_token=token)
584
+ self.name = name
585
+ self.description = description
586
+ space_description = self.client.view_api(
587
+ return_format="dict", print_info=False
588
+ )["named_endpoints"]
589
+
590
+ # If api_name is not defined, take the first of the available APIs for this space
591
+ if api_name is None:
592
+ api_name = list(space_description.keys())[0]
593
+ logger.warning(
594
+ f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`."
595
+ )
596
+ self.api_name = api_name
597
+
598
+ try:
599
+ space_description_api = space_description[api_name]
600
+ except KeyError:
601
+ raise KeyError(
602
+ f"Could not find specified {api_name=} among available api names."
603
+ )
604
+
605
+ self.inputs = {}
606
+ for parameter in space_description_api["parameters"]:
607
+ if not parameter["parameter_has_default"]:
608
+ parameter_type = parameter["type"]["type"]
609
+ if parameter_type == "object":
610
+ parameter_type = "any"
611
+ self.inputs[parameter["parameter_name"]] = {
612
+ "type": parameter_type,
613
+ "description": parameter["python_type"]["description"],
614
+ }
615
+ output_component = space_description_api["returns"][0]["component"]
616
+ if output_component == "Image":
617
+ self.output_type = "image"
618
+ elif output_component == "Audio":
619
+ self.output_type = "audio"
620
+ else:
621
+ self.output_type = "any"
622
+
623
+ def sanitize_argument_for_prediction(self, arg):
624
+ if isinstance(arg, ImageType):
625
+ temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
626
+ arg.save(temp_file.name)
627
+ arg = temp_file.name
628
+ if (
629
+ isinstance(arg, (str, Path))
630
+ and Path(arg).exists()
631
+ and Path(arg).is_file()
632
+ ) or is_http_url_like(arg):
633
+ arg = handle_file(arg)
634
+ return arg
635
+
636
+ def forward(self, *args, **kwargs):
637
+ # Preprocess args and kwargs:
638
+ args = list(args)
639
+ for i, arg in enumerate(args):
640
+ args[i] = self.sanitize_argument_for_prediction(arg)
641
+ for arg_name, arg in kwargs.items():
642
+ kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)
643
+
644
+ output = self.client.predict(*args, api_name=self.api_name, **kwargs)
645
+ if isinstance(output, tuple) or isinstance(output, list):
646
+ return output[
647
+ 0
648
+ ] # Sometime the space also returns the generation seed, in which case the result is at index 0
649
+ return output
650
+
651
+ return SpaceToolWrapper(
652
+ space_id, name, description, api_name=api_name, token=token
653
+ )
654
+
655
+ @staticmethod
656
+ def from_gradio(gradio_tool):
657
+ """
658
+ Creates a [`Tool`] from a gradio tool.
659
+ """
660
+ import inspect
661
+
662
+ class GradioToolWrapper(Tool):
663
+ def __init__(self, _gradio_tool):
664
+ self.name = _gradio_tool.name
665
+ self.description = _gradio_tool.description
666
+ self.output_type = "string"
667
+ self._gradio_tool = _gradio_tool
668
+ func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
669
+ self.inputs = {
670
+ key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
671
+ for key, value in func_args
672
+ }
673
+ self.forward = self._gradio_tool.run
674
+
675
+ return GradioToolWrapper(gradio_tool)
676
+
677
+ @staticmethod
678
+ def from_langchain(langchain_tool):
679
+ """
680
+ Creates a [`Tool`] from a langchain tool.
681
+ """
682
+
683
+ class LangChainToolWrapper(Tool):
684
+ def __init__(self, _langchain_tool):
685
+ self.name = _langchain_tool.name.lower()
686
+ self.description = _langchain_tool.description
687
+ self.inputs = _langchain_tool.args.copy()
688
+ for input_content in self.inputs.values():
689
+ if "title" in input_content:
690
+ input_content.pop("title")
691
+ input_content["description"] = ""
692
+ self.output_type = "string"
693
+ self.langchain_tool = _langchain_tool
694
+
695
+ def forward(self, *args, **kwargs):
696
+ tool_input = kwargs.copy()
697
+ for index, argument in enumerate(args):
698
+ if index < len(self.inputs):
699
+ input_key = next(iter(self.inputs))
700
+ tool_input[input_key] = argument
701
+ return self.langchain_tool.run(tool_input)
702
+
703
+ return LangChainToolWrapper(langchain_tool)
704
+
705
+
706
+ DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
707
+ - {{ tool.name }}: {{ tool.description }}
708
+ Takes inputs: {{tool.inputs}}
709
+ Returns an output of type: {{tool.output_type}}
710
+ """
711
+
712
+
713
+ def get_tool_description_with_args(
714
+ tool: Tool, description_template: Optional[str] = None
715
+ ) -> str:
716
+ if description_template is None:
717
+ description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
718
+ compiled_template = compile_jinja_template(description_template)
719
+ rendered = compiled_template.render(
720
+ tool=tool,
721
+ )
722
+ return rendered
723
+
724
+
725
+ @lru_cache
726
+ def compile_jinja_template(template):
727
+ try:
728
+ import jinja2
729
+ from jinja2.exceptions import TemplateError
730
+ from jinja2.sandbox import ImmutableSandboxedEnvironment
731
+ except ImportError:
732
+ raise ImportError("template requires jinja2 to be installed.")
733
+
734
+ if version.parse(jinja2.__version__) < version.parse("3.1.0"):
735
+ raise ImportError(
736
+ "template requires jinja2>=3.1.0 to be installed. Your version is "
737
+ f"{jinja2.__version__}."
738
+ )
739
+
740
+ def raise_exception(message):
741
+ raise TemplateError(message)
742
+
743
+ jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
744
+ jinja_env.globals["raise_exception"] = raise_exception
745
+ return jinja_env.from_string(template)
746
+
747
+
748
+ def launch_gradio_demo(tool_class: Tool):
749
+ """
750
+ Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
751
+ `inputs` and `output_type`.
752
+
753
+ Args:
754
+ tool_class (`type`): The class of the tool for which to launch the demo.
755
+ """
756
+ try:
757
+ import gradio as gr
758
+ except ImportError:
759
+ raise ImportError(
760
+ "Gradio should be installed in order to launch a gradio demo."
761
+ )
762
+
763
+ tool = tool_class()
764
+
765
+ def fn(*args, **kwargs):
766
+ return tool(*args, **kwargs)
767
+
768
+ TYPE_TO_COMPONENT_CLASS_MAPPING = {
769
+ "image": gr.Image,
770
+ "audio": gr.Audio,
771
+ "string": gr.Textbox,
772
+ "integer": gr.Textbox,
773
+ "number": gr.Textbox,
774
+ }
775
+
776
+ gradio_inputs = []
777
+ for input_name, input_details in tool_class.inputs.items():
778
+ input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
779
+ input_details["type"]
780
+ ]
781
+ new_component = input_gradio_component_class(label=input_name)
782
+ gradio_inputs.append(new_component)
783
+
784
+ output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
785
+ tool_class.output_type
786
+ ]
787
+ gradio_output = output_gradio_componentclass(label=input_name)
788
+
789
+ gr.Interface(
790
+ fn=fn,
791
+ inputs=gradio_inputs,
792
+ outputs=gradio_output,
793
+ title=tool_class.__name__,
794
+ article=tool.description,
795
+ ).launch()
796
+
797
+
798
+ TOOL_MAPPING = {
799
+ "python_interpreter": "PythonInterpreterTool",
800
+ "web_search": "DuckDuckGoSearchTool",
801
+ }
802
+
803
+
804
+ def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
805
+ """
806
+ Main function to quickly load a tool, be it on the Hub or in the Transformers library.
807
+
808
+ <Tip warning={true}>
809
+
810
+ Loading a tool means that you'll download the tool and execute it locally.
811
+ ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
812
+ installing a package using pip/npm/apt.
813
+
814
+ </Tip>
815
+
816
+ Args:
817
+ task_or_repo_id (`str`):
818
+ The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
819
+ are:
820
+
821
+ - `"document_question_answering"`
822
+ - `"image_question_answering"`
823
+ - `"speech_to_text"`
824
+ - `"text_to_speech"`
825
+ - `"translation"`
826
+
827
+ model_repo_id (`str`, *optional*):
828
+ Use this argument to use a different model than the default one for the tool you selected.
829
+ token (`str`, *optional*):
830
+ The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
831
+ login` (stored in `~/.huggingface`).
832
+ kwargs (additional keyword arguments, *optional*):
833
+ Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
834
+ `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
835
+ will be passed along to its init.
836
+ """
837
+ if task_or_repo_id in TOOL_MAPPING:
838
+ tool_class_name = TOOL_MAPPING[task_or_repo_id]
839
+ main_module = importlib.import_module("agents")
840
+ tools_module = main_module
841
+ tool_class = getattr(tools_module, tool_class_name)
842
+ return tool_class(model_repo_id, token=token, **kwargs)
843
+ else:
844
+ logger.warning_once(
845
+ f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
846
+ f"trust as the code within that tool will be executed on your machine. Always verify the code of "
847
+ f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
848
+ f"code that you have checked."
849
+ )
850
+ return Tool.from_hub(
851
+ task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs
852
+ )
853
+
854
+
855
+ def add_description(description):
856
+ """
857
+ A decorator that adds a description to a function.
858
+ """
859
+
860
+ def inner(func):
861
+ func.description = description
862
+ func.name = func.__name__
863
+ return func
864
+
865
+ return inner
866
+
867
+
868
+ ## Will move to the Hub
869
+ class EndpointClient:
870
+ def __init__(self, endpoint_url: str, token: Optional[str] = None):
871
+ self.headers = {
872
+ **build_hf_headers(token=token),
873
+ "Content-Type": "application/json",
874
+ }
875
+ self.endpoint_url = endpoint_url
876
+
877
+ @staticmethod
878
+ def encode_image(image):
879
+ _bytes = io.BytesIO()
880
+ image.save(_bytes, format="PNG")
881
+ b64 = base64.b64encode(_bytes.getvalue())
882
+ return b64.decode("utf-8")
883
+
884
+ @staticmethod
885
+ def decode_image(raw_image):
886
+ if not is_vision_available():
887
+ raise ImportError(
888
+ "This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
889
+ )
890
+
891
+ from PIL import Image
892
+
893
+ b64 = base64.b64decode(raw_image)
894
+ _bytes = io.BytesIO(b64)
895
+ return Image.open(_bytes)
896
+
897
+ def __call__(
898
+ self,
899
+ inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
900
+ params: Optional[Dict] = None,
901
+ data: Optional[bytes] = None,
902
+ output_image: bool = False,
903
+ ) -> Any:
904
+ # Build payload
905
+ payload = {}
906
+ if inputs:
907
+ payload["inputs"] = inputs
908
+ if params:
909
+ payload["parameters"] = params
910
+
911
+ # Make API call
912
+ response = get_session().post(
913
+ self.endpoint_url, headers=self.headers, json=payload, data=data
914
+ )
915
+
916
+ # By default, parse the response for the user.
917
+ if output_image:
918
+ return self.decode_image(response.content)
919
+ else:
920
+ return response.json()
921
+
922
+
923
+ class ToolCollection:
924
+ """
925
+ Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
926
+
927
+ > [!NOTE]
928
+ > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
929
+ > like for this collection to showcase them.
930
+
931
+ Args:
932
+ collection_slug (str):
933
+ The collection slug referencing the collection.
934
+ token (str, *optional*):
935
+ The authentication token if the collection is private.
936
+
937
+ Example:
938
+
939
+ ```py
940
+ >>> from transformers import ToolCollection, CodeAgent
941
+
942
+ >>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
943
+ >>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
944
+
945
+ >>> agent.run("Please draw me a picture of rivers and lakes.")
946
+ ```
947
+ """
948
+
949
+ def __init__(self, collection_slug: str, token: Optional[str] = None):
950
+ self._collection = get_collection(collection_slug, token=token)
951
+ self._hub_repo_ids = {
952
+ item.item_id for item in self._collection.items if item.item_type == "space"
953
+ }
954
+ self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
955
+
956
+
957
+ def tool(tool_function: Callable) -> Tool:
958
+ """
959
+ Converts a function into an instance of a Tool subclass.
960
+
961
+ Args:
962
+ tool_function: Your function. Should have type hints for each input and a type hint for the output.
963
+ Should also have a docstring description including an 'Args:' part where each argument is described.
964
+ """
965
+ parameters = get_json_schema(tool_function)["function"]
966
+ if "return" not in parameters:
967
+ raise TypeHintParsingException(
968
+ "Tool return type not found: make sure your function has a return type hint!"
969
+ )
970
+ class_name = f"{parameters['name'].capitalize()}Tool"
971
+ if parameters["return"]["type"] == "object":
972
+ parameters["return"]["type"] = "any"
973
+
974
+ class SpecificTool(Tool):
975
+ name = parameters["name"]
976
+ description = parameters["description"]
977
+ inputs = parameters["parameters"]["properties"]
978
+ output_type = parameters["return"]["type"]
979
+
980
+ @wraps(tool_function)
981
+ def forward(self, *args, **kwargs):
982
+ return tool_function(*args, **kwargs)
983
+
984
+ original_signature = inspect.signature(tool_function)
985
+ new_parameters = [
986
+ inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
987
+ ] + list(original_signature.parameters.values())
988
+ new_signature = original_signature.replace(parameters=new_parameters)
989
+ SpecificTool.forward.__signature__ = new_signature
990
+ SpecificTool.__name__ = class_name
991
+ return SpecificTool()
992
+
993
+
994
+ HUGGINGFACE_DEFAULT_TOOLS = {}
995
+
996
+
997
+ class Toolbox:
998
+ """
999
+ The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
1000
+ manage them.
1001
+
1002
+ Args:
1003
+ tools (`List[Tool]`):
1004
+ The list of tools to instantiate the toolbox with
1005
+ add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
1006
+ Whether to add the tools available within `transformers` to the toolbox.
1007
+ """
1008
+
1009
+ def __init__(self, tools: List[Tool], add_base_tools: bool = False):
1010
+ self._tools = {tool.name: tool for tool in tools}
1011
+ if add_base_tools:
1012
+ self.add_base_tools()
1013
+
1014
+ def add_base_tools(self, add_python_interpreter: bool = False):
1015
+ global HUGGINGFACE_DEFAULT_TOOLS
1016
+ if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0:
1017
+ HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
1018
+ for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
1019
+ if tool.name != "python_interpreter" or add_python_interpreter:
1020
+ self.add_tool(tool)
1021
+
1022
+ @property
1023
+ def tools(self) -> Dict[str, Tool]:
1024
+ """Get all tools currently in the toolbox"""
1025
+ return self._tools
1026
+
1027
+ def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str:
1028
+ """
1029
+ Returns the description of all tools in the toolbox
1030
+
1031
+ Args:
1032
+ tool_description_template (`str`, *optional*):
1033
+ The template to use to describe the tools. If not provided, the default template will be used.
1034
+ """
1035
+ return "\n".join(
1036
+ [
1037
+ get_tool_description_with_args(tool, tool_description_template)
1038
+ for tool in self._tools.values()
1039
+ ]
1040
+ )
1041
+
1042
+ def add_tool(self, tool: Tool):
1043
+ """
1044
+ Adds a tool to the toolbox
1045
+
1046
+ Args:
1047
+ tool (`Tool`):
1048
+ The tool to add to the toolbox.
1049
+ """
1050
+ if tool.name in self._tools:
1051
+ raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
1052
+ self._tools[tool.name] = tool
1053
+
1054
+ def remove_tool(self, tool_name: str):
1055
+ """
1056
+ Removes a tool from the toolbox
1057
+
1058
+ Args:
1059
+ tool_name (`str`):
1060
+ The tool to remove from the toolbox.
1061
+ """
1062
+ if tool_name not in self._tools:
1063
+ raise KeyError(
1064
+ f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
1065
+ )
1066
+ del self._tools[tool_name]
1067
+
1068
+ def update_tool(self, tool: Tool):
1069
+ """
1070
+ Updates a tool in the toolbox according to its name.
1071
+
1072
+ Args:
1073
+ tool (`Tool`):
1074
+ The tool to update to the toolbox.
1075
+ """
1076
+ if tool.name not in self._tools:
1077
+ raise KeyError(
1078
+ f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
1079
+ )
1080
+ self._tools[tool.name] = tool
1081
+
1082
+ def clear_toolbox(self):
1083
+ """Clears the toolbox"""
1084
+ self._tools = {}
1085
+
1086
+ # def _load_tools_if_needed(self):
1087
+ # for name, tool in self._tools.items():
1088
+ # if not isinstance(tool, Tool):
1089
+ # task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
1090
+ # self._tools[name] = load_tool(task_or_repo_id)
1091
+
1092
+ def __repr__(self):
1093
+ toolbox_description = "Toolbox contents:\n"
1094
+ for tool in self._tools.values():
1095
+ toolbox_description += f"\t{tool.name}: {tool.description}\n"
1096
+ return toolbox_description
1097
+
1098
+ __all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]
types.py ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 HuggingFace Inc.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ import os
16
+ import pathlib
17
+ import tempfile
18
+ import uuid
19
+
20
+ import numpy as np
21
+
22
+ from transformers.utils import (
23
+ is_soundfile_availble,
24
+ is_torch_available,
25
+ is_vision_available,
26
+ )
27
+ import logging
28
+
29
+
30
+ logger = logging.getLogger(__name__)
31
+
32
+ if is_vision_available():
33
+ from PIL import Image
34
+ from PIL.Image import Image as ImageType
35
+ else:
36
+ ImageType = object
37
+
38
+ if is_torch_available():
39
+ import torch
40
+ from torch import Tensor
41
+ else:
42
+ Tensor = object
43
+
44
+ if is_soundfile_availble():
45
+ import soundfile as sf
46
+
47
+
48
+ class AgentType:
49
+ """
50
+ Abstract class to be reimplemented to define types that can be returned by agents.
51
+
52
+ These objects serve three purposes:
53
+
54
+ - They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
55
+ - They can be stringified: str(object) in order to return a string defining the object
56
+ - They should be displayed correctly in ipython notebooks/colab/jupyter
57
+ """
58
+
59
+ def __init__(self, value):
60
+ self._value = value
61
+
62
+ def __str__(self):
63
+ return self.to_string()
64
+
65
+ def to_raw(self):
66
+ logger.error(
67
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
68
+ )
69
+ return self._value
70
+
71
+ def to_string(self) -> str:
72
+ logger.error(
73
+ "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
74
+ )
75
+ return str(self._value)
76
+
77
+
78
+ class AgentText(AgentType, str):
79
+ """
80
+ Text type returned by the agent. Behaves as a string.
81
+ """
82
+
83
+ def to_raw(self):
84
+ return self._value
85
+
86
+ def to_string(self):
87
+ return str(self._value)
88
+
89
+
90
+ class AgentImage(AgentType, ImageType):
91
+ """
92
+ Image type returned by the agent. Behaves as a PIL.Image.
93
+ """
94
+
95
+ def __init__(self, value):
96
+ AgentType.__init__(self, value)
97
+ ImageType.__init__(self)
98
+
99
+ if not is_vision_available():
100
+ raise ImportError("PIL must be installed in order to handle images.")
101
+
102
+ self._path = None
103
+ self._raw = None
104
+ self._tensor = None
105
+
106
+ if isinstance(value, ImageType):
107
+ self._raw = value
108
+ elif isinstance(value, (str, pathlib.Path)):
109
+ self._path = value
110
+ elif isinstance(value, torch.Tensor):
111
+ self._tensor = value
112
+ elif isinstance(value, np.ndarray):
113
+ self._tensor = torch.from_numpy(value)
114
+ else:
115
+ raise TypeError(
116
+ f"Unsupported type for {self.__class__.__name__}: {type(value)}"
117
+ )
118
+
119
+ def _ipython_display_(self, include=None, exclude=None):
120
+ """
121
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
122
+ """
123
+ from IPython.display import Image, display
124
+
125
+ display(Image(self.to_string()))
126
+
127
+ def to_raw(self):
128
+ """
129
+ Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
130
+ """
131
+ if self._raw is not None:
132
+ return self._raw
133
+
134
+ if self._path is not None:
135
+ self._raw = Image.open(self._path)
136
+ return self._raw
137
+
138
+ if self._tensor is not None:
139
+ array = self._tensor.cpu().detach().numpy()
140
+ return Image.fromarray((255 - array * 255).astype(np.uint8))
141
+
142
+ def to_string(self):
143
+ """
144
+ Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
145
+ version of the image.
146
+ """
147
+ if self._path is not None:
148
+ return self._path
149
+
150
+ if self._raw is not None:
151
+ directory = tempfile.mkdtemp()
152
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
153
+ self._raw.save(self._path, format="png")
154
+ return self._path
155
+
156
+ if self._tensor is not None:
157
+ array = self._tensor.cpu().detach().numpy()
158
+
159
+ # There is likely simpler than load into image into save
160
+ img = Image.fromarray((255 - array * 255).astype(np.uint8))
161
+
162
+ directory = tempfile.mkdtemp()
163
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
164
+ img.save(self._path, format="png")
165
+
166
+ return self._path
167
+
168
+ def save(self, output_bytes, format: str = None, **params):
169
+ """
170
+ Saves the image to a file.
171
+ Args:
172
+ output_bytes (bytes): The output bytes to save the image to.
173
+ format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
174
+ **params: Additional parameters to pass to PIL.Image.save.
175
+ """
176
+ img = self.to_raw()
177
+ img.save(output_bytes, format=format, **params)
178
+
179
+
180
+ class AgentAudio(AgentType, str):
181
+ """
182
+ Audio type returned by the agent.
183
+ """
184
+
185
+ def __init__(self, value, samplerate=16_000):
186
+ super().__init__(value)
187
+
188
+ if not is_soundfile_availble():
189
+ raise ImportError("soundfile must be installed in order to handle audio.")
190
+
191
+ self._path = None
192
+ self._tensor = None
193
+
194
+ self.samplerate = samplerate
195
+ if isinstance(value, (str, pathlib.Path)):
196
+ self._path = value
197
+ elif is_torch_available() and isinstance(value, torch.Tensor):
198
+ self._tensor = value
199
+ elif isinstance(value, tuple):
200
+ self.samplerate = value[0]
201
+ if isinstance(value[1], np.ndarray):
202
+ self._tensor = torch.from_numpy(value[1])
203
+ else:
204
+ self._tensor = torch.tensor(value[1])
205
+ else:
206
+ raise ValueError(f"Unsupported audio type: {type(value)}")
207
+
208
+ def _ipython_display_(self, include=None, exclude=None):
209
+ """
210
+ Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
211
+ """
212
+ from IPython.display import Audio, display
213
+
214
+ display(Audio(self.to_string(), rate=self.samplerate))
215
+
216
+ def to_raw(self):
217
+ """
218
+ Returns the "raw" version of that object. It is a `torch.Tensor` object.
219
+ """
220
+ if self._tensor is not None:
221
+ return self._tensor
222
+
223
+ if self._path is not None:
224
+ tensor, self.samplerate = sf.read(self._path)
225
+ self._tensor = torch.tensor(tensor)
226
+ return self._tensor
227
+
228
+ def to_string(self):
229
+ """
230
+ Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
231
+ version of the audio.
232
+ """
233
+ if self._path is not None:
234
+ return self._path
235
+
236
+ if self._tensor is not None:
237
+ directory = tempfile.mkdtemp()
238
+ self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
239
+ sf.write(self._path, self._tensor, samplerate=self.samplerate)
240
+ return self._path
241
+
242
+
243
+ AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
244
+ INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
245
+
246
+ if is_torch_available():
247
+ INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
248
+
249
+
250
+ def handle_agent_inputs(*args, **kwargs):
251
+ args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
252
+ kwargs = {
253
+ k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
254
+ }
255
+ return args, kwargs
256
+
257
+
258
+ def handle_agent_outputs(output, output_type=None):
259
+ if output_type in AGENT_TYPE_MAPPING:
260
+ # If the class has defined outputs, we can map directly according to the class definition
261
+ decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
262
+ return decoded_outputs
263
+ else:
264
+ # If the class does not have defined output, then we map according to the type
265
+ for _k, _v in INSTANCE_TYPE_MAPPING.items():
266
+ if isinstance(output, _k):
267
+ return _v(output)
268
+ return output
269
+
270
+ __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]