rroset's picture
Update handler.py
c2d2a0b verified
raw
history blame
1.34 kB
from transformers import pipeline
from PIL import Image
import requests
from io import BytesIO
import base64
from typing import Dict, List, Any
class EndpointHandler():
def __init__(self, path=""):
# Crea la pipeline de classificaci贸 d'imatges zero-shot amb el model espec铆fic
self.classifier = pipeline("zero-shot-image-classification", model="rroset/CLIP-ViT-B-32-laion2B-s34B-b79K")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
# Obt茅 l'imatge en base64 i els par脿metres de les dades
image_base64 = data.get("inputs", None)
parameters = data.get("parameters", None)
if image_base64 is None or parameters is None:
raise ValueError("Input data or parameters not provided")
# Obt茅 les etiquetes candidates dels par脿metres
candidate_labels = parameters.get("candidate_labels", None)
if candidate_labels is None:
raise ValueError("Candidate labels not provided")
# Decodifica la imatge des de base64
image_data = base64.b64decode(image_base64)
image = Image.open(BytesIO(image_data))
# Realitza la classificaci贸 zero-shot
results = self.classifier(images=image, candidate_labels=candidate_labels)
# Torna els resultats processats
return results