Samoed commited on
Commit
d4b7753
·
verified ·
1 Parent(s): d553fa4

Create custom_st.py

Browse files

Based on the updated [demo](https://github.com/haon-chen/mmE5/blob/main/demo.py), I've created a custom `SentenceTransformer` model. However, I don't have the hardware to test it. I've created it based on [jasper implementation](https://huggingface.co/NovaSearch/jasper_en_vision_language_v1/blob/main/custom_st.py).



@tomaarsen
, could you review it as well, please?

Files changed (1) hide show
  1. custom_st.py +105 -0
custom_st.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, Optional, List
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, MllamaForConditionalGeneration
5
+ from sentence_transformers.models import Transformer as BaseTransformer
6
+
7
+ class MultiModalTransformer(BaseTransformer):
8
+ def __init__(
9
+ self,
10
+ model_name_or_path: str,
11
+ cache_dir: Optional[str] = None,
12
+ tokenizer_args: Optional[Dict[str, Any]] = None,
13
+ **kwargs,
14
+ ):
15
+ super().__init__(model_name_or_path, **kwargs)
16
+ if tokenizer_args is None:
17
+ tokenizer_args = {}
18
+
19
+ # Initialize processor and set padding side
20
+ self.processor = AutoProcessor.from_pretrained(
21
+ model_name_or_path, cache_dir=cache_dir, **tokenizer_args
22
+ )
23
+
24
+ # Configure model settings
25
+ config = self.auto_model.config
26
+ if hasattr(config, 'use_cache'):
27
+ config.use_cache = False
28
+
29
+ padding_side = "right"
30
+ self.processor.tokenizer.padding_side = padding_side
31
+ config.padding_side = padding_side
32
+ self.auto_model.padding_side = padding_side
33
+
34
+ def forward(
35
+ self, features: Dict[str, torch.Tensor], **kwargs
36
+ ) -> Dict[str, torch.Tensor]:
37
+ # Process inputs through the model
38
+ outputs = self.auto_model(
39
+ **features,
40
+ return_dict=True,
41
+ output_hidden_states=True,
42
+ **kwargs
43
+ )
44
+
45
+ # Apply last pooling and normalization
46
+ last_hidden_state = outputs.hidden_states[-1]
47
+ attention_mask = features["attention_mask"]
48
+ sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
49
+
50
+ features.update({"sentence_embedding": sentence_embedding})
51
+ return features
52
+
53
+ def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
54
+ """Apply last token pooling and L2 normalization"""
55
+ sequence_lengths = attention_mask.sum(dim=1) - 1
56
+ batch_size = last_hidden_state.shape[0]
57
+ reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
58
+ return torch.nn.functional.normalize(reps, p=2, dim=-1)
59
+
60
+ def tokenize(self, texts: List[Dict] | List[str]) -> Dict[str, torch.Tensor]:
61
+ def process_text_item(item):
62
+ if isinstance(item, str):
63
+ return item, []
64
+
65
+ text, images = "", []
66
+ for sub_item in item:
67
+ if sub_item["type"] == "text":
68
+ text += sub_item["content"]
69
+ elif sub_item["type"] in ["image_bytes", "image_path"]:
70
+ text += "<|image|><|begin_of_text|> Represent the given image"
71
+ if sub_item["type"] == "image_bytes":
72
+ img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
73
+ else:
74
+ img = Image.open(sub_item["content"]).convert("RGB")
75
+ images.append(img)
76
+ else:
77
+ raise ValueError(f"Unknown data type {sub_item['type']}")
78
+ return text, images
79
+
80
+ all_texts, all_images = [], []
81
+ for item in texts:
82
+ text, images = process_text_item(item)
83
+ all_texts.append(text)
84
+ all_images.extend(images)
85
+
86
+ # Process inputs through the processor
87
+ if all_images:
88
+ inputs = self.processor(
89
+ text=all_texts,
90
+ images=all_images,
91
+ padding="longest",
92
+ truncation=True,
93
+ max_length=self.max_seq_length,
94
+ return_tensors="pt"
95
+ )
96
+ else:
97
+ inputs = self.processor(
98
+ text=all_texts,
99
+ padding="longest",
100
+ truncation=True,
101
+ max_length=self.max_seq_length,
102
+ return_tensors="pt"
103
+ )
104
+
105
+ return inputs