#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import Union, Tuple import torch import torch.nn as nn from toolbox.torchaudio.models.frcrn.uni_deep_fsmn import UniDeepFsmn class ComplexUniDeepFsmn(nn.Module): def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): super(ComplexUniDeepFsmn, self).__init__() self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) self.fsmn_re_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) self.fsmn_im_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) def forward(self, x: torch.Tensor): """ :param x: torch.Tensor, shape: [b, c, h, t, 2] :return: torch.Tensor, shape: [b, h, t, 2] """ b, c, h, t, d = x.size() x = torch.reshape(x, shape=(b, c * h, t, d)) # x shape: [b, h', t, 2] x = torch.transpose(x, dim0=1, dim1=2) # x shape: [b, t, h', 2] real_l1 = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) imaginary_l1 = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) # real, image shape: [b, t, h'] real = self.fsmn_re_l2(real_l1) - self.fsmn_im_l2(imaginary_l1) imaginary = self.fsmn_re_l2(imaginary_l1) + self.fsmn_im_l2(real_l1) # real, image shape: [b, t, h'] output = torch.stack(tensors=(real, imaginary), dim=-1) # output shape: [b, t, h', 2] output = torch.transpose(output, dim0=1, dim1=2) # output shape: [b, h', t, 2] output = torch.reshape(output, shape=(b, c, h, t, d)) # output shape: [b, c, h, t, 2] return output class ComplexUniDeepFsmnL1(nn.Module): def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): super(ComplexUniDeepFsmnL1, self).__init__() self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) def forward(self, x: torch.Tensor): b, c, h, t, d = x.size() x = torch.transpose(x, dim0=1, dim1=3) # x shape: [b, t, h, c, 2] x = torch.reshape(x, shape=(b * t, h, c, d)) # x shape: [b*t, h, c, 2] real = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) imaginary = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) # real, image shape: [b*t, h, c] output = torch.stack(tensors=(real, imaginary), dim=-1) # output shape: [b*t, h, c, 2] output = torch.reshape(output, shape=(b, t, h, c, d)) # output shape: [b, t, h, c, 2] output = torch.transpose(output, dim0=1, dim1=3) # output shape: [b, c, h, t, 2] return output class ComplexConv2d(nn.Module): # https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, **kwargs ): super().__init__() # Model components self.conv_re = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, **kwargs ) self.conv_im = nn.Conv2d( in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, **kwargs ) def forward(self, x: torch.Tensor): """ :param x: torch.Tensor, shape: [b, c, h, w, 2] :return: """ real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) output = torch.stack((real, imaginary), dim=-1) return output class ComplexConvTranspose2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, output_padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias=True, **kwargs ): super().__init__() # Model components self.tconv_re = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, **kwargs ) self.tconv_im = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, groups=groups, bias=bias, dilation=dilation, **kwargs ) def forward(self, x: torch.Tensor): """ :param x: torch.Tensor, shape: [b, c, h, w, 2] :return: """ real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) output = torch.stack((real, imaginary), dim=-1) return output class ComplexBatchNorm2d(nn.Module): def __init__(self, num_features: int, eps: float = 1e-5, momentum: float = 0.1, affine: bool = True, track_running_stats: bool = True, **kwargs ): super().__init__() self.bn_re = nn.BatchNorm2d( num_features=num_features, momentum=momentum, affine=affine, eps=eps, track_running_stats=track_running_stats, **kwargs ) self.bn_im = nn.BatchNorm2d( num_features=num_features, momentum=momentum, affine=affine, eps=eps, track_running_stats=track_running_stats, **kwargs ) def forward(self, x: torch.Tensor): real = self.bn_re(x[..., 0]) imag = self.bn_im(x[..., 1]) output = torch.stack((real, imag), dim=-1) return output def main(): # x = torch.rand(size=(1, 1, 32, 200, 2)) # fsmn = ComplexUniDeepFsmn( # input_dim=32, # hidden_size=64, # ) # result = fsmn.forward(x) # print(result.shape) # x = torch.rand(size=(1, 32, 32, 200, 2)) # fsmn = ComplexUniDeepFsmnL1( # input_dim=32, # hidden_size=64, # ) # result = fsmn.forward(x) # print(result.shape) # x = torch.rand(size=(1, 32, 200, 200, 2)) x = torch.rand(size=(1, 1, 320, 200, 2)) conv2d = ComplexConv2d( in_channels=1, out_channels=128, kernel_size=(5, 2), stride=(2, 1), padding=(0, 1), ) result = conv2d.forward(x) print(result.shape) # x = torch.rand(size=(1, 32, 200, 200, 2)) # x = torch.rand(size=(1, 64, 15, 2000, 2)) # tconv = ComplexConvTranspose2d( # in_channels=64, # out_channels=32, # kernel_size=(3, 3), # stride=(2, 1), # padding=(0, 1), # ) # result = tconv.forward(x) # print(result.shape) return if __name__ == "__main__": main()