vtiw commited on
Commit
0e04945
·
verified ·
1 Parent(s): 7738d4d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -88
app.py CHANGED
@@ -1,21 +1,17 @@
1
  import gradio as gr
2
  import nltk
3
- nltk.download('punkt')
4
- from lang_list import (
5
- LANGUAGE_NAME_TO_CODE,
6
- T2TT_TARGET_LANGUAGE_NAMES,
7
- TEXT_SOURCE_LANGUAGE_NAMES,
8
- )
9
- DEFAULT_TARGET_LANGUAGE = "English"
10
- from transformers import SeamlessM4TForTextToText
11
- from transformers import AutoProcessor
12
- model = SeamlessM4TForTextToText.from_pretrained("facebook/hf-seamless-m4t-medium")
13
- processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-medium")
14
 
15
- # text_inputs = processor(text = "Hello, my dog is cute", src_lang="eng", return_tensors="pt")
16
- # output_tokens = model.generate(**text_inputs, tgt_lang="pan")
17
- # translated_text_from_text = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
18
- # print(translated_text_from_text)
 
 
 
19
 
20
  def split_text_into_batches(text, max_tokens_per_batch):
21
  sentences = nltk.sent_tokenize(text) # Tokenize text into sentences
@@ -31,93 +27,88 @@ def split_text_into_batches(text, max_tokens_per_batch):
31
  batches.append(current_batch.strip()) # Add the last batch
32
  return batches
33
 
34
- def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
35
  if file_uploader is not None:
36
- with open(file_uploader, 'r') as file:
37
- input_text=file.read()
38
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
39
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
40
- max_tokens_per_batch= 256
 
 
 
 
 
 
 
 
 
41
  batches = split_text_into_batches(input_text, max_tokens_per_batch)
42
  translated_text = ""
 
43
  for batch in batches:
44
- text_inputs = processor(text=batch, src_lang=source_language_code, return_tensors="pt")
45
- output_tokens = model.generate(**text_inputs, tgt_lang=target_language_code)
46
- translated_batch = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
47
- translated_text += translated_batch + " "
48
- output=translated_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
  _output_name = "result.txt"
50
- open(_output_name, 'w').write(output)
51
- return str(output), _output_name
52
 
53
- with gr.Blocks() as demo_t2tt:
 
 
 
54
  with gr.Row():
55
  with gr.Column():
56
- with gr.Group():
57
- file_uploader = gr.File(label="Upload a text file (Optional)")
58
- input_text = gr.Textbox(label="Input text")
59
- with gr.Row():
60
- source_language = gr.Dropdown(
61
- label="Source language",
62
- choices=TEXT_SOURCE_LANGUAGE_NAMES,
63
- value="Punjabi",
64
- )
65
- target_language = gr.Dropdown(
66
- label="Target language",
67
- choices=T2TT_TARGET_LANGUAGE_NAMES,
68
- value=DEFAULT_TARGET_LANGUAGE,
69
- )
70
  btn = gr.Button("Translate")
71
  with gr.Column():
72
- output_text = gr.Textbox(label="Translated text")
73
  output_file = gr.File(label="Translated text file")
74
 
75
- gr.Examples(
76
- examples=[
77
- [
78
- None,
79
- "The sinister destruction of the holy Akal Takht and the ruthless massacre of thousands of innocent pilgrims had unmasked the deep-seated hatred and animosity that the Indian Government had been nurturing against Sikhs ever since independence",
80
- "English",
81
- "Punjabi",
82
- ],
83
- [
84
- None,
85
- "It contains. much useful information about administrative, revenue, judicial and ecclesiastical activities in various areas which, it is hoped, would supplement the information available in official records.",
86
- "English",
87
- "Hindi",
88
- ],
89
- [
90
- None,
91
- "दुनिया में बहुत सी अलग-अलग भाषाएं हैं और उनमें अपने वर्ण और शब्दों का भंडार होता है. इसमें में कुछ उनके अपने शब्द होते हैं तो कुछ ऐसे भी हैं, जो दूसरी भाषाओं से लिए जाते हैं.",
92
- "Hindi",
93
- "Punjabi",
94
- ],
95
- [
96
- None,
97
- "ਸੂੂਬੇ ਦੇ ਕਈ ਜ਼ਿਲ੍ਹਿਆਂ ’ਚ ਬੁੱਧਵਾਰ ਸਵੇਰੇ ਸੰਘਣੀ ਧੁੰਦ ਛਾਈ ਰਹੀ ਤੇ ਤੇਜ਼ ਹਵਾਵਾਂ ਨੇ ਕਾਂਬਾ ਹੋਰ ਵਧਾ ਦਿੱਤਾ। ਸੱਤ ਸ਼ਹਿਰਾਂ ’ਚ ਦਿਨ ਦਾ ਤਾਪਮਾਨ ਦਸ ਡਿਗਰੀ ਸੈਲਸੀਅਸ ਦੇ ਆਸਪਾਸ ਰਿਹਾ। ਸੂਬੇ ’ਚ ਵੱਧ ਤੋਂ ਵੱਧ ਤਾਪਮਾਨ ’ਚ ਵੀ ਦਸ ਡਿਗਰੀ ਸੈਲਸੀਅਸ ਦੀ ਗਿਰਾਵਟ ਦਰਜ ਕੀਤੀ ਗਈ",
98
- "Punjabi",
99
- "English",
100
- ],
101
- ],
102
- inputs=[file_uploader ,input_text, source_language, target_language],
103
- outputs=[output_text, output_file],
104
- fn=run_t2tt,
105
- cache_examples=False,
106
- api_name=False,
107
- )
108
-
109
- gr.on(
110
- triggers=[input_text.submit, btn.click],
111
- fn=run_t2tt,
112
  inputs=[file_uploader, input_text, source_language, target_language],
113
  outputs=[output_text, output_file],
114
- api_name="t2tt",
115
  )
116
 
117
- with gr.Blocks() as demo:
118
- with gr.Tabs():
119
- with gr.Tab(label="Translate"):
120
- demo_t2tt.render()
121
-
122
  if __name__ == "__main__":
123
- demo.launch()
 
1
  import gradio as gr
2
  import nltk
3
+ nltk.download('punkt_tab')
4
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
5
+ from IndicTransToolkit.IndicTransToolkit import IndicProcessor
6
+ import torch
 
 
 
 
 
 
 
7
 
8
+ # Load IndicTrans2 model
9
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
11
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
12
+ ip = IndicProcessor(inference=True)
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ model.to(DEVICE)
15
 
16
  def split_text_into_batches(text, max_tokens_per_batch):
17
  sentences = nltk.sent_tokenize(text) # Tokenize text into sentences
 
27
  batches.append(current_batch.strip()) # Add the last batch
28
  return batches
29
 
30
+ def run_translation(file_uploader, input_text, source_language, target_language):
31
  if file_uploader is not None:
32
+ with open(file_uploader.name, "r", encoding="utf-8") as file:
33
+ input_text = file.read()
34
+
35
+ # Language mapping
36
+ lang_code_map = {
37
+ "Hindi": "hin_Deva",
38
+ "Punjabi": "pan_Guru",
39
+ "English": "eng_Latn",
40
+ }
41
+
42
+ src_lang = lang_code_map[source_language]
43
+ tgt_lang = lang_code_map[target_language]
44
+
45
+ max_tokens_per_batch = 256
46
  batches = split_text_into_batches(input_text, max_tokens_per_batch)
47
  translated_text = ""
48
+
49
  for batch in batches:
50
+ batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang)
51
+ inputs = tokenizer(
52
+ batch_preprocessed,
53
+ truncation=True,
54
+ padding="longest",
55
+ return_tensors="pt",
56
+ return_attention_mask=True,
57
+ ).to(DEVICE)
58
+
59
+ with torch.no_grad():
60
+ generated_tokens = model.generate(
61
+ **inputs,
62
+ use_cache=True,
63
+ min_length=0,
64
+ max_length=256,
65
+ num_beams=5,
66
+ num_return_sequences=1,
67
+ )
68
+
69
+ with tokenizer.as_target_tokenizer():
70
+ decoded_tokens = tokenizer.batch_decode(
71
+ generated_tokens.detach().cpu().tolist(),
72
+ skip_special_tokens=True,
73
+ clean_up_tokenization_spaces=True,
74
+ )
75
+
76
+ translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang)
77
+ translated_text += " ".join(translations) + " "
78
+
79
+ output = translated_text.strip()
80
  _output_name = "result.txt"
81
+ with open(_output_name, "w", encoding="utf-8") as out_file:
82
+ out_file.write(output)
83
 
84
+ return output, _output_name
85
+
86
+ # Define Gradio UI
87
+ with gr.Blocks() as demo:
88
  with gr.Row():
89
  with gr.Column():
90
+ file_uploader = gr.File(label="Upload a text file (Optional)")
91
+ input_text = gr.Textbox(label="Input text", lines=5, placeholder="Enter text here...")
92
+ source_language = gr.Dropdown(
93
+ label="Source language",
94
+ choices=["Hindi", "Punjabi", "English"],
95
+ value="Hindi",
96
+ )
97
+ target_language = gr.Dropdown(
98
+ label="Target language",
99
+ choices=["Hindi", "Punjabi", "English"],
100
+ value="English",
101
+ )
 
 
102
  btn = gr.Button("Translate")
103
  with gr.Column():
104
+ output_text = gr.Textbox(label="Translated text", lines=5)
105
  output_file = gr.File(label="Translated text file")
106
 
107
+ btn.click(
108
+ fn=run_translation,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  inputs=[file_uploader, input_text, source_language, target_language],
110
  outputs=[output_text, output_file],
 
111
  )
112
 
 
 
 
 
 
113
  if __name__ == "__main__":
114
+ demo.launch(debug=True)