HoneyTian's picture
update
a88ebd1
raw
history blame
9.81 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://huggingface.co/spaces/LeeSangHoon/HierSpeech_TTS/blob/main/denoiser/generator.py
https://huggingface.co/spaces/JacobLinCool/MP-SENet
https://arxiv.org/abs/2305.13686
https://github.com/yxlu-0102/MP-SENet
应该是不支持流式改造的。
"""
import os
from typing import Optional, Union
from pesq import pesq
from joblib import Parallel, delayed
import numpy as np
import torch
import torch.nn as nn
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.mpnet.conformer import ConformerBlock
from toolbox.torchaudio.models.mpnet.transformers import TransformerBlock
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid2d
class SPConvTranspose2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, r=1):
super(SPConvTranspose2d, self).__init__()
self.pad1 = nn.ConstantPad2d((1, 1, 0, 0), value=0.)
self.out_channels = out_channels
self.conv = nn.Conv2d(in_channels, out_channels * r, kernel_size=kernel_size, stride=(1, 1))
self.r = r
def forward(self, x):
x = self.pad1(x)
out = self.conv(x)
batch_size, nchannels, H, W = out.shape
out = out.view((batch_size, self.r, nchannels // self.r, H, W))
out = out.permute(0, 2, 3, 4, 1)
out = out.contiguous().view((batch_size, nchannels // self.r, H, -1))
return out
class DenseBlock(nn.Module):
def __init__(self, h, kernel_size=(2, 3), depth=4):
super(DenseBlock, self).__init__()
self.h = h
self.depth = depth
self.dense_block = nn.ModuleList([])
for i in range(depth):
dilation = 2 ** i
pad_length = dilation
dense_conv = nn.Sequential(
nn.ConstantPad2d((1, 1, pad_length, 0), value=0.),
nn.Conv2d(h.dense_channel*(i+1), h.dense_channel, kernel_size, dilation=(dilation, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel)
)
self.dense_block.append(dense_conv)
def forward(self, x):
skip = x
for i in range(self.depth):
x = self.dense_block[i](skip)
skip = torch.cat([x, skip], dim=1)
return x
class DenseEncoder(nn.Module):
def __init__(self, h, in_channel):
super(DenseEncoder, self).__init__()
self.h = h
self.dense_conv_1 = nn.Sequential(
nn.Conv2d(in_channel, h.dense_channel, (1, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel))
self.dense_block = DenseBlock(h, depth=4)
self.dense_conv_2 = nn.Sequential(
nn.Conv2d(h.dense_channel, h.dense_channel, (1, 3), (1, 2), padding=(0, 1)),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel))
def forward(self, x):
x = self.dense_conv_1(x) # [b, 64, T, F]
x = self.dense_block(x) # [b, 64, T, F]
x = self.dense_conv_2(x) # [b, 64, T, F//2]
return x
class MaskDecoder(nn.Module):
def __init__(self, h, out_channel=1):
super(MaskDecoder, self).__init__()
self.dense_block = DenseBlock(h, depth=4)
self.mask_conv = nn.Sequential(
SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel),
nn.Conv2d(h.dense_channel, out_channel, (1, 2))
)
self.lsigmoid = LearnableSigmoid2d(h.n_fft//2+1, beta=h.beta)
def forward(self, x):
x = self.dense_block(x)
x = self.mask_conv(x)
x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
x = self.lsigmoid(x)
return x
class PhaseDecoder(nn.Module):
def __init__(self, h, out_channel=1):
super(PhaseDecoder, self).__init__()
self.dense_block = DenseBlock(h, depth=4)
self.phase_conv = nn.Sequential(
SPConvTranspose2d(h.dense_channel, h.dense_channel, (1, 3), 2),
nn.InstanceNorm2d(h.dense_channel, affine=True),
nn.PReLU(h.dense_channel)
)
self.phase_conv_r = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
self.phase_conv_i = nn.Conv2d(h.dense_channel, out_channel, (1, 2))
def forward(self, x):
x = self.dense_block(x)
x = self.phase_conv(x)
x_r = self.phase_conv_r(x)
x_i = self.phase_conv_i(x)
x = torch.atan2(x_i, x_r)
x = x.permute(0, 3, 2, 1).squeeze(-1) # [B, F, T]
return x
class TSTransformerBlock(nn.Module):
def __init__(self, h):
super(TSTransformerBlock, self).__init__()
self.h = h
self.time_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
self.freq_transformer = TransformerBlock(d_model=h.dense_channel, n_heads=4)
def forward(self, x):
b, c, t, f = x.size()
x = x.permute(0, 3, 2, 1).contiguous().view(b*f, t, c)
x = self.time_transformer(x) + x
x = x.view(b, f, t, c).permute(0, 2, 1, 3).contiguous().view(b*t, f, c)
x = self.freq_transformer(x) + x
x = x.view(b, t, f, c).permute(0, 3, 1, 2)
return x
class MPNet(nn.Module):
def __init__(self, config: MPNetConfig, num_tsblocks=4):
super(MPNet, self).__init__()
self.num_tscblocks = num_tsblocks
self.dense_encoder = DenseEncoder(config, in_channel=2)
self.TSTransformer = nn.ModuleList([])
for i in range(num_tsblocks):
self.TSTransformer.append(TSTransformerBlock(config))
self.mask_decoder = MaskDecoder(config, out_channel=1)
self.phase_decoder = PhaseDecoder(config, out_channel=1)
def forward(self, noisy_amp, noisy_pha): # [B, F, T]
x = torch.stack((noisy_amp, noisy_pha), dim=-1).permute(0, 3, 2, 1) # [B, 2, T, F]
x = self.dense_encoder(x)
for i in range(self.num_tscblocks):
x = self.TSTransformer[i](x)
denoised_amp = noisy_amp * self.mask_decoder(x)
denoised_pha = self.phase_decoder(x)
denoised_com = torch.stack(
tensors=(
denoised_amp * torch.cos(denoised_pha),
denoised_amp * torch.sin(denoised_pha)
),
dim=-1
)
return denoised_amp, denoised_pha, denoised_com
MODEL_FILE = "generator.pt"
class MPNetPretrainedModel(MPNet):
def __init__(self,
config: MPNetConfig,
):
super(MPNetPretrainedModel, self).__init__(
config=config,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def phase_losses(phase_r, phase_g):
ip_loss = torch.mean(anti_wrapping_function(phase_r - phase_g))
gd_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=1) - torch.diff(phase_g, dim=1)))
iaf_loss = torch.mean(anti_wrapping_function(torch.diff(phase_r, dim=2) - torch.diff(phase_g, dim=2)))
return ip_loss, gd_loss, iaf_loss
def anti_wrapping_function(x):
return torch.abs(x - torch.round(x / (2 * np.pi)) * 2 * np.pi)
# def pesq_score(utts_r, utts_g, h):
#
# pesq_score = Parallel(n_jobs=30)(delayed(eval_pesq)(
# utts_r[i].squeeze().cpu().numpy(),
# utts_g[i].squeeze().cpu().numpy(),
# h.sample_rate, )
# for i in range(len(utts_r)))
# pesq_score = np.mean(pesq_score)
#
# return pesq_score
#
#
# def eval_pesq(clean_utt, esti_utt, sr):
# try:
# mode = "nb" if sr == 8000 else "wb"
# pesq_score = pesq(sr, clean_utt, esti_utt, mode=mode)
# except:
# pesq_score = -1
#
# return pesq_score
def main():
import torchaudio
config = MPNetConfig()
model = MPNet(config=config)
transformer = torchaudio.transforms.Spectrogram(
n_fft=config.n_fft,
win_length=config.win_size,
hop_length=config.hop_size,
window_fn=torch.hamming_window,
)
inputs = torch.randn(size=(1, 32000), dtype=torch.float32)
spec = transformer.forward(inputs)
print(spec.shape)
denoised_amp, denoised_pha, denoised_com = model.forward(spec, spec)
print(denoised_amp.shape)
print(denoised_pha.shape)
print(denoised_com.shape)
return
if __name__ == '__main__':
main()