Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import sys | |
import subprocess | |
import gradio as gr | |
import numpy as np | |
import soundfile as sf | |
import librosa | |
import torch | |
import torch.cuda | |
import gc | |
# Check if required packages are installed, if not install them | |
try: | |
from espnet2.bin.s2t_inference import Speech2Text | |
import torchaudio | |
# Try importing espnet_model_zoo specifically | |
try: | |
import espnet_model_zoo | |
print("All packages already installed.") | |
except ModuleNotFoundError: | |
print("Installing espnet_model_zoo. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "espnet_model_zoo"]) | |
import espnet_model_zoo | |
print("espnet_model_zoo installed successfully.") | |
except ModuleNotFoundError as e: | |
missing_module = str(e).split("'")[1] | |
print(f"Installing missing module: {missing_module}") | |
if missing_module == "espnet2": | |
print("Installing ESPnet. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "espnet"]) | |
elif missing_module == "torchaudio": | |
print("Installing torchaudio. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "torchaudio"]) | |
# Try importing again | |
try: | |
from espnet2.bin.s2t_inference import Speech2Text | |
import torchaudio | |
# Also check for espnet_model_zoo | |
try: | |
import espnet_model_zoo | |
except ModuleNotFoundError: | |
print("Installing espnet_model_zoo. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "espnet_model_zoo"]) | |
import espnet_model_zoo | |
print("All required packages installed successfully.") | |
except ModuleNotFoundError as e: | |
print(f"Failed to install {str(e).split('No module named ')[1]}. Please install manually.") | |
raise | |
# Initialize the model with language option | |
def load_model(): | |
# Force garbage collection | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Set memory-efficient options | |
torch.cuda.set_per_process_memory_fraction(0.95) # Use 95% of available memory | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# For memory efficiency, you could try loading with 8-bit quantization | |
# This requires the bitsandbytes library | |
# pip install bitsandbytes | |
model = Speech2Text.from_pretrained( | |
"espnet/owls_4B_180K", | |
task_sym="<asr>", | |
beam_size=1, | |
device=device | |
) | |
return model | |
# Load the model at startup with English as default | |
print("Loading multilingual model...") | |
model = load_model() | |
print("Model loaded successfully!") | |
def transcribe_audio(audio_file, language): | |
"""Process the audio file and return the transcription""" | |
if audio_file is None: | |
return "Please upload an audio file or record audio." | |
# If audio is a tuple (from microphone recording) | |
if isinstance(audio_file, tuple): | |
sr, audio_data = audio_file | |
# Create a temporary file to save the audio | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
temp_path = temp_audio.name | |
sf.write(temp_path, audio_data, sr) | |
audio_file = temp_path | |
# Load and resample the audio file to 16kHz | |
speech, _ = librosa.load(audio_file, sr=16000) | |
# Update the language symbol if needed | |
model.beam_search.hyps = None | |
model.beam_search.pre_beam_score_key = None | |
if language != None: | |
model.lang_sym = language | |
# Perform ASR | |
text, *_ = model(speech)[0] | |
# Clean up temporary file if created | |
if isinstance(audio_file, str) and audio_file.startswith(tempfile.gettempdir()): | |
os.unlink(audio_file) | |
return text | |
# Function to handle English transcription | |
def transcribe_english(audio_file): | |
return transcribe_audio(audio_file, "<eng>") | |
# Function to handle Chinese transcription | |
def transcribe_chinese(audio_file): | |
return transcribe_audio(audio_file, "<zho>") | |
# Function to handle Japanese transcription | |
def transcribe_japanese(audio_file): | |
return transcribe_audio(audio_file, "<jpn>") | |
# Function to handle Korean transcription | |
def transcribe_korean(audio_file): | |
return transcribe_audio(audio_file, "<kor>") | |
# Function to handle Thai transcription | |
def transcribe_thai(audio_file): | |
return transcribe_audio(audio_file, "<tha>") | |
# Function to handle Italian transcription | |
def transcribe_italian(audio_file): | |
return transcribe_audio(audio_file, "<ita>") | |
# Function to handle German transcription | |
def transcribe_german(audio_file): | |
return transcribe_audio(audio_file, "<deu>") | |
# Create the Gradio interface with tabs | |
demo = gr.Blocks(title="NVIDIA Research Multilingual Demo") | |
with demo: | |
gr.Markdown("# NVIDIA Research Multilingual Demo") | |
gr.Markdown("Upload or record audio to transcribe up to 150 human languages using the NVIDIA Research (NVR) 9B model. Audio will be automatically resampled to 16kHz.") | |
with gr.Tabs(): | |
with gr.TabItem("Microphone Recording"): | |
language_mic = gr.Radio( | |
["English", "Mandarin", "Japanese", "Korean", "Thai", "Italian", "German"], | |
label="Select Language", | |
value="English" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio") | |
mic_button = gr.Button("Transcribe Recording") | |
with gr.Column(): | |
mic_output = gr.Textbox(label="Transcription") | |
def transcribe_mic(audio, lang): | |
lang_map = { | |
"English": "<eng>", | |
"Chinese": "<zho>", | |
"Japanese": "<jpn>", | |
"Korean": "<kor>", | |
"Thai": "<tha>", | |
"Italian": "<ita>", | |
"German": "<deu>" | |
} | |
return transcribe_audio(audio, lang_map.get(lang, "<eng>")) | |
mic_button.click(fn=transcribe_mic, inputs=[mic_input, language_mic], outputs=mic_output) | |
with gr.TabItem("English"): | |
with gr.Row(): | |
with gr.Column(): | |
en_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
en_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
en_output = gr.Textbox(label="Speech Transcription") | |
# Add example if the file exists | |
if os.path.exists("wav_en_sample_48k.wav"): | |
gr.Examples( | |
examples=[["wav_en_sample_48k.wav"]], | |
inputs=en_input | |
) | |
en_button.click(fn=transcribe_english, inputs=en_input, outputs=en_output) | |
with gr.TabItem("Mandarin"): | |
with gr.Row(): | |
with gr.Column(): | |
zh_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
zh_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
zh_output = gr.Textbox(label="Speech Transcription") | |
# Add example if the file exists | |
if os.path.exists("wav_zh_tw_sample_16k.wav"): | |
gr.Examples( | |
examples=[["wav_zh_tw_sample_16k.wav"]], | |
inputs=zh_input | |
) | |
zh_button.click(fn=transcribe_chinese, inputs=zh_input, outputs=zh_output) | |
with gr.TabItem("Japanese"): | |
with gr.Row(): | |
with gr.Column(): | |
jp_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
jp_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
jp_output = gr.Textbox(label="Speech Transcription") | |
# Add example if the file exists | |
if os.path.exists("wav_jp_sample_48k.wav"): | |
gr.Examples( | |
examples=[["wav_jp_sample_48k.wav"]], | |
inputs=jp_input | |
) | |
jp_button.click(fn=transcribe_japanese, inputs=jp_input, outputs=jp_output) | |
with gr.TabItem("Korean"): | |
with gr.Row(): | |
with gr.Column(): | |
kr_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
kr_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
kr_output = gr.Textbox(label="Speech Transcription") | |
# Add example if the file exists | |
if os.path.exists("wav_kr_sample_48k.wav"): | |
gr.Examples( | |
examples=[["wav_kr_sample_48k.wav"]], | |
inputs=kr_input | |
) | |
kr_button.click(fn=transcribe_korean, inputs=kr_input, outputs=kr_output) | |
with gr.TabItem("Thai"): | |
with gr.Row(): | |
with gr.Column(): | |
th_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
th_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
th_output = gr.Textbox(label="Speech Transcription") | |
# Add example if the file exists | |
if os.path.exists("wav_thai_sample.wav"): | |
gr.Examples( | |
examples=[["wav_thai_sample.wav"]], | |
inputs=th_input | |
) | |
th_button.click(fn=transcribe_thai, inputs=th_input, outputs=th_output) | |
with gr.TabItem("Italian"): | |
with gr.Row(): | |
with gr.Column(): | |
it_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
it_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
it_output = gr.Textbox(label="Speech Transcription") | |
# Add example if the file exists | |
if os.path.exists("wav_it_sample.wav"): | |
gr.Examples( | |
examples=[["wav_it_sample.wav"]], | |
inputs=it_input | |
) | |
it_button.click(fn=transcribe_italian, inputs=it_input, outputs=it_output) | |
with gr.TabItem("German"): | |
with gr.Row(): | |
with gr.Column(): | |
de_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
de_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
de_output = gr.Textbox(label="Speech Transcription") | |
# Add example if the file exists | |
if os.path.exists("wav_de_sample.wav"): | |
gr.Examples( | |
examples=[["wav_de_sample.wav"]], | |
inputs=de_input | |
) | |
de_button.click(fn=transcribe_german, inputs=de_input, outputs=de_output) | |
# Launch the app with Hugging Face Spaces compatible settings | |
if __name__ == "__main__": | |
demo.launch(share=False) | |