HoneyTian's picture
first commit
bd94e77
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://arxiv.org/abs/1805.00579
https://github.com/haoxiangsnr/A-Convolutional-Recurrent-Neural-Network-for-Real-Time-Speech-Enhancement
"""
import torch
import torch.nn as nn
class CausalConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 2),
stride=(2, 1),
padding=(0, 1)
)
self.norm = nn.BatchNorm2d(num_features=out_channels)
self.activation = nn.ELU()
def forward(self, x):
"""
2D Causal convolution.
Args:
x: [B, C, F, T]
Returns:
[B, C, F, T]
"""
x = self.conv(x)
x = x[:, :, :, :-1] # chomp size
x = self.norm(x)
x = self.activation(x)
return x
class CausalTransConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)):
super().__init__()
self.conv = nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=(3, 2),
stride=(2, 1),
output_padding=output_padding
)
self.norm = nn.BatchNorm2d(num_features=out_channels)
if is_last:
self.activation = nn.ReLU()
else:
self.activation = nn.ELU()
def forward(self, x):
"""
2D Causal convolution.
Args:
x: [B, C, F, T]
Returns:
[B, C, F, T]
"""
x = self.conv(x)
x = x[:, :, :, :-1] # chomp size
x = self.norm(x)
x = self.activation(x)
return x
class CRN(nn.Module):
"""
Input: [batch size, channels=1, T, n_fft]
Output: [batch size, T, n_fft]
"""
def __init__(self):
super(CRN, self).__init__()
# Encoder
self.conv_block_1 = CausalConvBlock(1, 16)
self.conv_block_2 = CausalConvBlock(16, 32)
self.conv_block_3 = CausalConvBlock(32, 64)
self.conv_block_4 = CausalConvBlock(64, 128)
self.conv_block_5 = CausalConvBlock(128, 256)
# LSTM
self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True)
self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128)
self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64)
self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32)
self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0))
self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True)
def forward(self, x):
self.lstm_layer.flatten_parameters()
e_1 = self.conv_block_1(x)
e_2 = self.conv_block_2(e_1)
e_3 = self.conv_block_3(e_2)
e_4 = self.conv_block_4(e_3)
e_5 = self.conv_block_5(e_4) # [2, 256, 4, 200]
batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape
# [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024]
lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1)
lstm_out, _ = self.lstm_layer(lstm_in) # [2, 200, 1024]
lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size) # [2, 256, 4, 200]
d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1))
d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1))
d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1))
d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1))
d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1))
return d_5
def main():
layer = CRN()
a = torch.rand(2, 1, 161, 200)
print(layer(a).shape)
return
if __name__ == '__main__':
main()