File size: 1,784 Bytes
7bac21a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from transformers import pipeline
import logging
from fastapi import Request, HTTPException
import base64


class TextToImageTaskService:    

    __logger: logging.Logger

    def __init__(self, logger: logging.Logger):
        self.__logger = logger

    async def get_encoded_image(
        self,
        request: Request
    )  -> str: 
        content_type = request.headers.get("content-type", "")
        if content_type.startswith("multipart/form-data"):
            form = await request.form()
            image = form.get("image")
            if image:
                image_bytes = await image.read()
                return base64.b64encode(image_bytes).decode("utf-8")
        if content_type.startswith("image/"):
            image_bytes = await request.body()
            return base64.b64encode(image_bytes).decode("utf-8")

        raise HTTPException(status_code=400, detail="Unsupported content type")

    async def extract(
        self,
        request: Request,
        model_name: str
    ): 
        encoded_image = await self.get_encoded_image(request)

        try:
            pipe = pipeline("image-to-text", model=model_name, use_fast=True)
        except Exception as e:
            self.__logger.error(f"Failed to load model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=404,
                detail=f"Model '{model_name}' could not be loaded: {str(e)}"
            )

        try:       
            result = pipe(encoded_image)
        except Exception as e:
            self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}")
            raise HTTPException(
                status_code=500,
                detail=f"Inference failed: {str(e)}"
            )

        return result