HoneyTian's picture
update
da78a0e
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
import math
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
import torchaudio
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT
MODEL_FILE = "model.pt"
norm_layer_dict = {
"batch_norm_2d": torch.nn.BatchNorm2d
}
activation_layer_dict = {
"relu": torch.nn.ReLU,
"identity": torch.nn.Identity,
"sigmoid": torch.nn.Sigmoid,
}
class CausalConv2d(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
lookahead: int = 0
):
"""
Causal Conv2d by delaying the signal for any lookahead.
Expected input format: [batch_size, channels, time_steps, spec_dim]
:param in_channels:
:param out_channels:
:param kernel_size:
:param fstride:
:param dilation:
:param fpad:
"""
super(CausalConv2d, self).__init__()
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
if fpad:
fpad_ = kernel_size[1] // 2 + dilation - 1
else:
fpad_ = 0
# for last 2 dim, pad (left, right, top, bottom).
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = list()
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
if max(kernel_size) == 1:
separable = False
layers.append(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(0, fpad_),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
)
if separable:
layers.append(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
)
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
layers.append(norm_layer(out_channels))
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
layers.append(activation_layer())
super().__init__(*layers)
def forward(self, inputs):
for module in self:
inputs = module(inputs)
return inputs
class CausalConvTranspose2d(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
lookahead: int = 0
):
"""
Causal ConvTranspose2d.
Expected input format: [batch_size, channels, time_steps, spec_dim]
"""
super(CausalConvTranspose2d, self).__init__()
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
if fpad:
fpad_ = kernel_size[1] // 2
else:
fpad_ = 0
# for last 2 dim, pad (left, right, top, bottom).
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = []
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
layers.append(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
output_padding=(0, fpad_),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
)
if separable:
layers.append(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
)
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
layers.append(norm_layer(out_channels))
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
layers.append(activation_layer())
super().__init__(*layers)
class GroupedLinear(nn.Module):
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
super().__init__()
# self.weight: Tensor
self.input_size = input_size
self.hidden_size = hidden_size
self.groups = groups
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
self.ws = input_size // groups
self.register_parameter(
"weight",
torch.nn.Parameter(
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., I]
b, t, _ = x.shape
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
new_shape = (b, t, self.groups, self.ws)
x = x.view(new_shape)
# The better way, but not supported by torchscript
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
x = x.flatten(2, 3) # [B, T, H]
return x
def __repr__(self):
cls = self.__class__.__name__
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
class SqueezedGRU_S(nn.Module):
"""
SGE net: Video object detection with squeezed GRU and information entropy map
https://arxiv.org/abs/2106.07224
"""
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: Optional[int] = None,
num_layers: int = 1,
linear_groups: int = 8,
batch_first: bool = True,
skip_op: str = "none",
activation_layer: str = "identity",
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear_in = nn.Sequential(
GroupedLinear(
input_size=input_size,
hidden_size=hidden_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
# gru skip operator
self.gru_skip_op = None
if skip_op == "none":
self.gru_skip_op = None
elif skip_op == "identity":
if not input_size != output_size:
raise AssertionError("Dimensions do not match")
self.gru_skip_op = nn.Identity()
elif skip_op == "grouped_linear":
self.gru_skip_op = GroupedLinear(
input_size=hidden_size,
hidden_size=hidden_size,
groups=linear_groups,
)
else:
raise NotImplementedError()
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
bidirectional=False,
)
if output_size is not None:
self.linear_out = nn.Sequential(
GroupedLinear(
input_size=hidden_size,
hidden_size=output_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
else:
self.linear_out = nn.Identity()
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
x = self.linear_in(inputs)
x, h = self.gru.forward(x, h)
x = self.linear_out(x)
if self.gru_skip_op is not None:
x = x + self.gru_skip_op(inputs)
return x, h
class Add(nn.Module):
def forward(self, a, b):
return a + b
class Concat(nn.Module):
def forward(self, a, b):
return torch.cat((a, b), dim=-1)
class Encoder(nn.Module):
def __init__(self, config: DfNetConfig):
super(Encoder, self).__init__()
self.embedding_input_size = config.conv_channels * config.spec_bins // 4
self.embedding_output_size = config.conv_channels * config.spec_bins // 4
self.embedding_hidden_size = config.embedding_hidden_size
self.spec_conv0 = CausalConv2d(
in_channels=1,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.spec_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.spec_conv2 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.spec_conv3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.df_conv0 = CausalConv2d(
in_channels=2,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
fstride=1,
)
self.df_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.df_fc_emb = nn.Sequential(
GroupedLinear(
config.conv_channels * config.df_bins // 2,
self.embedding_input_size,
groups=config.encoder_linear_groups
),
nn.ReLU(inplace=True)
)
if config.encoder_combine_op == "concat":
self.embedding_input_size *= 2
self.combine = Concat()
else:
self.combine = Add()
# emb_gru
if config.spec_bins % 8 != 0:
raise AssertionError("spec_bins should be divisible by 8")
self.emb_gru = SqueezedGRU_S(
self.embedding_input_size,
self.embedding_hidden_size,
output_size=self.embedding_output_size,
num_layers=1,
batch_first=True,
skip_op=config.encoder_emb_skip_op,
linear_groups=config.encoder_emb_linear_groups,
activation_layer="relu",
)
# lsnr
self.lsnr_fc = nn.Sequential(
nn.Linear(self.embedding_output_size, 1),
nn.Sigmoid()
)
self.lsnr_scale = config.lsnr_max - config.lsnr_min
self.lsnr_offset = config.lsnr_min
def forward(self,
feat_power: torch.Tensor,
feat_spec: torch.Tensor,
hidden_state: torch.Tensor = None,
):
# feat_power shape: (batch_size, 1, time_steps, spec_dim)
e0 = self.spec_conv0.forward(feat_power)
e1 = self.spec_conv1.forward(e0)
e2 = self.spec_conv2.forward(e1)
e3 = self.spec_conv3.forward(e2)
# e0 shape: [batch_size, channels, time_steps, spec_dim]
# e1 shape: [batch_size, channels, time_steps, spec_dim // 2]
# e2 shape: [batch_size, channels, time_steps, spec_dim // 4]
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
# feat_spec, shape: (batch_size, 2, time_steps, df_bins)
c0 = self.df_conv0(feat_spec)
c1 = self.df_conv1(c0)
# c0 shape: [batch_size, channels, time_steps, df_bins]
# c1 shape: [batch_size, channels, time_steps, df_bins // 2]
cemb = c1.permute(0, 2, 3, 1)
# cemb shape: [batch_size, time_steps, df_bins // 2, channels]
cemb = cemb.flatten(2)
# cemb shape: [batch_size, time_steps, df_bins // 2 * channels]
cemb = self.df_fc_emb(cemb)
# cemb shape: [batch_size, time_steps, spec_dim // 4 * channels]
# e3 shape: [batch_size, channels, time_steps, spec_dim // 4]
emb = e3.permute(0, 2, 3, 1)
# emb shape: [batch_size, time_steps, spec_dim // 4, channels]
emb = emb.flatten(2)
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
emb = self.combine(emb, cemb)
# if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2]
# if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
emb, h = self.emb_gru.forward(emb, hidden_state)
# emb shape: [batch_size, time_steps, spec_dim // 4 * channels]
# h shape: [batch_size, 1, spec_dim]
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
# lsnr shape: [batch_size, time_steps, 1]
return e0, e1, e2, e3, emb, c0, lsnr, h
class Decoder(nn.Module):
def __init__(self, config: DfNetConfig):
super(Decoder, self).__init__()
if config.spec_bins % 8 != 0:
raise AssertionError("spec_bins should be divisible by 8")
self.emb_in_dim = config.conv_channels * config.spec_bins // 4
self.emb_out_dim = config.conv_channels * config.spec_bins // 4
self.emb_hidden_dim = config.decoder_emb_hidden_size
self.emb_gru = SqueezedGRU_S(
self.emb_in_dim,
self.emb_hidden_dim,
output_size=self.emb_out_dim,
num_layers=config.decoder_emb_num_layers - 1,
batch_first=True,
skip_op=config.decoder_emb_skip_op,
linear_groups=config.decoder_emb_linear_groups,
activation_layer="relu",
)
self.conv3p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.convt3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.conv2p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.convt2 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.conv1p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.convt1 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.conv0p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.conv0_out = CausalConv2d(
in_channels=config.conv_channels,
out_channels=1,
kernel_size=config.conv_kernel_size_inner,
activation_layer="sigmoid",
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
# Estimates erb mask
b, _, t, f8 = e3.shape
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
emb, _ = self.emb_gru(emb)
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
e3 = self.convt3(self.conv3p(e3) + emb)
# e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
e2 = self.convt2(self.conv2p(e2) + e3)
# e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
e1 = self.convt1(self.conv1p(e1) + e2)
# e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
mask = self.conv0_out(self.conv0p(e0) + e1)
# mask shape: [batch_size, 1, time_steps, freq_dim]
return mask
class DfDecoder(nn.Module):
def __init__(self, config: DfNetConfig):
super(DfDecoder, self).__init__()
self.embedding_input_size = config.conv_channels * config.spec_bins // 4
self.df_decoder_hidden_size = config.df_decoder_hidden_size
self.df_num_layers = config.df_num_layers
self.df_order = config.df_order
self.df_bins = config.df_bins
self.df_out_ch = config.df_order * 2
self.df_convp = CausalConv2d(
config.conv_channels,
self.df_out_ch,
fstride=1,
kernel_size=(config.df_pathway_kernel_size_t, 1),
separable=True,
bias=False,
)
self.df_gru = SqueezedGRU_S(
self.embedding_input_size,
self.df_decoder_hidden_size,
num_layers=self.df_num_layers,
batch_first=True,
skip_op="none",
activation_layer="relu",
)
if config.df_gru_skip == "none":
self.df_skip = None
elif config.df_gru_skip == "identity":
if config.embedding_hidden_size != config.df_decoder_hidden_size:
raise AssertionError("Dimensions do not match")
self.df_skip = nn.Identity()
elif config.df_gru_skip == "grouped_linear":
self.df_skip = GroupedLinear(
self.embedding_input_size,
self.df_decoder_hidden_size,
groups=config.df_decoder_linear_groups
)
else:
raise NotImplementedError()
self.df_out: nn.Module
out_dim = self.df_bins * self.df_out_ch
self.df_out = nn.Sequential(
GroupedLinear(
input_size=self.df_decoder_hidden_size,
hidden_size=out_dim,
groups=config.df_decoder_linear_groups
),
nn.Tanh()
)
self.df_fc_a = nn.Sequential(
nn.Linear(self.df_decoder_hidden_size, 1),
nn.Sigmoid()
)
def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
# emb shape: [batch_size, time_steps, df_bins // 4 * channels]
b, t, _ = emb.shape
df_coefs, _ = self.df_gru(emb)
if self.df_skip is not None:
df_coefs = df_coefs + self.df_skip(emb)
# df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
# c0 shape: [batch_size, channels, time_steps, df_bins]
c0 = self.df_convp(c0)
# c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
c0 = c0.permute(0, 2, 3, 1)
# c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
# df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
df_coefs = df_coefs + c0
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
return df_coefs
class DfOutputReshapeMF(nn.Module):
"""Coefficients output reshape for multiframe/MultiFrameModule
Requires input of shape B, C, T, F, 2.
"""
def __init__(self, df_order: int, df_bins: int):
super().__init__()
self.df_order = df_order
self.df_bins = df_bins
def forward(self, coefs: torch.Tensor) -> torch.Tensor:
# [B, T, F, O*2] -> [B, O, T, F, 2]
new_shape = list(coefs.shape)
new_shape[-1] = -1
new_shape.append(2)
coefs = coefs.view(new_shape)
coefs = coefs.permute(0, 3, 1, 2, 4)
return coefs
class Mask(nn.Module):
def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
super().__init__()
self.use_post_filter = use_post_filter
self.eps = eps
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
"""
Post-Filter
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
https://arxiv.org/abs/2008.04259
:param mask: Real valued mask, typically of shape [B, C, T, F].
:param beta: Global gain factor.
:return:
"""
mask_sin = mask * torch.sin(np.pi * mask / 2)
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
return mask_pf
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
if not self.training and self.use_post_filter:
mask = self.post_filter(mask)
# mask shape: [batch_size, 1, time_steps, spec_bins]
mask = mask.unsqueeze(4)
# mask shape: [batch_size, 1, time_steps, spec_bins, 1]
return spec * mask
class DeepFiltering(nn.Module):
def __init__(self,
df_bins: int,
df_order: int,
lookahead: int = 0,
):
super(DeepFiltering, self).__init__()
self.df_bins = df_bins
self.df_order = df_order
self.need_unfold = df_order > 1
self.lookahead = lookahead
self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
def spec_unfold(self, spec: torch.Tensor):
"""
Pads and unfolds the spectrogram according to frame_size.
:param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
:return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
"""
if self.need_unfold:
# spec shape: [batch_size, spec_bins, time_steps]
spec_pad = self.pad(spec)
# spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins]
spec_unfold = spec_pad.unfold(2, self.df_order, 1)
# spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order]
return spec_unfold
else:
return spec.unsqueeze(-1)
def forward(self,
spec: torch.Tensor,
coefs: torch.Tensor,
):
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous()))
# spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]
# coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
coefs = torch.view_as_complex(coefs.contiguous())
# coefs shape: [batch_size, df_order, time_steps, df_bins]
spec_f = spec_u.narrow(-2, 0, self.df_bins)
# spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
# coefs shape: [batch_size, 1, df_order, time_steps, df_bins]
spec_f = self.df(spec_f, coefs)
# spec_f shape: [batch_size, 1, time_steps, df_bins]
if self.training:
spec = spec.clone()
spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
return spec
@staticmethod
def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
"""
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
:param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
:param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
:return: (complex Tensor). Spectrogram of shape [B, C, T, F].
"""
return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
class DfNet(nn.Module):
def __init__(self, config: DfNetConfig):
super(DfNet, self).__init__()
self.config = config
self.freq_bins = self.config.nfft // 2 + 1
self.nfft = config.nfft
self.win_size = config.win_size
self.hop_size = config.hop_size
self.win_type = config.win_type
self.stft = ConvSTFT(
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
feature_type="complex",
requires_grad=False
)
self.istft = ConviSTFT(
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
feature_type="complex",
requires_grad=False
)
self.encoder = Encoder(config)
self.decoder = Decoder(config)
self.df_decoder = DfDecoder(config)
self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
self.df_op = DeepFiltering(
df_bins=config.df_bins,
df_order=config.df_order,
lookahead=config.df_lookahead,
)
self.mask = Mask(use_post_filter=config.use_post_filter)
def forward(self,
noisy: torch.Tensor,
):
if noisy.dim() == 2:
noisy = torch.unsqueeze(noisy, dim=1)
_, _, n_samples = noisy.shape
remainder = (n_samples - self.win_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0)
# [batch_size, freq_bins * 2, time_steps]
cmp_spec = self.stft.forward(noisy)
# [batch_size, 1, freq_bins * 2, time_steps]
cmp_spec = torch.unsqueeze(cmp_spec, 1)
# [batch_size, 2, freq_bins, time_steps]
cmp_spec = torch.cat([
cmp_spec[:, :, :self.freq_bins, :],
cmp_spec[:, :, self.freq_bins:, :],
], dim=1)
# n//2+1 -> n//2; 257 -> 256
cmp_spec = cmp_spec[:, :, :-1, :]
spec = torch.unsqueeze(cmp_spec, dim=4)
# [batch_size, 2, freq_bins, time_steps, 1]
spec = spec.permute(0, 4, 3, 2, 1)
# spec shape: [batch_size, 1, time_steps, freq_bins, 2]
feat_power = torch.sum(torch.square(spec), dim=-1)
# feat_power shape: [batch_size, 1, time_steps, spec_bins]
feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3)
# feat_spec shape: [batch_size, 2, time_steps, freq_bins]
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
# feat_spec shape: [batch_size, 2, time_steps, df_bins]
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec)
mask = self.decoder.forward(emb, e3, e2, e1, e0)
# mask shape: [batch_size, 1, time_steps, spec_bins]
if torch.any(mask > 1) or torch.any(mask < 0):
raise AssertionError
spec_m = self.mask.forward(spec, mask)
# lsnr shape: [batch_size, time_steps, 1]
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
# lsnr shape: [batch_size, 1, time_steps]
df_coefs = self.df_decoder.forward(emb, c0)
df_coefs = self.df_out_transform(df_coefs)
# df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
spec_e = self.df_op.forward(spec.clone(), df_coefs)
# est_spec shape: [batch_size, 1, time_steps, spec_bins, 2]
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
spec_e = torch.squeeze(spec_e, dim=1)
spec_e = spec_e.permute(0, 2, 1, 3)
# spec_e shape: [batch_size, spec_bins, time_steps, 2]
mask = torch.squeeze(mask, dim=1)
est_mask = mask.permute(0, 2, 1)
# mask shape: [batch_size, spec_bins, time_steps]
b, _, t, _ = spec_e.shape
est_spec = torch.cat(tensors=[
torch.concat(tensors=[
spec_e[..., 0],
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
], dim=1),
torch.concat(tensors=[
spec_e[..., 1],
torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device)
], dim=1),
], dim=1)
# est_spec shape: [b, n+2, t]
est_wav = self.istft.forward(est_spec)
est_wav = torch.squeeze(est_wav, dim=1)
est_wav = est_wav[:, :n_samples]
# est_wav shape: [b, n_samples]
return est_spec, est_wav, est_mask, lsnr
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
"""
:param est_mask: torch.Tensor, shape: [b, n+2, t]
:param clean:
:param noisy:
:return:
"""
clean_stft = self.stft(clean)
clean_re = clean_stft[:, :self.freq_bins, :]
clean_im = clean_stft[:, self.freq_bins:, :]
noisy_stft = self.stft(noisy)
noisy_re = noisy_stft[:, :self.freq_bins, :]
noisy_im = noisy_stft[:, self.freq_bins:, :]
noisy_power = noisy_re ** 2 + noisy_im ** 2
sr = clean_re
yr = noisy_re
si = clean_im
yi = noisy_im
y_pow = noisy_power
# (Sr * Yr + Si * Yi) / (Y_pow + 1e-8)
gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps)
# (Si * Yr - Sr * Yi) / (Y_pow + 1e-8)
gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps)
gth_mask_re[gth_mask_re > 2] = 1
gth_mask_re[gth_mask_re < -2] = -1
gth_mask_im[gth_mask_im > 2] = 1
gth_mask_im[gth_mask_im < -2] = -1
mask_re = est_mask[:, :self.freq_bins, :]
mask_im = est_mask[:, self.freq_bins:, :]
loss_re = F.mse_loss(gth_mask_re, mask_re)
loss_im = F.mse_loss(gth_mask_im, mask_im)
loss = loss_re + loss_im
return loss
class DfNetPretrainedModel(DfNet):
def __init__(self,
config: DfNetConfig,
):
super(DfNetPretrainedModel, self).__init__(
config=config,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
config = DfNetConfig()
model = DfNetPretrainedModel(config=config)
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
output = model.forward(noisy)
print(output[1].shape)
return
if __name__ == "__main__":
main()