Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
https://github.com/xiph/rnnoise | |
https://github.com/xiph/rnnoise/blob/main/torch/rnnoise/rnnoise.py | |
https://arxiv.org/abs/1709.08243 | |
""" | |
import os | |
from typing import Optional, Union, Tuple | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from toolbox.torch.sparsification.gru_sparsifier import GRUSparsifier | |
from toolbox.torchaudio.models.rnnoise.configuration_rnnoise import RNNoiseConfig | |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT | |
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands | |
sparsify_start = 6000 | |
sparsify_stop = 20000 | |
sparsify_interval = 100 | |
sparsify_exponent = 3 | |
sparse_params1 = { | |
"W_hr" : (0.3, [8, 4], True), | |
"W_hz" : (0.2, [8, 4], True), | |
"W_hn" : (0.5, [8, 4], True), | |
"W_ir" : (0.3, [8, 4], False), | |
"W_iz" : (0.2, [8, 4], False), | |
"W_in" : (0.5, [8, 4], False), | |
} | |
def init_weights(module): | |
if isinstance(module, nn.GRU): | |
for p in module.named_parameters(): | |
if p[0].startswith("weight_hh_"): | |
nn.init.orthogonal_(p[1]) | |
class RNNoise(nn.Module): | |
def __init__(self, | |
sample_rate: int = 8000, | |
nfft: int = 512, | |
win_size: int = 512, | |
hop_size: int = 256, | |
win_type: str = "hann", | |
erb_bins: int = 32, | |
min_freq_bins_for_erb: int = 2, | |
conv_size: int = 128, | |
gru_size: int = 256, | |
): | |
super(RNNoise, self).__init__() | |
self.sample_rate = sample_rate | |
self.nfft = nfft | |
self.win_size = win_size | |
self.hop_size = hop_size | |
self.win_type = win_type | |
self.erb_bins = erb_bins | |
self.min_freq_bins_for_erb = min_freq_bins_for_erb | |
self.conv_size = conv_size | |
self.gru_size = gru_size | |
self.input_dim = nfft // 2 + 1 | |
self.eps = 1e-12 | |
self.erb_bands = ErbBands( | |
sample_rate=self.sample_rate, | |
nfft=self.nfft, | |
erb_bins=self.erb_bins, | |
min_freq_bins_for_erb=self.min_freq_bins_for_erb, | |
) | |
self.stft = ConvSTFT( | |
nfft=self.nfft, | |
win_size=self.win_size, | |
hop_size=self.hop_size, | |
win_type=self.win_type, | |
power=None, | |
requires_grad=False | |
) | |
self.istft = ConviSTFT( | |
nfft=self.nfft, | |
win_size=self.win_size, | |
hop_size=self.hop_size, | |
win_type=self.win_type, | |
requires_grad=False | |
) | |
self.pad = nn.ConstantPad1d(padding=(2, 2), value=0) | |
self.conv1 = nn.Conv1d(self.erb_bins, conv_size, kernel_size=3, padding="valid") | |
self.conv2 = nn.Conv1d(conv_size, gru_size, kernel_size=3, padding="valid") | |
self.gru1 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) | |
self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) | |
self.gru3 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) | |
self.dense_out = nn.Linear(4*self.gru_size, self.erb_bins) | |
nb_params = sum(p.numel() for p in self.parameters()) | |
print(f"model: {nb_params} weights") | |
self.apply(init_weights) | |
self.sparsifier = [ | |
GRUSparsifier( | |
task_list=[(self.gru1, sparse_params1)], | |
start=sparsify_start, | |
stop=sparsify_stop, | |
interval=sparsify_interval, | |
exponent=sparsify_exponent, | |
), | |
GRUSparsifier( | |
task_list=[(self.gru2, sparse_params1)], | |
start=sparsify_start, | |
stop=sparsify_stop, | |
interval=sparsify_interval, | |
exponent=sparsify_exponent, | |
), | |
GRUSparsifier( | |
task_list=[(self.gru3, sparse_params1)], | |
start=sparsify_start, | |
stop=sparsify_stop, | |
interval=sparsify_interval, | |
exponent=sparsify_exponent, | |
) | |
] | |
def sparsify(self): | |
for sparsifier in self.sparsifier: | |
sparsifier.step() | |
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: | |
if signal.dim() == 2: | |
signal = torch.unsqueeze(signal, dim=1) | |
_, _, n_samples = signal.shape | |
remainder = (n_samples - self.win_size) % self.hop_size | |
if remainder > 0: | |
n_samples_pad = self.hop_size - remainder | |
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) | |
return signal | |
def forward(self, | |
noisy: torch.Tensor, | |
states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None, | |
): | |
num_samples = noisy.shape[-1] | |
noisy = self.signal_prepare(noisy) | |
batch_size, _, num_samples_pad = noisy.shape | |
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") | |
mag_noisy, pha_noisy = self.mag_pha_stft(noisy) | |
# shape: (b, f, t) | |
# t = (num_samples - win_size) / hop_size + 1 | |
mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2) | |
# shape: (b, t, f) | |
mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True) | |
# shape: (b, t, erb_bins) | |
mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2) | |
# shape: (b, erb_bins, t) | |
mag_noisy_t_erb = self.pad(mag_noisy_t_erb) | |
mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb) | |
gru_out, states = self.forward_gru(mag_noisy_t_erb, states) | |
# gru_out shape: [b, t, f] | |
mask_erb = torch.sigmoid(self.dense_out(gru_out)) | |
# mask_erb shape: (b, t, erb_bins) | |
mask = self.erb_bands.erb_scale_inv(mask_erb) | |
# mask shape: (b, t, f) | |
mask = torch.transpose(mask, dim0=1, dim1=2) | |
# mask shape: (b, f, t) | |
stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) | |
denoise = self.istft.forward(stft_denoise) | |
# denoise shape: [b, 1, num_samples_pad] | |
denoise = denoise[:, :, :num_samples] | |
# denoise shape: [b, 1, num_samples] | |
return denoise, mask, states | |
def forward_conv(self, mag_noisy: torch.Tensor): | |
# mag_noisy shape: [b, f, t] | |
tmp = mag_noisy | |
# tmp shape: [b, f, t] | |
tmp = torch.tanh(self.conv1(tmp)) | |
tmp = torch.tanh(self.conv2(tmp)) | |
# tmp shape: [b, f, t] | |
return tmp | |
def forward_gru(self, | |
mag_noisy: torch.Tensor, | |
states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None, | |
): | |
if states is None: | |
gru1_state = None | |
gru2_state = None | |
gru3_state = None | |
else: | |
gru1_state = states[0] | |
gru2_state = states[1] | |
gru3_state = states[2] | |
# mag_noisy shape: [b, f, t] | |
tmp = mag_noisy.permute(0, 2, 1) | |
# tmp shape: [b, t, f] | |
gru1_out, gru1_state = self.gru1(tmp, gru1_state) | |
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) | |
gru3_out, gru3_state = self.gru3(gru2_out, gru3_state) | |
new_states = [gru1_state, gru2_state, gru3_state] | |
gru_out = torch.cat(tensors=[tmp, gru1_out, gru2_out, gru3_out], dim=-1) | |
# gru_out shape: [b, t, f] | |
return gru_out, new_states | |
def forward_chunk_by_chunk(self, | |
noisy: torch.Tensor, | |
): | |
noisy = self.signal_prepare(noisy) | |
b, _, num_samples = noisy.shape | |
t = (num_samples - self.win_size) / self.hop_size + 1 | |
waveform = torch.zeros(size=(b, 1, 0), dtype=torch.float32) | |
states = None | |
waveform_cache = None | |
coff_cache = None | |
cache_list = list() | |
for i in range(int(t)): | |
begin = i * self.hop_size | |
end = begin + self.win_size | |
sub_noisy = noisy[:, :, begin:end] | |
mag_noisy, pha_noisy = self.mag_pha_stft(sub_noisy) | |
mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2) | |
mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True) | |
mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2) | |
# mag_noisy_t_erb shape: (b, erb_bins, t) | |
if len(cache_list) == 0: | |
cache_list.extend([{ | |
"mag_noisy": torch.zeros_like(mag_noisy), | |
"pha_noisy": torch.zeros_like(pha_noisy), | |
"mag_noisy_t_erb": torch.zeros_like(mag_noisy_t_erb), | |
}] * 2) | |
cache_list.append({ | |
"mag_noisy": mag_noisy, | |
"pha_noisy": pha_noisy, | |
"mag_noisy_t_erb": mag_noisy_t_erb, | |
}) | |
if len(cache_list) < 5: | |
continue | |
mag_noisy_t_erb = torch.concat( | |
tensors=[c["mag_noisy_t_erb"] for c in cache_list], | |
dim=-1 | |
) | |
mag_noisy = cache_list[2]["mag_noisy"] | |
pha_noisy = cache_list[2]["pha_noisy"] | |
cache_list.pop(0) | |
# mag_noisy_t_erb shape: [b, f, 5] | |
mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb) | |
# mag_noisy_t_erb shape: [b, f, 1] | |
gru_out, states = self.forward_gru(mag_noisy_t_erb, states) | |
mask_erb = torch.sigmoid(self.dense_out(gru_out)) | |
mask = self.erb_bands.erb_scale_inv(mask_erb) | |
mask = torch.transpose(mask, dim0=1, dim1=2) | |
stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) | |
sub_waveform, waveform_cache, coff_cache = self.istft.forward_chunk(stft_denoise, waveform_cache, coff_cache) | |
waveform = torch.concat(tensors=[waveform, sub_waveform], dim=-1) | |
return waveform | |
def do_mask(self, | |
mag_noisy: torch.Tensor, | |
pha_noisy: torch.Tensor, | |
mask: torch.Tensor, | |
): | |
# (b, f, t) | |
mag_denoise = mag_noisy * mask | |
stft_denoise = mag_denoise * torch.exp((1j * pha_noisy)) | |
return stft_denoise | |
def mag_pha_stft(self, noisy: torch.Tensor): | |
# noisy shape: [b, num_samples] | |
stft_noisy = self.stft.forward(noisy) | |
# stft_noisy shape: [b, f, t], torch.complex64 | |
real = torch.real(stft_noisy) | |
imag = torch.imag(stft_noisy) | |
mag_noisy = torch.sqrt(real ** 2 + imag ** 2) | |
pha_noisy = torch.atan2(imag, real) | |
# shape: (b, f, t) | |
return mag_noisy, pha_noisy | |
MODEL_FILE = "model.pt" | |
class RNNoisePretrainedModel(RNNoise): | |
def __init__(self, | |
config: RNNoiseConfig, | |
): | |
super(RNNoisePretrainedModel, self).__init__( | |
sample_rate=config.sample_rate, | |
nfft=config.nfft, | |
win_size=config.win_size, | |
hop_size=config.hop_size, | |
win_type=config.win_type, | |
erb_bins=config.erb_bins, | |
min_freq_bins_for_erb=config.min_freq_bins_for_erb, | |
conv_size=config.conv_size, | |
gru_size=config.gru_size, | |
) | |
self.config = config | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
config = RNNoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
model = cls(config) | |
if os.path.isdir(pretrained_model_name_or_path): | |
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) | |
else: | |
ckpt_file = pretrained_model_name_or_path | |
with open(ckpt_file, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def save_pretrained(self, | |
save_directory: Union[str, os.PathLike], | |
state_dict: Optional[dict] = None, | |
): | |
model = self | |
if state_dict is None: | |
state_dict = model.state_dict() | |
os.makedirs(save_directory, exist_ok=True) | |
# save state dict | |
model_file = os.path.join(save_directory, MODEL_FILE) | |
torch.save(state_dict, model_file) | |
# save config | |
config_file = os.path.join(save_directory, CONFIG_FILE) | |
self.config.to_yaml_file(config_file) | |
return save_directory | |
def main1(): | |
config = RNNoiseConfig() | |
model = RNNoisePretrainedModel(config) | |
model.eval() | |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32) | |
noisy = model.signal_prepare(noisy) | |
b, _, num_samples = noisy.shape | |
t = (num_samples - config.win_size) / config.hop_size + 1 | |
waveform, mask, h_state = model.forward(noisy) | |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") | |
print(waveform[:, :, 300: 302]) | |
return | |
def main2(): | |
config = RNNoiseConfig() | |
model = RNNoisePretrainedModel(config) | |
model.eval() | |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32) | |
noisy = model.signal_prepare(noisy) | |
b, _, num_samples = noisy.shape | |
t = (num_samples - config.win_size) / config.hop_size + 1 | |
waveform, mask, h_state = model.forward(noisy) | |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") | |
print(waveform[:, :, 300: 302]) | |
waveform = model.forward_chunk_by_chunk(noisy) | |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") | |
print(waveform[:, :, 300: 302]) | |
return | |
if __name__ == "__main__": | |
main2() | |