flozi00 commited on
Commit
160cee9
·
1 Parent(s): fb40cda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -1
app.py CHANGED
@@ -1,10 +1,25 @@
1
- from transformers import pipeline
2
  import gradio as gr
3
  import re
 
4
 
5
  p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-base-german")
6
  ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def transcribe(audio):
9
  transcribed = p(audio, chunk_length_s=10, stride_length_s=(4, 2))["text"].lower()
10
  transcribed_corrected = ttp(re.sub("[^a-zA-Z0-9öäüÖÄÜ ]", " ",transcribed))[0]["generated_text"]
@@ -21,12 +36,24 @@ def get_asr_interface():
21
  "textbox"
22
  ])
23
 
 
 
 
 
 
 
 
 
 
 
24
  interfaces = [
25
  get_asr_interface(),
 
26
  ]
27
 
28
  names = [
29
  "ASR",
 
30
  ]
31
 
32
  gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0")
 
1
+ from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
2
  import gradio as gr
3
  import re
4
+ import torch
5
 
6
  p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-base-german")
7
  ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")
8
 
9
+ model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
10
+ tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
11
+
12
+ def translate(src, tgt, text):
13
+ src = src.split(" ")[-1][1:-1]
14
+ tgt = tgt.split(" ")[-1][1:-1]
15
+
16
+ # translate
17
+ tokenizer.src_lang = src
18
+ encoded_src = tokenizer(text, return_tensors="pt")
19
+ generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True)
20
+ result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
21
+ return result
22
+
23
  def transcribe(audio):
24
  transcribed = p(audio, chunk_length_s=10, stride_length_s=(4, 2))["text"].lower()
25
  transcribed_corrected = ttp(re.sub("[^a-zA-Z0-9öäüÖÄÜ ]", " ",transcribed))[0]["generated_text"]
 
36
  "textbox"
37
  ])
38
 
39
+
40
+ def get_translate_interface():
41
+ langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn),
42
+ Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk),
43
+ Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn),
44
+ Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)"""
45
+ lang_list = [lang.strip() for lang in langs.split(',')]
46
+ return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages")
47
+
48
+
49
  interfaces = [
50
  get_asr_interface(),
51
+ get_translate_interface(),
52
  ]
53
 
54
  names = [
55
  "ASR",
56
+ "translate",
57
  ]
58
 
59
  gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0")