import uuid import ffmpeg import gradio as gr from pathlib import Path from denoisers.SpectralGating import SpectralGating from huggingface_hub import hf_hub_download from denoisers.demucs import Demucs import torch import torchaudio import yaml import argparse import os os.environ['CURL_CA_BUNDLE'] = '' SAMPLE_RATE = 32000 def denoising_transform(audio, model): src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4()))) tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4()))) src_path.parent.mkdir(exist_ok=True, parents=True) tgt_path.parent.mkdir(exist_ok=True, parents=True) (ffmpeg.input(audio) .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=SAMPLE_RATE) .run() ) wav, rate = torchaudio.load(src_path) reduced_noise = model.predict(wav) torchaudio.save(tgt_path, reduced_noise, rate) return src_path, tgt_path def run_app(model_filename, config_filename, port, concurrency_count, max_size): model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename) config_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=config_filename) with open(config_path, 'r') as f: config = yaml.safe_load(f) model = Demucs(config['demucs']) checkpoint = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) title = "Denoising" with gr.Blocks(title=title) as app: with gr.Row(): with gr.Column(): gr.Markdown( """ # Denoising ## Instruction: \n 1. Press "Record from microphone" 2. Press "Stop recording" 3. Press "Enhance" \n - You can switch to the tab "File" to upload a prerecorded .wav audio instead of recording from microphone. """ ) with gr.Tab("Microphone"): microphone = gr.Audio(label="Source Audio", source="microphone", type='filepath') with gr.Row(): microphone_button = gr.Button("Enhance", variant="primary") with gr.Tab("File"): upload = gr.Audio(label="Upload Audio", source="upload", type='filepath') with gr.Row(): upload_button = gr.Button("Enhance", variant="primary") clear_btn = gr.Button("Clear") gr.Examples(examples=[[path] for path in Path("testing/wavs/").glob("*.wav")], inputs=[microphone, upload]) with gr.Column(): outputs = [gr.Audio(label="Input Audio", type='filepath'), gr.Audio(label="Demucs Enhancement", type='filepath'), gr.Audio(label="Spectral Gating Enhancement", type='filepath') ] def submit(audio): src_path, demucs_tgt_path = denoising_transform(audio, model) _, spectral_gating_tgt_path = denoising_transform(audio, SpectralGating()) return src_path, demucs_tgt_path, spectral_gating_tgt_path, gr.update(visible=False), gr.update(visible=False) microphone_button.click( submit, microphone, outputs + [microphone, upload] ) upload_button.click( submit, upload, outputs + [microphone, upload] ) def restart(): return microphone.update(visible=True, value=None), upload.update(visible=True, value=None), None, None, None clear_btn.click(restart, inputs=[], outputs=[microphone, upload] + outputs) app.queue(concurrency_count=concurrency_count, max_size=max_size) app.launch( server_name='0.0.0.0', server_port=port, ) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Running demo.') parser.add_argument('--port', type=int, default=7860) parser.add_argument('--model_filename', type=str, default="paper_replica_10_epoch/Demucs_replicate_paper_continue_epoch45.pt") parser.add_argument('--config_filename', type=str, default="paper_replica_10_epoch/config.yaml") parser.add_argument('--concurrency_count', type=int, default=4) parser.add_argument('--max_size', type=int, default=15) args = parser.parse_args() run_app(args.model_filename, args.config_filename, args.port, args.concurrency_count, args.max_size)