LinWeizheDragon
commited on
Update modeling_flmr.py
Browse files- modeling_flmr.py +16 -14
modeling_flmr.py
CHANGED
@@ -584,13 +584,14 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
|
|
584 |
self.text_encoder_embedding_size = self.config.text_config.hidden_size
|
585 |
self.late_interaction_embedding_size = self.config.dim
|
586 |
|
587 |
-
self.
|
588 |
-
(
|
589 |
-
|
590 |
-
|
591 |
-
|
|
|
|
|
592 |
)
|
593 |
-
)
|
594 |
|
595 |
if self.config.use_vision_encoder:
|
596 |
self.context_vision_encoder = FLMRVisionModel(config.vision_config)
|
@@ -636,13 +637,14 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
|
|
636 |
self.query_text_encoder_linear = self.context_text_encoder_linear
|
637 |
self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
|
638 |
|
639 |
-
if self.config.
|
640 |
-
self.
|
641 |
-
|
642 |
-
|
643 |
-
|
644 |
-
|
645 |
-
|
|
|
646 |
|
647 |
if self.config.load_cpu_extension:
|
648 |
try:
|
@@ -1304,7 +1306,7 @@ class FLMRModelForRetrieval(FLMRPretrainedModelForRetrieval):
|
|
1304 |
# TODO: fix the engine to support masks with discontinuous 0 and 1.
|
1305 |
D = torch.cat([vision_embeddings, text_embeddings], dim=1)
|
1306 |
# concatenate the mask
|
1307 |
-
mask = torch.cat([
|
1308 |
elif concat_output_from_vision_encoder:
|
1309 |
D = vision_embeddings
|
1310 |
mask = image_mask
|
|
|
584 |
self.text_encoder_embedding_size = self.config.text_config.hidden_size
|
585 |
self.late_interaction_embedding_size = self.config.dim
|
586 |
|
587 |
+
if self.config.use_vision_encoder:
|
588 |
+
self.context_vision_projection = FLMRMultiLayerPerceptron(
|
589 |
+
(
|
590 |
+
self.vision_encoder_embedding_size,
|
591 |
+
(self.late_interaction_embedding_size * self.mapping_network_prefix_length) // 2,
|
592 |
+
self.late_interaction_embedding_size * self.mapping_network_prefix_length,
|
593 |
+
)
|
594 |
)
|
|
|
595 |
|
596 |
if self.config.use_vision_encoder:
|
597 |
self.context_vision_encoder = FLMRVisionModel(config.vision_config)
|
|
|
637 |
self.query_text_encoder_linear = self.context_text_encoder_linear
|
638 |
self._tied_weights_keys += ["context_text_encoder", "context_text_encoder_linear"]
|
639 |
|
640 |
+
if self.config.use_vision_encoder:
|
641 |
+
if self.config.separate_query_and_context_vision_encoder:
|
642 |
+
self.query_vision_encoder = copy.deepcopy(self.context_vision_encoder)
|
643 |
+
self.query_vision_projection = copy.deepcopy(self.context_vision_projection)
|
644 |
+
else:
|
645 |
+
self.query_vision_encoder = self.context_vision_encoder
|
646 |
+
self.query_vision_projection = self.context_vision_projection
|
647 |
+
self._tied_weights_keys += ["context_vision_encoder", "context_vision_projection"]
|
648 |
|
649 |
if self.config.load_cpu_extension:
|
650 |
try:
|
|
|
1306 |
# TODO: fix the engine to support masks with discontinuous 0 and 1.
|
1307 |
D = torch.cat([vision_embeddings, text_embeddings], dim=1)
|
1308 |
# concatenate the mask
|
1309 |
+
mask = torch.cat([image_mask, mask], dim=1)
|
1310 |
elif concat_output_from_vision_encoder:
|
1311 |
D = vision_embeddings
|
1312 |
mask = image_mask
|