HoneyTian's picture
add microphone audio input
8c3c188
raw
history blame
13 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/LXP-Never/TCNN
https://github.com/LXP-Never/TCNN/blob/main/TCNN_model.py
https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement
https://ieeexplore.ieee.org/abstract/document/8683634
参考来源:
https://github.com/WenzheLiu-Speech/awesome-speech-enhancement
"""
from typing import Union
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.nn.common_types import _size_1_t, _size_2_t, _size_3_t
class Chomp1d(nn.Module):
def __init__(self, chomp_size: int):
super(Chomp1d, self).__init__()
self.chomp_size = chomp_size
def forward(self, x: torch.Tensor):
return x[:, :, :-self.chomp_size].contiguous()
class DepthwiseSeparableConv(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: _size_1_t,
stride: _size_1_t = 1,
padding: Union[str, _size_1_t] = 0,
dilation: _size_1_t = 1,
causal: bool = False,
):
super(DepthwiseSeparableConv, self).__init__()
# Use `groups` option to implement depthwise convolution
self.depthwise_conv = nn.Conv1d(
in_channels=in_channels, out_channels=in_channels,
kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation,
groups=in_channels,
bias=False,
)
self.chomp1d = Chomp1d(padding) if causal else nn.Identity()
self.prelu = nn.PReLU()
self.norm = nn.BatchNorm1d(in_channels)
self.pointwise_conv = nn.Conv1d(
in_channels=in_channels, out_channels=out_channels,
kernel_size=1,
bias=False,
)
def forward(self, x: torch.Tensor):
# x shape: [b, c, t]
x = self.depthwise_conv.forward(x)
# x shape: [b, c, t_pad]
x = self.chomp1d(x)
# x shape: [b, c, t]
x = self.prelu(x)
x = self.norm(x)
x = self.pointwise_conv.forward(x)
return x
class ResBlock(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
kernel_size: _size_1_t,
dilation: _size_1_t = 1,
):
super(ResBlock, self).__init__()
self.conv1d = nn.Conv1d(in_channels=in_channels, out_channels=hidden_channels, kernel_size=1)
self.prelu = nn.PReLU(num_parameters=1)
self.norm = nn.BatchNorm1d(num_features=hidden_channels)
self.sconv = DepthwiseSeparableConv(
in_channels=hidden_channels,
out_channels=in_channels,
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) * dilation,
dilation=dilation,
causal=True,
)
def forward(self, inputs: torch.Tensor):
x = inputs
# x shape: [b, in_channels, t]
x = self.conv1d.forward(x)
# x shape: [b, out_channels, t]
x = self.prelu(x)
x = self.norm(x)
# x shape: [b, out_channels, t]
x = self.sconv.forward(x)
# x shape: [b, in_channels, t]
result = x + inputs
return result
class TCNNBlock(nn.Module):
def __init__(self,
in_channels: int,
hidden_channels: int,
kernel_size: int = 3,
init_dilation: int = 2,
num_layers: int = 6
):
super(TCNNBlock, self).__init__()
self.layers = nn.ModuleList(modules=[])
for i in range(num_layers):
dilation_size = init_dilation ** i
# in_channels = in_channels if i == 0 else out_channels
self.layers.append(
ResBlock(
in_channels,
hidden_channels,
kernel_size,
dilation=dilation_size,
)
)
def forward(self, x: torch.Tensor):
for layer in self.layers:
# x shape: [b, c, t]
x = layer.forward(x)
# x shape: [b, c, t]
return x
class TCNN(nn.Module):
def __init__(self):
super(TCNN, self).__init__()
self.win_size = 320
self.hop_size = 160
self.conv2d_1 = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=16, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2)),
nn.BatchNorm2d(num_features=16),
nn.PReLU()
)
self.conv2d_2 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2)),
nn.BatchNorm2d(num_features=16),
nn.PReLU()
)
self.conv2d_3 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
nn.BatchNorm2d(num_features=16),
nn.PReLU()
)
self.conv2d_4 = nn.Sequential(
nn.Conv2d(in_channels=16, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
nn.BatchNorm2d(num_features=32),
nn.PReLU()
)
self.conv2d_5 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
nn.BatchNorm2d(num_features=32),
nn.PReLU()
)
self.conv2d_6 = nn.Sequential(
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
nn.BatchNorm2d(num_features=64),
nn.PReLU()
)
self.conv2d_7 = nn.Sequential(
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1)),
nn.BatchNorm2d(num_features=64),
nn.PReLU()
)
# 256 = 64 * 4
self.tcnn_block_1 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
self.tcnn_block_2 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
self.tcnn_block_3 = TCNNBlock(in_channels=256, hidden_channels=512, kernel_size=3, init_dilation=2, num_layers=6)
self.dconv2d_7 = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
output_padding=(0, 0)),
nn.BatchNorm2d(num_features=64),
nn.PReLU()
)
self.dconv2d_6 = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
output_padding=(0, 0)),
nn.BatchNorm2d(num_features=32),
nn.PReLU()
)
self.dconv2d_5 = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=32, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
output_padding=(0, 0)),
nn.BatchNorm2d(num_features=32),
nn.PReLU()
)
self.dconv2d_4 = nn.Sequential(
nn.ConvTranspose2d(in_channels=64, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
output_padding=(0, 0)),
nn.BatchNorm2d(num_features=16),
nn.PReLU()
)
self.dconv2d_3 = nn.Sequential(
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 1),
output_padding=(0, 1)),
nn.BatchNorm2d(num_features=16),
nn.PReLU()
)
self.dconv2d_2 = nn.Sequential(
nn.ConvTranspose2d(in_channels=32, out_channels=16, kernel_size=(3, 5), stride=(1, 2), padding=(1, 2),
output_padding=(0, 1)),
nn.BatchNorm2d(num_features=16),
nn.PReLU()
)
self.dconv2d_1 = nn.Sequential(
nn.ConvTranspose2d(in_channels=32, out_channels=1, kernel_size=(3, 5), stride=(1, 1), padding=(1, 2),
output_padding=(0, 0)),
nn.BatchNorm2d(num_features=1),
nn.PReLU()
)
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
if signal.dim() == 2:
signal = torch.unsqueeze(signal, dim=1)
_, _, n_samples = signal.shape
remainder = (n_samples - self.win_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
return signal, n_samples
def forward(self,
noisy: torch.Tensor,
):
noisy, num_samples = self.signal_prepare(noisy)
batch_size, _, num_samples_pad = noisy.shape
# n_frame = (num_samples_pad - self.win_size) / self.hop_size + 1
# unfold
# noisy shape: [b, 1, num_samples_pad]
noisy = noisy.unsqueeze(1)
# noisy shape: [b, 1, 1, num_samples_pad]
noisy_frame = torch.nn.functional.unfold(
input=noisy,
kernel_size=(1, self.win_size),
padding=(0, 0),
stride=(1, self.hop_size),
)
# noisy_frame shape: [b, win_size, n_frame]
noisy_frame = noisy_frame.unsqueeze(1)
# noisy_frame shape: [b, 1, win_size, n_frame]
noisy_frame = noisy_frame.permute(0, 1, 3, 2)
# noisy_frame shape: [b, 1, n_frame, win_size]
denoise_frame = self.forward_chunk(noisy_frame)
# denoise_frame shape: [b, c, n_frame, win_size]
denoise_frame = denoise_frame.squeeze(1)
# denoise_frame shape: [b, n_frame, win_size]
denoise = self.denoise_frame_to_denoise(denoise_frame, batch_size, num_samples_pad)
# denoise shape: [b, num_samples_pad]
denoise = denoise[:, :num_samples]
# denoise shape: [b, num_samples]
return denoise
def forward_chunk(self, inputs: torch.Tensor):
# inputs shape: [b, c, t, segment_length]
conv2d_1 = self.conv2d_1(inputs)
conv2d_2 = self.conv2d_2(conv2d_1)
conv2d_3 = self.conv2d_3(conv2d_2)
conv2d_4 = self.conv2d_4(conv2d_3)
conv2d_5 = self.conv2d_5(conv2d_4)
conv2d_6 = self.conv2d_6(conv2d_5)
conv2d_7 = self.conv2d_7(conv2d_6)
# shape: [b, c, t, 4]
reshape_1 = conv2d_7.permute(0, 1, 3, 2)
# shape: [b, c, 4, t]
batch_size, C, frame_len, frame_num = reshape_1.shape
reshape_1 = reshape_1.reshape(batch_size, C * frame_len, frame_num)
# shape: [b, c*4, t]
tcnn_block_1 = self.tcnn_block_1.forward(reshape_1)
tcnn_block_2 = self.tcnn_block_2.forward(tcnn_block_1)
tcnn_block_3 = self.tcnn_block_3.forward(tcnn_block_2)
# shape: [b, c*4, t]
reshape_2 = tcnn_block_3.reshape(batch_size, C, frame_len, frame_num)
reshape_2 = reshape_2.permute(0, 1, 3, 2)
# shape: [b, c, t, 4]
dconv2d_7 = self.dconv2d_7(torch.cat((conv2d_7, reshape_2), dim=1))
dconv2d_6 = self.dconv2d_6(torch.cat((conv2d_6, dconv2d_7), dim=1))
dconv2d_5 = self.dconv2d_5(torch.cat((conv2d_5, dconv2d_6), dim=1))
dconv2d_4 = self.dconv2d_4(torch.cat((conv2d_4, dconv2d_5), dim=1))
dconv2d_3 = self.dconv2d_3(torch.cat((conv2d_3, dconv2d_4), dim=1))
dconv2d_2 = self.dconv2d_2(torch.cat((conv2d_2, dconv2d_3), dim=1))
dconv2d_1 = self.dconv2d_1(torch.cat((conv2d_1, dconv2d_2), dim=1))
return dconv2d_1
def denoise_frame_to_denoise(self, denoise_frame: torch.Tensor, batch_size: int, num_samples: int):
# overlap and add
# https://github.com/HardeyPandya/Temporal-Convolutional-Neural-Network-Single-Channel-Speech-Enhancement/blob/main/TCNN/util/utils.py#L40
b, t, f = denoise_frame.shape
if f != self.win_size:
raise AssertionError
denoise = torch.zeros(size=(b, num_samples), dtype=denoise_frame.dtype)
count = torch.zeros(size=(b, num_samples), dtype=torch.float32)
start = 0
end = start + self.win_size
for i in range(t):
denoise[..., start:end] += denoise_frame[:, i, :]
count[..., start:end] += 1.
start += self.hop_size
end = start + self.win_size
denoise = denoise / count
return denoise
def main():
model = TCNN()
x = torch.randn(64, 1, 5, 320)
# x = torch.randn(64, 1, 5, 160)
y = model.forward_chunk(x)
print("output", y.shape)
noisy = torch.randn(size=(2, 16000), dtype=torch.float32)
denoise = model.forward(noisy)
print(f"denoise.shape: {denoise.shape}")
return
if __name__ == "__main__":
main()