Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import os | |
import math | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
import torchaudio | |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig | |
from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT | |
MODEL_FILE = "model.pt" | |
norm_layer_dict = { | |
"batch_norm_2d": torch.nn.BatchNorm2d | |
} | |
activation_layer_dict = { | |
"relu": torch.nn.ReLU, | |
"identity": torch.nn.Identity, | |
"sigmoid": torch.nn.Sigmoid, | |
} | |
class CausalConv2d(nn.Sequential): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
fpad: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
lookahead: int = 0 | |
): | |
""" | |
Causal Conv2d by delaying the signal for any lookahead. | |
Expected input format: [batch_size, channels, time_steps, spec_dim] | |
:param in_channels: | |
:param out_channels: | |
:param kernel_size: | |
:param fstride: | |
:param dilation: | |
:param fpad: | |
""" | |
super(CausalConv2d, self).__init__() | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) | |
if fpad: | |
fpad_ = kernel_size[1] // 2 + dilation - 1 | |
else: | |
fpad_ = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) | |
layers = list() | |
if any(x > 0 for x in pad): | |
layers.append(nn.ConstantPad2d(pad, 0.0)) | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
if max(kernel_size) == 1: | |
separable = False | |
layers.append( | |
nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(0, fpad_), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
) | |
if separable: | |
layers.append( | |
nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
) | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
layers.append(norm_layer(out_channels)) | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
layers.append(activation_layer()) | |
super().__init__(*layers) | |
def forward(self, inputs): | |
for module in self: | |
inputs = module(inputs) | |
return inputs | |
class CausalConvTranspose2d(nn.Sequential): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
fpad: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
lookahead: int = 0 | |
): | |
""" | |
Causal ConvTranspose2d. | |
Expected input format: [batch_size, channels, time_steps, spec_dim] | |
""" | |
super(CausalConvTranspose2d, self).__init__() | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size | |
if fpad: | |
fpad_ = kernel_size[1] // 2 | |
else: | |
fpad_ = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) | |
layers = [] | |
if any(x > 0 for x in pad): | |
layers.append(nn.ConstantPad2d(pad, 0.0)) | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
layers.append( | |
nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(kernel_size[0] - 1, fpad_ + dilation - 1), | |
output_padding=(0, fpad_), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
) | |
if separable: | |
layers.append( | |
nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
) | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
layers.append(norm_layer(out_channels)) | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
layers.append(activation_layer()) | |
super().__init__(*layers) | |
class GroupedLinear(nn.Module): | |
def __init__(self, input_size: int, hidden_size: int, groups: int = 1): | |
super().__init__() | |
# self.weight: Tensor | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.groups = groups | |
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" | |
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" | |
self.ws = input_size // groups | |
self.register_parameter( | |
"weight", | |
torch.nn.Parameter( | |
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True | |
), | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: [..., I] | |
b, t, _ = x.shape | |
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws] | |
new_shape = (b, t, self.groups, self.ws) | |
x = x.view(new_shape) | |
# The better way, but not supported by torchscript | |
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] | |
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] | |
x = x.flatten(2, 3) # [B, T, H] | |
return x | |
def __repr__(self): | |
cls = self.__class__.__name__ | |
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" | |
class SqueezedGRU_S(nn.Module): | |
""" | |
SGE net: Video object detection with squeezed GRU and information entropy map | |
https://arxiv.org/abs/2106.07224 | |
""" | |
def __init__( | |
self, | |
input_size: int, | |
hidden_size: int, | |
output_size: Optional[int] = None, | |
num_layers: int = 1, | |
linear_groups: int = 8, | |
batch_first: bool = True, | |
skip_op: str = "none", | |
activation_layer: str = "identity", | |
): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.linear_in = nn.Sequential( | |
GroupedLinear( | |
input_size=input_size, | |
hidden_size=hidden_size, | |
groups=linear_groups, | |
), | |
activation_layer_dict[activation_layer](), | |
) | |
# gru skip operator | |
self.gru_skip_op = None | |
if skip_op == "none": | |
self.gru_skip_op = None | |
elif skip_op == "identity": | |
if not input_size != output_size: | |
raise AssertionError("Dimensions do not match") | |
self.gru_skip_op = nn.Identity() | |
elif skip_op == "grouped_linear": | |
self.gru_skip_op = GroupedLinear( | |
input_size=hidden_size, | |
hidden_size=hidden_size, | |
groups=linear_groups, | |
) | |
else: | |
raise NotImplementedError() | |
self.gru = nn.GRU( | |
input_size=hidden_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
batch_first=batch_first, | |
bidirectional=False, | |
) | |
if output_size is not None: | |
self.linear_out = nn.Sequential( | |
GroupedLinear( | |
input_size=hidden_size, | |
hidden_size=output_size, | |
groups=linear_groups, | |
), | |
activation_layer_dict[activation_layer](), | |
) | |
else: | |
self.linear_out = nn.Identity() | |
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: | |
x = self.linear_in(inputs) | |
x, h = self.gru.forward(x, h) | |
x = self.linear_out(x) | |
if self.gru_skip_op is not None: | |
x = x + self.gru_skip_op(inputs) | |
return x, h | |
class Add(nn.Module): | |
def forward(self, a, b): | |
return a + b | |
class Concat(nn.Module): | |
def forward(self, a, b): | |
return torch.cat((a, b), dim=-1) | |
class Encoder(nn.Module): | |
def __init__(self, config: DfNetConfig): | |
super(Encoder, self).__init__() | |
self.embedding_input_size = config.conv_channels * config.spec_bins // 4 | |
self.embedding_output_size = config.conv_channels * config.spec_bins // 4 | |
self.embedding_hidden_size = config.embedding_hidden_size | |
self.spec_conv0 = CausalConv2d( | |
in_channels=1, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_input, | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
self.spec_conv1 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
lookahead=config.conv_lookahead, | |
) | |
self.spec_conv2 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
lookahead=config.conv_lookahead, | |
) | |
self.spec_conv3 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
self.df_conv0 = CausalConv2d( | |
in_channels=2, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_input, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.df_conv1 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.df_fc_emb = nn.Sequential( | |
GroupedLinear( | |
config.conv_channels * config.df_bins // 2, | |
self.embedding_input_size, | |
groups=config.encoder_linear_groups | |
), | |
nn.ReLU(inplace=True) | |
) | |
if config.encoder_combine_op == "concat": | |
self.embedding_input_size *= 2 | |
self.combine = Concat() | |
else: | |
self.combine = Add() | |
# emb_gru | |
if config.spec_bins % 8 != 0: | |
raise AssertionError("spec_bins should be divisible by 8") | |
self.emb_gru = SqueezedGRU_S( | |
self.embedding_input_size, | |
self.embedding_hidden_size, | |
output_size=self.embedding_output_size, | |
num_layers=1, | |
batch_first=True, | |
skip_op=config.encoder_emb_skip_op, | |
linear_groups=config.encoder_emb_linear_groups, | |
activation_layer="relu", | |
) | |
# lsnr | |
self.lsnr_fc = nn.Sequential( | |
nn.Linear(self.embedding_output_size, 1), | |
nn.Sigmoid() | |
) | |
self.lsnr_scale = config.lsnr_max - config.lsnr_min | |
self.lsnr_offset = config.lsnr_min | |
def forward(self, | |
feat_power: torch.Tensor, | |
feat_spec: torch.Tensor, | |
hidden_state: torch.Tensor = None, | |
): | |
# feat_power shape: (batch_size, 1, time_steps, spec_dim) | |
e0 = self.spec_conv0.forward(feat_power) | |
e1 = self.spec_conv1.forward(e0) | |
e2 = self.spec_conv2.forward(e1) | |
e3 = self.spec_conv3.forward(e2) | |
# e0 shape: [batch_size, channels, time_steps, spec_dim] | |
# e1 shape: [batch_size, channels, time_steps, spec_dim // 2] | |
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4] | |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4] | |
# feat_spec, shape: (batch_size, 2, time_steps, df_bins) | |
c0 = self.df_conv0(feat_spec) | |
c1 = self.df_conv1(c0) | |
# c0 shape: [batch_size, channels, time_steps, df_bins] | |
# c1 shape: [batch_size, channels, time_steps, df_bins // 2] | |
cemb = c1.permute(0, 2, 3, 1) | |
# cemb shape: [batch_size, time_steps, df_bins // 2, channels] | |
cemb = cemb.flatten(2) | |
# cemb shape: [batch_size, time_steps, df_bins // 2 * channels] | |
cemb = self.df_fc_emb(cemb) | |
# cemb shape: [batch_size, time_steps, spec_dim // 4 * channels] | |
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4] | |
emb = e3.permute(0, 2, 3, 1) | |
# emb shape: [batch_size, time_steps, spec_dim // 4, channels] | |
emb = emb.flatten(2) | |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels] | |
emb = self.combine(emb, cemb) | |
# if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2] | |
# if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels] | |
emb, h = self.emb_gru.forward(emb, hidden_state) | |
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels] | |
# h shape: [batch_size, 1, spec_dim] | |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset | |
# lsnr shape: [batch_size, time_steps, 1] | |
return e0, e1, e2, e3, emb, c0, lsnr, h | |
class Decoder(nn.Module): | |
def __init__(self, config: DfNetConfig): | |
super(Decoder, self).__init__() | |
if config.spec_bins % 8 != 0: | |
raise AssertionError("spec_bins should be divisible by 8") | |
self.emb_in_dim = config.conv_channels * config.spec_bins // 4 | |
self.emb_out_dim = config.conv_channels * config.spec_bins // 4 | |
self.emb_hidden_dim = config.decoder_emb_hidden_size | |
self.emb_gru = SqueezedGRU_S( | |
self.emb_in_dim, | |
self.emb_hidden_dim, | |
output_size=self.emb_out_dim, | |
num_layers=config.decoder_emb_num_layers - 1, | |
batch_first=True, | |
skip_op=config.decoder_emb_skip_op, | |
linear_groups=config.decoder_emb_linear_groups, | |
activation_layer="relu", | |
) | |
self.conv3p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
self.convt3 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
self.conv2p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
self.convt2 = CausalConvTranspose2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.convt_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
lookahead=config.conv_lookahead, | |
) | |
self.conv1p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
self.convt1 = CausalConvTranspose2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.convt_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
lookahead=config.conv_lookahead, | |
) | |
self.conv0p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
self.conv0_out = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=1, | |
kernel_size=config.conv_kernel_size_inner, | |
activation_layer="sigmoid", | |
bias=False, | |
separable=True, | |
fstride=1, | |
lookahead=config.conv_lookahead, | |
) | |
def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: | |
# Estimates erb mask | |
b, _, t, f8 = e3.shape | |
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] | |
emb, _ = self.emb_gru(emb) | |
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] | |
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) | |
e3 = self.convt3(self.conv3p(e3) + emb) | |
# e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] | |
e2 = self.convt2(self.conv2p(e2) + e3) | |
# e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] | |
e1 = self.convt1(self.conv1p(e1) + e2) | |
# e1 shape: [batch_size, conv_channels, time_steps, freq_dim] | |
mask = self.conv0_out(self.conv0p(e0) + e1) | |
# mask shape: [batch_size, 1, time_steps, freq_dim] | |
return mask | |
class DfDecoder(nn.Module): | |
def __init__(self, config: DfNetConfig): | |
super(DfDecoder, self).__init__() | |
self.embedding_input_size = config.conv_channels * config.spec_bins // 4 | |
self.df_decoder_hidden_size = config.df_decoder_hidden_size | |
self.df_num_layers = config.df_num_layers | |
self.df_order = config.df_order | |
self.df_bins = config.df_bins | |
self.df_out_ch = config.df_order * 2 | |
self.df_convp = CausalConv2d( | |
config.conv_channels, | |
self.df_out_ch, | |
fstride=1, | |
kernel_size=(config.df_pathway_kernel_size_t, 1), | |
separable=True, | |
bias=False, | |
) | |
self.df_gru = SqueezedGRU_S( | |
self.embedding_input_size, | |
self.df_decoder_hidden_size, | |
num_layers=self.df_num_layers, | |
batch_first=True, | |
skip_op="none", | |
activation_layer="relu", | |
) | |
if config.df_gru_skip == "none": | |
self.df_skip = None | |
elif config.df_gru_skip == "identity": | |
if config.embedding_hidden_size != config.df_decoder_hidden_size: | |
raise AssertionError("Dimensions do not match") | |
self.df_skip = nn.Identity() | |
elif config.df_gru_skip == "grouped_linear": | |
self.df_skip = GroupedLinear( | |
self.embedding_input_size, | |
self.df_decoder_hidden_size, | |
groups=config.df_decoder_linear_groups | |
) | |
else: | |
raise NotImplementedError() | |
self.df_out: nn.Module | |
out_dim = self.df_bins * self.df_out_ch | |
self.df_out = nn.Sequential( | |
GroupedLinear( | |
input_size=self.df_decoder_hidden_size, | |
hidden_size=out_dim, | |
groups=config.df_decoder_linear_groups | |
), | |
nn.Tanh() | |
) | |
self.df_fc_a = nn.Sequential( | |
nn.Linear(self.df_decoder_hidden_size, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: | |
# emb shape: [batch_size, time_steps, df_bins // 4 * channels] | |
b, t, _ = emb.shape | |
df_coefs, _ = self.df_gru(emb) | |
if self.df_skip is not None: | |
df_coefs = df_coefs + self.df_skip(emb) | |
# df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] | |
# c0 shape: [batch_size, channels, time_steps, df_bins] | |
c0 = self.df_convp(c0) | |
# c0 shape: [batch_size, df_order * 2, time_steps, df_bins] | |
c0 = c0.permute(0, 2, 3, 1) | |
# c0 shape: [batch_size, time_steps, df_bins, df_order * 2] | |
df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order | |
# df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] | |
df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) | |
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] | |
df_coefs = df_coefs + c0 | |
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] | |
return df_coefs | |
class DfOutputReshapeMF(nn.Module): | |
"""Coefficients output reshape for multiframe/MultiFrameModule | |
Requires input of shape B, C, T, F, 2. | |
""" | |
def __init__(self, df_order: int, df_bins: int): | |
super().__init__() | |
self.df_order = df_order | |
self.df_bins = df_bins | |
def forward(self, coefs: torch.Tensor) -> torch.Tensor: | |
# [B, T, F, O*2] -> [B, O, T, F, 2] | |
new_shape = list(coefs.shape) | |
new_shape[-1] = -1 | |
new_shape.append(2) | |
coefs = coefs.view(new_shape) | |
coefs = coefs.permute(0, 3, 1, 2, 4) | |
return coefs | |
class Mask(nn.Module): | |
def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): | |
super().__init__() | |
self.use_post_filter = use_post_filter | |
self.eps = eps | |
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: | |
""" | |
Post-Filter | |
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. | |
https://arxiv.org/abs/2008.04259 | |
:param mask: Real valued mask, typically of shape [B, C, T, F]. | |
:param beta: Global gain factor. | |
:return: | |
""" | |
mask_sin = mask * torch.sin(np.pi * mask / 2) | |
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) | |
return mask_pf | |
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
# spec shape: [batch_size, 1, time_steps, spec_bins, 2] | |
if not self.training and self.use_post_filter: | |
mask = self.post_filter(mask) | |
# mask shape: [batch_size, 1, time_steps, spec_bins] | |
mask = mask.unsqueeze(4) | |
# mask shape: [batch_size, 1, time_steps, spec_bins, 1] | |
return spec * mask | |
class DeepFiltering(nn.Module): | |
def __init__(self, | |
df_bins: int, | |
df_order: int, | |
lookahead: int = 0, | |
): | |
super(DeepFiltering, self).__init__() | |
self.df_bins = df_bins | |
self.df_order = df_order | |
self.need_unfold = df_order > 1 | |
self.lookahead = lookahead | |
self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) | |
def spec_unfold(self, spec: torch.Tensor): | |
""" | |
Pads and unfolds the spectrogram according to frame_size. | |
:param spec: complex Tensor, Spectrogram of shape [B, C, T, F]. | |
:return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. | |
""" | |
if self.need_unfold: | |
# spec shape: [batch_size, spec_bins, time_steps] | |
spec_pad = self.pad(spec) | |
# spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins] | |
spec_unfold = spec_pad.unfold(2, self.df_order, 1) | |
# spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order] | |
return spec_unfold | |
else: | |
return spec.unsqueeze(-1) | |
def forward(self, | |
spec: torch.Tensor, | |
coefs: torch.Tensor, | |
): | |
# spec shape: [batch_size, 1, time_steps, spec_bins, 2] | |
spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous())) | |
# spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order] | |
# coefs shape: [batch_size, df_order, time_steps, df_bins, 2] | |
coefs = torch.view_as_complex(coefs.contiguous()) | |
# coefs shape: [batch_size, df_order, time_steps, df_bins] | |
spec_f = spec_u.narrow(-2, 0, self.df_bins) | |
# spec_f shape: [batch_size, 1, time_steps, df_bins, df_order] | |
coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:]) | |
# coefs shape: [batch_size, 1, df_order, time_steps, df_bins] | |
spec_f = self.df(spec_f, coefs) | |
# spec_f shape: [batch_size, 1, time_steps, df_bins] | |
if self.training: | |
spec = spec.clone() | |
spec[..., :self.df_bins, :] = torch.view_as_real(spec_f) | |
# spec shape: [batch_size, 1, time_steps, spec_bins, 2] | |
return spec | |
def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: | |
""" | |
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. | |
:param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. | |
:param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. | |
:return: (complex Tensor). Spectrogram of shape [B, C, T, F]. | |
""" | |
return torch.einsum("...tfn,...ntf->...tf", spec, coefs) | |
class DfNet(nn.Module): | |
def __init__(self, config: DfNetConfig): | |
super(DfNet, self).__init__() | |
self.config = config | |
self.freq_bins = self.config.nfft // 2 + 1 | |
self.nfft = config.nfft | |
self.win_size = config.win_size | |
self.hop_size = config.hop_size | |
self.win_type = config.win_type | |
self.stft = ConvSTFT( | |
nfft=config.nfft, | |
win_size=config.win_size, | |
hop_size=config.hop_size, | |
win_type=config.win_type, | |
feature_type="complex", | |
requires_grad=False | |
) | |
self.istft = ConviSTFT( | |
nfft=config.nfft, | |
win_size=config.win_size, | |
hop_size=config.hop_size, | |
win_type=config.win_type, | |
feature_type="complex", | |
requires_grad=False | |
) | |
self.encoder = Encoder(config) | |
self.decoder = Decoder(config) | |
self.df_decoder = DfDecoder(config) | |
self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) | |
self.df_op = DeepFiltering( | |
df_bins=config.df_bins, | |
df_order=config.df_order, | |
lookahead=config.df_lookahead, | |
) | |
self.mask = Mask(use_post_filter=config.use_post_filter) | |
def forward(self, | |
noisy: torch.Tensor, | |
): | |
if noisy.dim() == 2: | |
noisy = torch.unsqueeze(noisy, dim=1) | |
_, _, n_samples = noisy.shape | |
remainder = (n_samples - self.win_size) % self.hop_size | |
if remainder > 0: | |
n_samples_pad = self.hop_size - remainder | |
noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0) | |
# [batch_size, freq_bins * 2, time_steps] | |
cmp_spec = self.stft.forward(noisy) | |
# [batch_size, 1, freq_bins * 2, time_steps] | |
cmp_spec = torch.unsqueeze(cmp_spec, 1) | |
# [batch_size, 2, freq_bins, time_steps] | |
cmp_spec = torch.cat([ | |
cmp_spec[:, :, :self.freq_bins, :], | |
cmp_spec[:, :, self.freq_bins:, :], | |
], dim=1) | |
# n//2+1 -> n//2; 257 -> 256 | |
cmp_spec = cmp_spec[:, :, :-1, :] | |
spec = torch.unsqueeze(cmp_spec, dim=4) | |
# [batch_size, 2, freq_bins, time_steps, 1] | |
spec = spec.permute(0, 4, 3, 2, 1) | |
# spec shape: [batch_size, 1, time_steps, freq_bins, 2] | |
feat_power = torch.sum(torch.square(spec), dim=-1) | |
# feat_power shape: [batch_size, 1, time_steps, spec_bins] | |
feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3) | |
# feat_spec shape: [batch_size, 2, time_steps, freq_bins] | |
feat_spec = feat_spec[..., :self.df_decoder.df_bins] | |
# feat_spec shape: [batch_size, 2, time_steps, df_bins] | |
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec) | |
mask = self.decoder.forward(emb, e3, e2, e1, e0) | |
# mask shape: [batch_size, 1, time_steps, spec_bins] | |
if torch.any(mask > 1) or torch.any(mask < 0): | |
raise AssertionError | |
spec_m = self.mask.forward(spec, mask) | |
# lsnr shape: [batch_size, time_steps, 1] | |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1) | |
# lsnr shape: [batch_size, 1, time_steps] | |
df_coefs = self.df_decoder.forward(emb, c0) | |
df_coefs = self.df_out_transform(df_coefs) | |
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2] | |
spec_e = self.df_op.forward(spec.clone(), df_coefs) | |
# est_spec shape: [batch_size, 1, time_steps, spec_bins, 2] | |
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :] | |
spec_e = torch.squeeze(spec_e, dim=1) | |
spec_e = spec_e.permute(0, 2, 1, 3) | |
# spec_e shape: [batch_size, spec_bins, time_steps, 2] | |
mask = torch.squeeze(mask, dim=1) | |
est_mask = mask.permute(0, 2, 1) | |
# mask shape: [batch_size, spec_bins, time_steps] | |
b, _, t, _ = spec_e.shape | |
est_spec = torch.cat(tensors=[ | |
torch.concat(tensors=[ | |
spec_e[..., 0], | |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device) | |
], dim=1), | |
torch.concat(tensors=[ | |
spec_e[..., 1], | |
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device) | |
], dim=1), | |
], dim=1) | |
# est_spec shape: [b, n+2, t] | |
est_wav = self.istft.forward(est_spec) | |
est_wav = torch.squeeze(est_wav, dim=1) | |
est_wav = est_wav[:, :n_samples] | |
# est_wav shape: [b, n_samples] | |
return est_spec, est_wav, est_mask, lsnr | |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): | |
""" | |
:param est_mask: torch.Tensor, shape: [b, n+2, t] | |
:param clean: | |
:param noisy: | |
:return: | |
""" | |
clean_stft = self.stft(clean) | |
clean_re = clean_stft[:, :self.freq_bins, :] | |
clean_im = clean_stft[:, self.freq_bins:, :] | |
noisy_stft = self.stft(noisy) | |
noisy_re = noisy_stft[:, :self.freq_bins, :] | |
noisy_im = noisy_stft[:, self.freq_bins:, :] | |
noisy_power = noisy_re ** 2 + noisy_im ** 2 | |
sr = clean_re | |
yr = noisy_re | |
si = clean_im | |
yi = noisy_im | |
y_pow = noisy_power | |
# (Sr * Yr + Si * Yi) / (Y_pow + 1e-8) | |
gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps) | |
# (Si * Yr - Sr * Yi) / (Y_pow + 1e-8) | |
gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps) | |
gth_mask_re[gth_mask_re > 2] = 1 | |
gth_mask_re[gth_mask_re < -2] = -1 | |
gth_mask_im[gth_mask_im > 2] = 1 | |
gth_mask_im[gth_mask_im < -2] = -1 | |
mask_re = est_mask[:, :self.freq_bins, :] | |
mask_im = est_mask[:, self.freq_bins:, :] | |
loss_re = F.mse_loss(gth_mask_re, mask_re) | |
loss_im = F.mse_loss(gth_mask_im, mask_im) | |
loss = loss_re + loss_im | |
return loss | |
class DfNetPretrainedModel(DfNet): | |
def __init__(self, | |
config: DfNetConfig, | |
): | |
super(DfNetPretrainedModel, self).__init__( | |
config=config, | |
) | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
model = cls(config) | |
if os.path.isdir(pretrained_model_name_or_path): | |
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) | |
else: | |
ckpt_file = pretrained_model_name_or_path | |
with open(ckpt_file, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def save_pretrained(self, | |
save_directory: Union[str, os.PathLike], | |
state_dict: Optional[dict] = None, | |
): | |
model = self | |
if state_dict is None: | |
state_dict = model.state_dict() | |
os.makedirs(save_directory, exist_ok=True) | |
# save state dict | |
model_file = os.path.join(save_directory, MODEL_FILE) | |
torch.save(state_dict, model_file) | |
# save config | |
config_file = os.path.join(save_directory, CONFIG_FILE) | |
self.config.to_yaml_file(config_file) | |
return save_directory | |
def main(): | |
config = DfNetConfig() | |
model = DfNetPretrainedModel(config=config) | |
noisy = torch.randn(size=(1, 16000), dtype=torch.float32) | |
output = model.forward(noisy) | |
print(output[1].shape) | |
return | |
if __name__ == "__main__": | |
main() | |