blip / handler.py
SaiSriTejaKuppa's picture
init commit for dealing with images from local
2b05c9e verified
import base64
from io import BytesIO
from typing import Dict, Any
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import torch
class EndpointHandler():
def __init__(self, path="./"):
# Load the processor and model, and move to CUDA if available
self.processor = BlipProcessor.from_pretrained(path)
self.model = BlipForConditionalGeneration.from_pretrained(path).to("cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, data: Any) -> Dict[str, str]:
"""
Args:
data (:obj:):
includes the input data and the parameters for the inference.
Return:
A :obj:`dict`:. The object returned should be a dict like {"caption": "Generated caption for the image"} containing:
- "caption": The generated caption as a string.
"""
# Extract inputs and parameters
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", {"mode": "image"})
# Get base64 image data and prompt from the inputs
image_base64 = inputs.get("image_base64")
prompt = inputs.get("prompt", "") # Optional prompt for conditional captioning
# Ensure base64-encoded image is provided
if not image_base64:
raise ValueError("No image data provided. Please provide 'image_base64'.")
# Decode base64 string and convert to RGB image
image_data = BytesIO(base64.b64decode(image_base64))
image = Image.open(image_data).convert("RGB")
# Process inputs with or without a prompt
if prompt:
processed_inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
else:
processed_inputs = self.processor(image, return_tensors="pt").to(self.model.device)
# Generate caption
out = self.model.generate(**processed_inputs)
caption = self.processor.decode(out[0], skip_special_tokens=True)
# Return the generated caption
return {"caption": caption}