#!/usr/bin/python3 # -*- coding: utf-8 -*- import os import math from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from torch.nn import functional as F import torchaudio from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig from toolbox.torchaudio.models.dfnet.conv_stft import ConvSTFT, ConviSTFT MODEL_FILE = "model.pt" norm_layer_dict = { "batch_norm_2d": torch.nn.BatchNorm2d } activation_layer_dict = { "relu": torch.nn.ReLU, "identity": torch.nn.Identity, "sigmoid": torch.nn.Sigmoid, } class CausalConv2d(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, fpad: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", lookahead: int = 0 ): """ Causal Conv2d by delaying the signal for any lookahead. Expected input format: [batch_size, channels, time_steps, spec_dim] :param in_channels: :param out_channels: :param kernel_size: :param fstride: :param dilation: :param fpad: """ super(CausalConv2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) if fpad: fpad_ = kernel_size[1] // 2 + dilation - 1 else: fpad_ = 0 # for last 2 dim, pad (left, right, top, bottom). pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) layers = list() if any(x > 0 for x in pad): layers.append(nn.ConstantPad2d(pad, 0.0)) groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False if max(kernel_size) == 1: separable = False layers.append( nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, fpad_), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) ) if separable: layers.append( nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) ) if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] layers.append(norm_layer(out_channels)) if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] layers.append(activation_layer()) super().__init__(*layers) def forward(self, inputs): for module in self: inputs = module(inputs) return inputs class CausalConvTranspose2d(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, fpad: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", lookahead: int = 0 ): """ Causal ConvTranspose2d. Expected input format: [batch_size, channels, time_steps, spec_dim] """ super(CausalConvTranspose2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size if fpad: fpad_ = kernel_size[1] // 2 else: fpad_ = 0 # for last 2 dim, pad (left, right, top, bottom). pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) layers = [] if any(x > 0 for x in pad): layers.append(nn.ConstantPad2d(pad, 0.0)) groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False layers.append( nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size[0] - 1, fpad_ + dilation - 1), output_padding=(0, fpad_), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) ) if separable: layers.append( nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) ) if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] layers.append(norm_layer(out_channels)) if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] layers.append(activation_layer()) super().__init__(*layers) class GroupedLinear(nn.Module): def __init__(self, input_size: int, hidden_size: int, groups: int = 1): super().__init__() # self.weight: Tensor self.input_size = input_size self.hidden_size = hidden_size self.groups = groups assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" self.ws = input_size // groups self.register_parameter( "weight", torch.nn.Parameter( torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True ), ) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [..., I] b, t, _ = x.shape # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] new_shape = (b, t, self.groups, self.ws) x = x.view(new_shape) # The better way, but not supported by torchscript # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] x = x.flatten(2, 3) # [B, T, H] return x def __repr__(self): cls = self.__class__.__name__ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" class SqueezedGRU_S(nn.Module): """ SGE net: Video object detection with squeezed GRU and information entropy map https://arxiv.org/abs/2106.07224 """ def __init__( self, input_size: int, hidden_size: int, output_size: Optional[int] = None, num_layers: int = 1, linear_groups: int = 8, batch_first: bool = True, skip_op: str = "none", activation_layer: str = "identity", ): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.linear_in = nn.Sequential( GroupedLinear( input_size=input_size, hidden_size=hidden_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) # gru skip operator self.gru_skip_op = None if skip_op == "none": self.gru_skip_op = None elif skip_op == "identity": if not input_size != output_size: raise AssertionError("Dimensions do not match") self.gru_skip_op = nn.Identity() elif skip_op == "grouped_linear": self.gru_skip_op = GroupedLinear( input_size=hidden_size, hidden_size=hidden_size, groups=linear_groups, ) else: raise NotImplementedError() self.gru = nn.GRU( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first, bidirectional=False, ) if output_size is not None: self.linear_out = nn.Sequential( GroupedLinear( input_size=hidden_size, hidden_size=output_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) else: self.linear_out = nn.Identity() def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: x = self.linear_in(inputs) x, h = self.gru.forward(x, h) x = self.linear_out(x) if self.gru_skip_op is not None: x = x + self.gru_skip_op(inputs) return x, h class Add(nn.Module): def forward(self, a, b): return a + b class Concat(nn.Module): def forward(self, a, b): return torch.cat((a, b), dim=-1) class Encoder(nn.Module): def __init__(self, config: DfNetConfig): super(Encoder, self).__init__() self.embedding_input_size = config.conv_channels * config.spec_bins // 4 self.embedding_output_size = config.conv_channels * config.spec_bins // 4 self.embedding_hidden_size = config.embedding_hidden_size self.spec_conv0 = CausalConv2d( in_channels=1, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.spec_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.spec_conv2 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.spec_conv3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.df_conv0 = CausalConv2d( in_channels=2, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, fstride=1, ) self.df_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.df_fc_emb = nn.Sequential( GroupedLinear( config.conv_channels * config.df_bins // 2, self.embedding_input_size, groups=config.encoder_linear_groups ), nn.ReLU(inplace=True) ) if config.encoder_combine_op == "concat": self.embedding_input_size *= 2 self.combine = Concat() else: self.combine = Add() # emb_gru if config.spec_bins % 8 != 0: raise AssertionError("spec_bins should be divisible by 8") self.emb_gru = SqueezedGRU_S( self.embedding_input_size, self.embedding_hidden_size, output_size=self.embedding_output_size, num_layers=1, batch_first=True, skip_op=config.encoder_emb_skip_op, linear_groups=config.encoder_emb_linear_groups, activation_layer="relu", ) # lsnr self.lsnr_fc = nn.Sequential( nn.Linear(self.embedding_output_size, 1), nn.Sigmoid() ) self.lsnr_scale = config.lsnr_max - config.lsnr_min self.lsnr_offset = config.lsnr_min def forward(self, feat_power: torch.Tensor, feat_spec: torch.Tensor, hidden_state: torch.Tensor = None, ): # feat_power shape: (batch_size, 1, time_steps, spec_dim) e0 = self.spec_conv0.forward(feat_power) e1 = self.spec_conv1.forward(e0) e2 = self.spec_conv2.forward(e1) e3 = self.spec_conv3.forward(e2) # e0 shape: [batch_size, channels, time_steps, spec_dim] # e1 shape: [batch_size, channels, time_steps, spec_dim // 2] # e2 shape: [batch_size, channels, time_steps, spec_dim // 4] # e3 shape: [batch_size, channels, time_steps, spec_dim // 4] # feat_spec, shape: (batch_size, 2, time_steps, df_bins) c0 = self.df_conv0(feat_spec) c1 = self.df_conv1(c0) # c0 shape: [batch_size, channels, time_steps, df_bins] # c1 shape: [batch_size, channels, time_steps, df_bins // 2] cemb = c1.permute(0, 2, 3, 1) # cemb shape: [batch_size, time_steps, df_bins // 2, channels] cemb = cemb.flatten(2) # cemb shape: [batch_size, time_steps, df_bins // 2 * channels] cemb = self.df_fc_emb(cemb) # cemb shape: [batch_size, time_steps, spec_dim // 4 * channels] # e3 shape: [batch_size, channels, time_steps, spec_dim // 4] emb = e3.permute(0, 2, 3, 1) # emb shape: [batch_size, time_steps, spec_dim // 4, channels] emb = emb.flatten(2) # emb shape: [batch_size, time_steps, spec_dim // 4 * channels] emb = self.combine(emb, cemb) # if concat; emb shape: [batch_size, time_steps, spec_dim // 4 * channels * 2] # if add; emb shape: [batch_size, time_steps, spec_dim // 4 * channels] emb, h = self.emb_gru.forward(emb, hidden_state) # emb shape: [batch_size, time_steps, spec_dim // 4 * channels] # h shape: [batch_size, 1, spec_dim] lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset # lsnr shape: [batch_size, time_steps, 1] return e0, e1, e2, e3, emb, c0, lsnr, h class Decoder(nn.Module): def __init__(self, config: DfNetConfig): super(Decoder, self).__init__() if config.spec_bins % 8 != 0: raise AssertionError("spec_bins should be divisible by 8") self.emb_in_dim = config.conv_channels * config.spec_bins // 4 self.emb_out_dim = config.conv_channels * config.spec_bins // 4 self.emb_hidden_dim = config.decoder_emb_hidden_size self.emb_gru = SqueezedGRU_S( self.emb_in_dim, self.emb_hidden_dim, output_size=self.emb_out_dim, num_layers=config.decoder_emb_num_layers - 1, batch_first=True, skip_op=config.decoder_emb_skip_op, linear_groups=config.decoder_emb_linear_groups, activation_layer="relu", ) self.conv3p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.convt3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.conv2p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.convt2 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.conv1p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.convt1 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, fstride=2, lookahead=config.conv_lookahead, ) self.conv0p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) self.conv0_out = CausalConv2d( in_channels=config.conv_channels, out_channels=1, kernel_size=config.conv_kernel_size_inner, activation_layer="sigmoid", bias=False, separable=True, fstride=1, lookahead=config.conv_lookahead, ) def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: # Estimates erb mask b, _, t, f8 = e3.shape # emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels] emb, _ = self.emb_gru(emb) # emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4] emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) e3 = self.convt3(self.conv3p(e3) + emb) # e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4] e2 = self.convt2(self.conv2p(e2) + e3) # e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2] e1 = self.convt1(self.conv1p(e1) + e2) # e1 shape: [batch_size, conv_channels, time_steps, freq_dim] mask = self.conv0_out(self.conv0p(e0) + e1) # mask shape: [batch_size, 1, time_steps, freq_dim] return mask class DfDecoder(nn.Module): def __init__(self, config: DfNetConfig): super(DfDecoder, self).__init__() self.embedding_input_size = config.conv_channels * config.spec_bins // 4 self.df_decoder_hidden_size = config.df_decoder_hidden_size self.df_num_layers = config.df_num_layers self.df_order = config.df_order self.df_bins = config.df_bins self.df_out_ch = config.df_order * 2 self.df_convp = CausalConv2d( config.conv_channels, self.df_out_ch, fstride=1, kernel_size=(config.df_pathway_kernel_size_t, 1), separable=True, bias=False, ) self.df_gru = SqueezedGRU_S( self.embedding_input_size, self.df_decoder_hidden_size, num_layers=self.df_num_layers, batch_first=True, skip_op="none", activation_layer="relu", ) if config.df_gru_skip == "none": self.df_skip = None elif config.df_gru_skip == "identity": if config.embedding_hidden_size != config.df_decoder_hidden_size: raise AssertionError("Dimensions do not match") self.df_skip = nn.Identity() elif config.df_gru_skip == "grouped_linear": self.df_skip = GroupedLinear( self.embedding_input_size, self.df_decoder_hidden_size, groups=config.df_decoder_linear_groups ) else: raise NotImplementedError() self.df_out: nn.Module out_dim = self.df_bins * self.df_out_ch self.df_out = nn.Sequential( GroupedLinear( input_size=self.df_decoder_hidden_size, hidden_size=out_dim, groups=config.df_decoder_linear_groups ), nn.Tanh() ) self.df_fc_a = nn.Sequential( nn.Linear(self.df_decoder_hidden_size, 1), nn.Sigmoid() ) def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: # emb shape: [batch_size, time_steps, df_bins // 4 * channels] b, t, _ = emb.shape df_coefs, _ = self.df_gru(emb) if self.df_skip is not None: df_coefs = df_coefs + self.df_skip(emb) # df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size] # c0 shape: [batch_size, channels, time_steps, df_bins] c0 = self.df_convp(c0) # c0 shape: [batch_size, df_order * 2, time_steps, df_bins] c0 = c0.permute(0, 2, 3, 1) # c0 shape: [batch_size, time_steps, df_bins, df_order * 2] df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order # df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2] df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch) # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] df_coefs = df_coefs + c0 # df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2] return df_coefs class DfOutputReshapeMF(nn.Module): """Coefficients output reshape for multiframe/MultiFrameModule Requires input of shape B, C, T, F, 2. """ def __init__(self, df_order: int, df_bins: int): super().__init__() self.df_order = df_order self.df_bins = df_bins def forward(self, coefs: torch.Tensor) -> torch.Tensor: # [B, T, F, O*2] -> [B, O, T, F, 2] new_shape = list(coefs.shape) new_shape[-1] = -1 new_shape.append(2) coefs = coefs.view(new_shape) coefs = coefs.permute(0, 3, 1, 2, 4) return coefs class Mask(nn.Module): def __init__(self, use_post_filter: bool = False, eps: float = 1e-12): super().__init__() self.use_post_filter = use_post_filter self.eps = eps def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: """ Post-Filter A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. https://arxiv.org/abs/2008.04259 :param mask: Real valued mask, typically of shape [B, C, T, F]. :param beta: Global gain factor. :return: """ mask_sin = mask * torch.sin(np.pi * mask / 2) mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) return mask_pf def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # spec shape: [batch_size, 1, time_steps, spec_bins, 2] if not self.training and self.use_post_filter: mask = self.post_filter(mask) # mask shape: [batch_size, 1, time_steps, spec_bins] mask = mask.unsqueeze(4) # mask shape: [batch_size, 1, time_steps, spec_bins, 1] return spec * mask class DeepFiltering(nn.Module): def __init__(self, df_bins: int, df_order: int, lookahead: int = 0, ): super(DeepFiltering, self).__init__() self.df_bins = df_bins self.df_order = df_order self.need_unfold = df_order > 1 self.lookahead = lookahead self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0) def spec_unfold(self, spec: torch.Tensor): """ Pads and unfolds the spectrogram according to frame_size. :param spec: complex Tensor, Spectrogram of shape [B, C, T, F]. :return: Tensor, Unfolded spectrogram of shape [B, C, T, F, N], where N: frame_size. """ if self.need_unfold: # spec shape: [batch_size, spec_bins, time_steps] spec_pad = self.pad(spec) # spec_pad shape: [batch_size, 1, time_steps_pad, spec_bins] spec_unfold = spec_pad.unfold(2, self.df_order, 1) # spec_unfold shape: [batch_size, 1, time_steps, spec_bins, df_order] return spec_unfold else: return spec.unsqueeze(-1) def forward(self, spec: torch.Tensor, coefs: torch.Tensor, ): # spec shape: [batch_size, 1, time_steps, spec_bins, 2] spec_u = self.spec_unfold(torch.view_as_complex(spec.contiguous())) # spec_u shape: [batch_size, 1, time_steps, spec_bins, df_order] # coefs shape: [batch_size, df_order, time_steps, df_bins, 2] coefs = torch.view_as_complex(coefs.contiguous()) # coefs shape: [batch_size, df_order, time_steps, df_bins] spec_f = spec_u.narrow(-2, 0, self.df_bins) # spec_f shape: [batch_size, 1, time_steps, df_bins, df_order] coefs = coefs.view(coefs.shape[0], -1, self.df_order, *coefs.shape[2:]) # coefs shape: [batch_size, 1, df_order, time_steps, df_bins] spec_f = self.df(spec_f, coefs) # spec_f shape: [batch_size, 1, time_steps, df_bins] if self.training: spec = spec.clone() spec[..., :self.df_bins, :] = torch.view_as_real(spec_f) # spec shape: [batch_size, 1, time_steps, spec_bins, 2] return spec @staticmethod def df(spec: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor: """ Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram. :param spec: (complex Tensor). Spectrogram of shape [B, C, T, F, N]. :param coefs: (complex Tensor). Coefficients of shape [B, C, N, T, F]. :return: (complex Tensor). Spectrogram of shape [B, C, T, F]. """ return torch.einsum("...tfn,...ntf->...tf", spec, coefs) class DfNet(nn.Module): def __init__(self, config: DfNetConfig): super(DfNet, self).__init__() self.config = config self.freq_bins = self.config.nfft // 2 + 1 self.nfft = config.nfft self.win_size = config.win_size self.hop_size = config.hop_size self.win_type = config.win_type self.stft = ConvSTFT( nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, feature_type="complex", requires_grad=False ) self.istft = ConviSTFT( nfft=config.nfft, win_size=config.win_size, hop_size=config.hop_size, win_type=config.win_type, feature_type="complex", requires_grad=False ) self.encoder = Encoder(config) self.decoder = Decoder(config) self.df_decoder = DfDecoder(config) self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins) self.df_op = DeepFiltering( df_bins=config.df_bins, df_order=config.df_order, lookahead=config.df_lookahead, ) self.mask = Mask(use_post_filter=config.use_post_filter) def forward(self, noisy: torch.Tensor, ): if noisy.dim() == 2: noisy = torch.unsqueeze(noisy, dim=1) _, _, n_samples = noisy.shape remainder = (n_samples - self.win_size) % self.hop_size if remainder > 0: n_samples_pad = self.hop_size - remainder noisy = F.pad(noisy, pad=(0, n_samples_pad), mode="constant", value=0) # [batch_size, freq_bins * 2, time_steps] cmp_spec = self.stft.forward(noisy) # [batch_size, 1, freq_bins * 2, time_steps] cmp_spec = torch.unsqueeze(cmp_spec, 1) # [batch_size, 2, freq_bins, time_steps] cmp_spec = torch.cat([ cmp_spec[:, :, :self.freq_bins, :], cmp_spec[:, :, self.freq_bins:, :], ], dim=1) # n//2+1 -> n//2; 257 -> 256 cmp_spec = cmp_spec[:, :, :-1, :] spec = torch.unsqueeze(cmp_spec, dim=4) # [batch_size, 2, freq_bins, time_steps, 1] spec = spec.permute(0, 4, 3, 2, 1) # spec shape: [batch_size, 1, time_steps, freq_bins, 2] feat_power = torch.sum(torch.square(spec), dim=-1) # feat_power shape: [batch_size, 1, time_steps, spec_bins] feat_spec = torch.transpose(cmp_spec, dim0=2, dim1=3) # feat_spec shape: [batch_size, 2, time_steps, freq_bins] feat_spec = feat_spec[..., :self.df_decoder.df_bins] # feat_spec shape: [batch_size, 2, time_steps, df_bins] e0, e1, e2, e3, emb, c0, lsnr, h = self.encoder.forward(feat_power, feat_spec) mask = self.decoder.forward(emb, e3, e2, e1, e0) # mask shape: [batch_size, 1, time_steps, spec_bins] if torch.any(mask > 1) or torch.any(mask < 0): raise AssertionError spec_m = self.mask.forward(spec, mask) # lsnr shape: [batch_size, time_steps, 1] lsnr = torch.transpose(lsnr, dim0=2, dim1=1) # lsnr shape: [batch_size, 1, time_steps] df_coefs = self.df_decoder.forward(emb, c0) df_coefs = self.df_out_transform(df_coefs) # df_coefs shape: [batch_size, df_order, time_steps, df_bins, 2] spec_e = self.df_op.forward(spec.clone(), df_coefs) # est_spec shape: [batch_size, 1, time_steps, spec_bins, 2] spec_e[..., self.df_decoder.df_bins:, :] = spec_m[..., self.df_decoder.df_bins:, :] spec_e = torch.squeeze(spec_e, dim=1) spec_e = spec_e.permute(0, 2, 1, 3) # spec_e shape: [batch_size, spec_bins, time_steps, 2] mask = torch.squeeze(mask, dim=1) est_mask = mask.permute(0, 2, 1) # mask shape: [batch_size, spec_bins, time_steps] b, _, t, _ = spec_e.shape est_spec = torch.cat(tensors=[ torch.concat(tensors=[ spec_e[..., 0], torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device) ], dim=1), torch.concat(tensors=[ spec_e[..., 1], torch.zeros(size=(b, 1, t), dtype=spec_e.dtype).to(spec_e.device) ], dim=1), ], dim=1) # est_spec shape: [b, n+2, t] est_wav = self.istft.forward(est_spec) est_wav = torch.squeeze(est_wav, dim=1) est_wav = est_wav[:, :n_samples] # est_wav shape: [b, n_samples] return est_spec, est_wav, est_mask, lsnr def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor): """ :param est_mask: torch.Tensor, shape: [b, n+2, t] :param clean: :param noisy: :return: """ clean_stft = self.stft(clean) clean_re = clean_stft[:, :self.freq_bins, :] clean_im = clean_stft[:, self.freq_bins:, :] noisy_stft = self.stft(noisy) noisy_re = noisy_stft[:, :self.freq_bins, :] noisy_im = noisy_stft[:, self.freq_bins:, :] noisy_power = noisy_re ** 2 + noisy_im ** 2 sr = clean_re yr = noisy_re si = clean_im yi = noisy_im y_pow = noisy_power # (Sr * Yr + Si * Yi) / (Y_pow + 1e-8) gth_mask_re = (sr * yr + si * yi) / (y_pow + self.eps) # (Si * Yr - Sr * Yi) / (Y_pow + 1e-8) gth_mask_im = (sr * yr - si * yi) / (y_pow + self.eps) gth_mask_re[gth_mask_re > 2] = 1 gth_mask_re[gth_mask_re < -2] = -1 gth_mask_im[gth_mask_im > 2] = 1 gth_mask_im[gth_mask_im < -2] = -1 mask_re = est_mask[:, :self.freq_bins, :] mask_im = est_mask[:, self.freq_bins:, :] loss_re = F.mse_loss(gth_mask_re, mask_re) loss_im = F.mse_loss(gth_mask_im, mask_im) loss = loss_re + loss_im return loss class DfNetPretrainedModel(DfNet): def __init__(self, config: DfNetConfig, ): super(DfNetPretrainedModel, self).__init__( config=config, ) @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = DfNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): config = DfNetConfig() model = DfNetPretrainedModel(config=config) noisy = torch.randn(size=(1, 16000), dtype=torch.float32) output = model.forward(noisy) print(output[1].shape) return if __name__ == "__main__": main()