Spaces:
Running
Running
from transformers import pipeline | |
from pydantic import BaseModel | |
import logging | |
from fastapi import Request, HTTPException | |
import json | |
from typing import Optional | |
class ClassificationRequest(BaseModel): | |
inputs: str | |
parameters: Optional[dict] = None | |
class ClassificationTaskService: | |
__logger: logging.Logger | |
__task_name: str | |
def __init__(self, logger: logging.Logger, task_name: str): | |
self.__logger = logger | |
self.__task_name = task_name | |
async def get_classification_request( | |
self, | |
request: Request | |
) -> ClassificationRequest: | |
content_type = request.headers.get("content-type", "") | |
if content_type.startswith("application/json"): | |
data = await request.json() | |
return ClassificationRequest(**data) | |
if content_type.startswith("application/x-www-form-urlencoded"): | |
raw = await request.body() | |
try: | |
data = json.loads(raw) | |
return ClassificationRequest(**data) | |
except Exception: | |
try: | |
data = json.loads(raw.decode("utf-8")) | |
return ClassificationRequest(**data) | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid request body") | |
raise HTTPException(status_code=400, detail="Unsupported content type") | |
async def classify( | |
self, | |
request: Request, | |
model_name: str | |
): | |
classificationRequest: ClassificationRequest = await self.get_classification_request(request) | |
try: | |
pipe = pipeline(self.__task_name, model=model_name) | |
except Exception as e: | |
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{model_name}' could not be loaded: {str(e)}" | |
) | |
try: | |
if self.__task_name == "zero-shot-image-classification" or self.__task_name == "zero-shot-classification": | |
candidate_labels = [] | |
if classificationRequest.parameters: | |
candidate_labels = classificationRequest.parameters.get('candidate_labels', []) | |
if isinstance(candidate_labels, str): | |
candidate_labels = [label.strip() for label in candidate_labels.split(',')] | |
result = pipe(classificationRequest.inputs, candidate_labels=candidate_labels) | |
else: # pretrained classification | |
result = pipe(classificationRequest.inputs) | |
except Exception as e: | |
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Inference failed: {str(e)}" | |
) | |
return result |