File size: 5,357 Bytes
241c492
 
 
 
 
 
 
 
 
 
 
 
 
 
273089c
241c492
273089c
241c492
273089c
 
241c492
 
 
273089c
241c492
273089c
 
 
241c492
 
 
 
273089c
241c492
273089c
 
241c492
273089c
 
 
 
 
 
 
 
 
 
241c492
 
 
 
 
 
 
 
 
273089c
 
 
241c492
 
 
 
 
 
 
273089c
241c492
 
 
273089c
241c492
 
273089c
 
241c492
273089c
 
 
 
241c492
273089c
 
241c492
 
 
 
273089c
241c492
273089c
241c492
 
 
273089c
 
 
 
 
 
241c492
 
 
 
 
 
 
 
 
273089c
 
241c492
273089c
 
241c492
273089c
241c492
273089c
 
241c492
 
 
 
273089c
241c492
273089c
 
241c492
273089c
241c492
273089c
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Importing required modules and libraries
from colpali_engine.models import ColPali  # Main ColPali model for embeddings
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor  # Preprocessing for ColPali
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor  # Base processor utility
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device  # Torch utilities for dataset and device management
from torch.utils.data import DataLoader  # PyTorch DataLoader for batching
import torch  # PyTorch library
from typing import List, cast  # Type annotations
from tqdm import tqdm  # Progress bar utility
from PIL import Image  # Image processing library
import os  # OS module for file path handling
import spaces  # Custom decorator module for GPU management

# Setting model name and initializing device
model_name = "vidore/colpali-v1.2"
device = get_torch_device("cuda")  # Get the available CUDA device

# Load the ColPali model with the specified configuration
model = ColPali.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,  # Use bfloat16 for reduced precision
    device_map=device,  # Map the model to the selected device
).eval()  # Set the model to evaluation mode

# Initialize the processor for handling image and text inputs
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))


class ColpaliManager:
    """
    A class to manage the processing of images and text using the ColPali model.
    """

    def __init__(self, device="cuda", model_name="vidore/colpali-v1.2"):
        print(f"Initializing ColpaliManager with device {device} and model {model_name}")

        # Uncomment the below lines if the class should initialize its own model and processor
        # self.device = get_torch_device(device)
        # self.model = ColPali.from_pretrained(
        #     model_name,
        #     torch_dtype=torch.bfloat16,
        #     device_map=self.device,
        # ).eval()
        # self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))

    @spaces.GPU
    def get_images(self, paths: list[str]) -> List[Image.Image]:
        """
        Load images from the given file paths.

        Args:
            paths (list[str]): List of file paths to images.

        Returns:
            List[Image.Image]: List of loaded PIL Image objects.
        """
        return [Image.open(path) for path in paths]

    @spaces.GPU
    def process_images(self, image_paths: list[str], batch_size=5):
        """
        Process a list of image paths to generate embeddings.

        Args:
            image_paths (list[str]): List of image file paths.
            batch_size (int): Batch size for processing images.

        Returns:
            list: List of image embeddings as NumPy arrays.
        """
        print(f"Processing {len(image_paths)} image_paths")

        # Load images
        images = self.get_images(image_paths)

        # Create a DataLoader for batching the images
        dataloader = DataLoader(
            dataset=ListDataset[str](images),
            batch_size=batch_size,
            shuffle=False,
            collate_fn=lambda x: processor.process_images(x),  # Process images using the processor
        )

        ds: List[torch.Tensor] = []  # Initialize a list to store embeddings
        for batch_doc in tqdm(dataloader):  # Iterate through batches with a progress bar
            with torch.no_grad():  # Disable gradient calculations for inference
                # Move batch to the model's device
                batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
                # Generate embeddings
                embeddings_doc = model(**batch_doc)
            ds.extend(list(torch.unbind(embeddings_doc.to(device))))  # Append each embedding to the list

        # Convert embeddings to NumPy arrays
        ds_np = [d.float().cpu().numpy() for d in ds]

        return ds_np

    @spaces.GPU
    def process_text(self, texts: list[str]):
        """
        Process a list of text inputs to generate embeddings.

        Args:
            texts (list[str]): List of text inputs.

        Returns:
            list: List of text embeddings as NumPy arrays.
        """
        print(f"Processing {len(texts)} texts")

        # Create a DataLoader for batching the texts
        dataloader = DataLoader(
            dataset=ListDataset[str](texts),
            batch_size=1,  # Process texts one at a time
            shuffle=False,
            collate_fn=lambda x: processor.process_queries(x),  # Process texts using the processor
        )

        qs: List[torch.Tensor] = []  # Initialize a list to store text embeddings
        for batch_query in dataloader:  # Iterate through batches
            with torch.no_grad():  # Disable gradient calculations for inference
                # Move batch to the model's device
                batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
                # Generate embeddings
                embeddings_query = model(**batch_query)

            qs.extend(list(torch.unbind(embeddings_query.to(device))))  # Append each embedding to the list

        # Convert embeddings to NumPy arrays
        qs_np = [q.float().cpu().numpy() for q in qs]

        return qs_np