Visual Document Retrieval
Transformers
Safetensors
ColPali
English
pretraining
colpali-v1.3-hf / handler.py
adrish's picture
updated the code
9931aed
import base64
import io
import os
from PIL import Image
import torch
from transformers import ColPaliProcessor, ColPaliForRetrieval
from typing import Dict, Any, List
class EndpointHandler:
def __init__(self, model_path: str = None):
"""
Initialize the endpoint handler using the ColPali model for OCR extraction.
If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'.
"""
if model_path is None:
model_path = os.path.dirname(os.path.realpath(__file__))
try:
# Use GPU if available, otherwise CPU.
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the specialized ColPali model (designed for retrieval but repurposed here for OCR generation).
self.model = ColPaliForRetrieval.from_pretrained(
model_path,
device_map="cuda" if torch.cuda.is_available() else "cpu",
trust_remote_code=True,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
).to(self.device)
# Load the processor that handles image preprocessing.
self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True)
except Exception as e:
raise RuntimeError(f"Error loading model or processor: {e}")
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
"""
Process the input data for OCR extraction.
Expects a dictionary with an "inputs" key containing a list of dictionaries.
Each dictionary must have an "image" key with a base64-encoded image string.
For OCR extraction, no text prompt is provided.
"""
try:
inputs_list = data.get("inputs", [])
if not inputs_list or not isinstance(inputs_list, list):
return {"error": "Inputs should be a list of dictionaries with an 'image' key."}
images: List[Image.Image] = []
for item in inputs_list:
image_b64 = item.get("image")
if not image_b64:
return {"error": "One of the input items is missing 'image' data."}
try:
# Decode the base64 string and convert to an RGB PIL image.
image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB")
images.append(image)
except Exception as e:
return {"error": f"Failed to decode one of the images: {e}"}
# Process only images with the processor (to avoid the text+image conflict).
model_inputs = self.processor(
images=images,
return_tensors="pt",
padding=True,
).to(self.device)
# Manually create a dummy text prompt by inserting a beginning-of-sequence token.
# This is necessary to trigger text generation even though no prompt is provided.
bos_token_id = (
self.processor.tokenizer.bos_token_id
or self.processor.tokenizer.cls_token_id
or self.processor.tokenizer.pad_token_id
)
if bos_token_id is None:
raise RuntimeError("No BOS token found in the tokenizer.")
batch_size = model_inputs["pixel_values"].shape[0]
dummy_input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long).to(self.device)
model_inputs["input_ids"] = dummy_input_ids
# Generation parameters (can be overridden via the "config" field).
config = data.get("config", {})
max_new_tokens = config.get("max_new_tokens", 256)
temperature = config.get("temperature", 0.8)
num_return_sequences = config.get("num_return_sequences", 1)
do_sample = bool(config.get("do_sample", True))
# Call generate on the model using the image-only inputs augmented with the dummy text.
outputs = self.model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
num_return_sequences=num_return_sequences,
do_sample=do_sample,
)
# Decode generated tokens into text using the processor's tokenizer.
text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return {"responses": text_output}
except Exception as e:
return {"error": f"Unexpected error: {e}"}
# Instantiate the endpoint handler.
_service = EndpointHandler()
def handle(data, context):
"""
Entry point for the Hugging Face dedicated inference endpoint.
Processes input data and returns the extracted OCR text.
"""
try:
if data is None:
return {"error": "No input data received"}
return _service(data)
except Exception as e:
return {"error": f"Exception in handler: {e}"}