Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import Union, Tuple | |
import torch | |
import torch.nn as nn | |
from toolbox.torchaudio.models.frcrn import complex_nn | |
class SELayer(nn.Module): | |
def __init__(self, channels: int, reduction: int = 16): | |
super(SELayer, self).__init__() | |
self.avg_pool = nn.AdaptiveAvgPool2d(1) | |
self.fc_r = nn.Sequential( | |
nn.Linear(channels, channels // reduction), | |
nn.ReLU(inplace=True), | |
nn.Linear(channels // reduction, channels), | |
nn.Sigmoid() | |
) | |
self.fc_i = nn.Sequential( | |
nn.Linear(channels, channels // reduction), | |
nn.ReLU(inplace=True), | |
nn.Linear(channels // reduction, channels), | |
nn.Sigmoid() | |
) | |
def forward(self, x: torch.Tensor): | |
b, c, _, _, _ = x.size() | |
x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c) | |
x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c) | |
y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(b, c, 1, 1, 1) | |
y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(b, c, 1, 1, 1) | |
y = torch.cat(tensors=[y_r, y_i], dim=4) | |
return x * y | |
class Encoder(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Tuple[int, int]], | |
stride: Union[int, Tuple[int, int]], | |
padding: Union[int, Tuple[int, int]] = None, | |
use_complex_networks: bool = False, | |
padding_mode: str = "zeros" | |
): | |
super().__init__() | |
if padding is None: | |
padding = [(k - 1) // 2 for k in kernel_size] # 'SAME' padding | |
if use_complex_networks: | |
conv = complex_nn.ComplexConv2d | |
bn = complex_nn.ComplexBatchNorm2d | |
else: | |
conv = nn.Conv2d | |
bn = nn.BatchNorm2d | |
self.conv = conv( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
padding_mode=padding_mode | |
) | |
self.bn = bn(out_channels) | |
self.relu = nn.LeakyReLU(inplace=True) | |
def forward(self, x: torch.Tensor): | |
x = self.conv(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
return x | |
class Decoder(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Tuple[int, int]], | |
stride: Union[int, Tuple[int, int]], | |
padding: Union[int, Tuple[int, int]] = (0, 0), | |
use_complex_networks: bool = False, | |
): | |
super().__init__() | |
if use_complex_networks: | |
tconv = complex_nn.ComplexConvTranspose2d | |
bn = complex_nn.ComplexBatchNorm2d | |
else: | |
tconv = nn.ConvTranspose2d | |
bn = nn.BatchNorm2d | |
self.transconv = tconv( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding | |
) | |
self.bn = bn(out_channels) | |
self.relu = nn.LeakyReLU(inplace=True) | |
def forward(self, x): | |
x = self.transconv(x) | |
x = self.bn(x) | |
x = self.relu(x) | |
return x | |
class UNetConfig14(object): | |
""" | |
inputs x shape: [1, 1, 321, 2000, 2] | |
sample rate: 16000 | |
nfft: 640 | |
win_size: 640 | |
hop_size: 320 (200ms) | |
""" | |
def __init__(self, in_channels: int): | |
self.enc_channels = [in_channels, 128, 128, 128, 128, 128, 128, 128] | |
self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (2, 2)] | |
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] | |
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] | |
self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1] | |
self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), (5, 2), (5, 2)] | |
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] | |
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] | |
class UNetConfig10(object): | |
""" | |
inputs x shape: [1, 1, 65, 200, 2] | |
sample rate: 8000 | |
nfft: 128 | |
win_size: 128 | |
hop_size: 64 (8ms) | |
""" | |
def __init__(self, in_channels: int): | |
self.enc_channels = [in_channels, 16, 32, 64, 128, 256] | |
self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)] | |
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] | |
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] | |
self.dec_channels = [128, 128, 64, 32, 16, 1] | |
self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)] | |
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)] | |
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)] | |
class UNetConfig20(object): | |
""" | |
inputs x shape: [1, 1, 257, 2000, 2] | |
sample rate: 8000 | |
nfft: 512 | |
win_size: 512 | |
hop_size: 256 (32ms) | |
""" | |
def __init__(self, in_channels: int, model_complexity: int): | |
self.enc_channels = [ | |
in_channels, | |
model_complexity, model_complexity, | |
model_complexity * 2, model_complexity * 2, | |
model_complexity * 2, model_complexity * 2, | |
model_complexity * 2, model_complexity * 2, | |
model_complexity * 2, | |
128 | |
] | |
self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3), | |
(5, 3), (5, 3), (5, 3), (5, 3), (5, 3)] | |
self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2), | |
(2, 1), (2, 2), (2, 1), (2, 2), (2, 1)] | |
self.enc_paddings = [ | |
(3, 0), | |
(0, 3), | |
None, # (0, 2), | |
None, | |
None, # (3,1), | |
None, # (3,1), | |
None, # (1,2), | |
None, | |
None, | |
None | |
] | |
self.dec_channels = [ | |
64, | |
model_complexity * 2, | |
model_complexity * 2, model_complexity * 2, | |
model_complexity * 2, model_complexity * 2, | |
model_complexity * 2, model_complexity * 2, | |
model_complexity, model_complexity, | |
1 | |
] | |
self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3), | |
(4, 2), (6, 3), (7, 4), (1, 7), (7, 1)] | |
self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1), | |
(2, 2), (2, 1), (2, 2), (1, 1), (1, 1)] | |
self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1), | |
(1, 0), (2, 1), (2, 1), (0, 3), (3, 0)] | |
class UNet(nn.Module): | |
def __init__(self, | |
in_channels: int = 1, | |
use_complex_networks: bool = False, | |
model_complexity: int = 45, | |
model_depth: int = 20, | |
padding_mode: str = "zeros" | |
): | |
super().__init__() | |
if use_complex_networks: | |
model_complexity = int(model_complexity // 1.414) | |
# config | |
if model_depth == 14: | |
config = UNetConfig14(in_channels) | |
elif model_depth == 10: | |
config = UNetConfig10(in_channels) | |
elif model_depth == 20: | |
config = UNetConfig20(in_channels, model_complexity) | |
else: | |
raise AssertionError(f"Unknown model depth : {model_depth}") | |
self.model_length = model_depth // 2 | |
self.fsmn = complex_nn.ComplexUniDeepFsmn( | |
config.enc_channels[-1], | |
config.enc_channels[-1] | |
) | |
# go down | |
self.encoder_layers = nn.ModuleList(modules=[]) | |
for i in range(self.model_length): | |
encoder_layer = nn.Sequential( | |
complex_nn.ComplexUniDeepFsmnL1( | |
config.enc_channels[i], | |
config.enc_channels[i] | |
) | |
if i != 0 else nn.Identity(), | |
Encoder( | |
config.enc_channels[i], | |
config.enc_channels[i + 1], | |
kernel_size=config.enc_kernel_sizes[i], | |
stride=config.enc_strides[i], | |
padding=config.enc_paddings[i], | |
use_complex_networks=use_complex_networks, | |
padding_mode=padding_mode | |
), | |
SELayer(config.enc_channels[i + 1], reduction=8) | |
) | |
self.encoder_layers.append(encoder_layer) | |
self.decoder_layers = nn.ModuleList(modules=[]) | |
for i in range(self.model_length): | |
decoder_layer = nn.Sequential( | |
Decoder( | |
config.dec_channels[i] * 2, | |
config.dec_channels[i + 1], | |
kernel_size=config.dec_kernel_sizes[i], | |
stride=config.dec_strides[i], | |
padding=config.dec_paddings[i], | |
use_complex_networks=use_complex_networks | |
), | |
complex_nn.ComplexUniDeepFsmnL1( | |
config.dec_channels[i + 1], | |
config.dec_channels[i + 1] | |
) | |
if i < (self.model_length - 1) else nn.Identity(), | |
SELayer( | |
config.dec_channels[i + 1], | |
reduction=8 | |
) | |
if i < (self.model_length - 2) else nn.Identity() | |
) | |
self.decoder_layers.append(decoder_layer) | |
if use_complex_networks: | |
conv = complex_nn.ComplexConv2d | |
else: | |
conv = nn.Conv2d | |
self.linear = conv( | |
in_channels=config.dec_channels[-1], | |
out_channels=1, | |
kernel_size=1, | |
) | |
def forward(self, inputs: torch.Tensor): | |
""" | |
:param inputs: torch.Tensor, shape: [b, c, f, t, 2] | |
:return: | |
""" | |
x = inputs | |
# print(f"inputs: {x.shape}") | |
# go down | |
xs = list() | |
xs_se = list() | |
xs_se.append(x) | |
for encoder_layer in self.encoder_layers: | |
xs.append(x) | |
# print(f"x: {x.shape}") | |
x = encoder_layer.forward(x) | |
# print(f"x: {x.shape}") | |
xs_se.append(x) | |
# x shape: [b, c, 1, t', 2] | |
x = self.fsmn.forward(x) | |
# x shape: [b, c, 1, t', 2] | |
# print(f"fsmn") | |
p = x | |
for i, decoder_layers in enumerate(self.decoder_layers): | |
p = decoder_layers.forward(p) | |
# print(f"p: {p.shape}") | |
if i == self.model_length - 1: | |
break | |
p = torch.cat(tensors=[p, xs_se[self.model_length - 1 - i]], dim=1) | |
# cmp_spec: [1, 1, 321, 200, 2] | |
# cmp_spec: [1, 1, 513, 200, 2] | |
cmp_spec = self.linear.forward(p) | |
return cmp_spec | |
def main(): | |
# [batch_size, 1, freq_bins, time_steps, 2] | |
# x = torch.rand(size=(1, 1, 257, 2000, 2)) | |
# unet = UNet( | |
# in_channels=1, | |
# model_complexity=45, | |
# model_depth=20, | |
# use_complex_networks=True | |
# ) | |
# print(unet) | |
# result = unet.forward(x) | |
# print(result.shape) | |
x = torch.rand(size=(1, 1, 65, 2000, 2)) | |
unet = UNet( | |
in_channels=1, | |
model_complexity=-1, | |
model_depth=10, | |
use_complex_networks=True | |
) | |
print(unet) | |
result = unet.forward(x) | |
print(result.shape) | |
return | |
if __name__ == "__main__": | |
main() | |