blip / handler.py
saisriteja's picture
Initial commit for checking the handler for the blip model
192a99b verified
from typing import Dict, List, Any
from transformers import BlipProcessor, BlipForConditionalGeneration
from PIL import Image
import requests
import torch
class EndpointHandler():
def __init__(self, path="./"):
# Load the processor and model, and move to CUDA if available
self.processor = BlipProcessor.from_pretrained(path)
self.model = BlipForConditionalGeneration.from_pretrained(path).to("cuda" if torch.cuda.is_available() else "cpu")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
image_url (:obj: `str`): URL of the image to caption
prompt (:obj: `str`, optional): Text prompt for conditional captioning
Return:
A :obj:`list` with caption as `dict`
"""
# Get inputs from the data
image_url = data.get("image_url")
prompt = data.get("prompt", "") # Optional prompt for conditional captioning
# Load image from URL and ensure RGB format
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")
# Conditional or Unconditional Captioning
if prompt:
# Conditional captioning
inputs = self.processor(image, prompt, return_tensors="pt").to(self.model.device)
else:
# Unconditional captioning
inputs = self.processor(image, return_tensors="pt").to(self.model.device)
# Generate caption
out = self.model.generate(**inputs)
caption = self.processor.decode(out[0], skip_special_tokens=True)
# Return the generated caption
return [{"caption": caption}]