import math import tempfile import gradio import gradio.inputs import gradio.outputs import matplotlib.pyplot as plt import numpy as np import torch from df import config from df.enhance import enhance, init_df, load_audio, save_audio device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model, df, _ = init_df() model = model.to(device=device).eval() def mix_at_snr(clean, noise, snr, eps=1e-10): """Mix clean and noise signal at a given SNR. Args: clean: 1D Tensor with the clean signal to mix. noise: 1D Tensor of shape. snr: Signal to noise ratio. Returns: clean: 1D Tensor with gain changed according to the snr. noise: 1D Tensor with the combined noise channels. mix: 1D Tensor with added clean and noise signals. """ clean = torch.as_tensor(clean).mean(0, keepdim=True) noise = torch.as_tensor(noise).mean(0, keepdim=True) if noise.shape[1] < clean.shape[1]: noise = noise.repeat((1, int(math.ceil(clean.shape[1] / noise.shape[1])))) max_start = int(noise.shape[1] - clean.shape[1]) start = torch.randint(0, max_start, ()).item() noise = noise[:, start : start + clean.shape[1]] E_speech = torch.mean(clean.pow(2)) + eps E_noise = torch.mean(noise.pow(2)) K = torch.sqrt((E_noise / E_speech) * 10 ** (snr / 10) + eps) noise = noise / K mixture = clean + noise assert torch.isfinite(mixture).all() return clean, noise, mixture def mix_and_denoise(speech, speech_alt, noise, snr): print(speech, noise, snr) if noise is None: noise = "samples/dkitchen.wav" if speech is None or speech == "": speech = "samples/p232_013_clean.wav" if speech_alt is not None: speech = speech_alt print(speech, noise, snr) sr = config("sr", 48000, int, section="df") speech, _ = load_audio(speech, sr) noise, _ = load_audio(noise, sr) speech, noise, noisy = mix_at_snr(speech, noise, snr) enhanced = enhance(model, df, noisy) lim = torch.linspace(0.0, 1.0, int(sr * 0.15)).unsqueeze(0) lim = torch.cat((lim, torch.ones(1, enhanced.shape[1] - lim.shape[1])), dim=1) print("lim", lim.shape, enhanced.shape) enhanced = enhanced * lim noisy_fn = tempfile.NamedTemporaryFile(suffix="noisy.wav", delete=False).name save_audio(noisy_fn, noisy, sr) enhanced_fn = tempfile.NamedTemporaryFile(suffix="enhanced.wav", delete=False).name save_audio(enhanced_fn, enhanced, sr) return ( "noisy.wav", spec_figure(noisy, sr=sr), "enhanced.wav", spec_figure(enhanced, sr=sr), ) def specshow( spec, ax=None, title=None, xlabel=None, ylabel=None, sr=48000, n_fft=None, hop=None, t=None, f=None, vmin=-100, vmax=0, xlim=None, ylim=None, cmap="viridis", ): """Plots a spectrogram of shape [F, T]""" spec_np = spec.cpu().numpy() if isinstance(spec, torch.Tensor) else spec if ax is not None: set_title = ax.set_title set_xlabel = ax.set_xlabel set_ylabel = ax.set_ylabel set_xlim = ax.set_xlim set_ylim = ax.set_ylim else: ax = plt set_title = plt.title set_xlabel = plt.xlabel set_ylabel = plt.ylabel set_xlim = plt.xlim set_ylim = plt.ylim if n_fft is None: if spec.shape[0] % 2 == 0: n_fft = spec.shape[0] * 2 else: n_fft = (spec.shape[0] - 1) * 2 hop = hop or n_fft // 4 if t is None: t = np.arange(0, spec_np.shape[-1]) * hop / sr if f is None: f = np.arange(0, spec_np.shape[0]) * sr // 2 / (n_fft // 2) / 1000 im = ax.pcolormesh( t, f, spec_np, rasterized=True, shading="auto", vmin=vmin, vmax=vmax, cmap=cmap ) if title is not None: set_title(title) if xlabel is not None: set_xlabel(xlabel) if ylabel is not None: set_ylabel(ylabel) if xlim is not None: set_xlim(xlim) if ylim is not None: set_ylim(ylim) return im def spec_figure( audio: torch.Tensor, figsize=(15, 5), colorbar=False, colorbar_format=None, figure=None, return_im=False, labels=True, **kwargs, ) -> plt.Figure: audio = torch.as_tensor(audio) if labels: kwargs.setdefault("xlabel", "Time [s]") kwargs.setdefault("ylabel", "Frequency [Hz]") n_fft = kwargs.setdefault("n_fft", 1024) hop = kwargs.setdefault("hop", 512) w = torch.hann_window(n_fft, device=audio.device) spec = torch.stft(audio, n_fft, hop, window=w, return_complex=False) spec = spec.div_(w.pow(2).sum()) spec = torch.view_as_complex(spec).abs().clamp_min(1e-12).log10().mul(10) kwargs.setdefault("vmax", max(0.0, spec.max().item())) if figure is None: figure = plt.figure(figsize=figsize) figure.set_tight_layout(True) if spec.dim() > 2: spec = spec.squeeze(0) im = specshow(spec, **kwargs) if colorbar: ckwargs = {} if "ax" in kwargs: if colorbar_format is None: if ( kwargs.get("vmin", None) is not None or kwargs.get("vmax", None) is not None ): colorbar_format = "%+2.0f dB" ckwargs = {"ax": kwargs["ax"]} plt.colorbar(im, format=colorbar_format, **ckwargs) if return_im: return im return figure inputs = [ gradio.inputs.Audio( source="microphone", type="filepath", optional=True, label="Record your own voice", ), gradio.inputs.Audio( source="upload", type="filepath", optional=True, label="Alternative: Upload speech sample", ), gradio.inputs.Audio( source="upload", type="filepath", optional=True, label="Upload noise sample" ), gradio.inputs.Slider(minimum=-20, maximum=40, step=5, default=10), ] examples = [ [ "samples/p232_013_clean.wav", "samples/p232_013_clean.wav", "samples/dkitchen.wav", 10, ], [ "samples/p232_013_clean.wav", "samples/p232_019_clean.wav", "samples/dliving.wav", 10, ], ] outputs = [ gradio.outputs.Audio(label="Noisy"), gradio.outputs.Image(type="plot"), gradio.outputs.Audio(label="Enhanced"), gradio.outputs.Image(type="plot"), ] description = ( "This demo denoises audio files using DeepFilterNet. Try it with your own voice!" ) iface = gradio.Interface( fn=mix_and_denoise, title="DeepFilterNet Demo", inputs=inputs, outputs=outputs, examples=examples, description=description, layout="horizontal", allow_flagging="never", ) iface.launch(cache_examples=False)