#!/usr/bin/python3 # -*- coding: utf-8 -*- import logging import math from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from toolbox.torchaudio.models.dfnet3.configuration_dfnet3 import DfNetConfig from toolbox.torchaudio.models.dfnet3 import multiframes as MF from toolbox.torchaudio.models.dfnet3 import utils logger = logging.getLogger("toolbox") PI = 3.1415926535897932384626433 norm_layer_dict = { "batch_norm_2d": torch.nn.BatchNorm2d } activation_layer_dict = { "relu": torch.nn.ReLU, "identity": torch.nn.Identity, "sigmoid": torch.nn.Sigmoid, } class CausalConv2d(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, fpad: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", ): """ Causal Conv2d by delaying the signal for any lookahead. Expected input format: [B, C, T, F] :param in_channels: :param out_channels: :param kernel_size: :param fstride: :param dilation: :param fpad: """ super(CausalConv2d, self).__init__() lookahead = 0 kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) if fpad: fpad_ = kernel_size[1] // 2 + dilation - 1 else: fpad_ = 0 # for last 2 dim, pad (left, right, top, bottom). pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) layers = [] if any(x > 0 for x in pad): layers.append(nn.ConstantPad2d(pad, 0.0)) groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False if max(kernel_size) == 1: separable = False layers.append( nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, fpad_), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) ) if separable: layers.append( nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) ) if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] layers.append(norm_layer(out_channels)) if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] layers.append(activation_layer()) super().__init__(*layers) class CausalConvTranspose2d(nn.Sequential): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, fpad: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", ): """ Causal ConvTranspose2d. Expected input format: [B, C, T, F] """ super(CausalConvTranspose2d, self).__init__() lookahead = 0 kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size if fpad: fpad_ = kernel_size[1] // 2 else: fpad_ = 0 # for last 2 dim, pad (left, right, top, bottom). pad = (0, 0, kernel_size[0] - 1 - lookahead, lookahead) layers = [] if any(x > 0 for x in pad): layers.append(nn.ConstantPad2d(pad, 0.0)) groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False layers.append( nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, padding=(kernel_size[0] - 1, fpad_ + dilation - 1), output_padding=(0, fpad_), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) ) if separable: layers.append( nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) ) if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] layers.append(norm_layer(out_channels)) if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] layers.append(activation_layer()) super().__init__(*layers) class GroupedLinear(nn.Module): def __init__(self, input_size: int, hidden_size: int, groups: int = 1): super().__init__() # self.weight: Tensor self.input_size = input_size self.hidden_size = hidden_size self.groups = groups assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}" assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}" self.ws = input_size // groups self.register_parameter( "weight", torch.nn.Parameter( torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True ), ) self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore def forward(self, x: torch.Tensor) -> torch.Tensor: # x: [..., I] b, t, _ = x.shape # new_shape = list(x.shape)[:-1] + [self.groups, self.ws] new_shape = (b, t, self.groups, self.ws) x = x.view(new_shape) # The better way, but not supported by torchscript # x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G] x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G] x = x.flatten(2, 3) # [B, T, H] return x def __repr__(self): cls = self.__class__.__name__ return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})" class SqueezedGRU_S(nn.Module): """ SGE net: Video object detection with squeezed GRU and information entropy map https://arxiv.org/abs/2106.07224 """ def __init__( self, input_size: int, hidden_size: int, output_size: Optional[int] = None, num_layers: int = 1, linear_groups: int = 8, batch_first: bool = True, skip_op: str = "none", activation_layer: str = "identity", ): super().__init__() self.input_size = input_size self.hidden_size = hidden_size self.linear_in = nn.Sequential( GroupedLinear( input_size=input_size, hidden_size=hidden_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) # gru skip operator self.gru_skip_op = None if skip_op == "none": self.gru_skip_op = None elif skip_op == "identity": if not input_size != output_size: raise AssertionError("Dimensions do not match") self.gru_skip_op = nn.Identity() elif skip_op == "grouped_linear": self.gru_skip_op = GroupedLinear( input_size=hidden_size, hidden_size=hidden_size, groups=linear_groups, ) else: raise NotImplementedError() self.gru = nn.GRU( input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=batch_first, ) if output_size is not None: self.linear_out = nn.Sequential( GroupedLinear( input_size=hidden_size, hidden_size=output_size, groups=linear_groups, ), activation_layer_dict[activation_layer](), ) else: self.linear_out = nn.Identity() def forward(self, inputs: torch.Tensor, h=None) -> Tuple[torch.Tensor, torch.Tensor]: x = self.linear_in(inputs) x, h = self.gru(x, h) x = self.linear_out(x) if self.gru_skip_op is not None: x = x + self.gru_skip_op(inputs) return x, h class Add(nn.Module): def forward(self, a, b): return a + b class Concat(nn.Module): def forward(self, a, b): return torch.cat((a, b), dim=-1) class Encoder(nn.Module): def __init__(self, config: DfNetConfig): super(Encoder, self).__init__() self.emb_in_dim = config.conv_channels * config.erb_bins // 4 self.emb_out_dim = config.conv_channels * config.erb_bins // 4 self.emb_hidden_dim = config.emb_hidden_dim self.erb_conv0 = CausalConv2d( in_channels=1, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, ) self.erb_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.erb_conv2 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.erb_conv3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=1, ) self.df_conv0 = CausalConv2d( in_channels=2, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_input, bias=False, separable=True, ) self.df_conv1 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, fstride=2, ) self.df_fc_emb = nn.Sequential( GroupedLinear( config.conv_channels * config.df_bins // 2, self.emb_in_dim, groups=config.encoder_linear_groups ), nn.ReLU(inplace=True) ) if config.encoder_concat: self.emb_in_dim *= 2 self.combine = Concat() else: self.combine = Add() self.emb_gru = SqueezedGRU_S( self.emb_in_dim, self.emb_hidden_dim, output_size=self.emb_out_dim, num_layers=1, batch_first=True, skip_op=config.encoder_gru_skip_op, linear_groups=config.encoder_squeezed_gru_linear_groups, activation_layer="relu", ) self.lsnr_fc = nn.Sequential( nn.Linear(self.emb_out_dim, 1), nn.Sigmoid() ) self.lsnr_scale = config.lsnr_max - config.lsnr_min self.lsnr_offset = config.lsnr_min def forward(self, feat_erb: torch.Tensor, feat_spec: torch.Tensor, h: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # Encodes erb; erb should be in dB scale + normalized; Fe are number of erb bands. # erb: [B, 1, T, Fe] # spec: [B, 2, T, Fc] # b, _, t, _ = feat_erb.shape e0 = self.erb_conv0(feat_erb) # [B, C, T, F] e1 = self.erb_conv1(e0) # [B, C*2, T, F/2] e2 = self.erb_conv2(e1) # [B, C*4, T, F/4] e3 = self.erb_conv3(e2) # [B, C*4, T, F/4] c0 = self.df_conv0(feat_spec) # [B, C, T, Fc] c1 = self.df_conv1(c0) # [B, C*2, T, Fc/2] cemb = c1.permute(0, 2, 3, 1).flatten(2) # [B, T, -1] cemb = self.df_fc_emb(cemb) # [T, B, C * F/4] emb = e3.permute(0, 2, 3, 1).flatten(2) # [B, T, C * F] emb = self.combine(emb, cemb) emb, h = self.emb_gru(emb, h) # [B, T, -1] lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset return e0, e1, e2, e3, emb, c0, lsnr, h class ErbDecoder(nn.Module): def __init__(self, config: DfNetConfig, ): super(ErbDecoder, self).__init__() if config.erb_bins % 8 != 0: raise AssertionError("erb_bins should be divisible by 8") self.emb_in_dim = config.conv_channels * config.erb_bins // 4 self.emb_out_dim = config.conv_channels * config.erb_bins // 4 self.emb_hidden_dim = config.emb_hidden_dim self.emb_gru = SqueezedGRU_S( self.emb_in_dim, self.emb_hidden_dim, output_size=self.emb_out_dim, num_layers=config.erb_decoder_emb_num_layers - 1, batch_first=True, skip_op=config.erb_decoder_gru_skip_op, linear_groups=config.erb_decoder_linear_groups, activation_layer="relu", ) # convt: TransposedConvolution, convp: Pathway (encoder to decoder) convolutions self.conv3p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, ) self.convt3 = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=config.conv_kernel_size_inner, bias=False, separable=True, ) self.conv2p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, ) self.convt2 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, fstride=2, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, ) self.conv1p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, ) self.convt1 = CausalConvTranspose2d( in_channels=config.conv_channels, out_channels=config.conv_channels, fstride=2, kernel_size=config.convt_kernel_size_inner, bias=False, separable=True, ) self.conv0p = CausalConv2d( in_channels=config.conv_channels, out_channels=config.conv_channels, kernel_size=1, bias=False, separable=True, ) self.conv0_out = CausalConv2d( in_channels=config.conv_channels, out_channels=1, kernel_size=config.conv_kernel_size_inner, activation_layer="sigmoid", bias=False, separable=True, ) def forward(self, emb, e3, e2, e1, e0) -> torch.Tensor: # Estimates erb mask b, _, t, f8 = e3.shape emb, _ = self.emb_gru(emb) emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2) # [B, C*8, T, F/8] e3 = self.convt3(self.conv3p(e3) + emb) # [B, C*4, T, F/4] e2 = self.convt2(self.conv2p(e2) + e3) # [B, C*2, T, F/2] e1 = self.convt1(self.conv1p(e1) + e2) # [B, C, T, F] m = self.conv0_out(self.conv0p(e0) + e1) # [B, 1, T, F] return m class Mask(nn.Module): def __init__(self, erb_inv_fb: torch.FloatTensor, post_filter: bool = False, eps: float = 1e-12): super().__init__() self.erb_inv_fb: torch.FloatTensor self.register_buffer("erb_inv_fb", erb_inv_fb.float()) self.clamp_tensor = torch.__version__ > "1.9.0" or torch.__version__ == "1.9.0" self.post_filter = post_filter self.eps = eps def pf(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor: """ Post-Filter A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech. https://arxiv.org/abs/2008.04259 :param mask: Real valued mask, typically of shape [B, C, T, F]. :param beta: Global gain factor. :return: """ mask_sin = mask * torch.sin(np.pi * mask / 2) mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2)) return mask_pf def forward(self, spec: torch.Tensor, mask: torch.Tensor, atten_lim: Optional[torch.Tensor] = None) -> torch.Tensor: # spec (real) [B, 1, T, F, 2], F: freq_bins # mask (real): [B, 1, T, Fe], Fe: erb_bins # atten_lim: [B] if not self.training and self.post_filter: mask = self.pf(mask) if atten_lim is not None: # dB to amplitude atten_lim = 10 ** (-atten_lim / 20) # Greater equal (__ge__) not implemented for TorchVersion. if self.clamp_tensor: # Supported by torch >= 1.9 mask = mask.clamp(min=atten_lim.view(-1, 1, 1, 1)) else: m_out = [] for i in range(atten_lim.shape[0]): m_out.append(mask[i].clamp_min(atten_lim[i].item())) mask = torch.stack(m_out, dim=0) mask = mask.matmul(self.erb_inv_fb) # [B, 1, T, F] if not spec.is_complex(): mask = mask.unsqueeze(4) return spec * mask class DfDecoder(nn.Module): def __init__(self, config: DfNetConfig, ): super().__init__() layer_width = config.conv_channels self.emb_in_dim = config.conv_channels * config.erb_bins // 4 self.emb_dim = config.df_hidden_dim self.df_n_hidden = config.df_hidden_dim self.df_n_layers = config.df_num_layers self.df_order = config.df_order self.df_bins = config.df_bins self.df_out_ch = config.df_order * 2 self.df_convp = CausalConv2d( layer_width, self.df_out_ch, fstride=1, kernel_size=(config.df_pathway_kernel_size_t, 1), separable=True, bias=False, ) self.df_gru = SqueezedGRU_S( self.emb_in_dim, self.emb_dim, num_layers=self.df_n_layers, batch_first=True, skip_op="none", activation_layer="relu", ) if config.df_gru_skip == "none": self.df_skip = None elif config.df_gru_skip == "identity": if config.emb_hidden_dim != config.df_hidden_dim: raise AssertionError("Dimensions do not match") self.df_skip = nn.Identity() elif config.df_gru_skip == "grouped_linear": self.df_skip = GroupedLinear(self.emb_in_dim, self.emb_dim, groups=config.df_decoder_linear_groups) else: raise NotImplementedError() self.df_out: nn.Module out_dim = self.df_bins * self.df_out_ch self.df_out = nn.Sequential( GroupedLinear( input_size=self.df_n_hidden, hidden_size=out_dim, groups=config.df_decoder_linear_groups ), nn.Tanh() ) self.df_fc_a = nn.Sequential( nn.Linear(self.df_n_hidden, 1), nn.Sigmoid() ) def forward(self, emb: torch.Tensor, c0: torch.Tensor) -> torch.Tensor: b, t, _ = emb.shape c, _ = self.df_gru(emb) # [B, T, H], H: df_n_hidden if self.df_skip is not None: c = c + self.df_skip(emb) c0 = self.df_convp(c0).permute(0, 2, 3, 1) # [B, T, F, O*2], channels_last c = self.df_out(c) # [B, T, F*O*2], O: df_order c = c.view(b, t, self.df_bins, self.df_out_ch) + c0 # [B, T, F, O*2] return c class DfOutputReshapeMF(nn.Module): """Coefficients output reshape for multiframe/MultiFrameModule Requires input of shape B, C, T, F, 2. """ def __init__(self, df_order: int, df_bins: int): super().__init__() self.df_order = df_order self.df_bins = df_bins def forward(self, coefs: torch.Tensor) -> torch.Tensor: # [B, T, F, O*2] -> [B, O, T, F, 2] new_shape = list(coefs.shape) new_shape[-1] = -1 new_shape.append(2) coefs = coefs.view(new_shape) coefs = coefs.permute(0, 3, 1, 2, 4) return coefs class DfNet(nn.Module): """ DeepFilterNet: Perceptually Motivated Real-Time Speech Enhancement https://arxiv.org/abs/2305.08227 hendrik.m.schroeter@fau.de """ def __init__(self, config: DfNetConfig, erb_fb: torch.FloatTensor, erb_inv_fb: torch.FloatTensor, run_df: bool = True, train_mask: bool = True, ): """ :param erb_fb: erb filter bank. """ super(DfNet, self).__init__() if config.erb_bins % 8 != 0: raise AssertionError("erb_bins should be divisible by 8") self.df_lookahead = config.df_lookahead self.df_bins = config.df_bins self.freq_bins: int = config.fft_size // 2 + 1 self.emb_dim: int = config.conv_channels * config.erb_bins self.erb_bins: int = config.erb_bins if config.conv_lookahead > 0: if config.conv_lookahead < config.df_lookahead: raise AssertionError # for last 2 dim, pad (left, right, top, bottom). self.pad_feat = nn.ConstantPad2d((0, 0, -config.conv_lookahead, config.conv_lookahead), 0.0) else: self.pad_feat = nn.Identity() if config.df_lookahead > 0: # for last 3 dim, pad (left, right, top, bottom, front, back). self.pad_spec = nn.ConstantPad3d((0, 0, 0, 0, -config.df_lookahead, config.df_lookahead), 0.0) else: self.pad_spec = nn.Identity() self.register_buffer("erb_fb", erb_fb) self.enc = Encoder(config) self.erb_dec = ErbDecoder(config) self.mask = Mask(erb_inv_fb) self.erb_inv_fb = erb_inv_fb self.post_filter = config.mask_post_filter self.post_filter_beta = config.post_filter_beta self.df_order = config.df_order self.df_op = MF.DF(num_freqs=config.df_bins, frame_size=config.df_order, lookahead=self.df_lookahead) self.df_dec = DfDecoder(config) self.df_out_transform = DfOutputReshapeMF(self.df_order, config.df_bins) self.run_erb = config.df_bins + 1 < self.freq_bins if not self.run_erb: logger.warning("Running without ERB stage") self.run_df = run_df if not run_df: logger.warning("Running without DF stage") self.train_mask = train_mask self.lsnr_dropout = config.lsnr_dropout if config.df_n_iter != 1: raise AssertionError def forward1( self, spec: torch.Tensor, feat_erb: torch.Tensor, feat_spec: torch.Tensor, # Not used, take spec modified by mask instead ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Forward method of DeepFilterNet2. Args: spec (Tensor): Spectrum of shape [B, 1, T, F, 2] feat_erb (Tensor): ERB features of shape [B, 1, T, E] feat_spec (Tensor): Complex spectrogram features of shape [B, 1, T, F', 2] Returns: spec (Tensor): Enhanced spectrum of shape [B, 1, T, F, 2] m (Tensor): ERB mask estimate of shape [B, 1, T, E] lsnr (Tensor): Local SNR estimate of shape [B, T, 1] """ # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2] feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) # feat_spec shape: [batch_size, 2, time_steps, freq_dim] # feat_erb shape: [batch_size, 1, time_steps, erb_bins] # assert time_steps >= conv_lookahead. feat_erb = self.pad_feat(feat_erb) feat_spec = self.pad_feat(feat_spec) e0, e1, e2, e3, emb, c0, lsnr, h = self.enc(feat_erb, feat_spec) if self.lsnr_droput: idcs = lsnr.squeeze() > -10.0 b, t = (spec.shape[0], spec.shape[2]) m = torch.zeros((b, 1, t, self.erb_bins), device=spec.device) df_coefs = torch.zeros((b, t, self.nb_df, self.df_order * 2)) spec_m = spec.clone() emb = emb[:, idcs] e0 = e0[:, :, idcs] e1 = e1[:, :, idcs] e2 = e2[:, :, idcs] e3 = e3[:, :, idcs] c0 = c0[:, :, idcs] if self.run_erb: if self.lsnr_dropout: m[:, :, idcs] = self.erb_dec(emb, e3, e2, e1, e0) else: m = self.erb_dec(emb, e3, e2, e1, e0) spec_m = self.mask(spec, m) else: m = torch.zeros((), device=spec.device) spec_m = torch.zeros_like(spec) if self.run_df: if self.lsnr_dropout: df_coefs[:, idcs] = self.df_dec(emb, c0) else: df_coefs = self.df_dec(emb, c0) df_coefs = self.df_out_transform(df_coefs) spec_e = self.df_op(spec.clone(), df_coefs) spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :] else: df_coefs = torch.zeros((), device=spec.device) spec_e = spec_m if self.post_filter: beta = self.post_filter_beta eps = 1e-12 mask = (utils.as_complex(spec_e).abs() / utils.as_complex(spec).abs().add(eps)).clamp(eps, 1) mask_sin = mask * torch.sin(PI * mask / 2).clamp_min(eps) pf = (1 + beta) / (1 + beta * mask.div(mask_sin).pow(2)) spec_e = spec_e * pf.unsqueeze(-1) return spec_e, m, lsnr, df_coefs def forward( self, spec: torch.Tensor, feat_erb: torch.Tensor, feat_spec: torch.Tensor, # Not used, take spec modified by mask instead erb_encoder_h: torch.Tensor = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: # feat_spec shape: [batch_size, 1, time_steps, freq_dim, 2] feat_spec = feat_spec.squeeze(1).permute(0, 3, 1, 2) # feat_spec shape: [batch_size, 2, time_steps, freq_dim] # feat_erb shape: [batch_size, 1, time_steps, erb_bins] # assert time_steps >= conv_lookahead. feat_erb = self.pad_feat(feat_erb) feat_spec = self.pad_feat(feat_spec) e0, e1, e2, e3, emb, c0, lsnr, erb_encoder_h = self.enc(feat_erb, feat_spec, erb_encoder_h) m = self.erb_dec(emb, e3, e2, e1, e0) spec_m = self.mask(spec, m) # spec_e = spec_m df_coefs = self.df_dec(emb, c0) df_coefs = self.df_out_transform(df_coefs) spec_e = self.df_op(spec.clone(), df_coefs) spec_e[..., self.df_bins:, :] = spec_m[..., self.df_bins:, :] return spec_e, m, lsnr, df_coefs, erb_encoder_h if __name__ == "__main__": pass