not-lain commited on
Commit
c094f91
·
1 Parent(s): 94d1b20

add translation tab

Browse files
Files changed (2) hide show
  1. app.py +86 -10
  2. requirements.txt +3 -2
app.py CHANGED
@@ -3,7 +3,12 @@ import spaces
3
  import torch
4
  from loadimg import load_img
5
  from torchvision import transforms
6
- from transformers import AutoModelForImageSegmentation, pipeline
 
 
 
 
 
7
  from diffusers import FluxFillPipeline
8
  from PIL import Image, ImageOps
9
 
@@ -11,9 +16,11 @@ from PIL import Image, ImageOps
11
  import numpy as np
12
  from simple_lama_inpainting import SimpleLama
13
  from contextlib import contextmanager
 
14
  # import whisperx
15
  import gc
16
 
 
17
  @contextmanager
18
  def float32_high_matmul_precision():
19
  torch.set_float32_matmul_precision("high")
@@ -187,7 +194,7 @@ def erase(image=None, mask=None):
187
  # model = whisperx.load_model("large-v2", device, compute_type=compute_type)
188
  # audio_input = whisperx.load_audio(audio)
189
  # result = model.transcribe(audio_input, batch_size=batch_size)
190
-
191
  # # Clear GPU memory
192
  # del model
193
  # gc.collect()
@@ -205,7 +212,7 @@ def erase(image=None, mask=None):
205
  # # 3. Assign speaker labels
206
  # diarize_model = whisperx.DiarizationPipeline(device=device)
207
  # diarize_segments = diarize_model(audio_input)
208
-
209
  # # Combine transcription with speaker diarization
210
  # result = whisperx.assign_word_speakers(diarize_segments, result)
211
 
@@ -214,7 +221,7 @@ def erase(image=None, mask=None):
214
  # for segment in result["segments"]:
215
  # if not isinstance(segment, dict):
216
  # continue
217
-
218
  # speaker = f"[Speaker {segment.get('speaker', 'Unknown')}]"
219
  # start_time = f"{float(segment.get('start', 0)):.2f}"
220
  # end_time = f"{float(segment.get('end', 0)):.2f}"
@@ -231,6 +238,32 @@ def erase(image=None, mask=None):
231
  # torch.cuda.empty_cache()
232
 
233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  @spaces.GPU(duration=120)
235
  def main(*args):
236
  api_num = args[0]
@@ -247,6 +280,8 @@ def main(*args):
247
  return erase(*args)
248
  # elif api_num == 6:
249
  # return transcribe(*args)
 
 
250
 
251
 
252
  rmbg_tab = gr.Interface(
@@ -349,7 +384,49 @@ transcribe_tab = gr.Interface(
349
  title="Audio Transcription",
350
  description="Upload an audio file to extract text using WhisperX with speaker diarization",
351
  api_name="transcribe",
352
- examples=[]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
353
  )
354
 
355
  demo = gr.TabbedInterface(
@@ -357,20 +434,19 @@ demo = gr.TabbedInterface(
357
  rmbg_tab,
358
  outpaint_tab,
359
  inpaint_tab,
360
- # sam2_tab,
361
  erase_tab,
362
  transcribe_tab,
 
363
  ],
364
  [
365
  "remove background",
366
  "outpainting",
367
  "inpainting",
368
- # "sam2",
369
  "erase",
370
- # "transcribe",
 
371
  ],
372
  title="Utilities that require GPU",
373
  )
374
 
375
-
376
- demo.launch()
 
3
  import torch
4
  from loadimg import load_img
5
  from torchvision import transforms
6
+ from transformers import (
7
+ AutoModelForImageSegmentation,
8
+ pipeline,
9
+ MBartForConditionalGeneration,
10
+ MBart50TokenizerFast,
11
+ )
12
  from diffusers import FluxFillPipeline
13
  from PIL import Image, ImageOps
14
 
 
16
  import numpy as np
17
  from simple_lama_inpainting import SimpleLama
18
  from contextlib import contextmanager
19
+
20
  # import whisperx
21
  import gc
22
 
23
+
24
  @contextmanager
25
  def float32_high_matmul_precision():
26
  torch.set_float32_matmul_precision("high")
 
194
  # model = whisperx.load_model("large-v2", device, compute_type=compute_type)
195
  # audio_input = whisperx.load_audio(audio)
196
  # result = model.transcribe(audio_input, batch_size=batch_size)
197
+
198
  # # Clear GPU memory
199
  # del model
200
  # gc.collect()
 
212
  # # 3. Assign speaker labels
213
  # diarize_model = whisperx.DiarizationPipeline(device=device)
214
  # diarize_segments = diarize_model(audio_input)
215
+
216
  # # Combine transcription with speaker diarization
217
  # result = whisperx.assign_word_speakers(diarize_segments, result)
218
 
 
221
  # for segment in result["segments"]:
222
  # if not isinstance(segment, dict):
223
  # continue
224
+
225
  # speaker = f"[Speaker {segment.get('speaker', 'Unknown')}]"
226
  # start_time = f"{float(segment.get('start', 0)):.2f}"
227
  # end_time = f"{float(segment.get('end', 0)):.2f}"
 
238
  # torch.cuda.empty_cache()
239
 
240
 
241
+ def translate_text(text, source_lang, target_lang):
242
+ model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
243
+ tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
244
+
245
+ # Set source language
246
+ tokenizer.src_lang = source_lang
247
+
248
+ # Encode the input text
249
+ encoded_text = tokenizer(text, return_tensors="pt")
250
+
251
+ # Generate translation
252
+ generated_tokens = model.generate(
253
+ **encoded_text,
254
+ forced_bos_token_id=tokenizer.lang_code_to_id[target_lang]
255
+ )
256
+
257
+ # Decode the generated tokens
258
+ translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
259
+
260
+ # Clear GPU memory
261
+ del model
262
+ gc.collect()
263
+ torch.cuda.empty_cache()
264
+
265
+ return translation
266
+
267
  @spaces.GPU(duration=120)
268
  def main(*args):
269
  api_num = args[0]
 
280
  return erase(*args)
281
  # elif api_num == 6:
282
  # return transcribe(*args)
283
+ elif api_num == 7:
284
+ return translate_text(*args)
285
 
286
 
287
  rmbg_tab = gr.Interface(
 
384
  title="Audio Transcription",
385
  description="Upload an audio file to extract text using WhisperX with speaker diarization",
386
  api_name="transcribe",
387
+ examples=[],
388
+ )
389
+
390
+ translate_tab = gr.Interface(
391
+ fn=main,
392
+ inputs=[
393
+ gr.Number(value=7, interactive=False),
394
+ gr.Textbox(label="Text to translate"),
395
+ gr.Dropdown(
396
+ choices=[
397
+ "ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX",
398
+ "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV",
399
+ "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN",
400
+ "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID",
401
+ "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF",
402
+ "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA",
403
+ "ur_PK", "xh_ZA", "gl_ES", "sl_SI"
404
+ ],
405
+ label="Source Language",
406
+ value="en_XX"
407
+ ),
408
+ gr.Dropdown(
409
+ choices=[
410
+ "ar_AR", "cs_CZ", "de_DE", "en_XX", "es_XX", "et_EE", "fi_FI", "fr_XX",
411
+ "gu_IN", "hi_IN", "it_IT", "ja_XX", "kk_KZ", "ko_KR", "lt_LT", "lv_LV",
412
+ "my_MM", "ne_NP", "nl_XX", "ro_RO", "ru_RU", "si_LK", "tr_TR", "vi_VN",
413
+ "zh_CN", "af_ZA", "az_AZ", "bn_IN", "fa_IR", "he_IL", "hr_HR", "id_ID",
414
+ "ka_GE", "km_KH", "mk_MK", "ml_IN", "mn_MN", "mr_IN", "pl_PL", "ps_AF",
415
+ "pt_XX", "sv_SE", "sw_KE", "ta_IN", "te_IN", "th_TH", "tl_XX", "uk_UA",
416
+ "ur_PK", "xh_ZA", "gl_ES", "sl_SI"
417
+ ],
418
+ label="Target Language",
419
+ value="fr_XX"
420
+ )
421
+ ],
422
+ outputs=gr.Textbox(label="Translated Text"),
423
+ title="Text Translation",
424
+ description="Translate text between multiple languages using mBART-50",
425
+ api_name="translate",
426
+ examples=[
427
+ [7, "Hello, how are you?", "en_XX", "fr_XX"],
428
+ [7, "Bonjour, comment allez-vous?", "fr_XX", "en_XX"]
429
+ ]
430
  )
431
 
432
  demo = gr.TabbedInterface(
 
434
  rmbg_tab,
435
  outpaint_tab,
436
  inpaint_tab,
 
437
  erase_tab,
438
  transcribe_tab,
439
+ translate_tab
440
  ],
441
  [
442
  "remove background",
443
  "outpainting",
444
  "inpainting",
 
445
  "erase",
446
+ "transcribe",
447
+ "translate"
448
  ],
449
  title="Utilities that require GPU",
450
  )
451
 
452
+ demo.launch()
 
requirements.txt CHANGED
@@ -3,7 +3,7 @@ spaces
3
  torch
4
  torchvision
5
  git+https://github.com/huggingface/diffusers.git
6
- transformers
7
  safetensors
8
  accelerate
9
  sentencepiece
@@ -22,4 +22,5 @@ einops
22
  # git+https://github.com/facebookresearch/sam2.git
23
  matplotlib
24
  simple-lama-inpainting
25
- # git+https://github.com/m-bain/whisperX.git
 
 
3
  torch
4
  torchvision
5
  git+https://github.com/huggingface/diffusers.git
6
+ transformers>=4.30.0
7
  safetensors
8
  accelerate
9
  sentencepiece
 
22
  # git+https://github.com/facebookresearch/sam2.git
23
  matplotlib
24
  simple-lama-inpainting
25
+ # git+https://github.com/m-bain/whisperX.git
26
+ sacremoses