|
import base64 |
|
import io |
|
import os |
|
from PIL import Image |
|
import torch |
|
from transformers import ColPaliProcessor, ColPaliForRetrieval |
|
from typing import Dict, Any, List |
|
|
|
class EndpointHandler: |
|
def __init__(self, model_path: str = None): |
|
""" |
|
Initialize the endpoint handler using the ColPali model for OCR extraction. |
|
If no model path is provided, it defaults to 'vidore/colpali-v1.3-hf'. |
|
""" |
|
if model_path is None: |
|
model_path = os.path.dirname(os.path.realpath(__file__)) |
|
try: |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
self.model = ColPaliForRetrieval.from_pretrained( |
|
model_path, |
|
device_map="cuda" if torch.cuda.is_available() else "cpu", |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
).to(self.device) |
|
|
|
self.processor = ColPaliProcessor.from_pretrained(model_path, trust_remote_code=True) |
|
except Exception as e: |
|
raise RuntimeError(f"Error loading model or processor: {e}") |
|
|
|
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Process the input data for OCR extraction. |
|
|
|
Expects a dictionary with an "inputs" key containing a list of dictionaries. |
|
Each dictionary must have an "image" key with a base64-encoded image string. |
|
For OCR extraction, no text prompt is provided. |
|
""" |
|
try: |
|
inputs_list = data.get("inputs", []) |
|
if not inputs_list or not isinstance(inputs_list, list): |
|
return {"error": "Inputs should be a list of dictionaries with an 'image' key."} |
|
|
|
images: List[Image.Image] = [] |
|
for item in inputs_list: |
|
image_b64 = item.get("image") |
|
if not image_b64: |
|
return {"error": "One of the input items is missing 'image' data."} |
|
try: |
|
|
|
image = Image.open(io.BytesIO(base64.b64decode(image_b64))).convert("RGB") |
|
images.append(image) |
|
except Exception as e: |
|
return {"error": f"Failed to decode one of the images: {e}"} |
|
|
|
|
|
model_inputs = self.processor( |
|
images=images, |
|
return_tensors="pt", |
|
padding=True, |
|
).to(self.device) |
|
|
|
|
|
|
|
bos_token_id = ( |
|
self.processor.tokenizer.bos_token_id |
|
or self.processor.tokenizer.cls_token_id |
|
or self.processor.tokenizer.pad_token_id |
|
) |
|
if bos_token_id is None: |
|
raise RuntimeError("No BOS token found in the tokenizer.") |
|
batch_size = model_inputs["pixel_values"].shape[0] |
|
dummy_input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long).to(self.device) |
|
model_inputs["input_ids"] = dummy_input_ids |
|
|
|
|
|
config = data.get("config", {}) |
|
max_new_tokens = config.get("max_new_tokens", 256) |
|
temperature = config.get("temperature", 0.8) |
|
num_return_sequences = config.get("num_return_sequences", 1) |
|
do_sample = bool(config.get("do_sample", True)) |
|
|
|
|
|
outputs = self.model.generate( |
|
**model_inputs, |
|
max_new_tokens=max_new_tokens, |
|
temperature=temperature, |
|
num_return_sequences=num_return_sequences, |
|
do_sample=do_sample, |
|
) |
|
|
|
|
|
text_output = self.processor.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
|
|
|
return {"responses": text_output} |
|
|
|
except Exception as e: |
|
return {"error": f"Unexpected error: {e}"} |
|
|
|
|
|
_service = EndpointHandler() |
|
|
|
def handle(data, context): |
|
""" |
|
Entry point for the Hugging Face dedicated inference endpoint. |
|
Processes input data and returns the extracted OCR text. |
|
""" |
|
try: |
|
if data is None: |
|
return {"error": "No input data received"} |
|
return _service(data) |
|
except Exception as e: |
|
return {"error": f"Exception in handler: {e}"} |