Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
""" | |
DeepFilterNet 的原生实现不直接支持流式推理 | |
社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现 | |
https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF | |
此文件试图实现一个支持流式推理的 dfnet | |
""" | |
import os | |
import math | |
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
from toolbox.torchaudio.models.dfnet.configuration_dfnet import DfNetConfig | |
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT | |
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget | |
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands | |
MODEL_FILE = "model.pt" | |
norm_layer_dict = { | |
"batch_norm_2d": torch.nn.BatchNorm2d | |
} | |
activation_layer_dict = { | |
"relu": torch.nn.ReLU, | |
"identity": torch.nn.Identity, | |
"sigmoid": torch.nn.Sigmoid, | |
} | |
class CausalConv2d(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
pad_f_dim: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
): | |
super(CausalConv2d, self).__init__() | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size) | |
if pad_f_dim: | |
fpad = kernel_size[1] // 2 + dilation - 1 | |
else: | |
fpad = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
self.lookback = kernel_size[0] - 1 | |
if self.lookback > 0: | |
self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0) | |
else: | |
self.tpad = nn.Identity() | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
if max(kernel_size) == 1: | |
separable = False | |
self.conv = nn.Conv2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(0, fpad), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
if separable: | |
self.convp = nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
else: | |
self.convp = nn.Identity() | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
self.norm = norm_layer(out_channels) | |
else: | |
self.norm = nn.Identity() | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
self.activation = activation_layer() | |
else: | |
self.activation = nn.Identity() | |
super().__init__() | |
def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None): | |
""" | |
:param inputs: shape: [b, c, t, f] | |
:param cache: shape: [b, c, lookback, f]; | |
:return: | |
""" | |
x = inputs | |
if cache is None: | |
x = self.tpad(x) | |
else: | |
x = torch.concat(tensors=[cache, x], dim=2) | |
new_cache = x[:, :, -self.lookback:, :] | |
x = self.conv(x) | |
x = self.convp(x) | |
x = self.norm(x) | |
x = self.activation(x) | |
return x, new_cache | |
class CausalConvTranspose2d(nn.Module): | |
def __init__(self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: Union[int, Iterable[int]], | |
fstride: int = 1, | |
dilation: int = 1, | |
pad_f_dim: bool = True, | |
bias: bool = True, | |
separable: bool = False, | |
norm_layer: str = "batch_norm_2d", | |
activation_layer: str = "relu", | |
): | |
super(CausalConvTranspose2d, self).__init__() | |
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size | |
if pad_f_dim: | |
fpad = kernel_size[1] // 2 | |
else: | |
fpad = 0 | |
# for last 2 dim, pad (left, right, top, bottom). | |
self.lookback = kernel_size[0] - 1 | |
groups = math.gcd(in_channels, out_channels) if separable else 1 | |
if groups == 1: | |
separable = False | |
self.convt = nn.ConvTranspose2d( | |
in_channels, | |
out_channels, | |
kernel_size=kernel_size, | |
padding=(0, fpad), | |
output_padding=(0, 0), | |
stride=(1, fstride), # stride over time is always 1 | |
dilation=(1, dilation), # dilation over time is always 1 | |
groups=groups, | |
bias=bias, | |
) | |
if separable: | |
self.convp = nn.Conv2d( | |
out_channels, | |
out_channels, | |
kernel_size=1, | |
bias=False, | |
) | |
else: | |
self.convp = nn.Identity() | |
if norm_layer is not None: | |
norm_layer = norm_layer_dict[norm_layer] | |
self.norm = norm_layer(out_channels) | |
else: | |
self.norm = nn.Identity() | |
if activation_layer is not None: | |
activation_layer = activation_layer_dict[activation_layer] | |
self.activation = activation_layer() | |
else: | |
self.activation = nn.Identity() | |
def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None): | |
""" | |
:param inputs: shape: [b, c, t, f] | |
:param cache: shape: [b, c, lookback, f]; | |
:return: | |
""" | |
x = inputs | |
# x shape: [b, c, t, f] | |
x = self.convt(x) | |
# x shape: [b, c, t+lookback, f] | |
if cache is not None: | |
x = torch.concat(tensors=[ | |
x[:, :, :self.lookback, :] + cache, | |
x[:, :, self.lookback:, :] | |
], dim=2) | |
x = x[:, :, :-self.lookback, :] | |
new_cache = x[:, :, -self.lookback:, :] | |
x = self.convp(x) | |
x = self.norm(x) | |
x = self.activation(x) | |
return x, new_cache | |
if __name__ == "__main__": | |
pass | |