from transformers import pipeline from pydantic import BaseModel import logging from fastapi import Request, HTTPException import json from typing import Optional class ClassificationRequest(BaseModel): inputs: str parameters: Optional[dict] = None class ClassificationTaskService: __logger: logging.Logger __task_name: str def __init__(self, logger: logging.Logger, task_name: str): self.__logger = logger self.__task_name = task_name async def get_classification_request( self, request: Request ) -> ClassificationRequest: content_type = request.headers.get("content-type", "") if content_type.startswith("application/json"): data = await request.json() return ClassificationRequest(**data) if content_type.startswith("application/x-www-form-urlencoded"): raw = await request.body() try: data = json.loads(raw) return ClassificationRequest(**data) except Exception: try: data = json.loads(raw.decode("utf-8")) return ClassificationRequest(**data) except Exception: raise HTTPException(status_code=400, detail="Invalid request body") raise HTTPException(status_code=400, detail="Unsupported content type") async def classify( self, request: Request, model_name: str ): classificationRequest: ClassificationRequest = await self.get_classification_request(request) try: pipe = pipeline(self.__task_name, model=model_name) except Exception as e: self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") raise HTTPException( status_code=404, detail=f"Model '{model_name}' could not be loaded: {str(e)}" ) try: if self.__task_name == "zero-shot-image-classification" or self.__task_name == "zero-shot-classification": candidate_labels = [] if classificationRequest.parameters: candidate_labels = classificationRequest.parameters.get('candidate_labels', []) if isinstance(candidate_labels, str): candidate_labels = [label.strip() for label in candidate_labels.split(',')] result = pipe(classificationRequest.inputs, candidate_labels=candidate_labels) else: # pretrained classification result = pipe(classificationRequest.inputs) except Exception as e: self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") raise HTTPException( status_code=500, detail=f"Inference failed: {str(e)}" ) return result