HoneyTian's picture
add frcrn model
cba47e4
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Union, Tuple
import torch
import torch.nn as nn
from toolbox.torchaudio.models.frcrn.uni_deep_fsmn import UniDeepFsmn
class ComplexUniDeepFsmn(nn.Module):
def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20):
super(ComplexUniDeepFsmn, self).__init__()
self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
self.fsmn_re_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
self.fsmn_im_l2 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
def forward(self, x: torch.Tensor):
"""
:param x: torch.Tensor, shape: [b, c, h, t, 2]
:return: torch.Tensor, shape: [b, h, t, 2]
"""
b, c, h, t, d = x.size()
x = torch.reshape(x, shape=(b, c * h, t, d))
# x shape: [b, h', t, 2]
x = torch.transpose(x, dim0=1, dim1=2)
# x shape: [b, t, h', 2]
real_l1 = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1])
imaginary_l1 = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0])
# real, image shape: [b, t, h']
real = self.fsmn_re_l2(real_l1) - self.fsmn_im_l2(imaginary_l1)
imaginary = self.fsmn_re_l2(imaginary_l1) + self.fsmn_im_l2(real_l1)
# real, image shape: [b, t, h']
output = torch.stack(tensors=(real, imaginary), dim=-1)
# output shape: [b, t, h', 2]
output = torch.transpose(output, dim0=1, dim1=2)
# output shape: [b, h', t, 2]
output = torch.reshape(output, shape=(b, c, h, t, d))
# output shape: [b, c, h, t, 2]
return output
class ComplexUniDeepFsmnL1(nn.Module):
def __init__(self, input_dim: int, hidden_size: int, lorder: int = 20):
super(ComplexUniDeepFsmnL1, self).__init__()
self.fsmn_re_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
self.fsmn_im_l1 = UniDeepFsmn(input_dim, hidden_size, lorder=lorder)
def forward(self, x: torch.Tensor):
b, c, h, t, d = x.size()
x = torch.transpose(x, dim0=1, dim1=3)
# x shape: [b, t, h, c, 2]
x = torch.reshape(x, shape=(b * t, h, c, d))
# x shape: [b*t, h, c, 2]
real = self.fsmn_re_l1(x[..., 0]) - self.fsmn_im_l1(x[..., 1])
imaginary = self.fsmn_re_l1(x[..., 1]) + self.fsmn_im_l1(x[..., 0])
# real, image shape: [b*t, h, c]
output = torch.stack(tensors=(real, imaginary), dim=-1)
# output shape: [b*t, h, c, 2]
output = torch.reshape(output, shape=(b, t, h, c, d))
# output shape: [b, t, h, c, 2]
output = torch.transpose(output, dim0=1, dim1=3)
# output shape: [b, c, h, t, 2]
return output
class ComplexConv2d(nn.Module):
# https://github.com/litcoderr/ComplexCNN/blob/master/complexcnn/modules.py
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias: bool = True,
**kwargs
):
super().__init__()
# Model components
self.conv_re = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs
)
self.conv_im = nn.Conv2d(
in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
**kwargs
)
def forward(self, x: torch.Tensor):
"""
:param x: torch.Tensor, shape: [b, c, h, w, 2]
:return:
"""
real = self.conv_re(x[..., 0]) - self.conv_im(x[..., 1])
imaginary = self.conv_re(x[..., 1]) + self.conv_im(x[..., 0])
output = torch.stack((real, imaginary), dim=-1)
return output
class ComplexConvTranspose2d(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]] = 1,
padding: Union[int, Tuple[int, int]] = 0,
output_padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
bias=True,
**kwargs
):
super().__init__()
# Model components
self.tconv_re = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
**kwargs
)
self.tconv_im = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
**kwargs
)
def forward(self, x: torch.Tensor):
"""
:param x: torch.Tensor, shape: [b, c, h, w, 2]
:return:
"""
real = self.tconv_re(x[..., 0]) - self.tconv_im(x[..., 1])
imaginary = self.tconv_re(x[..., 1]) + self.tconv_im(x[..., 0])
output = torch.stack((real, imaginary), dim=-1)
return output
class ComplexBatchNorm2d(nn.Module):
def __init__(self,
num_features: int,
eps: float = 1e-5,
momentum: float = 0.1,
affine: bool = True,
track_running_stats: bool = True,
**kwargs
):
super().__init__()
self.bn_re = nn.BatchNorm2d(
num_features=num_features,
momentum=momentum,
affine=affine,
eps=eps,
track_running_stats=track_running_stats,
**kwargs
)
self.bn_im = nn.BatchNorm2d(
num_features=num_features,
momentum=momentum,
affine=affine,
eps=eps,
track_running_stats=track_running_stats,
**kwargs
)
def forward(self, x: torch.Tensor):
real = self.bn_re(x[..., 0])
imag = self.bn_im(x[..., 1])
output = torch.stack((real, imag), dim=-1)
return output
def main():
# x = torch.rand(size=(1, 1, 32, 200, 2))
# fsmn = ComplexUniDeepFsmn(
# input_dim=32,
# hidden_size=64,
# )
# result = fsmn.forward(x)
# print(result.shape)
# x = torch.rand(size=(1, 32, 32, 200, 2))
# fsmn = ComplexUniDeepFsmnL1(
# input_dim=32,
# hidden_size=64,
# )
# result = fsmn.forward(x)
# print(result.shape)
# x = torch.rand(size=(1, 32, 200, 200, 2))
x = torch.rand(size=(1, 1, 320, 200, 2))
conv2d = ComplexConv2d(
in_channels=1,
out_channels=128,
kernel_size=(5, 2),
stride=(2, 1),
padding=(0, 1),
)
result = conv2d.forward(x)
print(result.shape)
# x = torch.rand(size=(1, 32, 200, 200, 2))
# x = torch.rand(size=(1, 64, 15, 2000, 2))
# tconv = ComplexConvTranspose2d(
# in_channels=64,
# out_channels=32,
# kernel_size=(3, 3),
# stride=(2, 1),
# padding=(0, 1),
# )
# result = tconv.forward(x)
# print(result.shape)
return
if __name__ == "__main__":
main()