rroset's picture
Update handler.py
266aad6 verified
raw
history blame
1.49 kB
from transformers import CLIPProcessor, CLIPModel, pipeline
from PIL import Image
import requests
from io import BytesIO
import base64
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, path=""):
# Carrega el model i el processor espec铆fics de CLIP
self.model = CLIPModel.from_pretrained(path)
self.processor = CLIPProcessor.from_pretrained(path)
# Crea la pipeline de classificaci贸 d'imatges zero-shot
self.classifier = pipeline("zero-shot-image-classification", model=self.model, tokenizer=self.processor)
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
image_input = data.get("inputs", None)
candidate_labels = data.get("candidate_labels", None)
if image_input is None or candidate_labels is None:
raise ValueError("Image input or candidate labels not provided")
# Determina si l'input 茅s una URL o una cadena base64
if image_input.startswith("http"):
response = requests.get(image_input)
image = Image.open(BytesIO(response.content))
else:
# Suposa que l'input 茅s base64 i decodifica-la
image_data = base64.b64decode(image_input)
image = Image.open(BytesIO(image_data))
# Realitza la classificaci贸 zero-shot
results = self.classifier(images=image, candidate_labels=candidate_labels)
# Torna els resultats processats
return results