local-inference-service / src /text_to_image.py
fashxp's picture
additional tasks
7bac21a
raw
history blame
1.78 kB
from transformers import pipeline
import logging
from fastapi import Request, HTTPException
import base64
class TextToImageTaskService:
__logger: logging.Logger
def __init__(self, logger: logging.Logger):
self.__logger = logger
async def get_encoded_image(
self,
request: Request
) -> str:
content_type = request.headers.get("content-type", "")
if content_type.startswith("multipart/form-data"):
form = await request.form()
image = form.get("image")
if image:
image_bytes = await image.read()
return base64.b64encode(image_bytes).decode("utf-8")
if content_type.startswith("image/"):
image_bytes = await request.body()
return base64.b64encode(image_bytes).decode("utf-8")
raise HTTPException(status_code=400, detail="Unsupported content type")
async def extract(
self,
request: Request,
model_name: str
):
encoded_image = await self.get_encoded_image(request)
try:
pipe = pipeline("image-to-text", model=model_name, use_fast=True)
except Exception as e:
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' could not be loaded: {str(e)}"
)
try:
result = pipe(encoded_image)
except Exception as e:
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Inference failed: {str(e)}"
)
return result