File size: 2,146 Bytes
ad152ab
bae6852
ad152ab
bae6852
 
d39f3fd
0b8919f
d39f3fd
bae6852
d39f3fd
 
 
 
 
 
 
 
 
bae6852
 
 
 
 
 
ad152ab
bae6852
 
 
 
d39f3fd
bae6852
 
 
 
 
 
 
 
 
 
 
d39f3fd
bae6852
 
 
 
 
 
 
 
 
d39f3fd
bae6852
 
 
 
 
 
ad152ab
bae6852
ad152ab
 
 
 
 
 
bae6852
 
ad152ab
 
 
bae6852
 
0bb4b6a
ad152ab
bae6852
ad152ab
bae6852
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI, HTTPException
from typing import List
from pydantic import BaseModel
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
from fastapi.middleware.cors import CORSMiddleware
import torch

app = FastAPI()

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

model = AutoModelForSeq2SeqLM.from_pretrained(
    "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)
tokenizer = AutoTokenizer.from_pretrained(
    "ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True
)

ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(DEVICE)


def translate_text(sentences: List[str], target_lang: str):
    try:
        src_lang = "eng_Latn"
        batch = ip.preprocess_batch(sentences, src_lang=src_lang, tgt_lang=target_lang)
        inputs = tokenizer(
            batch,
            truncation=True,
            padding="longest",
            return_tensors="pt",
            return_attention_mask=True,
        ).to(DEVICE)

        with torch.no_grad():
            generated_tokens = model.generate(
                **inputs,
                use_cache=True,
                min_length=0,
                max_length=256,
                num_beams=5,
                num_return_sequences=1,
            )

        with tokenizer.as_target_tokenizer():
            generated_tokens = tokenizer.batch_decode(
                generated_tokens.detach().cpu().tolist(),
                skip_special_tokens=True,
            )

        return generated_tokens
    except Exception as e:
        return str(e)


@app.get("/")
def read_root():
    return {"Hello": "World"}


class TranslateRequest(BaseModel):
    sentences: List[str]
    target_lang: str


@app.get("/translate/")
def translate(request: TranslateRequest):
    try:
        result = translate_text(request.sentences, request.target_lang)
        return result
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))