|
from typing import Any, Dict, Optional, List |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoProcessor, MllamaForConditionalGeneration |
|
from sentence_transformers.models import Transformer as BaseTransformer |
|
|
|
class MultiModalTransformer(BaseTransformer): |
|
def __init__( |
|
self, |
|
model_name_or_path: str, |
|
cache_dir: Optional[str] = None, |
|
tokenizer_args: Optional[Dict[str, Any]] = None, |
|
**kwargs, |
|
): |
|
super().__init__(model_name_or_path, **kwargs) |
|
if tokenizer_args is None: |
|
tokenizer_args = {} |
|
|
|
|
|
self.processor = AutoProcessor.from_pretrained( |
|
model_name_or_path, cache_dir=cache_dir, **tokenizer_args |
|
) |
|
|
|
|
|
config = self.auto_model.config |
|
if hasattr(config, 'use_cache'): |
|
config.use_cache = False |
|
|
|
padding_side = "right" |
|
self.processor.tokenizer.padding_side = padding_side |
|
config.padding_side = padding_side |
|
self.auto_model.padding_side = padding_side |
|
|
|
def forward( |
|
self, features: Dict[str, torch.Tensor], **kwargs |
|
) -> Dict[str, torch.Tensor]: |
|
|
|
outputs = self.auto_model( |
|
**features, |
|
return_dict=True, |
|
output_hidden_states=True, |
|
**kwargs |
|
) |
|
|
|
|
|
last_hidden_state = outputs.hidden_states[-1] |
|
attention_mask = features["attention_mask"] |
|
sentence_embedding = self._last_pooling(last_hidden_state, attention_mask) |
|
|
|
features.update({"sentence_embedding": sentence_embedding}) |
|
return features |
|
|
|
def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: |
|
"""Apply last token pooling and L2 normalization""" |
|
sequence_lengths = attention_mask.sum(dim=1) - 1 |
|
batch_size = last_hidden_state.shape[0] |
|
reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths] |
|
return torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
|
|
def tokenize(self, texts: List[Dict] | List[str]) -> Dict[str, torch.Tensor]: |
|
def process_text_item(item): |
|
if isinstance(item, str): |
|
return item, [] |
|
|
|
text, images = "", [] |
|
for sub_item in item: |
|
if sub_item["type"] == "text": |
|
text += sub_item["content"] |
|
elif sub_item["type"] in ["image_bytes", "image_path"]: |
|
text += "<|image|><|begin_of_text|> Represent the given image" |
|
if sub_item["type"] == "image_bytes": |
|
img = Image.open(BytesIO(sub_item["content"])).convert("RGB") |
|
else: |
|
img = Image.open(sub_item["content"]).convert("RGB") |
|
images.append(img) |
|
else: |
|
raise ValueError(f"Unknown data type {sub_item['type']}") |
|
return text, images |
|
|
|
all_texts, all_images = [], [] |
|
for item in texts: |
|
text, images = process_text_item(item) |
|
all_texts.append(text) |
|
all_images.extend(images) |
|
|
|
|
|
if all_images: |
|
inputs = self.processor( |
|
text=all_texts, |
|
images=all_images, |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_seq_length, |
|
return_tensors="pt" |
|
) |
|
else: |
|
inputs = self.processor( |
|
text=all_texts, |
|
padding="longest", |
|
truncation=True, |
|
max_length=self.max_seq_length, |
|
return_tensors="pt" |
|
) |
|
|
|
return inputs |