HoneyTian's picture
add dfnet2
ed91efa
raw
history blame
13.7 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/xiph/rnnoise
https://github.com/xiph/rnnoise/blob/main/torch/rnnoise/rnnoise.py
https://arxiv.org/abs/1709.08243
"""
import os
from typing import Optional, Union, Tuple
import torch
import torch.nn as nn
from torch.nn import functional as F
from toolbox.torch.sparsification.gru_sparsifier import GRUSparsifier
from toolbox.torchaudio.models.rnnoise.configuration_rnnoise import RNNoiseConfig
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
sparsify_start = 6000
sparsify_stop = 20000
sparsify_interval = 100
sparsify_exponent = 3
sparse_params1 = {
"W_hr" : (0.3, [8, 4], True),
"W_hz" : (0.2, [8, 4], True),
"W_hn" : (0.5, [8, 4], True),
"W_ir" : (0.3, [8, 4], False),
"W_iz" : (0.2, [8, 4], False),
"W_in" : (0.5, [8, 4], False),
}
def init_weights(module):
if isinstance(module, nn.GRU):
for p in module.named_parameters():
if p[0].startswith("weight_hh_"):
nn.init.orthogonal_(p[1])
class RNNoise(nn.Module):
def __init__(self,
sample_rate: int = 8000,
nfft: int = 512,
win_size: int = 512,
hop_size: int = 256,
win_type: str = "hann",
erb_bins: int = 32,
min_freq_bins_for_erb: int = 2,
conv_size: int = 128,
gru_size: int = 256,
):
super(RNNoise, self).__init__()
self.sample_rate = sample_rate
self.nfft = nfft
self.win_size = win_size
self.hop_size = hop_size
self.win_type = win_type
self.erb_bins = erb_bins
self.min_freq_bins_for_erb = min_freq_bins_for_erb
self.conv_size = conv_size
self.gru_size = gru_size
self.input_dim = nfft // 2 + 1
self.eps = 1e-12
self.erb_bands = ErbBands(
sample_rate=self.sample_rate,
nfft=self.nfft,
erb_bins=self.erb_bins,
min_freq_bins_for_erb=self.min_freq_bins_for_erb,
)
self.stft = ConvSTFT(
nfft=self.nfft,
win_size=self.win_size,
hop_size=self.hop_size,
win_type=self.win_type,
power=None,
requires_grad=False
)
self.istft = ConviSTFT(
nfft=self.nfft,
win_size=self.win_size,
hop_size=self.hop_size,
win_type=self.win_type,
requires_grad=False
)
self.pad = nn.ConstantPad1d(padding=(2, 2), value=0)
self.conv1 = nn.Conv1d(self.erb_bins, conv_size, kernel_size=3, padding="valid")
self.conv2 = nn.Conv1d(conv_size, gru_size, kernel_size=3, padding="valid")
self.gru1 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
self.gru3 = nn.GRU(self.gru_size, self.gru_size, batch_first=True)
self.dense_out = nn.Linear(4*self.gru_size, self.erb_bins)
nb_params = sum(p.numel() for p in self.parameters())
print(f"model: {nb_params} weights")
self.apply(init_weights)
self.sparsifier = [
GRUSparsifier(
task_list=[(self.gru1, sparse_params1)],
start=sparsify_start,
stop=sparsify_stop,
interval=sparsify_interval,
exponent=sparsify_exponent,
),
GRUSparsifier(
task_list=[(self.gru2, sparse_params1)],
start=sparsify_start,
stop=sparsify_stop,
interval=sparsify_interval,
exponent=sparsify_exponent,
),
GRUSparsifier(
task_list=[(self.gru3, sparse_params1)],
start=sparsify_start,
stop=sparsify_stop,
interval=sparsify_interval,
exponent=sparsify_exponent,
)
]
def sparsify(self):
for sparsifier in self.sparsifier:
sparsifier.step()
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
if signal.dim() == 2:
signal = torch.unsqueeze(signal, dim=1)
_, _, n_samples = signal.shape
remainder = (n_samples - self.win_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
return signal
def forward(self,
noisy: torch.Tensor,
states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
):
num_samples = noisy.shape[-1]
noisy = self.signal_prepare(noisy)
batch_size, _, num_samples_pad = noisy.shape
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
mag_noisy, pha_noisy = self.mag_pha_stft(noisy)
# shape: (b, f, t)
# t = (num_samples - win_size) / hop_size + 1
mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2)
# shape: (b, t, f)
mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True)
# shape: (b, t, erb_bins)
mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2)
# shape: (b, erb_bins, t)
mag_noisy_t_erb = self.pad(mag_noisy_t_erb)
mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb)
gru_out, states = self.forward_gru(mag_noisy_t_erb, states)
# gru_out shape: [b, t, f]
mask_erb = torch.sigmoid(self.dense_out(gru_out))
# mask_erb shape: (b, t, erb_bins)
mask = self.erb_bands.erb_scale_inv(mask_erb)
# mask shape: (b, t, f)
mask = torch.transpose(mask, dim0=1, dim1=2)
# mask shape: (b, f, t)
stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
denoise = self.istft.forward(stft_denoise)
# denoise shape: [b, 1, num_samples_pad]
denoise = denoise[:, :, :num_samples]
# denoise shape: [b, 1, num_samples]
return denoise, mask, states
def forward_conv(self, mag_noisy: torch.Tensor):
# mag_noisy shape: [b, f, t]
tmp = mag_noisy
# tmp shape: [b, f, t]
tmp = torch.tanh(self.conv1(tmp))
tmp = torch.tanh(self.conv2(tmp))
# tmp shape: [b, f, t]
return tmp
def forward_gru(self,
mag_noisy: torch.Tensor,
states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
):
if states is None:
gru1_state = None
gru2_state = None
gru3_state = None
else:
gru1_state = states[0]
gru2_state = states[1]
gru3_state = states[2]
# mag_noisy shape: [b, f, t]
tmp = mag_noisy.permute(0, 2, 1)
# tmp shape: [b, t, f]
gru1_out, gru1_state = self.gru1(tmp, gru1_state)
gru2_out, gru2_state = self.gru2(gru1_out, gru2_state)
gru3_out, gru3_state = self.gru3(gru2_out, gru3_state)
new_states = [gru1_state, gru2_state, gru3_state]
gru_out = torch.cat(tensors=[tmp, gru1_out, gru2_out, gru3_out], dim=-1)
# gru_out shape: [b, t, f]
return gru_out, new_states
def forward_chunk_by_chunk(self,
noisy: torch.Tensor,
):
noisy = self.signal_prepare(noisy)
b, _, num_samples = noisy.shape
t = (num_samples - self.win_size) / self.hop_size + 1
waveform = torch.zeros(size=(b, 1, 0), dtype=torch.float32)
states = None
cache_dict = None
cache_list = list()
for i in range(int(t)):
begin = i * self.hop_size
end = begin + self.win_size
sub_noisy = noisy[:, :, begin:end]
mag_noisy, pha_noisy = self.mag_pha_stft(sub_noisy)
mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2)
mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True)
mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2)
# mag_noisy_t_erb shape: (b, erb_bins, t)
if len(cache_list) == 0:
cache_list.extend([{
"mag_noisy": torch.zeros_like(mag_noisy),
"pha_noisy": torch.zeros_like(pha_noisy),
"mag_noisy_t_erb": torch.zeros_like(mag_noisy_t_erb),
}] * 2)
cache_list.append({
"mag_noisy": mag_noisy,
"pha_noisy": pha_noisy,
"mag_noisy_t_erb": mag_noisy_t_erb,
})
if len(cache_list) < 5:
continue
mag_noisy_t_erb = torch.concat(
tensors=[c["mag_noisy_t_erb"] for c in cache_list],
dim=-1
)
mag_noisy = cache_list[2]["mag_noisy"]
pha_noisy = cache_list[2]["pha_noisy"]
cache_list.pop(0)
# mag_noisy_t_erb shape: [b, f, 5]
mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb)
# mag_noisy_t_erb shape: [b, f, 1]
gru_out, states = self.forward_gru(mag_noisy_t_erb, states)
mask_erb = torch.sigmoid(self.dense_out(gru_out))
mask = self.erb_bands.erb_scale_inv(mask_erb)
mask = torch.transpose(mask, dim0=1, dim1=2)
stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask)
sub_waveform, cache_dict = self.istft.forward_chunk(stft_denoise, cache_dict=cache_dict)
waveform = torch.concat(tensors=[waveform, sub_waveform], dim=-1)
return waveform
def do_mask(self,
mag_noisy: torch.Tensor,
pha_noisy: torch.Tensor,
mask: torch.Tensor,
):
# (b, f, t)
mag_denoise = mag_noisy * mask
stft_denoise = mag_denoise * torch.exp((1j * pha_noisy))
return stft_denoise
def mag_pha_stft(self, noisy: torch.Tensor):
# noisy shape: [b, num_samples]
stft_noisy = self.stft.forward(noisy)
# stft_noisy shape: [b, f, t], torch.complex64
real = torch.real(stft_noisy)
imag = torch.imag(stft_noisy)
mag_noisy = torch.sqrt(real ** 2 + imag ** 2)
pha_noisy = torch.atan2(imag, real)
# shape: (b, f, t)
return mag_noisy, pha_noisy
MODEL_FILE = "model.pt"
class RNNoisePretrainedModel(RNNoise):
def __init__(self,
config: RNNoiseConfig,
):
super(RNNoisePretrainedModel, self).__init__(
sample_rate=config.sample_rate,
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
erb_bins=config.erb_bins,
min_freq_bins_for_erb=config.min_freq_bins_for_erb,
conv_size=config.conv_size,
gru_size=config.gru_size,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = RNNoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main1():
config = RNNoiseConfig()
model = RNNoisePretrainedModel(config)
model.eval()
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
noisy = model.signal_prepare(noisy)
b, _, num_samples = noisy.shape
t = (num_samples - config.win_size) / config.hop_size + 1
waveform, mask, h_state = model.forward(noisy)
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
return
def main2():
config = RNNoiseConfig()
model = RNNoisePretrainedModel(config)
model.eval()
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
noisy = model.signal_prepare(noisy)
b, _, num_samples = noisy.shape
t = (num_samples - config.win_size) / config.hop_size + 1
waveform, mask, h_state = model.forward(noisy)
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
waveform = model.forward_chunk_by_chunk(noisy)
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
return
if __name__ == "__main__":
main2()