HoneyTian's picture
update
bd3d872
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/AkenoSyuRi/DTLNPytorch
https://github.com/breizhn/DTLN
在 dns3 500个小时的数据上训练, 在 dns3 的测试集上达到了 pesq 3.04 的水平。
"""
import os
from typing import Optional, Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
from toolbox.torchaudio.models.dtln.configuration_dtln import DTLNConfig
class InstantLayerNormalization(nn.Module):
"""
Class implementing instant layer normalization. It can also be called
channel-wise layer normalization and was proposed by
Luo & Mesgarani (https://arxiv.org/abs/1809.07454v2)
"""
def __init__(self, channels):
super(InstantLayerNormalization, self).__init__()
self.epsilon = 1e-7
self.gamma = nn.Parameter(torch.ones(1, 1, channels), requires_grad=True)
self.beta = nn.Parameter(torch.zeros(1, 1, channels), requires_grad=True)
self.register_parameter("gamma", self.gamma)
self.register_parameter("beta", self.beta)
def forward(self, inputs: torch.Tensor):
# calculate mean of each frame
mean = torch.mean(inputs, dim=-1, keepdim=True)
# calculate variance of each frame
variance = torch.mean(torch.square(inputs - mean), dim=-1, keepdim=True)
# calculate standard deviation
std = torch.sqrt(variance + self.epsilon)
outputs = (inputs - mean) / std
# scale with gamma
outputs = outputs * self.gamma
# add the bias beta
outputs = outputs + self.beta
# return output
return outputs
class SeperationBlock(nn.Module):
def __init__(self,
input_size: int = 257,
hidden_size: int = 128,
dropout: float = 0.25,
):
super(SeperationBlock, self).__init__()
self.rnn1 = nn.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
dropout=0.0,
bidirectional=False,
)
self.rnn2 = nn.LSTM(input_size=hidden_size,
hidden_size=hidden_size,
num_layers=1,
batch_first=True,
dropout=0.0,
bidirectional=False,
)
self.drop = nn.Dropout(dropout)
self.dense = nn.Linear(hidden_size, input_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor, in_states: torch.Tensor = None):
if in_states is None:
hx1 = None
hx2 = None
else:
h1_in, c1_in = in_states[:1, :, :, 0], in_states[:1, :, :, 1]
h2_in, c2_in = in_states[1:, :, :, 0], in_states[1:, :, :, 1]
hx1 = (h1_in, c1_in)
hx2 = (h2_in, c2_in)
x1, (h1, c1) = self.rnn1.forward(x, hx=hx1)
x1 = self.drop(x1)
x2, (h2, c2) = self.rnn2.forward(x1, hx=hx2)
x2 = self.drop(x2)
mask = self.dense(x2)
mask = self.sigmoid(mask)
h = torch.cat((h1, h2), dim=0)
c = torch.cat((c1, c2), dim=0)
out_states = torch.stack((h, c), dim=-1)
return mask, out_states
MODEL_FILE = "model.pt"
class DTLNModel(nn.Module):
def __init__(self,
fft_size: int = 512,
hop_size: int = 128,
win_type: str = "hamming",
encoder_size: int = 256,
):
super(DTLNModel, self).__init__()
self.fft_size = fft_size
self.hop_size = hop_size
self.encoder_size = encoder_size
self.stft = ConvSTFT(
nfft=fft_size,
win_size=fft_size,
hop_size=hop_size,
win_type=win_type,
power=None,
requires_grad=False
)
self.istft = ConviSTFT(
nfft=fft_size,
win_size=fft_size,
hop_size=hop_size,
win_type=win_type,
requires_grad=False
)
self.sep1 = SeperationBlock(input_size=(fft_size // 2 + 1),
hidden_size=128,
dropout=0.25,
)
self.encoder_conv1 = nn.Conv1d(in_channels=fft_size,
out_channels=self.encoder_size,
kernel_size=1,
stride=1,
bias=False,
)
# self.encoder_norm1 = nn.InstanceNorm1d(num_features=self.encoder_size, eps=1e-7, affine=True)
self.encoder_norm1 = InstantLayerNormalization(channels=self.encoder_size)
self.sep2 = SeperationBlock(input_size=self.encoder_size,
hidden_size=128,
dropout=0.25,
)
self.decoder_conv1 = nn.Conv1d(in_channels=self.encoder_size,
out_channels=fft_size,
kernel_size=1,
stride=1,
bias=False,
)
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
if signal.dim() == 2:
signal = torch.unsqueeze(signal, dim=1)
_, _, n_samples = signal.shape
remainder = (n_samples - self.fft_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
return signal
def forward(self,
noisy: torch.Tensor,
):
num_samples = noisy.shape[-1]
noisy = self.signal_prepare(noisy)
batch_size, _, num_samples_pad = noisy.shape
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
denoise_frame, _, _ = self.forward_chunk(noisy)
denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad)
# denoise shape: [b, num_samples_pad]
denoise = denoise[:, :num_samples]
# denoise shape: [b, num_samples]
denoise = torch.unsqueeze(denoise, dim=1)
# denoise shape: [b, 1, num_samples]
return denoise
def forward_chunk(self,
noisy: torch.Tensor,
in_state1: torch.Tensor = None,
in_state2: torch.Tensor = None,
):
# noisy shape: [b, 1, num_samples]
spec = self.stft.forward(noisy)
# spec shape: [b, f, t], torch.complex64
# t = (num_samples - win_size) / hop_size + 1
spec = torch.view_as_real(spec)
# spec shape: [b, f, t, 2]
real = spec[..., 0]
imag = spec[..., 1]
mag = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag, real)
# shape: [b, f, t]
mag = mag.permute(0, 2, 1)
phase = phase.permute(0, 2, 1)
# shape: [b, t, f]
mask, out_state1 = self.sep1.forward(mag, in_state1)
# mask shape: [b, t, f]
estimated_mag = mask * mag
s1_stft = estimated_mag * torch.exp((1j * phase))
# s1_stft shape: [b, t, f], torch.complex64
y1 = torch.fft.irfft2(s1_stft, dim=-1)
# y1 shape: [b, t, fft_size], torch.float32
y1 = y1.permute(0, 2, 1)
# y1 shape: [b, fft_size, t]
encoded_f = self.encoder_conv1.forward(y1)
# shape: [b, c, t]
encoded_f = encoded_f.permute(0, 2, 1)
# shape: [b, t, c]
encoded_f_norm = self.encoder_norm1.forward(encoded_f)
# shape: [b, t, c]
mask_2, out_state2 = self.sep2.forward(encoded_f_norm, in_state2)
# shape: [b, t, c]
estimated = mask_2 * encoded_f
estimated = estimated.permute(0, 2, 1)
# shape: [b, c, t]
denoise_frame = self.decoder_conv1.forward(estimated)
# shape: [b, fft_size, t]
return denoise_frame, out_state1, out_state2
def forward_chunk_by_chunk(self, noisy: torch.Tensor):
noisy = self.signal_prepare(noisy)
# noisy shape: [b, 1, num_samples]
batch_size, _, num_samples_pad = noisy.shape
# print(f"num_samples: {num_samples}, num_samples_pad: {num_samples_pad}")
t = (num_samples_pad - self.fft_size) // self.hop_size + 1
denoise_list = list()
out_state1 = None
out_state2 = None
overlap_size = self.fft_size - self.hop_size
denoise_cache = torch.zeros(size=(batch_size, overlap_size), dtype=noisy.dtype)
# denoise_list.append(torch.clone(denoise_cache))
for i in range(t):
begin = i * self.hop_size
end = begin + self.fft_size
sub_noisy = noisy[:, :, begin: end]
# noisy shape: [b, 1, frame_size]
with torch.no_grad():
sub_denoise_frame, out_state1, out_state2 = self.forward_chunk(sub_noisy, out_state1, out_state2)
# sub_denoise_frame shape: [b, fft_size, 1]
sub_denoise_frame = sub_denoise_frame[:, :, 0]
# sub_denoise_frame shape: [b, fft_size]
sub_denoise_frame[:, :overlap_size] += denoise_cache
denoise_out = sub_denoise_frame[:, :self.hop_size]
denoise_cache = sub_denoise_frame[:, self.hop_size:]
# denoise_cache shape: [b, hop_size]
denoise_list.append(denoise_out)
denoise = torch.concat(denoise_list, dim=-1)
# denoise shape: [b, num_samples]
denoise = torch.unsqueeze(denoise, dim=1)
# denoise shape: [b, 1, num_samples]
return denoise
def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
# overlap and add
# denoise_frame shape: [b, fft_size, t]
denoise = torch.nn.functional.fold(
denoise_frame,
output_size=(num_samples, 1),
kernel_size=(self.fft_size, 1),
padding=(0, 0),
stride=(self.hop_size, 1),
)
# denoise shape: [b, 1, num_samples, 1]
denoise = denoise.reshape(batch_size, -1)
# denoise shape: [b, num_samples]
return denoise
class DTLNPretrainedModel(DTLNModel):
def __init__(self,
config: DTLNConfig,
):
super(DTLNPretrainedModel, self).__init__(
fft_size=config.fft_size,
hop_size=config.hop_size,
win_type=config.win_type,
encoder_size=config.encoder_size,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = DTLNConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
config = DTLNConfig()
model = DTLNPretrainedModel(config)
model.eval()
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
with torch.no_grad():
denoise = model.forward(noisy)
print(f"denoise.shape: {denoise.shape}")
print(denoise[:, :, 300: 302])
print(denoise[:, :, 15680: 15682])
print(denoise[:, :, 15760: 15762])
print(denoise[:, :, 15840: 15842])
denoise = model.forward_chunk_by_chunk(noisy)
print(f"denoise.shape: {denoise.shape}")
# denoise = denoise[:, :, (config.fft_size - config.hop_size):]
print(denoise[:, :, 300: 302])
print(denoise[:, :, 15680: 15682])
print(denoise[:, :, 15760: 15762])
print(denoise[:, :, 15840: 15842])
return
if __name__ == "__main__":
main()