#!/usr/bin/python3 # -*- coding: utf-8 -*- """ DeepFilterNet 的原生实现不直接支持流式推理 社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现 https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF 此文件试图实现一个支持流式推理的 dfnet """ import os import math from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch import torch.nn as nn from torch.nn import functional as F from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands MODEL_FILE = "model.pt" norm_layer_dict = { "batch_norm_2d": torch.nn.BatchNorm2d } activation_layer_dict = { "relu": torch.nn.ReLU, "identity": torch.nn.Identity, "sigmoid": torch.nn.Sigmoid, } class CausalConv2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, pad_f_dim: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", ): super(CausalConv2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) if pad_f_dim: fpad = kernel_size[1] // 2 + dilation - 1 else: fpad = 0 # for last 2 dim, pad (left, right, top, bottom). self.lookback = kernel_size[0] - 1 if self.lookback > 0: self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) else: self.tpad = nn.Identity() groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False if max(kernel_size) == 1: separable = False self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, fpad), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) if separable: self.convp = nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) else: self.convp = nn.Identity() if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] self.norm = norm_layer(out_channels) else: self.norm = nn.Identity() if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] self.activation = activation_layer() else: self.activation = nn.Identity() super().__init__() def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None): """ :param inputs: shape: [b, c, t, f] :param cache: shape: [b, c, lookback, f]; :return: """ x = inputs if cache is None: x = self.tpad(x) else: x = torch.concat(tensors=[cache, x], dim=2) new_cache = x[:, :, -self.lookback:, :] x = self.conv(x) x = self.convp(x) x = self.norm(x) x = self.activation(x) return x, new_cache class CausalConvTranspose2d(nn.Module): def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Iterable[int]], fstride: int = 1, dilation: int = 1, pad_f_dim: bool = True, bias: bool = True, separable: bool = False, norm_layer: str = "batch_norm_2d", activation_layer: str = "relu", ): super(CausalConvTranspose2d, self).__init__() kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size if pad_f_dim: fpad = kernel_size[1] // 2 else: fpad = 0 # for last 2 dim, pad (left, right, top, bottom). self.lookback = kernel_size[0] - 1 groups = math.gcd(in_channels, out_channels) if separable else 1 if groups == 1: separable = False self.convt = nn.ConvTranspose2d( in_channels, out_channels, kernel_size=kernel_size, padding=(0, fpad), output_padding=(0, 0), stride=(1, fstride), # stride over time is always 1 dilation=(1, dilation), # dilation over time is always 1 groups=groups, bias=bias, ) if separable: self.convp = nn.Conv2d( out_channels, out_channels, kernel_size=1, bias=False, ) else: self.convp = nn.Identity() if norm_layer is not None: norm_layer = norm_layer_dict[norm_layer] self.norm = norm_layer(out_channels) else: self.norm = nn.Identity() if activation_layer is not None: activation_layer = activation_layer_dict[activation_layer] self.activation = activation_layer() else: self.activation = nn.Identity() def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None): """ :param inputs: shape: [b, c, t, f] :param cache: shape: [b, c, lookback, f]; :return: """ x = inputs # x shape: [b, c, t, f] x = self.convt(x) # x shape: [b, c, t+lookback, f] if cache is not None: x = torch.concat(tensors=[ x[:, :, :self.lookback, :] + cache, x[:, :, self.lookback:, :] ], dim=2) x = x[:, :, :-self.lookback, :] new_cache = x[:, :, -self.lookback:, :] x = self.convp(x) x = self.norm(x) x = self.activation(x) return x, new_cache if __name__ == "__main__": pass