#!/usr/bin/python3 # -*- coding: utf-8 -*- """ DeepFilterNet 的原生实现不直接支持流式推理 社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现 https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF 此文件试图实现一个支持流式推理的 dfnet """ import os import math from collections import defaultdict from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands from toolbox.torchaudio.modules.utils.ema import ErbEMA, SpecEMA MODEL_FILE = "model.pt" norm_layer_dict = { "batch_norm_2d": torch.nn.BatchNorm2d } activation_layer_dict = { "relu": torch.nn.ReLU, "identity": torch.nn.Identity, "sigmoid": torch.nn.Sigmoid, } class CausalConv2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, pad_f_dim: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", ): super(CausalConv2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) if pad_f_dim: fpad = kernel_size[1] // 2 + dilation - 1 else: fpad = 0 # for last 2 dim, pad (left, right, top, bottom). self.lookback = kernel_size[0] - 1 if self.lookback > 0: self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) else: self.tpad = nn.Identity() groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False if max(kernel_size) == 1: separable = False self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, fpad), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) if separable: self.convp = nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) else: self.convp = nn.Identity() if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] self.norm = norm_layer(out_channels) else: self.norm = nn.Identity() if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] self.activation = activation_layer() else: self.activation = nn.Identity() def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None): """ :param inputs: shape: [b, c, t, f] :param cache: shape: [b, c, lookback, f]; :return: """ x = inputs if cache is None: x = self.tpad(x) else: x = torch.concat(tensors=[cache, x], dim=2) new_cache = None if self.lookback > 0: new_cache = x[:, :, -self.lookback:, :] x = self.conv(x) x = self.convp(x) x = self.norm(x) x = self.activation(x) return x, new_cache class CausalConvTranspose2dErrorCase(nn.Module): """ 错误的缓存方法。 """ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, pad_f_dim: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", ): super(CausalConvTranspose2dErrorCase, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size if pad_f_dim: fpad = kernel_size[1] // 2 else: fpad = 0 # for last 2 dim, pad (left, right, top, bottom). self.lookback = kernel_size[0] - 1 groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False self.convt = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, fpad), output_padding=(0, fpad), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) if separable: self.convp = nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) else: self.convp = nn.Identity() if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] self.norm = norm_layer(out_channels) else: self.norm = nn.Identity() if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] self.activation = activation_layer() else: self.activation = nn.Identity() def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): """ :param inputs: shape: [b, c, t, f] :param cache: shape: [b, c, lookback, f]; :return: """ x = inputs # x shape: [b, c, t, f] x = self.convt(x) # x shape: [b, c, t+lookback, f] new_cache = None if self.lookback > 0: if cache is not None: x = torch.concat(tensors=[ x[:, :, :self.lookback, :] + cache, x[:, :, self.lookback:, :] ], dim=2) x = x[:, :, :-self.lookback, :] new_cache = x[:, :, -self.lookback:, :] x = self.convp(x) x = self.norm(x) x = self.activation(x) return x, new_cache class CausalConvTranspose2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, pad_f_dim: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", ): super(CausalConvTranspose2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size if pad_f_dim: fpad = kernel_size[1] // 2 else: fpad = 0 # for last 2 dim, pad (left, right, top, bottom). self.lookback = kernel_size[0] - 1 if self.lookback > 0: self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) else: self.tpad = nn.Identity() groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False self.convt = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size[0] - 1, fpad + dilation - 1), output_padding=(0, fpad), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) if separable: self.convp = nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) else: self.convp = nn.Identity() if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] self.norm = norm_layer(out_channels) else: self.norm = nn.Identity() if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] self.activation = activation_layer() else: self.activation = nn.Identity() def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None): """ :param inputs: shape: [b, c, t, f] :param cache: shape: [b, c, lookback, f]; :return: """ x = inputs # x shape: [b, c, t, f] x = self.convt(x) # x shape: [b, c, t+lookback, f] if cache is None: x = self.tpad(x) else: x = torch.concat(tensors=[cache, x], dim=2) new_cache = None if self.lookback > 0: new_cache = x[:, :, -self.lookback:, :] x = self.convp(x) x = self.norm(x) x = self.activation(x) return x, new_cache class GroupedLinear(nn.Module): def __init__(self, input_size: int, hidden_size: int, groups: int = 1): super().__init__() # self.weight: Tensor self.input_size = input_size self.hidden_size = hidden_size self.groups = groups assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" self.ws = input_size // groups self.register_parameter( "weight", torch.nn.Parameter( torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True ), ) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [..., I] b, t, f = x.shape if f != self.input_size: raise AssertionError # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] new_shape = (b, t, self.groups, self.ws) x = x.view(new_shape) # The better way, but not supported by torchscript # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] x = x.flatten(2, 3) # x: [b, t, h] return x def __repr__(self): cls = self.__class__.__name__ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" class SqueezedGRU_S(nn.Module): """ SGE net: Video object detection with squeezed GRU and information entropy map https://arxiv.org/abs/2106.07224 """ def __init__( self, input_size: int, hidden_size: int, output_size: Optional[int] = None, num_layers: int = 1, linear_groups: int = 8, batch_first: bool = True, skip_op: str = "none", activation_layer: str = "identity", ): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.linear_in = nn.Sequential( GroupedLinear( input_size=input_size, hidden_size=hidden_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) # gru skip operator self.gru_skip_op = None if skip_op == "none": self.gru_skip_op = None elif skip_op == "identity": if not input_size != output_size: raise AssertionError("Dimensions do not match") self.gru_skip_op = nn.Identity() elif skip_op == "grouped_linear": self.gru_skip_op = GroupedLinear( input_size=hidden_size, hidden_size=hidden_size, groups=linear_groups, ) else: raise NotImplementedError() self.gru = nn.GRU( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=False, ) if output_size is not None: self.linear_out = nn.Sequential( GroupedLinear( input_size=hidden_size, hidden_size=output_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) else: self.linear_out = nn.Identity() def forward(self, inputs: torch.Tensor, hx: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: # inputs: shape: [b, t, h] x = self.linear_in.forward(inputs) x, hx = self.gru.forward(x, hx) x = self.linear_out(x) if self.gru_skip_op is not None: x = x + self.gru_skip_op(inputs) return x, hx class Add(nn.Module): def forward(self, a, b): return a + b class Concat(nn.Module): def forward(self, a, b): return torch.cat((a, b), dim=-1) class Encoder(nn.Module): def __init__(self, config: DfNet2Config): super(Encoder, self).__init__() self.embedding_input_size = config.conv_channels * config.erb_bins // 4 self.embedding_output_size = config.conv_channels * config.erb_bins // 4 self.embedding_hidden_size = config.embedding_hidden_size self.spec_conv0 = CausalConv2d( in_channels=1, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, fstride=1, ) self.spec_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.spec_conv2 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.spec_conv3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=1, ) self.df_conv0 = CausalConv2d( in_channels=2, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, fstride=1, ) self.df_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.df_fc_emb = nn.Sequential( GroupedLinear( config.conv_channels * config.df_bins // 2, self.embedding_input_size, groups=config.encoder_linear_groups ), nn.ReLU(inplace=True) ) if config.encoder_combine_op == "concat": self.embedding_input_size *= 2 self.combine = Concat() else: self.combine = Add() # emb_gru if config.spec_bins % 8 != 0: raise AssertionError("spec_bins should be divisible by 8") self.emb_gru = SqueezedGRU_S( self.embedding_input_size, self.embedding_hidden_size, output_size=self.embedding_output_size, num_layers=1, batch_first=True, skip_op=config.encoder_emb_skip_op, linear_groups=config.encoder_emb_linear_groups, activation_layer="relu", ) # lsnr self.lsnr_fc = nn.Sequential( nn.Linear(self.embedding_output_size, 1), nn.Sigmoid() ) self.lsnr_scale = config.max_local_snr - config.min_local_snr self.lsnr_offset = config.min_local_snr def forward(self, feat_erb: torch.Tensor, feat_spec: torch.Tensor, cache_dict: dict = None, ): if cache_dict is None: cache_dict = defaultdict(lambda: None) cache0 = cache_dict["cache0"] cache1 = cache_dict["cache1"] cache2 = cache_dict["cache2"] cache3 = cache_dict["cache3"] cache4 = cache_dict["cache4"] cache5 = cache_dict["cache5"] cache6 = cache_dict["cache6"] # feat_erb shape: (b, 1, t, erb_bins) e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0) e1, new_cache1 = self.spec_conv1.forward(e0, cache=cache1) e2, new_cache2 = self.spec_conv2.forward(e1, cache=cache2) e3, new_cache3 = self.spec_conv3.forward(e2, cache=cache3) # e0 shape: [b, c, t, erb_bins] # e1 shape: [b, c, t, erb_bins // 2] # e2 shape: [b, c, t, erb_bins // 4] # e3 shape: [b, c, t, erb_bins // 4] # e3 shape: [b, 64, t, 32/4=8] # feat_spec, shape: (b, 2, t, df_bins) c0, new_cache4 = self.df_conv0.forward(feat_spec, cache=cache4) c1, new_cache5 = self.df_conv1.forward(c0, cache=cache5) # c0 shape: [b, c, t, df_bins] # c1 shape: [b, c, t, df_bins // 2] # c1 shape: [b, 64, t, 96/2=48] cemb = c1.permute(0, 2, 3, 1) # cemb shape: [b, t, df_bins // 2, c] cemb = cemb.flatten(2) # cemb shape: [b, t, df_bins // 2 * c] # cemb shape: [b, t, 96/2*64=3072] cemb = self.df_fc_emb.forward(cemb) # cemb shape: [b, t, erb_bins // 4 * c] # cemb shape: [b, t, 32/4*64=512] # e3 shape: [b, c, t, erb_bins // 4] emb = e3.permute(0, 2, 3, 1) # emb shape: [b, t, erb_bins // 4, c] emb = emb.flatten(2) # emb shape: [b, t, erb_bins // 4 * c] # emb shape: [b, t, 32/4*64=512] emb = self.combine(emb, cemb) # if concat; emb shape: [b, t, spec_bins // 4 * c * 2] # if add; emb shape: [b, t, spec_bins // 4 * c] emb, new_cache6 = self.emb_gru.forward(emb, hx=cache6) # emb shape: [b, t, spec_dim // 4 * c] # h shape: [b, 1, spec_dim] lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset # lsnr shape: [b, t, 1] new_cache_dict = { "cache0": new_cache0, "cache1": new_cache1, "cache2": new_cache2, "cache3": new_cache3, "cache4": new_cache4, "cache5": new_cache5, "cache6": new_cache6, } return e0, e1, e2, e3, emb, c0, lsnr, new_cache_dict class ErbDecoder(nn.Module): def __init__(self, config: DfNet2Config): super(ErbDecoder, self).__init__() if config.spec_bins % 8 != 0: raise AssertionError("spec_bins should be divisible by 8") self.emb_in_dim = config.conv_channels * config.erb_bins // 4 self.emb_out_dim = config.conv_channels * config.erb_bins // 4 self.emb_hidden_dim = config.decoder_emb_hidden_size self.emb_gru = SqueezedGRU_S( self.emb_in_dim, self.emb_hidden_dim, output_size=self.emb_out_dim, num_layers=config.decoder_emb_num_layers - 1, batch_first=True, skip_op=config.decoder_emb_skip_op, linear_groups=config.decoder_emb_linear_groups, activation_layer="relu", ) self.conv3p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, ) self.convt3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=1, ) self.conv2p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, ) self.convt2 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.conv1p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, ) self.convt1 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.conv0p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, ) self.conv0_out = CausalConv2d( in_channels=config.conv_channels, out_channels=1, kernel_size=config.conv_kernel_size_inner, activation_layer="sigmoid", bias=False, separable=True, fstride=1, ) def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor: if cache_dict is None: cache_dict = defaultdict(lambda: None) cache0 = cache_dict["cache0"] cache1 = cache_dict["cache1"] cache2 = cache_dict["cache2"] cache3 = cache_dict["cache3"] cache4 = cache_dict["cache4"] # Estimates erb mask b, _, t, f8 = e3.shape # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] emb, new_cache0 = self.emb_gru.forward(emb, hx=cache0) # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) e3, new_cache1 = self.convt3.forward(self.conv3p(e3)[0] + emb, cache=cache1) # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] e2, new_cache2 = self.convt2.forward(self.conv2p(e2)[0] + e3, cache=cache2) # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] e1, new_cache3 = self.convt1.forward(self.conv1p(e1)[0] + e2, cache=cache3) # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] mask, new_cache4 = self.conv0_out.forward(self.conv0p(e0)[0] + e1, cache=cache4) # mask shape: [batch_size, 1, time_steps, freq_dim] new_cache_dict = { "cache0": new_cache0, "cache1": new_cache1, "cache2": new_cache2, "cache3": new_cache3, "cache4": new_cache4, } return mask, new_cache_dict class DfDecoder(nn.Module): def __init__(self, config: DfNet2Config): super(DfDecoder, self).__init__() self.embedding_input_size = config.conv_channels * config.erb_bins // 4 self.df_decoder_hidden_size = config.df_decoder_hidden_size self.df_num_layers = config.df_num_layers self.df_order = config.df_order self.df_bins = config.df_bins self.df_out_ch = config.df_order * 2 self.df_convp = CausalConv2d( config.conv_channels, self.df_out_ch, fstride=1, kernel_size=(config.df_pathway_kernel_size_t, 1), separable=True, bias=False, ) self.df_gru = SqueezedGRU_S( self.embedding_input_size, self.df_decoder_hidden_size, num_layers=self.df_num_layers, batch_first=True, skip_op="none", activation_layer="relu", ) if config.df_gru_skip == "none": self.df_skip = None elif config.df_gru_skip == "identity": if config.embedding_hidden_size != config.df_decoder_hidden_size: raise AssertionError("Dimensions do not match") self.df_skip = nn.Identity() elif config.df_gru_skip == "grouped_linear": self.df_skip = GroupedLinear( self.embedding_input_size, self.df_decoder_hidden_size, groups=config.df_decoder_linear_groups ) else: raise NotImplementedError() self.df_out: nn.Module out_dim = self.df_bins * self.df_out_ch self.df_out = nn.Sequential( GroupedLinear( input_size=self.df_decoder_hidden_size, hidden_size=out_dim, groups=config.df_decoder_linear_groups, # groups = self.df_bins // 5, ), nn.Tanh() ) self.df_fc_a = nn.Sequential( nn.Linear(self.df_decoder_hidden_size, 1), nn.Sigmoid() ) def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor: if cache_dict is None: cache_dict = defaultdict(lambda: None) cache0 = cache_dict["cache0"] cache1 = cache_dict["cache1"] # emb shape: [batch_size, time_steps, df_bins // 4 * channels] b, t, _ = emb.shape df_coefs, new_cache0 = self.df_gru.forward(emb, hx=cache0) if self.df_skip is not None: df_coefs = df_coefs + self.df_skip(emb) # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] # c0 shape: [batch_size, channels, time_steps, df_bins] c0, new_cache1 = self.df_convp.forward(c0, cache=cache1) # c0 shape: [batch_size, df_order * 2, time_steps, df_bins] c0 = c0.permute(0, 2, 3, 1) # c0 shape: [batch_size, time_steps, df_bins, df_order * 2] df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] df_coefs = df_coefs + c0 # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] new_cache_dict = { "cache0": new_cache0, "cache1": new_cache1, } return df_coefs, new_cache_dict class DfOutputReshapeMF(nn.Module): """Coefficients output reshape for multiframe/MultiFrameModule Requires input of shape B, C, T, F, 2. """ def __init__(self, df_order: int, df_bins: int): super().__init__() self.df_order = df_order self.df_bins = df_bins def forward(self, coefs: torch.Tensor) -> torch.Tensor: # [B, T, F, O*2] -> [B, O, T, F, 2] new_shape = list(coefs.shape) new_shape[-1] = -1 new_shape.append(2) coefs = coefs.view(new_shape) coefs = coefs.permute(0, 3, 1, 2, 4) return coefs class Mask(nn.Module): def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): super().__init__() self.use_post_filter = use_post_filter self.eps = eps def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: """ Post-Filter A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. https://arxiv.org/abs/2008.04259 :param mask: Real valued mask, typically of shape [B, C, T, F]. :param beta: Global gain factor. :return: """ mask_sin = mask * torch.sin(np.pi * mask / 2) mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) return mask_pf def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # spec shape: [b, 1, t, spec_bins, 2] if not self.training and self.use_post_filter: mask = self.post_filter(mask) # mask shape: [b, 1, t, spec_bins] mask = mask.unsqueeze(4) # mask shape: [b, 1, t, spec_bins, 1] return spec * mask class DeepFiltering(nn.Module): def __init__(self, df_bins: int, df_order: int, lookahead: int = 0, ): super(DeepFiltering, self).__init__() self.df_bins = df_bins self.df_order = df_order self.lookahead = lookahead self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) def forward(self, *args, **kwargs): raise AssertionError("use `forward_offline` or `forward_online` stead.") def spec_unfold_offline(self, spec: torch.Tensor) -> torch.Tensor: """ Pads and unfolds the spectrogram according to frame_size. :param spec: shape: [b, c, t, f], dtype: torch.complex64 :return: shape: [b, c, t, f, df_order] """ if self.df_order <= 1: return spec.unsqueeze(-1) # spec shape: [b, 1, t, f], dtype: torch.complex64 spec = self.pad(spec) # spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64 spec_unfold = spec.unfold(dimension=2, size=self.df_order, step=1) # spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64 return spec_unfold def forward_offline(self, spec: torch.Tensor, coefs: torch.Tensor, ): # spec shape: [b, 1, t, spec_bins, 2] spec_c = torch.view_as_complex(spec.contiguous()) # spec_c shape: [b, 1, t, spec_bins] spec_u = self.spec_unfold_offline(spec_c) # spec_u shape: [b, 1, t, spec_bins, df_order] spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins) # spec_f shape: [b, 1, t, df_bins, df_order] # coefs shape: [b, df_order, t, df_bins, 2] coefs = torch.view_as_complex(coefs.contiguous()) # coefs shape: [b, df_order, t, df_bins] coefs = coefs.unsqueeze(dim=1) # coefs shape: [b, 1, df_order, t, df_bins] spec_f = self.df_offline(spec_f, coefs) # spec_f shape: [b, 1, t, df_bins] spec_f = torch.view_as_real(spec_f) # spec_f shape: [b, 1, t, df_bins, 2] return spec_f def df_offline(self, spec: torch.Tensor, coefs: torch.Tensor): """ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. :param spec: [b, 1, t, df_bins, df_order] complex. :param coefs: [b, 1, df_order, t, df_bins] complex. :return: [b, 1, t, df_bins] complex. """ spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs) return spec_f def spec_unfold_online(self, spec: torch.Tensor, cache_spec: torch.Tensor = None): """ Pads and unfolds the spectrogram according to frame_size. :param spec: shape: [b, c, t, f], dtype: torch.complex64 :param cache_spec: shape: [b, c, df_order-1, f], dtype: torch.complex64 :return: shape: [b, c, t, f, df_order] """ if self.df_order <= 1: return spec.unsqueeze(-1) if cache_spec is None: b, c, _, f = spec.shape cache_spec = spec.new_zeros(size=(b, c, self.df_order-1, f)) spec_pad = torch.concat(tensors=[ cache_spec, spec ], dim=2) new_cache_spec = spec_pad[:, :, -(self.df_order-1):, :] # spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64 spec_unfold = spec_pad.unfold(dimension=2, size=self.df_order, step=1) # spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64 return spec_unfold, new_cache_spec def forward_online(self, spec: torch.Tensor, coefs: torch.Tensor, cache_dict: dict = None, ): if cache_dict is None: cache_dict = defaultdict(lambda: None) cache0 = cache_dict["cache0"] cache1 = cache_dict["cache1"] # spec shape: [b, 1, t, spec_bins, 2] spec_c = torch.view_as_complex(spec.contiguous()) # spec_c shape: [b, 1, t, spec_bins] spec_u, new_cache0 = self.spec_unfold_online(spec_c, cache_spec=cache0) # spec_u shape: [b, 1, t, spec_bins, df_order] spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins) # spec_f shape: [b, 1, t, df_bins, df_order] # coefs shape: [b, df_order, t, df_bins, 2] coefs = torch.view_as_complex(coefs.contiguous()) # coefs shape: [b, df_order, t, df_bins] coefs = coefs.unsqueeze(dim=1) # coefs shape: [b, 1, df_order, t, df_bins] spec_f, new_cache1 = self.df_online(spec_f, coefs, cache_coefs=cache1) # spec_f shape: [b, 1, t, df_bins] spec_f = torch.view_as_real(spec_f) # spec_f shape: [b, 1, t, df_bins, 2] new_cache_dict = { "cache0": new_cache0, "cache1": new_cache1, } return spec_f, new_cache_dict def df_online(self, spec: torch.Tensor, coefs: torch.Tensor, cache_coefs: torch.Tensor = None) -> torch.Tensor: """ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. :param spec: [b, 1, 1, df_bins, df_order] complex. :param coefs: [b, 1, df_order, 1, df_bins] complex. :param cache_coefs: [b, 1, df_order, lookahead, df_bins] complex. :return: [b, 1, 1, df_bins] complex. """ if cache_coefs is None: b, c, _, _, f = coefs.shape cache_coefs = coefs.new_zeros(size=(b, c, self.df_order, self.lookahead, f)) coefs_pad = torch.concat(tensors=[ cache_coefs, coefs ], dim=3) # coefs_pad shape: [b, 1, df_order, 1+lookahead, df_bins], torch.complex64. coefs = coefs_pad[:, :, :, :-self.lookahead, :] # coefs shape: [b, 1, df_order, 1, df_bins], torch.complex64. new_cache_coefs = coefs_pad[:, :, :, -self.lookahead:, :] # new_cache_coefs shape: [b, 1, df_order, lookahead, df_bins], torch.complex64. spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs) return spec_f, new_cache_coefs class DfNet2(nn.Module): def __init__(self, config: DfNet2Config): super(DfNet2, self).__init__() self.config = config self.eps = 1e-12 self.freq_bins = self.config.nfft // 2 + 1 self.nfft = config.nfft self.win_size = config.win_size self.hop_size = config.hop_size self.win_type = config.win_type self.stft = ConvSTFT( nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, power=None, requires_grad=False ) self.istft = ConviSTFT( nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, requires_grad=False ) self.erb_bands = ErbBands( sample_rate=config.sample_rate, nfft=config.nfft, erb_bins=config.erb_bins, min_freq_bins_for_erb=config.min_freq_bins_for_erb, ) self.erb_ema = ErbEMA( sample_rate=config.sample_rate, hop_size=config.hop_size, erb_bins=config.erb_bins, ) self.spec_ema = SpecEMA( sample_rate=config.sample_rate, hop_size=config.hop_size, df_bins=config.df_bins, ) self.encoder = Encoder(config) self.erb_decoder = ErbDecoder(config) self.df_decoder = DfDecoder(config) self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) self.df_op = DeepFiltering( df_bins=config.df_bins, df_order=config.df_order, lookahead=config.df_lookahead, ) self.mask = Mask(use_post_filter=config.use_post_filter) self.lsnr_fn = LocalSnrTarget( sample_rate=config.sample_rate, nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, n_frame=config.n_frame, min_local_snr=config.min_local_snr, max_local_snr=config.max_local_snr, db=True, ) def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: if signal.dim() == 2: signal = torch.unsqueeze(signal, dim=1) _, _, n_samples = signal.shape remainder = (n_samples - self.win_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) return signal def feature_prepare(self, signal: torch.Tensor): # noisy shape: [b, num_samples_pad] spec_cmp = self.stft.forward(signal) # spec_complex shape: [b, f, t], torch.complex64 spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2) # spec_complex shape: [b, t, f], torch.complex64 spec_cmp_real = torch.view_as_real(spec_cmp) # spec_cmp_real shape: [b, t, f, 2] spec_mag = torch.abs(spec_cmp) spec_pow = torch.square(spec_mag) # shape: [b, t, f] spec = torch.unsqueeze(spec_cmp_real, dim=1) # spec shape: [b, 1, t, f, 2] feat_erb = self.erb_bands.erb_scale(spec_pow, db=True) # feat_erb shape: [b, t, erb_bins] feat_erb = torch.unsqueeze(feat_erb, dim=1) # feat_erb shape: [b, 1, t, erb_bins] feat_spec = spec_cmp_real.permute(0, 3, 1, 2) # feat_spec shape: [b, 2, t, f] feat_spec = feat_spec[..., :self.df_decoder.df_bins] # feat_spec shape: [b, 2, t, df_bins] spec = spec.detach() feat_erb = feat_erb.detach() feat_spec = feat_spec.detach() return spec, feat_erb, feat_spec def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None): if cache_dict is None: cache_dict = defaultdict(lambda: None) cache0 = cache_dict["cache0"] cache1 = cache_dict["cache1"] feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0) feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1) new_cache_dict = { "cache0": new_cache0, "cache1": new_cache1, } feat_erb = feat_erb.detach() feat_spec = feat_spec.detach() return feat_erb, feat_spec, new_cache_dict def forward(self, noisy: torch.Tensor, ): """ :param noisy: :return: est_spec: shape: [b, 257*2, t] est_wav: shape: [b, num_samples] est_mask: shape: [b, 257, t] lsnr: shape: [b, 1, t] """ n_samples = noisy.shape[-1] noisy = self.signal_prepare(noisy) spec, feat_erb, feat_spec = self.feature_prepare(noisy) if self.config.use_ema_norm: feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec) e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec) mask, _ = self.erb_decoder.forward(emb, e3, e2, e1, e0) # mask shape: [b, 1, t, erb_bins] mask = self.erb_bands.erb_scale_inv(mask) # mask shape: [b, 1, t, f] if torch.any(mask > 1) or torch.any(mask < 0): raise AssertionError spec_m = self.mask.forward(spec, mask) # spec_m shape: [b, 1, t, f, 2] spec_m = spec_m[:, :, :, :self.config.spec_bins, :] # spec_m shape: [b, 1, t, spec_bins, 2] # lsnr shape: [b, t, 1] lsnr = torch.transpose(lsnr, dim0=2, dim1=1) # lsnr shape: [b, 1, t] df_coefs, _ = self.df_decoder.forward(emb, c0) df_coefs = self.df_out_transform(df_coefs) # df_coefs shape: [b, df_order, t, df_bins, 2] spec_ = spec[:, :, :, :self.config.spec_bins, :] # spec shape: [b, 1, t, spec_bins, 2] spec_f = self.df_op.forward_offline(spec_, df_coefs) # spec_f shape: [b, 1, t, df_bins, 2], torch.float32 spec_e = torch.concat(tensors=[ spec_f, spec_m[..., self.df_decoder.df_bins:, :] ], dim=3) spec_e = torch.squeeze(spec_e, dim=1) spec_e = spec_e.permute(0, 2, 1, 3) # spec_e shape: [b, spec_bins, t, 2] # spec_e shape: [b, spec_bins, t, 2] est_spec = torch.view_as_complex(spec_e.contiguous()) # est_spec shape: [b, spec_bins, t], torch.complex64 est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) # est_spec shape: [b, f, t], torch.complex64 est_wav = self.istft.forward(est_spec) est_wav = est_wav[:, :, :n_samples] # est_wav shape: [b, 1, n_samples] est_mask = torch.squeeze(mask, dim=1) est_mask = est_mask.permute(0, 2, 1) # est_mask shape: [b, f, t] return est_spec, est_wav, est_mask, lsnr def forward_chunk_by_chunk(self, noisy: torch.Tensor, ): noisy = self.signal_prepare(noisy) b, _, _ = noisy.shape noisy = torch.concat(tensors=[ noisy, noisy.new_zeros(size=(b, 1, (self.config.df_lookahead+1)*self.hop_size)) ], dim=2) b, _, num_samples = noisy.shape t = (num_samples - self.win_size) // self.hop_size + 1 cache_dict0 = None cache_dict1 = None cache_dict2 = None cache_dict3 = None cache_dict4 = None cache_dict5 = None cache_dict6 = None waveform_list = list() for i in range(int(t)): begin = i * self.hop_size end = begin + self.win_size sub_noisy = noisy[:, :, begin: end] spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy) # spec shape: [b, 1, t, f, 2] # feat_erb shape: [b, 1, t, erb_bins] # feat_spec shape: [b, 2, t, df_bins] if self.config.use_ema_norm: feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0) e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1) mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2) # mask shape: [b, 1, t, erb_bins] mask = self.erb_bands.erb_scale_inv(mask) # mask shape: [b, 1, t, f] spec_m = self.mask.forward(spec, mask) # spec_m shape: [b, 1, t, f, 2] spec_m = spec_m[:, :, :, :self.config.spec_bins, :] # spec_m shape: [b, 1, t, spec_bins, 2] # lsnr shape: [b, t, 1] lsnr = torch.transpose(lsnr, dim0=2, dim1=1) # lsnr shape: [b, 1, t] df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3) df_coefs = self.df_out_transform(df_coefs) # df_coefs shape: [b, df_order, t, df_bins, 2] spec_ = spec[:, :, :, :self.config.spec_bins, :] # spec shape: [b, 1, t, spec_bins, 2] spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4) # spec_f shape: [b, 1, t, df_bins, 2], torch.float32 spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5) spec_e = torch.squeeze(spec_e, dim=1) spec_e = spec_e.permute(0, 2, 1, 3) # spec_e shape: [b, spec_bins, t, 2] # spec_e shape: [b, spec_bins, t, 2] est_spec = torch.view_as_complex(spec_e.contiguous()) # est_spec shape: [b, spec_bins, t], torch.complex64 est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) # est_spec shape: [b, f, t], torch.complex64 est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6) # est_wav shape: [b, 1, hop_size] waveform_list.append(est_wav) waveform = torch.concat(tensors=waveform_list, dim=-1) # waveform shape: [b, 1, n] return waveform def spec_e_m_combine_online(self, spec_f: torch.Tensor, spec_m: torch.Tensor, cache_dict: dict = None): """ :param spec_f: shape: [b, 1, t, df_bins, 2], torch.float32 :param spec_m: shape: [b, 1, t, spec_bins, 2] :param cache_dict: :return: """ if cache_dict is None: cache_dict = defaultdict(lambda: None) cache_spec_m = cache_dict["cache_spec_m"] if cache_spec_m is None: b, c, t, f, _ = spec_m.shape cache_spec_m = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2)) # cache0 shape: [b, 1, lookahead, f, 2] spec_m_cat = torch.concat(tensors=[ cache_spec_m, spec_m, ], dim=2) spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :] new_cache_spec_m = spec_m_cat[:, :, -self.config.df_lookahead:, :, :] spec_e = torch.concat(tensors=[ spec_f, spec_m[..., self.df_decoder.df_bins:, :] ], dim=3) new_cache_dict = { "cache_spec_m": new_cache_spec_m, } return spec_e, new_cache_dict def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): """ :param est_mask: torch.Tensor, shape: [b, 257, t] :param clean: :param noisy: :return: """ if noisy.shape != clean.shape: raise AssertionError("Input signals must have the same shape") noise = noisy - clean clean = self.signal_prepare(clean) noise = self.signal_prepare(noise) stft_clean = self.stft.forward(clean) mag_clean = torch.abs(stft_clean) stft_noise = self.stft.forward(noise) mag_noise = torch.abs(stft_noise) gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1) loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean") return loss def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): if noisy.shape != clean.shape: raise AssertionError("Input signals must have the same shape") noise = noisy - clean clean = self.signal_prepare(clean) noise = self.signal_prepare(noise) stft_clean = self.stft.forward(clean) stft_noise = self.stft.forward(noise) # shape: [b, f, t] stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2) stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2) # shape: [b, t, f] stft_clean = torch.unsqueeze(stft_clean, dim=1) stft_noise = torch.unsqueeze(stft_noise, dim=1) # shape: [b, 1, t, f] # lsnr shape: [b, 1, t] lsnr = lsnr.squeeze(1) # lsnr shape: [b, t] lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) # lsnr_gth shape: [b, t] loss = F.mse_loss(lsnr, lsnr_gth) return loss class DfNet2PretrainedModel(DfNet2): def __init__(self, config: DfNet2Config, ): super(DfNet2PretrainedModel, self).__init__( config=config, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = DfNet2Config.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): import time # torch.set_num_threads(1) config = DfNet2Config() model = DfNet2PretrainedModel(config=config) model.eval() num_samples = 16000 noisy = torch.randn(size=(1, num_samples), dtype=torch.float32) duration = num_samples / config.sample_rate begin = time.time() est_spec, est_wav, est_mask, lsnr = model.forward(noisy) time_cost = time.time() - begin print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") # print(f"est_spec.shape: {est_spec.shape}") # print(f"est_wav.shape: {est_wav.shape}") # print(f"est_mask.shape: {est_mask.shape}") # print(f"lsnr.shape: {lsnr.shape}") waveform = est_wav print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) print(waveform[:, :, 15680: 15682]) print(waveform[:, :, 15760: 15762]) print(waveform[:, :, 15840: 15842]) begin = time.time() waveform = model.forward_chunk_by_chunk(noisy) time_cost = time.time() - begin print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}") waveform = waveform[:, :, (config.df_lookahead*config.hop_size):] print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}") print(waveform[:, :, 300: 302]) print(waveform[:, :, 15680: 15682]) print(waveform[:, :, 15760: 15762]) print(waveform[:, :, 15840: 15842]) return if __name__ == "__main__": main()