Spaces:
Running
Running
File size: 2,146 Bytes
ad152ab bae6852 ad152ab bae6852 d39f3fd 0b8919f d39f3fd bae6852 d39f3fd bae6852 ad152ab bae6852 d39f3fd bae6852 d39f3fd bae6852 d39f3fd bae6852 ad152ab bae6852 ad152ab bae6852 ad152ab bae6852 0bb4b6a ad152ab bae6852 ad152ab bae6852 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from fastapi import FastAPI, HTTPException
from typing import List
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
from fastapi.middleware.cors import CORSMiddleware
import torch
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
model = AutoModelForSeq2SeqLM.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
"ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)
def translate_text(sentences: List[str], target_lang: str):
try:
src_lang = "eng_Latn"
batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
inputs = tokenizer(
batch,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
generated_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
)
return generated_tokens
except Exception as e:
return str(e)
@app.get("/")
def read_root():
return {"Hello": "World"}
class TranslateRequest(BaseModel):
sentences: List[str]
target_lang: str
@app.get("/translate/")
def translate(request: TranslateRequest):
try:
result = translate_text(request.sentences, request.target_lang)
return result
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
|