File size: 3,623 Bytes
160cee9
412c852
aa756f5
160cee9
ab662d2
412c852
c0f356c
ab662d2
bbdaaee
aa756f5
412c852
160cee9
 
 
 
 
 
 
 
 
 
 
 
 
 
191d30d
aa756f5
18e36a8
aa756f5
dad77b7
412c852
2e8cc61
 
 
 
 
 
 
 
 
 
160cee9
 
 
 
 
 
 
 
 
 
2e8cc61
 
160cee9
2e8cc61
 
 
 
160cee9
2e8cc61
 
fb40cda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
import gradio as gr
import re
import torch
from pyctcdecode import BeamSearchDecoderCTC

lmID = "aware-ai/german-lowercase-4gram-kenlm"
decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID)
p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-base-german-lowercase", decoder=decoder)
ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")

model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")

def translate(src, tgt, text):
    src = src.split(" ")[-1][1:-1]
    tgt = tgt.split(" ")[-1][1:-1]

    # translate
    tokenizer.src_lang = src
    encoded_src = tokenizer(text, return_tensors="pt")
    generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True)
    result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    return result

def transcribe(audio):
    transcribed = p(audio, chunk_length_s=10, stride_length_s=(4, 2))["text"].lower()
    transcribed_corrected = ttp(re.sub("[^a-zA-Z0-9öäüÖÄÜ ]", " ",transcribed))[0]["generated_text"]
    
    return transcribed_corrected

def get_asr_interface():
    return gr.Interface(
        fn=transcribe, 
        inputs=[
            gr.inputs.Audio(source="microphone", type="filepath")
        ],
        outputs=[
            "textbox"
        ])
        

def get_translate_interface():
    langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), 
    Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), 
    Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), 
    Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)"""
    lang_list = [lang.strip() for lang in langs.split(',')]
    return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages")
        

interfaces = [
    get_asr_interface(),
    get_translate_interface(),
]

names = [
    "ASR",
    "translate",
]

gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0")