import uuid import ffmpeg import gradio as gr from pathlib import Path from denoisers.SpectralGating import SpectralGating from huggingface_hub import hf_hub_download from denoisers.demucs import Demucs import torch import torchaudio import yaml import os os.environ['CURL_CA_BUNDLE'] = '' def denoising_transform(audio, model): src_path = Path("cache_wav/original/{}.wav".format(str(uuid.uuid4()))) tgt_path = Path("cache_wav/denoised/{}.wav".format(str(uuid.uuid4()))) src_path.parent.mkdir(exist_ok=True, parents=True) tgt_path.parent.mkdir(exist_ok=True, parents=True) (ffmpeg.input(audio) .output(src_path.as_posix(), acodec='pcm_s16le', ac=1, ar=22050) .run() ) wav, rate = torchaudio.load(audio) reduced_noise = model.predict(wav) torchaudio.save(tgt_path, reduced_noise, rate) return tgt_path def run_app(model_filename, config_filename): model_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=model_filename) config_path = hf_hub_download(repo_id="BorisovMaksim/demucs", filename=config_filename) with open(config_path, 'r') as f: config = yaml.safe_load(f) model = Demucs(config['demucs']) checkpoint = torch.load(model_path, map_location=torch.device('cpu')) model.load_state_dict(checkpoint['model_state_dict']) interface_demucs = gr.Interface( fn=lambda x: denoising_transform(x, model), inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'), outputs=gr.Audio(label="Demucs", type='filepath'), allow_flagging='never' ) interface_spectral_gating = gr.Interface( fn=lambda x: denoising_transform(x, SpectralGating()), inputs=gr.Audio(label="Source Audio", source="microphone", type='filepath'), outputs=gr.Audio(label="Spectral Gating", type='filepath'), allow_flagging='never' ) gr.Parallel(interface_demucs, interface_spectral_gating, title="Denoising", examples=[[path] for path in Path("testing/wavs/").glob("*.wav")] ).launch(server_name='0.0.0.0', server_port=7860) if __name__ == "__main__": model_filename = "paper_replica_10_epoch/Demucs_replicate_paper_continue_epoch29.pt" config_filename = "paper_replica_10_epoch/config.yaml" run_app(model_filename, config_filename)