Last commit not found
from io import BytesIO | |
from typing import Any, Dict, Optional, List | |
import torch | |
from PIL import Image | |
from transformers import AutoProcessor, MllamaForConditionalGeneration | |
from sentence_transformers.models import Transformer as BaseTransformer | |
class MultiModalTransformer(BaseTransformer): | |
def __init__( | |
self, | |
model_name_or_path: str, | |
cache_dir: Optional[str] = None, | |
tokenizer_args: Optional[Dict[str, Any]] = None, | |
**kwargs, | |
): | |
super().__init__(model_name_or_path, **kwargs) | |
if tokenizer_args is None: | |
tokenizer_args = {} | |
tokenizer_args.pop("trust_remote_code", None) | |
# Initialize processor | |
self.processor = AutoProcessor.from_pretrained( | |
model_name_or_path, cache_dir=cache_dir, **tokenizer_args | |
) | |
def _load_model( | |
self, | |
model_name_or_path: str, | |
config, | |
cache_dir: str, | |
backend: str, | |
is_peft_model: bool, | |
**model_args, | |
) -> None: | |
model_args.pop("trust_remote_code", None) | |
self.auto_model = MllamaForConditionalGeneration.from_pretrained( | |
model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args | |
) | |
def forward( | |
self, features: Dict[str, torch.Tensor], **kwargs | |
) -> Dict[str, torch.Tensor]: | |
# Process inputs through the model | |
outputs = self.auto_model( | |
**features, | |
return_dict=True, | |
output_hidden_states=True, | |
**kwargs | |
) | |
features.update({"token_embeddings": outputs.hidden_states[-1]}) | |
return features | |
def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]: | |
def process_text_item(item): | |
if isinstance(item, str): | |
return item, None | |
text, img = "", None | |
if "image" in item: | |
text += "<|image|>" | |
img = item["image"] | |
if isinstance(img, bytes): | |
img = Image.open(BytesIO(img)).convert("RGB") | |
elif isinstance(img, str): | |
img = Image.open(img).convert("RGB") | |
elif not isinstance(img, Image): | |
raise ValueError(f"Unknown image type {type(img)}") | |
if "text" in item: | |
if text: | |
text += "<|begin_of_text|> " | |
text += item["text"].lstrip() | |
return text, img | |
all_texts, all_images = [], [] | |
for item in texts: | |
text, images = process_text_item(item) | |
all_texts.append(text) | |
all_images.append(images) | |
if all_images != [None] * len(all_images): | |
inputs = self.processor( | |
text=all_texts, | |
images=all_images, | |
padding="longest", | |
truncation=True, | |
max_length=self.max_seq_length, | |
return_tensors="pt" | |
) | |
else: | |
inputs = self.processor( | |
text=all_texts, | |
padding="longest", | |
truncation=True, | |
max_length=self.max_seq_length, | |
return_tensors="pt" | |
) | |
return inputs | |