HoneyTian's picture
add frcrn model
75f7547
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Union, Tuple
import torch
import torch.nn as nn
from toolbox.torchaudio.models.frcrn import complex_nn
class SELayer(nn.Module):
def __init__(self, channels: int, reduction: int = 16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc_r = nn.Sequential(
nn.Linear(channels, channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels),
nn.Sigmoid()
)
self.fc_i = nn.Sequential(
nn.Linear(channels, channels // reduction),
nn.ReLU(inplace=True),
nn.Linear(channels // reduction, channels),
nn.Sigmoid()
)
def forward(self, x: torch.Tensor):
b, c, _, _, _ = x.size()
x_r = self.avg_pool(x[:, :, :, :, 0]).view(b, c)
x_i = self.avg_pool(x[:, :, :, :, 1]).view(b, c)
y_r = self.fc_r(x_r).view(b, c, 1, 1, 1) - self.fc_i(x_i).view(b, c, 1, 1, 1)
y_i = self.fc_r(x_i).view(b, c, 1, 1, 1) + self.fc_i(x_r).view(b, c, 1, 1, 1)
y = torch.cat(tensors=[y_r, y_i], dim=4)
return x * y
class Encoder(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]] = None,
use_complex_networks: bool = False,
padding_mode: str = "zeros"
):
super().__init__()
if padding is None:
padding = [(k - 1) // 2 for k in kernel_size] # 'SAME' padding
if use_complex_networks:
conv = complex_nn.ComplexConv2d
bn = complex_nn.ComplexBatchNorm2d
else:
conv = nn.Conv2d
bn = nn.BatchNorm2d
self.conv = conv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode
)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
def forward(self, x: torch.Tensor):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class Decoder(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]],
stride: Union[int, Tuple[int, int]],
padding: Union[int, Tuple[int, int]] = (0, 0),
use_complex_networks: bool = False,
):
super().__init__()
if use_complex_networks:
tconv = complex_nn.ComplexConvTranspose2d
bn = complex_nn.ComplexBatchNorm2d
else:
tconv = nn.ConvTranspose2d
bn = nn.BatchNorm2d
self.transconv = tconv(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding
)
self.bn = bn(out_channels)
self.relu = nn.LeakyReLU(inplace=True)
def forward(self, x):
x = self.transconv(x)
x = self.bn(x)
x = self.relu(x)
return x
class UNetConfig14(object):
"""
inputs x shape: [1, 1, 321, 2000, 2]
sample rate: 16000
nfft: 640
win_size: 640
hop_size: 320 (200ms)
"""
def __init__(self, in_channels: int):
self.enc_channels = [in_channels, 128, 128, 128, 128, 128, 128, 128]
self.enc_kernel_sizes = [(5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (5, 2), (2, 2)]
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
self.dec_channels = [64, 128, 128, 128, 128, 128, 128, 1]
self.dec_kernel_sizes = [(2, 2), (5, 2), (5, 2), (5, 2), (6, 2), (5, 2), (5, 2)]
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
class UNetConfig10(object):
"""
inputs x shape: [1, 1, 65, 200, 2]
sample rate: 8000
nfft: 128
win_size: 128
hop_size: 64 (8ms)
"""
def __init__(self, in_channels: int):
self.enc_channels = [in_channels, 16, 32, 64, 128, 256]
self.enc_kernel_sizes = [(3, 3), (3, 3), (3, 3), (3, 3), (3, 3)]
self.enc_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.enc_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
self.dec_channels = [128, 128, 64, 32, 16, 1]
self.dec_kernel_sizes = [(3, 3), (3, 3), (3, 3), (4, 3), (3, 3)]
self.dec_strides = [(2, 1), (2, 1), (2, 1), (2, 1), (2, 1)]
self.dec_paddings = [(0, 1), (0, 1), (0, 1), (0, 1), (0, 1)]
class UNetConfig20(object):
"""
inputs x shape: [1, 1, 257, 2000, 2]
sample rate: 8000
nfft: 512
win_size: 512
hop_size: 256 (32ms)
"""
def __init__(self, in_channels: int, model_complexity: int):
self.enc_channels = [
in_channels,
model_complexity, model_complexity,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2,
128
]
self.enc_kernel_sizes = [(7, 1), (1, 7), (6, 4), (7, 5), (5, 3),
(5, 3), (5, 3), (5, 3), (5, 3), (5, 3)]
self.enc_strides = [(1, 1), (1, 1), (2, 2), (2, 1), (2, 2),
(2, 1), (2, 2), (2, 1), (2, 2), (2, 1)]
self.enc_paddings = [
(3, 0),
(0, 3),
None, # (0, 2),
None,
None, # (3,1),
None, # (3,1),
None, # (1,2),
None,
None,
None
]
self.dec_channels = [
64,
model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity * 2, model_complexity * 2,
model_complexity, model_complexity,
1
]
self.dec_kernel_sizes = [(4, 3), (4, 2), (4, 3), (4, 2), (4, 3),
(4, 2), (6, 3), (7, 4), (1, 7), (7, 1)]
self.dec_strides = [(2, 1), (2, 2), (2, 1), (2, 2), (2, 1),
(2, 2), (2, 1), (2, 2), (1, 1), (1, 1)]
self.dec_paddings = [(1, 1), (1, 0), (1, 1), (1, 0), (1, 1),
(1, 0), (2, 1), (2, 1), (0, 3), (3, 0)]
class UNet(nn.Module):
def __init__(self,
in_channels: int = 1,
use_complex_networks: bool = False,
model_complexity: int = 45,
model_depth: int = 20,
padding_mode: str = "zeros"
):
super().__init__()
if use_complex_networks:
model_complexity = int(model_complexity // 1.414)
# config
if model_depth == 14:
config = UNetConfig14(in_channels)
elif model_depth == 10:
config = UNetConfig10(in_channels)
elif model_depth == 20:
config = UNetConfig20(in_channels, model_complexity)
else:
raise AssertionError(f"Unknown model depth : {model_depth}")
self.model_length = model_depth // 2
self.fsmn = complex_nn.ComplexUniDeepFsmn(
config.enc_channels[-1],
config.enc_channels[-1]
)
# go down
self.encoder_layers = nn.ModuleList(modules=[])
for i in range(self.model_length):
encoder_layer = nn.Sequential(
complex_nn.ComplexUniDeepFsmnL1(
config.enc_channels[i],
config.enc_channels[i]
)
if i != 0 else nn.Identity(),
Encoder(
config.enc_channels[i],
config.enc_channels[i + 1],
kernel_size=config.enc_kernel_sizes[i],
stride=config.enc_strides[i],
padding=config.enc_paddings[i],
use_complex_networks=use_complex_networks,
padding_mode=padding_mode
),
SELayer(config.enc_channels[i + 1], reduction=8)
)
self.encoder_layers.append(encoder_layer)
self.decoder_layers = nn.ModuleList(modules=[])
for i in range(self.model_length):
decoder_layer = nn.Sequential(
Decoder(
config.dec_channels[i] * 2,
config.dec_channels[i + 1],
kernel_size=config.dec_kernel_sizes[i],
stride=config.dec_strides[i],
padding=config.dec_paddings[i],
use_complex_networks=use_complex_networks
),
complex_nn.ComplexUniDeepFsmnL1(
config.dec_channels[i + 1],
config.dec_channels[i + 1]
)
if i < (self.model_length - 1) else nn.Identity(),
SELayer(
config.dec_channels[i + 1],
reduction=8
)
if i < (self.model_length - 2) else nn.Identity()
)
self.decoder_layers.append(decoder_layer)
if use_complex_networks:
conv = complex_nn.ComplexConv2d
else:
conv = nn.Conv2d
self.linear = conv(
in_channels=config.dec_channels[-1],
out_channels=1,
kernel_size=1,
)
def forward(self, inputs: torch.Tensor):
"""
:param inputs: torch.Tensor, shape: [b, c, f, t, 2]
:return:
"""
x = inputs
# print(f"inputs: {x.shape}")
# go down
xs = list()
xs_se = list()
xs_se.append(x)
for encoder_layer in self.encoder_layers:
xs.append(x)
# print(f"x: {x.shape}")
x = encoder_layer.forward(x)
# print(f"x: {x.shape}")
xs_se.append(x)
# x shape: [b, c, 1, t', 2]
x = self.fsmn.forward(x)
# x shape: [b, c, 1, t', 2]
# print(f"fsmn")
p = x
for i, decoder_layers in enumerate(self.decoder_layers):
p = decoder_layers.forward(p)
# print(f"p: {p.shape}")
if i == self.model_length - 1:
break
p = torch.cat(tensors=[p, xs_se[self.model_length - 1 - i]], dim=1)
# cmp_spec: [1, 1, 321, 200, 2]
# cmp_spec: [1, 1, 513, 200, 2]
cmp_spec = self.linear.forward(p)
return cmp_spec
def main():
# [batch_size, 1, freq_bins, time_steps, 2]
# x = torch.rand(size=(1, 1, 257, 2000, 2))
# unet = UNet(
# in_channels=1,
# model_complexity=45,
# model_depth=20,
# use_complex_networks=True
# )
# print(unet)
# result = unet.forward(x)
# print(result.shape)
x = torch.rand(size=(1, 1, 65, 2000, 2))
unet = UNet(
in_channels=1,
model_complexity=-1,
model_depth=10,
use_complex_networks=True
)
print(unet)
result = unet.forward(x)
print(result.shape)
return
if __name__ == "__main__":
main()