flozi00 commited on
Commit
25fcb65
·
1 Parent(s): 8113b92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -36
app.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
2
  import gradio as gr
3
  import re
4
  import torch
@@ -13,26 +13,13 @@ decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID)
13
  p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-base-german-lowercase", decoder=decoder)
14
  ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")
15
 
16
- #model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
17
- #tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
18
-
19
- def translate(src, tgt, text):
20
- src = src.split(" ")[-1][1:-1]
21
- tgt = tgt.split(" ")[-1][1:-1]
22
-
23
- # translate
24
- tokenizer.src_lang = src
25
- encoded_src = tokenizer(text, return_tensors="pt")
26
- generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True)
27
- result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
28
- return result
29
-
30
  def transcribe(audio):
31
- transcribed = p(audio[1], chunk_length_s=20, stride_length_s=(0, 0))["text"]
32
-
33
- punctuated = ttp(transcribed, max_length = 512)[0]["generated_text"]
34
 
35
- return transcribed, punctuated
 
 
36
 
37
  def get_asr_interface():
38
  return gr.Interface(
@@ -42,27 +29,21 @@ def get_asr_interface():
42
  ],
43
  outputs=[
44
  "textbox",
45
- "textbox"
46
  ])
47
 
48
-
49
- def get_translate_interface():
50
- langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn),
51
- Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk),
52
- Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn),
53
- Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)"""
54
- lang_list = [lang.strip() for lang in langs.split(',')]
55
- return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages")
56
-
57
 
58
  interfaces = [
59
  get_asr_interface(),
60
- #get_translate_interface(),
61
- ]
62
-
63
- names = [
64
- "ASR",
65
- #"translate",
66
  ]
67
 
68
- gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0")
 
1
+ from transformers import pipeline
2
  import gradio as gr
3
  import re
4
  import torch
 
13
  p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-base-german-lowercase", decoder=decoder)
14
  ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  def transcribe(audio):
17
+ transcribed = p(audio[1], chunk_length_s=20, stride_length_s=(0, 0))["text"]
18
+ return transcribed
 
19
 
20
+ def punctuate(text):
21
+ punctuated = ttp(text, max_length = 512)[0]["generated_text"]
22
+ return punctuated
23
 
24
  def get_asr_interface():
25
  return gr.Interface(
 
29
  ],
30
  outputs=[
31
  "textbox",
 
32
  ])
33
 
34
+ def get_punctuate_interface():
35
+ return gr.Interface(
36
+ fn=punctuate,
37
+ inputs=[
38
+ "textbox"
39
+ ],
40
+ outputs=[
41
+ "textbox",
42
+ ])
43
 
44
  interfaces = [
45
  get_asr_interface(),
46
+ get_punctuate_interface(),
 
 
 
 
 
47
  ]
48
 
49
+ gradio.Series(interfaces).launch(server_name = "0.0.0.0")