local-inference-service / src /classification.py
fashxp's picture
additional tasks
7bac21a
raw
history blame
2.91 kB
from transformers import pipeline
from pydantic import BaseModel
import logging
from fastapi import Request, HTTPException
import json
from typing import Optional
class ClassificationRequest(BaseModel):
inputs: str
parameters: Optional[dict] = None
class ClassificationTaskService:
__logger: logging.Logger
__task_name: str
def __init__(self, logger: logging.Logger, task_name: str):
self.__logger = logger
self.__task_name = task_name
async def get_classification_request(
self,
request: Request
) -> ClassificationRequest:
content_type = request.headers.get("content-type", "")
if content_type.startswith("application/json"):
data = await request.json()
return ClassificationRequest(**data)
if content_type.startswith("application/x-www-form-urlencoded"):
raw = await request.body()
try:
data = json.loads(raw)
return ClassificationRequest(**data)
except Exception:
try:
data = json.loads(raw.decode("utf-8"))
return ClassificationRequest(**data)
except Exception:
raise HTTPException(status_code=400, detail="Invalid request body")
raise HTTPException(status_code=400, detail="Unsupported content type")
async def classify(
self,
request: Request,
model_name: str
):
classificationRequest: ClassificationRequest = await self.get_classification_request(request)
try:
pipe = pipeline(self.__task_name, model=model_name)
except Exception as e:
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' could not be loaded: {str(e)}"
)
try:
if self.__task_name == "zero-shot-image-classification" or self.__task_name == "zero-shot-classification":
candidate_labels = []
if classificationRequest.parameters:
candidate_labels = classificationRequest.parameters.get('candidate_labels', [])
if isinstance(candidate_labels, str):
candidate_labels = [label.strip() for label in candidate_labels.split(',')]
result = pipe(classificationRequest.inputs, candidate_labels=candidate_labels)
else: # pretrained classification
result = pipe(classificationRequest.inputs)
except Exception as e:
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Inference failed: {str(e)}"
)
return result