|
from transformers import CLIPProcessor, CLIPModel, pipeline |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import base64 |
|
from typing import Dict, List, Any |
|
|
|
class EndpointHandler(): |
|
def __init__(self, path="."): |
|
|
|
self.model = CLIPModel.from_pretrained(path) |
|
self.processor = CLIPProcessor.from_pretrained(path) |
|
|
|
|
|
self.classifier = pipeline("zero-shot-image-classification", model=self.model, tokenizer=self.processor) |
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
image_input = data.get("inputs", None) |
|
candidate_labels = data.get("candidate_labels", None) |
|
|
|
if image_input is None or candidate_labels is None: |
|
raise ValueError("Image input or candidate labels not provided") |
|
|
|
|
|
if image_input.startswith("http"): |
|
response = requests.get(image_input) |
|
image = Image.open(BytesIO(response.content)) |
|
else: |
|
|
|
image_data = base64.b64decode(image_input) |
|
image = Image.open(BytesIO(image_data)) |
|
|
|
|
|
results = self.classifier(images=image, candidate_labels=candidate_labels) |
|
|
|
|
|
return results |
|
|