darshankr commited on
Commit
a173ade
·
verified ·
1 Parent(s): cafe5d4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -6
app.py CHANGED
@@ -1,20 +1,64 @@
1
  from fastapi import FastAPI, HTTPException
2
- from pydantic import BaseModel
3
- from transformers import pipeline
 
 
 
 
4
 
5
- # Initialize FastAPI and load your Hugging Face model
6
  app = FastAPI()
7
- model = pipeline("text-classification", model="your-username/your-model-name")
8
 
9
  # Define request body with Pydantic
10
  class InputData(BaseModel):
11
- text: str
 
12
 
13
  # API endpoint to receive input and return predictions
14
- @app.post("/predict/")
15
  async def predict(input_data: InputData):
16
  try:
17
  result = model(input_data.text)
18
  return {"output": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  except Exception as e:
20
  raise HTTPException(status_code=500, detail=str(e))
 
1
  from fastapi import FastAPI, HTTPException
2
+ from transformers
3
+ import AutoModelForSeq2SeqLM
4
+ from IndicTransToolkit import IndicProcessor
5
+
6
+ model = AutoModelForSeq2SeqLM.from_pretrained("ai4bharat/indictrans2-en-indic-1B", trust_remote_code=True)
7
+ ip = IndicProcessor(inference=True)
8
 
 
9
  app = FastAPI()
 
10
 
11
  # Define request body with Pydantic
12
  class InputData(BaseModel):
13
+ sentences: str[]
14
+ target_lang: str
15
 
16
  # API endpoint to receive input and return predictions
17
+ @app.post("/translate/")
18
  async def predict(input_data: InputData):
19
  try:
20
  result = model(input_data.text)
21
  return {"output": result}
22
+ src_lang, tgt_lang = "eng_Latn", input_data.target_lang
23
+
24
+ batch = ip.preprocess_batch(
25
+ input_sentences,
26
+ src_lang=src_lang,
27
+ tgt_lang=tgt_lang,
28
+ )
29
+
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
+
32
+ # Tokenize the sentences and generate input encodings
33
+ inputs = tokenizer(
34
+ batch,
35
+ truncation=True,
36
+ padding="longest",
37
+ return_tensors="pt",
38
+ return_attention_mask=True,
39
+ ).to(DEVICE)
40
+
41
+ # Generate translations using the model
42
+ with torch.no_grad():
43
+ generated_tokens = model.generate(
44
+ **inputs,
45
+ use_cache=True,
46
+ min_length=0,
47
+ max_length=256,
48
+ num_beams=5,
49
+ num_return_sequences=1,
50
+ )
51
+
52
+ # Decode the generated tokens into text
53
+ with tokenizer.as_target_tokenizer():
54
+ generated_tokens = tokenizer.batch_decode(
55
+ generated_tokens.detach().cpu().tolist(),
56
+ skip_special_tokens=True,
57
+ clean_up_tokenization_spaces=True,
58
+ )
59
+
60
+ # Postprocess the translations, including entity replacement
61
+ translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
62
+ return {"output": translations}
63
  except Exception as e:
64
  raise HTTPException(status_code=500, detail=str(e))