Spaces:
Runtime error
Runtime error
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer | |
import gradio as gr | |
import re | |
import torch | |
from pyctcdecode import BeamSearchDecoderCTC | |
import torch | |
lmID = "aware-ai/german-lowercase-wiki-4gram" | |
decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID) | |
p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-xls-r-300m-german-lowercase", decoder=decoder) | |
ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar") | |
vadmodel, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', | |
model='silero_vad', | |
force_reload=False) | |
(get_speech_timestamps, | |
_, read_audio, | |
*_) = utils | |
#model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B") | |
#tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B") | |
def translate(src, tgt, text): | |
src = src.split(" ")[-1][1:-1] | |
tgt = tgt.split(" ")[-1][1:-1] | |
# translate | |
tokenizer.src_lang = src | |
encoded_src = tokenizer(text, return_tensors="pt") | |
generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True) | |
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
return result | |
def transcribe(audio): | |
sampling_rate = 16000 | |
audio, sr = librosa.load(audio, sr=sampling_rate) | |
speech_timestamps = get_speech_timestamps(audio, model, sampling_rate=sampling_rate) | |
chunks = [audio[i["start"]:i["end"]] for i in speech_timestamps] | |
transcribed = " ".join([text["text"] for text in p(chunks, chunk_length_s=20, stride_length_s=(0, 0))]) | |
punctuated = ttp(transcribed, max_length = 512)[0]["generated_text"] | |
return transcribed, punctuated | |
def get_asr_interface(): | |
return gr.Interface( | |
fn=transcribe, | |
inputs=[ | |
gr.inputs.Audio(source="microphone", type="filepath") | |
], | |
outputs=[ | |
"textbox", | |
"textbox" | |
]) | |
def get_translate_interface(): | |
langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn), | |
Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk), | |
Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn), | |
Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)""" | |
lang_list = [lang.strip() for lang in langs.split(',')] | |
return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages") | |
interfaces = [ | |
get_asr_interface(), | |
#get_translate_interface(), | |
] | |
names = [ | |
"ASR", | |
#"translate", | |
] | |
gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0") |