LinWeizheDragon commited on
Commit
ab8d3d5
·
verified ·
1 Parent(s): a7edceb

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "LinWeizheDragon/PreFLMR_ViT-L",
3
+ "architectures": [
4
+ "FLMRModelForRetrieval"
5
+ ],
6
+ "auto_map": {
7
+ "AutoConfig": "configuration_flmr.FLMRConfig",
8
+ "AutoModel": "modeling_flmr.FLMRModelForRetrieval"
9
+ },
10
+ "context_concat_output_from_text_encoder": true,
11
+ "context_concat_output_from_vision_encoder": false,
12
+ "dim": 128,
13
+ "initializer_range": 0.02,
14
+ "load_cpu_extension": false,
15
+ "mapping_network_prefix_length": 32,
16
+ "mask_instruction_token": ":",
17
+ "mask_punctuation": true,
18
+ "model_type": "flmr",
19
+ "query_concat_output_from_text_encoder": true,
20
+ "query_concat_output_from_vision_encoder": true,
21
+ "separate_query_and_context_text_encoder": true,
22
+ "separate_query_and_context_vision_encoder": false,
23
+ "text_config": {
24
+ "architectures": [
25
+ "BertForMaskedLM"
26
+ ],
27
+ "gradient_checkpointing": false,
28
+ "model_type": "flmr_text_model",
29
+ "use_cache": true
30
+ },
31
+ "torch_dtype": "float32",
32
+ "transformer_mapping_config_base": "bert-base-uncased",
33
+ "transformer_mapping_cross_attention_length": 32,
34
+ "transformer_mapping_num_hidden_layers": 1,
35
+ "transformers_version": "4.37.2",
36
+ "use_transformer_mapping_network": true,
37
+ "use_vision_encoder": true,
38
+ "vision_config": {
39
+ "dropout": 0.0,
40
+ "hidden_size": 1024,
41
+ "intermediate_size": 4096,
42
+ "model_type": "flmr_vision_model",
43
+ "num_attention_heads": 16,
44
+ "num_hidden_layers": 24,
45
+ "patch_size": 14,
46
+ "projection_dim": 768
47
+ },
48
+ "vision_model_version": "openai/clip-vit-large-patch14"
49
+ }
configuration_flmr.py ADDED
@@ -0,0 +1,385 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2010, FLMR authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ FLMR model configuration"""
16
+
17
+ import os
18
+ from typing import Union
19
+
20
+ from transformers.configuration_utils import PretrainedConfig
21
+ from transformers.utils import logging
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ FLMR_PRETRAINED_CONFIG_ARCHIVE_MAP = {
27
+ "LinWeizheDragon/PreFLMR_ViT-L": "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/config.json",
28
+ "LinWeizheDragon/FLMR": "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/config.json",
29
+ }
30
+
31
+
32
+ # Modified from transformers.models.clip.configuration_clip.CLIPVisionConfig with CLIP -> FLMR
33
+ class FLMRVisionConfig(PretrainedConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`FLMRVisionModel`]. It is used to instantiate a
36
+ FLMR vision encoder according to the specified arguments, defining the model architecture. Instantiating a
37
+ configuration with the defaults will yield a similar configuration to that of the vision encoder of the FLMR
38
+ [openai/flmr-vit-base-patch32](https://huggingface.co/openai/flmr-vit-base-patch32) architecture.
39
+
40
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
41
+ documentation from [`PretrainedConfig`] for more information.
42
+
43
+ Args:
44
+ hidden_size (`int`, *optional*, defaults to 768):
45
+ Dimensionality of the encoder layers and the pooler layer.
46
+ intermediate_size (`int`, *optional*, defaults to 3072):
47
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
48
+ projection_dim (`int`, *optional*, defaults to 512):
49
+ Dimentionality of text and vision projection layers.
50
+ num_hidden_layers (`int`, *optional*, defaults to 12):
51
+ Number of hidden layers in the Transformer encoder.
52
+ num_attention_heads (`int`, *optional*, defaults to 12):
53
+ Number of attention heads for each attention layer in the Transformer encoder.
54
+ num_channels (`int`, *optional*, defaults to 3):
55
+ The number of input channels.
56
+ image_size (`int`, *optional*, defaults to 224):
57
+ The size (resolution) of each image.
58
+ patch_size (`int`, *optional*, defaults to 32):
59
+ The size (resolution) of each patch.
60
+ hidden_act (`str` or `function`, *optional*, defaults to `"quick_gelu"`):
61
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
62
+ `"relu"`, `"selu"` and `"gelu_new"` ``"quick_gelu"` are supported.
63
+ layer_norm_eps (`float`, *optional*, defaults to 1e-05):
64
+ The epsilon used by the layer normalization layers.
65
+ attention_dropout (`float`, *optional*, defaults to 0.0):
66
+ The dropout ratio for the attention probabilities.
67
+ initializer_range (`float`, *optional*, defaults to 0.02):
68
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
69
+ initializer_factor (`float`, *optional*, defaults to 1.0):
70
+ A factor for initializing all weight matrices (should be kept to 1, used internally for initialization
71
+ testing).
72
+
73
+ Example:
74
+
75
+ ```python
76
+ >>> from transformers import FLMRVisionConfig, FLMRVisionModel
77
+
78
+ >>> # Initializing a FLMRVisionConfig with LinWeizheDragon/FLMR style configuration
79
+ >>> configuration = FLMRVisionConfig()
80
+
81
+ >>> # Initializing a FLMRVisionModel (with random weights) from the LinWeizheDragon/FLMR style configuration
82
+ >>> model = FLMRVisionModel(configuration)
83
+
84
+ >>> # Accessing the model configuration
85
+ >>> configuration = model.config
86
+ ```"""
87
+
88
+ model_type = "flmr_vision_model"
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_size=768,
93
+ intermediate_size=3072,
94
+ projection_dim=512,
95
+ num_hidden_layers=12,
96
+ num_attention_heads=12,
97
+ num_channels=3,
98
+ image_size=224,
99
+ patch_size=32,
100
+ hidden_act="quick_gelu",
101
+ layer_norm_eps=1e-5,
102
+ attention_dropout=0.0,
103
+ initializer_range=0.02,
104
+ initializer_factor=1.0,
105
+ **kwargs,
106
+ ):
107
+ super().__init__(**kwargs)
108
+
109
+ self.hidden_size = hidden_size
110
+ self.intermediate_size = intermediate_size
111
+ self.projection_dim = projection_dim
112
+ self.num_hidden_layers = num_hidden_layers
113
+ self.num_attention_heads = num_attention_heads
114
+ self.num_channels = num_channels
115
+ self.patch_size = patch_size
116
+ self.image_size = image_size
117
+ self.initializer_range = initializer_range
118
+ self.initializer_factor = initializer_factor
119
+ self.attention_dropout = attention_dropout
120
+ self.layer_norm_eps = layer_norm_eps
121
+ self.hidden_act = hidden_act
122
+
123
+ @classmethod
124
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
125
+ cls._set_token_in_kwargs(kwargs)
126
+
127
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
128
+
129
+ # get the vision config dict if we are loading from a CLIPConfig
130
+ if config_dict.get("model_type") == "clip":
131
+ config_dict = config_dict["vision_config"]
132
+
133
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
134
+ logger.warning(
135
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
136
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
137
+ )
138
+
139
+ return cls.from_dict(config_dict, **kwargs)
140
+
141
+
142
+ # Modified from transformers.models.dpr.configuration_dpr.DPRConfig with DPR -> FLMR
143
+ class FLMRTextConfig(PretrainedConfig):
144
+ r"""
145
+ [`FLMRTextConfig`] is the configuration class to store the configuration of a *FLMRTextModel*.
146
+
147
+ This is the configuration class to store the configuration of a [`FLMRTextModel`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
148
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
149
+ configuration to that of the DPRContextEncoder
150
+ [facebook/dpr-ctx_encoder-single-nq-base](https://huggingface.co/facebook/dpr-ctx_encoder-single-nq-base)
151
+ architecture.
152
+
153
+ This class is a subclass of [`BertConfig`]. Please check the superclass for the documentation of all kwargs.
154
+
155
+ Args:
156
+ vocab_size (`int`, *optional*, defaults to 30522):
157
+ Vocabulary size of the FLMR model. Defines the different tokens that can be represented by the *inputs_ids*
158
+ passed to the forward method of [`BertModel`].
159
+ hidden_size (`int`, *optional*, defaults to 768):
160
+ Dimensionality of the encoder layers and the pooler layer.
161
+ num_hidden_layers (`int`, *optional*, defaults to 12):
162
+ Number of hidden layers in the Transformer encoder.
163
+ num_attention_heads (`int`, *optional*, defaults to 12):
164
+ Number of attention heads for each attention layer in the Transformer encoder.
165
+ intermediate_size (`int`, *optional*, defaults to 3072):
166
+ Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
167
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
168
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
169
+ `"relu"`, `"silu"` and `"gelu_new"` are supported.
170
+ hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
171
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
172
+ attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
173
+ The dropout ratio for the attention probabilities.
174
+ max_position_embeddings (`int`, *optional*, defaults to 512):
175
+ The maximum sequence length that this model might ever be used with. Typically set this to something large
176
+ just in case (e.g., 512 or 1024 or 2048).
177
+ type_vocab_size (`int`, *optional*, defaults to 2):
178
+ The vocabulary size of the *token_type_ids* passed into [`BertModel`].
179
+ initializer_range (`float`, *optional*, defaults to 0.02):
180
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
181
+ layer_norm_eps (`float`, *optional*, defaults to 1e-12):
182
+ The epsilon used by the layer normalization layers.
183
+ pad_token_id (`int`, *optional*, defaults to 0):
184
+ Padding token id.
185
+ position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
186
+ Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
187
+ positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
188
+ [Self-Attention with Relative Position Representations (Shaw et al.)](https://arxiv.org/abs/1803.02155).
189
+ For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
190
+ with Better Relative Position Embeddings (Huang et al.)](https://arxiv.org/abs/2009.13658).
191
+ projection_dim (`int`, *optional*, defaults to 0):
192
+ Dimension of the projection for the context and question encoders. If it is set to zero (default), then no
193
+ projection is done.
194
+
195
+ Example:
196
+
197
+ ```python
198
+ >>> from transformers import FLMRTextConfig, FLMRTextModel
199
+
200
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
201
+ >>> configuration = FLMRTextConfig()
202
+
203
+ >>> # Initializing a model (with random weights) from the LinWeizheDragon/FLMR style configuration
204
+ >>> model = FLMRTextModel(configuration)
205
+
206
+ >>> # Accessing the model configuration
207
+ >>> configuration = model.config
208
+ ```"""
209
+
210
+ model_type = "flmr_text_model"
211
+
212
+ def __init__(
213
+ self,
214
+ vocab_size=30522,
215
+ hidden_size=768,
216
+ num_hidden_layers=12,
217
+ num_attention_heads=12,
218
+ intermediate_size=3072,
219
+ hidden_act="gelu",
220
+ hidden_dropout_prob=0.1,
221
+ attention_probs_dropout_prob=0.1,
222
+ max_position_embeddings=512,
223
+ type_vocab_size=2,
224
+ initializer_range=0.02,
225
+ layer_norm_eps=1e-12,
226
+ pad_token_id=0,
227
+ position_embedding_type="absolute",
228
+ projection_dim: int = 0,
229
+ **kwargs,
230
+ ):
231
+ super().__init__(pad_token_id=pad_token_id, **kwargs)
232
+
233
+ self.vocab_size = vocab_size
234
+ self.hidden_size = hidden_size
235
+ self.num_hidden_layers = num_hidden_layers
236
+ self.num_attention_heads = num_attention_heads
237
+ self.hidden_act = hidden_act
238
+ self.intermediate_size = intermediate_size
239
+ self.hidden_dropout_prob = hidden_dropout_prob
240
+ self.attention_probs_dropout_prob = attention_probs_dropout_prob
241
+ self.max_position_embeddings = max_position_embeddings
242
+ self.type_vocab_size = type_vocab_size
243
+ self.initializer_range = initializer_range
244
+ self.layer_norm_eps = layer_norm_eps
245
+ self.projection_dim = projection_dim
246
+ self.position_embedding_type = position_embedding_type
247
+
248
+
249
+ class FLMRConfig(PretrainedConfig):
250
+ r"""
251
+ [`FLMRConfig`] is the configuration class to store the configuration of a *FLMRModelForRetrieval*.
252
+ This is the configuration class to store the configuration of a [`FLMRModelForRetrieval`]. It is used to instantiate the components of the FLMR model according to the specified arguments,
253
+ defining the model component architectures. Instantiating a configuration with the defaults will yield a similar
254
+ configuration to that of the FLMR
255
+ [LinWeizheDragon/PreFLMR_ViT-G](https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-G)
256
+ architecture.
257
+
258
+ Args:
259
+ vision_config (`FLMRVisionConfig`, *optional*):
260
+ Configuration for the vision encoder.
261
+ text_config (`FLMRTextConfig`, *optional*):
262
+ Configuration for the text encoder.
263
+ mask_punctuation (`bool`, *optional*, defaults to `True`):
264
+ Whether to mask punctuation tokens in the input.
265
+ mapping_network_prefix_length (`int`, *optional*, defaults to 32):
266
+ The output length of the linear mapping network.
267
+ dim (`int`, *optional*, defaults to 128):
268
+ The late-interaction dimension of the model. The output of the text encoder, vision encoder, transformer mapping network should all be projected to this dimension for late-interaction scoring.
269
+ use_vision_encoder (`bool`, *optional*, defaults to `True`):
270
+ Whether to load the vision encoder. When no vision encoder is loaded, `image_features` should be used in the forward pass rather than `pixel_values`.
271
+ initializer_range (`float`, *optional*, defaults to 0.02):
272
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
273
+ separate_query_and_context_text_encoder (`bool`, *optional*, defaults to `False`):
274
+ Whether to use separate text encoders for query and context.
275
+ separate_query_and_context_vision_encoder (`bool`, *optional*, defaults to `False`):
276
+ Whether to use separate vision encoders for query and context.
277
+ query_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `True`):
278
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the query.
279
+ query_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
280
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the query.
281
+ context_concat_output_from_vision_encoder (`bool`, *optional*, defaults to `False`):
282
+ Whether to concatenate the output from the vision encoder to the output from the text encoder for the context.
283
+ context_concat_output_from_text_encoder (`bool`, *optional*, defaults to `True`):
284
+ Whether to concatenate the output from the text encoder to the output from the vision encoder for the context.
285
+ use_transformer_mapping_network (`bool`, *optional*, defaults to `False`):
286
+ Whether to add a transformer mapping network to map the features from the vision encoder to the embedding space. This option is used in PreFLMR.
287
+ transformer_mapping_config_base (`str`, *optional*):
288
+ The base configuration for the transformer mapping network. This option is used in PreFLMR. An example of this argument is `bert-base-uncased`.
289
+ transformer_mapping_num_hidden_layers (`int`, *optional*):
290
+ The number of hidden layers in the transformer mapping network. This option is used in PreFLMR.
291
+ load_cpu_extension (`bool`, *optional*, defaults to `False`):
292
+ Whether to load the CPU extension. Only set this to `True` if a CPU is used in training and inference. In any case, GPU is recommended for training and inference.
293
+ mask_instruction_token (`str`, *optional*):
294
+ The token that indicates the end of the input instruction. All tokens before this token (the first one in a sequence) will be masked. This option is used in PreFLMR.
295
+ transformer_mapping_cross_attention_length (`int`, *optional*, defaults to 32):
296
+ The length of the cross attention in the transformer mapping network. This option is used in PreFLMR.
297
+ vision_model_version (`str`, *optional*, defaults to `"openai/clip-vit-base-patch32"`):
298
+ The version of the vision model being used in this FLMR model.
299
+ This option is used in performing retrieval only. Though it does not affect the model architecture, it is highly recommended to set this argument so that it properly reflects the version of the vision model being used in the FLMR model. This arugment will be saved in the model configuration, and it can be read by the indexing engine. The indexing engine will use this argument to initialize an image processor, which can process the input image files. Find more details under `examples/research_projects/flmr-retrieval`.
300
+
301
+ Example:
302
+
303
+ ```python
304
+ >>> from transformers import FLMRConfig, FLMRModelForRetrieval
305
+
306
+ >>> # Initializing a FLMR LinWeizheDragon/FLMR style configuration
307
+ >>> configuration = FLMRConfig()
308
+
309
+ >>> # Initializing a model (with random weights) from the FLMR style configuration
310
+ >>> model = FLMRModelForRetrieval(configuration)
311
+
312
+ >>> # Accessing the model configuration
313
+ >>> configuration = model.config
314
+ ```"""
315
+
316
+ model_type = "flmr"
317
+
318
+ def __init__(
319
+ self,
320
+ vision_config: FLMRVisionConfig = None,
321
+ text_config: FLMRTextConfig = None,
322
+ mask_punctuation: bool = True,
323
+ mapping_network_prefix_length: int = 32,
324
+ dim: int = 128,
325
+ use_vision_encoder: bool = True,
326
+ initializer_range: float = 0.02,
327
+ separate_query_and_context_text_encoder: bool = False,
328
+ separate_query_and_context_vision_encoder: bool = False,
329
+ query_concat_output_from_vision_encoder: bool = True,
330
+ query_concat_output_from_text_encoder: bool = True,
331
+ context_concat_output_from_vision_encoder: bool = False,
332
+ context_concat_output_from_text_encoder: bool = True,
333
+ use_transformer_mapping_network: bool = False,
334
+ transformer_mapping_config_base: str = None,
335
+ transformer_mapping_num_hidden_layers: int = None,
336
+ load_cpu_extension: bool = False,
337
+ mask_instruction_token: str = None,
338
+ transformer_mapping_cross_attention_length: int = 32,
339
+ vision_model_version: str = "openai/clip-vit-base-patch32",
340
+ **kwargs,
341
+ ):
342
+ super().__init__(**kwargs)
343
+
344
+ if vision_config is None:
345
+ vision_config = {}
346
+ if text_config is None:
347
+ text_config = {}
348
+
349
+ if not isinstance(vision_config, FLMRVisionConfig):
350
+ vision_config = FLMRVisionConfig(**vision_config)
351
+ if not isinstance(text_config, FLMRTextConfig):
352
+ text_config = FLMRTextConfig(**text_config)
353
+
354
+ self.vision_config = vision_config
355
+ self.text_config = text_config
356
+ self.dim = dim
357
+ self.initializer_range = initializer_range
358
+ self.mask_punctuation = mask_punctuation
359
+ self.mapping_network_prefix_length = mapping_network_prefix_length
360
+ self.use_vision_encoder = use_vision_encoder
361
+ self.separate_query_and_context_text_encoder = separate_query_and_context_text_encoder
362
+ self.separate_query_and_context_vision_encoder = separate_query_and_context_vision_encoder
363
+ self.query_concat_output_from_vision_encoder = query_concat_output_from_vision_encoder
364
+ self.query_concat_output_from_text_encoder = query_concat_output_from_text_encoder
365
+ self.context_concat_output_from_vision_encoder = context_concat_output_from_vision_encoder
366
+ self.context_concat_output_from_text_encoder = context_concat_output_from_text_encoder
367
+ self.use_transformer_mapping_network = use_transformer_mapping_network
368
+ self.transformer_mapping_config_base = transformer_mapping_config_base
369
+ self.transformer_mapping_num_hidden_layers = transformer_mapping_num_hidden_layers
370
+ self.load_cpu_extension = load_cpu_extension
371
+ self.mask_instruction_token = mask_instruction_token
372
+ self.transformer_mapping_cross_attention_length = transformer_mapping_cross_attention_length
373
+ self.vision_model_version = vision_model_version
374
+
375
+ @classmethod
376
+ def from_text_vision_configs(cls, text_config: FLMRTextConfig, vision_config: FLMRVisionConfig, **kwargs):
377
+ r"""
378
+ Instantiate a [`FLMRConfig`] (or a derived class) from FLMR text model configuration and FLMR vision model
379
+ configuration.
380
+
381
+ Returns:
382
+ [`FLMRConfig`]: An instance of a configuration object
383
+ """
384
+
385
+ return cls(text_config=text_config, vision_config=vision_config, **kwargs)
context_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
context_tokenizer/tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
context_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "auto_map": {
45
+ "AutoTokenizer": [
46
+ "tokenization_flmr.FLMRContextEncoderTokenizer",
47
+ null
48
+ ]
49
+ },
50
+ "clean_up_tokenization_spaces": true,
51
+ "cls_token": "[CLS]",
52
+ "do_basic_tokenize": true,
53
+ "do_lower_case": true,
54
+ "doc_maxlen": 512,
55
+ "mask_token": "[MASK]",
56
+ "model_max_length": 512,
57
+ "never_split": null,
58
+ "pad_token": "[PAD]",
59
+ "sep_token": "[SEP]",
60
+ "strip_accents": null,
61
+ "tokenize_chinese_chars": true,
62
+ "tokenizer_class": "FLMRContextEncoderTokenizer",
63
+ "unk_token": "[UNK]"
64
+ }
context_tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
flmr_utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This file contains utility functions for the FLMR model. Some of these functions are adapted from the original ColBERT codebase.
3
+ """
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+
8
+
9
+ def get_rank():
10
+ return dist.get_rank()
11
+
12
+
13
+ def get_world_size():
14
+ return dist.get_world_size()
15
+
16
+
17
+ def get_default_group():
18
+ return dist.group.WORLD
19
+
20
+
21
+ # TODO: The masking below might also be applicable in the kNN part
22
+ def colbert_score_reduce(scores_padded, D_mask):
23
+ # print('D_mask', D_mask.shape, D_mask)
24
+ D_padding = ~D_mask.view(scores_padded.size(0), scores_padded.size(1)).bool()
25
+ # print('D_padding', D_padding.shape, D_padding)
26
+ # print(D_padding[0].tolist())
27
+ scores_padded[D_padding] = -9999
28
+ scores = scores_padded.max(1).values
29
+
30
+ return scores.sum(-1)
31
+
32
+
33
+ def colbert_score(Q, D_padded, D_mask, use_gpu=False):
34
+ """
35
+ Supply sizes Q = (1 | num_docs, *, dim) and D = (num_docs, *, dim).
36
+ If Q.size(0) is 1, the matrix will be compared with all passages.
37
+ Otherwise, each query matrix will be compared against the *aligned* passage.
38
+
39
+ EVENTUALLY: Consider masking with -inf for the maxsim (or enforcing a ReLU).
40
+ """
41
+ if use_gpu:
42
+ Q, D_padded, D_mask = Q.cuda(), D_padded.cuda(), D_mask.cuda()
43
+ assert Q.dim() == 3, Q.size()
44
+ assert D_padded.dim() == 3, D_padded.size()
45
+ assert Q.size(0) in [1, D_padded.size(0)]
46
+
47
+ scores = D_padded @ Q.to(dtype=D_padded.dtype).permute(0, 2, 1)
48
+
49
+ return colbert_score_reduce(scores, D_mask)
50
+
51
+
52
+ def _sort_by_length(ids, mask, bsize, *args):
53
+ if ids.size(0) <= bsize:
54
+ return ids, mask, torch.arange(ids.size(0))
55
+
56
+ indices = mask.sum(-1).sort().indices
57
+ reverse_indices = indices.sort().indices
58
+
59
+ return_array = [ids[indices], mask[indices]]
60
+ for arg in args:
61
+ if isinstance(arg, torch.Tensor):
62
+ return_array.append(arg[indices])
63
+ else:
64
+ # arg is a list, and we want to sort the list according to indices
65
+ return_array.append([arg[i] for i in indices])
66
+
67
+ return *return_array, reverse_indices
68
+
69
+
70
+ def _split_into_batches(ids, mask, bsize, *args):
71
+ batches = []
72
+ for offset in range(0, ids.size(0), bsize):
73
+ batch = [ids[offset : offset + bsize], mask[offset : offset + bsize]]
74
+ for arg in args:
75
+ batch.append(arg[offset : offset + bsize])
76
+ batches.append(batch)
77
+ return batches
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dfa4f00731b541a25072b2ea0a9a5388c3744f8389f8f2f2092ab64df66a40a2
3
+ size 2172804160
modeling_flmr.py ADDED
@@ -0,0 +1,1499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 FLMR Authors, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ PyTorch FLMR model for Knowledge-intensive Visual Question Answering."""
16
+
17
+
18
+ import copy
19
+ import os
20
+ import pathlib
21
+ import string
22
+ from dataclasses import dataclass
23
+ from typing import Optional, Tuple, Union
24
+
25
+ import torch
26
+ import torch.distributed as dist
27
+ from torch import Tensor, nn
28
+ from torch.utils.cpp_extension import load
29
+
30
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
31
+ from transformers.modeling_utils import PreTrainedModel
32
+ from transformers.utils import (
33
+ ModelOutput,
34
+ add_start_docstrings,
35
+ add_start_docstrings_to_model_forward,
36
+ logging,
37
+ replace_return_docstrings,
38
+ )
39
+ from transformers.models.bert.modeling_bert import BertModel
40
+ from transformers.models.clip import CLIPVisionModel
41
+ from .configuration_flmr import FLMRConfig, FLMRTextConfig, FLMRVisionConfig
42
+ from .tokenization_flmr import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer
43
+ from .tokenization_flmr_fast import FLMRQueryEncoderTokenizerFast, FLMRContextEncoderTokenizerFast
44
+ from .flmr_utils import (
45
+ colbert_score,
46
+ colbert_score_reduce,
47
+ get_rank,
48
+ get_world_size,
49
+ )
50
+
51
+
52
+ logger = logging.get_logger(__name__)
53
+
54
+ _CONFIG_FOR_DOC = "FLMRConfig"
55
+ _CHECKPOINT_FOR_DOC = "LinWeizheDragon/PreFLMR_ViT-L"
56
+
57
+
58
+ FLMR_PRETRAINED_MODEL_ARCHIVE_LIST = [
59
+ "LinWeizheDragon/PreFLMR_ViT-L",
60
+ "LinWeizheDragon/FLMR",
61
+ # See all FLMR models at https://huggingface.co/models?filter=flmr
62
+ ]
63
+
64
+
65
+ ##########
66
+ # Outputs
67
+ ##########
68
+
69
+
70
+ @dataclass
71
+ class FLMRContextEncoderOutput(ModelOutput):
72
+ """
73
+ Class for outputs of the `doc()` function of [`FLMRModelForRetrieval`].
74
+
75
+ Args:
76
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
77
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the context representation.
78
+ This output can be used to embed questions for nearest neighbors queries with query embeddings.
79
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
80
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
81
+ This output is to be used to embed contexts for late-interaction retrieval with query embeddings.
82
+ context_mask (`torch.FloatTensor` of shape `(batch_size, context_embedding_length)`):
83
+ The FLMR encoder outputs the *context_mask* that corresponds to the mask of the context representation.
84
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
85
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
86
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
87
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
88
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
89
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
90
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
91
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
92
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
93
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
94
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
95
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
96
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
97
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
98
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
99
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
100
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
101
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
102
+ """
103
+
104
+ pooler_output: torch.FloatTensor
105
+ late_interaction_output: torch.FloatTensor = None
106
+ context_mask: torch.FloatTensor = None
107
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
108
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
109
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
110
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
111
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
112
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
113
+
114
+
115
+ @dataclass
116
+ class FLMRQueryEncoderOutput(ModelOutput):
117
+ """
118
+ Class for outputs of the `query()` function of [`FLMRModelForRetrieval.query()`].
119
+
120
+ Args:
121
+ pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
122
+ The FLMR encoder outputs the *pooler_output* that corresponds to the embedding of the first token of the query representation.
123
+ This output can be used to embed questions for nearest neighbors queries with context embeddings.
124
+ late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
125
+ The FLMR encoder outputs the *late_interaction_output* that corresponds to the question representation. The embeddings of all tokens are included for late interaction retrieval.
126
+ This output is to be used to embed questions for late-interaction retrieval with context embeddings.
127
+ text_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
128
+ Tuple of elements containing the attention weights of the text encoder's layers. Each element is a
129
+ tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`.
130
+ text_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
131
+ Tuple of elements containing the hidden states of the text encoder at each layer plus the initial embedding
132
+ outputs. Each tensor has a shape of `(batch_size, sequence_length, hidden_size)`.
133
+ vision_encoder_attentions (`Tuple[torch.FloatTensor]`, *optional*):
134
+ Tuple of elements containing the attention weights of the vision encoder's layers. Each element is a
135
+ tensor of shape `(batch_size, num_heads, vision_sequence_length, vision_sequence_length)`.
136
+ vision_encoder_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
137
+ Tuple of elements containing the hidden states of the vision encoder at each layer plus the initial embedding
138
+ outputs. Each tensor has a shape of `(batch_size, vision_sequence_length, hidden_size)`.
139
+ transformer_mapping_network_attentions (`Tuple[torch.FloatTensor]`, *optional*):
140
+ Tuple of elements containing the attention weights of the transformer mapping network's layers. Each element
141
+ is a tensor of shape `(batch_size, num_heads, mapping_sequence_length, mapping_sequence_length)`.
142
+ transformer_mapping_network_hidden_states (`Tuple[torch.FloatTensor]`, *optional*):
143
+ Tuple of elements containing the hidden states of the transformer mapping network at each layer plus the
144
+ initial embedding outputs. Each tensor has a shape of `(batch_size, mapping_sequence_length, hidden_size)`.
145
+ """
146
+
147
+ pooler_output: torch.FloatTensor
148
+ late_interaction_output: torch.FloatTensor = None
149
+ text_encoder_attentions: Optional[Tuple[Tensor]] = None
150
+ text_encoder_hidden_states: Optional[Tuple[Tensor]] = None
151
+ vision_encoder_attentions: Optional[Tuple[Tensor]] = None
152
+ vision_encoder_hidden_states: Optional[Tuple[Tensor]] = None
153
+ transformer_mapping_network_attentions: Optional[Tuple[Tensor]] = None
154
+ transformer_mapping_network_hidden_states: Optional[Tuple[Tensor]] = None
155
+
156
+
157
+ @dataclass
158
+ class FLMRModelForRetrievalOutput(ModelOutput):
159
+ """
160
+ Class for outputs of [`FLMRModelForRetrieval.query()`].
161
+
162
+ Args:
163
+ loss (`torch.FloatTensor`):
164
+ contrastive loss of the input queries and positive and negative examples. This output is to be used in model training.
165
+ scores (`torch.FloatTensor` of shape `(batch_size, num_positive_examples + num_negative_examples)`):
166
+ The FLMR model outputs the *scores* that corresponds to the late-interaction scores of the input query and context. Each query is associated with `num_positive_examples` positive examples and `num_negative_examples` negative examples, and the scores are the late-interaction scores of the query and these examples.
167
+ in_batch_negative_loss (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
168
+ The FLMR model outputs the *in_batch_negative_loss* which computes contrastive loss that includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This output is to be used in model training.
169
+ query_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, query_embedding_length, embeddings_size)`):
170
+ The FLMR model outputs the *query_late_interaction_output* that corresponds to the late-interaction representations of the input query.
171
+ context_late_interaction_output (`torch.FloatTensor` of shape `(batch_size, context_embedding_length, embeddings_size)`):
172
+ The FLMR model outputs the *context_late_interaction_output* that corresponds to the late-interaction representations of the input context.
173
+ query_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
174
+ Tuple of elements containing the attention weights of the query's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
175
+ query_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
176
+ Tuple of elements containing the hidden states of the query's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
177
+ context_attentions (`Tuple[Tuple[Tensor]]`, *optional*):
178
+ Tuple of elements containing the attention weights of the context's layers. There are three sub-tuples in this tuple, corresponding to the attentions of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, num_heads, sequence_length, sequence_length)`, with `sequence_length` being the sequence length in the corresponding encoder.
179
+ context_hidden_states (`Tuple[Tuple[Tensor]]`, *optional*):
180
+ Tuple of elements containing the hidden states of the context's layers. There are three sub-tuples in this tuple, corresponding to the hidden states of the text encoder, vision encoder, and transformer mapping network. Each element in the sub-tuple is a tensor of shape `(batch_size, sequence_length, hidden_size)`, with `sequence_length` being the sequence length in the corresponding encoder.
181
+ """
182
+
183
+ loss: torch.FloatTensor
184
+ scores: torch.FloatTensor = None
185
+ in_batch_negative_loss: torch.FloatTensor = None
186
+ query_late_interaction_output: torch.FloatTensor = None
187
+ context_late_interaction_output: torch.FloatTensor = None
188
+ query_attentions: Optional[Tuple[Tuple[Tensor]]] = None
189
+ query_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
190
+ context_attentions: Optional[Tuple[Tuple[Tensor]]] = None
191
+ context_hidden_states: Optional[Tuple[Tuple[Tensor]]] = None
192
+
193
+
194
+ class FLMRPreTrainedModel(PreTrainedModel):
195
+ def _init_weights(self, module):
196
+ """Initialize the weights"""
197
+ if isinstance(module, nn.Linear):
198
+ # Slightly different from the TF version which uses truncated_normal for initialization
199
+ # cf https://github.com/pytorch/pytorch/pull/5617
200
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
201
+ if module.bias is not None:
202
+ module.bias.data.zero_()
203
+ elif isinstance(module, nn.Embedding):
204
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
205
+ if module.padding_idx is not None:
206
+ module.weight.data[module.padding_idx].zero_()
207
+ elif isinstance(module, nn.LayerNorm):
208
+ module.bias.data.zero_()
209
+ module.weight.data.fill_(1.0)
210
+
211
+
212
+ ##################
213
+ # PreTrainedModel
214
+ ##################
215
+
216
+
217
+ class FLMRPretrainedModelForRetrieval(FLMRPreTrainedModel):
218
+ """
219
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
220
+ models.
221
+ """
222
+
223
+ config_class = FLMRConfig
224
+ load_tf_weights = None
225
+ base_model_prefix = "flmr"
226
+
227
+
228
+ ###############
229
+ # Actual Models
230
+ ###############
231
+
232
+
233
+ FLMR_START_DOCSTRING = r"""
234
+
235
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
236
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
237
+ etc.)
238
+
239
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
240
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
241
+ and behavior.
242
+
243
+ Parameters:
244
+ config ([`FLMRConfig`]): Model configuration class with all the parameters of the model.
245
+ Initializing with a config file does not load the weights associated with the model, only the
246
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
247
+ query_tokenizer ([`FLMRQueryEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the query.
248
+ The query tokenizer can be initialized with `FLMRQueryEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
249
+ context_tokenizer ([`FLMRContextEncoderTokenizer`], *optional*): The tokenizer used for tokenizing the context.
250
+ The context tokenizer can be initialized with `FLMRContextEncoderTokenizer.from_pretrained(pretrained_model_name_or_path)`.
251
+ """
252
+
253
+
254
+ FLMR_MODEL_INPUTS_DOCSTRING = r"""
255
+ Args:
256
+ query_input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
257
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
258
+ formatted with [CLS] and Q marker tokens as follows:
259
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
260
+
261
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
262
+ rather than the left.
263
+
264
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
265
+ [`PreTrainedTokenizer.__call__`] for details.
266
+
267
+ [What are input IDs?](../glossary#input-ids)
268
+ query_attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
269
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
270
+
271
+ - 1 for tokens that are **not masked**,
272
+ - 0 for tokens that are **masked**.
273
+
274
+ [What are attention masks?](../glossary#attention-mask)
275
+ query_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
276
+ Pixel values. Pixel values can be obtained using
277
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
278
+ query_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
279
+ Image features are required when `query_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
280
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
281
+ context_input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
282
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
283
+ formatted with [CLS] and D marker tokens as follows:
284
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
285
+
286
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
287
+ rather than the left.
288
+
289
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
290
+ [`PreTrainedTokenizer.__call__`] for details.
291
+
292
+ [What are input IDs?](../glossary#input-ids)
293
+
294
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
295
+
296
+ context_attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
297
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
298
+
299
+ - 1 for tokens that are **not masked**,
300
+ - 0 for tokens that are **masked**.
301
+
302
+ [What are attention masks?](../glossary#attention-mask)
303
+
304
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
305
+ context_pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
306
+ Pixel values. Pixel values can be obtained using
307
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
308
+ context_image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
309
+ Image features are required when `context_pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
310
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
311
+ use_in_batch_negatives (`bool`, *optional*):
312
+ Whether or not to use in-batch negatives. If `True`, the contrastive loss includes in-batch negatives. For each positive example, all other examples in the batch except itself are considered negative examples in computing the contrastive loss. This improves ultimate performance in practice. This input is to be used in model training.
313
+ in_batch_negatives_from_all_gpus (`bool`, *optional*):
314
+ Whether or not to use in-batch negatives from all GPUs. If `True`, the contrastive loss includes in-batch negatives from all GPUs. This input is to be used in model training.
315
+ num_negative_examples (`int`, *optional*):
316
+ The number of negative examples in the batch. For example, if `num_negative_examples` is 4, the batch size of `context_input_ids` and `context_attention_mask` is `batch_size * 5`.
317
+ query_concat_output_from_vision_encoder (`bool`, *optional*):
318
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
319
+ query_concat_output_from_text_encoder (`bool`, *optional*):
320
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
321
+
322
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
323
+ context_concat_output_from_vision_encoder (`bool`, *optional*):
324
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
325
+
326
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
327
+ context_concat_output_from_text_encoder (`bool`, *optional*):
328
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
329
+ return_dict (`bool`, *optional*):
330
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
331
+ output_attentions (`bool`, *optional*):
332
+ Whether or not to return the attentions tensors of all attention layers. See `*_attentions` under returned
333
+ tensors for more detail.
334
+ output_hidden_states (`bool`, *optional*):
335
+ Whether or not to return the hidden states of all layers. See `*_hidden_states` under returned tensors for more detail.
336
+ """
337
+
338
+
339
+ FLMR_MODEL_QUERY_INPUTS_DOCSTRING = r"""
340
+ Args:
341
+ input_ids (`torch.LongTensor` of shape `(batch_size, query_length)`):
342
+ Indices of input query tokens in the vocabulary. To match pretraining, FLMR input sequence should be
343
+ formatted with [CLS] and Q marker tokens as follows:
344
+ [CLS] [unused0] using the provided image, obtain documents that address the subsequent question : what is the capital of france? [SEP] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] ...
345
+
346
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
347
+ rather than the left.
348
+
349
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
350
+ [`PreTrainedTokenizer.__call__`] for details.
351
+
352
+ [What are input IDs?](../glossary#input-ids)
353
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, query_length)`, *optional*):
354
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
355
+
356
+ - 1 for tokens that are **not masked**,
357
+ - 0 for tokens that are **masked**.
358
+
359
+ [What are attention masks?](../glossary#attention-mask)
360
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
361
+ Pixel values. Pixel values can be obtained using
362
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
363
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
364
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
365
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel.__call__`] for details.
366
+ concat_output_from_vision_encoder (`bool`, *optional*):
367
+ Whether or not to concatenate the output from the vision encoder to the final query late-interaction representations. If `True`, the output from the vision encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
368
+ concat_output_from_text_encoder (`bool`, *optional*):
369
+ Whether or not to concatenate the output from the text encoder to the final query late-interaction representations. If `True`, the output from the text encoder is concatenated to the query representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
370
+
371
+ This argument can be set to `False` when performing mapping network pretraining as in FLMR and PreFLMR, in which case the output from the text encoder is not concatenated to the final query representations.
372
+ """
373
+
374
+
375
+ FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING = r"""
376
+ Args:
377
+ input_ids (`torch.LongTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`):
378
+ Indices of input context tokens in the vocabulary. To match pretraining, FLMR input sequence should be
379
+ formatted with [CLS] and D marker tokens as follows:
380
+ [CLS] [unused1] paris is the capital of france. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] ...
381
+
382
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
383
+ rather than the left.
384
+
385
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
386
+ [`PreTrainedTokenizer.__call__`] for details.
387
+
388
+ [What are input IDs?](../glossary#input-ids)
389
+
390
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
391
+ attention_mask (`torch.FloatTensor` of shape `(batch_size * (1 + num_negative_examples), context_length)`, *optional*):
392
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
393
+
394
+ - 1 for tokens that are **not masked**,
395
+ - 0 for tokens that are **masked**.
396
+
397
+ [What are attention masks?](../glossary#attention-mask)
398
+
399
+ The input batch size of this tensor is `batch_size * (1 + num_negative_examples)`. Check the following argument `num_negative_examples` for details.
400
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`, *optional*):
401
+ Pixel values. Pixel values can be obtained using
402
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
403
+ image_features (`torch.FloatTensor` of shape `(batch_size, vision_encoder_hidden_size)`, *optional*):
404
+ Image features are required when `pixel_values` is not provided. In this case, vision encoder outputs are pre-extracted to speed up training and inference by skipping the vision encoder forward pass and the extract image features are directly given to the FLMR model. Image features can be obtained
405
+ using [`CLIPVisionModel`]. See [`CLIPVisionModel
406
+ .__call__`] for details.
407
+ concat_output_from_vision_encoder (`bool`, *optional*):
408
+ Whether or not to concatenate the output from the vision encoder to the final context late-interaction representations. If `True`, the output from the vision encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `False` for FLMR and PreFLMR -style models since the context vision encoder is not used.
409
+
410
+ This can be set to `True` to additionally encode the context images with the vision encoder when context images are provided.
411
+ concat_output_from_text_encoder (`bool`, *optional*):
412
+ Whether or not to concatenate the output from the text encoder to the final context late-interaction representations. If `True`, the output from the text encoder is concatenated to the context representations. When using a pretrained model, this will be read from the model configuration. It should be set to `True` for FLMR and PreFLMR -style models.
413
+ keep_dims (`bool`, *optional*):
414
+ Whether or not to keep the dimensions of the output. If `True`, the output is returned with the same dimensions as the input. If `False`, the output is returned with the batch size of the input and the context length. This input is to be used in model training.
415
+ return_mask (`bool`, *optional*):
416
+ Whether or not to return the mask of the context representation. If `True`, the mask of the context representation is returned. This input is to be used in model training.
417
+ """
418
+
419
+
420
+ FLMR_TEXT_ENCODERS_START_DOCSTRING = r"""
421
+
422
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
423
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
424
+ etc.)
425
+
426
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
427
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
428
+ and behavior.
429
+
430
+ Parameters:
431
+ config ([`FLMRTextConfig`]): Model configuration class with all the parameters of the model.
432
+ Initializing with a config file does not load the weights associated with the model, only the
433
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
434
+ """
435
+
436
+
437
+ # Modified from transformers.models.dpr.modeling_dpr with DPR -> FLMR
438
+ FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING = r"""
439
+ Args:
440
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
441
+ Indices of input sequence tokens in the vocabulary. To match pretraining, FLMR input sequence should be
442
+ formatted with [CLS] and [SEP] tokens as follows:
443
+
444
+ (a) For sequence pairs (for a pair title+text for example):
445
+
446
+ ```
447
+ tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
448
+ token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
449
+ ```
450
+
451
+ (b) For single sequences (for a question for example):
452
+
453
+ ```
454
+ tokens: [CLS] the dog is hairy . [SEP]
455
+ token_type_ids: 0 0 0 0 0 0 0
456
+ ```
457
+
458
+ FLMR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
459
+ rather than the left.
460
+
461
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
462
+ [`PreTrainedTokenizer.__call__`] for details.
463
+
464
+ [What are input IDs?](../glossary#input-ids)
465
+ attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
466
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
467
+
468
+ - 1 for tokens that are **not masked**,
469
+ - 0 for tokens that are **masked**.
470
+
471
+ [What are attention masks?](../glossary#attention-mask)
472
+ token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
473
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
474
+ 1]`:
475
+
476
+ - 0 corresponds to a *sentence A* token,
477
+ - 1 corresponds to a *sentence B* token.
478
+
479
+ [What are token type IDs?](../glossary#token-type-ids)
480
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
481
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
482
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
483
+ model's internal embedding lookup matrix.
484
+ output_attentions (`bool`, *optional*):
485
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
486
+ tensors for more detail.
487
+ output_hidden_states (`bool`, *optional*):
488
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
489
+ more detail.
490
+ return_dict (`bool`, *optional*):
491
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
492
+ """
493
+
494
+ FLMR_VISION_ENCODERS_START_DOCSTRING = r"""
495
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
496
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
497
+ etc.)
498
+
499
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
500
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
501
+ and behavior.
502
+
503
+ Parameters:
504
+ config ([`FLMRVisionConfig`]): Model configuration class with all the parameters of the model.
505
+ Initializing with a config file does not load the weights associated with the model, only the
506
+ configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
507
+ """
508
+
509
+ # Modified from transformers.models.clip.modeling_clip with CLIP -> FLMR
510
+ FLMR_VISION_ENCODERS_INPUTS_DOCSTRING = r"""
511
+ Args:
512
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
513
+ Pixel values. Padding will be ignored by default should you provide it. Pixel values can be obtained using
514
+ [`AutoImageProcessor`]. See [`CLIPImageProcessor.__call__`] for details.
515
+ output_attentions (`bool`, *optional*):
516
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
517
+ tensors for more detail.
518
+ output_hidden_states (`bool`, *optional*):
519
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
520
+ more detail.
521
+ return_dict (`bool`, *optional*):
522
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
523
+ """
524
+
525
+
526
+ class FLMRMultiLayerPerceptron(nn.Module):
527
+ """
528
+ A simple multi-layer perceptron with an activation function. This can be used as the mapping network in the FLMR model.
529
+ """
530
+
531
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
532
+ return self.model(x)
533
+
534
+ def __init__(self, sizes, bias=True, act=nn.Tanh):
535
+ super(FLMRMultiLayerPerceptron, self).__init__()
536
+ layers = []
537
+ for i in range(len(sizes) - 1):
538
+ layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias))
539
+ if i < len(sizes) - 2:
540
+ layers.append(act())
541
+ self.model = nn.Sequential(*layers)
542
+
543
+
544
+ @add_start_docstrings(
545
+ "The bare FLMR model that can be used to generate late-interaction embeddings for both multi-modal queries and documents. ",
546
+ FLMR_START_DOCSTRING,
547
+ )
548
+ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
549
+ _keys_to_ignore_on_load_unexpected = [r"cls"]
550
+ main_input_name = "query_input_ids"
551
+ _tied_weights_keys = [] # Added dynamically at initialization depending on the architecture
552
+
553
+ def __init__(self, config: FLMRConfig, query_tokenizer=None, context_tokenizer=None):
554
+ super().__init__(config)
555
+ self.config = config
556
+ self.vision_model_version = config.vision_model_version
557
+
558
+ self.context_text_encoder = FLMRTextModel(config.text_config)
559
+ self.context_text_encoder_linear = nn.Linear(config.text_config.hidden_size, config.dim, bias=False)
560
+
561
+ self.query_tokenizer = query_tokenizer
562
+ self.context_tokenizer = context_tokenizer
563
+
564
+ if self.query_tokenizer is None:
565
+ logger.warning(
566
+ "query_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRQueryEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
567
+ )
568
+ from transformers import FLMRQueryEncoderTokenizer
569
+
570
+ # initialize a FLMRQueryEncoderTokenizer
571
+ self.query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained("bert-base-uncased")
572
+
573
+ if self.context_tokenizer is None:
574
+ logger.warning(
575
+ "context_tokenizer is not provided. A tokenizer is initialized from `bert-base-uncased`. Please pass in an FLMRContextEncoderTokenizer instance if you need to extend the vocabulary beyond the existing ones in the bert tokenizer."
576
+ )
577
+ from transformers import FLMRContextEncoderTokenizer
578
+
579
+ # initialize a FLMRContextEncoderTokenizer
580
+ self.context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained("bert-base-uncased")
581
+
582
+ self.mapping_network_prefix_length = self.config.mapping_network_prefix_length
583
+ self.vision_encoder_embedding_size = self.config.vision_config.hidden_size
584
+ self.text_encoder_embedding_size = self.config.text_config.hidden_size
585
+ self.late_interaction_embedding_size = self.config.dim
586
+
587
+ self.context_vision_projection = FLMRMultiLayerPerceptron(
588
+ (
589
+ self.vision_encoder_embedding_size,
590
+ (self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
591
+ self.late_interaction_embedding_size * self.mapping_network_prefix_length,
592
+ )
593
+ )
594
+
595
+ if self.config.use_vision_encoder:
596
+ self.context_vision_encoder = FLMRVisionModel(config.vision_config)
597
+
598
+ if self.config.use_transformer_mapping_network:
599
+ # This is a PreFLMR style model
600
+ transformer_mapping_config_base = self.config.transformer_mapping_config_base
601
+ try:
602
+ from transformers import BertConfig
603
+ from transformers.models.bert.modeling_bert import BertEncoder
604
+ except Exception as e:
605
+ raise ImportError(f"Failed to import BertConfig and BertEncoder from transformers. {e}")
606
+
607
+ transformer_mapping_config = BertConfig.from_pretrained(transformer_mapping_config_base)
608
+
609
+ assert (
610
+ self.config.text_config.hidden_size == transformer_mapping_config.hidden_size
611
+ ), f"hidden_size {self.config.text_config.hidden_size} != transformer_mapping_config.hidden_size {transformer_mapping_config.hidden_size}. To use cross attention, the dimensions must match."
612
+ # shallow transformer
613
+ transformer_mapping_config.num_hidden_layers = self.config.transformer_mapping_num_hidden_layers
614
+ # add cross attention
615
+ transformer_mapping_config.is_decoder = True
616
+ transformer_mapping_config.add_cross_attention = True
617
+
618
+ # The linear layer from vision encoder to transformer input
619
+ self.transformer_mapping_input_linear = nn.Linear(
620
+ self.vision_encoder_embedding_size, transformer_mapping_config.hidden_size
621
+ )
622
+
623
+ # The transformer encoder
624
+ self.transformer_mapping_network = BertEncoder(transformer_mapping_config)
625
+
626
+ # The linear layer from transformer output to FLMR dim
627
+ self.transformer_mapping_output_linear = nn.Linear(
628
+ transformer_mapping_config.hidden_size, self.late_interaction_embedding_size
629
+ )
630
+
631
+ if self.config.separate_query_and_context_text_encoder:
632
+ self.query_text_encoder = copy.deepcopy(self.context_text_encoder)
633
+ self.query_text_encoder_linear = copy.deepcopy(self.context_text_encoder_linear)
634
+ else:
635
+ self.query_text_encoder = self.context_text_encoder
636
+ self.query_text_encoder_linear = self.context_text_encoder_linear
637
+ self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
638
+
639
+ if self.config.separate_query_and_context_vision_encoder:
640
+ self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
641
+ self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
642
+ else:
643
+ self.query_vision_encoder = self.context_vision_encoder
644
+ self.query_vision_projection = self.context_vision_projection
645
+ self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
646
+
647
+ if self.config.load_cpu_extension:
648
+ FLMRModelForRetrieval.try_load_torch_extensions()
649
+
650
+ if self.config.mask_punctuation:
651
+ self.skiplist = {
652
+ w: True
653
+ for symbol in string.punctuation
654
+ for w in [symbol, self.context_tokenizer.encode(symbol, add_special_tokens=False)[0]]
655
+ }
656
+
657
+ if self.config.mask_instruction_token is not None:
658
+ self.mask_instruction = True
659
+ # obtain the token id of the instruction token
660
+ self.instruction_token_id = self.query_tokenizer.encode(
661
+ self.config.mask_instruction_token, add_special_tokens=False
662
+ )[0]
663
+ else:
664
+ self.mask_instruction = False
665
+
666
+ self.loss_fn = torch.nn.CrossEntropyLoss()
667
+
668
+ # Initialize weights and apply final processing
669
+ self.post_init()
670
+
671
+ @property
672
+ def use_gpu(self):
673
+ return self.device.type == "cuda"
674
+
675
+ @classmethod
676
+ def from_pretrained(self, name_or_path, **kwargs):
677
+ obj = super().from_pretrained(name_or_path, **kwargs)
678
+ return obj
679
+
680
+ @classmethod
681
+ def try_load_torch_extensions(cls):
682
+ if hasattr(cls, "loaded_extensions"):
683
+ return
684
+
685
+ logger.info(
686
+ "Loading segmented_maxsim_cpp extension (set COLBERT_LOAD_TORCH_EXTENSION_VERBOSE=True for more info)..."
687
+ )
688
+ segmented_maxsim_cpp = load(
689
+ name="segmented_maxsim_cpp",
690
+ sources=[
691
+ os.path.join(pathlib.Path(__file__).parent.resolve(), "segmented_maxsim.cpp"),
692
+ ],
693
+ extra_cflags=["-O3"],
694
+ verbose=os.getenv("COLBERT_LOAD_TORCH_EXTENSION_VERBOSE", "False") == "True",
695
+ )
696
+ cls.segmented_maxsim = segmented_maxsim_cpp.segmented_maxsim_cpp
697
+
698
+ cls.loaded_extensions = True
699
+
700
+ def query_mask(self, input_ids, skiplist):
701
+ if not self.mask_instruction:
702
+ return self.mask(input_ids, skiplist)
703
+
704
+ # find the position of end of instruction in input_ids
705
+ # mask the tokens before the position
706
+ sep_id = self.instruction_token_id
707
+ sep_positions = torch.argmax((input_ids == sep_id).int(), dim=1).tolist()
708
+ # if any of the positions is lower than 1, set to 1
709
+ for i, x in enumerate(sep_positions):
710
+ if x < 1:
711
+ sep_positions[i] = 1
712
+ logger.error(f"can not find the separator in the input_ids: {input_ids[i].tolist()}")
713
+ mask = [
714
+ [
715
+ (x not in skiplist) and (x != 0) and (index > sep_positions[seq_index] or index < 2)
716
+ for index, x in enumerate(d)
717
+ ]
718
+ for seq_index, d in enumerate(input_ids.cpu().tolist())
719
+ ]
720
+ return mask
721
+
722
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_INPUTS_DOCSTRING)
723
+ @replace_return_docstrings(output_type=FLMRModelForRetrievalOutput, config_class=_CONFIG_FOR_DOC)
724
+ def forward(
725
+ self,
726
+ query_input_ids: Optional[torch.Tensor] = None,
727
+ query_attention_mask: Optional[torch.Tensor] = None,
728
+ query_pixel_values: Optional[torch.Tensor] = None,
729
+ query_image_features: Optional[torch.Tensor] = None,
730
+ context_input_ids: Optional[torch.Tensor] = None,
731
+ context_attention_mask: Optional[torch.Tensor] = None,
732
+ context_pixel_values: Optional[torch.Tensor] = None,
733
+ context_image_features: Optional[torch.Tensor] = None,
734
+ use_in_batch_negatives: bool = True,
735
+ in_batch_negatives_from_all_gpus: bool = False,
736
+ num_negative_examples: int = 1,
737
+ query_concat_output_from_vision_encoder: Optional[bool] = None,
738
+ query_concat_output_from_text_encoder: Optional[bool] = None,
739
+ context_concat_output_from_vision_encoder: Optional[bool] = None,
740
+ context_concat_output_from_text_encoder: Optional[bool] = None,
741
+ return_dict: bool = None,
742
+ output_attentions: bool = None,
743
+ output_hidden_states: bool = None,
744
+ ) -> Union[FLMRModelForRetrievalOutput, Tuple[Tensor, ...]]:
745
+ r"""
746
+ Return:
747
+
748
+ Examples:
749
+
750
+ ```python
751
+ >>> import torch
752
+ >>> from transformers import FLMRQueryEncoderTokenizer, FLMRContextEncoderTokenizer, FLMRModelForRetrieval, AutoImageProcessor
753
+
754
+ >>> checkpoint_path = "LinWeizheDragon/PreFLMR_ViT-L"
755
+ >>> image_processor_name = "openai/clip-vit-large-patch14"
756
+ >>> query_tokenizer = FLMRQueryEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="query_tokenizer")
757
+ >>> context_tokenizer = FLMRContextEncoderTokenizer.from_pretrained(checkpoint_path, subfolder="context_tokenizer")
758
+
759
+ >>> model = FLMRModelForRetrieval.from_pretrained(checkpoint_path,
760
+ query_tokenizer=query_tokenizer,
761
+ context_tokenizer=context_tokenizer,
762
+ )
763
+ >>> image_processor = AutoImageProcessor.from_pretrained(image_processor_name)
764
+
765
+ >>> Q_encoding = query_tokenizer(["Using the provided image, obtain documents that address the subsequent question: What is the capital of France?", "Extract documents linked to the question provided in conjunction with the image: What is the capital of China?"])
766
+ >>> D_encoding = context_tokenizer(["Paris is the capital of France.", "Beijing is the capital of China.",
767
+ "Paris is the capital of France.", "Beijing is the capital of China."])
768
+ >>> Q_pixel_values = torch.zeros(2, 3, 224, 224)
769
+ >>> inputs = dict(
770
+ query_input_ids=Q_encoding['input_ids'],
771
+ query_attention_mask=Q_encoding['attention_mask'],
772
+ query_pixel_values=Q_pixel_values,
773
+ context_input_ids=D_encoding['input_ids'],
774
+ context_attention_mask=D_encoding['attention_mask'],
775
+ use_in_batch_negatives=True,
776
+ )
777
+
778
+ >>> model.forward(**inputs)
779
+ FLMRModelForRetrievalOutput(loss=tensor(4.5000, device='cuda:0', dtype=torch.float16,
780
+ grad_fn=<NllLossBackward0>), scores=tensor([[44.2188, 40.6562],
781
+ [39.4375, 48.4062]], device='cuda:0', dtype=torch.float16,
782
+ grad_fn=<ViewBackward0>), in_batch_negative_loss=tensor(5.1994, device='cuda:0', grad_fn=<NllLossBackward0>), query_late_interaction_output=tensor(...), context_late_interaction_output=tensor(...)
783
+ ```
784
+ """
785
+
786
+ if query_concat_output_from_vision_encoder is None:
787
+ query_concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
788
+
789
+ if query_concat_output_from_text_encoder is None:
790
+ query_concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
791
+
792
+ if context_concat_output_from_vision_encoder is None:
793
+ context_concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
794
+
795
+ if context_concat_output_from_text_encoder is None:
796
+ context_concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
797
+
798
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
799
+ output_hidden_states = (
800
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
801
+ )
802
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
803
+
804
+ query_outputs = self.query(
805
+ input_ids=query_input_ids,
806
+ attention_mask=query_attention_mask,
807
+ pixel_values=query_pixel_values,
808
+ image_features=query_image_features,
809
+ concat_output_from_vision_encoder=query_concat_output_from_vision_encoder,
810
+ concat_output_from_text_encoder=query_concat_output_from_text_encoder,
811
+ output_attentions=output_attentions,
812
+ output_hidden_states=output_hidden_states,
813
+ )
814
+ Q = query_outputs.late_interaction_output
815
+
816
+ context_outputs = self.doc(
817
+ input_ids=context_input_ids,
818
+ attention_mask=context_attention_mask,
819
+ pixel_values=context_pixel_values,
820
+ image_features=context_image_features,
821
+ concat_output_from_vision_encoder=context_concat_output_from_vision_encoder,
822
+ concat_output_from_text_encoder=context_concat_output_from_text_encoder,
823
+ keep_dims=True,
824
+ return_mask=True,
825
+ output_attentions=output_attentions,
826
+ output_hidden_states=output_hidden_states,
827
+ )
828
+ D, D_mask = context_outputs.late_interaction_output, context_outputs.context_mask
829
+
830
+ # Gather tensors from other GPUs
831
+ if in_batch_negatives_from_all_gpus:
832
+ Q, D, D_mask = self.gather_tensors_from_other_gpus(Q, D, D_mask)
833
+ # Repeat each query encoding for every corresponding document.
834
+ Q_duplicated = Q.repeat_interleave(num_negative_examples + 1, dim=0).contiguous()
835
+
836
+ scores = self.score(Q_duplicated, D, D_mask)
837
+
838
+ # Use contrastive learning
839
+ batch_size = query_input_ids.shape[0]
840
+ scores = scores.view(-1, num_negative_examples + 1)
841
+ labels = torch.zeros(batch_size, dtype=torch.long, device=self.device)
842
+ loss = self.loss_fn(scores, labels)
843
+
844
+ if use_in_batch_negatives:
845
+ ib_loss = self.compute_ib_loss_new(Q, D, D_mask)
846
+ else:
847
+ ib_loss = None
848
+
849
+ if output_attentions:
850
+ query_attentions = (
851
+ query_outputs.text_encoder_attentions if query_outputs.text_encoder_attentions is not None else None,
852
+ query_outputs.vision_encoder_attentions
853
+ if query_outputs.vision_encoder_attentions is not None
854
+ else None,
855
+ query_outputs.transformer_mapping_network_attentions
856
+ if query_outputs.transformer_mapping_network_attentions is not None
857
+ else None,
858
+ )
859
+ context_attentions = (
860
+ context_outputs.text_encoder_attentions
861
+ if context_outputs.text_encoder_attentions is not None
862
+ else None,
863
+ context_outputs.vision_encoder_attentions
864
+ if context_outputs.vision_encoder_attentions is not None
865
+ else None,
866
+ context_outputs.transformer_mapping_network_attentions
867
+ if context_outputs.transformer_mapping_network_attentions is not None
868
+ else None,
869
+ )
870
+ else:
871
+ query_attentions = None
872
+ context_attentions = None
873
+
874
+ if output_hidden_states:
875
+ query_hidden_states = (
876
+ query_outputs.text_encoder_hidden_states
877
+ if query_outputs.text_encoder_hidden_states is not None
878
+ else None,
879
+ query_outputs.vision_encoder_hidden_states
880
+ if query_outputs.vision_encoder_hidden_states is not None
881
+ else None,
882
+ query_outputs.transformer_mapping_network_hidden_states
883
+ if query_outputs.transformer_mapping_network_hidden_states is not None
884
+ else None,
885
+ )
886
+ context_hidden_states = (
887
+ context_outputs.text_encoder_hidden_states
888
+ if context_outputs.text_encoder_hidden_states is not None
889
+ else None,
890
+ context_outputs.vision_encoder_hidden_states
891
+ if context_outputs.vision_encoder_hidden_states is not None
892
+ else None,
893
+ context_outputs.transformer_mapping_network_hidden_states
894
+ if context_outputs.transformer_mapping_network_hidden_states is not None
895
+ else None,
896
+ )
897
+ else:
898
+ query_hidden_states = None
899
+ context_hidden_states = None
900
+
901
+ if not return_dict:
902
+ if output_attentions and output_hidden_states:
903
+ return (
904
+ loss,
905
+ scores,
906
+ ib_loss,
907
+ query_outputs.late_interaction_output,
908
+ context_outputs.late_interaction_output,
909
+ query_attentions,
910
+ query_hidden_states,
911
+ context_attentions,
912
+ context_hidden_states,
913
+ )
914
+ elif output_attentions:
915
+ return (
916
+ loss,
917
+ scores,
918
+ ib_loss,
919
+ query_outputs.late_interaction_output,
920
+ context_outputs.late_interaction_output,
921
+ query_attentions,
922
+ context_attentions,
923
+ )
924
+ elif output_hidden_states:
925
+ return (
926
+ loss,
927
+ scores,
928
+ ib_loss,
929
+ query_outputs.late_interaction_output,
930
+ context_outputs.late_interaction_output,
931
+ query_hidden_states,
932
+ context_hidden_states,
933
+ )
934
+ else:
935
+ return (
936
+ loss,
937
+ scores,
938
+ ib_loss,
939
+ query_outputs.late_interaction_output,
940
+ context_outputs.late_interaction_output,
941
+ )
942
+
943
+ return FLMRModelForRetrievalOutput(
944
+ loss=loss,
945
+ scores=scores,
946
+ in_batch_negative_loss=ib_loss,
947
+ query_late_interaction_output=query_outputs.late_interaction_output,
948
+ context_late_interaction_output=context_outputs.late_interaction_output,
949
+ query_attentions=query_attentions if output_attentions else None,
950
+ query_hidden_states=query_hidden_states if output_hidden_states else None,
951
+ context_attentions=context_attentions if output_attentions else None,
952
+ context_hidden_states=context_hidden_states if output_hidden_states else None,
953
+ )
954
+
955
+ def compute_ib_loss_new(self, Q: torch.Tensor, D: torch.Tensor, D_mask: torch.Tensor) -> torch.Tensor:
956
+ # Q: batch_size x q_len x dim
957
+ # D: batch_size*n_docs x i_len x dim
958
+ # D_mask: batch_size*n_docs x i_len x dim
959
+ # 1 x batch_size*n_docs x i_len x dim matmul batch_size x 1 x q_len x dim
960
+ # = batch_size x batch_size*n_docs x i_len x q_len
961
+
962
+ scores = (D.float().unsqueeze(0) @ Q.float().permute(0, 2, 1).unsqueeze(1)).flatten(
963
+ 0, 1
964
+ ) # query-major unsqueeze
965
+ scores = colbert_score_reduce(scores, D_mask.repeat(Q.size(0), 1, 1))
966
+
967
+ in_batch_scores = scores.reshape(Q.size(0), -1)
968
+
969
+ batch_size = Q.shape[0]
970
+ batch_size_with_pos_and_neg = D.shape[0]
971
+ num_pos_and_neg = batch_size_with_pos_and_neg // batch_size
972
+
973
+ # batch_size x dim matmul dim x (num_pos+num_neg)*batch_size
974
+ # --> batch_size x (num_pos+num_neg)*batch_size
975
+ in_batch_labels = torch.zeros(batch_size, batch_size_with_pos_and_neg).to(scores.device)
976
+ step = num_pos_and_neg
977
+ for i in range(batch_size):
978
+ in_batch_labels[i, step * i] = 1
979
+ # print('in_batch_labels', in_batch_labels)
980
+ in_batch_labels = torch.argmax(in_batch_labels, dim=1)
981
+ # print('in_batch_labels', in_batch_labels)
982
+
983
+ loss = self.loss_fn(in_batch_scores, in_batch_labels)
984
+
985
+ return loss
986
+
987
+ def gather_tensors_from_other_gpus(self, query_embeddings, item_embeddings, item_mask):
988
+ # print("get rank", get_rank())
989
+ # print("get world size", get_world_size())
990
+ # Gather embeddings from other GPUs
991
+ n_nodes = get_world_size()
992
+ if n_nodes == 1:
993
+ return query_embeddings, item_embeddings, item_mask
994
+ # Create placeholder to hold embeddings passed from other ranks
995
+ global_query_embeddings_placeholder = [
996
+ torch.zeros(*query_embeddings.shape, dtype=query_embeddings.dtype).to(query_embeddings.device)
997
+ for _ in range(n_nodes)
998
+ ]
999
+ global_item_embeddings_placeholder = [
1000
+ torch.zeros(*item_embeddings.shape, dtype=item_embeddings.dtype).to(item_embeddings.device)
1001
+ for _ in range(n_nodes)
1002
+ ]
1003
+ global_item_mask_placeholder = [
1004
+ torch.zeros(*item_mask.shape, dtype=item_mask.dtype).to(item_mask.device) for _ in range(n_nodes)
1005
+ ]
1006
+ dist.all_gather(global_query_embeddings_placeholder, query_embeddings.detach())
1007
+ dist.all_gather(global_item_embeddings_placeholder, item_embeddings.detach())
1008
+ dist.all_gather(global_item_mask_placeholder, item_mask.detach())
1009
+
1010
+ global_query_embeddings = []
1011
+ global_item_embeddings = []
1012
+ global_item_mask = []
1013
+ # print(f"rank {get_rank()} global_query_embeddings", global_query_embeddings)
1014
+ # print(f"rank {get_rank()} global_item_embeddings", global_item_embeddings)
1015
+ # input()
1016
+ current_rank = get_rank()
1017
+ for rank_index, remote_q_embeddings in enumerate(global_query_embeddings_placeholder):
1018
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1019
+ if rank_index != current_rank:
1020
+ global_query_embeddings.append(remote_q_embeddings)
1021
+ else:
1022
+ global_query_embeddings.append(query_embeddings)
1023
+
1024
+ for rank_index, remote_item_embeddings in enumerate(global_item_embeddings_placeholder):
1025
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1026
+ if rank_index != current_rank:
1027
+ global_item_embeddings.append(remote_item_embeddings)
1028
+ else:
1029
+ global_item_embeddings.append(item_embeddings)
1030
+
1031
+ for rank_index, remote_item_mask in enumerate(global_item_mask_placeholder):
1032
+ # We append the embeddings from other GPUs if this embedding does not require gradients
1033
+ if rank_index != current_rank:
1034
+ global_item_mask.append(remote_item_mask)
1035
+ else:
1036
+ global_item_mask.append(item_mask)
1037
+
1038
+ # Replace the previous variables with gathered tensors
1039
+ query_embeddings = torch.cat(global_query_embeddings)
1040
+ item_embeddings = torch.cat(global_item_embeddings)
1041
+ item_mask = torch.cat(global_item_mask)
1042
+
1043
+ return query_embeddings, item_embeddings, item_mask
1044
+
1045
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_QUERY_INPUTS_DOCSTRING)
1046
+ @replace_return_docstrings(output_type=FLMRQueryEncoderOutput, config_class=_CONFIG_FOR_DOC)
1047
+ def query(
1048
+ self,
1049
+ input_ids: torch.Tensor,
1050
+ attention_mask: torch.Tensor,
1051
+ pixel_values: Optional[torch.Tensor] = None,
1052
+ image_features: Optional[torch.Tensor] = None,
1053
+ concat_output_from_vision_encoder: Optional[bool] = None,
1054
+ concat_output_from_text_encoder: Optional[bool] = None,
1055
+ output_attentions: Optional[bool] = None,
1056
+ output_hidden_states: Optional[bool] = None,
1057
+ ):
1058
+ r"""
1059
+ Returns:
1060
+
1061
+ """
1062
+
1063
+ if concat_output_from_vision_encoder is None:
1064
+ concat_output_from_vision_encoder = self.config.query_concat_output_from_vision_encoder
1065
+
1066
+ if concat_output_from_text_encoder is None:
1067
+ concat_output_from_text_encoder = self.config.query_concat_output_from_text_encoder
1068
+
1069
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1070
+ output_hidden_states = (
1071
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1072
+ )
1073
+
1074
+ input_modality = []
1075
+ if pixel_values is not None or image_features is not None:
1076
+ input_modality.append("image")
1077
+ if input_ids is not None and attention_mask is not None:
1078
+ input_modality.append("text")
1079
+
1080
+ text_encoder_outputs = None
1081
+ vision_encoder_outputs = None
1082
+ transformer_mapping_outputs = None
1083
+
1084
+ if "image" in input_modality:
1085
+ assert (
1086
+ pixel_values is not None or image_features is not None
1087
+ ), "pixel_values or image_features must be provided if image modality is used"
1088
+ assert (
1089
+ pixel_values is None or image_features is None
1090
+ ), "pixel_values and image_features cannot be provided at the same time"
1091
+
1092
+ if "text" in input_modality:
1093
+ assert (
1094
+ input_ids is not None and attention_mask is not None
1095
+ ), "input_ids and attention_mask must be provided if text modality is used"
1096
+ # Forward the text encoder
1097
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1098
+ text_encoder_outputs = self.query_text_encoder(input_ids, attention_mask=attention_mask)
1099
+ text_encoder_hidden_states = text_encoder_outputs[0]
1100
+ text_embeddings = self.query_text_encoder_linear(text_encoder_hidden_states)
1101
+ mask = torch.tensor(self.query_mask(input_ids, skiplist=[]), device=self.device).unsqueeze(2).float()
1102
+
1103
+ text_embeddings = text_embeddings * mask
1104
+
1105
+ if "image" in input_modality:
1106
+ if pixel_values is not None:
1107
+ batch_size = pixel_values.shape[0]
1108
+ # Forward the vision encoder
1109
+ pixel_values = pixel_values.to(self.device)
1110
+ if len(pixel_values.shape) == 5:
1111
+ # Multiple ROIs are provided
1112
+ # merge the first two dimensions
1113
+ pixel_values = pixel_values.reshape(
1114
+ -1, pixel_values.shape[2], pixel_values.shape[3], pixel_values.shape[4]
1115
+ )
1116
+ vision_encoder_outputs = self.query_vision_encoder(pixel_values, output_hidden_states=True)
1117
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1118
+
1119
+ if image_features is not None:
1120
+ batch_size = image_features.shape[0]
1121
+ vision_embeddings = image_features.to(self.device)
1122
+
1123
+ # Forward the vision projection / mapping network
1124
+ vision_embeddings = self.query_vision_projection(vision_embeddings)
1125
+ vision_embeddings = vision_embeddings.view(batch_size, -1, self.late_interaction_embedding_size)
1126
+
1127
+ if self.config.use_transformer_mapping_network:
1128
+ # select the second last layer
1129
+ vision_second_last_layer_hidden_states = vision_encoder_outputs.hidden_states[-2][:, 1:]
1130
+ # transformer_mapping
1131
+ transformer_mapping_input_features = self.transformer_mapping_input_linear(
1132
+ vision_second_last_layer_hidden_states
1133
+ )
1134
+
1135
+ # Cross attention only attends to the first 32 tokens
1136
+ encoder_mask = torch.ones_like(mask).to(mask.device, dtype=mask.dtype)
1137
+ cross_attention_length = self.config.transformer_mapping_cross_attention_length
1138
+ if text_encoder_hidden_states.shape[1] > cross_attention_length:
1139
+ text_encoder_hidden_states = text_encoder_hidden_states[:, :cross_attention_length]
1140
+ encoder_mask = encoder_mask[:, :cross_attention_length]
1141
+
1142
+ # Obtain cross attention mask
1143
+ encoder_extended_attention_mask = self.invert_attention_mask(encoder_mask.squeeze(-1))
1144
+ # Pass through the transformer mapping
1145
+ transformer_mapping_outputs = self.transformer_mapping_network(
1146
+ transformer_mapping_input_features,
1147
+ encoder_hidden_states=text_encoder_hidden_states,
1148
+ encoder_attention_mask=encoder_extended_attention_mask,
1149
+ )
1150
+ transformer_mapping_output_features = transformer_mapping_outputs.last_hidden_state
1151
+ # Convert the dimension to FLMR dim
1152
+ transformer_mapping_output_features = self.transformer_mapping_output_linear(
1153
+ transformer_mapping_output_features
1154
+ )
1155
+ # Merge with the vision embeddings
1156
+ vision_embeddings = torch.cat([vision_embeddings, transformer_mapping_output_features], dim=1)
1157
+
1158
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1159
+ Q = torch.cat([text_embeddings, vision_embeddings], dim=1)
1160
+ elif concat_output_from_vision_encoder:
1161
+ Q = vision_embeddings
1162
+ elif concat_output_from_text_encoder:
1163
+ Q = text_embeddings
1164
+
1165
+ vision_encoder_attentions = (
1166
+ vision_encoder_outputs.attentions
1167
+ if vision_encoder_outputs is not None
1168
+ and hasattr(vision_encoder_outputs, "attentions")
1169
+ and output_attentions
1170
+ else None
1171
+ )
1172
+ vision_encoder_hidden_states = (
1173
+ vision_encoder_outputs.hidden_states
1174
+ if vision_encoder_outputs is not None
1175
+ and hasattr(vision_encoder_outputs, "hidden_states")
1176
+ and output_hidden_states
1177
+ else None
1178
+ )
1179
+ text_encoder_attentions = (
1180
+ text_encoder_outputs.attentions
1181
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1182
+ else None
1183
+ )
1184
+ text_encoder_hidden_states = (
1185
+ text_encoder_outputs.hidden_states
1186
+ if text_encoder_outputs is not None
1187
+ and hasattr(text_encoder_outputs, "hidden_states")
1188
+ and output_hidden_states
1189
+ else None
1190
+ )
1191
+ transformer_mapping_network_attentions = (
1192
+ transformer_mapping_outputs.attentions
1193
+ if transformer_mapping_outputs is not None
1194
+ and hasattr(transformer_mapping_outputs, "attentions")
1195
+ and output_attentions
1196
+ else None
1197
+ )
1198
+ transformer_mapping_network_hidden_states = (
1199
+ transformer_mapping_outputs.hidden_states
1200
+ if transformer_mapping_outputs is not None
1201
+ and hasattr(transformer_mapping_outputs, "hidden_states")
1202
+ and output_hidden_states
1203
+ else None
1204
+ )
1205
+
1206
+ return FLMRQueryEncoderOutput(
1207
+ pooler_output=Q[:, 0, :],
1208
+ late_interaction_output=torch.nn.functional.normalize(Q, p=2, dim=2),
1209
+ vision_encoder_attentions=vision_encoder_attentions,
1210
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1211
+ text_encoder_attentions=text_encoder_attentions,
1212
+ text_encoder_hidden_states=text_encoder_hidden_states,
1213
+ transformer_mapping_network_attentions=transformer_mapping_network_attentions,
1214
+ transformer_mapping_network_hidden_states=transformer_mapping_network_hidden_states,
1215
+ )
1216
+
1217
+ @add_start_docstrings_to_model_forward(FLMR_MODEL_CONTEXT_INPUTS_DOCSTRING)
1218
+ @replace_return_docstrings(output_type=FLMRContextEncoderOutput, config_class=_CONFIG_FOR_DOC)
1219
+ def doc(
1220
+ self,
1221
+ input_ids: torch.Tensor,
1222
+ attention_mask: torch.Tensor,
1223
+ pixel_values: Optional[torch.Tensor] = None,
1224
+ image_features: Optional[torch.Tensor] = None,
1225
+ concat_output_from_vision_encoder: Optional[bool] = None,
1226
+ concat_output_from_text_encoder: Optional[bool] = None,
1227
+ keep_dims: Optional[bool] = True,
1228
+ return_mask: Optional[bool] = True,
1229
+ output_attentions: Optional[bool] = None,
1230
+ output_hidden_states: Optional[bool] = None,
1231
+ ):
1232
+ r"""
1233
+ Returns:
1234
+
1235
+ """
1236
+ assert keep_dims in [True, False]
1237
+
1238
+ if concat_output_from_vision_encoder is None:
1239
+ concat_output_from_vision_encoder = self.config.context_concat_output_from_vision_encoder
1240
+
1241
+ if concat_output_from_text_encoder is None:
1242
+ concat_output_from_text_encoder = self.config.context_concat_output_from_text_encoder
1243
+
1244
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1245
+ output_hidden_states = (
1246
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1247
+ )
1248
+
1249
+ input_modality = []
1250
+ if pixel_values is not None or image_features is not None:
1251
+ input_modality.append("image")
1252
+ if input_ids is not None and attention_mask is not None:
1253
+ input_modality.append("text")
1254
+
1255
+ text_encoder_outputs = None
1256
+ vision_encoder_outputs = None
1257
+
1258
+ if "image" in input_modality:
1259
+ assert (
1260
+ pixel_values is not None or image_features is not None
1261
+ ), "pixel_values or image_features must be provided if image modality is used"
1262
+ assert (
1263
+ pixel_values is None or image_features is None
1264
+ ), "pixel_values and image_features cannot be provided at the same time"
1265
+
1266
+ if "text" in input_modality:
1267
+ assert (
1268
+ input_ids is not None and attention_mask is not None
1269
+ ), "input_ids and attention_mask must be provided if text modality is used"
1270
+ # Forward the text encoder
1271
+ input_ids, attention_mask = input_ids.to(self.device), attention_mask.to(self.device)
1272
+ text_encoder_outputs = self.context_text_encoder(input_ids, attention_mask=attention_mask)
1273
+ text_embeddings = text_encoder_outputs[0]
1274
+ text_embeddings = self.context_text_encoder_linear(text_embeddings)
1275
+
1276
+ mask = torch.tensor(self.mask(input_ids, skiplist=self.skiplist), device=self.device).unsqueeze(2).float()
1277
+ text_embeddings = text_embeddings * mask
1278
+
1279
+ if "image" in input_modality:
1280
+ if pixel_values is not None:
1281
+ # Forward the vision encoder
1282
+ pixel_values = pixel_values.to(self.device)
1283
+ vision_encoder_outputs = self.context_vision_encoder(pixel_values)
1284
+ vision_embeddings = vision_encoder_outputs.last_hidden_state[:, 0]
1285
+
1286
+ if image_features is not None:
1287
+ vision_embeddings = image_features.to(self.device)
1288
+
1289
+ batch_size = vision_embeddings.shape[0]
1290
+
1291
+ # Forward the vision projection / mapping network
1292
+ vision_embeddings = self.context_vision_projection(vision_embeddings)
1293
+ vision_embeddings = vision_embeddings.view(
1294
+ -1, self.mapping_network_prefix_length, self.late_interaction_embedding_size
1295
+ )
1296
+
1297
+ image_mask = torch.ones(batch_size, vision_embeddings.shape[1], 1).to(self.device)
1298
+
1299
+ if concat_output_from_vision_encoder and concat_output_from_text_encoder:
1300
+ # Note: vision embeddings must be in the front since the ColBERT engine only indexes embeddings up to number of 1's in the mask
1301
+ # TODO: fix the engine to support masks with discontinuous 0 and 1.
1302
+ D = torch.cat([vision_embeddings, text_embeddings], dim=1)
1303
+ # concatenate the mask
1304
+ mask = torch.cat([mask, image_mask], dim=1)
1305
+ elif concat_output_from_vision_encoder:
1306
+ D = vision_embeddings
1307
+ mask = image_mask
1308
+ elif concat_output_from_text_encoder:
1309
+ D = text_embeddings
1310
+ mask = mask
1311
+
1312
+ D = torch.nn.functional.normalize(D, p=2, dim=2)
1313
+
1314
+ if self.use_gpu:
1315
+ D = D.half()
1316
+
1317
+ if keep_dims is False:
1318
+ D, mask = D.cpu(), mask.bool().cpu().squeeze(-1)
1319
+ D = [d[mask[idx]] for idx, d in enumerate(D)]
1320
+
1321
+ vision_encoder_attentions = (
1322
+ vision_encoder_outputs.attentions
1323
+ if vision_encoder_outputs is not None
1324
+ and hasattr(vision_encoder_outputs, "attentions")
1325
+ and output_attentions
1326
+ else None
1327
+ )
1328
+ vision_encoder_hidden_states = (
1329
+ vision_encoder_outputs.hidden_states
1330
+ if vision_encoder_outputs is not None
1331
+ and hasattr(vision_encoder_outputs, "hidden_states")
1332
+ and output_hidden_states
1333
+ else None
1334
+ )
1335
+ text_encoder_attentions = (
1336
+ text_encoder_outputs.attentions
1337
+ if text_encoder_outputs is not None and hasattr(text_encoder_outputs, "attentions") and output_attentions
1338
+ else None
1339
+ )
1340
+ text_encoder_hidden_states = (
1341
+ text_encoder_outputs.hidden_states
1342
+ if text_encoder_outputs is not None
1343
+ and hasattr(text_encoder_outputs, "hidden_states")
1344
+ and output_hidden_states
1345
+ else None
1346
+ )
1347
+
1348
+ return FLMRContextEncoderOutput(
1349
+ pooler_output=D[:, 0, :],
1350
+ late_interaction_output=D,
1351
+ context_mask=mask.bool() if return_mask else None,
1352
+ vision_encoder_attentions=vision_encoder_attentions,
1353
+ vision_encoder_hidden_states=vision_encoder_hidden_states,
1354
+ text_encoder_attentions=text_encoder_attentions,
1355
+ text_encoder_hidden_states=text_encoder_hidden_states,
1356
+ )
1357
+
1358
+ def score(self, Q, D_padded, D_mask):
1359
+ # assert self.colbert_config.similarity == 'cosine'
1360
+ # if self.colbert_config.similarity == 'l2':
1361
+ # assert self.colbert_config.interaction == 'colbert'
1362
+ # return (-1.0 * ((Q.unsqueeze(2) - D_padded.unsqueeze(1))**2).sum(-1)).max(-1).values.sum(-1)
1363
+ return colbert_score(Q, D_padded, D_mask, use_gpu=self.use_gpu)
1364
+
1365
+ def mask(self, input_ids, skiplist):
1366
+ mask = [[(x not in skiplist) and (x != 0) for x in d] for d in input_ids.cpu().tolist()]
1367
+ return mask
1368
+
1369
+
1370
+ @add_start_docstrings(
1371
+ "The bare FLMR text encoder that can be used to generate late-interaction embeddings for texts in queries and contexts. This model is based on a `BertModel`. It can be used like a `BertModel` model for encoding text.",
1372
+ FLMR_TEXT_ENCODERS_START_DOCSTRING,
1373
+ )
1374
+ class FLMRTextModel(FLMRPreTrainedModel):
1375
+ base_model_prefix = "bert_model"
1376
+ config_class = FLMRTextConfig
1377
+
1378
+ def __init__(self, config: FLMRTextConfig, *args, **kwargs):
1379
+ super().__init__(config)
1380
+ self.bert_model = BertModel(config, add_pooling_layer=True)
1381
+ if self.bert_model.config.hidden_size <= 0:
1382
+ raise ValueError("Encoder hidden_size can't be zero")
1383
+ self.projection_dim = config.projection_dim
1384
+ if self.projection_dim > 0:
1385
+ self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
1386
+ # Initialize weights and apply final processing
1387
+ self.post_init()
1388
+
1389
+ @add_start_docstrings_to_model_forward(FLMR_TEXT_ENCODERS_INPUTS_DOCSTRING)
1390
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRTextConfig)
1391
+ def forward(
1392
+ self,
1393
+ input_ids: Optional[Tensor] = None,
1394
+ attention_mask: Optional[Tensor] = None,
1395
+ token_type_ids: Optional[Tensor] = None,
1396
+ inputs_embeds: Optional[Tensor] = None,
1397
+ output_attentions: bool = None,
1398
+ output_hidden_states: bool = None,
1399
+ return_dict: bool = None,
1400
+ ) -> Union[BaseModelOutputWithPooling, Tuple[Tensor, ...]]:
1401
+ r"""
1402
+ Returns:
1403
+
1404
+ """
1405
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1406
+ output_hidden_states = (
1407
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1408
+ )
1409
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1410
+
1411
+ outputs = self.bert_model(
1412
+ input_ids=input_ids,
1413
+ attention_mask=attention_mask,
1414
+ token_type_ids=token_type_ids,
1415
+ inputs_embeds=inputs_embeds,
1416
+ output_attentions=output_attentions,
1417
+ output_hidden_states=output_hidden_states,
1418
+ return_dict=return_dict,
1419
+ )
1420
+ sequence_output = outputs[0]
1421
+ pooled_output = sequence_output[:, 0, :]
1422
+
1423
+ if self.projection_dim > 0:
1424
+ pooled_output = self.encode_proj(pooled_output)
1425
+
1426
+ if not return_dict:
1427
+ return (sequence_output, pooled_output) + outputs[2:]
1428
+
1429
+ return BaseModelOutputWithPooling(
1430
+ last_hidden_state=sequence_output,
1431
+ pooler_output=pooled_output,
1432
+ hidden_states=outputs.hidden_states,
1433
+ attentions=outputs.attentions,
1434
+ )
1435
+
1436
+ @property
1437
+ def embeddings_size(self) -> int:
1438
+ if self.projection_dim > 0:
1439
+ return self.encode_proj.out_features
1440
+ return self.bert_model.config.hidden_size
1441
+
1442
+
1443
+ @add_start_docstrings(
1444
+ "The bare FLMR vision encoder that can be used to generate late-interaction embeddings for images in queries and contexts. This model is based on a `CLIPVisionModel`. It can be used like a `CLIPVisionModel` model for encoding images.",
1445
+ FLMR_VISION_ENCODERS_START_DOCSTRING,
1446
+ )
1447
+ class FLMRVisionModel(FLMRPreTrainedModel):
1448
+ base_model_prefix = "vision_model"
1449
+ config_class = FLMRVisionConfig
1450
+ main_input_name = "pixel_values"
1451
+ _no_split_modules = ["CLIPEncoderLayer"]
1452
+
1453
+ def __init__(self, config: FLMRVisionConfig):
1454
+ super().__init__(config)
1455
+ self.vision_model = CLIPVisionModel(config)
1456
+ self.post_init()
1457
+
1458
+ def get_input_embeddings(self) -> nn.Module:
1459
+ return self.vision_model.vision_model.embeddings.patch_embedding
1460
+
1461
+ @add_start_docstrings_to_model_forward(FLMR_VISION_ENCODERS_INPUTS_DOCSTRING)
1462
+ @replace_return_docstrings(output_type=BaseModelOutputWithPooling, config_class=FLMRVisionConfig)
1463
+ def forward(
1464
+ self,
1465
+ pixel_values: Optional[torch.FloatTensor] = None,
1466
+ output_attentions: Optional[bool] = None,
1467
+ output_hidden_states: Optional[bool] = None,
1468
+ return_dict: Optional[bool] = None,
1469
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
1470
+ r"""
1471
+ Returns:
1472
+
1473
+ Examples:
1474
+
1475
+ ```python
1476
+ >>> from PIL import Image
1477
+ >>> import requests
1478
+ >>> from transformers import AutoProcessor, FLMRVisionModel
1479
+
1480
+ >>> model = FLMRVisionModel.from_pretrained("openai/clip-vit-base-patch32")
1481
+ >>> processor = AutoProcessor.from_pretrained("openai/clip-vit-base-patch32")
1482
+
1483
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1484
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1485
+
1486
+ >>> inputs = processor(images=image, return_tensors="pt")
1487
+
1488
+ >>> outputs = model(**inputs)
1489
+ >>> last_hidden_state = outputs.last_hidden_state
1490
+ >>> pooled_output = outputs.pooler_output # pooled CLS states
1491
+ ```"""
1492
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1493
+
1494
+ return self.vision_model(
1495
+ pixel_values=pixel_values,
1496
+ output_attentions=output_attentions,
1497
+ output_hidden_states=output_hidden_states,
1498
+ return_dict=return_dict,
1499
+ )
query_tokenizer/special_tokens_map.json ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cls_token": {
3
+ "content": "[CLS]",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "mask_token": {
10
+ "content": "[MASK]",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": {
17
+ "content": "[PAD]",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "sep_token": {
24
+ "content": "[SEP]",
25
+ "lstrip": false,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "unk_token": {
31
+ "content": "[UNK]",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ }
37
+ }
query_tokenizer/tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
query_tokenizer/tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "[PAD]",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "100": {
12
+ "content": "[UNK]",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "101": {
20
+ "content": "[CLS]",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "102": {
28
+ "content": "[SEP]",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "103": {
36
+ "content": "[MASK]",
37
+ "lstrip": false,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "attend_to_mask_tokens": false,
45
+ "auto_map": {
46
+ "AutoTokenizer": [
47
+ "tokenization_flmr.FLMRQueryEncoderTokenizer",
48
+ null
49
+ ]
50
+ },
51
+ "clean_up_tokenization_spaces": true,
52
+ "cls_token": "[CLS]",
53
+ "do_basic_tokenize": true,
54
+ "do_lower_case": true,
55
+ "mask_token": "[MASK]",
56
+ "model_max_length": 512,
57
+ "never_split": null,
58
+ "pad_token": "[PAD]",
59
+ "query_maxlen": 32,
60
+ "sep_token": "[SEP]",
61
+ "strip_accents": null,
62
+ "tokenize_chinese_chars": true,
63
+ "tokenizer_class": "FLMRQueryEncoderTokenizer",
64
+ "unk_token": "[UNK]"
65
+ }
query_tokenizer/vocab.txt ADDED
The diff for this file is too large to render. See raw diff
 
segmented_maxsim.cpp ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <pthread.h>
2
+ #include <torch/extension.h>
3
+
4
+ #include <algorithm>
5
+ #include <numeric>
6
+
7
+ typedef struct {
8
+ int tid;
9
+ int nthreads;
10
+
11
+ int ndocs;
12
+ int ndoc_vectors;
13
+ int nquery_vectors;
14
+
15
+ int64_t* lengths;
16
+ float* scores;
17
+ int64_t* offsets;
18
+
19
+ float* max_scores;
20
+ } max_args_t;
21
+
22
+ void* max(void* args) {
23
+ max_args_t* max_args = (max_args_t*)args;
24
+
25
+ int ndocs_per_thread =
26
+ std::ceil(((float)max_args->ndocs) / max_args->nthreads);
27
+ int start = max_args->tid * ndocs_per_thread;
28
+ int end = std::min((max_args->tid + 1) * ndocs_per_thread, max_args->ndocs);
29
+
30
+ auto max_scores_offset =
31
+ max_args->max_scores + (start * max_args->nquery_vectors);
32
+ auto scores_offset =
33
+ max_args->scores + (max_args->offsets[start] * max_args->nquery_vectors);
34
+
35
+ for (int i = start; i < end; i++) {
36
+ for (int j = 0; j < max_args->lengths[i]; j++) {
37
+ std::transform(max_scores_offset,
38
+ max_scores_offset + max_args->nquery_vectors,
39
+ scores_offset, max_scores_offset,
40
+ [](float a, float b) { return std::max(a, b); });
41
+ scores_offset += max_args->nquery_vectors;
42
+ }
43
+ max_scores_offset += max_args->nquery_vectors;
44
+ }
45
+
46
+ return NULL;
47
+ }
48
+
49
+ torch::Tensor segmented_maxsim(const torch::Tensor scores,
50
+ const torch::Tensor lengths) {
51
+ auto lengths_a = lengths.data_ptr<int64_t>();
52
+ auto scores_a = scores.data_ptr<float>();
53
+ auto ndocs = lengths.size(0);
54
+ auto ndoc_vectors = scores.size(0);
55
+ auto nquery_vectors = scores.size(1);
56
+ auto nthreads = at::get_num_threads();
57
+
58
+ torch::Tensor max_scores =
59
+ torch::zeros({ndocs, nquery_vectors}, scores.options());
60
+
61
+ int64_t offsets[ndocs + 1];
62
+ offsets[0] = 0;
63
+ std::partial_sum(lengths_a, lengths_a + ndocs, offsets + 1);
64
+
65
+ pthread_t threads[nthreads];
66
+ max_args_t args[nthreads];
67
+
68
+ for (int i = 0; i < nthreads; i++) {
69
+ args[i].tid = i;
70
+ args[i].nthreads = nthreads;
71
+
72
+ args[i].ndocs = ndocs;
73
+ args[i].ndoc_vectors = ndoc_vectors;
74
+ args[i].nquery_vectors = nquery_vectors;
75
+
76
+ args[i].lengths = lengths_a;
77
+ args[i].scores = scores_a;
78
+ args[i].offsets = offsets;
79
+
80
+ args[i].max_scores = max_scores.data_ptr<float>();
81
+
82
+ int rc = pthread_create(&threads[i], NULL, max, (void*)&args[i]);
83
+ if (rc) {
84
+ fprintf(stderr, "Unable to create thread %d: %d\n", i, rc);
85
+ }
86
+ }
87
+
88
+ for (int i = 0; i < nthreads; i++) {
89
+ pthread_join(threads[i], NULL);
90
+ }
91
+
92
+ return max_scores.sum(1);
93
+ }
94
+
95
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
96
+ m.def("segmented_maxsim_cpp", &segmented_maxsim, "Segmented MaxSim");
97
+ }
tokenization_flmr.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from typing import List, Optional, Union
19
+
20
+ from transformers.utils import TensorType, logging
21
+ from transformers.models.bert.tokenization_bert import BertTokenizer
22
+
23
+
24
+ logger = logging.get_logger(__name__)
25
+
26
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
27
+
28
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
29
+ "vocab_file": {
30
+ "LinWeizheDragon/PreFLMR_ViT-L": (
31
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
32
+ ),
33
+ "LinWeizheDragon/FLMR": (
34
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
35
+ ),
36
+ },
37
+ "tokenizer_file": {
38
+ "LinWeizheDragon/PreFLMR_ViT-L": (
39
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
40
+ ),
41
+ "LinWeizheDragon/FLMR": (
42
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
43
+ ),
44
+ },
45
+ }
46
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
47
+ "vocab_file": {
48
+ "LinWeizheDragon/PreFLMR_ViT-L": (
49
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
50
+ ),
51
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
52
+ },
53
+ "tokenizer_file": {
54
+ "LinWeizheDragon/PreFLMR_ViT-L": (
55
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
56
+ ),
57
+ "LinWeizheDragon/FLMR": (
58
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
59
+ ),
60
+ },
61
+ }
62
+
63
+
64
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
65
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
66
+ "LinWeizheDragon/FLMR": 512,
67
+ }
68
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
69
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
70
+ "LinWeizheDragon/FLMR": 512,
71
+ }
72
+
73
+
74
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
75
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
76
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
77
+ }
78
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
79
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
80
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
81
+ }
82
+
83
+
84
+ # Modified from colbert.modeling.tokenization
85
+ class FLMRContextEncoderTokenizer(BertTokenizer):
86
+ r"""
87
+ Construct a FLMRContextEncoder tokenizer.
88
+
89
+ [`FLMRContextEncoderTokenizer`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
90
+ splitting and wordpiece.
91
+
92
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
93
+ """
94
+
95
+ vocab_files_names = VOCAB_FILES_NAMES
96
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
97
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
98
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
99
+
100
+ def __init__(
101
+ self,
102
+ doc_maxlen: Optional[int] = 512,
103
+ **kwargs,
104
+ ):
105
+ super().__init__(
106
+ doc_maxlen=doc_maxlen,
107
+ **kwargs,
108
+ )
109
+
110
+ self.doc_maxlen = doc_maxlen
111
+ self.D_marker_token, self.D_marker_token_id = "[D]", self.convert_tokens_to_ids("[unused1]")
112
+
113
+ def __call__(
114
+ self,
115
+ text: List[str],
116
+ padding: Optional[Union[str, bool]] = "max_length",
117
+ truncation: Optional[Union[bool, str]] = "longest_first",
118
+ max_length: Optional[int] = 512,
119
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
120
+ **kwargs,
121
+ ):
122
+ # add placehold for the [D] marker
123
+ text = [". " + x for x in text]
124
+
125
+ if max_length > self.doc_maxlen:
126
+ # can not exceed the pre-set length
127
+ max_length = self.doc_maxlen
128
+
129
+ encoding = super().__call__(
130
+ text,
131
+ padding=padding,
132
+ truncation=truncation,
133
+ return_tensors=return_tensors,
134
+ max_length=max_length,
135
+ **kwargs,
136
+ )
137
+
138
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
139
+
140
+ # postprocess for the [D] marker
141
+ ids[:, 1] = self.D_marker_token_id
142
+
143
+ # if bsize:
144
+ # # This bsize function is used in the original ColBERT codebase to split inputs into multiple batches
145
+ # if image_features is not None:
146
+ # ids, mask, image_features, reverse_indices = _sort_by_length(ids, mask, bsize, image_features=image_features)
147
+ # batches = _split_into_batches(ids, mask, bsize, image_features=image_features)
148
+ # else:
149
+ # ids, mask, reverse_indices = _sort_by_length(ids, mask, bsize)
150
+ # batches = _split_into_batches(ids, mask, bsize)
151
+
152
+ # return batches, reverse_indices
153
+
154
+ encoding["input_ids"] = ids
155
+ encoding["attention_mask"] = mask
156
+
157
+ return encoding
158
+
159
+
160
+ # Modified from colbert.modeling.tokenization
161
+ class FLMRQueryEncoderTokenizer(BertTokenizer):
162
+ r"""
163
+ Constructs a FLMRQueryEncoder tokenizer.
164
+
165
+ [`FLMRQueryEncoder`] is identical to [`BertTokenizer`] and runs end-to-end tokenization: punctuation
166
+ splitting and wordpiece.
167
+
168
+ Refer to superclass [`BertTokenizer`] for usage examples and documentation concerning parameters.
169
+ """
170
+
171
+ vocab_files_names = VOCAB_FILES_NAMES
172
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
173
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
174
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
175
+
176
+ def __init__(
177
+ self,
178
+ *args,
179
+ query_maxlen: Optional[int] = 32,
180
+ attend_to_mask_tokens: Optional[bool] = False,
181
+ **kwargs,
182
+ ):
183
+ super().__init__(
184
+ *args,
185
+ query_maxlen=query_maxlen,
186
+ attend_to_mask_tokens=attend_to_mask_tokens,
187
+ **kwargs,
188
+ )
189
+
190
+ self.query_maxlen = query_maxlen
191
+ self.background_maxlen = 512 - self.query_maxlen + 1 # FIXME: Make this configurable
192
+ self.attend_to_mask_tokens = attend_to_mask_tokens
193
+
194
+ self.Q_marker_token, self.Q_marker_token_id = "[Q]", self.convert_tokens_to_ids("[unused0]")
195
+
196
+ def __call__(
197
+ self,
198
+ text: Union[str, List[str]],
199
+ padding: Optional[Union[str, bool]] = "max_length",
200
+ truncation: Optional[Union[bool, str]] = True,
201
+ max_length: Optional[int] = None,
202
+ return_tensors: Optional[Union[str, TensorType]] = "pt",
203
+ **kwargs,
204
+ ):
205
+ if isinstance(text, str):
206
+ # convert to list if input is a single string
207
+ text = [text]
208
+
209
+ # add placehold for the [Q] marker
210
+ text = [". " + x for x in text]
211
+
212
+ if max_length is not None:
213
+ # use user specified max_length
214
+ pass
215
+ else:
216
+ # use default max length
217
+ max_length = self.query_maxlen
218
+
219
+ encoding = super().__call__(
220
+ text,
221
+ padding=padding,
222
+ truncation=truncation,
223
+ return_tensors=return_tensors,
224
+ max_length=max_length,
225
+ **kwargs,
226
+ )
227
+
228
+ ids, mask = encoding["input_ids"], encoding["attention_mask"]
229
+
230
+ # postprocess for the [Q] marker and the [MASK] augmentation
231
+ ids[:, 1] = self.Q_marker_token_id
232
+ ids[ids == self.pad_token_id] = self.mask_token_id
233
+
234
+ if self.attend_to_mask_tokens:
235
+ # When attend_to_mask_tokens is True, we want to attend to the [MASK] tokens
236
+ mask[ids == self.mask_token_id] = 1
237
+ assert mask.sum().item() == mask.size(0) * mask.size(1), mask
238
+
239
+ return {"input_ids": ids, "attention_mask": mask}
tokenization_flmr_fast.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The HuggingFace Inc. team, The Hugging Face Team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Tokenization classes for FLMR."""
16
+
17
+
18
+ from transformers.utils import logging
19
+ from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
20
+ from .tokenization_flmr import FLMRContextEncoderTokenizer, FLMRQueryEncoderTokenizer
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+ VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer_config.json"}
26
+
27
+ CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
28
+ "vocab_file": {
29
+ "LinWeizheDragon/PreFLMR_ViT-L": (
30
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/vocab.txt"
31
+ ),
32
+ "LinWeizheDragon/FLMR": (
33
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/vocab.txt"
34
+ ),
35
+ },
36
+ "tokenizer_file": {
37
+ "LinWeizheDragon/PreFLMR_ViT-L": (
38
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/context_tokenizer/tokenizer_config.json"
39
+ ),
40
+ "LinWeizheDragon/FLMR": (
41
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/context_tokenizer/tokenizer_config.json"
42
+ ),
43
+ },
44
+ }
45
+ QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP = {
46
+ "vocab_file": {
47
+ "LinWeizheDragon/PreFLMR_ViT-L": (
48
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/vocab.txt"
49
+ ),
50
+ "LinWeizheDragon/FLMR": ("https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/vocab.txt"),
51
+ },
52
+ "tokenizer_file": {
53
+ "LinWeizheDragon/PreFLMR_ViT-L": (
54
+ "https://huggingface.co/LinWeizheDragon/PreFLMR_ViT-L/resolve/main/query_tokenizer/tokenizer_config.json"
55
+ ),
56
+ "LinWeizheDragon/FLMR": (
57
+ "https://huggingface.co/LinWeizheDragon/FLMR/resolve/main/query_tokenizer/tokenizer_config.json"
58
+ ),
59
+ },
60
+ }
61
+
62
+
63
+ CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
64
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
65
+ "LinWeizheDragon/FLMR": 512,
66
+ }
67
+ QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = {
68
+ "LinWeizheDragon/PreFLMR_ViT-L": 512,
69
+ "LinWeizheDragon/FLMR": 512,
70
+ }
71
+
72
+
73
+ CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
74
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
75
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
76
+ }
77
+ QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION = {
78
+ "LinWeizheDragon/PreFLMR_ViT-L": {"do_lower_case": True},
79
+ "LinWeizheDragon/FLMR": {"do_lower_case": True},
80
+ }
81
+
82
+
83
+ class FLMRContextEncoderTokenizerFast(BertTokenizerFast):
84
+ r"""
85
+ Construct a "fast" FLMRContextEncoder tokenizer (backed by HuggingFace's *tokenizers* library).
86
+
87
+ [`FLMRContextEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
88
+ punctuation splitting and wordpiece.
89
+
90
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
91
+ """
92
+
93
+ vocab_files_names = VOCAB_FILES_NAMES
94
+ pretrained_vocab_files_map = CONTEXT_ENCODER_PRETRAINED_VOCAB_FILES_MAP
95
+ max_model_input_sizes = CONTEXT_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
96
+ pretrained_init_configuration = CONTEXT_ENCODER_PRETRAINED_INIT_CONFIGURATION
97
+ slow_tokenizer_class = FLMRContextEncoderTokenizer
98
+
99
+
100
+ class FLMRQueryEncoderTokenizerFast(BertTokenizerFast):
101
+ r"""
102
+ Constructs a "fast" FLMRQueryEncoderTokenizer tokenizer (backed by HuggingFace's *tokenizers* library).
103
+
104
+ [`FLMRQueryEncoderTokenizerFast`] is identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
105
+ punctuation splitting and wordpiece.
106
+
107
+ Refer to superclass [`BertTokenizerFast`] for usage examples and documentation concerning parameters.
108
+ """
109
+
110
+ vocab_files_names = VOCAB_FILES_NAMES
111
+ pretrained_vocab_files_map = QUESTION_ENCODER_PRETRAINED_VOCAB_FILES_MAP
112
+ max_model_input_sizes = QUESTION_ENCODER_PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES
113
+ pretrained_init_configuration = QUESTION_ENCODER_PRETRAINED_INIT_CONFIGURATION
114
+ slow_tokenizer_class = FLMRQueryEncoderTokenizer