Spaces:
Sleeping
Sleeping
File size: 3,447 Bytes
4d6d915 40b6bb8 4d6d915 d3f18e8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
from fastapi import FastAPI, Depends, HTTPException, Request
from fastapi.security import APIKeyQuery
from pydantic import BaseModel
from typing import List, Union, Dict
from functools import lru_cache
import jwt
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torch
from flores200_codes import flores_codes
import gradio as gr
CUSTOM_PATH = "/gradio"
app = FastAPI()
# This should be a secure secret key in a real application
SECRET_KEY = "your_secret_key_here"
# Define the security scheme
api_key_query = APIKeyQuery(name="jwtToken", auto_error=False)
class TranslationRequest(BaseModel):
strings: List[Union[str, Dict[str, str]]]
class TranslationResponse(BaseModel):
data: Dict[str, List[str]]
@lru_cache()
def load_model():
model_name_dict = {
"nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
}
call_name = "nllb-distilled-600M"
real_name = model_name_dict[call_name]
print(f"\tLoading model: {call_name}")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForSeq2SeqLM.from_pretrained(real_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(real_name)
return model, tokenizer
model, tokenizer = load_model()
def translate_text(text: List[str], source_lang: str, target_lang: str) -> List[str]:
source = flores_codes[source_lang]
target = flores_codes[target_lang]
translator = pipeline(
"translation",
model=model,
tokenizer=tokenizer,
src_lang=source,
tgt_lang=target,
)
output = translator(text, max_length=400)
return [item["translation_text"] for item in output]
async def verify_token(token: str = Depends(api_key_query)):
if not token:
raise HTTPException(status_code=401, detail={"message": "Token is missing"})
try:
jwt.decode(token, SECRET_KEY, algorithms=["HS256"])
except:
raise HTTPException(status_code=401, detail={"message": "Token is invalid"})
return token
@app.get("/translate/", response_model=TranslationResponse)
@app.post("/translate/", response_model=TranslationResponse)
async def translate(
request: Request,
source: str,
target: str,
project_id: str,
token: str = Depends(verify_token),
):
if not all([source, target, project_id]):
raise HTTPException(
status_code=400, detail={"message": "Missing required parameters"}
)
data = await request.json()
strings = data.get("strings", [])
if not strings:
raise HTTPException(
status_code=400, detail={"message": "No strings provided for translation"}
)
try:
if isinstance(strings[0], dict): # Extended request
translations = translate_text([s["text"] for s in strings], source, target)
else: # Simple request
translations = translate_text(strings, source, target)
return TranslationResponse(data={"translations": translations})
except Exception as e:
raise HTTPException(status_code=500, detail={"message": str(e)})
@app.get("/logo.png")
async def logo():
# TODO: Implement logic to serve the logo
return "Logo placeholder"
io = gr.Interface(lambda x: "Hello, " + x + "!", "textbox", "textbox")
app = gr.mount_gradio_app(app, io, path=CUSTOM_PATH)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|