#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/LXP-Never/TCNN https://github.com/LXP-Never/TCNN/blob/main/TCNN_model.py https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement https://ieeexplore.ieee.org/abstract/document/8683634 参考来源: https://github.com/WenzheLiu-Speech/awesome-speech-enhancement """ from typing import Union import torch import torch.nn as nn from torch.nn import functional as F from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t class Chomp1d(nn.Module): def __init__(self, chomp_size: int): super(Chomp1d, self).__init__() self.chomp_size = chomp_size def forward(self, x: torch.Tensor): return x[:, :, :-self.chomp_size].contiguous() class DepthwiseSeparableConv(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: _size_1_t, stride: _size_1_t = 1, padding: Union[str, _size_1_t] = 0, dilation: _size_1_t = 1, causal: bool = False, ): super(DepthwiseSeparableConv, self).__init__() # Use `groups` option to implement depthwise convolution self.depthwise_conv = nn.Conv1d( in_channels=in_channels, out_channels=in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=in_channels, bias=False, ) self.chomp1d = Chomp1d(padding) if causal else nn.Identity() self.prelu = nn.PReLU() self.norm = nn.BatchNorm1d(in_channels) self.pointwise_conv = nn.Conv1d( in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=False, ) def forward(self, x: torch.Tensor): # x shape: [b, c, t] x = self.depthwise_conv.forward(x) # x shape: [b, c, t_pad] x = self.chomp1d(x) # x shape: [b, c, t] x = self.prelu(x) x = self.norm(x) x = self.pointwise_conv.forward(x) return x class ResBlock(nn.Module): def __init__(self, in_channels: int, hidden_channels: int, kernel_size: _size_1_t, dilation: _size_1_t = 1, ): super(ResBlock, self).__init__() self.conv1d = nn.Conv1d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1) self.prelu = nn.PReLU(num_parameters=1) self.norm = nn.BatchNorm1d(num_features=hidden_channels) self.sconv = DepthwiseSeparableConv( in_channels=hidden_channels, out_channels=in_channels, kernel_size=kernel_size, stride=1, padding=(kernel_size - 1) * dilation, dilation=dilation, causal=True, ) def forward(self, inputs: torch.Tensor): x = inputs # x shape: [b, in_channels, t] x = self.conv1d.forward(x) # x shape: [b, out_channels, t] x = self.prelu(x) x = self.norm(x) # x shape: [b, out_channels, t] x = self.sconv.forward(x) # x shape: [b, in_channels, t] result = x + inputs return result class TCNNBlock(nn.Module): def __init__(self, in_channels: int, hidden_channels: int, kernel_size: int = 3, init_dilation: int = 2, num_layers: int = 6 ): super(TCNNBlock, self).__init__() self.layers = nn.ModuleList(modules=[]) for i in range(num_layers): dilation_size = init_dilation ** i # in_channels = in_channels if i == 0 else out_channels self.layers.append( ResBlock( in_channels, hidden_channels, kernel_size, dilation=dilation_size, ) ) def forward(self, x: torch.Tensor): for layer in self.layers: # x shape: [b, c, t] x = layer.forward(x) # x shape: [b, c, t] return x class TCNN(nn.Module): def __init__(self): super(TCNN, self).__init__() self.win_size = 320 self.hop_size = 160 self.conv2d_1 = nn.Sequential( nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2)), nn.BatchNorm2d(num_features=16), nn.PReLU() ) self.conv2d_2 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)), nn.BatchNorm2d(num_features=16), nn.PReLU() ) self.conv2d_3 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), nn.BatchNorm2d(num_features=16), nn.PReLU() ) self.conv2d_4 = nn.Sequential( nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), nn.BatchNorm2d(num_features=32), nn.PReLU() ) self.conv2d_5 = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), nn.BatchNorm2d(num_features=32), nn.PReLU() ) self.conv2d_6 = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), nn.BatchNorm2d(num_features=64), nn.PReLU() ) self.conv2d_7 = nn.Sequential( nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)), nn.BatchNorm2d(num_features=64), nn.PReLU() ) # 256 = 64 * 4 self.tcnn_block_1 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) self.tcnn_block_2 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) self.tcnn_block_3 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6) self.dconv2d_7 = nn.Sequential( nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), output_padding=(0, 0)), nn.BatchNorm2d(num_features=64), nn.PReLU() ) self.dconv2d_6 = nn.Sequential( nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), output_padding=(0, 0)), nn.BatchNorm2d(num_features=32), nn.PReLU() ) self.dconv2d_5 = nn.Sequential( nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), output_padding=(0, 0)), nn.BatchNorm2d(num_features=32), nn.PReLU() ) self.dconv2d_4 = nn.Sequential( nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), output_padding=(0, 0)), nn.BatchNorm2d(num_features=16), nn.PReLU() ) self.dconv2d_3 = nn.Sequential( nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1), output_padding=(0, 1)), nn.BatchNorm2d(num_features=16), nn.PReLU() ) self.dconv2d_2 = nn.Sequential( nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2), output_padding=(0, 1)), nn.BatchNorm2d(num_features=16), nn.PReLU() ) self.dconv2d_1 = nn.Sequential( nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2), output_padding=(0, 0)), nn.BatchNorm2d(num_features=1), nn.PReLU() ) def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: if signal.dim() == 2: signal = torch.unsqueeze(signal, dim=1) _, _, n_samples = signal.shape remainder = (n_samples - self.win_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) return signal, n_samples def forward(self, noisy: torch.Tensor, ): noisy, num_samples = self.signal_prepare(noisy) batch_size, _, num_samples_pad = noisy.shape # n_frame = (num_samples_pad - self.win_size) / self.hop_size + 1 # unfold # noisy shape: [b, 1, num_samples_pad] noisy = noisy.unsqueeze(1) # noisy shape: [b, 1, 1, num_samples_pad] noisy_frame = torch.nn.functional.unfold( input=noisy, kernel_size=(1, self.win_size), padding=(0, 0), stride=(1, self.hop_size), ) # noisy_frame shape: [b, win_size, n_frame] noisy_frame = noisy_frame.unsqueeze(1) # noisy_frame shape: [b, 1, win_size, n_frame] noisy_frame = noisy_frame.permute(0, 1, 3, 2) # noisy_frame shape: [b, 1, n_frame, win_size] denoise_frame = self.forward_chunk(noisy_frame) # denoise_frame shape: [b, c, n_frame, win_size] denoise_frame = denoise_frame.squeeze(1) # denoise_frame shape: [b, n_frame, win_size] denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad) # denoise shape: [b, num_samples_pad] denoise = denoise[:, :num_samples] # denoise shape: [b, num_samples] return denoise def forward_chunk(self, inputs: torch.Tensor): # inputs shape: [b, c, t, segment_length] conv2d_1 = self.conv2d_1(inputs) conv2d_2 = self.conv2d_2(conv2d_1) conv2d_3 = self.conv2d_3(conv2d_2) conv2d_4 = self.conv2d_4(conv2d_3) conv2d_5 = self.conv2d_5(conv2d_4) conv2d_6 = self.conv2d_6(conv2d_5) conv2d_7 = self.conv2d_7(conv2d_6) # shape: [b, c, t, 4] reshape_1 = conv2d_7.permute(0, 1, 3, 2) # shape: [b, c, 4, t] batch_size, C, frame_len, frame_num = reshape_1.shape reshape_1 = reshape_1.reshape(batch_size, C * frame_len, frame_num) # shape: [b, c*4, t] tcnn_block_1 = self.tcnn_block_1.forward(reshape_1) tcnn_block_2 = self.tcnn_block_2.forward(tcnn_block_1) tcnn_block_3 = self.tcnn_block_3.forward(tcnn_block_2) # shape: [b, c*4, t] reshape_2 = tcnn_block_3.reshape(batch_size, C, frame_len, frame_num) reshape_2 = reshape_2.permute(0, 1, 3, 2) # shape: [b, c, t, 4] dconv2d_7 = self.dconv2d_7(torch.cat((conv2d_7, reshape_2), dim=1)) dconv2d_6 = self.dconv2d_6(torch.cat((conv2d_6, dconv2d_7), dim=1)) dconv2d_5 = self.dconv2d_5(torch.cat((conv2d_5, dconv2d_6), dim=1)) dconv2d_4 = self.dconv2d_4(torch.cat((conv2d_4, dconv2d_5), dim=1)) dconv2d_3 = self.dconv2d_3(torch.cat((conv2d_3, dconv2d_4), dim=1)) dconv2d_2 = self.dconv2d_2(torch.cat((conv2d_2, dconv2d_3), dim=1)) dconv2d_1 = self.dconv2d_1(torch.cat((conv2d_1, dconv2d_2), dim=1)) return dconv2d_1 def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int): # overlap and add # https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement/blob/main/TCNN/util/utils.py#L40 b, t, f = denoise_frame.shape if f != self.win_size: raise AssertionError denoise = torch.zeros(size=(b, num_samples), dtype=denoise_frame.dtype) count = torch.zeros(size=(b, num_samples), dtype=torch.float32) start = 0 end = start + self.win_size for i in range(t): denoise[..., start:end] += denoise_frame[:, i, :] count[..., start:end] += 1. start += self.hop_size end = start + self.win_size denoise = denoise / count return denoise def main(): model = TCNN() x = torch.randn(64, 1, 5, 320) # x = torch.randn(64, 1, 5, 160) y = model.forward_chunk(x) print("output", y.shape) noisy = torch.randn(size=(2, 16000), dtype=torch.float32) denoise = model.forward(noisy) print(f"denoise.shape: {denoise.shape}") return if __name__ == "__main__": main()