gmastrapas commited on
Commit
5695d43
1 Parent(s): 3103208

feat: add option assume_text_inputs in sentence transformers

Browse files
Files changed (1) hide show
  1. custom_st.py +37 -19
custom_st.py CHANGED
@@ -22,6 +22,7 @@ class Transformer(nn.Module):
22
  model_args: Optional[Dict[str, Any]] = None,
23
  tokenizer_args: Optional[Dict[str, Any]] = None,
24
  image_processor_args: Optional[Dict[str, Any]] = None,
 
25
  cache_dir: Optional[str] = None,
26
  backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
27
  **_,
@@ -56,6 +57,8 @@ class Transformer(nn.Module):
56
  image_processor_args (Dict[str, Any], optional): Additional image processor
57
  configuration parameters to be passed to the Hugging Face Transformers
58
  image processor
 
 
59
  cache_dir (str, optional): The Hugging Face Hub cache directory
60
  backend (str, optional): Computational backend, only 'torch' is supported
61
 
@@ -119,6 +122,7 @@ class Transformer(nn.Module):
119
  cache_dir=cache_dir,
120
  **image_processor_kwargs,
121
  )
 
122
 
123
  # No max_seq_length set. Try to infer from model
124
  if max_seq_length is None:
@@ -151,26 +155,40 @@ class Transformer(nn.Module):
151
  _images = []
152
  _texts = []
153
  _image_or_text_descriptors = []
154
- for sample in texts:
155
- if isinstance(sample, str):
156
- if sample.startswith('http'):
157
- response = requests.get(sample)
158
- _images.append(Image.open(BytesIO(response.content)).convert('RGB'))
159
- _image_or_text_descriptors.append(0)
160
- elif sample.startswith('data:image/'):
161
- _images.append(self._decode_data_image(sample).convert('RGB'))
162
- _image_or_text_descriptors.append(0)
163
- else:
164
- try:
165
- _images.append(Image.open(sample).convert('RGB'))
 
 
 
 
 
 
 
 
 
 
166
  _image_or_text_descriptors.append(0)
167
- except Exception as e:
168
- _ = str(e)
169
- _texts.append(sample)
170
- _image_or_text_descriptors.append(1)
171
- elif isinstance(sample, Image.Image):
172
- _images.append(sample.convert('RGB'))
173
- _image_or_text_descriptors.append(0)
 
 
 
 
174
 
175
  encoding = {}
176
  if len(_texts):
 
22
  model_args: Optional[Dict[str, Any]] = None,
23
  tokenizer_args: Optional[Dict[str, Any]] = None,
24
  image_processor_args: Optional[Dict[str, Any]] = None,
25
+ assume_text_inputs: bool = False,
26
  cache_dir: Optional[str] = None,
27
  backend: Literal['torch', 'onnx', 'openvino'] = 'torch',
28
  **_,
 
57
  image_processor_args (Dict[str, Any], optional): Additional image processor
58
  configuration parameters to be passed to the Hugging Face Transformers
59
  image processor
60
+ assume_text_inputs (bool, optional): If set to `True`, all inputs are
61
+ treated as texts. Defaults to `False`
62
  cache_dir (str, optional): The Hugging Face Hub cache directory
63
  backend (str, optional): Computational backend, only 'torch' is supported
64
 
 
122
  cache_dir=cache_dir,
123
  **image_processor_kwargs,
124
  )
125
+ self.assume_text_inputs = assume_text_inputs
126
 
127
  # No max_seq_length set. Try to infer from model
128
  if max_seq_length is None:
 
155
  _images = []
156
  _texts = []
157
  _image_or_text_descriptors = []
158
+
159
+ if self.assume_text_inputs:
160
+ for sample in texts:
161
+ if isinstance(sample, str):
162
+ _texts.append(sample)
163
+ _image_or_text_descriptors.append(1)
164
+ else:
165
+ for sample in texts:
166
+ if isinstance(sample, str):
167
+ if sample.startswith('http'):
168
+ try:
169
+ response = requests.get(sample)
170
+ _images.append(
171
+ Image.open(BytesIO(response.content)).convert('RGB')
172
+ )
173
+ _image_or_text_descriptors.append(0)
174
+ except Exception as e:
175
+ _ = str(e)
176
+ _texts.append(sample)
177
+ _image_or_text_descriptors.append(1)
178
+ elif sample.startswith('data:image/'):
179
+ _images.append(self._decode_data_image(sample).convert('RGB'))
180
  _image_or_text_descriptors.append(0)
181
+ else:
182
+ try:
183
+ _images.append(Image.open(sample).convert('RGB'))
184
+ _image_or_text_descriptors.append(0)
185
+ except Exception as e:
186
+ _ = str(e)
187
+ _texts.append(sample)
188
+ _image_or_text_descriptors.append(1)
189
+ elif isinstance(sample, Image.Image):
190
+ _images.append(sample.convert('RGB'))
191
+ _image_or_text_descriptors.append(0)
192
 
193
  encoding = {}
194
  if len(_texts):