ancztxi2 / app.py
Geek7's picture
Update app.py
41eafd2 verified
raw
history blame
931 Bytes
import gradio as gr
import torchaudio
from speechbrain.inference.enhancement import WaveformEnhancement
import torch
# Load the SpeechBrain enhancement model
enhance_model = WaveformEnhancement.from_hparams(
source="speechbrain/mtl-mimic-voicebank",
savedir="pretrained_models/mtl-mimic-voicebank",
)
def enhance_audio(input_audio):
# Load the uploaded audio file
waveform, sample_rate = torchaudio.load(input_audio)
# Enhance the audio
enhanced_waveform = enhance_model.enhance_batch(waveform)
# Save the enhanced audio to a file
output_path = "enhanced_audio.wav"
torchaudio.save(output_path, enhanced_waveform.cpu(), sample_rate)
return output_path
# Set up the Gradio interface
demo = gr.Interface(
fn=enhance_audio,
inputs=gr.Audio(type="filepath"), # Upload an audio file
outputs=gr.Audio(type="filepath"), # Download the enhanced audio
)
demo.launch()