Spaces:
Running
Running
from transformers import pipeline | |
from pydantic import BaseModel | |
import logging | |
from fastapi import Request, HTTPException | |
import json | |
from typing import Optional | |
class TranslationRequest(BaseModel): | |
inputs: str | |
parameters: Optional[dict] = None | |
options: Optional[dict] = None | |
class TranslationTaskService: | |
__logger: logging.Logger | |
def __init__(self, logger: logging.Logger): | |
self.__logger = logger | |
async def get_translation_request( | |
self, | |
request: Request | |
) -> TranslationRequest: | |
content_type = request.headers.get("content-type", "") | |
if content_type.startswith("application/json"): | |
data = await request.json() | |
return TranslationRequest(**data) | |
if content_type.startswith("application/x-www-form-urlencoded"): | |
raw = await request.body() | |
try: | |
data = json.loads(raw) | |
return TranslationRequest(**data) | |
except Exception: | |
try: | |
data = json.loads(raw.decode("utf-8")) | |
return TranslationRequest(**data) | |
except Exception: | |
raise HTTPException(status_code=400, detail="Invalid request body") | |
raise HTTPException(status_code=400, detail="Unsupported content type") | |
async def translate( | |
self, | |
request: Request, | |
model_name: str | |
): | |
translationRequest: TranslationRequest = await self.get_translation_request(request) | |
try: | |
pipe = pipeline("translation", model=model_name) | |
except Exception as e: | |
self.__logger.error(f"Failed to load model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{model_name}' could not be loaded: {str(e)}" | |
) | |
try: | |
result = pipe(translationRequest.inputs, **(translationRequest.parameters or {})) | |
return result | |
except Exception as e: | |
self.__logger.error(f"Inference failed for model '{model_name}': {str(e)}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Inference failed: {str(e)}" | |
) |