HoneyTian's picture
update
f74ae8e
raw
history blame
9.62 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py
https://arxiv.org/abs/2305.13686
https://github.com/yxlu-0102/MP-SENet
"""
import os
from typing import Optional, Union
from pesq import pesq
from joblib import Parallel, delayed
import numpy as np
import torch
import torch.nn as nn
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock
from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock
from toolbox.torchaudio.models.mpnet.configuation_mpnet import MPNetConfig
from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d
class SPConvTranspose2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, r=1):
super(SPConvTranspose2d, self).__init__()
self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
self.r = r
def forward(self, x):
x = self.pad1(x)
out = self.conv(x)
batch_size, nchannels, H, W = out.shape
out = out.view((batch_size, self.r, nchannels // self.r, H, W))
out = out.permute(0, 2, 3, 4, 1)
out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
return out
class DenseBlock(nn.Module):
def __init__(self, h, kernel_size=(2, 3), depth=4):
super(DenseBlock, self).__init__()
self.h = h
self.depth = depth
self.dense_block = nn.ModuleList([])
for i in range(depth):
dilation = 2 ** i
pad_length = dilation
dense_conv = nn.Sequential(
nn.ConstantPad2d((1, 1, pad_length, 0), value=0.),
nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dilation, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel)
)
self.dense_block.append(dense_conv)
def forward(self, x):
skip = x
for i in range(self.depth):
x = self.dense_block[i](skip)
skip = torch.cat([x, skip], dim=1)
return x
class DenseEncoder(nn.Module):
def __init__(self, h, in_channel):
super(DenseEncoder, self).__init__()
self.h = h
self.dense_conv_1 = nn.Sequential(
nn.Conv2d(in_channel, h.dense_channel, (1, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel))
self.dense_block = DenseBlock(h, depth=4)
self.dense_conv_2 = nn.Sequential(
nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2), padding=(0, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel))
def forward(self, x):
x = self.dense_conv_1(x) # [b, 64, T, F]
x = self.dense_block(x) # [b, 64, T, F]
x = self.dense_conv_2(x) # [b, 64, T, F//2]
return x
class MaskDecoder(nn.Module):
def __init__(self, h, out_channel=1):
super(MaskDecoder, self).__init__()
self.dense_block = DenseBlock(h, depth=4)
self.mask_conv = nn.Sequential(
SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel),
nn.Conv2d(h.dense_channel, out_channel, (1, 2))
)
self.lsigmoid = LearnableSigmoid2d(h.n_fft//2+1, beta=h.beta)
def forward(self, x):
x = self.dense_block(x)
x = self.mask_conv(x)
x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
x = self.lsigmoid(x)
return x
class PhaseDecoder(nn.Module):
def __init__(self, h, out_channel=1):
super(PhaseDecoder, self).__init__()
self.dense_block = DenseBlock(h, depth=4)
self.phase_conv = nn.Sequential(
SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel)
)
self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
def forward(self, x):
x = self.dense_block(x)
x = self.phase_conv(x)
x_r = self.phase_conv_r(x)
x_i = self.phase_conv_i(x)
x = torch.atan2(x_i, x_r)
x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
return x
class TSTransformerBlock(nn.Module):
def __init__(self, h):
super(TSTransformerBlock, self).__init__()
self.h = h
self.time_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
self.freq_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
def forward(self, x):
b, c, t, f = x.size()
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
x = self.time_transformer(x) + x
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
x = self.freq_transformer(x) + x
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
return x
class MPNet(nn.Module):
def __init__(self, config: MPNetConfig, num_tsblocks=4):
super(MPNet, self).__init__()
self.config = config
self.num_tscblocks = num_tsblocks
self.dense_encoder = DenseEncoder(config, in_channel=2)
self.TSTransformer = nn.ModuleList([])
for i in range(num_tsblocks):
self.TSTransformer.append(TSTransformerBlock(config))
self.mask_decoder = MaskDecoder(config, out_channel=1)
self.phase_decoder = PhaseDecoder(config, out_channel=1)
def forward(self, noisy_amp, noisy_pha): # [B, F, T]
x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F]
x = self.dense_encoder(x)
for i in range(self.num_tscblocks):
x = self.TSTransformer[i](x)
denoised_amp = noisy_amp * self.mask_decoder(x)
denoised_pha = self.phase_decoder(x)
denoised_com = torch.stack(
tensors=(
denoised_amp * torch.cos(denoised_pha),
denoised_amp * torch.sin(denoised_pha)
),
dim=-1
)
return denoised_amp, denoised_pha, denoised_com
MODEL_FILE = "model.pt"
class MPNetPretrainedModel(MPNet):
def __init__(self,
config: MPNetConfig,
):
super(MPNetPretrainedModel, self).__init__(
config=config,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def phase_losses(phase_r, phase_g):
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
return ip_loss, gd_loss, iaf_loss
def anti_wrapping_function(x):
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
def pesq_score(utts_r, utts_g, h):
pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
utts_r[i].squeeze().cpu().numpy(),
utts_g[i].squeeze().cpu().numpy(),
h.sampling_rate)
for i in range(len(utts_r)))
pesq_score = np.mean(pesq_score)
return pesq_score
def eval_pesq(clean_utt, esti_utt, sr):
try:
pesq_score = pesq(sr, clean_utt, esti_utt)
except:
pesq_score = -1
return pesq_score
def main():
import torchaudio
config = MPNetConfig()
model = MPNet(config=config)
transformer = torchaudio.transforms.Spectrogram(
n_fft=config.n_fft,
win_length=config.win_size,
hop_length=config.hop_size,
window_fn=torch.hamming_window,
)
inputs = torch.randn(size=(1, 32000), dtype=torch.float32)
spec = transformer.forward(inputs)
print(spec.shape)
denoised_amp, denoised_pha, denoised_com = model.forward(spec, spec)
print(denoised_amp.shape)
print(denoised_pha.shape)
print(denoised_com.shape)
return
if __name__ == '__main__':
main()