#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/xiph/rnnoise https://github.com/xiph/rnnoise/blob/main/torch/rnnoise/rnnoise.py https://arxiv.org/abs/1709.08243 """ import os from typing import Optional, Union, Tuple import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torch.sparsification.gru_sparsifier import GRUSparsifier from toolbox.torchaudio.models.rnnoise.configuration_rnnoise import RNNoiseConfig from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands sparsify_start = 6000 sparsify_stop = 20000 sparsify_interval = 100 sparsify_exponent = 3 sparse_params1 = { "W_hr" : (0.3, [8, 4], True), "W_hz" : (0.2, [8, 4], True), "W_hn" : (0.5, [8, 4], True), "W_ir" : (0.3, [8, 4], False), "W_iz" : (0.2, [8, 4], False), "W_in" : (0.5, [8, 4], False), } def init_weights(module): if isinstance(module, nn.GRU): for p in module.named_parameters(): if p[0].startswith("weight_hh_"): nn.init.orthogonal_(p[1]) class RNNoise(nn.Module): def __init__(self, sample_rate: int = 8000, nfft: int = 512, win_size: int = 512, hop_size: int = 256, win_type: str = "hann", erb_bins: int = 32, min_freq_bins_for_erb: int = 2, conv_size: int = 128, gru_size: int = 256, ): super(RNNoise, self).__init__() self.sample_rate = sample_rate self.nfft = nfft self.win_size = win_size self.hop_size = hop_size self.win_type = win_type self.erb_bins = erb_bins self.min_freq_bins_for_erb = min_freq_bins_for_erb self.conv_size = conv_size self.gru_size = gru_size self.input_dim = nfft // 2 + 1 self.eps = 1e-12 self.erb_bands = ErbBands( sample_rate=self.sample_rate, nfft=self.nfft, erb_bins=self.erb_bins, min_freq_bins_for_erb=self.min_freq_bins_for_erb, ) self.stft = ConvSTFT( nfft=self.nfft, win_size=self.win_size, hop_size=self.hop_size, win_type=self.win_type, power=None, requires_grad=False ) self.istft = ConviSTFT( nfft=self.nfft, win_size=self.win_size, hop_size=self.hop_size, win_type=self.win_type, requires_grad=False ) self.pad = nn.ConstantPad1d(padding=(2, 2), value=0) self.conv1 = nn.Conv1d(self.erb_bins, conv_size, kernel_size=3, padding="valid") self.conv2 = nn.Conv1d(conv_size, gru_size, kernel_size=3, padding="valid") self.gru1 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) self.gru2 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) self.gru3 = nn.GRU(self.gru_size, self.gru_size, batch_first=True) self.dense_out = nn.Linear(4*self.gru_size, self.erb_bins) nb_params = sum(p.numel() for p in self.parameters()) print(f"model: {nb_params} weights") self.apply(init_weights) self.sparsifier = [ GRUSparsifier( task_list=[(self.gru1, sparse_params1)], start=sparsify_start, stop=sparsify_stop, interval=sparsify_interval, exponent=sparsify_exponent, ), GRUSparsifier( task_list=[(self.gru2, sparse_params1)], start=sparsify_start, stop=sparsify_stop, interval=sparsify_interval, exponent=sparsify_exponent, ), GRUSparsifier( task_list=[(self.gru3, sparse_params1)], start=sparsify_start, stop=sparsify_stop, interval=sparsify_interval, exponent=sparsify_exponent, ) ] def sparsify(self): for sparsifier in self.sparsifier: sparsifier.step() def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: if signal.dim() == 2: signal = torch.unsqueeze(signal, dim=1) _, _, n_samples = signal.shape remainder = (n_samples - self.win_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) return signal def forward(self, noisy: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None, ): num_samples = noisy.shape[-1] noisy = self.signal_prepare(noisy) batch_size, _, num_samples_pad = noisy.shape # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") mag_noisy, pha_noisy = self.mag_pha_stft(noisy) # shape: (b, f, t) # t = (num_samples - win_size) / hop_size + 1 mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2) # shape: (b, t, f) mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True) # shape: (b, t, erb_bins) mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2) # shape: (b, erb_bins, t) mag_noisy_t_erb = self.pad(mag_noisy_t_erb) mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb) gru_out, states = self.forward_gru(mag_noisy_t_erb, states) # gru_out shape: [b, t, f] mask_erb = torch.sigmoid(self.dense_out(gru_out)) # mask_erb shape: (b, t, erb_bins) mask = self.erb_bands.erb_scale_inv(mask_erb) # mask shape: (b, t, f) mask = torch.transpose(mask, dim0=1, dim1=2) # mask shape: (b, f, t) stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) denoise = self.istft.forward(stft_denoise) # denoise shape: [b, 1, num_samples_pad] denoise = denoise[:, :, :num_samples] # denoise shape: [b, 1, num_samples] return denoise, mask, states def forward_conv(self, mag_noisy: torch.Tensor): # mag_noisy shape: [b, f, t] tmp = mag_noisy # tmp shape: [b, f, t] tmp = torch.tanh(self.conv1(tmp)) tmp = torch.tanh(self.conv2(tmp)) # tmp shape: [b, f, t] return tmp def forward_gru(self, mag_noisy: torch.Tensor, states: Tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None, ): if states is None: gru1_state = None gru2_state = None gru3_state = None else: gru1_state = states[0] gru2_state = states[1] gru3_state = states[2] # mag_noisy shape: [b, f, t] tmp = mag_noisy.permute(0, 2, 1) # tmp shape: [b, t, f] gru1_out, gru1_state = self.gru1(tmp, gru1_state) gru2_out, gru2_state = self.gru2(gru1_out, gru2_state) gru3_out, gru3_state = self.gru3(gru2_out, gru3_state) new_states = [gru1_state, gru2_state, gru3_state] gru_out = torch.cat(tensors=[tmp, gru1_out, gru2_out, gru3_out], dim=-1) # gru_out shape: [b, t, f] return gru_out, new_states def forward_chunk_by_chunk(self, noisy: torch.Tensor, ): noisy = self.signal_prepare(noisy) b, _, num_samples = noisy.shape t = (num_samples - self.win_size) / self.hop_size + 1 waveform = torch.zeros(size=(b, 1, 0), dtype=torch.float32) states = None cache_dict = None cache_list = list() for i in range(int(t)): begin = i * self.hop_size end = begin + self.win_size sub_noisy = noisy[:, :, begin:end] mag_noisy, pha_noisy = self.mag_pha_stft(sub_noisy) mag_noisy_t = torch.transpose(mag_noisy, dim0=1, dim1=2) mag_noisy_t_erb = self.erb_bands.erb_scale(mag_noisy_t, db=True) mag_noisy_t_erb = torch.transpose(mag_noisy_t_erb, dim0=1, dim1=2) # mag_noisy_t_erb shape: (b, erb_bins, t) if len(cache_list) == 0: cache_list.extend([{ "mag_noisy": torch.zeros_like(mag_noisy), "pha_noisy": torch.zeros_like(pha_noisy), "mag_noisy_t_erb": torch.zeros_like(mag_noisy_t_erb), }] * 2) cache_list.append({ "mag_noisy": mag_noisy, "pha_noisy": pha_noisy, "mag_noisy_t_erb": mag_noisy_t_erb, }) if len(cache_list) < 5: continue mag_noisy_t_erb = torch.concat( tensors=[c["mag_noisy_t_erb"] for c in cache_list], dim=-1 ) mag_noisy = cache_list[2]["mag_noisy"] pha_noisy = cache_list[2]["pha_noisy"] cache_list.pop(0) # mag_noisy_t_erb shape: [b, f, 5] mag_noisy_t_erb = self.forward_conv(mag_noisy_t_erb) # mag_noisy_t_erb shape: [b, f, 1] gru_out, states = self.forward_gru(mag_noisy_t_erb, states) mask_erb = torch.sigmoid(self.dense_out(gru_out)) mask = self.erb_bands.erb_scale_inv(mask_erb) mask = torch.transpose(mask, dim0=1, dim1=2) stft_denoise = self.do_mask(mag_noisy, pha_noisy, mask) sub_waveform, cache_dict = self.istft.forward_chunk(stft_denoise, cache_dict=cache_dict) waveform = torch.concat(tensors=[waveform, sub_waveform], dim=-1) return waveform def do_mask(self, mag_noisy: torch.Tensor, pha_noisy: torch.Tensor, mask: torch.Tensor, ): # (b, f, t) mag_denoise = mag_noisy * mask stft_denoise = mag_denoise * torch.exp((1j * pha_noisy)) return stft_denoise def mag_pha_stft(self, noisy: torch.Tensor): # noisy shape: [b, num_samples] stft_noisy = self.stft.forward(noisy) # stft_noisy shape: [b, f, t], torch.complex64 real = torch.real(stft_noisy) imag = torch.imag(stft_noisy) mag_noisy = torch.sqrt(real ** 2 + imag ** 2) pha_noisy = torch.atan2(imag, real) # shape: (b, f, t) return mag_noisy, pha_noisy MODEL_FILE = "model.pt" class RNNoisePretrainedModel(RNNoise): def __init__(self, config: RNNoiseConfig, ): super(RNNoisePretrainedModel, self).__init__( sample_rate=config.sample_rate, nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, erb_bins=config.erb_bins, min_freq_bins_for_erb=config.min_freq_bins_for_erb, conv_size=config.conv_size, gru_size=config.gru_size, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = RNNoiseConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main1(): config = RNNoiseConfig() model = RNNoisePretrainedModel(config) model.eval() noisy = torch.randn(size=(1, 16000), dtype=torch.float32) noisy = model.signal_prepare(noisy) b, _, num_samples = noisy.shape t = (num_samples - config.win_size) / config.hop_size + 1 waveform, mask, h_state = model.forward(noisy) print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) return def main2(): config = RNNoiseConfig() model = RNNoisePretrainedModel(config) model.eval() noisy = torch.randn(size=(1, 16000), dtype=torch.float32) noisy = model.signal_prepare(noisy) b, _, num_samples = noisy.shape t = (num_samples - config.win_size) / config.hop_size + 1 waveform, mask, h_state = model.forward(noisy) print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) waveform = model.forward_chunk_by_chunk(noisy) print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) return if __name__ == "__main__": main2()