mmE5-mllama-11b-instruct / custom_st.py
Samoed's picture
Create custom_st.py
d4b7753 verified
raw
history blame
3.99 kB
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer
class MultiModalTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
cache_dir: Optional[str] = None,
tokenizer_args: Optional[Dict[str, Any]] = None,
**kwargs,
):
super().__init__(model_name_or_path, **kwargs)
if tokenizer_args is None:
tokenizer_args = {}
# Initialize processor and set padding side
self.processor = AutoProcessor.from_pretrained(
model_name_or_path, cache_dir=cache_dir, **tokenizer_args
)
# Configure model settings
config = self.auto_model.config
if hasattr(config, 'use_cache'):
config.use_cache = False
padding_side = "right"
self.processor.tokenizer.padding_side = padding_side
config.padding_side = padding_side
self.auto_model.padding_side = padding_side
def forward(
self, features: Dict[str, torch.Tensor], **kwargs
) -> Dict[str, torch.Tensor]:
# Process inputs through the model
outputs = self.auto_model(
**features,
return_dict=True,
output_hidden_states=True,
**kwargs
)
# Apply last pooling and normalization
last_hidden_state = outputs.hidden_states[-1]
attention_mask = features["attention_mask"]
sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
features.update({"sentence_embedding": sentence_embedding})
return features
def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
"""Apply last token pooling and L2 normalization"""
sequence_lengths = attention_mask.sum(dim=1) - 1
batch_size = last_hidden_state.shape[0]
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
return torch.nn.functional.normalize(reps, p=2, dim=-1)
def tokenize(self, texts: List[Dict] | List[str]) -> Dict[str, torch.Tensor]:
def process_text_item(item):
if isinstance(item, str):
return item, []
text, images = "", []
for sub_item in item:
if sub_item["type"] == "text":
text += sub_item["content"]
elif sub_item["type"] in ["image_bytes", "image_path"]:
text += "<|image|><|begin_of_text|> Represent the given image"
if sub_item["type"] == "image_bytes":
img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
else:
img = Image.open(sub_item["content"]).convert("RGB")
images.append(img)
else:
raise ValueError(f"Unknown data type {sub_item['type']}")
return text, images
all_texts, all_images = [], []
for item in texts:
text, images = process_text_item(item)
all_texts.append(text)
all_images.extend(images)
# Process inputs through the processor
if all_images:
inputs = self.processor(
text=all_texts,
images=all_images,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
else:
inputs = self.processor(
text=all_texts,
padding="longest",
truncation=True,
max_length=self.max_seq_length,
return_tensors="pt"
)
return inputs