local-inference-service / src /image_classification.py
fashxp's picture
additional tasks
7bac21a
raw
history blame
2.91 kB
from transformers import pipeline
from pydantic import BaseModel
import logging
from fastapi import Request, HTTPException
import json
from typing import Optional
class ImageClassificationRequest(BaseModel):
inputs: str
parameters: Optional[dict] = None
class ImageClassificationTaskService:
__logger: logging.Logger
__task_name: str
def __init__(self, logger: logging.Logger, task_name: str = "image-classification"):
self.__logger = logger
self.__task_name = task_name
async def get_image_classification_request(
self,
request: Request
) -> ImageClassificationRequest:
content_type = request.headers.get("content-type", "")
if content_type.startswith("application/json"):
data = await request.json()
return ImageClassificationRequest(**data)
if content_type.startswith("application/x-www-form-urlencoded"):
raw = await request.body()
try:
data = json.loads(raw)
return ImageClassificationRequest(**data)
except Exception:
try:
data = json.loads(raw.decode("utf-8"))
return ImageClassificationRequest(**data)
except Exception:
raise HTTPException(status_code=400, detail="Invalid request body")
raise HTTPException(status_code=400, detail="Unsupported content type")
async def classify(
self,
request: Request,
model_name: str
):
imageRequest: ImageClassificationRequest = await self.get_image_classification_request(request)
try:
pipe = pipeline(self.__task_name, model=model_name)
except Exception as e:
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' could not be loaded: {str(e)}"
)
try:
if self.__task_name == "zero-shot-image-classification":
candidate_labels = []
if imageRequest.parameters:
candidate_labels = imageRequest.parameters.get('candidate_labels', [])
if isinstance(candidate_labels, str):
candidate_labels = [label.strip() for label in candidate_labels.split(',')]
result = pipe(imageRequest.inputs, candidate_labels=candidate_labels)
else: # image classification
result = pipe(imageRequest.inputs)
except Exception as e:
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Inference failed: {str(e)}"
)
return result