Files changed (2) hide show
  1. README.md +46 -39
  2. custom_st.py +110 -0
README.md CHANGED
@@ -44,57 +44,64 @@ pip install -r requirements.txt
44
 
45
  Then you can enter the directory to run the following command.
46
  ```python
47
- from src.model import MMEBModel
48
- from src.arguments import ModelArguments
49
- from src.utils import load_processor
50
  import torch
51
- from transformers import HfArgumentParser, AutoProcessor
52
  from PIL import Image
53
- import numpy as np
54
- model_args = ModelArguments(
55
- model_name='intfloat/mmE5-mllama-11b-instruct',
56
- pooling='last',
57
- normalize=True,
58
- model_backbone='mllama')
59
- processor = load_processor(model_args)
60
- model = MMEBModel.load(model_args)
 
 
 
 
 
 
 
 
 
 
 
 
61
  model.eval()
62
- model = model.to('cuda', dtype=torch.bfloat16)
63
  # Image + Text -> Text
64
  inputs = processor(text='<|image|><|begin_of_text|> Represent the given image with the following question: What is in the image', images=[Image.open(
65
- 'figures/example.jpg')], return_tensors="pt")
66
- inputs = {key: value.to('cuda') for key, value in inputs.items()}
67
- qry_output = model(qry=inputs)["qry_reps"]
68
  string = 'A cat and a dog'
69
- inputs = processor(text=string, return_tensors="pt")
70
- inputs = {key: value.to('cuda') for key, value in inputs.items()}
71
- tgt_output = model(tgt=inputs)["tgt_reps"]
72
- print(string, '=', model.compute_similarity(qry_output, tgt_output))
73
  ## A cat and a dog = tensor([[0.3965]], device='cuda:0', dtype=torch.bfloat16)
 
74
  string = 'A cat and a tiger'
75
- inputs = processor(text=string, return_tensors="pt")
76
- inputs = {key: value.to('cuda') for key, value in inputs.items()}
77
- tgt_output = model(tgt=inputs)["tgt_reps"]
78
- print(string, '=', model.compute_similarity(qry_output, tgt_output))
79
  ## A cat and a tiger = tensor([[0.3105]], device='cuda:0', dtype=torch.bfloat16)
 
80
  # Text -> Image
81
- inputs = processor(text='Find me an everyday image that matches the given caption: A cat and a dog.', return_tensors="pt")
82
- inputs = {key: value.to('cuda') for key, value in inputs.items()}
83
- qry_output = model(qry=inputs)["qry_reps"]
84
  string = '<|image|><|begin_of_text|> Represent the given image.'
85
- inputs = processor(text=string, images=[Image.open('figures/example.jpg')], return_tensors="pt")
86
- inputs = {key: value.to('cuda') for key, value in inputs.items()}
87
- tgt_output = model(tgt=inputs)["tgt_reps"]
88
- print(string, '=', model.compute_similarity(qry_output, tgt_output))
89
  ## <|image|><|begin_of_text|> Represent the given image. = tensor([[0.4219]], device='cuda:0', dtype=torch.bfloat16)
90
- inputs = processor(text='Find me an everyday image that matches the given caption: A cat and a tiger.', return_tensors="pt")
91
- inputs = {key: value.to('cuda') for key, value in inputs.items()}
92
- qry_output = model(qry=inputs)["qry_reps"]
93
  string = '<|image|><|begin_of_text|> Represent the given image.'
94
- inputs = processor(text=string, images=[Image.open('figures/example.jpg')], return_tensors="pt")
95
- inputs = {key: value.to('cuda') for key, value in inputs.items()}
96
- tgt_output = model(tgt=inputs)["tgt_reps"]
97
- print(string, '=', model.compute_similarity(qry_output, tgt_output))
98
  ## <|image|><|begin_of_text|> Represent the given image. = tensor([[0.3887]], device='cuda:0', dtype=torch.bfloat16)
99
  ```
100
 
@@ -106,4 +113,4 @@ print(string, '=', model.compute_similarity(qry_output, tgt_output))
106
  journal={arXiv preprint arXiv:2502.08468},
107
  year={2025}
108
  }
109
- ```
 
44
 
45
  Then you can enter the directory to run the following command.
46
  ```python
47
+ from transformers import MllamaForConditionalGeneration, AutoProcessor
 
 
48
  import torch
 
49
  from PIL import Image
50
+
51
+ # Pooling and Normalization
52
+ def last_pooling(last_hidden_state, attention_mask, normalize=True):
53
+ sequence_lengths = attention_mask.sum(dim=1) - 1
54
+ batch_size = last_hidden_state.shape[0]
55
+ reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
56
+ if normalize:
57
+ reps = torch.nn.functional.normalize(reps, p=2, dim=-1)
58
+ return reps
59
+
60
+ def compute_similarity(q_reps, p_reps):
61
+ return torch.matmul(q_reps, p_reps.transpose(0, 1))
62
+
63
+ model_name = "intfloat/mmE5-mllama-11b-instruct"
64
+
65
+ # Load Processor and Model
66
+ processor = AutoProcessor.from_pretrained(model_name)
67
+ model = MllamaForConditionalGeneration.from_pretrained(
68
+ model_name, torch_dtype=torch.bfloat16
69
+ ).to("cuda")
70
  model.eval()
71
+
72
  # Image + Text -> Text
73
  inputs = processor(text='<|image|><|begin_of_text|> Represent the given image with the following question: What is in the image', images=[Image.open(
74
+ 'figures/example.jpg')], return_tensors="pt").to("cuda")
75
+ qry_output = last_pooling(model(**inputs, return_dict=True, output_hidden_states=True).hidden_states[-1], inputs['attention_mask'])
76
+
77
  string = 'A cat and a dog'
78
+ text_inputs = processor(text=string, return_tensors="pt").to("cuda")
79
+ tgt_output = last_pooling(model(**text_inputs, return_dict=True, output_hidden_states=True).hidden_states[-1], text_inputs['attention_mask'])
80
+ print(string, '=', compute_similarity(qry_output, tgt_output))
 
81
  ## A cat and a dog = tensor([[0.3965]], device='cuda:0', dtype=torch.bfloat16)
82
+
83
  string = 'A cat and a tiger'
84
+ text_inputs = processor(text=string, return_tensors="pt").to("cuda")
85
+ tgt_output = last_pooling(model(**text_inputs, return_dict=True, output_hidden_states=True).hidden_states[-1], text_inputs['attention_mask'])
86
+ print(string, '=', compute_similarity(qry_output, tgt_output))
 
87
  ## A cat and a tiger = tensor([[0.3105]], device='cuda:0', dtype=torch.bfloat16)
88
+
89
  # Text -> Image
90
+ inputs = processor(text='Find me an everyday image that matches the given caption: A cat and a dog.', return_tensors="pt").to("cuda")
91
+ qry_output = last_pooling(model(**inputs, return_dict=True, output_hidden_states=True).hidden_states[-1], inputs['attention_mask'])
92
+
93
  string = '<|image|><|begin_of_text|> Represent the given image.'
94
+ tgt_inputs = processor(text=string, images=[Image.open('figures/example.jpg')], return_tensors="pt").to("cuda")
95
+ tgt_output = last_pooling(model(**tgt_inputs, return_dict=True, output_hidden_states=True).hidden_states[-1], tgt_inputs['attention_mask'])
96
+ print(string, '=', compute_similarity(qry_output, tgt_output))
 
97
  ## <|image|><|begin_of_text|> Represent the given image. = tensor([[0.4219]], device='cuda:0', dtype=torch.bfloat16)
98
+
99
+ inputs = processor(text='Find me an everyday image that matches the given caption: A cat and a tiger.', return_tensors="pt").to("cuda")
100
+ qry_output = last_pooling(model(**inputs, return_dict=True, output_hidden_states=True).hidden_states[-1], inputs['attention_mask'])
101
  string = '<|image|><|begin_of_text|> Represent the given image.'
102
+ tgt_inputs = processor(text=string, images=[Image.open('figures/example.jpg')], return_tensors="pt").to("cuda")
103
+ tgt_output = last_pooling(model(**tgt_inputs, return_dict=True, output_hidden_states=True).hidden_states[-1], tgt_inputs['attention_mask'])
104
+ print(string, '=', compute_similarity(qry_output, tgt_output))
 
105
  ## <|image|><|begin_of_text|> Represent the given image. = tensor([[0.3887]], device='cuda:0', dtype=torch.bfloat16)
106
  ```
107
 
 
113
  journal={arXiv preprint arXiv:2502.08468},
114
  year={2025}
115
  }
116
+ ```
custom_st.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Any, Dict, Optional, List
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, MllamaForConditionalGeneration
6
+ from sentence_transformers.models import Transformer as BaseTransformer
7
+
8
+
9
+ class MultiModalTransformer(BaseTransformer):
10
+ def __init__(
11
+ self,
12
+ model_name_or_path: str,
13
+ cache_dir: Optional[str] = None,
14
+ tokenizer_args: Optional[Dict[str, Any]] = None,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(model_name_or_path, **kwargs)
18
+ if tokenizer_args is None:
19
+ tokenizer_args = {}
20
+
21
+ # Initialize processor
22
+ self.processor = AutoProcessor.from_pretrained(
23
+ model_name_or_path, cache_dir=cache_dir, **tokenizer_args
24
+ )
25
+
26
+ def _load_model(
27
+ self,
28
+ model_name_or_path: str,
29
+ config,
30
+ cache_dir: str,
31
+ backend: str,
32
+ is_peft_model: bool,
33
+ **model_args,
34
+ ) -> None:
35
+ self.auto_model = MllamaForConditionalGeneration.from_pretrained(
36
+ model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
37
+ )
38
+
39
+ def forward(
40
+ self, features: Dict[str, torch.Tensor], **kwargs
41
+ ) -> Dict[str, torch.Tensor]:
42
+ # Process inputs through the model
43
+ outputs = self.auto_model(
44
+ **features,
45
+ return_dict=True,
46
+ output_hidden_states=True,
47
+ **kwargs
48
+ )
49
+
50
+ # Apply last pooling and normalization
51
+ last_hidden_state = outputs.hidden_states[-1]
52
+ attention_mask = features["attention_mask"]
53
+ sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
54
+
55
+ features.update({"sentence_embedding": sentence_embedding})
56
+ return features
57
+
58
+ def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
59
+ """Apply last token pooling and L2 normalization"""
60
+ sequence_lengths = attention_mask.sum(dim=1) - 1
61
+ batch_size = last_hidden_state.shape[0]
62
+ reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
63
+ return torch.nn.functional.normalize(reps, p=2, dim=-1)
64
+
65
+ def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
66
+ def process_text_item(item):
67
+ if isinstance(item, str):
68
+ return item, []
69
+
70
+ text, images = "", []
71
+ for sub_item in item:
72
+ if sub_item["type"] == "text":
73
+ text += sub_item["content"]
74
+ elif sub_item["type"] in ["image_bytes", "image_path"]:
75
+ text += "<|image|>"
76
+ if sub_item["type"] == "image_bytes":
77
+ img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
78
+ else:
79
+ img = Image.open(sub_item["content"]).convert("RGB")
80
+ images.append(img)
81
+ else:
82
+ raise ValueError(f"Unknown data type {sub_item['type']}")
83
+ return text, images
84
+
85
+ all_texts, all_images = [], []
86
+ for item in texts:
87
+ text, images = process_text_item(item)
88
+ all_texts.append(text)
89
+ all_images.extend(images)
90
+
91
+ # Process inputs through the processor
92
+ if all_images:
93
+ inputs = self.processor(
94
+ text=all_texts,
95
+ images=all_images,
96
+ padding="longest",
97
+ truncation=True,
98
+ max_length=self.max_seq_length,
99
+ return_tensors="pt"
100
+ )
101
+ else:
102
+ inputs = self.processor(
103
+ text=all_texts,
104
+ padding="longest",
105
+ truncation=True,
106
+ max_length=self.max_seq_length,
107
+ return_tensors="pt"
108
+ )
109
+
110
+ return inputs