Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
DeepFilterNet 的原生实现不直接支持流式推理 | |
社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现 | |
https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF | |
此文件试图实现一个支持流式推理的 dfnet | |
""" | |
import os | |
import math | |
from collections import defaultdict | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config | |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT | |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget | |
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands | |
from toolbox.torchaudio.modules.utils.ema import ErbEMA, SpecEMA | |
MODEL_FILE = "model.pt" | |
norm_layer_dict = { | |
"batch_norm_2d": torch.nn.BatchNorm2d | |
} | |
activation_layer_dict = { | |
"relu": torch.nn.ReLU, | |
"identity": torch.nn.Identity, | |
"sigmoid": torch.nn.Sigmoid, | |
} | |
class CausalConv2d(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
pad_f_dim: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
): | |
super(CausalConv2d, self).__init__() | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) | |
if pad_f_dim: | |
fpad = kernel_size[1] // 2 + dilation - 1 | |
else: | |
fpad = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
self.lookback = kernel_size[0] - 1 | |
if self.lookback > 0: | |
self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) | |
else: | |
self.tpad = nn.Identity() | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
if max(kernel_size) == 1: | |
separable = False | |
self.conv = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(0, fpad), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
if separable: | |
self.convp = nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
else: | |
self.convp = nn.Identity() | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
self.norm = norm_layer(out_channels) | |
else: | |
self.norm = nn.Identity() | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
self.activation = activation_layer() | |
else: | |
self.activation = nn.Identity() | |
def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None): | |
""" | |
:param inputs: shape: [b, c, t, f] | |
:param cache: shape: [b, c, lookback, f]; | |
:return: | |
""" | |
x = inputs | |
if cache is None: | |
x = self.tpad(x) | |
else: | |
x = torch.concat(tensors=[cache, x], dim=2) | |
new_cache = None | |
if self.lookback > 0: | |
new_cache = x[:, :, -self.lookback:, :] | |
x = self.conv(x) | |
x = self.convp(x) | |
x = self.norm(x) | |
x = self.activation(x) | |
return x, new_cache | |
class CausalConvTranspose2dErrorCase(nn.Module): | |
""" | |
错误的缓存方法。 | |
""" | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
pad_f_dim: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
): | |
super(CausalConvTranspose2dErrorCase, self).__init__() | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size | |
if pad_f_dim: | |
fpad = kernel_size[1] // 2 | |
else: | |
fpad = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
self.lookback = kernel_size[0] - 1 | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
self.convt = nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(0, fpad), | |
output_padding=(0, fpad), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
if separable: | |
self.convp = nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
else: | |
self.convp = nn.Identity() | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
self.norm = norm_layer(out_channels) | |
else: | |
self.norm = nn.Identity() | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
self.activation = activation_layer() | |
else: | |
self.activation = nn.Identity() | |
def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): | |
""" | |
:param inputs: shape: [b, c, t, f] | |
:param cache: shape: [b, c, lookback, f]; | |
:return: | |
""" | |
x = inputs | |
# x shape: [b, c, t, f] | |
x = self.convt(x) | |
# x shape: [b, c, t+lookback, f] | |
new_cache = None | |
if self.lookback > 0: | |
if cache is not None: | |
x = torch.concat(tensors=[ | |
x[:, :, :self.lookback, :] + cache, | |
x[:, :, self.lookback:, :] | |
], dim=2) | |
x = x[:, :, :-self.lookback, :] | |
new_cache = x[:, :, -self.lookback:, :] | |
x = self.convp(x) | |
x = self.norm(x) | |
x = self.activation(x) | |
return x, new_cache | |
class CausalConvTranspose2d(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
pad_f_dim: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
): | |
super(CausalConvTranspose2d, self).__init__() | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size | |
if pad_f_dim: | |
fpad = kernel_size[1] // 2 | |
else: | |
fpad = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
self.lookback = kernel_size[0] - 1 | |
if self.lookback > 0: | |
self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) | |
else: | |
self.tpad = nn.Identity() | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
self.convt = nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(kernel_size[0] - 1, fpad + dilation - 1), | |
output_padding=(0, fpad), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
if separable: | |
self.convp = nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
else: | |
self.convp = nn.Identity() | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
self.norm = norm_layer(out_channels) | |
else: | |
self.norm = nn.Identity() | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
self.activation = activation_layer() | |
else: | |
self.activation = nn.Identity() | |
def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): | |
""" | |
:param inputs: shape: [b, c, t, f] | |
:param cache: shape: [b, c, lookback, f]; | |
:return: | |
""" | |
x = inputs | |
# x shape: [b, c, t, f] | |
x = self.convt(x) | |
# x shape: [b, c, t+lookback, f] | |
if cache is None: | |
x = self.tpad(x) | |
else: | |
x = torch.concat(tensors=[cache, x], dim=2) | |
new_cache = None | |
if self.lookback > 0: | |
new_cache = x[:, :, -self.lookback:, :] | |
x = self.convp(x) | |
x = self.norm(x) | |
x = self.activation(x) | |
return x, new_cache | |
class GroupedLinear(nn.Module): | |
def __init__(self, input_size: int, hidden_size: int, groups: int = 1): | |
super().__init__() | |
# self.weight: Tensor | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.groups = groups | |
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" | |
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" | |
self.ws = input_size // groups | |
self.register_parameter( | |
"weight", | |
torch.nn.Parameter( | |
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True | |
), | |
) | |
self.reset_parameters() | |
def reset_parameters(self): | |
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# x: [..., I] | |
b, t, f = x.shape | |
if f != self.input_size: | |
raise AssertionError | |
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws] | |
new_shape = (b, t, self.groups, self.ws) | |
x = x.view(new_shape) | |
# The better way, but not supported by torchscript | |
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] | |
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] | |
x = x.flatten(2, 3) | |
# x: [b, t, h] | |
return x | |
def __repr__(self): | |
cls = self.__class__.__name__ | |
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" | |
class SqueezedGRU_S(nn.Module): | |
""" | |
SGE net: Video object detection with squeezed GRU and information entropy map | |
https://arxiv.org/abs/2106.07224 | |
""" | |
def __init__( | |
self, | |
input_size: int, | |
hidden_size: int, | |
output_size: Optional[int] = None, | |
num_layers: int = 1, | |
linear_groups: int = 8, | |
batch_first: bool = True, | |
skip_op: str = "none", | |
activation_layer: str = "identity", | |
): | |
super().__init__() | |
self.input_size = input_size | |
self.hidden_size = hidden_size | |
self.linear_in = nn.Sequential( | |
GroupedLinear( | |
input_size=input_size, | |
hidden_size=hidden_size, | |
groups=linear_groups, | |
), | |
activation_layer_dict[activation_layer](), | |
) | |
# gru skip operator | |
self.gru_skip_op = None | |
if skip_op == "none": | |
self.gru_skip_op = None | |
elif skip_op == "identity": | |
if not input_size != output_size: | |
raise AssertionError("Dimensions do not match") | |
self.gru_skip_op = nn.Identity() | |
elif skip_op == "grouped_linear": | |
self.gru_skip_op = GroupedLinear( | |
input_size=hidden_size, | |
hidden_size=hidden_size, | |
groups=linear_groups, | |
) | |
else: | |
raise NotImplementedError() | |
self.gru = nn.GRU( | |
input_size=hidden_size, | |
hidden_size=hidden_size, | |
num_layers=num_layers, | |
batch_first=batch_first, | |
bidirectional=False, | |
) | |
if output_size is not None: | |
self.linear_out = nn.Sequential( | |
GroupedLinear( | |
input_size=hidden_size, | |
hidden_size=output_size, | |
groups=linear_groups, | |
), | |
activation_layer_dict[activation_layer](), | |
) | |
else: | |
self.linear_out = nn.Identity() | |
def forward(self, inputs: torch.Tensor, hx: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: | |
# inputs: shape: [b, t, h] | |
x = self.linear_in.forward(inputs) | |
x, hx = self.gru.forward(x, hx) | |
x = self.linear_out(x) | |
if self.gru_skip_op is not None: | |
x = x + self.gru_skip_op(inputs) | |
return x, hx | |
class Add(nn.Module): | |
def forward(self, a, b): | |
return a + b | |
class Concat(nn.Module): | |
def forward(self, a, b): | |
return torch.cat((a, b), dim=-1) | |
class Encoder(nn.Module): | |
def __init__(self, config: DfNet2Config): | |
super(Encoder, self).__init__() | |
self.embedding_input_size = config.conv_channels * config.erb_bins // 4 | |
self.embedding_output_size = config.conv_channels * config.erb_bins // 4 | |
self.embedding_hidden_size = config.embedding_hidden_size | |
self.spec_conv0 = CausalConv2d( | |
in_channels=1, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_input, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.spec_conv1 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.spec_conv2 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.spec_conv3 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.df_conv0 = CausalConv2d( | |
in_channels=2, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_input, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.df_conv1 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.df_fc_emb = nn.Sequential( | |
GroupedLinear( | |
config.conv_channels * config.df_bins // 2, | |
self.embedding_input_size, | |
groups=config.encoder_linear_groups | |
), | |
nn.ReLU(inplace=True) | |
) | |
if config.encoder_combine_op == "concat": | |
self.embedding_input_size *= 2 | |
self.combine = Concat() | |
else: | |
self.combine = Add() | |
# emb_gru | |
if config.spec_bins % 8 != 0: | |
raise AssertionError("spec_bins should be divisible by 8") | |
self.emb_gru = SqueezedGRU_S( | |
self.embedding_input_size, | |
self.embedding_hidden_size, | |
output_size=self.embedding_output_size, | |
num_layers=1, | |
batch_first=True, | |
skip_op=config.encoder_emb_skip_op, | |
linear_groups=config.encoder_emb_linear_groups, | |
activation_layer="relu", | |
) | |
# lsnr | |
self.lsnr_fc = nn.Sequential( | |
nn.Linear(self.embedding_output_size, 1), | |
nn.Sigmoid() | |
) | |
self.lsnr_scale = config.max_local_snr - config.min_local_snr | |
self.lsnr_offset = config.min_local_snr | |
def forward(self, | |
feat_erb: torch.Tensor, | |
feat_spec: torch.Tensor, | |
cache_dict: dict = None, | |
): | |
if cache_dict is None: | |
cache_dict = defaultdict(lambda: None) | |
cache0 = cache_dict["cache0"] | |
cache1 = cache_dict["cache1"] | |
cache2 = cache_dict["cache2"] | |
cache3 = cache_dict["cache3"] | |
cache4 = cache_dict["cache4"] | |
cache5 = cache_dict["cache5"] | |
cache6 = cache_dict["cache6"] | |
# feat_erb shape: (b, 1, t, erb_bins) | |
e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0) | |
e1, new_cache1 = self.spec_conv1.forward(e0, cache=cache1) | |
e2, new_cache2 = self.spec_conv2.forward(e1, cache=cache2) | |
e3, new_cache3 = self.spec_conv3.forward(e2, cache=cache3) | |
# e0 shape: [b, c, t, erb_bins] | |
# e1 shape: [b, c, t, erb_bins // 2] | |
# e2 shape: [b, c, t, erb_bins // 4] | |
# e3 shape: [b, c, t, erb_bins // 4] | |
# e3 shape: [b, 64, t, 32/4=8] | |
# feat_spec, shape: (b, 2, t, df_bins) | |
c0, new_cache4 = self.df_conv0.forward(feat_spec, cache=cache4) | |
c1, new_cache5 = self.df_conv1.forward(c0, cache=cache5) | |
# c0 shape: [b, c, t, df_bins] | |
# c1 shape: [b, c, t, df_bins // 2] | |
# c1 shape: [b, 64, t, 96/2=48] | |
cemb = c1.permute(0, 2, 3, 1) | |
# cemb shape: [b, t, df_bins // 2, c] | |
cemb = cemb.flatten(2) | |
# cemb shape: [b, t, df_bins // 2 * c] | |
# cemb shape: [b, t, 96/2*64=3072] | |
cemb = self.df_fc_emb.forward(cemb) | |
# cemb shape: [b, t, erb_bins // 4 * c] | |
# cemb shape: [b, t, 32/4*64=512] | |
# e3 shape: [b, c, t, erb_bins // 4] | |
emb = e3.permute(0, 2, 3, 1) | |
# emb shape: [b, t, erb_bins // 4, c] | |
emb = emb.flatten(2) | |
# emb shape: [b, t, erb_bins // 4 * c] | |
# emb shape: [b, t, 32/4*64=512] | |
emb = self.combine(emb, cemb) | |
# if concat; emb shape: [b, t, spec_bins // 4 * c * 2] | |
# if add; emb shape: [b, t, spec_bins // 4 * c] | |
emb, new_cache6 = self.emb_gru.forward(emb, hx=cache6) | |
# emb shape: [b, t, spec_dim // 4 * c] | |
# h shape: [b, 1, spec_dim] | |
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset | |
# lsnr shape: [b, t, 1] | |
new_cache_dict = { | |
"cache0": new_cache0, | |
"cache1": new_cache1, | |
"cache2": new_cache2, | |
"cache3": new_cache3, | |
"cache4": new_cache4, | |
"cache5": new_cache5, | |
"cache6": new_cache6, | |
} | |
return e0, e1, e2, e3, emb, c0, lsnr, new_cache_dict | |
class ErbDecoder(nn.Module): | |
def __init__(self, config: DfNet2Config): | |
super(ErbDecoder, self).__init__() | |
if config.spec_bins % 8 != 0: | |
raise AssertionError("spec_bins should be divisible by 8") | |
self.emb_in_dim = config.conv_channels * config.erb_bins // 4 | |
self.emb_out_dim = config.conv_channels * config.erb_bins // 4 | |
self.emb_hidden_dim = config.decoder_emb_hidden_size | |
self.emb_gru = SqueezedGRU_S( | |
self.emb_in_dim, | |
self.emb_hidden_dim, | |
output_size=self.emb_out_dim, | |
num_layers=config.decoder_emb_num_layers - 1, | |
batch_first=True, | |
skip_op=config.decoder_emb_skip_op, | |
linear_groups=config.decoder_emb_linear_groups, | |
activation_layer="relu", | |
) | |
self.conv3p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.convt3 = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.conv_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.conv2p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.convt2 = CausalConvTranspose2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.convt_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.conv1p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.convt1 = CausalConvTranspose2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=config.convt_kernel_size_inner, | |
bias=False, | |
separable=True, | |
fstride=2, | |
) | |
self.conv0p = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=config.conv_channels, | |
kernel_size=1, | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
self.conv0_out = CausalConv2d( | |
in_channels=config.conv_channels, | |
out_channels=1, | |
kernel_size=config.conv_kernel_size_inner, | |
activation_layer="sigmoid", | |
bias=False, | |
separable=True, | |
fstride=1, | |
) | |
def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor: | |
if cache_dict is None: | |
cache_dict = defaultdict(lambda: None) | |
cache0 = cache_dict["cache0"] | |
cache1 = cache_dict["cache1"] | |
cache2 = cache_dict["cache2"] | |
cache3 = cache_dict["cache3"] | |
cache4 = cache_dict["cache4"] | |
# Estimates erb mask | |
b, _, t, f8 = e3.shape | |
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] | |
emb, new_cache0 = self.emb_gru.forward(emb, hx=cache0) | |
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] | |
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) | |
e3, new_cache1 = self.convt3.forward(self.conv3p(e3)[0] + emb, cache=cache1) | |
# e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] | |
e2, new_cache2 = self.convt2.forward(self.conv2p(e2)[0] + e3, cache=cache2) | |
# e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] | |
e1, new_cache3 = self.convt1.forward(self.conv1p(e1)[0] + e2, cache=cache3) | |
# e1 shape: [batch_size, conv_channels, time_steps, freq_dim] | |
mask, new_cache4 = self.conv0_out.forward(self.conv0p(e0)[0] + e1, cache=cache4) | |
# mask shape: [batch_size, 1, time_steps, freq_dim] | |
new_cache_dict = { | |
"cache0": new_cache0, | |
"cache1": new_cache1, | |
"cache2": new_cache2, | |
"cache3": new_cache3, | |
"cache4": new_cache4, | |
} | |
return mask, new_cache_dict | |
class DfDecoder(nn.Module): | |
def __init__(self, config: DfNet2Config): | |
super(DfDecoder, self).__init__() | |
self.embedding_input_size = config.conv_channels * config.erb_bins // 4 | |
self.df_decoder_hidden_size = config.df_decoder_hidden_size | |
self.df_num_layers = config.df_num_layers | |
self.df_order = config.df_order | |
self.df_bins = config.df_bins | |
self.df_out_ch = config.df_order * 2 | |
self.df_convp = CausalConv2d( | |
config.conv_channels, | |
self.df_out_ch, | |
fstride=1, | |
kernel_size=(config.df_pathway_kernel_size_t, 1), | |
separable=True, | |
bias=False, | |
) | |
self.df_gru = SqueezedGRU_S( | |
self.embedding_input_size, | |
self.df_decoder_hidden_size, | |
num_layers=self.df_num_layers, | |
batch_first=True, | |
skip_op="none", | |
activation_layer="relu", | |
) | |
if config.df_gru_skip == "none": | |
self.df_skip = None | |
elif config.df_gru_skip == "identity": | |
if config.embedding_hidden_size != config.df_decoder_hidden_size: | |
raise AssertionError("Dimensions do not match") | |
self.df_skip = nn.Identity() | |
elif config.df_gru_skip == "grouped_linear": | |
self.df_skip = GroupedLinear( | |
self.embedding_input_size, | |
self.df_decoder_hidden_size, | |
groups=config.df_decoder_linear_groups | |
) | |
else: | |
raise NotImplementedError() | |
self.df_out: nn.Module | |
out_dim = self.df_bins * self.df_out_ch | |
self.df_out = nn.Sequential( | |
GroupedLinear( | |
input_size=self.df_decoder_hidden_size, | |
hidden_size=out_dim, | |
groups=config.df_decoder_linear_groups, | |
# groups = self.df_bins // 5, | |
), | |
nn.Tanh() | |
) | |
self.df_fc_a = nn.Sequential( | |
nn.Linear(self.df_decoder_hidden_size, 1), | |
nn.Sigmoid() | |
) | |
def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor: | |
if cache_dict is None: | |
cache_dict = defaultdict(lambda: None) | |
cache0 = cache_dict["cache0"] | |
cache1 = cache_dict["cache1"] | |
# emb shape: [batch_size, time_steps, df_bins // 4 * channels] | |
b, t, _ = emb.shape | |
df_coefs, new_cache0 = self.df_gru.forward(emb, hx=cache0) | |
if self.df_skip is not None: | |
df_coefs = df_coefs + self.df_skip(emb) | |
# df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] | |
# c0 shape: [batch_size, channels, time_steps, df_bins] | |
c0, new_cache1 = self.df_convp.forward(c0, cache=cache1) | |
# c0 shape: [batch_size, df_order * 2, time_steps, df_bins] | |
c0 = c0.permute(0, 2, 3, 1) | |
# c0 shape: [batch_size, time_steps, df_bins, df_order * 2] | |
df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order | |
# df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] | |
df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) | |
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] | |
df_coefs = df_coefs + c0 | |
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] | |
new_cache_dict = { | |
"cache0": new_cache0, | |
"cache1": new_cache1, | |
} | |
return df_coefs, new_cache_dict | |
class DfOutputReshapeMF(nn.Module): | |
"""Coefficients output reshape for multiframe/MultiFrameModule | |
Requires input of shape B, C, T, F, 2. | |
""" | |
def __init__(self, df_order: int, df_bins: int): | |
super().__init__() | |
self.df_order = df_order | |
self.df_bins = df_bins | |
def forward(self, coefs: torch.Tensor) -> torch.Tensor: | |
# [B, T, F, O*2] -> [B, O, T, F, 2] | |
new_shape = list(coefs.shape) | |
new_shape[-1] = -1 | |
new_shape.append(2) | |
coefs = coefs.view(new_shape) | |
coefs = coefs.permute(0, 3, 1, 2, 4) | |
return coefs | |
class Mask(nn.Module): | |
def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): | |
super().__init__() | |
self.use_post_filter = use_post_filter | |
self.eps = eps | |
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: | |
""" | |
Post-Filter | |
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. | |
https://arxiv.org/abs/2008.04259 | |
:param mask: Real valued mask, typically of shape [B, C, T, F]. | |
:param beta: Global gain factor. | |
:return: | |
""" | |
mask_sin = mask * torch.sin(np.pi * mask / 2) | |
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) | |
return mask_pf | |
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
# spec shape: [b, 1, t, spec_bins, 2] | |
if not self.training and self.use_post_filter: | |
mask = self.post_filter(mask) | |
# mask shape: [b, 1, t, spec_bins] | |
mask = mask.unsqueeze(4) | |
# mask shape: [b, 1, t, spec_bins, 1] | |
return spec * mask | |
class DeepFiltering(nn.Module): | |
def __init__(self, | |
df_bins: int, | |
df_order: int, | |
lookahead: int = 0, | |
): | |
super(DeepFiltering, self).__init__() | |
self.df_bins = df_bins | |
self.df_order = df_order | |
self.lookahead = lookahead | |
self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) | |
def forward(self, *args, **kwargs): | |
raise AssertionError("use `forward_offline` or `forward_online` stead.") | |
def spec_unfold_offline(self, spec: torch.Tensor) -> torch.Tensor: | |
""" | |
Pads and unfolds the spectrogram according to frame_size. | |
:param spec: shape: [b, c, t, f], dtype: torch.complex64 | |
:return: shape: [b, c, t, f, df_order] | |
""" | |
if self.df_order <= 1: | |
return spec.unsqueeze(-1) | |
# spec shape: [b, 1, t, f], dtype: torch.complex64 | |
spec = self.pad(spec) | |
# spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64 | |
spec_unfold = spec.unfold(dimension=2, size=self.df_order, step=1) | |
# spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64 | |
return spec_unfold | |
def forward_offline(self, | |
spec: torch.Tensor, | |
coefs: torch.Tensor, | |
): | |
# spec shape: [b, 1, t, spec_bins, 2] | |
spec_c = torch.view_as_complex(spec.contiguous()) | |
# spec_c shape: [b, 1, t, spec_bins] | |
spec_u = self.spec_unfold_offline(spec_c) | |
# spec_u shape: [b, 1, t, spec_bins, df_order] | |
spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins) | |
# spec_f shape: [b, 1, t, df_bins, df_order] | |
# coefs shape: [b, df_order, t, df_bins, 2] | |
coefs = torch.view_as_complex(coefs.contiguous()) | |
# coefs shape: [b, df_order, t, df_bins] | |
coefs = coefs.unsqueeze(dim=1) | |
# coefs shape: [b, 1, df_order, t, df_bins] | |
spec_f = self.df_offline(spec_f, coefs) | |
# spec_f shape: [b, 1, t, df_bins] | |
spec_f = torch.view_as_real(spec_f) | |
# spec_f shape: [b, 1, t, df_bins, 2] | |
return spec_f | |
def df_offline(self, spec: torch.Tensor, coefs: torch.Tensor): | |
""" | |
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. | |
:param spec: [b, 1, t, df_bins, df_order] complex. | |
:param coefs: [b, 1, df_order, t, df_bins] complex. | |
:return: [b, 1, t, df_bins] complex. | |
""" | |
spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs) | |
return spec_f | |
def spec_unfold_online(self, spec: torch.Tensor, cache_spec: torch.Tensor = None): | |
""" | |
Pads and unfolds the spectrogram according to frame_size. | |
:param spec: shape: [b, c, t, f], dtype: torch.complex64 | |
:param cache_spec: shape: [b, c, df_order-1, f], dtype: torch.complex64 | |
:return: shape: [b, c, t, f, df_order] | |
""" | |
if self.df_order <= 1: | |
return spec.unsqueeze(-1) | |
if cache_spec is None: | |
b, c, _, f = spec.shape | |
cache_spec = spec.new_zeros(size=(b, c, self.df_order-1, f)) | |
spec_pad = torch.concat(tensors=[ | |
cache_spec, spec | |
], dim=2) | |
new_cache_spec = spec_pad[:, :, -(self.df_order-1):, :] | |
# spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64 | |
spec_unfold = spec_pad.unfold(dimension=2, size=self.df_order, step=1) | |
# spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64 | |
return spec_unfold, new_cache_spec | |
def forward_online(self, | |
spec: torch.Tensor, | |
coefs: torch.Tensor, | |
cache_dict: dict = None, | |
): | |
if cache_dict is None: | |
cache_dict = defaultdict(lambda: None) | |
cache0 = cache_dict["cache0"] | |
cache1 = cache_dict["cache1"] | |
# spec shape: [b, 1, t, spec_bins, 2] | |
spec_c = torch.view_as_complex(spec.contiguous()) | |
# spec_c shape: [b, 1, t, spec_bins] | |
spec_u, new_cache0 = self.spec_unfold_online(spec_c, cache_spec=cache0) | |
# spec_u shape: [b, 1, t, spec_bins, df_order] | |
spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins) | |
# spec_f shape: [b, 1, t, df_bins, df_order] | |
# coefs shape: [b, df_order, t, df_bins, 2] | |
coefs = torch.view_as_complex(coefs.contiguous()) | |
# coefs shape: [b, df_order, t, df_bins] | |
coefs = coefs.unsqueeze(dim=1) | |
# coefs shape: [b, 1, df_order, t, df_bins] | |
spec_f, new_cache1 = self.df_online(spec_f, coefs, cache_coefs=cache1) | |
# spec_f shape: [b, 1, t, df_bins] | |
spec_f = torch.view_as_real(spec_f) | |
# spec_f shape: [b, 1, t, df_bins, 2] | |
new_cache_dict = { | |
"cache0": new_cache0, | |
"cache1": new_cache1, | |
} | |
return spec_f, new_cache_dict | |
def df_online(self, spec: torch.Tensor, coefs: torch.Tensor, cache_coefs: torch.Tensor = None) -> torch.Tensor: | |
""" | |
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. | |
:param spec: [b, 1, 1, df_bins, df_order] complex. | |
:param coefs: [b, 1, df_order, 1, df_bins] complex. | |
:param cache_coefs: [b, 1, df_order, lookahead, df_bins] complex. | |
:return: [b, 1, 1, df_bins] complex. | |
""" | |
if cache_coefs is None: | |
b, c, _, _, f = coefs.shape | |
cache_coefs = coefs.new_zeros(size=(b, c, self.df_order, self.lookahead, f)) | |
coefs_pad = torch.concat(tensors=[ | |
cache_coefs, coefs | |
], dim=3) | |
# coefs_pad shape: [b, 1, df_order, 1+lookahead, df_bins], torch.complex64. | |
coefs = coefs_pad[:, :, :, :-self.lookahead, :] | |
# coefs shape: [b, 1, df_order, 1, df_bins], torch.complex64. | |
new_cache_coefs = coefs_pad[:, :, :, -self.lookahead:, :] | |
# new_cache_coefs shape: [b, 1, df_order, lookahead, df_bins], torch.complex64. | |
spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs) | |
return spec_f, new_cache_coefs | |
class DfNet2(nn.Module): | |
def __init__(self, config: DfNet2Config): | |
super(DfNet2, self).__init__() | |
self.config = config | |
self.eps = 1e-12 | |
self.freq_bins = self.config.nfft // 2 + 1 | |
self.nfft = config.nfft | |
self.win_size = config.win_size | |
self.hop_size = config.hop_size | |
self.win_type = config.win_type | |
self.stft = ConvSTFT( | |
nfft=config.nfft, | |
win_size=config.win_size, | |
hop_size=config.hop_size, | |
win_type=config.win_type, | |
power=None, | |
requires_grad=False | |
) | |
self.istft = ConviSTFT( | |
nfft=config.nfft, | |
win_size=config.win_size, | |
hop_size=config.hop_size, | |
win_type=config.win_type, | |
requires_grad=False | |
) | |
self.erb_bands = ErbBands( | |
sample_rate=config.sample_rate, | |
nfft=config.nfft, | |
erb_bins=config.erb_bins, | |
min_freq_bins_for_erb=config.min_freq_bins_for_erb, | |
) | |
self.erb_ema = ErbEMA( | |
sample_rate=config.sample_rate, | |
hop_size=config.hop_size, | |
erb_bins=config.erb_bins, | |
) | |
self.spec_ema = SpecEMA( | |
sample_rate=config.sample_rate, | |
hop_size=config.hop_size, | |
df_bins=config.df_bins, | |
) | |
self.encoder = Encoder(config) | |
self.erb_decoder = ErbDecoder(config) | |
self.df_decoder = DfDecoder(config) | |
self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) | |
self.df_op = DeepFiltering( | |
df_bins=config.df_bins, | |
df_order=config.df_order, | |
lookahead=config.df_lookahead, | |
) | |
self.mask = Mask(use_post_filter=config.use_post_filter) | |
self.lsnr_fn = LocalSnrTarget( | |
sample_rate=config.sample_rate, | |
nfft=config.nfft, | |
win_size=config.win_size, | |
hop_size=config.hop_size, | |
n_frame=config.n_frame, | |
min_local_snr=config.min_local_snr, | |
max_local_snr=config.max_local_snr, | |
db=True, | |
) | |
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: | |
if signal.dim() == 2: | |
signal = torch.unsqueeze(signal, dim=1) | |
_, _, n_samples = signal.shape | |
remainder = (n_samples - self.win_size) % self.hop_size | |
if remainder > 0: | |
n_samples_pad = self.hop_size - remainder | |
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) | |
return signal | |
def feature_prepare(self, signal: torch.Tensor): | |
# noisy shape: [b, num_samples_pad] | |
spec_cmp = self.stft.forward(signal) | |
# spec_complex shape: [b, f, t], torch.complex64 | |
spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2) | |
# spec_complex shape: [b, t, f], torch.complex64 | |
spec_cmp_real = torch.view_as_real(spec_cmp) | |
# spec_cmp_real shape: [b, t, f, 2] | |
spec_mag = torch.abs(spec_cmp) | |
spec_pow = torch.square(spec_mag) | |
# shape: [b, t, f] | |
spec = torch.unsqueeze(spec_cmp_real, dim=1) | |
# spec shape: [b, 1, t, f, 2] | |
feat_erb = self.erb_bands.erb_scale(spec_pow, db=True) | |
# feat_erb shape: [b, t, erb_bins] | |
feat_erb = torch.unsqueeze(feat_erb, dim=1) | |
# feat_erb shape: [b, 1, t, erb_bins] | |
feat_spec = spec_cmp_real.permute(0, 3, 1, 2) | |
# feat_spec shape: [b, 2, t, f] | |
feat_spec = feat_spec[..., :self.df_decoder.df_bins] | |
# feat_spec shape: [b, 2, t, df_bins] | |
spec = spec.detach() | |
feat_erb = feat_erb.detach() | |
feat_spec = feat_spec.detach() | |
return spec, feat_erb, feat_spec | |
def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None): | |
if cache_dict is None: | |
cache_dict = defaultdict(lambda: None) | |
cache0 = cache_dict["cache0"] | |
cache1 = cache_dict["cache1"] | |
feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0) | |
feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1) | |
new_cache_dict = { | |
"cache0": new_cache0, | |
"cache1": new_cache1, | |
} | |
feat_erb = feat_erb.detach() | |
feat_spec = feat_spec.detach() | |
return feat_erb, feat_spec, new_cache_dict | |
def forward(self, | |
noisy: torch.Tensor, | |
): | |
""" | |
:param noisy: | |
:return: | |
est_spec: shape: [b, 257*2, t] | |
est_wav: shape: [b, num_samples] | |
est_mask: shape: [b, 257, t] | |
lsnr: shape: [b, 1, t] | |
""" | |
n_samples = noisy.shape[-1] | |
noisy = self.signal_prepare(noisy) | |
spec, feat_erb, feat_spec = self.feature_prepare(noisy) | |
if self.config.use_ema_norm: | |
feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec) | |
e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec) | |
mask, _ = self.erb_decoder.forward(emb, e3, e2, e1, e0) | |
# mask shape: [b, 1, t, erb_bins] | |
mask = self.erb_bands.erb_scale_inv(mask) | |
# mask shape: [b, 1, t, f] | |
if torch.any(mask > 1) or torch.any(mask < 0): | |
raise AssertionError | |
spec_m = self.mask.forward(spec, mask) | |
# spec_m shape: [b, 1, t, f, 2] | |
spec_m = spec_m[:, :, :, :self.config.spec_bins, :] | |
# spec_m shape: [b, 1, t, spec_bins, 2] | |
# lsnr shape: [b, t, 1] | |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1) | |
# lsnr shape: [b, 1, t] | |
df_coefs, _ = self.df_decoder.forward(emb, c0) | |
df_coefs = self.df_out_transform(df_coefs) | |
# df_coefs shape: [b, df_order, t, df_bins, 2] | |
spec_ = spec[:, :, :, :self.config.spec_bins, :] | |
# spec shape: [b, 1, t, spec_bins, 2] | |
spec_f = self.df_op.forward_offline(spec_, df_coefs) | |
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32 | |
spec_e = torch.concat(tensors=[ | |
spec_f, spec_m[..., self.df_decoder.df_bins:, :] | |
], dim=3) | |
spec_e = torch.squeeze(spec_e, dim=1) | |
spec_e = spec_e.permute(0, 2, 1, 3) | |
# spec_e shape: [b, spec_bins, t, 2] | |
# spec_e shape: [b, spec_bins, t, 2] | |
est_spec = torch.view_as_complex(spec_e.contiguous()) | |
# est_spec shape: [b, spec_bins, t], torch.complex64 | |
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) | |
# est_spec shape: [b, f, t], torch.complex64 | |
est_wav = self.istft.forward(est_spec) | |
est_wav = est_wav[:, :, :n_samples] | |
# est_wav shape: [b, 1, n_samples] | |
est_mask = torch.squeeze(mask, dim=1) | |
est_mask = est_mask.permute(0, 2, 1) | |
# est_mask shape: [b, f, t] | |
return est_spec, est_wav, est_mask, lsnr | |
def forward_chunk_by_chunk(self, | |
noisy: torch.Tensor, | |
): | |
noisy = self.signal_prepare(noisy) | |
b, _, _ = noisy.shape | |
noisy = torch.concat(tensors=[ | |
noisy, noisy.new_zeros(size=(b, 1, (self.config.df_lookahead+1)*self.hop_size)) | |
], dim=2) | |
b, _, num_samples = noisy.shape | |
t = (num_samples - self.win_size) // self.hop_size + 1 | |
cache_dict0 = None | |
cache_dict1 = None | |
cache_dict2 = None | |
cache_dict3 = None | |
cache_dict4 = None | |
cache_dict5 = None | |
cache_dict6 = None | |
waveform_list = list() | |
for i in range(int(t)): | |
begin = i * self.hop_size | |
end = begin + self.win_size | |
sub_noisy = noisy[:, :, begin: end] | |
spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy) | |
# spec shape: [b, 1, t, f, 2] | |
# feat_erb shape: [b, 1, t, erb_bins] | |
# feat_spec shape: [b, 2, t, df_bins] | |
if self.config.use_ema_norm: | |
feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0) | |
e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1) | |
mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2) | |
# mask shape: [b, 1, t, erb_bins] | |
mask = self.erb_bands.erb_scale_inv(mask) | |
# mask shape: [b, 1, t, f] | |
spec_m = self.mask.forward(spec, mask) | |
# spec_m shape: [b, 1, t, f, 2] | |
spec_m = spec_m[:, :, :, :self.config.spec_bins, :] | |
# spec_m shape: [b, 1, t, spec_bins, 2] | |
# lsnr shape: [b, t, 1] | |
lsnr = torch.transpose(lsnr, dim0=2, dim1=1) | |
# lsnr shape: [b, 1, t] | |
df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3) | |
df_coefs = self.df_out_transform(df_coefs) | |
# df_coefs shape: [b, df_order, t, df_bins, 2] | |
spec_ = spec[:, :, :, :self.config.spec_bins, :] | |
# spec shape: [b, 1, t, spec_bins, 2] | |
spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4) | |
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32 | |
spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5) | |
spec_e = torch.squeeze(spec_e, dim=1) | |
spec_e = spec_e.permute(0, 2, 1, 3) | |
# spec_e shape: [b, spec_bins, t, 2] | |
# spec_e shape: [b, spec_bins, t, 2] | |
est_spec = torch.view_as_complex(spec_e.contiguous()) | |
# est_spec shape: [b, spec_bins, t], torch.complex64 | |
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) | |
# est_spec shape: [b, f, t], torch.complex64 | |
est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6) | |
# est_wav shape: [b, 1, hop_size] | |
waveform_list.append(est_wav) | |
waveform = torch.concat(tensors=waveform_list, dim=-1) | |
# waveform shape: [b, 1, n] | |
return waveform | |
def spec_e_m_combine_online(self, spec_f: torch.Tensor, spec_m: torch.Tensor, cache_dict: dict = None): | |
""" | |
:param spec_f: shape: [b, 1, t, df_bins, 2], torch.float32 | |
:param spec_m: shape: [b, 1, t, spec_bins, 2] | |
:param cache_dict: | |
:return: | |
""" | |
if cache_dict is None: | |
cache_dict = defaultdict(lambda: None) | |
cache_spec_m = cache_dict["cache_spec_m"] | |
if cache_spec_m is None: | |
b, c, t, f, _ = spec_m.shape | |
cache_spec_m = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2)) | |
# cache0 shape: [b, 1, lookahead, f, 2] | |
spec_m_cat = torch.concat(tensors=[ | |
cache_spec_m, spec_m, | |
], dim=2) | |
spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :] | |
new_cache_spec_m = spec_m_cat[:, :, -self.config.df_lookahead:, :, :] | |
spec_e = torch.concat(tensors=[ | |
spec_f, spec_m[..., self.df_decoder.df_bins:, :] | |
], dim=3) | |
new_cache_dict = { | |
"cache_spec_m": new_cache_spec_m, | |
} | |
return spec_e, new_cache_dict | |
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): | |
""" | |
:param est_mask: torch.Tensor, shape: [b, 257, t] | |
:param clean: | |
:param noisy: | |
:return: | |
""" | |
if noisy.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
noise = noisy - clean | |
clean = self.signal_prepare(clean) | |
noise = self.signal_prepare(noise) | |
stft_clean = self.stft.forward(clean) | |
mag_clean = torch.abs(stft_clean) | |
stft_noise = self.stft.forward(noise) | |
mag_noise = torch.abs(stft_noise) | |
gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1) | |
loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean") | |
return loss | |
def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): | |
if noisy.shape != clean.shape: | |
raise AssertionError("Input signals must have the same shape") | |
noise = noisy - clean | |
clean = self.signal_prepare(clean) | |
noise = self.signal_prepare(noise) | |
stft_clean = self.stft.forward(clean) | |
stft_noise = self.stft.forward(noise) | |
# shape: [b, f, t] | |
stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2) | |
stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2) | |
# shape: [b, t, f] | |
stft_clean = torch.unsqueeze(stft_clean, dim=1) | |
stft_noise = torch.unsqueeze(stft_noise, dim=1) | |
# shape: [b, 1, t, f] | |
# lsnr shape: [b, 1, t] | |
lsnr = lsnr.squeeze(1) | |
# lsnr shape: [b, t] | |
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) | |
# lsnr_gth shape: [b, t] | |
loss = F.mse_loss(lsnr, lsnr_gth) | |
return loss | |
class DfNet2PretrainedModel(DfNet2): | |
def __init__(self, | |
config: DfNet2Config, | |
): | |
super(DfNet2PretrainedModel, self).__init__( | |
config=config, | |
) | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
config = DfNet2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
model = cls(config) | |
if os.path.isdir(pretrained_model_name_or_path): | |
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) | |
else: | |
ckpt_file = pretrained_model_name_or_path | |
with open(ckpt_file, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def save_pretrained(self, | |
save_directory: Union[str, os.PathLike], | |
state_dict: Optional[dict] = None, | |
): | |
model = self | |
if state_dict is None: | |
state_dict = model.state_dict() | |
os.makedirs(save_directory, exist_ok=True) | |
# save state dict | |
model_file = os.path.join(save_directory, MODEL_FILE) | |
torch.save(state_dict, model_file) | |
# save config | |
config_file = os.path.join(save_directory, CONFIG_FILE) | |
self.config.to_yaml_file(config_file) | |
return save_directory | |
def main(): | |
import time | |
# torch.set_num_threads(1) | |
config = DfNet2Config() | |
model = DfNet2PretrainedModel(config=config) | |
model.eval() | |
num_samples = 16000 | |
noisy = torch.randn(size=(1, num_samples), dtype=torch.float32) | |
duration = num_samples / config.sample_rate | |
begin = time.time() | |
est_spec, est_wav, est_mask, lsnr = model.forward(noisy) | |
time_cost = time.time() - begin | |
print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") | |
# print(f"est_spec.shape: {est_spec.shape}") | |
# print(f"est_wav.shape: {est_wav.shape}") | |
# print(f"est_mask.shape: {est_mask.shape}") | |
# print(f"lsnr.shape: {lsnr.shape}") | |
waveform = est_wav | |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") | |
print(waveform[:, :, 300: 302]) | |
print(waveform[:, :, 15680: 15682]) | |
print(waveform[:, :, 15760: 15762]) | |
print(waveform[:, :, 15840: 15842]) | |
begin = time.time() | |
waveform = model.forward_chunk_by_chunk(noisy) | |
time_cost = time.time() - begin | |
print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") | |
waveform = waveform[:, :, (config.df_lookahead*config.hop_size):] | |
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") | |
print(waveform[:, :, 300: 302]) | |
print(waveform[:, :, 15680: 15682]) | |
print(waveform[:, :, 15760: 15762]) | |
print(waveform[:, :, 15840: 15842]) | |
return | |
if __name__ == "__main__": | |
main() | |