demo / colpali_manager.py
Kazel
change
58d195f
from colpali_engine.models import ColPali
from colpali_engine.models.paligemma.colpali.processing_colpali import ColPaliProcessor
from colpali_engine.utils.processing_utils import BaseVisualRetrieverProcessor
from colpali_engine.utils.torch_utils import ListDataset, get_torch_device
from torch.utils.data import DataLoader
import torch
from typing import List, cast
#from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor
from colpali_engine.models import ColIdefics3, ColIdefics3Processor
from tqdm import tqdm
from PIL import Image
import os
import spaces
#this part is for local runs
torch.cuda.empty_cache()
#get model name from .env variable & set directory & processor dir as the model names!
import dotenv
# Load the .env file
dotenv_file = dotenv.find_dotenv()
dotenv.load_dotenv(dotenv_file)
model_name = os.environ['colpali'] #"vidore/colSmol-256M"
device = get_torch_device("cuda") #try using cpu instead of cuda?
#switch to locally downloading models & loading locally rather than from hf
#
current_working_directory = os.getcwd()
save_directory = model_name # Directory to save the specific model name
save_directory = os.path.join(current_working_directory, save_directory)
processor_directory = model_name+'_processor' # Directory to save the processor
processor_directory = os.path.join(current_working_directory, processor_directory)
if not os.path.exists(save_directory): #download if directory not created/model not loaded
# Directory does not exist; create it
if "colSmol" in model_name: #if colsmol
model = ColIdefics3.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColIdefics3Processor, ColIdefics3Processor.from_pretrained(model_name))
else: #if colpali v1.3 etc
model = ColPali.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map=device,
).eval()
processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
os.makedirs(save_directory)
print(f"Directory '{save_directory}' created.")
model.save_pretrained(save_directory)
os.makedirs(processor_directory)
processor.save_pretrained(processor_directory)
else:
if "colSmol" in model_name:
model = ColIdefics3.from_pretrained(save_directory)
processor = ColIdefics3Processor.from_pretrained(processor_directory, use_fast=True)
else:
model = ColPali.from_pretrained(save_directory)
processor = ColPaliProcessor.from_pretrained(processor_directory, use_fast=True)
class ColpaliManager:
def __init__(self, device = "cuda", model_name = model_name): #need to hot potato/use diff gpus between colpali & ollama
print(f"Initializing ColpaliManager with device {device} and model {model_name}")
# self.device = get_torch_device(device)
# self.model = ColPali.from_pretrained(
# model_name,
# torch_dtype=torch.bfloat16,
# device_map=self.device,
# ).eval()
# self.processor = cast(ColPaliProcessor, ColPaliProcessor.from_pretrained(model_name))
@spaces.GPU
def get_images(self, paths: list[str]) -> List[Image.Image]:
model.to("cuda")
return [Image.open(path) for path in paths]
@spaces.GPU
def process_images(self, image_paths:list[str], batch_size=5):
model.to("cuda")
print(f"Processing {len(image_paths)} image_paths")
images = self.get_images(image_paths)
dataloader = DataLoader(
dataset=ListDataset[str](images),
batch_size=batch_size,
shuffle=False,
collate_fn=lambda x: processor.process_images(x),
)
ds: List[torch.Tensor] = []
for batch_doc in tqdm(dataloader):
with torch.no_grad():
batch_doc = {k: v.to(model.device) for k, v in batch_doc.items()}
embeddings_doc = model(**batch_doc)
ds.extend(list(torch.unbind(embeddings_doc.to(device))))
ds_np = [d.float().cpu().numpy() for d in ds]
return ds_np
@spaces.GPU
def process_text(self, texts: list[str]):
model.to("cuda") #ensure this is commented out so ollama/multimodal llm can use gpu! (nah wrong, need to enable so that it can process multiple)
print(f"Processing {len(texts)} texts")
dataloader = DataLoader(
dataset=ListDataset[str](texts),
batch_size=5,
shuffle=False,
collate_fn=lambda x: processor.process_queries(x),
)
qs: List[torch.Tensor] = []
for batch_query in dataloader:
with torch.no_grad():
batch_query = {k: v.to(model.device) for k, v in batch_query.items()}
embeddings_query = model(**batch_query)
qs.extend(list(torch.unbind(embeddings_query.to(device))))
qs_np = [q.float().cpu().numpy() for q in qs]
model.to("cpu") # Moves all model parameters and buffers to the CPU, freeing up gpu for ollama call after this process text call! (THIS WORKS!)
return qs_np