german-asr / app.py
flozi00's picture
Update app.py
922cd73
raw
history blame
4.13 kB
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
import gradio as gr
import re
import torch
from pyctcdecode import BeamSearchDecoderCTC
import torch
lmID = "aware-ai/german-lowercase-wiki-4gram"
decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID)
p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-xls-r-300m-german-lowercase", decoder=decoder)
ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")
vadmodel, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False)
(get_speech_timestamps,
_, read_audio,
*_) = utils
#model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
#tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
def translate(src, tgt, text):
src = src.split(" ")[-1][1:-1]
tgt = tgt.split(" ")[-1][1:-1]
# translate
tokenizer.src_lang = src
encoded_src = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return result
def transcribe(audio):
sampling_rate = 16000
audio, sr = librosa.load(audio, sr=sampling_rate)
speech_timestamps = get_speech_timestamps(audio, model, sampling_rate=sampling_rate)
chunks = [audio[i["start"]:i["end"]] for i in speech_timestamps]
transcribed = " ".join([text["text"] for text in p(chunks, chunk_length_s=20, stride_length_s=(0, 0))])
punctuated = ttp(transcribed, max_length = 512)[0]["generated_text"]
return transcribed, punctuated
def get_asr_interface():
return gr.Interface(
fn=transcribe,
inputs=[
gr.inputs.Audio(source="microphone", type="filepath")
],
outputs=[
"textbox",
"textbox"
])
def get_translate_interface():
langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn),
Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk),
Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn),
Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)"""
lang_list = [lang.strip() for lang in langs.split(',')]
return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages")
interfaces = [
get_asr_interface(),
#get_translate_interface(),
]
names = [
"ASR",
#"translate",
]
gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0")