#!/usr/bin/python3 # -*- coding: utf-8 -*- """ DeepFilterNet 的原生实现不直接支持流式推理 社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现 https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF """ import os import math from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands MODEL_FILE = "model.pt" norm_layer_dict = { "batch_norm_2d": torch.nn.BatchNorm2d } activation_layer_dict = { "relu": torch.nn.ReLU, "identity": torch.nn.Identity, "sigmoid": torch.nn.Sigmoid, } class CausalConv2d(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, fpad: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", lookahead: int = 0 ): """ Causal Conv2d by delaying the signal for any lookahead. Expected input format: [batch_size, channels, time_steps, spec_dim] :param in_channels: :param out_channels: :param kernel_size: :param fstride: :param dilation: :param fpad: """ super(CausalConv2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) if fpad: fpad_ = kernel_size[1] // 2 + dilation - 1 else: fpad_ = 0 # for last 2 dim, pad (left, right, top, bottom). pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) layers = list() if any(x > 0 for x in pad): layers.append(nn.ConstantPad2d(pad, 0.0)) groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False if max(kernel_size) == 1: separable = False layers.append( nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, fpad_), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) ) if separable: layers.append( nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) ) if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] layers.append(norm_layer(out_channels)) if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] layers.append(activation_layer()) super().__init__(*layers) def forward(self, inputs): for module in self: inputs = module(inputs) return inputs class CausalConvTranspose2d(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, fpad: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", lookahead: int = 0 ): """ Causal ConvTranspose2d. Expected input format: [batch_size, channels, time_steps, spec_dim] """ super(CausalConvTranspose2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size if fpad: fpad_ = kernel_size[1] // 2 else: fpad_ = 0 # for last 2 dim, pad (left, right, top, bottom). pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) layers = [] if any(x > 0 for x in pad): layers.append(nn.ConstantPad2d(pad, 0.0)) groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False layers.append( nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size[0] - 1, fpad_ + dilation - 1), output_padding=(0, fpad_), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) ) if separable: layers.append( nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) ) if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] layers.append(norm_layer(out_channels)) if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] layers.append(activation_layer()) super().__init__(*layers) class GroupedLinear(nn.Module): def __init__(self, input_size: int, hidden_size: int, groups: int = 1): super().__init__() # self.weight: Tensor self.input_size = input_size self.hidden_size = hidden_size self.groups = groups assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" self.ws = input_size // groups self.register_parameter( "weight", torch.nn.Parameter( torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True ), ) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [..., I] b, t, f = x.shape if f != self.input_size: raise AssertionError # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] new_shape = (b, t, self.groups, self.ws) x = x.view(new_shape) # The better way, but not supported by torchscript # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] x = x.flatten(2, 3) # x: [b, t, h] return x def __repr__(self): cls = self.__class__.__name__ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" class SqueezedGRU_S(nn.Module): """ SGE net: Video object detection with squeezed GRU and information entropy map https://arxiv.org/abs/2106.07224 """ def __init__( self, input_size: int, hidden_size: int, output_size: Optional[int] = None, num_layers: int = 1, linear_groups: int = 8, batch_first: bool = True, skip_op: str = "none", activation_layer: str = "identity", ): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.linear_in = nn.Sequential( GroupedLinear( input_size=input_size, hidden_size=hidden_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) # gru skip operator self.gru_skip_op = None if skip_op == "none": self.gru_skip_op = None elif skip_op == "identity": if not input_size != output_size: raise AssertionError("Dimensions do not match") self.gru_skip_op = nn.Identity() elif skip_op == "grouped_linear": self.gru_skip_op = GroupedLinear( input_size=hidden_size, hidden_size=hidden_size, groups=linear_groups, ) else: raise NotImplementedError() self.gru = nn.GRU( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=False, ) if output_size is not None: self.linear_out = nn.Sequential( GroupedLinear( input_size=hidden_size, hidden_size=output_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) else: self.linear_out = nn.Identity() def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: # inputs: shape: [b, t, h] x = self.linear_in.forward(inputs) x, h = self.gru.forward(x, h) x = self.linear_out(x) if self.gru_skip_op is not None: x = x + self.gru_skip_op(inputs) return x, h class Add(nn.Module): def forward(self, a, b): return a + b class Concat(nn.Module): def forward(self, a, b): return torch.cat((a, b), dim=-1) class Encoder(nn.Module): def __init__(self, config: DfNetConfig): super(Encoder, self).__init__() self.embedding_input_size = config.conv_channels * config.erb_bins // 4 self.embedding_output_size = config.conv_channels * config.erb_bins // 4 self.embedding_hidden_size = config.embedding_hidden_size self.spec_conv0 = CausalConv2d( in_channels=1, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.spec_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.spec_conv2 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.spec_conv3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.df_conv0 = CausalConv2d( in_channels=2, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, fstride=1, ) self.df_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.df_fc_emb = nn.Sequential( GroupedLinear( config.conv_channels * config.df_bins // 2, self.embedding_input_size, groups=config.encoder_linear_groups ), nn.ReLU(inplace=True) ) if config.encoder_combine_op == "concat": self.embedding_input_size *= 2 self.combine = Concat() else: self.combine = Add() # emb_gru if config.spec_bins % 8 != 0: raise AssertionError("spec_bins should be divisible by 8") self.emb_gru = SqueezedGRU_S( self.embedding_input_size, self.embedding_hidden_size, output_size=self.embedding_output_size, num_layers=1, batch_first=True, skip_op=config.encoder_emb_skip_op, linear_groups=config.encoder_emb_linear_groups, activation_layer="relu", ) # lsnr self.lsnr_fc = nn.Sequential( nn.Linear(self.embedding_output_size, 1), nn.Sigmoid() ) self.lsnr_scale = config.max_local_snr - config.min_local_snr self.lsnr_offset = config.min_local_snr def forward(self, feat_erb: torch.Tensor, feat_spec: torch.Tensor, hidden_state: torch.Tensor = None, ): # feat_erb shape: (b, 1, t, erb_bins) e0 = self.spec_conv0.forward(feat_erb) e1 = self.spec_conv1.forward(e0) e2 = self.spec_conv2.forward(e1) e3 = self.spec_conv3.forward(e2) # e0 shape: [b, c, t, erb_bins] # e1 shape: [b, c, t, erb_bins // 2] # e2 shape: [b, c, t, erb_bins // 4] # e3 shape: [b, c, t, erb_bins // 4] # e3 shape: [b, 64, t, 32/4=8] # feat_spec, shape: (b, 2, t, df_bins) c0 = self.df_conv0(feat_spec) c1 = self.df_conv1(c0) # c0 shape: [b, c, t, df_bins] # c1 shape: [b, c, t, df_bins // 2] # c1 shape: [b, 64, t, 96/2=48] cemb = c1.permute(0, 2, 3, 1) # cemb shape: [b, t, df_bins // 2, c] cemb = cemb.flatten(2) # cemb shape: [b, t, df_bins // 2 * c] # cemb shape: [b, t, 96/2*64=3072] cemb = self.df_fc_emb.forward(cemb) # cemb shape: [b, t, erb_bins // 4 * c] # cemb shape: [b, t, 32/4*64=512] # e3 shape: [b, c, t, erb_bins // 4] emb = e3.permute(0, 2, 3, 1) # emb shape: [b, t, erb_bins // 4, c] emb = emb.flatten(2) # emb shape: [b, t, erb_bins // 4 * c] # emb shape: [b, t, 32/4*64=512] emb = self.combine(emb, cemb) # if concat; emb shape: [b, t, spec_bins // 4 * c * 2] # if add; emb shape: [b, t, spec_bins // 4 * c] emb, h = self.emb_gru.forward(emb, hidden_state) # emb shape: [b, t, spec_dim // 4 * c] # h shape: [b, 1, spec_dim] lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset # lsnr shape: [b, t, 1] return e0, e1, e2, e3, emb, c0, lsnr, h class Decoder(nn.Module): """ErbDecoder""" def __init__(self, config: DfNetConfig): super(Decoder, self).__init__() if config.spec_bins % 8 != 0: raise AssertionError("spec_bins should be divisible by 8") self.emb_in_dim = config.conv_channels * config.erb_bins // 4 self.emb_out_dim = config.conv_channels * config.erb_bins // 4 self.emb_hidden_dim = config.decoder_emb_hidden_size self.emb_gru = SqueezedGRU_S( self.emb_in_dim, self.emb_hidden_dim, output_size=self.emb_out_dim, num_layers=config.decoder_emb_num_layers - 1, batch_first=True, skip_op=config.decoder_emb_skip_op, linear_groups=config.decoder_emb_linear_groups, activation_layer="relu", ) self.conv3p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.convt3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.conv2p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.convt2 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.conv1p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.convt1 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.conv0p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.conv0_out = CausalConv2d( in_channels=config.conv_channels, out_channels=1, kernel_size=config.conv_kernel_size_inner, activation_layer="sigmoid", bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: # Estimates erb mask b, _, t, f8 = e3.shape # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] emb, _ = self.emb_gru.forward(emb) # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) e3 = self.convt3(self.conv3p(e3) + emb) # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] e2 = self.convt2(self.conv2p(e2) + e3) # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] e1 = self.convt1(self.conv1p(e1) + e2) # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] mask = self.conv0_out(self.conv0p(e0) + e1) # mask shape: [batch_size, 1, time_steps, freq_dim] return mask class DfDecoder(nn.Module): def __init__(self, config: DfNetConfig): super(DfDecoder, self).__init__() self.embedding_input_size = config.conv_channels * config.erb_bins // 4 self.df_decoder_hidden_size = config.df_decoder_hidden_size self.df_num_layers = config.df_num_layers self.df_order = config.df_order self.df_bins = config.df_bins self.df_out_ch = config.df_order * 2 self.df_convp = CausalConv2d( config.conv_channels, self.df_out_ch, fstride=1, kernel_size=(config.df_pathway_kernel_size_t, 1), separable=True, bias=False, ) self.df_gru = SqueezedGRU_S( self.embedding_input_size, self.df_decoder_hidden_size, num_layers=self.df_num_layers, batch_first=True, skip_op="none", activation_layer="relu", ) if config.df_gru_skip == "none": self.df_skip = None elif config.df_gru_skip == "identity": if config.embedding_hidden_size != config.df_decoder_hidden_size: raise AssertionError("Dimensions do not match") self.df_skip = nn.Identity() elif config.df_gru_skip == "grouped_linear": self.df_skip = GroupedLinear( self.embedding_input_size, self.df_decoder_hidden_size, groups=config.df_decoder_linear_groups ) else: raise NotImplementedError() self.df_out: nn.Module out_dim = self.df_bins * self.df_out_ch self.df_out = nn.Sequential( GroupedLinear( input_size=self.df_decoder_hidden_size, hidden_size=out_dim, groups=config.df_decoder_linear_groups, # groups = self.df_bins // 5, ), nn.Tanh() ) self.df_fc_a = nn.Sequential( nn.Linear(self.df_decoder_hidden_size, 1), nn.Sigmoid() ) def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: # emb shape: [batch_size, time_steps, df_bins // 4 * channels] b, t, _ = emb.shape df_coefs, _ = self.df_gru(emb) if self.df_skip is not None: df_coefs = df_coefs + self.df_skip(emb) # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] # c0 shape: [batch_size, channels, time_steps, df_bins] c0 = self.df_convp(c0) # c0 shape: [batch_size, df_order * 2, time_steps, df_bins] c0 = c0.permute(0, 2, 3, 1) # c0 shape: [batch_size, time_steps, df_bins, df_order * 2] df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] df_coefs = df_coefs + c0 # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] return df_coefs class DfOutputReshapeMF(nn.Module): """Coefficients output reshape for multiframe/MultiFrameModule Requires input of shape B, C, T, F, 2. """ def __init__(self, df_order: int, df_bins: int): super().__init__() self.df_order = df_order self.df_bins = df_bins def forward(self, coefs: torch.Tensor) -> torch.Tensor: # [B, T, F, O*2] -> [B, O, T, F, 2] new_shape = list(coefs.shape) new_shape[-1] = -1 new_shape.append(2) coefs = coefs.view(new_shape) coefs = coefs.permute(0, 3, 1, 2, 4) return coefs class Mask(nn.Module): def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): super().__init__() self.use_post_filter = use_post_filter self.eps = eps def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: """ Post-Filter A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. https://arxiv.org/abs/2008.04259 :param mask: Real valued mask, typically of shape [B, C, T, F]. :param beta: Global gain factor. :return: """ mask_sin = mask * torch.sin(np.pi * mask / 2) mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) return mask_pf def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # spec shape: [b, 1, t, spec_bins, 2] if not self.training and self.use_post_filter: mask = self.post_filter(mask) # mask shape: [b, 1, t, spec_bins] mask = mask.unsqueeze(4) # mask shape: [b, 1, t, spec_bins, 1] return spec * mask class DeepFiltering(nn.Module): def __init__(self, df_bins: int, df_order: int, lookahead: int = 0, ): super(DeepFiltering, self).__init__() self.df_bins = df_bins self.df_order = df_order self.need_unfold = df_order > 1 self.lookahead = lookahead self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) def spec_unfold(self, spec: torch.Tensor): """ Pads and unfolds the spectrogram according to frame_size. :param spec: complex Tensor, Spectrogram of shape [B, C, T, F]. :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. """ if self.need_unfold: # spec shape: [batch_size, spec_bins, time_steps] spec_pad = self.pad(spec) # spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins] spec_unfold = spec_pad.unfold(2, self.df_order, 1) # spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order] return spec_unfold else: return spec.unsqueeze(-1) def forward(self, spec: torch.Tensor, coefs: torch.Tensor, ): # spec shape: [batch_size, 1, time_steps, spec_bins, 2] spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous())) # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order] # coefs shape: [batch_size, df_order, time_steps, df_bins, 2] coefs = torch.view_as_complex(coefs.contiguous()) # coefs shape: [batch_size, df_order, time_steps, df_bins] spec_f = spec_u.narrow(-2, 0, self.df_bins) # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order] coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:]) # coefs shape: [batch_size, 1, df_order, time_steps, df_bins] spec_f = self.df(spec_f, coefs) # spec_f shape: [batch_size, 1, time_steps, df_bins] if self.training: spec = spec.clone() spec[..., :self.df_bins, :] = torch.view_as_real(spec_f) # spec shape: [batch_size, 1, time_steps, spec_bins, 2] return spec @staticmethod def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: """ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. :return: (complex Tensor). Spectrogram of shape [B, C, T, F]. """ return torch.einsum("...tfn,...ntf->...tf", spec, coefs) class DfNet(nn.Module): """ 我感觉这个模型没办法实现完全一致的流式推理。 """ def __init__(self, config: DfNetConfig): super(DfNet, self).__init__() self.config = config self.eps = 1e-12 self.freq_bins = self.config.nfft // 2 + 1 self.nfft = config.nfft self.win_size = config.win_size self.hop_size = config.hop_size self.win_type = config.win_type self.erb_bands = ErbBands( sample_rate=config.sample_rate, nfft=config.nfft, erb_bins=config.erb_bins, min_freq_bins_for_erb=config.min_freq_bins_for_erb, ) self.stft = ConvSTFT( nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, power=None, requires_grad=False ) self.istft = ConviSTFT( nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, requires_grad=False ) self.encoder = Encoder(config) self.decoder = Decoder(config) self.df_decoder = DfDecoder(config) self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) self.df_op = DeepFiltering( df_bins=config.df_bins, df_order=config.df_order, lookahead=config.df_lookahead, ) self.mask = Mask(use_post_filter=config.use_post_filter) self.lsnr_fn = LocalSnrTarget( sample_rate=config.sample_rate, nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, n_frame=config.n_frame, min_local_snr=config.min_local_snr, max_local_snr=config.max_local_snr, db=True, ) def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor: if signal.dim() == 2: signal = torch.unsqueeze(signal, dim=1) _, _, n_samples = signal.shape remainder = (n_samples - self.win_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0) return signal def feature_prepare(self, signal: torch.Tensor): # noisy shape: [b, num_samples_pad] spec_cmp = self.stft.forward(signal) # spec_complex shape: [b, f, t], torch.complex64 spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2) # spec_complex shape: [b, t, f], torch.complex64 spec_cmp_real = torch.view_as_real(spec_cmp) # spec_cmp_real shape: [b, t, f, 2] spec_mag = torch.abs(spec_cmp) spec_pow = torch.square(spec_mag) # shape: [b, t, f] spec = torch.unsqueeze(spec_cmp_real, dim=1) # spec shape: [b, 1, t, f, 2] feat_erb = self.erb_bands.erb_scale(spec_pow, db=True) # feat_erb shape: [b, t, erb_bins] feat_erb = torch.unsqueeze(feat_erb, dim=1) # feat_erb shape: [b, 1, t, erb_bins] feat_spec = spec_cmp_real.permute(0, 3, 1, 2) # feat_spec shape: [b, 2, t, f] feat_spec = feat_spec[..., :self.df_decoder.df_bins] # feat_spec shape: [b, 2, t, df_bins] return spec, feat_erb, feat_spec def forward(self, noisy: torch.Tensor, ): """ :param noisy: :return: est_spec: shape: [b, 257*2, t] est_wav: shape: [b, num_samples] est_mask: shape: [b, 257, t] lsnr: shape: [b, 1, t] """ n_samples = noisy.shape[-1] noisy = self.signal_prepare(noisy) spec, feat_erb, feat_spec = self.feature_prepare(noisy) e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_erb, feat_spec) mask = self.decoder.forward(emb, e3, e2, e1, e0) # mask shape: [b, 1, t, erb_bins] mask = self.erb_bands.erb_scale_inv(mask) # mask shape: [b, 1, t, f] if torch.any(mask > 1) or torch.any(mask < 0): raise AssertionError spec_m = self.mask.forward(spec, mask) # spec_m shape: [b, 1, t, f, 2] spec_m = spec_m[:, :, :, :self.config.spec_bins, :] # spec_m shape: [b, 1, t, spec_bins, 2] # lsnr shape: [b, t, 1] lsnr = torch.transpose(lsnr, dim0=2, dim1=1) # lsnr shape: [b, 1, t] df_coefs = self.df_decoder.forward(emb, c0) df_coefs = self.df_out_transform(df_coefs) # df_coefs shape: [b, df_order, t, df_bins, 2] spec_ = spec[:, :, :, :self.config.spec_bins, :] # spec shape: [b, 1, t, spec_bins, 2] spec_e = self.df_op.forward(spec_, df_coefs) # spec_e shape: [b, 1, t, spec_bins, 2] spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :] spec_e = torch.squeeze(spec_e, dim=1) spec_e = spec_e.permute(0, 2, 1, 3) # spec_e shape: [b, spec_bins, t, 2] # spec_e shape: [b, spec_bins, t, 2] est_spec = torch.complex(real=spec_e[..., 0], imag=spec_e[..., 1]) # est_spec shape: [b, spec_bins, t], torch.complex64 est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1) # est_spec shape: [b, f, t], torch.complex64 est_wav = self.istft.forward(est_spec) est_wav = est_wav[:, :, :n_samples] # est_wav shape: [b, 1, n_samples] est_mask = torch.squeeze(mask, dim=1) est_mask = est_mask.permute(0, 2, 1) # est_mask shape: [b, f, t] return est_spec, est_wav, est_mask, lsnr def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): """ :param est_mask: torch.Tensor, shape: [b, 257, t] :param clean: :param noisy: :return: """ if noisy.shape != clean.shape: raise AssertionError("Input signals must have the same shape") noise = noisy - clean clean = self.signal_prepare(clean) noise = self.signal_prepare(noise) stft_clean = self.stft.forward(clean) mag_clean = torch.abs(stft_clean) stft_noise = self.stft.forward(noise) mag_noise = torch.abs(stft_noise) gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1) loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean") return loss def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): if noisy.shape != clean.shape: raise AssertionError("Input signals must have the same shape") noise = noisy - clean clean = self.signal_prepare(clean) noise = self.signal_prepare(noise) stft_clean = self.stft.forward(clean) stft_noise = self.stft.forward(noise) # shape: [b, f, t] stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2) stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2) # shape: [b, t, f] stft_clean = torch.unsqueeze(stft_clean, dim=1) stft_noise = torch.unsqueeze(stft_noise, dim=1) # shape: [b, 1, t, f] # lsnr shape: [b, 1, t] lsnr = lsnr.squeeze(1) # lsnr shape: [b, t] lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise) # lsnr_gth shape: [b, t] loss = F.mse_loss(lsnr, lsnr_gth) return loss class DfNetPretrainedModel(DfNet): def __init__(self, config: DfNetConfig, ): super(DfNetPretrainedModel, self).__init__( config=config, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): config = DfNetConfig() model = DfNetPretrainedModel(config=config) noisy = torch.randn(size=(1, 16000), dtype=torch.float32) est_spec, est_wav, est_mask, lsnr = model.forward(noisy) print(f"est_spec.shape: {est_spec.shape}") print(f"est_wav.shape: {est_wav.shape}") print(f"est_mask.shape: {est_mask.shape}") print(f"lsnr.shape: {lsnr.shape}") return if __name__ == "__main__": main()