|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
|
|
from ...core import AudioSignal |
|
from ...core import STFTParams |
|
from ...core import util |
|
|
|
|
|
class SpectralGate(nn.Module): |
|
"""Spectral gating algorithm for noise reduction, |
|
as in Audacity/Ocenaudio. The steps are as follows: |
|
|
|
1. An FFT is calculated over the noise audio clip |
|
2. Statistics are calculated over FFT of the the noise |
|
(in frequency) |
|
3. A threshold is calculated based upon the statistics |
|
of the noise (and the desired sensitivity of the algorithm) |
|
4. An FFT is calculated over the signal |
|
5. A mask is determined by comparing the signal FFT to the |
|
threshold |
|
6. The mask is smoothed with a filter over frequency and time |
|
7. The mask is appled to the FFT of the signal, and is inverted |
|
|
|
Implementation inspired by Tim Sainburg's noisereduce: |
|
|
|
https://timsainburg.com/noise-reduction-python.html |
|
|
|
Parameters |
|
---------- |
|
n_freq : int, optional |
|
Number of frequency bins to smooth by, by default 3 |
|
n_time : int, optional |
|
Number of time bins to smooth by, by default 5 |
|
""" |
|
|
|
def __init__(self, n_freq: int = 3, n_time: int = 5): |
|
super().__init__() |
|
|
|
smoothing_filter = torch.outer( |
|
torch.cat( |
|
[ |
|
torch.linspace(0, 1, n_freq + 2)[:-1], |
|
torch.linspace(1, 0, n_freq + 2), |
|
] |
|
)[..., 1:-1], |
|
torch.cat( |
|
[ |
|
torch.linspace(0, 1, n_time + 2)[:-1], |
|
torch.linspace(1, 0, n_time + 2), |
|
] |
|
)[..., 1:-1], |
|
) |
|
smoothing_filter = smoothing_filter / smoothing_filter.sum() |
|
smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0) |
|
self.register_buffer("smoothing_filter", smoothing_filter) |
|
|
|
def forward( |
|
self, |
|
audio_signal: AudioSignal, |
|
nz_signal: AudioSignal, |
|
denoise_amount: float = 1.0, |
|
n_std: float = 3.0, |
|
win_length: int = 2048, |
|
hop_length: int = 512, |
|
): |
|
"""Perform noise reduction. |
|
|
|
Parameters |
|
---------- |
|
audio_signal : AudioSignal |
|
Audio signal that noise will be removed from. |
|
nz_signal : AudioSignal, optional |
|
Noise signal to compute noise statistics from. |
|
denoise_amount : float, optional |
|
Amount to denoise by, by default 1.0 |
|
n_std : float, optional |
|
Number of standard deviations above which to consider |
|
noise, by default 3.0 |
|
win_length : int, optional |
|
Length of window for STFT, by default 2048 |
|
hop_length : int, optional |
|
Hop length for STFT, by default 512 |
|
|
|
Returns |
|
------- |
|
AudioSignal |
|
Denoised audio signal. |
|
""" |
|
stft_params = STFTParams(win_length, hop_length, "sqrt_hann") |
|
|
|
audio_signal = audio_signal.clone() |
|
audio_signal.stft_data = None |
|
audio_signal.stft_params = stft_params |
|
|
|
nz_signal = nz_signal.clone() |
|
nz_signal.stft_params = stft_params |
|
|
|
nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10() |
|
nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1) |
|
nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1) |
|
|
|
nz_thresh = nz_freq_mean + nz_freq_std * n_std |
|
|
|
stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10() |
|
nb, nac, nf, nt = stft_db.shape |
|
db_thresh = nz_thresh.expand(nb, nac, -1, nt) |
|
|
|
stft_mask = (stft_db < db_thresh).float() |
|
shape = stft_mask.shape |
|
|
|
stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt) |
|
pad_tuple = ( |
|
self.smoothing_filter.shape[-2] // 2, |
|
self.smoothing_filter.shape[-1] // 2, |
|
) |
|
stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple) |
|
stft_mask = stft_mask.reshape(*shape) |
|
stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to( |
|
audio_signal.device |
|
) |
|
stft_mask = 1 - stft_mask |
|
|
|
audio_signal.stft_data *= stft_mask |
|
audio_signal.istft() |
|
|
|
return audio_signal |
|
|