Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import Union, Tuple | |
import torch | |
import torch.nn as nn | |
from toolbox.torchaudio.models.frcrn.uni_deep_fsmn import UniDeepFsmn | |
class ComplexUniDeepFsmn(nn.Module): | |
def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): | |
super(ComplexUniDeepFsmn, self).__init__() | |
self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
self.fsmn_re_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
self.fsmn_im_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
def forward(self, x: torch.Tensor): | |
""" | |
:param x: torch.Tensor, shape: [b, c, h, t, 2] | |
:return: torch.Tensor, shape: [b, h, t, 2] | |
""" | |
b, c, h, t, d = x.size() | |
x = torch.reshape(x, shape=(b, c * h, t, d)) | |
# x shape: [b, h', t, 2] | |
x = torch.transpose(x, dim0=1, dim1=2) | |
# x shape: [b, t, h', 2] | |
real_l1 = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) | |
imaginary_l1 = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) | |
# real, image shape: [b, t, h'] | |
real = self.fsmn_re_l2(real_l1) - self.fsmn_im_l2(imaginary_l1) | |
imaginary = self.fsmn_re_l2(imaginary_l1) + self.fsmn_im_l2(real_l1) | |
# real, image shape: [b, t, h'] | |
output = torch.stack(tensors=(real, imaginary), dim=-1) | |
# output shape: [b, t, h', 2] | |
output = torch.transpose(output, dim0=1, dim1=2) | |
# output shape: [b, h', t, 2] | |
output = torch.reshape(output, shape=(b, c, h, t, d)) | |
# output shape: [b, c, h, t, 2] | |
return output | |
class ComplexUniDeepFsmnL1(nn.Module): | |
def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20): | |
super(ComplexUniDeepFsmnL1, self).__init__() | |
self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder) | |
def forward(self, x: torch.Tensor): | |
b, c, h, t, d = x.size() | |
x = torch.transpose(x, dim0=1, dim1=3) | |
# x shape: [b, t, h, c, 2] | |
x = torch.reshape(x, shape=(b * t, h, c, d)) | |
# x shape: [b*t, h, c, 2] | |
real = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1]) | |
imaginary = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0]) | |
# real, image shape: [b*t, h, c] | |
output = torch.stack(tensors=(real, imaginary), dim=-1) | |
# output shape: [b*t, h, c, 2] | |
output = torch.reshape(output, shape=(b, t, h, c, d)) | |
# output shape: [b, t, h, c, 2] | |
output = torch.transpose(output, dim0=1, dim1=3) | |
# output shape: [b, c, h, t, 2] | |
return output | |
class ComplexConv2d(nn.Module): | |
# https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Tuple[int, int]], | |
stride: Union[int, Tuple[int, int]] = 1, | |
padding: Union[int, Tuple[int, int]] = 0, | |
dilation: Union[int, Tuple[int, int]] = 1, | |
groups: int = 1, | |
bias: bool = True, | |
**kwargs | |
): | |
super().__init__() | |
# Model components | |
self.conv_re = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
**kwargs | |
) | |
self.conv_im = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
bias=bias, | |
**kwargs | |
) | |
def forward(self, x: torch.Tensor): | |
""" | |
:param x: torch.Tensor, shape: [b, c, h, w, 2] | |
:return: | |
""" | |
real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1]) | |
imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0]) | |
output = torch.stack((real, imaginary), dim=-1) | |
return output | |
class ComplexConvTranspose2d(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Tuple[int, int]], | |
stride: Union[int, Tuple[int, int]] = 1, | |
padding: Union[int, Tuple[int, int]] = 0, | |
output_padding: Union[int, Tuple[int, int]] = 0, | |
dilation: Union[int, Tuple[int, int]] = 1, | |
groups: int = 1, | |
bias=True, | |
**kwargs | |
): | |
super().__init__() | |
# Model components | |
self.tconv_re = nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
groups=groups, | |
bias=bias, | |
dilation=dilation, | |
**kwargs | |
) | |
self.tconv_im = nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
output_padding=output_padding, | |
groups=groups, | |
bias=bias, | |
dilation=dilation, | |
**kwargs | |
) | |
def forward(self, x: torch.Tensor): | |
""" | |
:param x: torch.Tensor, shape: [b, c, h, w, 2] | |
:return: | |
""" | |
real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1]) | |
imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0]) | |
output = torch.stack((real, imaginary), dim=-1) | |
return output | |
class ComplexBatchNorm2d(nn.Module): | |
def __init__(self, | |
num_features: int, | |
eps: float = 1e-5, | |
momentum: float = 0.1, | |
affine: bool = True, | |
track_running_stats: bool = True, | |
**kwargs | |
): | |
super().__init__() | |
self.bn_re = nn.BatchNorm2d( | |
num_features=num_features, | |
momentum=momentum, | |
affine=affine, | |
eps=eps, | |
track_running_stats=track_running_stats, | |
**kwargs | |
) | |
self.bn_im = nn.BatchNorm2d( | |
num_features=num_features, | |
momentum=momentum, | |
affine=affine, | |
eps=eps, | |
track_running_stats=track_running_stats, | |
**kwargs | |
) | |
def forward(self, x: torch.Tensor): | |
real = self.bn_re(x[..., 0]) | |
imag = self.bn_im(x[..., 1]) | |
output = torch.stack((real, imag), dim=-1) | |
return output | |
def main(): | |
# x = torch.rand(size=(1, 1, 32, 200, 2)) | |
# fsmn = ComplexUniDeepFsmn( | |
# input_dim=32, | |
# hidden_size=64, | |
# ) | |
# result = fsmn.forward(x) | |
# print(result.shape) | |
# x = torch.rand(size=(1, 32, 32, 200, 2)) | |
# fsmn = ComplexUniDeepFsmnL1( | |
# input_dim=32, | |
# hidden_size=64, | |
# ) | |
# result = fsmn.forward(x) | |
# print(result.shape) | |
# x = torch.rand(size=(1, 32, 200, 200, 2)) | |
x = torch.rand(size=(1, 1, 320, 200, 2)) | |
conv2d = ComplexConv2d( | |
in_channels=1, | |
out_channels=128, | |
kernel_size=(5, 2), | |
stride=(2, 1), | |
padding=(0, 1), | |
) | |
result = conv2d.forward(x) | |
print(result.shape) | |
# x = torch.rand(size=(1, 32, 200, 200, 2)) | |
# x = torch.rand(size=(1, 64, 15, 2000, 2)) | |
# tconv = ComplexConvTranspose2d( | |
# in_channels=64, | |
# out_channels=32, | |
# kernel_size=(3, 3), | |
# stride=(2, 1), | |
# padding=(0, 1), | |
# ) | |
# result = tconv.forward(x) | |
# print(result.shape) | |
return | |
if __name__ == "__main__": | |
main() | |