# Importing required modules and libraries from colpali_engine.models import ColPali # Main ColPali model for embeddings from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor # Preprocessing for ColPali from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor # Base processor utility from colpali_engine.utils.torch_utils import ListDataset, get_torch_device # Torch utilities for dataset and device management from torch.utils.data import DataLoader # PyTorch DataLoader for batching import torch # PyTorch library from typing import List, cast # Type annotations from tqdm import tqdm # Progress bar utility from PIL import Image # Image processing library import os # OS module for file path handling import spaces # Custom decorator module for GPU management # Setting model name and initializing device model_name = "vidore/colpali-v1.2" device = get_torch_device("cuda") # Get the available CUDA device # Load the ColPali model with the specified configuration model = ColPali.from_pretrained( model_name, torch_dtype=torch.bfloat16, # Use bfloat16 for reduced precision device_map=device, # Map the model to the selected device ).eval() # Set the model to evaluation mode # Initialize the processor for handling image and text inputs processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) class ColpaliManager: """ A class to manage the processing of images and text using the ColPali model. """ def __init__(self, device="cuda", model_name="vidore/colpali-v1.2"): print(f"Initializing ColpaliManager with device {device} and model {model_name}") # Uncomment the below lines if the class should initialize its own model and processor # self.device = get_torch_device(device) # self.model = ColPali.from_pretrained( # model_name, # torch_dtype=torch.bfloat16, # device_map=self.device, # ).eval() # self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name)) @spaces.GPU def get_images(self, paths: list[str]) -> List[Image.Image]: """ Load images from the given file paths. Args: paths (list[str]): List of file paths to images. Returns: List[Image.Image]: List of loaded PIL Image objects. """ return [Image.open(path) for path in paths] @spaces.GPU def process_images(self, image_paths: list[str], batch_size=5): """ Process a list of image paths to generate embeddings. Args: image_paths (list[str]): List of image file paths. batch_size (int): Batch size for processing images. Returns: list: List of image embeddings as NumPy arrays. """ print(f"Processing {len(image_paths)} image_paths") # Load images images = self.get_images(image_paths) # Create a DataLoader for batching the images dataloader = DataLoader( dataset=ListDataset[str](images), batch_size=batch_size, shuffle=False, collate_fn=lambda x: processor.process_images(x), # Process images using the processor ) ds: List[torch.Tensor] = [] # Initialize a list to store embeddings for batch_doc in tqdm(dataloader): # Iterate through batches with a progress bar with torch.no_grad(): # Disable gradient calculations for inference # Move batch to the model's device batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()} # Generate embeddings embeddings_doc = model(**batch_doc) ds.extend(list(torch.unbind(embeddings_doc.to(device)))) # Append each embedding to the list # Convert embeddings to NumPy arrays ds_np = [d.float().cpu().numpy() for d in ds] return ds_np @spaces.GPU def process_text(self, texts: list[str]): """ Process a list of text inputs to generate embeddings. Args: texts (list[str]): List of text inputs. Returns: list: List of text embeddings as NumPy arrays. """ print(f"Processing {len(texts)} texts") # Create a DataLoader for batching the texts dataloader = DataLoader( dataset=ListDataset[str](texts), batch_size=1, # Process texts one at a time shuffle=False, collate_fn=lambda x: processor.process_queries(x), # Process texts using the processor ) qs: List[torch.Tensor] = [] # Initialize a list to store text embeddings for batch_query in dataloader: # Iterate through batches with torch.no_grad(): # Disable gradient calculations for inference # Move batch to the model's device batch_query = {k: v.to(model.device) for k, v in batch_query.items()} # Generate embeddings embeddings_query = model(**batch_query) qs.extend(list(torch.unbind(embeddings_query.to(device)))) # Append each embedding to the list # Convert embeddings to NumPy arrays qs_np = [q.float().cpu().numpy() for q in qs] return qs_np