|
import math |
|
from typing import List, Optional, TypeAlias, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
from PIL import Image |
|
from tqdm import tqdm |
|
from transformers import AutoModelForImageTextToText, AutoProcessor |
|
|
|
ImageInput: TypeAlias = Union[Image.Image, List[Image.Image]] |
|
BatchImageInput: TypeAlias = Union[List[Image.Image], List[List[Image.Image]]] |
|
|
|
|
|
class OpsMMEmbeddingV1(nn.Module): |
|
def __init__( |
|
self, |
|
model_name: str, |
|
device: str = "cuda", |
|
max_length: Optional[int] = None, |
|
attn_implementation: Optional[str] = None, |
|
): |
|
super().__init__() |
|
self.device = device |
|
self.max_length = max_length |
|
self.default_instruction = "You are a helpful assistant." |
|
self.base_model = AutoModelForImageTextToText.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.bfloat16, |
|
low_cpu_mem_usage=True, |
|
attn_implementation=attn_implementation, |
|
).to(self.device) |
|
|
|
self.processor = AutoProcessor.from_pretrained(model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28) |
|
self.processor.tokenizer.padding_side = "left" |
|
self.eval() |
|
|
|
def encode_input(self, input): |
|
hidden_states = self.base_model(**input, return_dict=True, output_hidden_states=True) |
|
hidden_states = hidden_states.hidden_states[-1] |
|
pooled_output = self._pooling(hidden_states) |
|
return pooled_output |
|
|
|
def _pooling(self, last_hidden_state): |
|
batch_size = last_hidden_state.shape[0] |
|
reps = last_hidden_state[torch.arange(batch_size), -1, :] |
|
reps = torch.nn.functional.normalize(reps, p=2, dim=-1) |
|
return reps |
|
|
|
def _validate_instructions( |
|
self, |
|
texts: Optional[List[str]], |
|
images: Optional[BatchImageInput], |
|
instruction: Optional[Union[str, List[str]]], |
|
) -> List[str]: |
|
"""Validate and format instructions to match batch size""" |
|
batch_size = max(len(x) if x is not None else 0 for x in [texts, images]) |
|
|
|
if instruction is None: |
|
return [self.default_instruction] * batch_size |
|
|
|
if isinstance(instruction, str): |
|
return [instruction] * batch_size |
|
|
|
if isinstance(instruction, list): |
|
if len(instruction) != batch_size: |
|
raise ValueError(f"Length of instruction list ({len(instruction)}) must match batch size ({batch_size}) when texts/images are provided") |
|
return instruction |
|
|
|
raise TypeError("instruction must be str, List[str] or None") |
|
|
|
def _process_images(self, images: ImageInput) -> List[Image.Image]: |
|
"""Convert single image or list of images to processed format""" |
|
if isinstance(images, Image.Image) or isinstance(images, str): |
|
return [fetch_image(images)] |
|
return [fetch_image(i) for i in images] |
|
|
|
def embed( |
|
self, |
|
texts: Optional[List[str]] = None, |
|
images: Optional[BatchImageInput] = None, |
|
instruction: Optional[Union[str, List[str]]] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
"""Generate embeddings for text, images, or combined inputs. |
|
|
|
Args: |
|
texts: List of text inputs (optional) |
|
images: Can be: |
|
- List[Image.Image]: Single image per input |
|
- List[List[Image.Image]]: Multiple images per input |
|
instruction: Instruction(s) for the model. Can be: |
|
- None: use default instruction |
|
- str: use same instruction for all inputs |
|
- List[str]: per-input instructions (must match batch size) |
|
""" |
|
if texts is None and images is None: |
|
raise ValueError("Either texts or images must be provided") |
|
|
|
instructions = self._validate_instructions(texts, images, instruction) |
|
|
|
|
|
batch_size = len(texts) if texts is not None else len(images) |
|
|
|
input_texts, input_images = [], [] |
|
for i in range(batch_size): |
|
text = texts[i] if texts is not None else None |
|
image = images[i] if images is not None else None |
|
|
|
input_str = "" |
|
processed_image = None |
|
if image is not None: |
|
processed_image = self._process_images(image) |
|
input_str += "<|vision_start|><|image_pad|><|vision_end|>" * len(processed_image) |
|
|
|
if text is not None: |
|
input_str += text |
|
|
|
msg = f"<|im_start|>system\n{instructions[i]}<|im_end|>\n<|im_start|>user\n{input_str}<|im_end|>\n<|im_start|>assistant\n<|endoftext|>" |
|
|
|
input_texts.append(msg) |
|
input_images.append(processed_image) |
|
|
|
|
|
processed_images = input_images if any(img is not None for img in input_images) else None |
|
|
|
inputs = self.processor( |
|
text=input_texts, |
|
images=processed_images, |
|
padding=True, |
|
truncation=True, |
|
max_length=self.max_length, |
|
return_tensors="pt", |
|
) |
|
inputs = {k: v.to(self.device) for k, v in inputs.items()} |
|
|
|
with torch.inference_mode(): |
|
embeddings = self.encode_input(inputs) |
|
|
|
return embeddings |
|
|
|
def get_text_embeddings( |
|
self, |
|
texts: List[str], |
|
instruction: Optional[Union[str, List[str]]] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
"""Convenience method for text-only embeddings""" |
|
return self.get_fused_embeddings(texts=texts, instruction=instruction, **kwargs) |
|
|
|
def get_image_embeddings( |
|
self, |
|
images: BatchImageInput, |
|
instruction: Optional[Union[str, List[str]]] = None, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
"""Convenience method for image-only embeddings. |
|
|
|
Args: |
|
images: Can be: |
|
- List[Image.Image]: Single image per input |
|
- List[List[Image.Image]]: Multiple images per input |
|
""" |
|
return self.get_fused_embeddings(images=images, instruction=instruction, **kwargs) |
|
|
|
def get_fused_embeddings( |
|
self, |
|
texts: Optional[List[str]] = None, |
|
images: Optional[BatchImageInput] = None, |
|
instruction: Optional[Union[str, List[str]]] = None, |
|
batch_size: int = 8, |
|
show_progress: bool = True, |
|
**kwargs, |
|
) -> torch.Tensor: |
|
"""Batch processing for large collections of texts/images. |
|
|
|
Args: |
|
texts: List of text inputs (optional) |
|
images: Can be: |
|
- List[Image.Image]: Single image per input |
|
- List[List[Image.Image]]: Multiple images per input |
|
instruction: Instruction(s) for the model |
|
batch_size: Number of items to process at once |
|
show_progress: Whether to display progress bar |
|
""" |
|
|
|
if texts is None and images is None: |
|
raise ValueError("Either texts or images must be provided") |
|
|
|
total_items = len(texts) if texts is not None else len(images) |
|
num_batches = math.ceil(total_items / batch_size) |
|
|
|
all_embeddings = [] |
|
progress = tqdm(total=num_batches, disable=not show_progress, desc="Processing") |
|
|
|
for i in range(0, total_items, batch_size): |
|
batch_texts = texts[i : i + batch_size] if texts is not None else None |
|
batch_images = images[i : i + batch_size] if images is not None else None |
|
batch_emb = self.embed(texts=batch_texts, images=batch_images, instruction=instruction) |
|
|
|
all_embeddings.append(batch_emb.cpu()) |
|
progress.update(1) |
|
|
|
progress.close() |
|
return torch.cat(all_embeddings, dim=0).to(self.device) |
|
|
|
def forward(self, **inputs) -> torch.Tensor: |
|
"""Alias for encode_input""" |
|
return self.encode_input(inputs) |
|
|
|
|
|
|
|
import base64 |
|
import logging |
|
import math |
|
from io import BytesIO |
|
|
|
import requests |
|
|
|
IMAGE_FACTOR = 28 |
|
MIN_PIXELS = 256 * 28 * 28 |
|
MAX_PIXELS = 1280 * 28 * 28 |
|
MAX_RATIO = 200 |
|
|
|
|
|
def round_by_factor(number: int, factor: int) -> int: |
|
"""Returns the closest integer to 'number' that is divisible by 'factor'.""" |
|
return round(number / factor) * factor |
|
|
|
|
|
def ceil_by_factor(number: int | float, factor: int) -> int: |
|
"""Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.ceil(number / factor) * factor |
|
|
|
|
|
def floor_by_factor(number: int | float, factor: int) -> int: |
|
"""Returns the largest integer less than or equal to 'number' that is divisible by 'factor'.""" |
|
return math.floor(number / factor) * factor |
|
|
|
|
|
def smart_resize( |
|
height: int, |
|
width: int, |
|
factor: int = IMAGE_FACTOR, |
|
min_pixels: int = MIN_PIXELS, |
|
max_pixels: int = MAX_PIXELS, |
|
) -> tuple[int, int]: |
|
""" |
|
Rescales the image so that the following conditions are met: |
|
1. Both dimensions (height and width) are divisible by 'factor'. |
|
2. The total number of pixels is within the range ['min_pixels', 'max_pixels']. |
|
3. The aspect ratio of the image is maintained as closely as possible. |
|
""" |
|
h_bar = max(factor, round_by_factor(height, factor)) |
|
w_bar = max(factor, round_by_factor(width, factor)) |
|
if h_bar * w_bar > max_pixels: |
|
beta = math.sqrt((height * width) / max_pixels) |
|
h_bar = floor_by_factor(height / beta, factor) |
|
w_bar = floor_by_factor(width / beta, factor) |
|
elif h_bar * w_bar < min_pixels: |
|
beta = math.sqrt(min_pixels / (height * width)) |
|
h_bar = ceil_by_factor(height * beta, factor) |
|
w_bar = ceil_by_factor(width * beta, factor) |
|
|
|
if max(h_bar, w_bar) / min(h_bar, w_bar) > MAX_RATIO: |
|
logging.warning(f"Absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(h_bar, w_bar) / min(h_bar, w_bar)}") |
|
if h_bar > w_bar: |
|
h_bar = w_bar * MAX_RATIO |
|
else: |
|
w_bar = h_bar * MAX_RATIO |
|
return h_bar, w_bar |
|
|
|
|
|
def fetch_image( |
|
image: str | Image.Image, |
|
size_factor: int = IMAGE_FACTOR, |
|
min_pixels: int = MIN_PIXELS, |
|
max_pixels: int = MAX_PIXELS, |
|
) -> Image.Image: |
|
image_obj = None |
|
if isinstance(image, Image.Image): |
|
image_obj = image |
|
elif image.startswith("http://") or image.startswith("https://"): |
|
image_obj = Image.open(requests.get(image, stream=True).raw) |
|
elif image.startswith("file://"): |
|
image_obj = Image.open(image[7:]) |
|
elif image.startswith("data:image"): |
|
if "base64," in image: |
|
_, base64_data = image.split("base64,", 1) |
|
data = base64.b64decode(base64_data) |
|
image_obj = Image.open(BytesIO(data)) |
|
else: |
|
image_obj = Image.open(image) |
|
if image_obj is None: |
|
raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}") |
|
image = image_obj.convert("RGB") |
|
width, height = image.size |
|
resized_height, resized_width = smart_resize( |
|
height, |
|
width, |
|
factor=size_factor, |
|
min_pixels=min_pixels, |
|
max_pixels=max_pixels, |
|
) |
|
image = image.resize((resized_width, resized_height)) |
|
|
|
return image |
|
|
|
|
|
|
|
|