Spaces:
Runtime error
Runtime error
File size: 4,126 Bytes
160cee9 412c852 aa756f5 160cee9 ab662d2 922cd73 412c852 606f61c ab662d2 6e3264b 793e132 922cd73 412c852 00c9bf5 160cee9 191d30d 922cd73 aa756f5 793e132 412c852 2e8cc61 793e132 2e8cc61 160cee9 2e8cc61 00c9bf5 2e8cc61 00c9bf5 2e8cc61 fb40cda |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
from transformers import pipeline, M2M100ForConditionalGeneration, M2M100Tokenizer
import gradio as gr
import re
import torch
from pyctcdecode import BeamSearchDecoderCTC
import torch
lmID = "aware-ai/german-lowercase-wiki-4gram"
decoder = BeamSearchDecoderCTC.load_from_hf_hub(lmID)
p = pipeline("automatic-speech-recognition", model="aware-ai/robust-wav2vec2-xls-r-300m-german-lowercase", decoder=decoder)
ttp = pipeline("text2text-generation", model="aware-ai/marian-german-grammar")
vadmodel, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad',
model='silero_vad',
force_reload=False)
(get_speech_timestamps,
_, read_audio,
*_) = utils
#model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_1.2B")
#tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_1.2B")
def translate(src, tgt, text):
src = src.split(" ")[-1][1:-1]
tgt = tgt.split(" ")[-1][1:-1]
# translate
tokenizer.src_lang = src
encoded_src = tokenizer(text, return_tensors="pt")
generated_tokens = model.generate(**encoded_src, forced_bos_token_id=tokenizer.get_lang_id(tgt), use_cache=True)
result = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
return result
def transcribe(audio):
sampling_rate = 16000
audio, sr = librosa.load(audio, sr=sampling_rate)
speech_timestamps = get_speech_timestamps(audio, model, sampling_rate=sampling_rate)
chunks = [audio[i["start"]:i["end"]] for i in speech_timestamps]
transcribed = " ".join([text["text"] for text in p(chunks, chunk_length_s=20, stride_length_s=(0, 0))])
punctuated = ttp(transcribed, max_length = 512)[0]["generated_text"]
return transcribed, punctuated
def get_asr_interface():
return gr.Interface(
fn=transcribe,
inputs=[
gr.inputs.Audio(source="microphone", type="filepath")
],
outputs=[
"textbox",
"textbox"
])
def get_translate_interface():
langs = """Afrikaans (af), Amharic (am), Arabic (ar), Asturian (ast), Azerbaijani (az), Bashkir (ba), Belarusian (be), Bulgarian (bg), Bengali (bn), Breton (br), Bosnian (bs), Catalan; Valencian (ca), Cebuano (ceb), Czech (cs), Welsh (cy), Danish (da), German (de), Greeek (el), English (en), Spanish (es), Estonian (et), Persian (fa), Fulah (ff), Finnish (fi), French (fr), Western Frisian (fy), Irish (ga), Gaelic; Scottish Gaelic (gd), Galician (gl), Gujarati (gu), Hausa (ha), Hebrew (he), Hindi (hi), Croatian (hr), Haitian; Haitian Creole (ht), Hungarian (hu), Armenian (hy), Indonesian (id), Igbo (ig), Iloko (ilo), Icelandic (is), Italian (it), Japanese (ja), Javanese (jv), Georgian (ka), Kazakh (kk), Central Khmer (km), Kannada (kn),
Korean (ko), Luxembourgish; Letzeburgesch (lb), Ganda (lg), Lingala (ln), Lao (lo), Lithuanian (lt), Latvian (lv), Malagasy (mg), Macedonian (mk), Malayalam (ml), Mongolian (mn), Marathi (mr), Malay (ms), Burmese (my), Nepali (ne), Dutch; Flemish (nl), Norwegian (no), Northern Sotho (ns), Occitan (post 1500) (oc), Oriya (or), Panjabi; Punjabi (pa), Polish (pl), Pushto; Pashto (ps), Portuguese (pt), Romanian; Moldavian; Moldovan (ro), Russian (ru), Sindhi (sd), Sinhala; Sinhalese (si), Slovak (sk),
Slovenian (sl), Somali (so), Albanian (sq), Serbian (sr), Swati (ss), Sundanese (su), Swedish (sv), Swahili (sw), Tamil (ta), Thai (th), Tagalog (tl), Tswana (tn),
Turkish (tr), Ukrainian (uk), Urdu (ur), Uzbek (uz), Vietnamese (vi), Wolof (wo), Xhosa (xh), Yiddish (yi), Yoruba (yo), Chinese (zh), Zulu (zu)"""
lang_list = [lang.strip() for lang in langs.split(',')]
return gr.Interface(translate, inputs=[gr.inputs.Dropdown(lang_list, label="Source Language"), gr.inputs.Dropdown(lang_list, label="Target Language"), 'text'], outputs=gr.outputs.Textbox(), title="Translate Between 100 languages")
interfaces = [
get_asr_interface(),
#get_translate_interface(),
]
names = [
"ASR",
#"translate",
]
gr.TabbedInterface(interfaces, names).launch(server_name = "0.0.0.0") |