gmastrapas commited on
Commit
d220929
1 Parent(s): d956937

fix: bug in custom_st.py

Browse files
Files changed (2) hide show
  1. config_sentence_transformers.json +3 -3
  2. custom_st.py +43 -43
config_sentence_transformers.json CHANGED
@@ -1,8 +1,8 @@
1
  {
2
  "__version__": {
3
- "sentence_transformers": "3.1.0",
4
- "transformers": "4.41.2",
5
- "pytorch": "2.3.1+cu121"
6
  },
7
  "prompts": {},
8
  "default_prompt_name": null,
 
1
  {
2
  "__version__": {
3
+ "sentence_transformers": "3.3.0",
4
+ "transformers": "4.46.2",
5
+ "pytorch": "2.2.2"
6
  },
7
  "prompts": {},
8
  "default_prompt_name": null,
custom_st.py CHANGED
@@ -34,8 +34,8 @@ class Transformer(nn.Module):
34
  self.model = AutoModel.from_pretrained(
35
  model_name_or_path, config=config, **model_kwargs
36
  )
37
- if max_seq_length is not None and "model_max_length" not in tokenizer_kwargs:
38
- tokenizer_kwargs["model_max_length"] = max_seq_length
39
 
40
  self.tokenizer = AutoTokenizer.from_pretrained(
41
  tokenizer_name_or_path or model_name_or_path,
@@ -49,9 +49,9 @@ class Transformer(nn.Module):
49
  # No max_seq_length set. Try to infer from model
50
  if max_seq_length is None:
51
  if (
52
- hasattr(self.model, "config")
53
- and hasattr(self.model.config, "max_position_embeddings")
54
- and hasattr(self.tokenizer, "model_max_length")
55
  ):
56
  max_seq_length = min(
57
  self.model.config.max_position_embeddings,
@@ -63,7 +63,7 @@ class Transformer(nn.Module):
63
 
64
  @staticmethod
65
  def _decode_data_image(data_image_str: str) -> Image.Image:
66
- header, data = data_image_str.split(",", 1)
67
  image_data = base64.b64decode(data)
68
  return Image.open(BytesIO(image_data))
69
 
@@ -79,62 +79,62 @@ class Transformer(nn.Module):
79
  _image_or_text_descriptors = []
80
  for sample in texts:
81
  if isinstance(sample, str):
82
- if sample.startswith("http"):
83
  response = requests.get(sample)
84
- _images.append(Image.open(BytesIO(response.content)).convert("RGB"))
85
  _image_or_text_descriptors.append(0)
86
- elif sample.startswith("data:image/"):
87
- _images.append(self._decode_data_image(sample).convert("RGB"))
88
  _image_or_text_descriptors.append(0)
89
  else:
90
  try:
91
- _images.append(Image.open(sample).convert("RGB"))
92
  _image_or_text_descriptors.append(0)
93
  except Exception as e:
94
  _ = str(e)
95
  _texts.append(sample)
96
  _image_or_text_descriptors.append(1)
97
  elif isinstance(sample, Image.Image):
98
- _images.append(sample.convert("RGB"))
99
  _image_or_text_descriptors.append(0)
100
 
101
  encoding = {}
102
  if len(_texts):
103
- encoding["input_ids"] = self.tokenizer(
104
- texts,
105
  padding=padding,
106
- truncation="longest_first",
107
- return_tensors="pt",
108
  max_length=self.max_seq_length,
109
  ).input_ids
110
 
111
  if len(_images):
112
- encoding["pixel_values"] = self.image_processor(
113
- _images, return_tensors="pt"
114
  ).pixel_values
115
 
116
- encoding["image_text_info"] = _image_or_text_descriptors
117
  return encoding
118
 
119
  def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
120
  image_embeddings = []
121
  text_embeddings = []
122
 
123
- if "pixel_values" in features:
124
- image_embeddings = self.model.get_image_features(features["pixel_values"])
125
- if "input_ids" in features:
126
- text_embeddings = self.model.get_text_features(features["input_ids"])
127
 
128
  sentence_embedding = []
129
  image_features = iter(image_embeddings)
130
  text_features = iter(text_embeddings)
131
- for _, _input_type in enumerate(features["image_text_info"]):
132
  if _input_type == 0:
133
  sentence_embedding.append(next(image_features))
134
  else:
135
  sentence_embedding.append(next(text_features))
136
 
137
- features["sentence_embedding"] = torch.stack(sentence_embedding).float()
138
  return features
139
 
140
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
@@ -143,16 +143,16 @@ class Transformer(nn.Module):
143
  self.image_processor.save_pretrained(output_path)
144
 
145
  @staticmethod
146
- def load(input_path: str) -> "Transformer":
147
  # Old classes used other config names than 'sentence_bert_config.json'
148
  for config_name in [
149
- "sentence_bert_config.json",
150
- "sentence_roberta_config.json",
151
- "sentence_distilbert_config.json",
152
- "sentence_camembert_config.json",
153
- "sentence_albert_config.json",
154
- "sentence_xlm-roberta_config.json",
155
- "sentence_xlnet_config.json",
156
  ]:
157
  sbert_config_path = os.path.join(input_path, config_name)
158
  if os.path.exists(sbert_config_path):
@@ -162,19 +162,19 @@ class Transformer(nn.Module):
162
  config = json.load(fIn)
163
 
164
  # Don't allow configs to set trust_remote_code
165
- if "config_kwargs" in config and "trust_remote_code" in config["config_kwargs"]:
166
- config["config_kwargs"].pop("trust_remote_code")
167
- if "model_kwargs" in config and "trust_remote_code" in config["model_kwargs"]:
168
- config["model_kwargs"].pop("trust_remote_code")
169
  if (
170
- "tokenizer_kwargs" in config
171
- and "trust_remote_code" in config["tokenizer_kwargs"]
172
  ):
173
- config["tokenizer_kwargs"].pop("trust_remote_code")
174
  if (
175
- "image_processor_kwargs" in config
176
- and "trust_remote_code" in config["image_processor_kwargs"]
177
  ):
178
- config["image_processor_kwargs"].pop("trust_remote_code")
179
 
180
  return Transformer(model_name_or_path=input_path, **config)
 
34
  self.model = AutoModel.from_pretrained(
35
  model_name_or_path, config=config, **model_kwargs
36
  )
37
+ if max_seq_length is not None and 'model_max_length' not in tokenizer_kwargs:
38
+ tokenizer_kwargs['model_max_length'] = max_seq_length
39
 
40
  self.tokenizer = AutoTokenizer.from_pretrained(
41
  tokenizer_name_or_path or model_name_or_path,
 
49
  # No max_seq_length set. Try to infer from model
50
  if max_seq_length is None:
51
  if (
52
+ hasattr(self.model, 'config')
53
+ and hasattr(self.model.config, 'max_position_embeddings')
54
+ and hasattr(self.tokenizer, 'model_max_length')
55
  ):
56
  max_seq_length = min(
57
  self.model.config.max_position_embeddings,
 
63
 
64
  @staticmethod
65
  def _decode_data_image(data_image_str: str) -> Image.Image:
66
+ header, data = data_image_str.split(',', 1)
67
  image_data = base64.b64decode(data)
68
  return Image.open(BytesIO(image_data))
69
 
 
79
  _image_or_text_descriptors = []
80
  for sample in texts:
81
  if isinstance(sample, str):
82
+ if sample.startswith('http'):
83
  response = requests.get(sample)
84
+ _images.append(Image.open(BytesIO(response.content)).convert('RGB'))
85
  _image_or_text_descriptors.append(0)
86
+ elif sample.startswith('data:image/'):
87
+ _images.append(self._decode_data_image(sample).convert('RGB'))
88
  _image_or_text_descriptors.append(0)
89
  else:
90
  try:
91
+ _images.append(Image.open(sample).convert('RGB'))
92
  _image_or_text_descriptors.append(0)
93
  except Exception as e:
94
  _ = str(e)
95
  _texts.append(sample)
96
  _image_or_text_descriptors.append(1)
97
  elif isinstance(sample, Image.Image):
98
+ _images.append(sample.convert('RGB'))
99
  _image_or_text_descriptors.append(0)
100
 
101
  encoding = {}
102
  if len(_texts):
103
+ encoding['input_ids'] = self.tokenizer(
104
+ _texts,
105
  padding=padding,
106
+ truncation='longest_first',
107
+ return_tensors='pt',
108
  max_length=self.max_seq_length,
109
  ).input_ids
110
 
111
  if len(_images):
112
+ encoding['pixel_values'] = self.image_processor(
113
+ _images, return_tensors='pt'
114
  ).pixel_values
115
 
116
+ encoding['image_text_info'] = _image_or_text_descriptors
117
  return encoding
118
 
119
  def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
120
  image_embeddings = []
121
  text_embeddings = []
122
 
123
+ if 'pixel_values' in features:
124
+ image_embeddings = self.model.get_image_features(features['pixel_values'])
125
+ if 'input_ids' in features:
126
+ text_embeddings = self.model.get_text_features(features['input_ids'])
127
 
128
  sentence_embedding = []
129
  image_features = iter(image_embeddings)
130
  text_features = iter(text_embeddings)
131
+ for _, _input_type in enumerate(features['image_text_info']):
132
  if _input_type == 0:
133
  sentence_embedding.append(next(image_features))
134
  else:
135
  sentence_embedding.append(next(text_features))
136
 
137
+ features['sentence_embedding'] = torch.stack(sentence_embedding).float()
138
  return features
139
 
140
  def save(self, output_path: str, safe_serialization: bool = True) -> None:
 
143
  self.image_processor.save_pretrained(output_path)
144
 
145
  @staticmethod
146
+ def load(input_path: str) -> 'Transformer':
147
  # Old classes used other config names than 'sentence_bert_config.json'
148
  for config_name in [
149
+ 'sentence_bert_config.json',
150
+ 'sentence_roberta_config.json',
151
+ 'sentence_distilbert_config.json',
152
+ 'sentence_camembert_config.json',
153
+ 'sentence_albert_config.json',
154
+ 'sentence_xlm-roberta_config.json',
155
+ 'sentence_xlnet_config.json',
156
  ]:
157
  sbert_config_path = os.path.join(input_path, config_name)
158
  if os.path.exists(sbert_config_path):
 
162
  config = json.load(fIn)
163
 
164
  # Don't allow configs to set trust_remote_code
165
+ if 'config_kwargs' in config and 'trust_remote_code' in config['config_kwargs']:
166
+ config['config_kwargs'].pop('trust_remote_code')
167
+ if 'model_kwargs' in config and 'trust_remote_code' in config['model_kwargs']:
168
+ config['model_kwargs'].pop('trust_remote_code')
169
  if (
170
+ 'tokenizer_kwargs' in config
171
+ and 'trust_remote_code' in config['tokenizer_kwargs']
172
  ):
173
+ config['tokenizer_kwargs'].pop('trust_remote_code')
174
  if (
175
+ 'image_processor_kwargs' in config
176
+ and 'trust_remote_code' in config['image_processor_kwargs']
177
  ):
178
+ config['image_processor_kwargs'].pop('trust_remote_code')
179
 
180
  return Transformer(model_name_or_path=input_path, **config)