Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import logging | |
import math | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from toolbox.torchaudio.models.dfnet3.configuration_dfnet3 import DfNetConfig | |
from toolbox.torchaudio.models.dfnet3 import multiframes as MF | |
from toolbox.torchaudio.models.dfnet3 import utils | |
logger = logging.getLogger("toolbox") | |
PI = 3.1415926535897932384626433 | |
norm_layer_dict = { | |
"batch_norm_2d": torch.nn.BatchNorm2d | |
} | |
activation_layer_dict = { | |
"relu": torch.nn.ReLU, | |
"identity": torch.nn.Identity, | |
"sigmoid": torch.nn.Sigmoid, | |
} | |
class CausalConv2d(nn.Sequential): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
fpad: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
): | |
""" | |
Causal Conv2d by delaying the signal for any lookahead. | |
Expected input format: [B, C, T, F] | |
:param in_channels: | |
:param out_channels: | |
:param kernel_size: | |
:param fstride: | |
:param dilation: | |
:param fpad: | |
""" | |
super(CausalConv2d, self).__init__() | |
lookahead = 0 | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) | |
if fpad: | |
fpad_ = kernel_size[1] // 2 + dilation - 1 | |
else: | |
fpad_ = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) | |
layers = [] | |
if any(x > 0 for x in pad): | |
layers.append(nn.ConstantPad2d(pad, 0.0)) | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
if max(kernel_size) == 1: | |
separable = False | |
layers.append( | |
nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(0, fpad_), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
) | |
if separable: | |
layers.append( | |
nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
) | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
layers.append(norm_layer(out_channels)) | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
layers.append(activation_layer()) | |
super().__init__(*layers) | |
class CausalConvTranspose2d(nn.Sequential): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
fpad: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
): | |
""" | |
Causal ConvTranspose2d. | |
Expected input format: [B, C, T, F] | |
""" | |
super(CausalConvTranspose2d, self).__init__() | |
lookahead = 0 | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size | |
if fpad: | |
fpad_ = kernel_size[1] // 2 | |
else: | |
fpad_ = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) | |
layers = [] | |
if any(x > 0 for x in pad): | |
layers.append(nn.ConstantPad2d(pad, 0.0)) | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
layers.append( | |
nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(kernel_size[0] - 1, fpad_ + dilation - 1), | |
output_padding=(0, fpad_), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
) | |
if separable: | |
layers.append( | |
nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
) | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
layers.append(norm_layer(out_channels)) | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
layers.append(activation_layer()) | |
super().__init__(*layers) | |
class GroupedLinear(nn.Module): | |
def __init__(self, input_size: int, hidden_size: int, groups: int = 1): | |
super().__init__() | |
# self.weight: Tensor | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.groups = groups | |
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" | |
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" | |
self.ws = input_size // groups | |
self.register_parameter( | |
"weight", | |
torch.nn.Parameter( | |
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True | |
), | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: [..., I] | |
b, t, _ = x.shape | |
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws] | |
new_shape = (b, t, self.groups, self.ws) | |
x = x.view(new_shape) | |
# The better way, but not supported by torchscript | |
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] | |
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] | |
x = x.flatten(2, 3) # [B, T, H] | |
return x | |
def __repr__(self): | |
cls = self.__class__.__name__ | |
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" | |
class SqueezedGRU_S(nn.Module): | |
""" | |
SGE net: Video object detection with squeezed GRU and information entropy map | |
https://arxiv.org/abs/2106.07224 | |
""" | |
def __init__( | |
self, | |
input_size: int, | |
hidden_size: int, | |
output_size: Optional[int] = None, | |
num_layers: int = 1, | |
linear_groups: int = 8, | |
batch_first: bool = True, | |
skip_op: str = "none", | |
activation_layer: str = "identity", | |
): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.linear_in = nn.Sequential( | |
GroupedLinear( | |
input_size=input_size, | |
hidden_size=hidden_size, | |
groups=linear_groups, | |
), | |
activation_layer_dict[activation_layer](), | |
) | |
# gru skip operator | |
self.gru_skip_op = None | |
if skip_op == "none": | |
self.gru_skip_op = None | |
elif skip_op == "identity": | |
if not input_size != output_size: | |
raise AssertionError("Dimensions do not match") | |
self.gru_skip_op = nn.Identity() | |
elif skip_op == "grouped_linear": | |
self.gru_skip_op = GroupedLinear( | |
input_size=hidden_size, | |
hidden_size=hidden_size, | |
groups=linear_groups, | |
) | |
else: | |
raise NotImplementedError() | |
self.gru = nn.GRU( | |
input_size=hidden_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
batch_first=batch_first, | |
) | |
if output_size is not None: | |
self.linear_out = nn.Sequential( | |
GroupedLinear( | |
input_size=hidden_size, | |
hidden_size=output_size, | |
groups=linear_groups, | |
), | |
activation_layer_dict[activation_layer](), | |
) | |
else: | |
self.linear_out = nn.Identity() | |
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: | |
x = self.linear_in(inputs) | |
x, h = self.gru(x, h) | |
x = self.linear_out(x) | |
if self.gru_skip_op is not None: | |
x = x + self.gru_skip_op(inputs) | |
return x, h | |
class Add(nn.Module): | |
def forward(self, a, b): | |
return a + b | |
class Concat(nn.Module): | |
def forward(self, a, b): | |
return torch.cat((a, b), dim=-1) | |
class Encoder(nn.Module): | |
def __init__(self, config: DfNetConfig): | |
super(Encoder, self).__init__() | |
self.emb_in_dim = config.conv_channels * config.erb_bins // 4 | |
self.emb_out_dim = config.conv_channels * config.erb_bins // 4 | |
self.emb_hidden_dim = config.emb_hidden_dim | |
self.erb_conv0 = CausalConv2d( | |
in_channels=1, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_input, | |
bias=False, | |
separable=True, | |
) | |
self.erb_conv1 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.erb_conv2 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.erb_conv3 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.df_conv0 = CausalConv2d( | |
in_channels=2, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_input, | |
bias=False, | |
separable=True, | |
) | |
self.df_conv1 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.df_fc_emb = nn.Sequential( | |
GroupedLinear( | |
config.conv_channels * config.df_bins // 2, | |
self.emb_in_dim, | |
groups=config.encoder_linear_groups | |
), | |
nn.ReLU(inplace=True) | |
) | |
if config.encoder_concat: | |
self.emb_in_dim *= 2 | |
self.combine = Concat() | |
else: | |
self.combine = Add() | |
self.emb_gru = SqueezedGRU_S( | |
self.emb_in_dim, | |
self.emb_hidden_dim, | |
output_size=self.emb_out_dim, | |
num_layers=1, | |
batch_first=True, | |
skip_op=config.encoder_gru_skip_op, | |
linear_groups=config.encoder_squeezed_gru_linear_groups, | |
activation_layer="relu", | |
) | |
self.lsnr_fc = nn.Sequential( | |
nn.Linear(self.emb_out_dim, 1), | |
nn.Sigmoid() | |
) | |
self.lsnr_scale = config.lsnr_max - config.lsnr_min | |
self.lsnr_offset = config.lsnr_min | |
def forward(self, | |
feat_erb: torch.Tensor, | |
feat_spec: torch.Tensor, | |
h: torch.Tensor = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
# Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands. | |
# erb: [B, 1, T, Fe] | |
# spec: [B, 2, T, Fc] | |
# b, _, t, _ = feat_erb.shape | |
e0 = self.erb_conv0(feat_erb) # [B, C, T, F] | |
e1 = self.erb_conv1(e0) # [B, C*2, T, F/2] | |
e2 = self.erb_conv2(e1) # [B, C*4, T, F/4] | |
e3 = self.erb_conv3(e2) # [B, C*4, T, F/4] | |
c0 = self.df_conv0(feat_spec) # [B, C, T, Fc] | |
c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2] | |
cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1] | |
cemb = self.df_fc_emb(cemb) # [T, B, C * F/4] | |
emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F] | |
emb = self.combine(emb, cemb) | |
emb, h = self.emb_gru(emb, h) # [B, T, -1] | |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset | |
return e0, e1, e2, e3, emb, c0, lsnr, h | |
class ErbDecoder(nn.Module): | |
def __init__(self, | |
config: DfNetConfig, | |
): | |
super(ErbDecoder, self).__init__() | |
if config.erb_bins % 8 != 0: | |
raise AssertionError("erb_bins should be divisible by 8") | |
self.emb_in_dim = config.conv_channels * config.erb_bins // 4 | |
self.emb_out_dim = config.conv_channels * config.erb_bins // 4 | |
self.emb_hidden_dim = config.emb_hidden_dim | |
self.emb_gru = SqueezedGRU_S( | |
self.emb_in_dim, | |
self.emb_hidden_dim, | |
output_size=self.emb_out_dim, | |
num_layers=config.erb_decoder_emb_num_layers - 1, | |
batch_first=True, | |
skip_op=config.erb_decoder_gru_skip_op, | |
linear_groups=config.erb_decoder_linear_groups, | |
activation_layer="relu", | |
) | |
# convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions | |
self.conv3p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
) | |
self.convt3 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
) | |
self.conv2p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
) | |
self.convt2 = CausalConvTranspose2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
fstride=2, | |
kernel_size=config.convt_kernel_size_inner, | |
bias=False, | |
separable=True, | |
) | |
self.conv1p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
) | |
self.convt1 = CausalConvTranspose2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
fstride=2, | |
kernel_size=config.convt_kernel_size_inner, | |
bias=False, | |
separable=True, | |
) | |
self.conv0p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
) | |
self.conv0_out = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=1, | |
kernel_size=config.conv_kernel_size_inner, | |
activation_layer="sigmoid", | |
bias=False, | |
separable=True, | |
) | |
def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: | |
# Estimates erb mask | |
b, _, t, f8 = e3.shape | |
emb, _ = self.emb_gru(emb) | |
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8] | |
e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4] | |
e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2] | |
e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F] | |
m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F] | |
return m | |
class Mask(nn.Module): | |
def __init__(self, erb_inv_fb: torch.FloatTensor, post_filter: bool = False, eps: float = 1e-12): | |
super().__init__() | |
self.erb_inv_fb: torch.FloatTensor | |
self.register_buffer("erb_inv_fb", erb_inv_fb.float()) | |
self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0" | |
self.post_filter = post_filter | |
self.eps = eps | |
def pf(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: | |
""" | |
Post-Filter | |
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. | |
https://arxiv.org/abs/2008.04259 | |
:param mask: Real valued mask, typically of shape [B, C, T, F]. | |
:param beta: Global gain factor. | |
:return: | |
""" | |
mask_sin = mask * torch.sin(np.pi * mask / 2) | |
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) | |
return mask_pf | |
def forward(self, spec: torch.Tensor, mask: torch.Tensor, atten_lim: Optional[torch.Tensor] = None) -> torch.Tensor: | |
# spec (real) [B, 1, T, F, 2], F: freq_bins | |
# mask (real): [B, 1, T, Fe], Fe: erb_bins | |
# atten_lim: [B] | |
if not self.training and self.post_filter: | |
mask = self.pf(mask) | |
if atten_lim is not None: | |
# dB to amplitude | |
atten_lim = 10 ** (-atten_lim / 20) | |
# Greater equal (__ge__) not implemented for TorchVersion. | |
if self.clamp_tensor: | |
# Supported by torch >= 1.9 | |
mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1)) | |
else: | |
m_out = [] | |
for i in range(atten_lim.shape[0]): | |
m_out.append(mask[i].clamp_min(atten_lim[i].item())) | |
mask = torch.stack(m_out, dim=0) | |
mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F] | |
if not spec.is_complex(): | |
mask = mask.unsqueeze(4) | |
return spec * mask | |
class DfDecoder(nn.Module): | |
def __init__(self, | |
config: DfNetConfig, | |
): | |
super().__init__() | |
layer_width = config.conv_channels | |
self.emb_in_dim = config.conv_channels * config.erb_bins // 4 | |
self.emb_dim = config.df_hidden_dim | |
self.df_n_hidden = config.df_hidden_dim | |
self.df_n_layers = config.df_num_layers | |
self.df_order = config.df_order | |
self.df_bins = config.df_bins | |
self.df_out_ch = config.df_order * 2 | |
self.df_convp = CausalConv2d( | |
layer_width, | |
self.df_out_ch, | |
fstride=1, | |
kernel_size=(config.df_pathway_kernel_size_t, 1), | |
separable=True, | |
bias=False, | |
) | |
self.df_gru = SqueezedGRU_S( | |
self.emb_in_dim, | |
self.emb_dim, | |
num_layers=self.df_n_layers, | |
batch_first=True, | |
skip_op="none", | |
activation_layer="relu", | |
) | |
if config.df_gru_skip == "none": | |
self.df_skip = None | |
elif config.df_gru_skip == "identity": | |
if config.emb_hidden_dim != config.df_hidden_dim: | |
raise AssertionError("Dimensions do not match") | |
self.df_skip = nn.Identity() | |
elif config.df_gru_skip == "grouped_linear": | |
self.df_skip = GroupedLinear(self.emb_in_dim, self.emb_dim, groups=config.df_decoder_linear_groups) | |
else: | |
raise NotImplementedError() | |
self.df_out: nn.Module | |
out_dim = self.df_bins * self.df_out_ch | |
self.df_out = nn.Sequential( | |
GroupedLinear( | |
input_size=self.df_n_hidden, | |
hidden_size=out_dim, | |
groups=config.df_decoder_linear_groups | |
), | |
nn.Tanh() | |
) | |
self.df_fc_a = nn.Sequential( | |
nn.Linear(self.df_n_hidden, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: | |
b, t, _ = emb.shape | |
c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden | |
if self.df_skip is not None: | |
c = c + self.df_skip(emb) | |
c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last | |
c = self.df_out(c) # [B, T, F*O*2], O: df_order | |
c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2] | |
return c | |
class DfOutputReshapeMF(nn.Module): | |
"""Coefficients output reshape for multiframe/MultiFrameModule | |
Requires input of shape B, C, T, F, 2. | |
""" | |
def __init__(self, df_order: int, df_bins: int): | |
super().__init__() | |
self.df_order = df_order | |
self.df_bins = df_bins | |
def forward(self, coefs: torch.Tensor) -> torch.Tensor: | |
# [B, T, F, O*2] -> [B, O, T, F, 2] | |
new_shape = list(coefs.shape) | |
new_shape[-1] = -1 | |
new_shape.append(2) | |
coefs = coefs.view(new_shape) | |
coefs = coefs.permute(0, 3, 1, 2, 4) | |
return coefs | |
class DfNet(nn.Module): | |
""" | |
DeepFilterNet: Perceptually Motivated Real-Time Speech Enhancement | |
https://arxiv.org/abs/2305.08227 | |
[email protected] | |
""" | |
def __init__(self, | |
config: DfNetConfig, | |
erb_fb: torch.FloatTensor, | |
erb_inv_fb: torch.FloatTensor, | |
run_df: bool = True, | |
train_mask: bool = True, | |
): | |
""" | |
:param erb_fb: erb filter bank. | |
""" | |
super(DfNet, self).__init__() | |
if config.erb_bins % 8 != 0: | |
raise AssertionError("erb_bins should be divisible by 8") | |
self.df_lookahead = config.df_lookahead | |
self.df_bins = config.df_bins | |
self.freq_bins: int = config.fft_size // 2 + 1 | |
self.emb_dim: int = config.conv_channels * config.erb_bins | |
self.erb_bins: int = config.erb_bins | |
if config.conv_lookahead > 0: | |
if config.conv_lookahead < config.df_lookahead: | |
raise AssertionError | |
# for last 2 dim, pad (left, right, top, bottom). | |
self.pad_feat = nn.ConstantPad2d((0, 0, -config.conv_lookahead, config.conv_lookahead), 0.0) | |
else: | |
self.pad_feat = nn.Identity() | |
if config.df_lookahead > 0: | |
# for last 3 dim, pad (left, right, top, bottom, front, back). | |
self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -config.df_lookahead, config.df_lookahead), 0.0) | |
else: | |
self.pad_spec = nn.Identity() | |
self.register_buffer("erb_fb", erb_fb) | |
self.enc = Encoder(config) | |
self.erb_dec = ErbDecoder(config) | |
self.mask = Mask(erb_inv_fb) | |
self.erb_inv_fb = erb_inv_fb | |
self.post_filter = config.mask_post_filter | |
self.post_filter_beta = config.post_filter_beta | |
self.df_order = config.df_order | |
self.df_op = MF.DF(num_freqs=config.df_bins, frame_size=config.df_order, lookahead=self.df_lookahead) | |
self.df_dec = DfDecoder(config) | |
self.df_out_transform = DfOutputReshapeMF(self.df_order, config.df_bins) | |
self.run_erb = config.df_bins + 1 < self.freq_bins | |
if not self.run_erb: | |
logger.warning("Running without ERB stage") | |
self.run_df = run_df | |
if not run_df: | |
logger.warning("Running without DF stage") | |
self.train_mask = train_mask | |
self.lsnr_dropout = config.lsnr_dropout | |
if config.df_n_iter != 1: | |
raise AssertionError | |
def forward1( | |
self, | |
spec: torch.Tensor, | |
feat_erb: torch.Tensor, | |
feat_spec: torch.Tensor, # Not used, take spec modified by mask instead | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
"""Forward method of DeepFilterNet2. | |
Args: | |
spec (Tensor): Spectrum of shape [B, 1, T, F, 2] | |
feat_erb (Tensor): ERB features of shape [B, 1, T, E] | |
feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F', 2] | |
Returns: | |
spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2] | |
m (Tensor): ERB mask estimate of shape [B, 1, T, E] | |
lsnr (Tensor): Local SNR estimate of shape [B, T, 1] | |
""" | |
# feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2] | |
feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) | |
# feat_spec shape: [batch_size, 2, time_steps, freq_dim] | |
# feat_erb shape: [batch_size, 1, time_steps, erb_bins] | |
# assert time_steps >= conv_lookahead. | |
feat_erb = self.pad_feat(feat_erb) | |
feat_spec = self.pad_feat(feat_spec) | |
e0, e1, e2, e3, emb, c0, lsnr, h = self.enc(feat_erb, feat_spec) | |
if self.lsnr_droput: | |
idcs = lsnr.squeeze() > -10.0 | |
b, t = (spec.shape[0], spec.shape[2]) | |
m = torch.zeros((b, 1, t, self.erb_bins), device=spec.device) | |
df_coefs = torch.zeros((b, t, self.nb_df, self.df_order * 2)) | |
spec_m = spec.clone() | |
emb = emb[:, idcs] | |
e0 = e0[:, :, idcs] | |
e1 = e1[:, :, idcs] | |
e2 = e2[:, :, idcs] | |
e3 = e3[:, :, idcs] | |
c0 = c0[:, :, idcs] | |
if self.run_erb: | |
if self.lsnr_dropout: | |
m[:, :, idcs] = self.erb_dec(emb, e3, e2, e1, e0) | |
else: | |
m = self.erb_dec(emb, e3, e2, e1, e0) | |
spec_m = self.mask(spec, m) | |
else: | |
m = torch.zeros((), device=spec.device) | |
spec_m = torch.zeros_like(spec) | |
if self.run_df: | |
if self.lsnr_dropout: | |
df_coefs[:, idcs] = self.df_dec(emb, c0) | |
else: | |
df_coefs = self.df_dec(emb, c0) | |
df_coefs = self.df_out_transform(df_coefs) | |
spec_e = self.df_op(spec.clone(), df_coefs) | |
spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :] | |
else: | |
df_coefs = torch.zeros((), device=spec.device) | |
spec_e = spec_m | |
if self.post_filter: | |
beta = self.post_filter_beta | |
eps = 1e-12 | |
mask = (utils.as_complex(spec_e).abs() / utils.as_complex(spec).abs().add(eps)).clamp(eps, 1) | |
mask_sin = mask * torch.sin(PI * mask / 2).clamp_min(eps) | |
pf = (1 + beta) / (1 + beta * mask.div(mask_sin).pow(2)) | |
spec_e = spec_e * pf.unsqueeze(-1) | |
return spec_e, m, lsnr, df_coefs | |
def forward( | |
self, | |
spec: torch.Tensor, | |
feat_erb: torch.Tensor, | |
feat_spec: torch.Tensor, # Not used, take spec modified by mask instead | |
erb_encoder_h: torch.Tensor = None, | |
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
# feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2] | |
feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) | |
# feat_spec shape: [batch_size, 2, time_steps, freq_dim] | |
# feat_erb shape: [batch_size, 1, time_steps, erb_bins] | |
# assert time_steps >= conv_lookahead. | |
feat_erb = self.pad_feat(feat_erb) | |
feat_spec = self.pad_feat(feat_spec) | |
e0, e1, e2, e3, emb, c0, lsnr, erb_encoder_h = self.enc(feat_erb, feat_spec, erb_encoder_h) | |
m = self.erb_dec(emb, e3, e2, e1, e0) | |
spec_m = self.mask(spec, m) | |
# spec_e = spec_m | |
df_coefs = self.df_dec(emb, c0) | |
df_coefs = self.df_out_transform(df_coefs) | |
spec_e = self.df_op(spec.clone(), df_coefs) | |
spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :] | |
return spec_e, m, lsnr, df_coefs, erb_encoder_h | |
if __name__ == "__main__": | |
pass | |