juangtzi's picture
Update app.py
417f0b6 verified
raw
history blame
4.66 kB
import gradio as gr
import numpy as np
import torch
from transformers import pipeline, VitsModel, AutoTokenizer, AutoTokenizer
device = "cuda:0" if torch.cuda.is_available() else "cpu"
translation_models = {
"en": "Helsinki-NLP/opus-mt-en-es", # Inglés a Español
"fr": "Helsinki-NLP/opus-mt-fr-es", # Francés a Español
"de": "Helsinki-NLP/opus-mt-de-es", # Alemán a Español
"it": "Helsinki-NLP/opus-mt-it-es", # Italiano a Español
"pt": "Helsinki-NLP/opus-mt-pt-es", # Portugués a Español
"nl": "Helsinki-NLP/opus-mt-nl-es", # Neerlandés (Holandés) a Español
"fi": "Helsinki-NLP/opus-mt-fi-es", # Finés a Español
"sv": "Helsinki-NLP/opus-mt-sv-es", # Sueco a Español
"da": "Helsinki-NLP/opus-mt-da-es", # Danés a Español
"no": "Helsinki-NLP/opus-mt-no-es", # Noruego a Español
"ru": "Helsinki-NLP/opus-mt-ru-es", # Ruso a Español
"pl": "Helsinki-NLP/opus-mt-pl-es", # Polaco a Español
"cs": "Helsinki-NLP/opus-mt-cs-es", # Checo a Español
"tr": "Helsinki-NLP/opus-mt-tr-es", # Turco a Español
"zh": "Helsinki-NLP/opus-mt-zh-es", # Chino a Español
"ja": "Helsinki-NLP/opus-mt-ja-es", # Japonés a Español
"ar": "Helsinki-NLP/opus-mt-ar-es", # Árabe a Español
"ro": "Helsinki-NLP/opus-mt-ro-es", # Rumano a Español
"el": "Helsinki-NLP/opus-mt-el-es", # Griego a Español
"bg": "Helsinki-NLP/opus-mt-bg-es", # Búlgaro a Español
"uk": "Helsinki-NLP/opus-mt-uk-es", # Ucraniano a Español
"he": "Helsinki-NLP/opus-mt-he-es", # Hebreo a Español
"lt": "Helsinki-NLP/opus-mt-lt-es", # Lituano a Español
"et": "Helsinki-NLP/opus-mt-et-es", # Estonio a Español
"hr": "Helsinki-NLP/opus-mt-hr-es", # Croata a Español
"hu": "Helsinki-NLP/opus-mt-hu-es", # Húngaro a Español
"lv": "Helsinki-NLP/opus-mt-lv-es", # Letón a Español
"sl": "Helsinki-NLP/opus-mt-sl-es", # Esloveno a Español
"sk": "Helsinki-NLP/opus-mt-sk-es", # Eslovaco a Español
"sr": "Helsinki-NLP/opus-mt-sr-es", # Serbio a Español
"fa": "Helsinki-NLP/opus-mt-fa-es", # Persa a Español
}
asr_pipe = pipeline("automatic-speech-recognition", model="openai/whisper-base", device=device)
vist_model = VitsModel.from_pretrained("facebook/mms-tts-spa")
vist_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-spa")
lang_detector = pipeline("text-classification", model="papluca/xlm-roberta-base-language-detection")
def language_detector(text):
resultado = lang_detector(text)
idioma_detectado = resultado[0]['label']
return idioma_detectado
def translate(audio):
transcribe = asr_pipe(audio, max_new_tokens=256)
codigo_idioma = language_detector(transcribe['text'])
if codigo_idioma in translation_models:
translator = pipeline("translation", model=translation_models[codigo_idioma])
traduccion = translator(transcribe['text'])
else:
transcribe = transcribe['text']
print(f"No hay un modelo de traducción disponible para el idioma detectado {codigo_idioma}")
return transcribe
return traduccion
def synthesise(text):
if isinstance(text, list):
text = text[0]['translation_text']
else:
text = text
print(text)
inputs = vist_tokenizer(text, return_tensors="pt")
with torch.no_grad():
output = vist_model(**inputs).waveform[0]
return output
def speech_to_speech_translation(audio):
translated_text = translate(audio)
synthesised_speech = synthesise(translated_text)
synthesised_speech = (synthesised_speech.numpy() * 32767).astype(np.int16)
return 16000, synthesised_speech
title = "Cascaded STST"
description = """
Demo for cascaded speech-to-speech translation (STST), mapping from source speech in any language to target speech in Spanish.
![Cascaded STST](https://huggingface.co/datasets/huggingface-course/audio-course-images/resolve/main/s2st_cascaded.png "Diagram of cascaded speech to speech translation")
"""
demo = gr.Blocks()
mic_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="microphone", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
title=title,
description=description,
)
file_translate = gr.Interface(
fn=speech_to_speech_translation,
inputs=gr.Audio(sources="upload", type="filepath"),
outputs=gr.Audio(label="Generated Speech", type="numpy"),
examples=[["./example.wav"]],
title=title,
description=description,
)
with demo:
gr.TabbedInterface([mic_translate, file_translate], ["Microphone", "Audio File"])
demo.launch()