File size: 4,126 Bytes
160cee9
412c852
aa756f5
160cee9
ab662d2
922cd73
 
412c852
606f61c
ab662d2
6e3264b
793e132
 
922cd73
 
 
 
 
 
 
412c852
00c9bf5
 
160cee9
 
 
 
 
 
 
 
 
 
 
 
191d30d
922cd73
 
 
 
 
aa756f5
793e132
 
 
412c852
2e8cc61
 
 
 
 
 
 
793e132
2e8cc61
 
 
160cee9
 
 
 
 
 
 
 
 
 
2e8cc61
 
00c9bf5
2e8cc61
 
 
 
00c9bf5
2e8cc61
 
fb40cda
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
import gradio as gr
import re
import torch
from pyctcdecode import BeamSearchDecoderCTC
import torch


lmID = "aware-ai/german-lowercase-wiki-4gram"
decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID)
p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-xls-r-300m-german-lowercase", decoder=decoder)
ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")

vadmodel, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
                              model='silero_vad',
                              force_reload=False)

(get_speech_timestamps,
 _, read_audio,
 *_) = utils

#model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
#tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")

def translate(src, tgt, text):
    src = src.split(" ")[-1][1:-1]
    tgt = tgt.split(" ")[-1][1:-1]

    # translate
    tokenizer.src_lang = src
    encoded_src = tokenizer(text, return_tensors="pt")
    generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True)
    result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
    return result

def transcribe(audio):
    sampling_rate = 16000
    audio, sr = librosa.load(audio, sr=sampling_rate)
    speech_timestamps = get_speech_timestamps(audio, model, sampling_rate=sampling_rate)
    chunks = [audio[i["start"]:i["end"]] for i in speech_timestamps]
    transcribed = " ".join([text["text"] for text in p(chunks, chunk_length_s=20, stride_length_s=(0, 0))])
    
    punctuated = ttp(transcribed, max_length = 512)[0]["generated_text"]
    
    return transcribed, punctuated 

def get_asr_interface():
    return gr.Interface(
        fn=transcribe, 
        inputs=[
            gr.inputs.Audio(source="microphone", type="filepath")
        ],
        outputs=[
            "textbox",
            "textbox"
        ])
        

def get_translate_interface():
    langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), 
    Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), 
    Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), 
    Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)"""
    lang_list = [lang.strip() for lang in langs.split(',')]
    return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages")
        

interfaces = [
    get_asr_interface(),
    #get_translate_interface(),
]

names = [
    "ASR",
    #"translate",
]

gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0")