bwang0911 commited on
Commit
e03d74b
·
verified ·
1 Parent(s): 903aaac

feat: add sbert support (#25)

Browse files

- feat: add sbert support (32864adfa27340706f76018abfd0b6e3a424335b)

Files changed (3) hide show
  1. config_sentence_transformers.json +10 -0
  2. custom_st.py +197 -0
  3. modules.json +14 -0
config_sentence_transformers.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "__version__": {
3
+ "sentence_transformers": "3.1.0.dev0",
4
+ "transformers": "4.41.2",
5
+ "pytorch": "2.3.1+cu121"
6
+ },
7
+ "prompts": {},
8
+ "default_prompt_name": null,
9
+ "similarity_fn_name": "cosine"
10
+ }
custom_st.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
+ import os
4
+ from io import BytesIO
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ import requests
8
+ import torch
9
+ from PIL import Image
10
+ from torch import nn
11
+ from transformers import AutoConfig, AutoImageProcessor, AutoModel, AutoTokenizer
12
+
13
+
14
+ class Transformer(nn.Module):
15
+ """Huggingface AutoModel to generate token embeddings.
16
+ Loads the correct class, e.g. BERT / RoBERTa etc.
17
+
18
+ Args:
19
+ model_name_or_path: Huggingface models name
20
+ (https://huggingface.co/models)
21
+ max_seq_length: Truncate any inputs longer than max_seq_length
22
+ model_args: Keyword arguments passed to the Huggingface
23
+ Transformers model
24
+ tokenizer_args: Keyword arguments passed to the Huggingface
25
+ Transformers tokenizer
26
+ config_args: Keyword arguments passed to the Huggingface
27
+ Transformers config
28
+ cache_dir: Cache dir for Huggingface Transformers to store/load
29
+ models
30
+ do_lower_case: If true, lowercases the input (independent if the
31
+ model is cased or not)
32
+ tokenizer_name_or_path: Name or path of the tokenizer. When
33
+ None, then model_name_or_path is used
34
+ """
35
+
36
+ def __init__(
37
+ self,
38
+ model_name_or_path: str,
39
+ max_seq_length: Optional[int] = None,
40
+ model_args: Optional[Dict[str, Any]] = None,
41
+ tokenizer_args: Optional[Dict[str, Any]] = None,
42
+ config_args: Optional[Dict[str, Any]] = None,
43
+ cache_dir: Optional[str] = None,
44
+ do_lower_case: bool = False,
45
+ tokenizer_name_or_path: str = None,
46
+ ) -> None:
47
+ super(Transformer, self).__init__()
48
+ self.config_keys = ["max_seq_length", "do_lower_case"]
49
+ self.do_lower_case = do_lower_case
50
+ if model_args is None:
51
+ model_args = {}
52
+ if tokenizer_args is None:
53
+ tokenizer_args = {}
54
+ if config_args is None:
55
+ config_args = {}
56
+
57
+ config = AutoConfig.from_pretrained(
58
+ model_name_or_path, **config_args, cache_dir=cache_dir
59
+ )
60
+ self.jina_clip = AutoModel.from_pretrained(
61
+ model_name_or_path, config=config, cache_dir=cache_dir, **model_args
62
+ )
63
+
64
+ if max_seq_length is not None and "model_max_length" not in tokenizer_args:
65
+ tokenizer_args["model_max_length"] = max_seq_length
66
+ self.tokenizer = AutoTokenizer.from_pretrained(
67
+ (
68
+ tokenizer_name_or_path
69
+ if tokenizer_name_or_path is not None
70
+ else model_name_or_path
71
+ ),
72
+ cache_dir=cache_dir,
73
+ **tokenizer_args,
74
+ )
75
+ self.preprocessor = AutoImageProcessor.from_pretrained(
76
+ (
77
+ tokenizer_name_or_path
78
+ if tokenizer_name_or_path is not None
79
+ else model_name_or_path
80
+ ),
81
+ cache_dir=cache_dir,
82
+ **tokenizer_args,
83
+ )
84
+
85
+ # No max_seq_length set. Try to infer from model
86
+ if max_seq_length is None:
87
+ if (
88
+ hasattr(self.jina_clip, "config")
89
+ and hasattr(self.jina_clip.config, "max_position_embeddings")
90
+ and hasattr(self.tokenizer, "model_max_length")
91
+ ):
92
+ max_seq_length = min(
93
+ self.jina_clip.config.max_position_embeddings,
94
+ self.tokenizer.model_max_length,
95
+ )
96
+
97
+ self.max_seq_length = max_seq_length
98
+
99
+ if tokenizer_name_or_path is not None:
100
+ self.jina_clip.config.tokenizer_class = self.tokenizer.__class__.__name__
101
+
102
+ def forward(
103
+ self, features: Dict[str, torch.Tensor], task_type: Optional[str] = None
104
+ ) -> Dict[str, torch.Tensor]:
105
+ """Returns token_embeddings, cls_token"""
106
+ print("task_type in the custom Transformer:", task_type)
107
+ if "input_ids" in features:
108
+ embedding = self.jina_clip.get_text_features(
109
+ input_ids=features["input_ids"]
110
+ )
111
+ else:
112
+ embedding = self.jina_clip.get_image_features(
113
+ pixel_values=features["pixel_values"]
114
+ )
115
+ return {"sentence_embedding": embedding}
116
+
117
+ def get_word_embedding_dimension(self) -> int:
118
+ return self.config.text_config.embed_dim
119
+
120
+ def decode_data_image(data_image_str):
121
+ header, data = data_image_str.split(',', 1)
122
+ image_data = base64.b64decode(data)
123
+ return Image.open(BytesIO(image_data))
124
+
125
+ def tokenize(
126
+ self, batch: Union[List[str]], padding: Union[str, bool] = True
127
+ ) -> Dict[str, torch.Tensor]:
128
+ """Tokenizes a text and maps tokens to token-ids"""
129
+ images = []
130
+ texts = []
131
+ for sample in batch:
132
+ if isinstance(sample, str):
133
+ if sample.startswith('http'):
134
+ response = requests.get(sample)
135
+ images.append(Image.open(BytesIO(response.content)).convert('RGB'))
136
+ elif sample.startswith('data:image/'):
137
+ images.append(self.decode_data_image(sample).convert('RGB'))
138
+ else:
139
+ # TODO: Make sure that Image.open fails for non-image files
140
+ try:
141
+ images.append(Image.open(sample).convert('RGB'))
142
+ except:
143
+ texts.append(sample)
144
+ elif isinstance(sample, Image.Image):
145
+ images.append(sample.convert('RGB'))
146
+
147
+ if images and texts:
148
+ raise ValueError('Batch must contain either images or texts, not both')
149
+
150
+ if texts:
151
+ return self.tokenizer(
152
+ texts,
153
+ padding=padding,
154
+ truncation="longest_first",
155
+ return_tensors="pt",
156
+ max_length=self.max_seq_length,
157
+ )
158
+ elif images:
159
+ return self.preprocessor(images)
160
+ return {}
161
+
162
+ def save(self, output_path: str, safe_serialization: bool = True) -> None:
163
+ self.jina_clip.save_pretrained(
164
+ output_path, safe_serialization=safe_serialization
165
+ )
166
+ self.tokenizer.save_pretrained(output_path)
167
+ self.preprocessor.save_pretrained(output_path)
168
+
169
+ @staticmethod
170
+ def load(input_path: str) -> "Transformer":
171
+ # Old classes used other config names than 'sentence_bert_config.json'
172
+ for config_name in [
173
+ "sentence_bert_config.json",
174
+ "sentence_roberta_config.json",
175
+ "sentence_distilbert_config.json",
176
+ "sentence_camembert_config.json",
177
+ "sentence_albert_config.json",
178
+ "sentence_xlm-roberta_config.json",
179
+ "sentence_xlnet_config.json",
180
+ ]:
181
+ sbert_config_path = os.path.join(input_path, config_name)
182
+ if os.path.exists(sbert_config_path):
183
+ break
184
+
185
+ with open(sbert_config_path) as fIn:
186
+ config = json.load(fIn)
187
+ # Don't allow configs to set trust_remote_code
188
+ if "model_args" in config and "trust_remote_code" in config["model_args"]:
189
+ config["model_args"].pop("trust_remote_code")
190
+ if (
191
+ "tokenizer_args" in config
192
+ and "trust_remote_code" in config["tokenizer_args"]
193
+ ):
194
+ config["tokenizer_args"].pop("trust_remote_code")
195
+ if "config_args" in config and "trust_remote_code" in config["config_args"]:
196
+ config["config_args"].pop("trust_remote_code")
197
+ return Transformer(model_name_or_path=input_path, **config)
modules.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx":0,
4
+ "name":"0",
5
+ "path":"",
6
+ "type":"custom_st.Transformer"
7
+ },
8
+ {
9
+ "idx":2,
10
+ "name":"2",
11
+ "path":"2_Normalize",
12
+ "type":"sentence_transformers.models.Normalize"
13
+ }
14
+ ]