import base64 import io import os from PIL import Image import torch from transformers import ColPaliProcessor, ColPaliForRetrieval from typing import Dict, Any, List class EndpointHandler: def __init__(self, model_path: str = None): """ Initialize the endpoint handler using the ColPali model for OCR extraction. If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'. """ if model_path is None: model_path = os.path.dirname(os.path.realpath(__file__)) try: # Use GPU if available, otherwise CPU. self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load the specialized ColPali model (designed for retrieval but repurposed here for OCR generation). self.model = ColPaliForRetrieval.from_pretrained( model_path, device_map="cuda" if torch.cuda.is_available() else "cpu", trust_remote_code=True, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ).to(self.device) # Load the processor that handles image preprocessing. self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True) except Exception as e: raise RuntimeError(f"Error loading model or processor: {e}") def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: """ Process the input data for OCR extraction. Expects a dictionary with an "inputs" key containing a list of dictionaries. Each dictionary must have an "image" key with a base64-encoded image string. For OCR extraction, no text prompt is provided. """ try: inputs_list = data.get("inputs", []) if not inputs_list or not isinstance(inputs_list, list): return {"error": "Inputs should be a list of dictionaries with an 'image' key."} images: List[Image.Image] = [] for item in inputs_list: image_b64 = item.get("image") if not image_b64: return {"error": "One of the input items is missing 'image' data."} try: # Decode the base64 string and convert to an RGB PIL image. image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB") images.append(image) except Exception as e: return {"error": f"Failed to decode one of the images: {e}"} # Process only images with the processor (to avoid the text+image conflict). model_inputs = self.processor( images=images, return_tensors="pt", padding=True, ).to(self.device) # Manually create a dummy text prompt by inserting a beginning-of-sequence token. # This is necessary to trigger text generation even though no prompt is provided. bos_token_id = ( self.processor.tokenizer.bos_token_id or self.processor.tokenizer.cls_token_id or self.processor.tokenizer.pad_token_id ) if bos_token_id is None: raise RuntimeError("No BOS token found in the tokenizer.") batch_size = model_inputs["pixel_values"].shape[0] dummy_input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long).to(self.device) model_inputs["input_ids"] = dummy_input_ids # Generation parameters (can be overridden via the "config" field). config = data.get("config", {}) max_new_tokens = config.get("max_new_tokens", 256) temperature = config.get("temperature", 0.8) num_return_sequences = config.get("num_return_sequences", 1) do_sample = bool(config.get("do_sample", True)) # Call generate on the model using the image-only inputs augmented with the dummy text. outputs = self.model.generate( **model_inputs, max_new_tokens=max_new_tokens, temperature=temperature, num_return_sequences=num_return_sequences, do_sample=do_sample, ) # Decode generated tokens into text using the processor's tokenizer. text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) return {"responses": text_output} except Exception as e: return {"error": f"Unexpected error: {e}"} # Instantiate the endpoint handler. _service = EndpointHandler() def handle(data, context): """ Entry point for the Hugging Face dedicated inference endpoint. Processes input data and returns the extracted OCR text. """ try: if data is None: return {"error": "No input data received"} return _service(data) except Exception as e: return {"error": f"Exception in handler: {e}"}