Spaces:
Sleeping
Sleeping
import os | |
import tempfile | |
import sys | |
import subprocess | |
import gradio as gr | |
import numpy as np | |
import soundfile as sf | |
import librosa | |
import torch | |
import torch.cuda | |
import gc | |
import json | |
import datetime | |
from pathlib import Path | |
# Check if required packages are installed, if not install them | |
try: | |
from espnet2.bin.s2t_inference import Speech2Text | |
import torchaudio | |
# Try importing espnet_model_zoo specifically | |
try: | |
import espnet_model_zoo | |
print("All packages already installed.") | |
except ModuleNotFoundError: | |
print("Installing espnet_model_zoo. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "espnet_model_zoo"]) | |
import espnet_model_zoo | |
print("espnet_model_zoo installed successfully.") | |
# Check for opencc-python-reimplemented | |
try: | |
from opencc import OpenCC | |
print("OpenCC already installed.") | |
except ModuleNotFoundError: | |
print("Installing opencc-python-reimplemented. This may take a moment...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "opencc-python-reimplemented"]) | |
from opencc import OpenCC | |
print("OpenCC installed successfully.") | |
except ModuleNotFoundError as e: | |
missing_module = str(e).split("'")[1] | |
print(f"Installing missing module: {missing_module}") | |
if missing_module == "espnet2": | |
print("Installing ESPnet. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "espnet"]) | |
elif missing_module == "torchaudio": | |
print("Installing torchaudio. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "torchaudio"]) | |
# Try importing again | |
try: | |
from espnet2.bin.s2t_inference import Speech2Text | |
import torchaudio | |
# Also check for espnet_model_zoo | |
try: | |
import espnet_model_zoo | |
except ModuleNotFoundError: | |
print("Installing espnet_model_zoo. This may take a few minutes...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "-U", "espnet_model_zoo"]) | |
import espnet_model_zoo | |
# Also check for OpenCC | |
try: | |
from opencc import OpenCC | |
except ModuleNotFoundError: | |
print("Installing opencc-python-reimplemented. This may take a moment...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "opencc-python-reimplemented"]) | |
from opencc import OpenCC | |
print("All required packages installed successfully.") | |
except ModuleNotFoundError as e: | |
print(f"Failed to install {str(e).split('No module named ')[1]}. Please install manually.") | |
raise | |
# Initialize the model with language option | |
def load_model(): | |
# Force garbage collection | |
gc.collect() | |
torch.cuda.empty_cache() | |
# Set memory-efficient options | |
torch.cuda.set_per_process_memory_fraction(0.95) # Use 95% of available memory | |
# Check if CUDA is available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# For memory efficiency, you could try loading with 8-bit quantization | |
# This requires the bitsandbytes library | |
# pip install bitsandbytes | |
model = Speech2Text.from_pretrained( | |
"espnet/owls_4B_180K", | |
task_sym="<asr>", | |
beam_size=1, | |
device=device | |
) | |
return model | |
# Load the model at startup with English as default | |
print("Loading multilingual model...") | |
model = load_model() | |
print("Model loaded successfully!") | |
def transcribe_audio(audio_file, language): | |
"""Process the audio file and return the transcription""" | |
if audio_file is None: | |
return "Please upload an audio file or record audio." | |
# If audio is a tuple (from microphone recording) | |
if isinstance(audio_file, tuple): | |
sr, audio_data = audio_file | |
# Create a temporary file to save the audio | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
temp_path = temp_audio.name | |
sf.write(temp_path, audio_data, sr) | |
audio_file = temp_path | |
# Load and resample the audio file to 16kHz | |
speech, _ = librosa.load(audio_file, sr=16000) | |
# Update the language symbol if needed | |
model.beam_search.hyps = None | |
model.beam_search.pre_beam_score_key = None | |
if language != None: | |
model.lang_sym = language | |
# Perform ASR | |
text, *_ = model(speech)[0] | |
# Clean up temporary file if created | |
if isinstance(audio_file, str) and audio_file.startswith(tempfile.gettempdir()): | |
os.unlink(audio_file) | |
return text | |
# Function to handle English transcription | |
def transcribe_english(audio_file): | |
return transcribe_audio(audio_file, "<eng>") | |
# Function to handle Chinese transcription | |
def transcribe_chinese(audio_file, chinese_variant="Traditional"): | |
""" | |
Process the audio file and return Chinese transcription in simplified or traditional characters | |
Args: | |
audio_file: Path to the audio file | |
chinese_variant: Either "Simplified" or "Traditional" | |
""" | |
# First get the base transcription | |
asr_text = transcribe_audio(audio_file, "<zho>") | |
# Convert between simplified and traditional Chinese if needed | |
if chinese_variant == "Traditional": | |
# Convert simplified to traditional | |
# Use s2t for more complete conversion from Simplified to Traditional | |
cc = OpenCC('s2twp') # s2twp: Simplified to Traditional (Taiwan) | |
asr_text = cc.convert(asr_text) | |
cc = OpenCC('s2t') # s2t | |
asr_text = cc.convert(asr_text) | |
elif chinese_variant == "Simplified" and not asr_text.isascii(): | |
# If the text contains non-ASCII characters, it might be traditional | |
# Convert traditional to simplified just to be safe | |
cc = OpenCC('t2s') # t2s: Traditional to Simplified | |
asr_text = cc.convert(asr_text) | |
return asr_text | |
# Function to handle Japanese transcription | |
def transcribe_japanese(audio_file): | |
return transcribe_audio(audio_file, "<jpn>") | |
# Function to handle Korean transcription | |
def transcribe_korean(audio_file): | |
return transcribe_audio(audio_file, "<kor>") | |
# Function to handle Thai transcription | |
def transcribe_thai(audio_file): | |
return transcribe_audio(audio_file, "<tha>") | |
# Function to handle Italian transcription | |
def transcribe_italian(audio_file): | |
return transcribe_audio(audio_file, "<ita>") | |
# Function to handle German transcription | |
def transcribe_german(audio_file): | |
return transcribe_audio(audio_file, "<deu>") | |
# Create a function to save feedback | |
def save_feedback(transcription, rating, language, audio_path=None): | |
"""Save user feedback to a JSON file""" | |
# Create feedback directory if it doesn't exist | |
feedback_dir = Path("feedback_data") | |
feedback_dir.mkdir(exist_ok=True) | |
# Create a unique filename based on timestamp | |
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") | |
feedback_file = feedback_dir / f"feedback_{timestamp}.json" | |
# Prepare feedback data | |
feedback_data = { | |
"timestamp": timestamp, | |
"language": language, | |
"transcription": transcription, | |
"rating": rating, | |
"audio_path": str(audio_path) if audio_path else None | |
} | |
# Save to JSON file | |
with open(feedback_file, "w", encoding="utf-8") as f: | |
json.dump(feedback_data, f, ensure_ascii=False, indent=2) | |
return "🪂 Thank you for your feedback!" | |
# Create the Gradio interface with tabs | |
demo = gr.Blocks(title="NVIDIA Research Multilingual Demo") | |
with demo: | |
gr.Markdown("# NVIDIA Research Multilingual Demo") | |
gr.Markdown("Upload or record audio to transcribe up to 150 human languages using the NVIDIA Research (NVR) 4B model. Audio will be automatically resampled to 16kHz.") | |
gr.Markdown("You can choose 🎙️ your microphone or 💻 upload an audio file in the tag next to Microphone Recording. The file will be deleted after the demo ends.") | |
with gr.Tabs(): | |
with gr.TabItem("Microphone Recording"): | |
language_mic = gr.Radio( | |
["English", "Mandarin", "Japanese", "Korean", "Thai", "Italian", "German"], | |
label="Select Language", | |
value="English" | |
) | |
# Add Chinese variant selection that appears only when Mandarin is selected | |
chinese_variant_mic = gr.Radio( | |
["Traditional", "Simplified"], | |
label="Mandarin User Desired Output ➡️ zh-cn: Simplified or zh-tw: Traditional", | |
value="Traditional", | |
visible=False | |
) | |
# Make Chinese variant selection visible only when Mandarin is selected | |
def update_chinese_variant_visibility(lang): | |
return gr.update(visible=(lang == "Mandarin")) | |
language_mic.change( | |
fn=update_chinese_variant_visibility, | |
inputs=language_mic, | |
outputs=chinese_variant_mic | |
) | |
with gr.Row(): | |
with gr.Column(): | |
mic_input = gr.Audio(sources=["microphone"], type="filepath", label="Record Audio") | |
mic_button = gr.Button("Transcribe Recording") | |
with gr.Column(): | |
mic_output = gr.Textbox(label="Transcription") | |
# Add feedback components | |
with gr.Row(): | |
mic_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
mic_feedback_btn = gr.Button("Submit Feedback") | |
mic_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
def transcribe_mic(audio, lang, chinese_variant=None): | |
lang_map = { | |
"English": "<eng>", | |
"Mandarin": "<zho>", | |
"Japanese": "<jpn>", | |
"Korean": "<kor>", | |
"Thai": "<tha>", | |
"Italian": "<ita>", | |
"German": "<deu>" | |
} | |
# Special handling for Chinese with variant selection | |
if lang == "Mandarin" and chinese_variant: | |
return transcribe_chinese(audio, chinese_variant) | |
return transcribe_audio(audio, lang_map.get(lang, "<eng>")) | |
mic_button.click(fn=transcribe_mic, inputs=[mic_input, language_mic, chinese_variant_mic], outputs=mic_output) | |
# Add feedback submission function | |
def submit_mic_feedback(transcription, rating, language, chinese_variant): | |
lang_name = language # Already a string like "English" | |
return save_feedback(transcription, rating, f"{lang_name} ({chinese_variant})", None) | |
mic_feedback_btn.click( | |
fn=submit_mic_feedback, | |
inputs=[mic_output, mic_rating, language_mic, chinese_variant_mic], | |
outputs=mic_feedback_msg | |
) | |
with gr.TabItem("English"): | |
with gr.Row(): | |
with gr.Column(): | |
en_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
en_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
en_output = gr.Textbox(label="Speech Transcription") | |
# Add feedback components | |
with gr.Row(): | |
en_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
en_feedback_btn = gr.Button("Submit Feedback") | |
en_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
# Add example if the file exists | |
if os.path.exists("wav_en_sample_48k.wav"): | |
gr.Examples( | |
examples=[["wav_en_sample_48k.wav"]], | |
inputs=en_input | |
) | |
en_button.click(fn=transcribe_english, inputs=en_input, outputs=en_output) | |
# Add feedback submission | |
def submit_en_feedback(transcription, rating, audio_path): | |
return save_feedback(transcription, rating, "English", audio_path) | |
en_feedback_btn.click( | |
fn=submit_en_feedback, | |
inputs=[en_output, en_rating, en_input], | |
outputs=en_feedback_msg | |
) | |
with gr.TabItem("Mandarin"): | |
# Add Chinese variant selection | |
chinese_variant = gr.Radio( | |
["Traditional", "Simplified"], | |
label="Mandarin User Desired Output ➡️ zh-cn: Simplified or zh-tw: Traditional", | |
value="Traditional" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
zh_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
zh_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
zh_output = gr.Textbox(label="Speech Transcription") | |
# Add feedback components | |
with gr.Row(): | |
zh_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
zh_feedback_btn = gr.Button("Submit Feedback") | |
zh_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
# Add example if the file exists | |
if os.path.exists("wav_zh_tw_sample_16k.wav"): | |
gr.Examples( | |
examples=[["wav_zh_tw_sample_16k.wav"]], | |
inputs=zh_input | |
) | |
# Update the click function to include the Chinese variant | |
def transcribe_chinese_with_variant(audio_file, variant): | |
return transcribe_chinese(audio_file, variant.lower()) | |
zh_button.click(fn=transcribe_chinese_with_variant, inputs=[zh_input, chinese_variant], outputs=zh_output) | |
# Update feedback submission to include variant | |
def submit_zh_feedback(transcription, rating, audio_path, variant): | |
return save_feedback(transcription, rating, f"Mandarin ({variant})", audio_path) | |
zh_feedback_btn.click( | |
fn=submit_zh_feedback, | |
inputs=[zh_output, zh_rating, zh_input, chinese_variant], | |
outputs=zh_feedback_msg | |
) | |
with gr.TabItem("Japanese"): | |
with gr.Row(): | |
with gr.Column(): | |
jp_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
jp_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
jp_output = gr.Textbox(label="Speech Transcription") | |
# Add feedback components | |
with gr.Row(): | |
jp_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
jp_feedback_btn = gr.Button("Submit Feedback") | |
jp_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
# Add example if the file exists | |
if os.path.exists("wav_jp_sample_48k.wav"): | |
gr.Examples( | |
examples=[["wav_jp_sample_48k.wav"]], | |
inputs=jp_input | |
) | |
jp_button.click(fn=transcribe_japanese, inputs=jp_input, outputs=jp_output) | |
# Add feedback submission | |
def submit_jp_feedback(transcription, rating, audio_path): | |
return save_feedback(transcription, rating, "Japanese", audio_path) | |
jp_feedback_btn.click( | |
fn=submit_jp_feedback, | |
inputs=[jp_output, jp_rating, jp_input], | |
outputs=jp_feedback_msg | |
) | |
with gr.TabItem("Korean"): | |
with gr.Row(): | |
with gr.Column(): | |
kr_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
kr_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
kr_output = gr.Textbox(label="Speech Transcription") | |
# Add feedback components | |
with gr.Row(): | |
kr_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
kr_feedback_btn = gr.Button("Submit Feedback") | |
kr_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
# Add example if the file exists | |
if os.path.exists("wav_kr_sample_48k.wav"): | |
gr.Examples( | |
examples=[["wav_kr_sample_48k.wav"]], | |
inputs=kr_input | |
) | |
kr_button.click(fn=transcribe_korean, inputs=kr_input, outputs=kr_output) | |
# Add feedback submission | |
def submit_kr_feedback(transcription, rating, audio_path): | |
return save_feedback(transcription, rating, "Korean", audio_path) | |
kr_feedback_btn.click( | |
fn=submit_kr_feedback, | |
inputs=[kr_output, kr_rating, kr_input], | |
outputs=kr_feedback_msg | |
) | |
with gr.TabItem("Thai"): | |
with gr.Row(): | |
with gr.Column(): | |
th_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
th_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
th_output = gr.Textbox(label="Speech Transcription") | |
# Add feedback components | |
with gr.Row(): | |
th_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
th_feedback_btn = gr.Button("Submit Feedback") | |
th_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
# Add example if the file exists | |
if os.path.exists("wav_thai_sample.wav"): | |
gr.Examples( | |
examples=[["wav_thai_sample.wav"]], | |
inputs=th_input | |
) | |
th_button.click(fn=transcribe_thai, inputs=th_input, outputs=th_output) | |
# Add feedback submission | |
def submit_th_feedback(transcription, rating, audio_path): | |
return save_feedback(transcription, rating, "Thai", audio_path) | |
th_feedback_btn.click( | |
fn=submit_th_feedback, | |
inputs=[th_output, th_rating, th_input], | |
outputs=th_feedback_msg | |
) | |
with gr.TabItem("Italian"): | |
with gr.Row(): | |
with gr.Column(): | |
it_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
it_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
it_output = gr.Textbox(label="Speech Transcription") | |
# Add feedback components | |
with gr.Row(): | |
it_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
it_feedback_btn = gr.Button("Submit Feedback") | |
it_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
# Add example if the file exists | |
if os.path.exists("wav_it_sample.wav"): | |
gr.Examples( | |
examples=[["wav_it_sample.wav"]], | |
inputs=it_input | |
) | |
it_button.click(fn=transcribe_italian, inputs=it_input, outputs=it_output) | |
# Add feedback submission | |
def submit_it_feedback(transcription, rating, audio_path): | |
return save_feedback(transcription, rating, "Italian", audio_path) | |
it_feedback_btn.click( | |
fn=submit_it_feedback, | |
inputs=[it_output, it_rating, it_input], | |
outputs=it_feedback_msg | |
) | |
with gr.TabItem("German"): | |
with gr.Row(): | |
with gr.Column(): | |
de_input = gr.Audio(sources=["upload"], type="filepath", label="Upload Audio") | |
de_button = gr.Button("Transcribe Speech") | |
with gr.Column(): | |
de_output = gr.Textbox(label="Speech Transcription") | |
# Add feedback components | |
with gr.Row(): | |
de_rating = gr.Slider(minimum=1, maximum=5, step=1, value=3, | |
label="Rate the transcription quality (1=worst, 5=best)") | |
de_feedback_btn = gr.Button("Submit Feedback") | |
de_feedback_msg = gr.Textbox(label="Feedback Status", visible=True) | |
# Add example if the file exists | |
if os.path.exists("wav_de_sample.wav"): | |
gr.Examples( | |
examples=[["wav_de_sample.wav"]], | |
inputs=de_input | |
) | |
de_button.click(fn=transcribe_german, inputs=de_input, outputs=de_output) | |
# Add feedback submission | |
def submit_de_feedback(transcription, rating, audio_path): | |
return save_feedback(transcription, rating, "German", audio_path) | |
de_feedback_btn.click( | |
fn=submit_de_feedback, | |
inputs=[de_output, de_rating, de_input], | |
outputs=de_feedback_msg | |
) | |
# Launch the app with Hugging Face Spaces compatible settings | |
if __name__ == "__main__": | |
demo.launch(share=False) | |