anvilum's picture
modify handler to receive a list of url instead of image bytes
d5da980
import requests
from typing import Dict, Any
from PIL import Image
import torch
from io import BytesIO
from transformers import BlipForConditionalGeneration, BlipProcessor
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class EndpointHandler():
def __init__(self, path=""):
self.processor = BlipProcessor.from_pretrained(
"Salesforce/blip-image-captioning-large")
self.model = BlipForConditionalGeneration.from_pretrained(
"Salesforce/blip-image-captioning-large"
).to(device)
self.model.eval()
def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
input_data = data.get("inputs", {})
image_urls = input_data.get("image_urls", [])
if not image_urls:
return {"captions": [], "error": "No images provided"}
texts = input_data.get(
"texts", [""] * len(image_urls))
if len(image_urls) != len(texts):
return {
"captions": [],
"error": "Texts and images should have the same length"
}
images_data = [requests.get(url).content for url in image_urls]
try:
raw_images = [
Image.open(BytesIO((img))).convert("RGB")
for img in images_data]
processed_inputs = [
self.processor(image, text, return_tensors="pt")
for image, text in zip(raw_images, texts)
]
processed_inputs = {
"pixel_values": torch.cat(
[inp["pixel_values"]
for inp in processed_inputs], dim=0).to(device),
"input_ids": torch.cat(
[inp["input_ids"]
for inp in processed_inputs], dim=0).to(device),
"attention_mask": torch.cat(
[inp["attention_mask"]
for inp in processed_inputs], dim=0).to(device)
}
with torch.no_grad():
out = self.model.generate(**processed_inputs)
captions = self.processor.batch_decode(
out, skip_special_tokens=True)
return {"captions": captions}
except Exception as e:
print(f"Error during processing: {str(e)}")
return {"captions": [], "error": str(e)}