#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/AkenoSyuRi/DTLNPytorch https://github.com/breizhn/DTLN 在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。 """ import os from typing import Optional, Union import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig class InstantLayerNormalization(nn.Module): """ Class implementing instant layer normalization. It can also be called channel-wise layer normalization and was proposed by Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2) """ def __init__(self, channels): super(InstantLayerNormalization, self).__init__() self.epsilon = 1e-7 self.gamma = nn.Parameter(torch.ones(1, 1, channels), requires_grad=True) self.beta = nn.Parameter(torch.zeros(1, 1, channels), requires_grad=True) self.register_parameter("gamma", self.gamma) self.register_parameter("beta", self.beta) def forward(self, inputs: torch.Tensor): # calculate mean of each frame mean = torch.mean(inputs, dim=-1, keepdim=True) # calculate variance of each frame variance = torch.mean(torch.square(inputs - mean), dim=-1, keepdim=True) # calculate standard deviation std = torch.sqrt(variance + self.epsilon) outputs = (inputs - mean) / std # scale with gamma outputs = outputs * self.gamma # add the bias beta outputs = outputs + self.beta # return output return outputs class SeperationBlock(nn.Module): def __init__(self, input_size: int = 257, hidden_size: int = 128, dropout: float = 0.25, ): super(SeperationBlock, self).__init__() self.rnn1 = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=1, batch_first=True, dropout=0.0, bidirectional=False, ) self.rnn2 = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=1, batch_first=True, dropout=0.0, bidirectional=False, ) self.drop = nn.Dropout(dropout) self.dense = nn.Linear(hidden_size, input_size) self.sigmoid = nn.Sigmoid() def forward(self, x: torch.Tensor, in_states: torch.Tensor = None): if in_states is None: hx1 = None hx2 = None else: h1_in, c1_in = in_states[:1, :, :, 0], in_states[:1, :, :, 1] h2_in, c2_in = in_states[1:, :, :, 0], in_states[1:, :, :, 1] hx1 = (h1_in, c1_in) hx2 = (h2_in, c2_in) x1, (h1, c1) = self.rnn1.forward(x, hx=hx1) x1 = self.drop(x1) x2, (h2, c2) = self.rnn2.forward(x1, hx=hx2) x2 = self.drop(x2) mask = self.dense(x2) mask = self.sigmoid(mask) h = torch.cat((h1, h2), dim=0) c = torch.cat((c1, c2), dim=0) out_states = torch.stack((h, c), dim=-1) return mask, out_states MODEL_FILE = "model.pt" class DTLNModel(nn.Module): def __init__(self, fft_size: int = 512, hop_size: int = 128, win_type: str = "hamming", encoder_size: int = 256, ): super(DTLNModel, self).__init__() self.fft_size = fft_size self.hop_size = hop_size self.encoder_size = encoder_size self.stft = ConvSTFT( nfft=fft_size, win_size=fft_size, hop_size=hop_size, win_type=win_type, power=None, requires_grad=False ) self.istft = ConviSTFT( nfft=fft_size, win_size=fft_size, hop_size=hop_size, win_type=win_type, requires_grad=False ) self.sep1 = SeperationBlock(input_size=(fft_size // 2 + 1), hidden_size=128, dropout=0.25, ) self.encoder_conv1 = nn.Conv1d(in_channels=fft_size, out_channels=self.encoder_size, kernel_size=1, stride=1, bias=False, ) # self.encoder_norm1 = nn.InstanceNorm1d(num_features=self.encoder_size, eps=1e-7, affine=True) self.encoder_norm1 = InstantLayerNormalization(channels=self.encoder_size) self.sep2 = SeperationBlock(input_size=self.encoder_size, hidden_size=128, dropout=0.25, ) self.decoder_conv1 = nn.Conv1d(in_channels=self.encoder_size, out_channels=fft_size, kernel_size=1, stride=1, bias=False, ) def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: if signal.dim() == 2: signal = torch.unsqueeze(signal, dim=1) _, _, n_samples = signal.shape remainder = (n_samples - self.fft_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) return signal def forward(self, noisy: torch.Tensor, ): num_samples = noisy.shape[-1] noisy = self.signal_prepare(noisy) batch_size, _, num_samples_pad = noisy.shape # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") denoise_frame, _, _ = self.forward_chunk(noisy) denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad) # denoise shape: [b, num_samples_pad] denoise = denoise[:, :num_samples] # denoise shape: [b, num_samples] denoise = torch.unsqueeze(denoise, dim=1) # denoise shape: [b, 1, num_samples] return denoise def forward_chunk(self, noisy: torch.Tensor, in_state1: torch.Tensor = None, in_state2: torch.Tensor = None, ): # noisy shape: [b, 1, num_samples] spec = self.stft.forward(noisy) # spec shape: [b, f, t], torch.complex64 # t = (num_samples - win_size) / hop_size + 1 spec = torch.view_as_real(spec) # spec shape: [b, f, t, 2] real = spec[..., 0] imag = spec[..., 1] mag = torch.sqrt(real ** 2 + imag ** 2) phase = torch.atan2(imag, real) # shape: [b, f, t] mag = mag.permute(0, 2, 1) phase = phase.permute(0, 2, 1) # shape: [b, t, f] mask, out_state1 = self.sep1.forward(mag, in_state1) # mask shape: [b, t, f] estimated_mag = mask * mag s1_stft = estimated_mag * torch.exp((1j * phase)) # s1_stft shape: [b, t, f], torch.complex64 y1 = torch.fft.irfft2(s1_stft, dim=-1) # y1 shape: [b, t, fft_size], torch.float32 y1 = y1.permute(0, 2, 1) # y1 shape: [b, fft_size, t] encoded_f = self.encoder_conv1.forward(y1) # shape: [b, c, t] encoded_f = encoded_f.permute(0, 2, 1) # shape: [b, t, c] encoded_f_norm = self.encoder_norm1.forward(encoded_f) # shape: [b, t, c] mask_2, out_state2 = self.sep2.forward(encoded_f_norm, in_state2) # shape: [b, t, c] estimated = mask_2 * encoded_f estimated = estimated.permute(0, 2, 1) # shape: [b, c, t] denoise_frame = self.decoder_conv1.forward(estimated) # shape: [b, fft_size, t] return denoise_frame, out_state1, out_state2 def forward_chunk_by_chunk(self, noisy: torch.Tensor): noisy = self.signal_prepare(noisy) # noisy shape: [b, 1, num_samples] batch_size, _, num_samples_pad = noisy.shape # print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}") t = (num_samples_pad - self.fft_size) // self.hop_size + 1 denoise_list = list() out_state1 = None out_state2 = None overlap_size = self.fft_size - self.hop_size denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype) # denoise_list.append(torch.clone(denoise_cache)) for i in range(t): begin = i * self.hop_size end = begin + self.fft_size sub_noisy = noisy[:, :, begin: end] # noisy shape: [b, 1, frame_size] with torch.no_grad(): sub_denoise_frame, out_state1, out_state2 = self.forward_chunk(sub_noisy, out_state1, out_state2) # sub_denoise_frame shape: [b, fft_size, 1] sub_denoise_frame = sub_denoise_frame[:, :, 0] # sub_denoise_frame shape: [b, fft_size] sub_denoise_frame[:, :overlap_size] += denoise_cache denoise_out = sub_denoise_frame[:, :self.hop_size] denoise_cache = sub_denoise_frame[:, self.hop_size:] # denoise_cache shape: [b, hop_size] denoise_list.append(denoise_out) denoise = torch.concat(denoise_list, dim=-1) # denoise shape: [b, num_samples] denoise = torch.unsqueeze(denoise, dim=1) # denoise shape: [b, 1, num_samples] return denoise def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int): # overlap and add # denoise_frame shape: [b, fft_size, t] denoise = torch.nn.functional.fold( denoise_frame, output_size=(num_samples, 1), kernel_size=(self.fft_size, 1), padding=(0, 0), stride=(self.hop_size, 1), ) # denoise shape: [b, 1, num_samples, 1] denoise = denoise.reshape(batch_size, -1) # denoise shape: [b, num_samples] return denoise class DTLNPretrainedModel(DTLNModel): def __init__(self, config: DTLNConfig, ): super(DTLNPretrainedModel, self).__init__( fft_size=config.fft_size, hop_size=config.hop_size, win_type=config.win_type, encoder_size=config.encoder_size, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = DTLNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): config = DTLNConfig() model = DTLNPretrainedModel(config) model.eval() noisy = torch.randn(size=(1, 16000), dtype=torch.float32) with torch.no_grad(): denoise = model.forward(noisy) print(f"denoise.shape: {denoise.shape}") print(denoise[:, :, 300: 302]) print(denoise[:, :, 15680: 15682]) print(denoise[:, :, 15760: 15762]) print(denoise[:, :, 15840: 15842]) denoise = model.forward_chunk_by_chunk(noisy) print(f"denoise.shape: {denoise.shape}") # denoise = denoise[:, :, (config.fft_size - config.hop_size):] print(denoise[:, :, 300: 302]) print(denoise[:, :, 15680: 15682]) print(denoise[:, :, 15760: 15762]) print(denoise[:, :, 15840: 15842]) return if __name__ == "__main__": main()