#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://arxiv.org/abs/1805.00579 https://github.com/haoxiangsnr/A-Convolutional-Recurrent-Neural-Network-for-Real-Time-Speech-Enhancement """ import torch import torch.nn as nn class CausalConvBlock(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 2), stride=(2, 1), padding=(0, 1) ) self.norm = nn.BatchNorm2d(num_features=out_channels) self.activation = nn.ELU() def forward(self, x): """ 2D Causal convolution. Args: x: [B, C, F, T] Returns: [B, C, F, T] """ x = self.conv(x) x = x[:, :, :, :-1] # chomp size x = self.norm(x) x = self.activation(x) return x class CausalTransConvBlock(nn.Module): def __init__(self, in_channels, out_channels, is_last=False, output_padding=(0, 0)): super().__init__() self.conv = nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 2), stride=(2, 1), output_padding=output_padding ) self.norm = nn.BatchNorm2d(num_features=out_channels) if is_last: self.activation = nn.ReLU() else: self.activation = nn.ELU() def forward(self, x): """ 2D Causal convolution. Args: x: [B, C, F, T] Returns: [B, C, F, T] """ x = self.conv(x) x = x[:, :, :, :-1] # chomp size x = self.norm(x) x = self.activation(x) return x class CRN(nn.Module): """ Input: [batch size, channels=1, T, n_fft] Output: [batch size, T, n_fft] """ def __init__(self): super(CRN, self).__init__() # Encoder self.conv_block_1 = CausalConvBlock(1, 16) self.conv_block_2 = CausalConvBlock(16, 32) self.conv_block_3 = CausalConvBlock(32, 64) self.conv_block_4 = CausalConvBlock(64, 128) self.conv_block_5 = CausalConvBlock(128, 256) # LSTM self.lstm_layer = nn.LSTM(input_size=1024, hidden_size=1024, num_layers=2, batch_first=True) self.tran_conv_block_1 = CausalTransConvBlock(256 + 256, 128) self.tran_conv_block_2 = CausalTransConvBlock(128 + 128, 64) self.tran_conv_block_3 = CausalTransConvBlock(64 + 64, 32) self.tran_conv_block_4 = CausalTransConvBlock(32 + 32, 16, output_padding=(1, 0)) self.tran_conv_block_5 = CausalTransConvBlock(16 + 16, 1, is_last=True) def forward(self, x): self.lstm_layer.flatten_parameters() e_1 = self.conv_block_1(x) e_2 = self.conv_block_2(e_1) e_3 = self.conv_block_3(e_2) e_4 = self.conv_block_4(e_3) e_5 = self.conv_block_5(e_4) # [2, 256, 4, 200] batch_size, n_channels, n_f_bins, n_frame_size = e_5.shape # [2, 256, 4, 200] = [2, 1024, 200] => [2, 200, 1024] lstm_in = e_5.reshape(batch_size, n_channels * n_f_bins, n_frame_size).permute(0, 2, 1) lstm_out, _ = self.lstm_layer(lstm_in) # [2, 200, 1024] lstm_out = lstm_out.permute(0, 2, 1).reshape(batch_size, n_channels, n_f_bins, n_frame_size) # [2, 256, 4, 200] d_1 = self.tran_conv_block_1(torch.cat((lstm_out, e_5), 1)) d_2 = self.tran_conv_block_2(torch.cat((d_1, e_4), 1)) d_3 = self.tran_conv_block_3(torch.cat((d_2, e_3), 1)) d_4 = self.tran_conv_block_4(torch.cat((d_3, e_2), 1)) d_5 = self.tran_conv_block_5(torch.cat((d_4, e_1), 1)) return d_5 def main(): layer = CRN() a = torch.rand(2, 1, 161, 200) print(layer(a).shape) return if __name__ == '__main__': main()