Spaces:
Running
Running
File size: 2,913 Bytes
7bac21a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
from transformers import pipeline
from pydantic import BaseModel
import logging
from fastapi import Request, HTTPException
import json
from typing import Optional
class ClassificationRequest(BaseModel):
inputs: str
parameters: Optional[dict] = None
class ClassificationTaskService:
__logger: logging.Logger
__task_name: str
def __init__(self, logger: logging.Logger, task_name: str):
self.__logger = logger
self.__task_name = task_name
async def get_classification_request(
self,
request: Request
) -> ClassificationRequest:
content_type = request.headers.get("content-type", "")
if content_type.startswith("application/json"):
data = await request.json()
return ClassificationRequest(**data)
if content_type.startswith("application/x-www-form-urlencoded"):
raw = await request.body()
try:
data = json.loads(raw)
return ClassificationRequest(**data)
except Exception:
try:
data = json.loads(raw.decode("utf-8"))
return ClassificationRequest(**data)
except Exception:
raise HTTPException(status_code=400, detail="Invalid request body")
raise HTTPException(status_code=400, detail="Unsupported content type")
async def classify(
self,
request: Request,
model_name: str
):
classificationRequest: ClassificationRequest = await self.get_classification_request(request)
try:
pipe = pipeline(self.__task_name, model=model_name)
except Exception as e:
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' could not be loaded: {str(e)}"
)
try:
if self.__task_name == "zero-shot-image-classification" or self.__task_name == "zero-shot-classification":
candidate_labels = []
if classificationRequest.parameters:
candidate_labels = classificationRequest.parameters.get('candidate_labels', [])
if isinstance(candidate_labels, str):
candidate_labels = [label.strip() for label in candidate_labels.split(',')]
result = pipe(classificationRequest.inputs, candidate_labels=candidate_labels)
else: # pretrained classification
result = pipe(classificationRequest.inputs)
except Exception as e:
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Inference failed: {str(e)}"
)
return result |