gmastrapas
commited on
Commit
•
5695d43
1
Parent(s):
3103208
feat: add option assume_text_inputs in sentence transformers
Browse files- custom_st.py +37 -19
custom_st.py
CHANGED
@@ -22,6 +22,7 @@ class Transformer(nn.Module):
|
|
22 |
model_args: Optional[Dict[str, Any]] = None,
|
23 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
24 |
image_processor_args: Optional[Dict[str, Any]] = None,
|
|
|
25 |
cache_dir: Optional[str] = None,
|
26 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
27 |
**_,
|
@@ -56,6 +57,8 @@ class Transformer(nn.Module):
|
|
56 |
image_processor_args (Dict[str, Any], optional): Additional image processor
|
57 |
configuration parameters to be passed to the Hugging Face Transformers
|
58 |
image processor
|
|
|
|
|
59 |
cache_dir (str, optional): The Hugging Face Hub cache directory
|
60 |
backend (str, optional): Computational backend, only 'torch' is supported
|
61 |
|
@@ -119,6 +122,7 @@ class Transformer(nn.Module):
|
|
119 |
cache_dir=cache_dir,
|
120 |
**image_processor_kwargs,
|
121 |
)
|
|
|
122 |
|
123 |
# No max_seq_length set. Try to infer from model
|
124 |
if max_seq_length is None:
|
@@ -151,26 +155,40 @@ class Transformer(nn.Module):
|
|
151 |
_images = []
|
152 |
_texts = []
|
153 |
_image_or_text_descriptors = []
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
_image_or_text_descriptors.append(
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
166 |
_image_or_text_descriptors.append(0)
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
|
|
|
|
|
|
|
|
174 |
|
175 |
encoding = {}
|
176 |
if len(_texts):
|
|
|
22 |
model_args: Optional[Dict[str, Any]] = None,
|
23 |
tokenizer_args: Optional[Dict[str, Any]] = None,
|
24 |
image_processor_args: Optional[Dict[str, Any]] = None,
|
25 |
+
assume_text_inputs: bool = False,
|
26 |
cache_dir: Optional[str] = None,
|
27 |
backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
|
28 |
**_,
|
|
|
57 |
image_processor_args (Dict[str, Any], optional): Additional image processor
|
58 |
configuration parameters to be passed to the Hugging Face Transformers
|
59 |
image processor
|
60 |
+
assume_text_inputs (bool, optional): If set to `True`, all inputs are
|
61 |
+
treated as texts. Defaults to `False`
|
62 |
cache_dir (str, optional): The Hugging Face Hub cache directory
|
63 |
backend (str, optional): Computational backend, only 'torch' is supported
|
64 |
|
|
|
122 |
cache_dir=cache_dir,
|
123 |
**image_processor_kwargs,
|
124 |
)
|
125 |
+
self.assume_text_inputs = assume_text_inputs
|
126 |
|
127 |
# No max_seq_length set. Try to infer from model
|
128 |
if max_seq_length is None:
|
|
|
155 |
_images = []
|
156 |
_texts = []
|
157 |
_image_or_text_descriptors = []
|
158 |
+
|
159 |
+
if self.assume_text_inputs:
|
160 |
+
for sample in texts:
|
161 |
+
if isinstance(sample, str):
|
162 |
+
_texts.append(sample)
|
163 |
+
_image_or_text_descriptors.append(1)
|
164 |
+
else:
|
165 |
+
for sample in texts:
|
166 |
+
if isinstance(sample, str):
|
167 |
+
if sample.startswith('http'):
|
168 |
+
try:
|
169 |
+
response = requests.get(sample)
|
170 |
+
_images.append(
|
171 |
+
Image.open(BytesIO(response.content)).convert('RGB')
|
172 |
+
)
|
173 |
+
_image_or_text_descriptors.append(0)
|
174 |
+
except Exception as e:
|
175 |
+
_ = str(e)
|
176 |
+
_texts.append(sample)
|
177 |
+
_image_or_text_descriptors.append(1)
|
178 |
+
elif sample.startswith('data:image/'):
|
179 |
+
_images.append(self._decode_data_image(sample).convert('RGB'))
|
180 |
_image_or_text_descriptors.append(0)
|
181 |
+
else:
|
182 |
+
try:
|
183 |
+
_images.append(Image.open(sample).convert('RGB'))
|
184 |
+
_image_or_text_descriptors.append(0)
|
185 |
+
except Exception as e:
|
186 |
+
_ = str(e)
|
187 |
+
_texts.append(sample)
|
188 |
+
_image_or_text_descriptors.append(1)
|
189 |
+
elif isinstance(sample, Image.Image):
|
190 |
+
_images.append(sample.convert('RGB'))
|
191 |
+
_image_or_text_descriptors.append(0)
|
192 |
|
193 |
encoding = {}
|
194 |
if len(_texts):
|