m-ric HF staff commited on
Commit
17c4fb7
·
verified ·
1 Parent(s): 3fb089b

Delete types.py

Browse files
Files changed (1) hide show
  1. types.py +0 -270
types.py DELETED
@@ -1,270 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 HuggingFace Inc.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- import os
16
- import pathlib
17
- import tempfile
18
- import uuid
19
-
20
- import numpy as np
21
-
22
- from transformers.utils import (
23
- is_soundfile_availble,
24
- is_torch_available,
25
- is_vision_available,
26
- )
27
- import logging
28
-
29
-
30
- logger = logging.getLogger(__name__)
31
-
32
- if is_vision_available():
33
- from PIL import Image
34
- from PIL.Image import Image as ImageType
35
- else:
36
- ImageType = object
37
-
38
- if is_torch_available():
39
- import torch
40
- from torch import Tensor
41
- else:
42
- Tensor = object
43
-
44
- if is_soundfile_availble():
45
- import soundfile as sf
46
-
47
-
48
- class AgentType:
49
- """
50
- Abstract class to be reimplemented to define types that can be returned by agents.
51
-
52
- These objects serve three purposes:
53
-
54
- - They behave as they were the type they're meant to be, e.g., a string for text, a PIL.Image for images
55
- - They can be stringified: str(object) in order to return a string defining the object
56
- - They should be displayed correctly in ipython notebooks/colab/jupyter
57
- """
58
-
59
- def __init__(self, value):
60
- self._value = value
61
-
62
- def __str__(self):
63
- return self.to_string()
64
-
65
- def to_raw(self):
66
- logger.error(
67
- "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
68
- )
69
- return self._value
70
-
71
- def to_string(self) -> str:
72
- logger.error(
73
- "This is a raw AgentType of unknown type. Display in notebooks and string conversion will be unreliable"
74
- )
75
- return str(self._value)
76
-
77
-
78
- class AgentText(AgentType, str):
79
- """
80
- Text type returned by the agent. Behaves as a string.
81
- """
82
-
83
- def to_raw(self):
84
- return self._value
85
-
86
- def to_string(self):
87
- return str(self._value)
88
-
89
-
90
- class AgentImage(AgentType, ImageType):
91
- """
92
- Image type returned by the agent. Behaves as a PIL.Image.
93
- """
94
-
95
- def __init__(self, value):
96
- AgentType.__init__(self, value)
97
- ImageType.__init__(self)
98
-
99
- if not is_vision_available():
100
- raise ImportError("PIL must be installed in order to handle images.")
101
-
102
- self._path = None
103
- self._raw = None
104
- self._tensor = None
105
-
106
- if isinstance(value, ImageType):
107
- self._raw = value
108
- elif isinstance(value, (str, pathlib.Path)):
109
- self._path = value
110
- elif isinstance(value, torch.Tensor):
111
- self._tensor = value
112
- elif isinstance(value, np.ndarray):
113
- self._tensor = torch.from_numpy(value)
114
- else:
115
- raise TypeError(
116
- f"Unsupported type for {self.__class__.__name__}: {type(value)}"
117
- )
118
-
119
- def _ipython_display_(self, include=None, exclude=None):
120
- """
121
- Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
122
- """
123
- from IPython.display import Image, display
124
-
125
- display(Image(self.to_string()))
126
-
127
- def to_raw(self):
128
- """
129
- Returns the "raw" version of that object. In the case of an AgentImage, it is a PIL.Image.
130
- """
131
- if self._raw is not None:
132
- return self._raw
133
-
134
- if self._path is not None:
135
- self._raw = Image.open(self._path)
136
- return self._raw
137
-
138
- if self._tensor is not None:
139
- array = self._tensor.cpu().detach().numpy()
140
- return Image.fromarray((255 - array * 255).astype(np.uint8))
141
-
142
- def to_string(self):
143
- """
144
- Returns the stringified version of that object. In the case of an AgentImage, it is a path to the serialized
145
- version of the image.
146
- """
147
- if self._path is not None:
148
- return self._path
149
-
150
- if self._raw is not None:
151
- directory = tempfile.mkdtemp()
152
- self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
153
- self._raw.save(self._path, format="png")
154
- return self._path
155
-
156
- if self._tensor is not None:
157
- array = self._tensor.cpu().detach().numpy()
158
-
159
- # There is likely simpler than load into image into save
160
- img = Image.fromarray((255 - array * 255).astype(np.uint8))
161
-
162
- directory = tempfile.mkdtemp()
163
- self._path = os.path.join(directory, str(uuid.uuid4()) + ".png")
164
- img.save(self._path, format="png")
165
-
166
- return self._path
167
-
168
- def save(self, output_bytes, format: str = None, **params):
169
- """
170
- Saves the image to a file.
171
- Args:
172
- output_bytes (bytes): The output bytes to save the image to.
173
- format (str): The format to use for the output image. The format is the same as in PIL.Image.save.
174
- **params: Additional parameters to pass to PIL.Image.save.
175
- """
176
- img = self.to_raw()
177
- img.save(output_bytes, format=format, **params)
178
-
179
-
180
- class AgentAudio(AgentType, str):
181
- """
182
- Audio type returned by the agent.
183
- """
184
-
185
- def __init__(self, value, samplerate=16_000):
186
- super().__init__(value)
187
-
188
- if not is_soundfile_availble():
189
- raise ImportError("soundfile must be installed in order to handle audio.")
190
-
191
- self._path = None
192
- self._tensor = None
193
-
194
- self.samplerate = samplerate
195
- if isinstance(value, (str, pathlib.Path)):
196
- self._path = value
197
- elif is_torch_available() and isinstance(value, torch.Tensor):
198
- self._tensor = value
199
- elif isinstance(value, tuple):
200
- self.samplerate = value[0]
201
- if isinstance(value[1], np.ndarray):
202
- self._tensor = torch.from_numpy(value[1])
203
- else:
204
- self._tensor = torch.tensor(value[1])
205
- else:
206
- raise ValueError(f"Unsupported audio type: {type(value)}")
207
-
208
- def _ipython_display_(self, include=None, exclude=None):
209
- """
210
- Displays correctly this type in an ipython notebook (ipython, colab, jupyter, ...)
211
- """
212
- from IPython.display import Audio, display
213
-
214
- display(Audio(self.to_string(), rate=self.samplerate))
215
-
216
- def to_raw(self):
217
- """
218
- Returns the "raw" version of that object. It is a `torch.Tensor` object.
219
- """
220
- if self._tensor is not None:
221
- return self._tensor
222
-
223
- if self._path is not None:
224
- tensor, self.samplerate = sf.read(self._path)
225
- self._tensor = torch.tensor(tensor)
226
- return self._tensor
227
-
228
- def to_string(self):
229
- """
230
- Returns the stringified version of that object. In the case of an AgentAudio, it is a path to the serialized
231
- version of the audio.
232
- """
233
- if self._path is not None:
234
- return self._path
235
-
236
- if self._tensor is not None:
237
- directory = tempfile.mkdtemp()
238
- self._path = os.path.join(directory, str(uuid.uuid4()) + ".wav")
239
- sf.write(self._path, self._tensor, samplerate=self.samplerate)
240
- return self._path
241
-
242
-
243
- AGENT_TYPE_MAPPING = {"string": AgentText, "image": AgentImage, "audio": AgentAudio}
244
- INSTANCE_TYPE_MAPPING = {str: AgentText, ImageType: AgentImage}
245
-
246
- if is_torch_available():
247
- INSTANCE_TYPE_MAPPING[Tensor] = AgentAudio
248
-
249
-
250
- def handle_agent_inputs(*args, **kwargs):
251
- args = [(arg.to_raw() if isinstance(arg, AgentType) else arg) for arg in args]
252
- kwargs = {
253
- k: (v.to_raw() if isinstance(v, AgentType) else v) for k, v in kwargs.items()
254
- }
255
- return args, kwargs
256
-
257
-
258
- def handle_agent_outputs(output, output_type=None):
259
- if output_type in AGENT_TYPE_MAPPING:
260
- # If the class has defined outputs, we can map directly according to the class definition
261
- decoded_outputs = AGENT_TYPE_MAPPING[output_type](output)
262
- return decoded_outputs
263
- else:
264
- # If the class does not have defined output, then we map according to the type
265
- for _k, _v in INSTANCE_TYPE_MAPPING.items():
266
- if isinstance(output, _k):
267
- return _v(output)
268
- return output
269
-
270
- __all__ = ["AgentType", "AgentImage", "AgentText", "AgentAudio"]