vtiw commited on
Commit
a013987
·
verified ·
1 Parent(s): 5157dfe

Replaced MetaSeamless with IndicTrans2

Browse files
Files changed (1) hide show
  1. app.py +60 -16
app.py CHANGED
@@ -14,19 +14,26 @@ from happytransformer import HappyTextToText, TTSettings
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
15
  from transformers.integrations import deepspeed
16
  import re
 
 
17
  import torch
18
  from lang_list import (
19
  LANGUAGE_NAME_TO_CODE,
20
  T2TT_TARGET_LANGUAGE_NAMES,
21
  TEXT_SOURCE_LANGUAGE_NAMES,
22
  )
 
23
  logging.set_verbosity_error()
24
 
25
  DEFAULT_TARGET_LANGUAGE = "English"
26
- from transformers import SeamlessM4TForTextToText
27
- from transformers import AutoProcessor
28
- model = SeamlessM4TForTextToText.from_pretrained("facebook/hf-seamless-m4t-large")
29
- processor = AutoProcessor.from_pretrained("facebook/hf-seamless-m4t-large")
 
 
 
 
30
 
31
 
32
  import pytesseract as pt
@@ -174,22 +181,59 @@ def split_text_into_batches(text, max_tokens_per_batch):
174
  @spaces.GPU(duration=60)
175
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
176
  if file_uploader is not None:
177
- with open(file_uploader, 'r') as file:
178
- input_text=file.read()
179
- source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
180
- target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
181
- max_tokens_per_batch= 2048
 
 
 
 
 
 
 
 
 
182
  batches = split_text_into_batches(input_text, max_tokens_per_batch)
183
  translated_text = ""
 
184
  for batch in batches:
185
- text_inputs = processor(text=batch, src_lang=source_language_code, return_tensors="pt")
186
- output_tokens = model.generate(**text_inputs, tgt_lang=target_language_code)
187
- translated_batch = processor.decode(output_tokens[0].tolist(), skip_special_tokens=True)
188
- translated_text += translated_batch + " "
189
- output=translated_text.strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  _output_name = "result.txt"
191
- open(_output_name, 'w').write(output)
192
- return str(output), _output_name
 
 
193
 
194
  with gr.Blocks() as demo_t2tt:
195
  with gr.Row():
 
14
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM,logging
15
  from transformers.integrations import deepspeed
16
  import re
17
+ from IndicTransToolkit import IndicProcessor
18
+ import torch
19
  import torch
20
  from lang_list import (
21
  LANGUAGE_NAME_TO_CODE,
22
  T2TT_TARGET_LANGUAGE_NAMES,
23
  TEXT_SOURCE_LANGUAGE_NAMES,
24
  )
25
+
26
  logging.set_verbosity_error()
27
 
28
  DEFAULT_TARGET_LANGUAGE = "English"
29
+ # Load IndicTrans2 model
30
+ model_name = "ai4bharat/indictrans2-indic-indic-dist-320M"
31
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
32
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
33
+ ip = IndicProcessor(inference=True)
34
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
35
+ model.to(DEVICE)
36
+
37
 
38
 
39
  import pytesseract as pt
 
181
  @spaces.GPU(duration=60)
182
  def run_t2tt(file_uploader , input_text: str, source_language: str, target_language: str) -> (str, bytes):
183
  if file_uploader is not None:
184
+ with open(file_uploader.name, "r", encoding="utf-8") as file:
185
+ input_text = file.read()
186
+
187
+ # Language mapping
188
+ lang_code_map = {
189
+ "Hindi": "hin_Deva",
190
+ "Punjabi": "pan_Guru",
191
+ "English": "eng_Latn",
192
+ }
193
+
194
+ src_lang = lang_code_map[source_language]
195
+ tgt_lang = lang_code_map[target_language]
196
+
197
+ max_tokens_per_batch = 256
198
  batches = split_text_into_batches(input_text, max_tokens_per_batch)
199
  translated_text = ""
200
+
201
  for batch in batches:
202
+ batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang)
203
+ inputs = tokenizer(
204
+ batch_preprocessed,
205
+ truncation=True,
206
+ padding="longest",
207
+ return_tensors="pt",
208
+ return_attention_mask=True,
209
+ ).to(DEVICE)
210
+
211
+ with torch.no_grad():
212
+ generated_tokens = model.generate(
213
+ **inputs,
214
+ use_cache=True,
215
+ min_length=0,
216
+ max_length=256,
217
+ num_beams=5,
218
+ num_return_sequences=1,
219
+ )
220
+
221
+ with tokenizer.as_target_tokenizer():
222
+ decoded_tokens = tokenizer.batch_decode(
223
+ generated_tokens.detach().cpu().tolist(),
224
+ skip_special_tokens=True,
225
+ clean_up_tokenization_spaces=True,
226
+ )
227
+
228
+ translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang)
229
+ translated_text += " ".join(translations) + " "
230
+
231
+ output = translated_text.strip()
232
  _output_name = "result.txt"
233
+ with open(_output_name, "w", encoding="utf-8") as out_file:
234
+ out_file.write(output)
235
+
236
+ return output, _output_name
237
 
238
  with gr.Blocks() as demo_t2tt:
239
  with gr.Row():