m-ric commited on
Commit
3fb089b
·
verified ·
1 Parent(s): ade7903

Delete tools.py

Browse files
Files changed (1) hide show
  1. tools.py +0 -1098
tools.py DELETED
@@ -1,1098 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding=utf-8
3
-
4
- # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
5
- #
6
- # Licensed under the Apache License, Version 2.0 (the "License");
7
- # you may not use this file except in compliance with the License.
8
- # You may obtain a copy of the License at
9
- #
10
- # http://www.apache.org/licenses/LICENSE-2.0
11
- #
12
- # Unless required by applicable law or agreed to in writing, software
13
- # distributed under the License is distributed on an "AS IS" BASIS,
14
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
- # See the License for the specific language governing permissions and
16
- # limitations under the License.
17
- import ast
18
- import base64
19
- import importlib
20
- import inspect
21
- import io
22
- import json
23
- import os
24
- import tempfile
25
- import textwrap
26
- from functools import lru_cache, wraps
27
- from pathlib import Path
28
- from typing import Any, Callable, Dict, List, Optional, Union
29
-
30
- from huggingface_hub import (
31
- create_repo,
32
- get_collection,
33
- hf_hub_download,
34
- metadata_update,
35
- upload_folder,
36
- )
37
- from huggingface_hub.utils import RepositoryNotFoundError, build_hf_headers, get_session
38
- from packaging import version
39
-
40
- from transformers.dynamic_module_utils import (
41
- custom_object_save,
42
- get_class_from_dynamic_module,
43
- get_imports,
44
- )
45
- from transformers import AutoProcessor
46
- from transformers.utils import (
47
- CONFIG_NAME,
48
- TypeHintParsingException,
49
- cached_file,
50
- get_json_schema,
51
- is_accelerate_available,
52
- is_torch_available,
53
- is_vision_available,
54
- )
55
- from .types import ImageType, handle_agent_inputs, handle_agent_outputs
56
- import logging
57
-
58
- logger = logging.getLogger(__name__)
59
-
60
-
61
- if is_torch_available():
62
- import torch
63
-
64
- if is_accelerate_available():
65
- from accelerate import PartialState
66
- from accelerate.utils import send_to_device
67
-
68
-
69
- TOOL_CONFIG_FILE = "tool_config.json"
70
-
71
-
72
- def get_repo_type(repo_id, repo_type=None, **hub_kwargs):
73
- if repo_type is not None:
74
- return repo_type
75
- try:
76
- hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="space", **hub_kwargs)
77
- return "space"
78
- except RepositoryNotFoundError:
79
- try:
80
- hf_hub_download(repo_id, TOOL_CONFIG_FILE, repo_type="model", **hub_kwargs)
81
- return "model"
82
- except RepositoryNotFoundError:
83
- raise EnvironmentError(
84
- f"`{repo_id}` does not seem to be a valid repo identifier on the Hub."
85
- )
86
- except Exception:
87
- return "model"
88
- except Exception:
89
- return "space"
90
-
91
-
92
- def setup_default_tools():
93
- default_tools = {}
94
- main_module = importlib.import_module("transformers")
95
- tools_module = main_module.agents
96
-
97
- for task_name, tool_class_name in TOOL_MAPPING.items():
98
- tool_class = getattr(tools_module, tool_class_name)
99
- tool_instance = tool_class()
100
- default_tools[tool_class.name] = tool_instance
101
-
102
- return default_tools
103
-
104
-
105
- # docstyle-ignore
106
- APP_FILE_TEMPLATE = """from transformers import launch_gradio_demo
107
- from {module_name} import {class_name}
108
-
109
- launch_gradio_demo({class_name})
110
- """
111
-
112
-
113
- def validate_after_init(cls, do_validate_forward: bool = True):
114
- original_init = cls.__init__
115
-
116
- @wraps(original_init)
117
- def new_init(self, *args, **kwargs):
118
- original_init(self, *args, **kwargs)
119
- self.validate_arguments(do_validate_forward=do_validate_forward)
120
-
121
- cls.__init__ = new_init
122
- return cls
123
-
124
- def validate_forward_method_args(cls):
125
- """Validates that all names in forward method are properly defined.
126
- In particular it will check that all imports are done within the function."""
127
- if 'forward' not in cls.__dict__:
128
- return
129
-
130
- forward = cls.__dict__['forward']
131
- source_code = textwrap.dedent(inspect.getsource(forward))
132
- tree = ast.parse(source_code)
133
-
134
- # Get function arguments
135
- func_node = tree.body[0]
136
- arg_names = {arg.arg for arg in func_node.args.args}
137
-
138
-
139
- import builtins
140
- builtin_names = set(vars(builtins))
141
-
142
-
143
- # Find all used names that aren't arguments or self attributes
144
- class NameChecker(ast.NodeVisitor):
145
- def __init__(self):
146
- self.undefined_names = set()
147
- self.imports = {}
148
- self.from_imports = {}
149
-
150
- def visit_Import(self, node):
151
- """Handle simple imports like 'import datetime'."""
152
- for name in node.names:
153
- actual_name = name.asname or name.name
154
- self.imports[actual_name] = (name.name, actual_name)
155
-
156
- def visit_ImportFrom(self, node):
157
- """Handle from imports like 'from datetime import datetime'."""
158
- module = node.module or ''
159
- for name in node.names:
160
- actual_name = name.asname or name.name
161
- self.from_imports[actual_name] = (module, name.name, actual_name)
162
-
163
- def visit_Name(self, node):
164
- if (isinstance(node.ctx, ast.Load) and not (
165
- node.id == "tool" or
166
- node.id in builtin_names or
167
- node.id in arg_names or
168
- node.id == 'self'
169
- )):
170
- if node.id not in self.from_imports and node.id not in self.imports:
171
- self.undefined_names.add(node.id)
172
-
173
- def visit_Attribute(self, node):
174
- # Skip self.something
175
- if not (isinstance(node.value, ast.Name) and node.value.id == 'self'):
176
- self.generic_visit(node)
177
-
178
- checker = NameChecker()
179
- checker.visit(tree)
180
-
181
- if checker.undefined_names:
182
- raise ValueError(
183
- f"""The following names in forward method are not defined: {', '.join(checker.undefined_names)}.
184
- Make sure all imports and variables are defined within the method.
185
- For instance:
186
-
187
- """
188
- )
189
-
190
- AUTHORIZED_TYPES = [
191
- "string",
192
- "boolean",
193
- "integer",
194
- "number",
195
- "image",
196
- "audio",
197
- "any",
198
- ]
199
-
200
- CONVERSION_DICT = {"str": "string", "int": "integer", "float": "number"}
201
-
202
-
203
- class Tool:
204
- """
205
- A base class for the functions used by the agent. Subclass this and implement the `forward` method as well as the
206
- following class attributes:
207
-
208
- - **description** (`str`) -- A short description of what your tool does, the inputs it expects and the output(s) it
209
- will return. For instance 'This is a tool that downloads a file from a `url`. It takes the `url` as input, and
210
- returns the text contained in the file'.
211
- - **name** (`str`) -- A performative name that will be used for your tool in the prompt to the agent. For instance
212
- `"text-classifier"` or `"image_generator"`.
213
- - **inputs** (`Dict[str, Dict[str, Union[str, type]]]`) -- The dict of modalities expected for the inputs.
214
- It has one `type`key and a `description`key.
215
- This is used by `launch_gradio_demo` or to make a nice space from your tool, and also can be used in the generated
216
- description for your tool.
217
- - **output_type** (`type`) -- The type of the tool output. This is used by `launch_gradio_demo`
218
- or to make a nice space from your tool, and also can be used in the generated description for your tool.
219
-
220
- You can also override the method [`~Tool.setup`] if your tool has an expensive operation to perform before being
221
- usable (such as loading a model). [`~Tool.setup`] will be called the first time you use your tool, but not at
222
- instantiation.
223
- """
224
-
225
- name: str
226
- description: str
227
- inputs: Dict[str, Dict[str, Union[str, type]]]
228
- output_type: str
229
-
230
- def __init__(self, *args, **kwargs):
231
- self.is_initialized = False
232
-
233
- def __init_subclass__(cls, **kwargs):
234
- super().__init_subclass__(**kwargs)
235
- validate_forward_method_args(cls)
236
- validate_after_init(cls, do_validate_forward=False)
237
-
238
-
239
- def validate_arguments(self, do_validate_forward: bool = True):
240
- required_attributes = {
241
- "description": str,
242
- "name": str,
243
- "inputs": dict,
244
- "output_type": str,
245
- }
246
-
247
- for attr, expected_type in required_attributes.items():
248
- attr_value = getattr(self, attr, None)
249
- if attr_value is None:
250
- raise TypeError(f"You must set an attribute {attr}.")
251
- if not isinstance(attr_value, expected_type):
252
- raise TypeError(
253
- f"Attribute {attr} should have type {expected_type.__name__}, got {type(attr_value)} instead."
254
- )
255
- for input_name, input_content in self.inputs.items():
256
- assert isinstance(
257
- input_content, dict
258
- ), f"Input '{input_name}' should be a dictionary."
259
- assert (
260
- "type" in input_content and "description" in input_content
261
- ), f"Input '{input_name}' should have keys 'type' and 'description', has only {list(input_content.keys())}."
262
- if input_content["type"] not in AUTHORIZED_TYPES:
263
- raise Exception(
264
- f"Input '{input_name}': type '{input_content['type']}' is not an authorized value, should be one of {AUTHORIZED_TYPES}."
265
- )
266
-
267
- assert getattr(self, "output_type", None) in AUTHORIZED_TYPES
268
- if do_validate_forward:
269
- signature = inspect.signature(self.forward)
270
- if not set(signature.parameters.keys()) == set(self.inputs.keys()):
271
- raise Exception(
272
- "Tool's 'forward' method should take 'self' as its first argument, then its next arguments should match the keys of tool attribute 'inputs'."
273
- )
274
-
275
- def forward(self, *args, **kwargs):
276
- return NotImplementedError("Write this method in your subclass of `Tool`.")
277
-
278
- def __call__(self, *args, **kwargs):
279
- if not self.is_initialized:
280
- self.setup()
281
- args, kwargs = handle_agent_inputs(*args, **kwargs)
282
- outputs = self.forward(*args, **kwargs)
283
- return handle_agent_outputs(outputs, self.output_type)
284
-
285
- def setup(self):
286
- """
287
- Overwrite this method here for any operation that is expensive and needs to be executed before you start using
288
- your tool. Such as loading a big model.
289
- """
290
- self.is_initialized = True
291
-
292
- def save(self, output_dir):
293
- """
294
- Saves the relevant code files for your tool so it can be pushed to the Hub. This will copy the code of your
295
- tool in `output_dir` as well as autogenerate:
296
-
297
- - an `app.py` file so that your tool can be converted to a space
298
- - a `requirements.txt` containing the names of the module used by your tool (as detected when inspecting its
299
- code)
300
-
301
- You should only use this method to save tools that are defined in a separate module (not `__main__`).
302
-
303
- Args:
304
- output_dir (`str`): The folder in which you want to save your tool.
305
- """
306
- os.makedirs(output_dir, exist_ok=True)
307
- # Save module file
308
- if self.__module__ == "__main__":
309
- raise ValueError(
310
- f"We can't save the code defining {self} in {output_dir} as it's been defined in __main__. You "
311
- "have to put this code in a separate module so we can include it in the saved folder."
312
- )
313
- module_files = custom_object_save(self, output_dir)
314
-
315
- module_name = self.__class__.__module__
316
- last_module = module_name.split(".")[-1]
317
- full_name = f"{last_module}.{self.__class__.__name__}"
318
-
319
- # Save config file
320
- config_file = os.path.join(output_dir, "tool_config.json")
321
- if os.path.isfile(config_file):
322
- with open(config_file, "r", encoding="utf-8") as f:
323
- tool_config = json.load(f)
324
- else:
325
- tool_config = {}
326
-
327
- tool_config = {
328
- "tool_class": full_name,
329
- "description": self.description,
330
- "name": self.name,
331
- "inputs": self.inputs,
332
- "output_type": str(self.output_type),
333
- }
334
- with open(config_file, "w", encoding="utf-8") as f:
335
- f.write(json.dumps(tool_config, indent=2, sort_keys=True) + "\n")
336
-
337
- # Save app file
338
- app_file = os.path.join(output_dir, "app.py")
339
- with open(app_file, "w", encoding="utf-8") as f:
340
- f.write(
341
- APP_FILE_TEMPLATE.format(
342
- module_name=last_module, class_name=self.__class__.__name__
343
- )
344
- )
345
-
346
- # Save requirements file
347
- requirements_file = os.path.join(output_dir, "requirements.txt")
348
- imports = []
349
- for module in module_files:
350
- imports.extend(get_imports(module))
351
- imports = list(set(imports))
352
- with open(requirements_file, "w", encoding="utf-8") as f:
353
- f.write("\n".join(imports) + "\n")
354
-
355
- @classmethod
356
- def from_hub(
357
- cls,
358
- repo_id: str,
359
- token: Optional[str] = None,
360
- **kwargs,
361
- ):
362
- """
363
- Loads a tool defined on the Hub.
364
-
365
- <Tip warning={true}>
366
-
367
- Loading a tool from the Hub means that you'll download the tool and execute it locally.
368
- ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
369
- installing a package using pip/npm/apt.
370
-
371
- </Tip>
372
-
373
- Args:
374
- repo_id (`str`):
375
- The name of the repo on the Hub where your tool is defined.
376
- token (`str`, *optional*):
377
- The token to identify you on hf.co. If unset, will use the token generated when running
378
- `huggingface-cli login` (stored in `~/.huggingface`).
379
- kwargs (additional keyword arguments, *optional*):
380
- Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
381
- `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the
382
- others will be passed along to its init.
383
- """
384
- hub_kwargs_names = [
385
- "cache_dir",
386
- "force_download",
387
- "resume_download",
388
- "proxies",
389
- "revision",
390
- "repo_type",
391
- "subfolder",
392
- "local_files_only",
393
- ]
394
- hub_kwargs = {k: v for k, v in kwargs.items() if k in hub_kwargs_names}
395
-
396
- # Try to get the tool config first.
397
- hub_kwargs["repo_type"] = get_repo_type(repo_id, **hub_kwargs)
398
- resolved_config_file = cached_file(
399
- repo_id,
400
- TOOL_CONFIG_FILE,
401
- token=token,
402
- **hub_kwargs,
403
- _raise_exceptions_for_gated_repo=False,
404
- _raise_exceptions_for_missing_entries=False,
405
- _raise_exceptions_for_connection_errors=False,
406
- )
407
- is_tool_config = resolved_config_file is not None
408
- if resolved_config_file is None:
409
- resolved_config_file = cached_file(
410
- repo_id,
411
- CONFIG_NAME,
412
- token=token,
413
- **hub_kwargs,
414
- _raise_exceptions_for_gated_repo=False,
415
- _raise_exceptions_for_missing_entries=False,
416
- _raise_exceptions_for_connection_errors=False,
417
- )
418
- if resolved_config_file is None:
419
- raise EnvironmentError(
420
- f"{repo_id} does not appear to provide a valid configuration in `tool_config.json` or `config.json`."
421
- )
422
-
423
- with open(resolved_config_file, encoding="utf-8") as reader:
424
- config = json.load(reader)
425
-
426
- if not is_tool_config:
427
- if "custom_tool" not in config:
428
- raise EnvironmentError(
429
- f"{repo_id} does not provide a mapping to custom tools in its configuration `config.json`."
430
- )
431
- custom_tool = config["custom_tool"]
432
- else:
433
- custom_tool = config
434
-
435
- tool_class = custom_tool["tool_class"]
436
- tool_class = get_class_from_dynamic_module(
437
- tool_class, repo_id, token=token, **hub_kwargs
438
- )
439
-
440
- if len(tool_class.name) == 0:
441
- tool_class.name = custom_tool["name"]
442
- if tool_class.name != custom_tool["name"]:
443
- logger.warning(
444
- f"{tool_class.__name__} implements a different name in its configuration and class. Using the tool "
445
- "configuration name."
446
- )
447
- tool_class.name = custom_tool["name"]
448
-
449
- if len(tool_class.description) == 0:
450
- tool_class.description = custom_tool["description"]
451
- if tool_class.description != custom_tool["description"]:
452
- logger.warning(
453
- f"{tool_class.__name__} implements a different description in its configuration and class. Using the "
454
- "tool configuration description."
455
- )
456
- tool_class.description = custom_tool["description"]
457
-
458
- if tool_class.inputs != custom_tool["inputs"]:
459
- tool_class.inputs = custom_tool["inputs"]
460
- if tool_class.output_type != custom_tool["output_type"]:
461
- tool_class.output_type = custom_tool["output_type"]
462
-
463
- if not isinstance(tool_class.inputs, dict):
464
- tool_class.inputs = ast.literal_eval(tool_class.inputs)
465
-
466
- return tool_class(**kwargs)
467
-
468
- def push_to_hub(
469
- self,
470
- repo_id: str,
471
- commit_message: str = "Upload tool",
472
- private: Optional[bool] = None,
473
- token: Optional[Union[bool, str]] = None,
474
- create_pr: bool = False,
475
- ) -> str:
476
- """
477
- Upload the tool to the Hub.
478
-
479
- For this method to work properly, your tool must have been defined in a separate module (not `__main__`).
480
- For instance:
481
- ```
482
- from my_tool_module import MyTool
483
- my_tool = MyTool()
484
- my_tool.push_to_hub("my-username/my-space")
485
- ```
486
-
487
- Parameters:
488
- repo_id (`str`):
489
- The name of the repository you want to push your tool to. It should contain your organization name when
490
- pushing to a given organization.
491
- commit_message (`str`, *optional*, defaults to `"Upload tool"`):
492
- Message to commit while pushing.
493
- private (`bool`, *optional*):
494
- Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
495
- token (`bool` or `str`, *optional*):
496
- The token to use as HTTP bearer authorization for remote files. If unset, will use the token generated
497
- when running `huggingface-cli login` (stored in `~/.huggingface`).
498
- create_pr (`bool`, *optional*, defaults to `False`):
499
- Whether or not to create a PR with the uploaded files or directly commit.
500
- """
501
- repo_url = create_repo(
502
- repo_id=repo_id,
503
- token=token,
504
- private=private,
505
- exist_ok=True,
506
- repo_type="space",
507
- space_sdk="gradio",
508
- )
509
- repo_id = repo_url.repo_id
510
- metadata_update(repo_id, {"tags": ["tool"]}, repo_type="space")
511
-
512
- with tempfile.TemporaryDirectory() as work_dir:
513
- # Save all files.
514
- self.save(work_dir)
515
- logger.info(
516
- f"Uploading the following files to {repo_id}: {','.join(os.listdir(work_dir))}"
517
- )
518
- return upload_folder(
519
- repo_id=repo_id,
520
- commit_message=commit_message,
521
- folder_path=work_dir,
522
- token=token,
523
- create_pr=create_pr,
524
- repo_type="space",
525
- )
526
-
527
- @staticmethod
528
- def from_space(
529
- space_id: str,
530
- name: str,
531
- description: str,
532
- api_name: Optional[str] = None,
533
- token: Optional[str] = None,
534
- ):
535
- """
536
- Creates a [`Tool`] from a Space given its id on the Hub.
537
-
538
- Args:
539
- space_id (`str`):
540
- The id of the Space on the Hub.
541
- name (`str`):
542
- The name of the tool.
543
- description (`str`):
544
- The description of the tool.
545
- api_name (`str`, *optional*):
546
- The specific api_name to use, if the space has several tabs. If not precised, will default to the first available api.
547
- token (`str`, *optional*):
548
- Add your token to access private spaces or increase your GPU quotas.
549
- Returns:
550
- [`Tool`]:
551
- The Space, as a tool.
552
-
553
- Examples:
554
- ```
555
- image_generator = Tool.from_space(
556
- space_id="black-forest-labs/FLUX.1-schnell",
557
- name="image-generator",
558
- description="Generate an image from a prompt"
559
- )
560
- image = image_generator("Generate an image of a cool surfer in Tahiti")
561
- ```
562
- ```
563
- face_swapper = Tool.from_space(
564
- "tuan2308/face-swap",
565
- "face_swapper",
566
- "Tool that puts the face shown on the first image on the second image. You can give it paths to images.",
567
- )
568
- image = face_swapper('./aymeric.jpeg', './ruth.jpg')
569
- ```
570
- """
571
- from gradio_client import Client, handle_file
572
- from gradio_client.utils import is_http_url_like
573
-
574
- class SpaceToolWrapper(Tool):
575
- def __init__(
576
- self,
577
- space_id: str,
578
- name: str,
579
- description: str,
580
- api_name: Optional[str] = None,
581
- token: Optional[str] = None,
582
- ):
583
- self.client = Client(space_id, hf_token=token)
584
- self.name = name
585
- self.description = description
586
- space_description = self.client.view_api(
587
- return_format="dict", print_info=False
588
- )["named_endpoints"]
589
-
590
- # If api_name is not defined, take the first of the available APIs for this space
591
- if api_name is None:
592
- api_name = list(space_description.keys())[0]
593
- logger.warning(
594
- f"Since `api_name` was not defined, it was automatically set to the first avilable API: `{api_name}`."
595
- )
596
- self.api_name = api_name
597
-
598
- try:
599
- space_description_api = space_description[api_name]
600
- except KeyError:
601
- raise KeyError(
602
- f"Could not find specified {api_name=} among available api names."
603
- )
604
-
605
- self.inputs = {}
606
- for parameter in space_description_api["parameters"]:
607
- if not parameter["parameter_has_default"]:
608
- parameter_type = parameter["type"]["type"]
609
- if parameter_type == "object":
610
- parameter_type = "any"
611
- self.inputs[parameter["parameter_name"]] = {
612
- "type": parameter_type,
613
- "description": parameter["python_type"]["description"],
614
- }
615
- output_component = space_description_api["returns"][0]["component"]
616
- if output_component == "Image":
617
- self.output_type = "image"
618
- elif output_component == "Audio":
619
- self.output_type = "audio"
620
- else:
621
- self.output_type = "any"
622
-
623
- def sanitize_argument_for_prediction(self, arg):
624
- if isinstance(arg, ImageType):
625
- temp_file = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
626
- arg.save(temp_file.name)
627
- arg = temp_file.name
628
- if (
629
- isinstance(arg, (str, Path))
630
- and Path(arg).exists()
631
- and Path(arg).is_file()
632
- ) or is_http_url_like(arg):
633
- arg = handle_file(arg)
634
- return arg
635
-
636
- def forward(self, *args, **kwargs):
637
- # Preprocess args and kwargs:
638
- args = list(args)
639
- for i, arg in enumerate(args):
640
- args[i] = self.sanitize_argument_for_prediction(arg)
641
- for arg_name, arg in kwargs.items():
642
- kwargs[arg_name] = self.sanitize_argument_for_prediction(arg)
643
-
644
- output = self.client.predict(*args, api_name=self.api_name, **kwargs)
645
- if isinstance(output, tuple) or isinstance(output, list):
646
- return output[
647
- 0
648
- ] # Sometime the space also returns the generation seed, in which case the result is at index 0
649
- return output
650
-
651
- return SpaceToolWrapper(
652
- space_id, name, description, api_name=api_name, token=token
653
- )
654
-
655
- @staticmethod
656
- def from_gradio(gradio_tool):
657
- """
658
- Creates a [`Tool`] from a gradio tool.
659
- """
660
- import inspect
661
-
662
- class GradioToolWrapper(Tool):
663
- def __init__(self, _gradio_tool):
664
- self.name = _gradio_tool.name
665
- self.description = _gradio_tool.description
666
- self.output_type = "string"
667
- self._gradio_tool = _gradio_tool
668
- func_args = list(inspect.signature(_gradio_tool.run).parameters.items())
669
- self.inputs = {
670
- key: {"type": CONVERSION_DICT[value.annotation], "description": ""}
671
- for key, value in func_args
672
- }
673
- self.forward = self._gradio_tool.run
674
-
675
- return GradioToolWrapper(gradio_tool)
676
-
677
- @staticmethod
678
- def from_langchain(langchain_tool):
679
- """
680
- Creates a [`Tool`] from a langchain tool.
681
- """
682
-
683
- class LangChainToolWrapper(Tool):
684
- def __init__(self, _langchain_tool):
685
- self.name = _langchain_tool.name.lower()
686
- self.description = _langchain_tool.description
687
- self.inputs = _langchain_tool.args.copy()
688
- for input_content in self.inputs.values():
689
- if "title" in input_content:
690
- input_content.pop("title")
691
- input_content["description"] = ""
692
- self.output_type = "string"
693
- self.langchain_tool = _langchain_tool
694
-
695
- def forward(self, *args, **kwargs):
696
- tool_input = kwargs.copy()
697
- for index, argument in enumerate(args):
698
- if index < len(self.inputs):
699
- input_key = next(iter(self.inputs))
700
- tool_input[input_key] = argument
701
- return self.langchain_tool.run(tool_input)
702
-
703
- return LangChainToolWrapper(langchain_tool)
704
-
705
-
706
- DEFAULT_TOOL_DESCRIPTION_TEMPLATE = """
707
- - {{ tool.name }}: {{ tool.description }}
708
- Takes inputs: {{tool.inputs}}
709
- Returns an output of type: {{tool.output_type}}
710
- """
711
-
712
-
713
- def get_tool_description_with_args(
714
- tool: Tool, description_template: Optional[str] = None
715
- ) -> str:
716
- if description_template is None:
717
- description_template = DEFAULT_TOOL_DESCRIPTION_TEMPLATE
718
- compiled_template = compile_jinja_template(description_template)
719
- rendered = compiled_template.render(
720
- tool=tool,
721
- )
722
- return rendered
723
-
724
-
725
- @lru_cache
726
- def compile_jinja_template(template):
727
- try:
728
- import jinja2
729
- from jinja2.exceptions import TemplateError
730
- from jinja2.sandbox import ImmutableSandboxedEnvironment
731
- except ImportError:
732
- raise ImportError("template requires jinja2 to be installed.")
733
-
734
- if version.parse(jinja2.__version__) < version.parse("3.1.0"):
735
- raise ImportError(
736
- "template requires jinja2>=3.1.0 to be installed. Your version is "
737
- f"{jinja2.__version__}."
738
- )
739
-
740
- def raise_exception(message):
741
- raise TemplateError(message)
742
-
743
- jinja_env = ImmutableSandboxedEnvironment(trim_blocks=True, lstrip_blocks=True)
744
- jinja_env.globals["raise_exception"] = raise_exception
745
- return jinja_env.from_string(template)
746
-
747
-
748
- def launch_gradio_demo(tool_class: Tool):
749
- """
750
- Launches a gradio demo for a tool. The corresponding tool class needs to properly implement the class attributes
751
- `inputs` and `output_type`.
752
-
753
- Args:
754
- tool_class (`type`): The class of the tool for which to launch the demo.
755
- """
756
- try:
757
- import gradio as gr
758
- except ImportError:
759
- raise ImportError(
760
- "Gradio should be installed in order to launch a gradio demo."
761
- )
762
-
763
- tool = tool_class()
764
-
765
- def fn(*args, **kwargs):
766
- return tool(*args, **kwargs)
767
-
768
- TYPE_TO_COMPONENT_CLASS_MAPPING = {
769
- "image": gr.Image,
770
- "audio": gr.Audio,
771
- "string": gr.Textbox,
772
- "integer": gr.Textbox,
773
- "number": gr.Textbox,
774
- }
775
-
776
- gradio_inputs = []
777
- for input_name, input_details in tool_class.inputs.items():
778
- input_gradio_component_class = TYPE_TO_COMPONENT_CLASS_MAPPING[
779
- input_details["type"]
780
- ]
781
- new_component = input_gradio_component_class(label=input_name)
782
- gradio_inputs.append(new_component)
783
-
784
- output_gradio_componentclass = TYPE_TO_COMPONENT_CLASS_MAPPING[
785
- tool_class.output_type
786
- ]
787
- gradio_output = output_gradio_componentclass(label=input_name)
788
-
789
- gr.Interface(
790
- fn=fn,
791
- inputs=gradio_inputs,
792
- outputs=gradio_output,
793
- title=tool_class.__name__,
794
- article=tool.description,
795
- ).launch()
796
-
797
-
798
- TOOL_MAPPING = {
799
- "python_interpreter": "PythonInterpreterTool",
800
- "web_search": "DuckDuckGoSearchTool",
801
- }
802
-
803
-
804
- def load_tool(task_or_repo_id, model_repo_id=None, token=None, **kwargs):
805
- """
806
- Main function to quickly load a tool, be it on the Hub or in the Transformers library.
807
-
808
- <Tip warning={true}>
809
-
810
- Loading a tool means that you'll download the tool and execute it locally.
811
- ALWAYS inspect the tool you're downloading before loading it within your runtime, as you would do when
812
- installing a package using pip/npm/apt.
813
-
814
- </Tip>
815
-
816
- Args:
817
- task_or_repo_id (`str`):
818
- The task for which to load the tool or a repo ID of a tool on the Hub. Tasks implemented in Transformers
819
- are:
820
-
821
- - `"document_question_answering"`
822
- - `"image_question_answering"`
823
- - `"speech_to_text"`
824
- - `"text_to_speech"`
825
- - `"translation"`
826
-
827
- model_repo_id (`str`, *optional*):
828
- Use this argument to use a different model than the default one for the tool you selected.
829
- token (`str`, *optional*):
830
- The token to identify you on hf.co. If unset, will use the token generated when running `huggingface-cli
831
- login` (stored in `~/.huggingface`).
832
- kwargs (additional keyword arguments, *optional*):
833
- Additional keyword arguments that will be split in two: all arguments relevant to the Hub (such as
834
- `cache_dir`, `revision`, `subfolder`) will be used when downloading the files for your tool, and the others
835
- will be passed along to its init.
836
- """
837
- if task_or_repo_id in TOOL_MAPPING:
838
- tool_class_name = TOOL_MAPPING[task_or_repo_id]
839
- main_module = importlib.import_module("agents")
840
- tools_module = main_module
841
- tool_class = getattr(tools_module, tool_class_name)
842
- return tool_class(model_repo_id, token=token, **kwargs)
843
- else:
844
- logger.warning_once(
845
- f"You're loading a tool from the Hub from {model_repo_id}. Please make sure this is a source that you "
846
- f"trust as the code within that tool will be executed on your machine. Always verify the code of "
847
- f"the tools that you load. We recommend specifying a `revision` to ensure you're loading the "
848
- f"code that you have checked."
849
- )
850
- return Tool.from_hub(
851
- task_or_repo_id, model_repo_id=model_repo_id, token=token, **kwargs
852
- )
853
-
854
-
855
- def add_description(description):
856
- """
857
- A decorator that adds a description to a function.
858
- """
859
-
860
- def inner(func):
861
- func.description = description
862
- func.name = func.__name__
863
- return func
864
-
865
- return inner
866
-
867
-
868
- ## Will move to the Hub
869
- class EndpointClient:
870
- def __init__(self, endpoint_url: str, token: Optional[str] = None):
871
- self.headers = {
872
- **build_hf_headers(token=token),
873
- "Content-Type": "application/json",
874
- }
875
- self.endpoint_url = endpoint_url
876
-
877
- @staticmethod
878
- def encode_image(image):
879
- _bytes = io.BytesIO()
880
- image.save(_bytes, format="PNG")
881
- b64 = base64.b64encode(_bytes.getvalue())
882
- return b64.decode("utf-8")
883
-
884
- @staticmethod
885
- def decode_image(raw_image):
886
- if not is_vision_available():
887
- raise ImportError(
888
- "This tool returned an image but Pillow is not installed. Please install it (`pip install Pillow`)."
889
- )
890
-
891
- from PIL import Image
892
-
893
- b64 = base64.b64decode(raw_image)
894
- _bytes = io.BytesIO(b64)
895
- return Image.open(_bytes)
896
-
897
- def __call__(
898
- self,
899
- inputs: Optional[Union[str, Dict, List[str], List[List[str]]]] = None,
900
- params: Optional[Dict] = None,
901
- data: Optional[bytes] = None,
902
- output_image: bool = False,
903
- ) -> Any:
904
- # Build payload
905
- payload = {}
906
- if inputs:
907
- payload["inputs"] = inputs
908
- if params:
909
- payload["parameters"] = params
910
-
911
- # Make API call
912
- response = get_session().post(
913
- self.endpoint_url, headers=self.headers, json=payload, data=data
914
- )
915
-
916
- # By default, parse the response for the user.
917
- if output_image:
918
- return self.decode_image(response.content)
919
- else:
920
- return response.json()
921
-
922
-
923
- class ToolCollection:
924
- """
925
- Tool collections enable loading all Spaces from a collection in order to be added to the agent's toolbox.
926
-
927
- > [!NOTE]
928
- > Only Spaces will be fetched, so you can feel free to add models and datasets to your collection if you'd
929
- > like for this collection to showcase them.
930
-
931
- Args:
932
- collection_slug (str):
933
- The collection slug referencing the collection.
934
- token (str, *optional*):
935
- The authentication token if the collection is private.
936
-
937
- Example:
938
-
939
- ```py
940
- >>> from transformers import ToolCollection, CodeAgent
941
-
942
- >>> image_tool_collection = ToolCollection(collection_slug="huggingface-tools/diffusion-tools-6630bb19a942c2306a2cdb6f")
943
- >>> agent = CodeAgent(tools=[*image_tool_collection.tools], add_base_tools=True)
944
-
945
- >>> agent.run("Please draw me a picture of rivers and lakes.")
946
- ```
947
- """
948
-
949
- def __init__(self, collection_slug: str, token: Optional[str] = None):
950
- self._collection = get_collection(collection_slug, token=token)
951
- self._hub_repo_ids = {
952
- item.item_id for item in self._collection.items if item.item_type == "space"
953
- }
954
- self.tools = {Tool.from_hub(repo_id) for repo_id in self._hub_repo_ids}
955
-
956
-
957
- def tool(tool_function: Callable) -> Tool:
958
- """
959
- Converts a function into an instance of a Tool subclass.
960
-
961
- Args:
962
- tool_function: Your function. Should have type hints for each input and a type hint for the output.
963
- Should also have a docstring description including an 'Args:' part where each argument is described.
964
- """
965
- parameters = get_json_schema(tool_function)["function"]
966
- if "return" not in parameters:
967
- raise TypeHintParsingException(
968
- "Tool return type not found: make sure your function has a return type hint!"
969
- )
970
- class_name = f"{parameters['name'].capitalize()}Tool"
971
- if parameters["return"]["type"] == "object":
972
- parameters["return"]["type"] = "any"
973
-
974
- class SpecificTool(Tool):
975
- name = parameters["name"]
976
- description = parameters["description"]
977
- inputs = parameters["parameters"]["properties"]
978
- output_type = parameters["return"]["type"]
979
-
980
- @wraps(tool_function)
981
- def forward(self, *args, **kwargs):
982
- return tool_function(*args, **kwargs)
983
-
984
- original_signature = inspect.signature(tool_function)
985
- new_parameters = [
986
- inspect.Parameter("self", inspect.Parameter.POSITIONAL_OR_KEYWORD)
987
- ] + list(original_signature.parameters.values())
988
- new_signature = original_signature.replace(parameters=new_parameters)
989
- SpecificTool.forward.__signature__ = new_signature
990
- SpecificTool.__name__ = class_name
991
- return SpecificTool()
992
-
993
-
994
- HUGGINGFACE_DEFAULT_TOOLS = {}
995
-
996
-
997
- class Toolbox:
998
- """
999
- The toolbox contains all tools that the agent can perform operations with, as well as a few methods to
1000
- manage them.
1001
-
1002
- Args:
1003
- tools (`List[Tool]`):
1004
- The list of tools to instantiate the toolbox with
1005
- add_base_tools (`bool`, defaults to `False`, *optional*, defaults to `False`):
1006
- Whether to add the tools available within `transformers` to the toolbox.
1007
- """
1008
-
1009
- def __init__(self, tools: List[Tool], add_base_tools: bool = False):
1010
- self._tools = {tool.name: tool for tool in tools}
1011
- if add_base_tools:
1012
- self.add_base_tools()
1013
-
1014
- def add_base_tools(self, add_python_interpreter: bool = False):
1015
- global HUGGINGFACE_DEFAULT_TOOLS
1016
- if len(HUGGINGFACE_DEFAULT_TOOLS.keys()) == 0:
1017
- HUGGINGFACE_DEFAULT_TOOLS = setup_default_tools()
1018
- for tool in HUGGINGFACE_DEFAULT_TOOLS.values():
1019
- if tool.name != "python_interpreter" or add_python_interpreter:
1020
- self.add_tool(tool)
1021
-
1022
- @property
1023
- def tools(self) -> Dict[str, Tool]:
1024
- """Get all tools currently in the toolbox"""
1025
- return self._tools
1026
-
1027
- def show_tool_descriptions(self, tool_description_template: Optional[str] = None) -> str:
1028
- """
1029
- Returns the description of all tools in the toolbox
1030
-
1031
- Args:
1032
- tool_description_template (`str`, *optional*):
1033
- The template to use to describe the tools. If not provided, the default template will be used.
1034
- """
1035
- return "\n".join(
1036
- [
1037
- get_tool_description_with_args(tool, tool_description_template)
1038
- for tool in self._tools.values()
1039
- ]
1040
- )
1041
-
1042
- def add_tool(self, tool: Tool):
1043
- """
1044
- Adds a tool to the toolbox
1045
-
1046
- Args:
1047
- tool (`Tool`):
1048
- The tool to add to the toolbox.
1049
- """
1050
- if tool.name in self._tools:
1051
- raise KeyError(f"Error: tool '{tool.name}' already exists in the toolbox.")
1052
- self._tools[tool.name] = tool
1053
-
1054
- def remove_tool(self, tool_name: str):
1055
- """
1056
- Removes a tool from the toolbox
1057
-
1058
- Args:
1059
- tool_name (`str`):
1060
- The tool to remove from the toolbox.
1061
- """
1062
- if tool_name not in self._tools:
1063
- raise KeyError(
1064
- f"Error: tool {tool_name} not found in toolbox for removal, should be instead one of {list(self._tools.keys())}."
1065
- )
1066
- del self._tools[tool_name]
1067
-
1068
- def update_tool(self, tool: Tool):
1069
- """
1070
- Updates a tool in the toolbox according to its name.
1071
-
1072
- Args:
1073
- tool (`Tool`):
1074
- The tool to update to the toolbox.
1075
- """
1076
- if tool.name not in self._tools:
1077
- raise KeyError(
1078
- f"Error: tool {tool.name} not found in toolbox for update, should be instead one of {list(self._tools.keys())}."
1079
- )
1080
- self._tools[tool.name] = tool
1081
-
1082
- def clear_toolbox(self):
1083
- """Clears the toolbox"""
1084
- self._tools = {}
1085
-
1086
- # def _load_tools_if_needed(self):
1087
- # for name, tool in self._tools.items():
1088
- # if not isinstance(tool, Tool):
1089
- # task_or_repo_id = tool.task if tool.repo_id is None else tool.repo_id
1090
- # self._tools[name] = load_tool(task_or_repo_id)
1091
-
1092
- def __repr__(self):
1093
- toolbox_description = "Toolbox contents:\n"
1094
- for tool in self._tools.values():
1095
- toolbox_description += f"\t{tool.name}: {tool.description}\n"
1096
- return toolbox_description
1097
-
1098
- __all__ = ["AUTHORIZED_TYPES", "Tool", "tool", "load_tool", "launch_gradio_demo", "Toolbox"]