HoneyTian's picture
update
bd3d872
raw
history blame
36.7 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
DeepFilterNet 的原生实现不直接支持流式推理
社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现
https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF
"""
import os
import math
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
MODEL_FILE = "model.pt"
norm_layer_dict = {
"batch_norm_2d": torch.nn.BatchNorm2d
}
activation_layer_dict = {
"relu": torch.nn.ReLU,
"identity": torch.nn.Identity,
"sigmoid": torch.nn.Sigmoid,
}
class CausalConv2d(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
lookahead: int = 0
):
"""
Causal Conv2d by delaying the signal for any lookahead.
Expected input format: [batch_size, channels, time_steps, spec_dim]
:param in_channels:
:param out_channels:
:param kernel_size:
:param fstride:
:param dilation:
:param fpad:
"""
super(CausalConv2d, self).__init__()
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
if fpad:
fpad_ = kernel_size[1] // 2 + dilation - 1
else:
fpad_ = 0
# for last 2 dim, pad (left, right, top, bottom).
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = list()
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
if max(kernel_size) == 1:
separable = False
layers.append(
nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(0, fpad_),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
)
if separable:
layers.append(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
)
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
layers.append(norm_layer(out_channels))
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
layers.append(activation_layer())
super().__init__(*layers)
def forward(self, inputs):
for module in self:
inputs = module(inputs)
return inputs
class CausalConvTranspose2d(nn.Sequential):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
fpad: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
lookahead: int = 0
):
"""
Causal ConvTranspose2d.
Expected input format: [batch_size, channels, time_steps, spec_dim]
"""
super(CausalConvTranspose2d, self).__init__()
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
if fpad:
fpad_ = kernel_size[1] // 2
else:
fpad_ = 0
# for last 2 dim, pad (left, right, top, bottom).
pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead)
layers = []
if any(x > 0 for x in pad):
layers.append(nn.ConstantPad2d(pad, 0.0))
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
layers.append(
nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size[0] - 1, fpad_ + dilation - 1),
output_padding=(0, fpad_),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
)
if separable:
layers.append(
nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
)
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
layers.append(norm_layer(out_channels))
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
layers.append(activation_layer())
super().__init__(*layers)
class GroupedLinear(nn.Module):
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
super().__init__()
# self.weight: Tensor
self.input_size = input_size
self.hidden_size = hidden_size
self.groups = groups
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
self.ws = input_size // groups
self.register_parameter(
"weight",
torch.nn.Parameter(
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., I]
b, t, f = x.shape
if f != self.input_size:
raise AssertionError
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
new_shape = (b, t, self.groups, self.ws)
x = x.view(new_shape)
# The better way, but not supported by torchscript
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
x = x.flatten(2, 3)
# x: [b, t, h]
return x
def __repr__(self):
cls = self.__class__.__name__
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
class SqueezedGRU_S(nn.Module):
"""
SGE net: Video object detection with squeezed GRU and information entropy map
https://arxiv.org/abs/2106.07224
"""
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: Optional[int] = None,
num_layers: int = 1,
linear_groups: int = 8,
batch_first: bool = True,
skip_op: str = "none",
activation_layer: str = "identity",
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear_in = nn.Sequential(
GroupedLinear(
input_size=input_size,
hidden_size=hidden_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
# gru skip operator
self.gru_skip_op = None
if skip_op == "none":
self.gru_skip_op = None
elif skip_op == "identity":
if not input_size != output_size:
raise AssertionError("Dimensions do not match")
self.gru_skip_op = nn.Identity()
elif skip_op == "grouped_linear":
self.gru_skip_op = GroupedLinear(
input_size=hidden_size,
hidden_size=hidden_size,
groups=linear_groups,
)
else:
raise NotImplementedError()
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
bidirectional=False,
)
if output_size is not None:
self.linear_out = nn.Sequential(
GroupedLinear(
input_size=hidden_size,
hidden_size=output_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
else:
self.linear_out = nn.Identity()
def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]:
# inputs: shape: [b, t, h]
x = self.linear_in.forward(inputs)
x, h = self.gru.forward(x, h)
x = self.linear_out(x)
if self.gru_skip_op is not None:
x = x + self.gru_skip_op(inputs)
return x, h
class Add(nn.Module):
def forward(self, a, b):
return a + b
class Concat(nn.Module):
def forward(self, a, b):
return torch.cat((a, b), dim=-1)
class Encoder(nn.Module):
def __init__(self, config: DfNetConfig):
super(Encoder, self).__init__()
self.embedding_input_size = config.conv_channels * config.erb_bins // 4
self.embedding_output_size = config.conv_channels * config.erb_bins // 4
self.embedding_hidden_size = config.embedding_hidden_size
self.spec_conv0 = CausalConv2d(
in_channels=1,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.spec_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.spec_conv2 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.spec_conv3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.df_conv0 = CausalConv2d(
in_channels=2,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
fstride=1,
)
self.df_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.df_fc_emb = nn.Sequential(
GroupedLinear(
config.conv_channels * config.df_bins // 2,
self.embedding_input_size,
groups=config.encoder_linear_groups
),
nn.ReLU(inplace=True)
)
if config.encoder_combine_op == "concat":
self.embedding_input_size *= 2
self.combine = Concat()
else:
self.combine = Add()
# emb_gru
if config.spec_bins % 8 != 0:
raise AssertionError("spec_bins should be divisible by 8")
self.emb_gru = SqueezedGRU_S(
self.embedding_input_size,
self.embedding_hidden_size,
output_size=self.embedding_output_size,
num_layers=1,
batch_first=True,
skip_op=config.encoder_emb_skip_op,
linear_groups=config.encoder_emb_linear_groups,
activation_layer="relu",
)
# lsnr
self.lsnr_fc = nn.Sequential(
nn.Linear(self.embedding_output_size, 1),
nn.Sigmoid()
)
self.lsnr_scale = config.max_local_snr - config.min_local_snr
self.lsnr_offset = config.min_local_snr
def forward(self,
feat_erb: torch.Tensor,
feat_spec: torch.Tensor,
hidden_state: torch.Tensor = None,
):
# feat_erb shape: (b, 1, t, erb_bins)
e0 = self.spec_conv0.forward(feat_erb)
e1 = self.spec_conv1.forward(e0)
e2 = self.spec_conv2.forward(e1)
e3 = self.spec_conv3.forward(e2)
# e0 shape: [b, c, t, erb_bins]
# e1 shape: [b, c, t, erb_bins // 2]
# e2 shape: [b, c, t, erb_bins // 4]
# e3 shape: [b, c, t, erb_bins // 4]
# e3 shape: [b, 64, t, 32/4=8]
# feat_spec, shape: (b, 2, t, df_bins)
c0 = self.df_conv0(feat_spec)
c1 = self.df_conv1(c0)
# c0 shape: [b, c, t, df_bins]
# c1 shape: [b, c, t, df_bins // 2]
# c1 shape: [b, 64, t, 96/2=48]
cemb = c1.permute(0, 2, 3, 1)
# cemb shape: [b, t, df_bins // 2, c]
cemb = cemb.flatten(2)
# cemb shape: [b, t, df_bins // 2 * c]
# cemb shape: [b, t, 96/2*64=3072]
cemb = self.df_fc_emb.forward(cemb)
# cemb shape: [b, t, erb_bins // 4 * c]
# cemb shape: [b, t, 32/4*64=512]
# e3 shape: [b, c, t, erb_bins // 4]
emb = e3.permute(0, 2, 3, 1)
# emb shape: [b, t, erb_bins // 4, c]
emb = emb.flatten(2)
# emb shape: [b, t, erb_bins // 4 * c]
# emb shape: [b, t, 32/4*64=512]
emb = self.combine(emb, cemb)
# if concat; emb shape: [b, t, spec_bins // 4 * c * 2]
# if add; emb shape: [b, t, spec_bins // 4 * c]
emb, h = self.emb_gru.forward(emb, hidden_state)
# emb shape: [b, t, spec_dim // 4 * c]
# h shape: [b, 1, spec_dim]
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
# lsnr shape: [b, t, 1]
return e0, e1, e2, e3, emb, c0, lsnr, h
class Decoder(nn.Module):
"""ErbDecoder"""
def __init__(self, config: DfNetConfig):
super(Decoder, self).__init__()
if config.spec_bins % 8 != 0:
raise AssertionError("spec_bins should be divisible by 8")
self.emb_in_dim = config.conv_channels * config.erb_bins // 4
self.emb_out_dim = config.conv_channels * config.erb_bins // 4
self.emb_hidden_dim = config.decoder_emb_hidden_size
self.emb_gru = SqueezedGRU_S(
self.emb_in_dim,
self.emb_hidden_dim,
output_size=self.emb_out_dim,
num_layers=config.decoder_emb_num_layers - 1,
batch_first=True,
skip_op=config.decoder_emb_skip_op,
linear_groups=config.decoder_emb_linear_groups,
activation_layer="relu",
)
self.conv3p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.convt3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.conv2p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.convt2 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.conv1p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.convt1 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
lookahead=config.conv_lookahead,
)
self.conv0p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
self.conv0_out = CausalConv2d(
in_channels=config.conv_channels,
out_channels=1,
kernel_size=config.conv_kernel_size_inner,
activation_layer="sigmoid",
bias=False,
separable=True,
fstride=1,
lookahead=config.conv_lookahead,
)
def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor:
# Estimates erb mask
b, _, t, f8 = e3.shape
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
emb, _ = self.emb_gru.forward(emb)
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
e3 = self.convt3(self.conv3p(e3) + emb)
# e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
e2 = self.convt2(self.conv2p(e2) + e3)
# e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
e1 = self.convt1(self.conv1p(e1) + e2)
# e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
mask = self.conv0_out(self.conv0p(e0) + e1)
# mask shape: [batch_size, 1, time_steps, freq_dim]
return mask
class DfDecoder(nn.Module):
def __init__(self, config: DfNetConfig):
super(DfDecoder, self).__init__()
self.embedding_input_size = config.conv_channels * config.erb_bins // 4
self.df_decoder_hidden_size = config.df_decoder_hidden_size
self.df_num_layers = config.df_num_layers
self.df_order = config.df_order
self.df_bins = config.df_bins
self.df_out_ch = config.df_order * 2
self.df_convp = CausalConv2d(
config.conv_channels,
self.df_out_ch,
fstride=1,
kernel_size=(config.df_pathway_kernel_size_t, 1),
separable=True,
bias=False,
)
self.df_gru = SqueezedGRU_S(
self.embedding_input_size,
self.df_decoder_hidden_size,
num_layers=self.df_num_layers,
batch_first=True,
skip_op="none",
activation_layer="relu",
)
if config.df_gru_skip == "none":
self.df_skip = None
elif config.df_gru_skip == "identity":
if config.embedding_hidden_size != config.df_decoder_hidden_size:
raise AssertionError("Dimensions do not match")
self.df_skip = nn.Identity()
elif config.df_gru_skip == "grouped_linear":
self.df_skip = GroupedLinear(
self.embedding_input_size,
self.df_decoder_hidden_size,
groups=config.df_decoder_linear_groups
)
else:
raise NotImplementedError()
self.df_out: nn.Module
out_dim = self.df_bins * self.df_out_ch
self.df_out = nn.Sequential(
GroupedLinear(
input_size=self.df_decoder_hidden_size,
hidden_size=out_dim,
groups=config.df_decoder_linear_groups,
# groups = self.df_bins // 5,
),
nn.Tanh()
)
self.df_fc_a = nn.Sequential(
nn.Linear(self.df_decoder_hidden_size, 1),
nn.Sigmoid()
)
def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor:
# emb shape: [batch_size, time_steps, df_bins // 4 * channels]
b, t, _ = emb.shape
df_coefs, _ = self.df_gru(emb)
if self.df_skip is not None:
df_coefs = df_coefs + self.df_skip(emb)
# df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
# c0 shape: [batch_size, channels, time_steps, df_bins]
c0 = self.df_convp(c0)
# c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
c0 = c0.permute(0, 2, 3, 1)
# c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
# df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
df_coefs = df_coefs + c0
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
return df_coefs
class DfOutputReshapeMF(nn.Module):
"""Coefficients output reshape for multiframe/MultiFrameModule
Requires input of shape B, C, T, F, 2.
"""
def __init__(self, df_order: int, df_bins: int):
super().__init__()
self.df_order = df_order
self.df_bins = df_bins
def forward(self, coefs: torch.Tensor) -> torch.Tensor:
# [B, T, F, O*2] -> [B, O, T, F, 2]
new_shape = list(coefs.shape)
new_shape[-1] = -1
new_shape.append(2)
coefs = coefs.view(new_shape)
coefs = coefs.permute(0, 3, 1, 2, 4)
return coefs
class Mask(nn.Module):
def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
super().__init__()
self.use_post_filter = use_post_filter
self.eps = eps
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
"""
Post-Filter
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
https://arxiv.org/abs/2008.04259
:param mask: Real valued mask, typically of shape [B, C, T, F].
:param beta: Global gain factor.
:return:
"""
mask_sin = mask * torch.sin(np.pi * mask / 2)
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
return mask_pf
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# spec shape: [b, 1, t, spec_bins, 2]
if not self.training and self.use_post_filter:
mask = self.post_filter(mask)
# mask shape: [b, 1, t, spec_bins]
mask = mask.unsqueeze(4)
# mask shape: [b, 1, t, spec_bins, 1]
return spec * mask
class DeepFiltering(nn.Module):
def __init__(self,
df_bins: int,
df_order: int,
lookahead: int = 0,
):
super(DeepFiltering, self).__init__()
self.df_bins = df_bins
self.df_order = df_order
self.need_unfold = df_order > 1
self.lookahead = lookahead
self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
def spec_unfold(self, spec: torch.Tensor):
"""
Pads and unfolds the spectrogram according to frame_size.
:param spec: complex Tensor, Spectrogram of shape [B, C, T, F].
:return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size.
"""
if self.need_unfold:
# spec shape: [batch_size, spec_bins, time_steps]
spec_pad = self.pad(spec)
# spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins]
spec_unfold = spec_pad.unfold(2, self.df_order, 1)
# spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order]
return spec_unfold
else:
return spec.unsqueeze(-1)
def forward(self,
spec: torch.Tensor,
coefs: torch.Tensor,
):
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous()))
# spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order]
# coefs shape: [batch_size, df_order, time_steps, df_bins, 2]
coefs = torch.view_as_complex(coefs.contiguous())
# coefs shape: [batch_size, df_order, time_steps, df_bins]
spec_f = spec_u.narrow(-2, 0, self.df_bins)
# spec_f shape: [batch_size, 1, time_steps, df_bins, df_order]
coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:])
# coefs shape: [batch_size, 1, df_order, time_steps, df_bins]
spec_f = self.df(spec_f, coefs)
# spec_f shape: [batch_size, 1, time_steps, df_bins]
if self.training:
spec = spec.clone()
spec[..., :self.df_bins, :] = torch.view_as_real(spec_f)
# spec shape: [batch_size, 1, time_steps, spec_bins, 2]
return spec
@staticmethod
def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
"""
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
:param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N].
:param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F].
:return: (complex Tensor). Spectrogram of shape [B, C, T, F].
"""
return torch.einsum("...tfn,...ntf->...tf", spec, coefs)
class DfNet(nn.Module):
"""
我感觉这个模型没办法实现完全一致的流式推理。
"""
def __init__(self, config: DfNetConfig):
super(DfNet, self).__init__()
self.config = config
self.eps = 1e-12
self.freq_bins = self.config.nfft // 2 + 1
self.nfft = config.nfft
self.win_size = config.win_size
self.hop_size = config.hop_size
self.win_type = config.win_type
self.erb_bands = ErbBands(
sample_rate=config.sample_rate,
nfft=config.nfft,
erb_bins=config.erb_bins,
min_freq_bins_for_erb=config.min_freq_bins_for_erb,
)
self.stft = ConvSTFT(
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
power=None,
requires_grad=False
)
self.istft = ConviSTFT(
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
requires_grad=False
)
self.encoder = Encoder(config)
self.decoder = Decoder(config)
self.df_decoder = DfDecoder(config)
self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
self.df_op = DeepFiltering(
df_bins=config.df_bins,
df_order=config.df_order,
lookahead=config.df_lookahead,
)
self.mask = Mask(use_post_filter=config.use_post_filter)
self.lsnr_fn = LocalSnrTarget(
sample_rate=config.sample_rate,
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
n_frame=config.n_frame,
min_local_snr=config.min_local_snr,
max_local_snr=config.max_local_snr,
db=True,
)
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
if signal.dim() == 2:
signal = torch.unsqueeze(signal, dim=1)
_, _, n_samples = signal.shape
remainder = (n_samples - self.win_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
return signal
def feature_prepare(self, signal: torch.Tensor):
# noisy shape: [b, num_samples_pad]
spec_cmp = self.stft.forward(signal)
# spec_complex shape: [b, f, t], torch.complex64
spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2)
# spec_complex shape: [b, t, f], torch.complex64
spec_cmp_real = torch.view_as_real(spec_cmp)
# spec_cmp_real shape: [b, t, f, 2]
spec_mag = torch.abs(spec_cmp)
spec_pow = torch.square(spec_mag)
# shape: [b, t, f]
spec = torch.unsqueeze(spec_cmp_real, dim=1)
# spec shape: [b, 1, t, f, 2]
feat_erb = self.erb_bands.erb_scale(spec_pow, db=True)
# feat_erb shape: [b, t, erb_bins]
feat_erb = torch.unsqueeze(feat_erb, dim=1)
# feat_erb shape: [b, 1, t, erb_bins]
feat_spec = spec_cmp_real.permute(0, 3, 1, 2)
# feat_spec shape: [b, 2, t, f]
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
# feat_spec shape: [b, 2, t, df_bins]
return spec, feat_erb, feat_spec
def forward(self,
noisy: torch.Tensor,
):
"""
:param noisy:
:return:
est_spec: shape: [b, 257*2, t]
est_wav: shape: [b, num_samples]
est_mask: shape: [b, 257, t]
lsnr: shape: [b, 1, t]
"""
n_samples = noisy.shape[-1]
noisy = self.signal_prepare(noisy)
spec, feat_erb, feat_spec = self.feature_prepare(noisy)
e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_erb, feat_spec)
mask = self.decoder.forward(emb, e3, e2, e1, e0)
# mask shape: [b, 1, t, erb_bins]
mask = self.erb_bands.erb_scale_inv(mask)
# mask shape: [b, 1, t, f]
if torch.any(mask > 1) or torch.any(mask < 0):
raise AssertionError
spec_m = self.mask.forward(spec, mask)
# spec_m shape: [b, 1, t, f, 2]
spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
# spec_m shape: [b, 1, t, spec_bins, 2]
# lsnr shape: [b, t, 1]
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
# lsnr shape: [b, 1, t]
df_coefs = self.df_decoder.forward(emb, c0)
df_coefs = self.df_out_transform(df_coefs)
# df_coefs shape: [b, df_order, t, df_bins, 2]
spec_ = spec[:, :, :, :self.config.spec_bins, :]
# spec shape: [b, 1, t, spec_bins, 2]
spec_e = self.df_op.forward(spec_, df_coefs)
# spec_e shape: [b, 1, t, spec_bins, 2]
spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :]
spec_e = torch.squeeze(spec_e, dim=1)
spec_e = spec_e.permute(0, 2, 1, 3)
# spec_e shape: [b, spec_bins, t, 2]
# spec_e shape: [b, spec_bins, t, 2]
est_spec = torch.complex(real=spec_e[..., 0], imag=spec_e[..., 1])
# est_spec shape: [b, spec_bins, t], torch.complex64
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
# est_spec shape: [b, f, t], torch.complex64
est_wav = self.istft.forward(est_spec)
est_wav = est_wav[:, :, :n_samples]
# est_wav shape: [b, 1, n_samples]
est_mask = torch.squeeze(mask, dim=1)
est_mask = est_mask.permute(0, 2, 1)
# est_mask shape: [b, f, t]
return est_spec, est_wav, est_mask, lsnr
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
"""
:param est_mask: torch.Tensor, shape: [b, 257, t]
:param clean:
:param noisy:
:return:
"""
if noisy.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
noise = noisy - clean
clean = self.signal_prepare(clean)
noise = self.signal_prepare(noise)
stft_clean = self.stft.forward(clean)
mag_clean = torch.abs(stft_clean)
stft_noise = self.stft.forward(noise)
mag_noise = torch.abs(stft_noise)
gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1)
loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean")
return loss
def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
if noisy.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
noise = noisy - clean
clean = self.signal_prepare(clean)
noise = self.signal_prepare(noise)
stft_clean = self.stft.forward(clean)
stft_noise = self.stft.forward(noise)
# shape: [b, f, t]
stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
# shape: [b, t, f]
stft_clean = torch.unsqueeze(stft_clean, dim=1)
stft_noise = torch.unsqueeze(stft_noise, dim=1)
# shape: [b, 1, t, f]
# lsnr shape: [b, 1, t]
lsnr = lsnr.squeeze(1)
# lsnr shape: [b, t]
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
# lsnr_gth shape: [b, t]
loss = F.mse_loss(lsnr, lsnr_gth)
return loss
class DfNetPretrainedModel(DfNet):
def __init__(self,
config: DfNetConfig,
):
super(DfNetPretrainedModel, self).__init__(
config=config,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
config = DfNetConfig()
model = DfNetPretrainedModel(config=config)
noisy = torch.randn(size=(1, 16000), dtype=torch.float32)
est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
print(f"est_spec.shape: {est_spec.shape}")
print(f"est_wav.shape: {est_wav.shape}")
print(f"est_mask.shape: {est_mask.shape}")
print(f"lsnr.shape: {lsnr.shape}")
return
if __name__ == "__main__":
main()