#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import Union, Tuple import torch import torch.nn as nn from toolbox.torchaudio.models.frcrn import complex_nn class SELayer(nn.Module): def __init__(self, channels: int, reduction: int = 16): super(SELayer, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc_r = nn.Sequential( nn.Linear(channels, channels // reduction), nn.ReLU(inplace=True), nn.Linear(channels // reduction, channels), nn.Sigmoid() ) self.fc_i = nn.Sequential( nn.Linear(channels, channels // reduction), nn.ReLU(inplace=True), nn.Linear(channels // reduction, channels), nn.Sigmoid() ) def forward(self, x: torch.Tensor): b, c, _, _, _ = x.size() x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(b, c, 1, 1, 1) y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(b, c, 1, 1, 1) y = torch.cat(tensors=[y_r, y_i], dim=4) return x * y class Encoder(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]] = None, use_complex_networks: bool = False, padding_mode: str = "zeros" ): super().__init__() if padding is None: padding = [(k - 1) // 2 for k in kernel_size] # 'SAME' padding if use_complex_networks: conv = complex_nn.ComplexConv2d bn = complex_nn.ComplexBatchNorm2d else: conv = nn.Conv2d bn = nn.BatchNorm2d self.conv = conv( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, padding_mode=padding_mode ) self.bn = bn(out_channels) self.relu = nn.LeakyReLU(inplace=True) def forward(self, x: torch.Tensor): x = self.conv(x) x = self.bn(x) x = self.relu(x) return x class Decoder(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]] = (0, 0), use_complex_networks: bool = False, ): super().__init__() if use_complex_networks: tconv = complex_nn.ComplexConvTranspose2d bn = complex_nn.ComplexBatchNorm2d else: tconv = nn.ConvTranspose2d bn = nn.BatchNorm2d self.transconv = tconv( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding ) self.bn = bn(out_channels) self.relu = nn.LeakyReLU(inplace=True) def forward(self, x): x = self.transconv(x) x = self.bn(x) x = self.relu(x) return x class UNetConfig14(object): """ inputs x shape: [1, 1, 321, 2000, 2] sample rate: 16000 nfft: 640 win_size: 640 hop_size: 320 (200ms) """ def __init__(self, in_channels: int): self.enc_channels = [in_channels, 128, 128, 128, 128, 128, 128, 128] self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (2, 2)] self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1] self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), (5, 2), (5, 2)] self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] class UNetConfig10(object): """ inputs x shape: [1, 1, 65, 200, 2] sample rate: 8000 nfft: 128 win_size: 128 hop_size: 64 (8ms) """ def __init__(self, in_channels: int): self.enc_channels = [in_channels, 16, 32, 64, 128, 256] self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] self.dec_channels = [128, 128, 64, 32, 16, 1] self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)] self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] class UNetConfig20(object): """ inputs x shape: [1, 1, 257, 2000, 2] sample rate: 8000 nfft: 512 win_size: 512 hop_size: 256 (32ms) """ def __init__(self, in_channels: int, model_complexity: int): self.enc_channels = [ in_channels, model_complexity, model_complexity, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, 128 ] self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3), (5, 3), (5, 3), (5, 3), (5, 3), (5, 3)] self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), (2, 1)] self.enc_paddings = [ (3, 0), (0, 3), None, # (0, 2), None, None, # (3,1), None, # (3,1), None, # (1,2), None, None, None ] self.dec_channels = [ 64, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity * 2, model_complexity, model_complexity, 1 ] self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3), (4, 2), (6, 3), (7, 4), (1, 7), (7, 1)] self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), (2, 1), (2, 2), (1, 1), (1, 1)] self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1), (1, 0), (2, 1), (2, 1), (0, 3), (3, 0)] class UNet(nn.Module): def __init__(self, in_channels: int = 1, use_complex_networks: bool = False, model_complexity: int = 45, model_depth: int = 20, padding_mode: str = "zeros" ): super().__init__() if use_complex_networks: model_complexity = int(model_complexity // 1.414) # config if model_depth == 14: config = UNetConfig14(in_channels) elif model_depth == 10: config = UNetConfig10(in_channels) elif model_depth == 20: config = UNetConfig20(in_channels, model_complexity) else: raise AssertionError(f"Unknown model depth : {model_depth}") self.model_length = model_depth // 2 self.fsmn = complex_nn.ComplexUniDeepFsmn( config.enc_channels[-1], config.enc_channels[-1] ) # go down self.encoder_layers = nn.ModuleList(modules=[]) for i in range(self.model_length): encoder_layer = nn.Sequential( complex_nn.ComplexUniDeepFsmnL1( config.enc_channels[i], config.enc_channels[i] ) if i != 0 else nn.Identity(), Encoder( config.enc_channels[i], config.enc_channels[i + 1], kernel_size=config.enc_kernel_sizes[i], stride=config.enc_strides[i], padding=config.enc_paddings[i], use_complex_networks=use_complex_networks, padding_mode=padding_mode ), SELayer(config.enc_channels[i + 1], reduction=8) ) self.encoder_layers.append(encoder_layer) self.decoder_layers = nn.ModuleList(modules=[]) for i in range(self.model_length): decoder_layer = nn.Sequential( Decoder( config.dec_channels[i] * 2, config.dec_channels[i + 1], kernel_size=config.dec_kernel_sizes[i], stride=config.dec_strides[i], padding=config.dec_paddings[i], use_complex_networks=use_complex_networks ), complex_nn.ComplexUniDeepFsmnL1( config.dec_channels[i + 1], config.dec_channels[i + 1] ) if i < (self.model_length - 1) else nn.Identity(), SELayer( config.dec_channels[i + 1], reduction=8 ) if i < (self.model_length - 2) else nn.Identity() ) self.decoder_layers.append(decoder_layer) if use_complex_networks: conv = complex_nn.ComplexConv2d else: conv = nn.Conv2d self.linear = conv( in_channels=config.dec_channels[-1], out_channels=1, kernel_size=1, ) def forward(self, inputs: torch.Tensor): """ :param inputs: torch.Tensor, shape: [b, c, f, t, 2] :return: """ x = inputs # print(f"inputs: {x.shape}") # go down xs = list() xs_se = list() xs_se.append(x) for encoder_layer in self.encoder_layers: xs.append(x) # print(f"x: {x.shape}") x = encoder_layer.forward(x) # print(f"x: {x.shape}") xs_se.append(x) # x shape: [b, c, 1, t', 2] x = self.fsmn.forward(x) # x shape: [b, c, 1, t', 2] # print(f"fsmn") p = x for i, decoder_layers in enumerate(self.decoder_layers): p = decoder_layers.forward(p) # print(f"p: {p.shape}") if i == self.model_length - 1: break p = torch.cat(tensors=[p, xs_se[self.model_length - 1 - i]], dim=1) # cmp_spec: [1, 1, 321, 200, 2] # cmp_spec: [1, 1, 513, 200, 2] cmp_spec = self.linear.forward(p) return cmp_spec def main(): # [batch_size, 1, freq_bins, time_steps, 2] # x = torch.rand(size=(1, 1, 257, 2000, 2)) # unet = UNet( # in_channels=1, # model_complexity=45, # model_depth=20, # use_complex_networks=True # ) # print(unet) # result = unet.forward(x) # print(result.shape) x = torch.rand(size=(1, 1, 65, 2000, 2)) unet = UNet( in_channels=1, model_complexity=-1, model_depth=10, use_complex_networks=True ) print(unet) result = unet.forward(x) print(result.shape) return if __name__ == "__main__": main()