darshankr commited on
Commit
2b53f2b
·
verified ·
1 Parent(s): b864750

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -74
app.py CHANGED
@@ -1,84 +1,52 @@
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from typing import List
4
  import torch
5
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
6
  from IndicTransToolkit import IndicProcessor
 
 
 
7
 
8
- # Initialize FastAPI app
9
- app = FastAPI(
10
- title="Indic Translation API",
11
- description="API for translating text between English and Indic languages",
12
- version="1.0.0"
 
 
 
 
 
 
13
  )
 
 
 
14
 
15
- # Define request body model
16
  class InputData(BaseModel):
17
  sentences: List[str]
18
  target_lang: str
19
 
20
- class Config:
21
- schema_extra = {
22
- "example": {
23
- "sentences": ["Hello, how are you?", "What is your name?"],
24
- "target_lang": "hin_Deva"
25
- }
26
- }
27
-
28
- # Initialize models and processors
29
- try:
30
- model = AutoModelForSeq2SeqLM.from_pretrained(
31
- "ai4bharat/indictrans2-en-indic-1B",
32
- trust_remote_code=True
33
- )
34
- tokenizer = AutoTokenizer.from_pretrained(
35
- "ai4bharat/indictrans2-en-indic-1B",
36
- trust_remote_code=True
37
- )
38
- ip = IndicProcessor(inference=True)
39
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
40
- model = model.to(DEVICE)
41
- except Exception as e:
42
- raise RuntimeError(f"Failed to load models: {str(e)}")
43
-
44
- @app.get("/")
45
- async def root():
46
- """Root endpoint returning API information"""
47
- return {
48
- "message": "Welcome to the Indic Translation API",
49
- "status": "active",
50
- "supported_languages": [
51
- "hin_Deva", # Hindi
52
- "ben_Beng", # Bengali
53
- "tam_Taml", # Tamil
54
- # Add other supported languages here
55
- ]
56
- }
57
 
58
  @app.post("/translate/")
59
  async def translate(input_data: InputData):
60
- """
61
- Translate text from English to specified Indic language
62
-
63
- Args:
64
- input_data: InputData object containing sentences and target language
65
-
66
- Returns:
67
- Dictionary containing translated text
68
- """
69
  try:
70
- # Source language is always English
71
  src_lang = "eng_Latn"
72
  tgt_lang = input_data.target_lang
73
-
74
- # Preprocess the input sentences
75
  batch = ip.preprocess_batch(
76
  input_data.sentences,
77
  src_lang=src_lang,
78
  tgt_lang=tgt_lang
79
  )
80
-
81
- # Tokenize the sentences
82
  inputs = tokenizer(
83
  batch,
84
  truncation=True,
@@ -86,8 +54,7 @@ async def translate(input_data: InputData):
86
  return_tensors="pt",
87
  return_attention_mask=True
88
  ).to(DEVICE)
89
-
90
- # Generate translations
91
  with torch.no_grad():
92
  generated_tokens = model.generate(
93
  **inputs,
@@ -97,32 +64,77 @@ async def translate(input_data: InputData):
97
  num_beams=5,
98
  num_return_sequences=1
99
  )
100
-
101
- # Decode the generated tokens
102
  with tokenizer.as_target_tokenizer():
103
  generated_tokens = tokenizer.batch_decode(
104
  generated_tokens.detach().cpu().tolist(),
105
  skip_special_tokens=True,
106
  clean_up_tokenization_spaces=True
107
  )
108
-
109
- # Postprocess the translations
110
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
111
-
112
  return {
113
  "translations": translations,
114
  "source_language": src_lang,
115
  "target_language": tgt_lang
116
  }
117
-
118
  except Exception as e:
119
- raise HTTPException(
120
- status_code=500,
121
- detail=f"Translation error: {str(e)}"
122
- )
123
 
124
- # Add health check endpoint
125
- @app.get("/health")
126
- async def health_check():
127
- """Health check endpoint"""
128
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import streamlit as st
3
  from fastapi import FastAPI, HTTPException
4
  from pydantic import BaseModel
5
  from typing import List
6
  import torch
7
  from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
8
  from IndicTransToolkit import IndicProcessor
9
+ import uvicorn
10
+ import nest_asyncio
11
+ import threading
12
 
13
+ # Initialize FastAPI
14
+ app = FastAPI()
15
+
16
+ # Initialize models and processors
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(
18
+ "ai4bharat/indictrans2-en-indic-1B",
19
+ trust_remote_code=True
20
+ )
21
+ tokenizer = AutoTokenizer.from_pretrained(
22
+ "ai4bharat/indictrans2-en-indic-1B",
23
+ trust_remote_code=True
24
  )
25
+ ip = IndicProcessor(inference=True)
26
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ model = model.to(DEVICE)
28
 
 
29
  class InputData(BaseModel):
30
  sentences: List[str]
31
  target_lang: str
32
 
33
+ # FastAPI endpoints
34
+ @app.get("/health")
35
+ async def health_check():
36
+ return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
  @app.post("/translate/")
39
  async def translate(input_data: InputData):
 
 
 
 
 
 
 
 
 
40
  try:
 
41
  src_lang = "eng_Latn"
42
  tgt_lang = input_data.target_lang
43
+
 
44
  batch = ip.preprocess_batch(
45
  input_data.sentences,
46
  src_lang=src_lang,
47
  tgt_lang=tgt_lang
48
  )
49
+
 
50
  inputs = tokenizer(
51
  batch,
52
  truncation=True,
 
54
  return_tensors="pt",
55
  return_attention_mask=True
56
  ).to(DEVICE)
57
+
 
58
  with torch.no_grad():
59
  generated_tokens = model.generate(
60
  **inputs,
 
64
  num_beams=5,
65
  num_return_sequences=1
66
  )
67
+
 
68
  with tokenizer.as_target_tokenizer():
69
  generated_tokens = tokenizer.batch_decode(
70
  generated_tokens.detach().cpu().tolist(),
71
  skip_special_tokens=True,
72
  clean_up_tokenization_spaces=True
73
  )
74
+
 
75
  translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang)
76
+
77
  return {
78
  "translations": translations,
79
  "source_language": src_lang,
80
  "target_language": tgt_lang
81
  }
82
+
83
  except Exception as e:
84
+ raise HTTPException(status_code=500, detail=str(e))
 
 
 
85
 
86
+ # Streamlit interface
87
+ def main():
88
+ st.title("Indic Language Translator")
89
+
90
+ # Input text
91
+ text_input = st.text_area("Enter text to translate:", "Hello, how are you?")
92
+
93
+ # Language selection
94
+ target_languages = {
95
+ "Hindi": "hin_Deva",
96
+ "Bengali": "ben_Beng",
97
+ "Tamil": "tam_Taml",
98
+ "Telugu": "tel_Telu",
99
+ "Marathi": "mar_Deva",
100
+ "Gujarati": "guj_Gujr",
101
+ "Kannada": "kan_Knda",
102
+ "Malayalam": "mal_Mlym",
103
+ "Punjabi": "pan_Guru",
104
+ "Odia": "ori_Orya"
105
+ }
106
+
107
+ target_lang = st.selectbox(
108
+ "Select target language:",
109
+ options=list(target_languages.keys())
110
+ )
111
+
112
+ if st.button("Translate"):
113
+ try:
114
+ # Prepare input data
115
+ input_data = InputData(
116
+ sentences=[text_input],
117
+ target_lang=target_languages[target_lang]
118
+ )
119
+
120
+ # Call translation function directly
121
+ result = translate(input_data)
122
+
123
+ # Display result
124
+ st.success("Translation:")
125
+ st.write(result["translations"][0])
126
+
127
+ except Exception as e:
128
+ st.error(f"Translation failed: {str(e)}")
129
+
130
+ def run_fastapi():
131
+ nest_asyncio.apply()
132
+ uvicorn.run(app, host="0.0.0.0", port=8000)
133
+
134
+ if __name__ == "__main__":
135
+ # Start FastAPI in a separate thread
136
+ api_thread = threading.Thread(target=run_fastapi, daemon=True)
137
+ api_thread.start()
138
+
139
+ # Run Streamlit interface
140
+ main()