blip2-opt-2.7b-coco / handler.py
adasdimchom's picture
Upload handler.py
1347a75
raw
history blame
1.49 kB
from transformers import Blip2Processor, Blip2Model
from typing import Dict, List, Any
from PIL import Image
from transformers import pipeline
import requests
import torch
class EndpointHandler():
def __init__(self, path=""):
"""
path:
"""
# Preload all the elements you are going to need at inference.
# pseudo:
# self.model= load_model(path)
#self.processor = Blip2Processor.from_pretrained(path)
#self.pipeline = pipeline(model = path)
self.path = path
self.device = "cuda" if torch.cuda.is_available() else "cpu"
#self.processor = Blip2Processor.from_pretrained(path)
#self.model = Blip2Model.from_pretrained(path, torch_dtype=torch.float16)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
"""
data args:
inputs (:obj: `str` | `PIL.Image` | `np.array`)
kwargs
Return:
A :obj:`list` | `dict`: will be serialized and returned
"""
inputs = data.pop("inputs", data)
image_url = inputs['image_url']
#image = Image.open(requests.get(image_url, stream=True).raw)
#processed_image = self.processor(images=image, return_tensors="pt").to(self.device, torch.float16)
#generated_ids = self.pipeline(**inputs)
#generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
return image_url, self.path, self.device