multimodal_rag / colpali_manager.py
ej68okap
new code added
241c492
raw
history blame
5.36 kB
# Importing required modules and libraries
from colpali_engine.models import ColPali # Main ColPali model for embeddings
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor # Preprocessing for ColPali
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor # Base processor utility
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device # Torch utilities for dataset and device management
from torch.utils.data import DataLoader # PyTorch DataLoader for batching
import torch # PyTorch library
from typing import List, cast # Type annotations
from tqdm import tqdm # Progress bar utility
from PIL import Image # Image processing library
import os # OS module for file path handling
import spaces # Custom decorator module for GPU management
# Setting model name and initializing device
model_name = "vidore/colpali-v1.2"
device = get_torch_device("cuda") # Get the available CUDA device
# Load the ColPali model with the specified configuration
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16, # Use bfloat16 for reduced precision
device_map=device, # Map the model to the selected device
).eval() # Set the model to evaluation mode
# Initialize the processor for handling image and text inputs
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
class ColpaliManager:
"""
A class to manage the processing of images and text using the ColPali model.
"""
def __init__(self, device="cuda", model_name="vidore/colpali-v1.2"):
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
# Uncomment the below lines if the class should initialize its own model and processor
# self.device = get_torch_device(device)
# self.model = ColPali.from_pretrained(
# model_name,
# torch_dtype=torch.bfloat16,
# device_map=self.device,
# ).eval()
# self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
@spaces.GPU
def get_images(self, paths: list[str]) -> List[Image.Image]:
"""
Load images from the given file paths.
Args:
paths (list[str]): List of file paths to images.
Returns:
List[Image.Image]: List of loaded PIL Image objects.
"""
return [Image.open(path) for path in paths]
@spaces.GPU
def process_images(self, image_paths: list[str], batch_size=5):
"""
Process a list of image paths to generate embeddings.
Args:
image_paths (list[str]): List of image file paths.
batch_size (int): Batch size for processing images.
Returns:
list: List of image embeddings as NumPy arrays.
"""
print(f"Processing {len(image_paths)} image_paths")
# Load images
images = self.get_images(image_paths)
# Create a DataLoader for batching the images
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: processor.process_images(x), # Process images using the processor
)
ds: List[torch.Tensor] = [] # Initialize a list to store embeddings
for batch_doc in tqdm(dataloader): # Iterate through batches with a progress bar
with torch.no_grad(): # Disable gradient calculations for inference
# Move batch to the model's device
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
# Generate embeddings
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to(device)))) # Append each embedding to the list
# Convert embeddings to NumPy arrays
ds_np = [d.float().cpu().numpy() for d in ds]
return ds_np
@spaces.GPU
def process_text(self, texts: list[str]):
"""
Process a list of text inputs to generate embeddings.
Args:
texts (list[str]): List of text inputs.
Returns:
list: List of text embeddings as NumPy arrays.
"""
print(f"Processing {len(texts)} texts")
# Create a DataLoader for batching the texts
dataloader = DataLoader(
dataset=ListDataset[str](texts),
batch_size=1, # Process texts one at a time
shuffle=False,
collate_fn=lambda x: processor.process_queries(x), # Process texts using the processor
)
qs: List[torch.Tensor] = [] # Initialize a list to store text embeddings
for batch_query in dataloader: # Iterate through batches
with torch.no_grad(): # Disable gradient calculations for inference
# Move batch to the model's device
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
# Generate embeddings
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to(device)))) # Append each embedding to the list
# Convert embeddings to NumPy arrays
qs_np = [q.float().cpu().numpy() for q in qs]
return qs_np