File size: 2,905 Bytes
7bac21a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from transformers import pipeline
from pydantic import BaseModel
import logging
from fastapi import Request, HTTPException
import json
from typing import Optional


class ImageClassificationRequest(BaseModel):
    inputs: str
    parameters: Optional[dict] = None

class ImageClassificationTaskService:    

    __logger: logging.Logger
    __task_name: str

    def __init__(self, logger: logging.Logger, task_name: str = "image-classification"):
        self.__logger = logger
        self.__task_name = task_name

    async def get_image_classification_request(
        self,
        request: Request
    )  -> ImageClassificationRequest: 
        content_type = request.headers.get("content-type", "")
        if content_type.startswith("application/json"):
            data = await request.json()
            return ImageClassificationRequest(**data)
        if content_type.startswith("application/x-www-form-urlencoded"):
            raw = await request.body()
            try:
                data = json.loads(raw)
                return ImageClassificationRequest(**data)
            except Exception:
                try:
                    data = json.loads(raw.decode("utf-8"))
                    return ImageClassificationRequest(**data)
                except Exception:
                    raise HTTPException(status_code=400, detail="Invalid request body")
        raise HTTPException(status_code=400, detail="Unsupported content type")


    async def classify(
        self,
        request: Request,
        model_name: str
    ): 

        imageRequest: ImageClassificationRequest = await self.get_image_classification_request(request)

        try:
            pipe = pipeline(self.__task_name, model=model_name)
        except Exception as e:
            self.__logger.error(f"Failed to load model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=404,
                detail=f"Model '{model_name}' could not be loaded: {str(e)}"
            )

        try:       
            
            if self.__task_name == "zero-shot-image-classification":
                candidate_labels = []

                if imageRequest.parameters:
                    candidate_labels = imageRequest.parameters.get('candidate_labels', [])
                    if isinstance(candidate_labels, str):
                        candidate_labels = [label.strip() for label in candidate_labels.split(',')]
                result = pipe(imageRequest.inputs, candidate_labels=candidate_labels)

            else:  # image classification
                result = pipe(imageRequest.inputs)

        except Exception as e:
            self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Inference failed: {str(e)}"
            )

        return result