local-inference-service / src /translation_task.py
fashxp's picture
additional tasks
7bac21a
raw
history blame
2.29 kB
from transformers import pipeline
from pydantic import BaseModel
import logging
from fastapi import Request, HTTPException
import json
from typing import Optional
class TranslationRequest(BaseModel):
inputs: str
parameters: Optional[dict] = None
options: Optional[dict] = None
class TranslationTaskService:
__logger: logging.Logger
def __init__(self, logger: logging.Logger):
self.__logger = logger
async def get_translation_request(
self,
request: Request
) -> TranslationRequest:
content_type = request.headers.get("content-type", "")
if content_type.startswith("application/json"):
data = await request.json()
return TranslationRequest(**data)
if content_type.startswith("application/x-www-form-urlencoded"):
raw = await request.body()
try:
data = json.loads(raw)
return TranslationRequest(**data)
except Exception:
try:
data = json.loads(raw.decode("utf-8"))
return TranslationRequest(**data)
except Exception:
raise HTTPException(status_code=400, detail="Invalid request body")
raise HTTPException(status_code=400, detail="Unsupported content type")
async def translate(
self,
request: Request,
model_name: str
):
translationRequest: TranslationRequest = await self.get_translation_request(request)
try:
pipe = pipeline("translation", model=model_name)
except Exception as e:
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}")
raise HTTPException(
status_code=404,
detail=f"Model '{model_name}' could not be loaded: {str(e)}"
)
try:
result = pipe(translationRequest.inputs, **(translationRequest.parameters or {}))
return result
except Exception as e:
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}")
raise HTTPException(
status_code=500,
detail=f"Inference failed: {str(e)}"
)