HoneyTian's picture
first commit
bd94e77
raw
history blame
28.6 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import logging
import math
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from toolbox.torchaudio.models.dfnet3.configuration_dfnet3 import DfNetConfig
from toolbox.torchaudio.models.dfnet3 import multiframes as MF
from toolbox.torchaudio.models.dfnet3 import utils
logger = logging.getLogger("toolbox")
PI = 3.1415926535897932384626433
norm_layer_dict = {
"batch_norm_2d": torch.nn.BatchNorm2d
}
activation_layer_dict = {
"relu": torch.nn.ReLU,
"identity": torch.nn.Identity,
"sigmoid": torch.nn.Sigmoid,
}
class CausalConv2d(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
):
"""
Causal Conv2d by delaying the signal for any lookahead.
Expected input format: [B, C, T, F]
:param in_channels:
:param out_channels:
:param kernel_size:
:param fstride:
:param dilation:
:param fpad:
"""
super(CausalConv2d, self).__init__()
lookahead = 0
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
if fpad:
fpad_ = kernel_size[1] // 2 + dilation - 1
else:
fpad_ = 0
# for last 2 dim, pad (left, right, top, bottom).
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = []
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
if max(kernel_size) == 1:
separable = False
layers.append(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(0, fpad_),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
)
if separable:
layers.append(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
)
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
layers.append(norm_layer(out_channels))
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
layers.append(activation_layer())
super().__init__(*layers)
class CausalConvTranspose2d(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
):
"""
Causal ConvTranspose2d.
Expected input format: [B, C, T, F]
"""
super(CausalConvTranspose2d, self).__init__()
lookahead = 0
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
if fpad:
fpad_ = kernel_size[1] // 2
else:
fpad_ = 0
# for last 2 dim, pad (left, right, top, bottom).
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = []
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
layers.append(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
output_padding=(0, fpad_),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
)
if separable:
layers.append(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
)
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
layers.append(norm_layer(out_channels))
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
layers.append(activation_layer())
super().__init__(*layers)
class GroupedLinear(nn.Module):
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
super().__init__()
# self.weight: Tensor
self.input_size = input_size
self.hidden_size = hidden_size
self.groups = groups
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
self.ws = input_size // groups
self.register_parameter(
"weight",
torch.nn.Parameter(
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., I]
b, t, _ = x.shape
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
new_shape = (b, t, self.groups, self.ws)
x = x.view(new_shape)
# The better way, but not supported by torchscript
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
x = x.flatten(2, 3) # [B, T, H]
return x
def __repr__(self):
cls = self.__class__.__name__
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
class SqueezedGRU_S(nn.Module):
"""
SGE net: Video object detection with squeezed GRU and information entropy map
https://arxiv.org/abs/2106.07224
"""
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: Optional[int] = None,
num_layers: int = 1,
linear_groups: int = 8,
batch_first: bool = True,
skip_op: str = "none",
activation_layer: str = "identity",
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear_in = nn.Sequential(
GroupedLinear(
input_size=input_size,
hidden_size=hidden_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
# gru skip operator
self.gru_skip_op = None
if skip_op == "none":
self.gru_skip_op = None
elif skip_op == "identity":
if not input_size != output_size:
raise AssertionError("Dimensions do not match")
self.gru_skip_op = nn.Identity()
elif skip_op == "grouped_linear":
self.gru_skip_op = GroupedLinear(
input_size=hidden_size,
hidden_size=hidden_size,
groups=linear_groups,
)
else:
raise NotImplementedError()
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
)
if output_size is not None:
self.linear_out = nn.Sequential(
GroupedLinear(
input_size=hidden_size,
hidden_size=output_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
else:
self.linear_out = nn.Identity()
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.linear_in(inputs)
x, h = self.gru(x, h)
x = self.linear_out(x)
if self.gru_skip_op is not None:
x = x + self.gru_skip_op(inputs)
return x, h
class Add(nn.Module):
def forward(self, a, b):
return a + b
class Concat(nn.Module):
def forward(self, a, b):
return torch.cat((a, b), dim=-1)
class Encoder(nn.Module):
def __init__(self, config: DfNetConfig):
super(Encoder, self).__init__()
self.emb_in_dim = config.conv_channels * config.erb_bins // 4
self.emb_out_dim = config.conv_channels * config.erb_bins // 4
self.emb_hidden_dim = config.emb_hidden_dim
self.erb_conv0 = CausalConv2d(
in_channels=1,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
)
self.erb_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.erb_conv2 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.erb_conv3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=1,
)
self.df_conv0 = CausalConv2d(
in_channels=2,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
)
self.df_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.df_fc_emb = nn.Sequential(
GroupedLinear(
config.conv_channels * config.df_bins // 2,
self.emb_in_dim,
groups=config.encoder_linear_groups
),
nn.ReLU(inplace=True)
)
if config.encoder_concat:
self.emb_in_dim *= 2
self.combine = Concat()
else:
self.combine = Add()
self.emb_gru = SqueezedGRU_S(
self.emb_in_dim,
self.emb_hidden_dim,
output_size=self.emb_out_dim,
num_layers=1,
batch_first=True,
skip_op=config.encoder_gru_skip_op,
linear_groups=config.encoder_squeezed_gru_linear_groups,
activation_layer="relu",
)
self.lsnr_fc = nn.Sequential(
nn.Linear(self.emb_out_dim, 1),
nn.Sigmoid()
)
self.lsnr_scale = config.lsnr_max - config.lsnr_min
self.lsnr_offset = config.lsnr_min
def forward(self,
feat_erb: torch.Tensor,
feat_spec: torch.Tensor,
h: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands.
# erb: [B, 1, T, Fe]
# spec: [B, 2, T, Fc]
# b, _, t, _ = feat_erb.shape
e0 = self.erb_conv0(feat_erb) # [B, C, T, F]
e1 = self.erb_conv1(e0) # [B, C*2, T, F/2]
e2 = self.erb_conv2(e1) # [B, C*4, T, F/4]
e3 = self.erb_conv3(e2) # [B, C*4, T, F/4]
c0 = self.df_conv0(feat_spec) # [B, C, T, Fc]
c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2]
cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1]
cemb = self.df_fc_emb(cemb) # [T, B, C * F/4]
emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F]
emb = self.combine(emb, cemb)
emb, h = self.emb_gru(emb, h) # [B, T, -1]
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
return e0, e1, e2, e3, emb, c0, lsnr, h
class ErbDecoder(nn.Module):
def __init__(self,
config: DfNetConfig,
):
super(ErbDecoder, self).__init__()
if config.erb_bins % 8 != 0:
raise AssertionError("erb_bins should be divisible by 8")
self.emb_in_dim = config.conv_channels * config.erb_bins // 4
self.emb_out_dim = config.conv_channels * config.erb_bins // 4
self.emb_hidden_dim = config.emb_hidden_dim
self.emb_gru = SqueezedGRU_S(
self.emb_in_dim,
self.emb_hidden_dim,
output_size=self.emb_out_dim,
num_layers=config.erb_decoder_emb_num_layers - 1,
batch_first=True,
skip_op=config.erb_decoder_gru_skip_op,
linear_groups=config.erb_decoder_linear_groups,
activation_layer="relu",
)
# convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions
self.conv3p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
)
self.convt3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
)
self.conv2p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
)
self.convt2 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
fstride=2,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
)
self.conv1p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
)
self.convt1 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
fstride=2,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
)
self.conv0p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
)
self.conv0_out = CausalConv2d(
in_channels=config.conv_channels,
out_channels=1,
kernel_size=config.conv_kernel_size_inner,
activation_layer="sigmoid",
bias=False,
separable=True,
)
def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
# Estimates erb mask
b, _, t, f8 = e3.shape
emb, _ = self.emb_gru(emb)
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8]
e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4]
e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2]
e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F]
m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F]
return m
class Mask(nn.Module):
def __init__(self, erb_inv_fb: torch.FloatTensor, post_filter: bool = False, eps: float = 1e-12):
super().__init__()
self.erb_inv_fb: torch.FloatTensor
self.register_buffer("erb_inv_fb", erb_inv_fb.float())
self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0"
self.post_filter = post_filter
self.eps = eps
def pf(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
"""
Post-Filter
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
https://arxiv.org/abs/2008.04259
:param mask: Real valued mask, typically of shape [B, C, T, F].
:param beta: Global gain factor.
:return:
"""
mask_sin = mask * torch.sin(np.pi * mask / 2)
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
return mask_pf
def forward(self, spec: torch.Tensor, mask: torch.Tensor, atten_lim: Optional[torch.Tensor] = None) -> torch.Tensor:
# spec (real) [B, 1, T, F, 2], F: freq_bins
# mask (real): [B, 1, T, Fe], Fe: erb_bins
# atten_lim: [B]
if not self.training and self.post_filter:
mask = self.pf(mask)
if atten_lim is not None:
# dB to amplitude
atten_lim = 10 ** (-atten_lim / 20)
# Greater equal (__ge__) not implemented for TorchVersion.
if self.clamp_tensor:
# Supported by torch >= 1.9
mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1))
else:
m_out = []
for i in range(atten_lim.shape[0]):
m_out.append(mask[i].clamp_min(atten_lim[i].item()))
mask = torch.stack(m_out, dim=0)
mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F]
if not spec.is_complex():
mask = mask.unsqueeze(4)
return spec * mask
class DfDecoder(nn.Module):
def __init__(self,
config: DfNetConfig,
):
super().__init__()
layer_width = config.conv_channels
self.emb_in_dim = config.conv_channels * config.erb_bins // 4
self.emb_dim = config.df_hidden_dim
self.df_n_hidden = config.df_hidden_dim
self.df_n_layers = config.df_num_layers
self.df_order = config.df_order
self.df_bins = config.df_bins
self.df_out_ch = config.df_order * 2
self.df_convp = CausalConv2d(
layer_width,
self.df_out_ch,
fstride=1,
kernel_size=(config.df_pathway_kernel_size_t, 1),
separable=True,
bias=False,
)
self.df_gru = SqueezedGRU_S(
self.emb_in_dim,
self.emb_dim,
num_layers=self.df_n_layers,
batch_first=True,
skip_op="none",
activation_layer="relu",
)
if config.df_gru_skip == "none":
self.df_skip = None
elif config.df_gru_skip == "identity":
if config.emb_hidden_dim != config.df_hidden_dim:
raise AssertionError("Dimensions do not match")
self.df_skip = nn.Identity()
elif config.df_gru_skip == "grouped_linear":
self.df_skip = GroupedLinear(self.emb_in_dim, self.emb_dim, groups=config.df_decoder_linear_groups)
else:
raise NotImplementedError()
self.df_out: nn.Module
out_dim = self.df_bins * self.df_out_ch
self.df_out = nn.Sequential(
GroupedLinear(
input_size=self.df_n_hidden,
hidden_size=out_dim,
groups=config.df_decoder_linear_groups
),
nn.Tanh()
)
self.df_fc_a = nn.Sequential(
nn.Linear(self.df_n_hidden, 1),
nn.Sigmoid()
)
def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
b, t, _ = emb.shape
c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden
if self.df_skip is not None:
c = c + self.df_skip(emb)
c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last
c = self.df_out(c) # [B, T, F*O*2], O: df_order
c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2]
return c
class DfOutputReshapeMF(nn.Module):
"""Coefficients output reshape for multiframe/MultiFrameModule
Requires input of shape B, C, T, F, 2.
"""
def __init__(self, df_order: int, df_bins: int):
super().__init__()
self.df_order = df_order
self.df_bins = df_bins
def forward(self, coefs: torch.Tensor) -> torch.Tensor:
# [B, T, F, O*2] -> [B, O, T, F, 2]
new_shape = list(coefs.shape)
new_shape[-1] = -1
new_shape.append(2)
coefs = coefs.view(new_shape)
coefs = coefs.permute(0, 3, 1, 2, 4)
return coefs
class DfNet(nn.Module):
"""
DeepFilterNet: Perceptually Motivated Real-Time Speech Enhancement
https://arxiv.org/abs/2305.08227
[email protected]
"""
def __init__(self,
config: DfNetConfig,
erb_fb: torch.FloatTensor,
erb_inv_fb: torch.FloatTensor,
run_df: bool = True,
train_mask: bool = True,
):
"""
:param erb_fb: erb filter bank.
"""
super(DfNet, self).__init__()
if config.erb_bins % 8 != 0:
raise AssertionError("erb_bins should be divisible by 8")
self.df_lookahead = config.df_lookahead
self.df_bins = config.df_bins
self.freq_bins: int = config.fft_size // 2 + 1
self.emb_dim: int = config.conv_channels * config.erb_bins
self.erb_bins: int = config.erb_bins
if config.conv_lookahead > 0:
if config.conv_lookahead < config.df_lookahead:
raise AssertionError
# for last 2 dim, pad (left, right, top, bottom).
self.pad_feat = nn.ConstantPad2d((0, 0, -config.conv_lookahead, config.conv_lookahead), 0.0)
else:
self.pad_feat = nn.Identity()
if config.df_lookahead > 0:
# for last 3 dim, pad (left, right, top, bottom, front, back).
self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -config.df_lookahead, config.df_lookahead), 0.0)
else:
self.pad_spec = nn.Identity()
self.register_buffer("erb_fb", erb_fb)
self.enc = Encoder(config)
self.erb_dec = ErbDecoder(config)
self.mask = Mask(erb_inv_fb)
self.erb_inv_fb = erb_inv_fb
self.post_filter = config.mask_post_filter
self.post_filter_beta = config.post_filter_beta
self.df_order = config.df_order
self.df_op = MF.DF(num_freqs=config.df_bins, frame_size=config.df_order, lookahead=self.df_lookahead)
self.df_dec = DfDecoder(config)
self.df_out_transform = DfOutputReshapeMF(self.df_order, config.df_bins)
self.run_erb = config.df_bins + 1 < self.freq_bins
if not self.run_erb:
logger.warning("Running without ERB stage")
self.run_df = run_df
if not run_df:
logger.warning("Running without DF stage")
self.train_mask = train_mask
self.lsnr_dropout = config.lsnr_dropout
if config.df_n_iter != 1:
raise AssertionError
def forward1(
self,
spec: torch.Tensor,
feat_erb: torch.Tensor,
feat_spec: torch.Tensor, # Not used, take spec modified by mask instead
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Forward method of DeepFilterNet2.
Args:
spec (Tensor): Spectrum of shape [B, 1, T, F, 2]
feat_erb (Tensor): ERB features of shape [B, 1, T, E]
feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F', 2]
Returns:
spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2]
m (Tensor): ERB mask estimate of shape [B, 1, T, E]
lsnr (Tensor): Local SNR estimate of shape [B, T, 1]
"""
# feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2]
feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
# feat_spec shape: [batch_size, 2, time_steps, freq_dim]
# feat_erb shape: [batch_size, 1, time_steps, erb_bins]
# assert time_steps >= conv_lookahead.
feat_erb = self.pad_feat(feat_erb)
feat_spec = self.pad_feat(feat_spec)
e0, e1, e2, e3, emb, c0, lsnr, h = self.enc(feat_erb, feat_spec)
if self.lsnr_droput:
idcs = lsnr.squeeze() > -10.0
b, t = (spec.shape[0], spec.shape[2])
m = torch.zeros((b, 1, t, self.erb_bins), device=spec.device)
df_coefs = torch.zeros((b, t, self.nb_df, self.df_order * 2))
spec_m = spec.clone()
emb = emb[:, idcs]
e0 = e0[:, :, idcs]
e1 = e1[:, :, idcs]
e2 = e2[:, :, idcs]
e3 = e3[:, :, idcs]
c0 = c0[:, :, idcs]
if self.run_erb:
if self.lsnr_dropout:
m[:, :, idcs] = self.erb_dec(emb, e3, e2, e1, e0)
else:
m = self.erb_dec(emb, e3, e2, e1, e0)
spec_m = self.mask(spec, m)
else:
m = torch.zeros((), device=spec.device)
spec_m = torch.zeros_like(spec)
if self.run_df:
if self.lsnr_dropout:
df_coefs[:, idcs] = self.df_dec(emb, c0)
else:
df_coefs = self.df_dec(emb, c0)
df_coefs = self.df_out_transform(df_coefs)
spec_e = self.df_op(spec.clone(), df_coefs)
spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :]
else:
df_coefs = torch.zeros((), device=spec.device)
spec_e = spec_m
if self.post_filter:
beta = self.post_filter_beta
eps = 1e-12
mask = (utils.as_complex(spec_e).abs() / utils.as_complex(spec).abs().add(eps)).clamp(eps, 1)
mask_sin = mask * torch.sin(PI * mask / 2).clamp_min(eps)
pf = (1 + beta) / (1 + beta * mask.div(mask_sin).pow(2))
spec_e = spec_e * pf.unsqueeze(-1)
return spec_e, m, lsnr, df_coefs
def forward(
self,
spec: torch.Tensor,
feat_erb: torch.Tensor,
feat_spec: torch.Tensor, # Not used, take spec modified by mask instead
erb_encoder_h: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
# feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2]
feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2)
# feat_spec shape: [batch_size, 2, time_steps, freq_dim]
# feat_erb shape: [batch_size, 1, time_steps, erb_bins]
# assert time_steps >= conv_lookahead.
feat_erb = self.pad_feat(feat_erb)
feat_spec = self.pad_feat(feat_spec)
e0, e1, e2, e3, emb, c0, lsnr, erb_encoder_h = self.enc(feat_erb, feat_spec, erb_encoder_h)
m = self.erb_dec(emb, e3, e2, e1, e0)
spec_m = self.mask(spec, m)
# spec_e = spec_m
df_coefs = self.df_dec(emb, c0)
df_coefs = self.df_out_transform(df_coefs)
spec_e = self.df_op(spec.clone(), df_coefs)
spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :]
return spec_e, m, lsnr, df_coefs, erb_encoder_h
if __name__ == "__main__":
pass