#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py https://huggingface.co/spaces/JacobLinCool/MP-SENet https://arxiv.org/abs/2305.13686 https://github.com/yxlu-0102/MP-SENet 应该是不支持流式改造的。 """ import os from typing import Optional, Union from pesq import pesq from joblib import Parallel, delayed import numpy as np import torch import torch.nn as nn from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d class SPConvTranspose2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size, r=1): super(SPConvTranspose2d, self).__init__() self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.) self.out_channels = out_channels self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1)) self.r = r def forward(self, x): x = self.pad1(x) out = self.conv(x) batch_size, nchannels, H, W = out.shape out = out.view((batch_size, self.r, nchannels // self.r, H, W)) out = out.permute(0, 2, 3, 4, 1) out = out.contiguous().view((batch_size, nchannels // self.r, H, -1)) return out class DenseBlock(nn.Module): def __init__(self, h, kernel_size=(2, 3), depth=4): super(DenseBlock, self).__init__() self.h = h self.depth = depth self.dense_block = nn.ModuleList([]) for i in range(depth): dilation = 2 ** i pad_length = dilation dense_conv = nn.Sequential( nn.ConstantPad2d((1, 1, pad_length, 0), value=0.), nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dilation, 1)), nn.InstanceNorm2d(h.dense_channel, affine=True), nn.PReLU(h.dense_channel) ) self.dense_block.append(dense_conv) def forward(self, x): skip = x for i in range(self.depth): x = self.dense_block[i](skip) skip = torch.cat([x, skip], dim=1) return x class DenseEncoder(nn.Module): def __init__(self, h, in_channel): super(DenseEncoder, self).__init__() self.h = h self.dense_conv_1 = nn.Sequential( nn.Conv2d(in_channel, h.dense_channel, (1, 1)), nn.InstanceNorm2d(h.dense_channel, affine=True), nn.PReLU(h.dense_channel)) self.dense_block = DenseBlock(h, depth=4) self.dense_conv_2 = nn.Sequential( nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2), padding=(0, 1)), nn.InstanceNorm2d(h.dense_channel, affine=True), nn.PReLU(h.dense_channel)) def forward(self, x): x = self.dense_conv_1(x) # [b, 64, T, F] x = self.dense_block(x) # [b, 64, T, F] x = self.dense_conv_2(x) # [b, 64, T, F//2] return x class MaskDecoder(nn.Module): def __init__(self, h, out_channel=1): super(MaskDecoder, self).__init__() self.dense_block = DenseBlock(h, depth=4) self.mask_conv = nn.Sequential( SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2), nn.InstanceNorm2d(h.dense_channel, affine=True), nn.PReLU(h.dense_channel), nn.Conv2d(h.dense_channel, out_channel, (1, 2)) ) self.lsigmoid = LearnableSigmoid2d(h.n_fft//2+1, beta=h.beta) def forward(self, x): x = self.dense_block(x) x = self.mask_conv(x) x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T] x = self.lsigmoid(x) return x class PhaseDecoder(nn.Module): def __init__(self, h, out_channel=1): super(PhaseDecoder, self).__init__() self.dense_block = DenseBlock(h, depth=4) self.phase_conv = nn.Sequential( SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2), nn.InstanceNorm2d(h.dense_channel, affine=True), nn.PReLU(h.dense_channel) ) self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 2)) self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 2)) def forward(self, x): x = self.dense_block(x) x = self.phase_conv(x) x_r = self.phase_conv_r(x) x_i = self.phase_conv_i(x) x = torch.atan2(x_i, x_r) x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T] return x class TSTransformerBlock(nn.Module): def __init__(self, h): super(TSTransformerBlock, self).__init__() self.h = h self.time_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4) self.freq_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4) def forward(self, x): b, c, t, f = x.size() x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c) x = self.time_transformer(x) + x x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c) x = self.freq_transformer(x) + x x = x.view(b, t, f, c).permute(0, 3, 1, 2) return x class MPNet(nn.Module): def __init__(self, config: MPNetConfig, num_tsblocks=4): super(MPNet, self).__init__() self.num_tscblocks = num_tsblocks self.dense_encoder = DenseEncoder(config, in_channel=2) self.TSTransformer = nn.ModuleList([]) for i in range(num_tsblocks): self.TSTransformer.append(TSTransformerBlock(config)) self.mask_decoder = MaskDecoder(config, out_channel=1) self.phase_decoder = PhaseDecoder(config, out_channel=1) def forward(self, noisy_amp, noisy_pha): # [B, F, T] x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F] x = self.dense_encoder(x) for i in range(self.num_tscblocks): x = self.TSTransformer[i](x) denoised_amp = noisy_amp * self.mask_decoder(x) denoised_pha = self.phase_decoder(x) denoised_com = torch.stack( tensors=( denoised_amp * torch.cos(denoised_pha), denoised_amp * torch.sin(denoised_pha) ), dim=-1 ) return denoised_amp, denoised_pha, denoised_com MODEL_FILE = "generator.pt" class MPNetPretrainedModel(MPNet): def __init__(self, config: MPNetConfig, ): super(MPNetPretrainedModel, self).__init__( config=config, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def phase_losses(phase_r, phase_g): ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g)) gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1))) iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2))) return ip_loss, gd_loss, iaf_loss def anti_wrapping_function(x): return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi) # def pesq_score(utts_r, utts_g, h): # # pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)( # utts_r[i].squeeze().cpu().numpy(), # utts_g[i].squeeze().cpu().numpy(), # h.sample_rate, ) # for i in range(len(utts_r))) # pesq_score = np.mean(pesq_score) # # return pesq_score # # # def eval_pesq(clean_utt, esti_utt, sr): # try: # mode = "nb" if sr == 8000 else "wb" # pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode) # except: # pesq_score = -1 # # return pesq_score def main(): import torchaudio config = MPNetConfig() model = MPNet(config=config) transformer = torchaudio.transforms.Spectrogram( n_fft=config.n_fft, win_length=config.win_size, hop_length=config.hop_size, window_fn=torch.hamming_window, ) inputs = torch.randn(size=(1, 32000), dtype=torch.float32) spec = transformer.forward(inputs) print(spec.shape) denoised_amp, denoised_pha, denoised_com = model.forward(spec, spec) print(denoised_amp.shape) print(denoised_pha.shape) print(denoised_com.shape) return if __name__ == '__main__': main()