File size: 1,079 Bytes
ae68214 6b0c845 5cce52e d94fef7 ae68214 d94fef7 32f759e ae68214 40bec71 ae68214 2a7e830 6b0c845 2a7e830 ae68214 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 |
from typing import Any, Dict
from transformers import Blip2Processor, Blip2ForConditionalGeneration
import io
from PIL import Image
import base64
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.processor = Blip2Processor.from_pretrained(path)
self.model = Blip2ForConditionalGeneration.from_pretrained(path, torch_dtype=torch.float16)
self.device = "cuda"
self.model.to(self.device)
def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
# process input
data = data.pop("inputs", data)
text = data.pop("text", data)
image_string = base64.b64decode(data["image"])
image = Image.open(io.BytesIO(image_string))
inputs = self.processor(images=image, text=text, return_tensors="pt").to(self.device, torch.float16)
generated_ids = self.model.generate(**inputs)
generated_text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return [{"answer": generated_text}] |