HoneyTian's picture
update
55d487a
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
DeepFilterNet 的原生实现不直接支持流式推理
社区开发者(如 Rikorose)提供了基于 Torch 的流式推理实现
https://github.com/grazder/DeepFilterNet/tree/1097015d53ced78fb234e7d7071a5dd4446e3952/torchDF
此文件试图实现一个支持流式推理的 dfnet
"""
import os
import math
from collections import defaultdict
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.dfnet2.configuration_dfnet2 import DfNet2Config
from toolbox.torchaudio.modules.conv_stft import ConvSTFT, ConviSTFT
from toolbox.torchaudio.modules.local_snr_target import LocalSnrTarget
from toolbox.torchaudio.modules.freq_bands.erb_bands import ErbBands
from toolbox.torchaudio.modules.utils.ema import ErbEMA, SpecEMA
MODEL_FILE = "model.pt"
norm_layer_dict = {
"batch_norm_2d": torch.nn.BatchNorm2d
}
activation_layer_dict = {
"relu": torch.nn.ReLU,
"identity": torch.nn.Identity,
"sigmoid": torch.nn.Sigmoid,
}
class CausalConv2d(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
pad_f_dim: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
):
super(CausalConv2d, self).__init__()
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
if pad_f_dim:
fpad = kernel_size[1] // 2 + dilation - 1
else:
fpad = 0
# for last 2 dim, pad (left, right, top, bottom).
self.lookback = kernel_size[0] - 1
if self.lookback > 0:
self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
else:
self.tpad = nn.Identity()
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
if max(kernel_size) == 1:
separable = False
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(0, fpad),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
if separable:
self.convp = nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
else:
self.convp = nn.Identity()
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
self.norm = norm_layer(out_channels)
else:
self.norm = nn.Identity()
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
self.activation = activation_layer()
else:
self.activation = nn.Identity()
def forward(self, inputs: torch.Tensor, cache: Tuple[torch.Tensor, torch.Tensor] = None):
"""
:param inputs: shape: [b, c, t, f]
:param cache: shape: [b, c, lookback, f];
:return:
"""
x = inputs
if cache is None:
x = self.tpad(x)
else:
x = torch.concat(tensors=[cache, x], dim=2)
new_cache = None
if self.lookback > 0:
new_cache = x[:, :, -self.lookback:, :]
x = self.conv(x)
x = self.convp(x)
x = self.norm(x)
x = self.activation(x)
return x, new_cache
class CausalConvTranspose2dErrorCase(nn.Module):
"""
错误的缓存方法。
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
pad_f_dim: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
):
super(CausalConvTranspose2dErrorCase, self).__init__()
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
if pad_f_dim:
fpad = kernel_size[1] // 2
else:
fpad = 0
# for last 2 dim, pad (left, right, top, bottom).
self.lookback = kernel_size[0] - 1
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
self.convt = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(0, fpad),
output_padding=(0, fpad),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
if separable:
self.convp = nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
else:
self.convp = nn.Identity()
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
self.norm = norm_layer(out_channels)
else:
self.norm = nn.Identity()
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
self.activation = activation_layer()
else:
self.activation = nn.Identity()
def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
"""
:param inputs: shape: [b, c, t, f]
:param cache: shape: [b, c, lookback, f];
:return:
"""
x = inputs
# x shape: [b, c, t, f]
x = self.convt(x)
# x shape: [b, c, t+lookback, f]
new_cache = None
if self.lookback > 0:
if cache is not None:
x = torch.concat(tensors=[
x[:, :, :self.lookback, :] + cache,
x[:, :, self.lookback:, :]
], dim=2)
x = x[:, :, :-self.lookback, :]
new_cache = x[:, :, -self.lookback:, :]
x = self.convp(x)
x = self.norm(x)
x = self.activation(x)
return x, new_cache
class CausalConvTranspose2d(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Iterable[int]],
fstride: int = 1,
dilation: int = 1,
pad_f_dim: bool = True,
bias: bool = True,
separable: bool = False,
norm_layer: str = "batch_norm_2d",
activation_layer: str = "relu",
):
super(CausalConvTranspose2d, self).__init__()
kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
if pad_f_dim:
fpad = kernel_size[1] // 2
else:
fpad = 0
# for last 2 dim, pad (left, right, top, bottom).
self.lookback = kernel_size[0] - 1
if self.lookback > 0:
self.tpad = nn.ConstantPad2d(padding=(0, 0, self.lookback, 0), value=0.0)
else:
self.tpad = nn.Identity()
groups = math.gcd(in_channels, out_channels) if separable else 1
if groups == 1:
separable = False
self.convt = nn.ConvTranspose2d(
in_channels,
out_channels,
kernel_size=kernel_size,
padding=(kernel_size[0] - 1, fpad + dilation - 1),
output_padding=(0, fpad),
stride=(1, fstride), # stride over time is always 1
dilation=(1, dilation), # dilation over time is always 1
groups=groups,
bias=bias,
)
if separable:
self.convp = nn.Conv2d(
out_channels,
out_channels,
kernel_size=1,
bias=False,
)
else:
self.convp = nn.Identity()
if norm_layer is not None:
norm_layer = norm_layer_dict[norm_layer]
self.norm = norm_layer(out_channels)
else:
self.norm = nn.Identity()
if activation_layer is not None:
activation_layer = activation_layer_dict[activation_layer]
self.activation = activation_layer()
else:
self.activation = nn.Identity()
def forward(self, inputs: torch.Tensor, cache: torch.Tensor = None):
"""
:param inputs: shape: [b, c, t, f]
:param cache: shape: [b, c, lookback, f];
:return:
"""
x = inputs
# x shape: [b, c, t, f]
x = self.convt(x)
# x shape: [b, c, t+lookback, f]
if cache is None:
x = self.tpad(x)
else:
x = torch.concat(tensors=[cache, x], dim=2)
new_cache = None
if self.lookback > 0:
new_cache = x[:, :, -self.lookback:, :]
x = self.convp(x)
x = self.norm(x)
x = self.activation(x)
return x, new_cache
class GroupedLinear(nn.Module):
def __init__(self, input_size: int, hidden_size: int, groups: int = 1):
super().__init__()
# self.weight: Tensor
self.input_size = input_size
self.hidden_size = hidden_size
self.groups = groups
assert input_size % groups == 0, f"Input size {input_size} not divisible by {groups}"
assert hidden_size % groups == 0, f"Hidden size {hidden_size} not divisible by {groups}"
self.ws = input_size // groups
self.register_parameter(
"weight",
torch.nn.Parameter(
torch.zeros(groups, input_size // groups, hidden_size // groups), requires_grad=True
),
)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) # type: ignore
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: [..., I]
b, t, f = x.shape
if f != self.input_size:
raise AssertionError
# new_shape = list(x.shape)[:-1] + [self.groups, self.ws]
new_shape = (b, t, self.groups, self.ws)
x = x.view(new_shape)
# The better way, but not supported by torchscript
# x = x.unflatten(-1, (self.groups, self.ws)) # [..., G, I/G]
x = torch.einsum("btgi,gih->btgh", x, self.weight) # [..., G, H/G]
x = x.flatten(2, 3)
# x: [b, t, h]
return x
def __repr__(self):
cls = self.__class__.__name__
return f"{cls}(input_size: {self.input_size}, hidden_size: {self.hidden_size}, groups: {self.groups})"
class SqueezedGRU_S(nn.Module):
"""
SGE net: Video object detection with squeezed GRU and information entropy map
https://arxiv.org/abs/2106.07224
"""
def __init__(
self,
input_size: int,
hidden_size: int,
output_size: Optional[int] = None,
num_layers: int = 1,
linear_groups: int = 8,
batch_first: bool = True,
skip_op: str = "none",
activation_layer: str = "identity",
):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.linear_in = nn.Sequential(
GroupedLinear(
input_size=input_size,
hidden_size=hidden_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
# gru skip operator
self.gru_skip_op = None
if skip_op == "none":
self.gru_skip_op = None
elif skip_op == "identity":
if not input_size != output_size:
raise AssertionError("Dimensions do not match")
self.gru_skip_op = nn.Identity()
elif skip_op == "grouped_linear":
self.gru_skip_op = GroupedLinear(
input_size=hidden_size,
hidden_size=hidden_size,
groups=linear_groups,
)
else:
raise NotImplementedError()
self.gru = nn.GRU(
input_size=hidden_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=batch_first,
bidirectional=False,
)
if output_size is not None:
self.linear_out = nn.Sequential(
GroupedLinear(
input_size=hidden_size,
hidden_size=output_size,
groups=linear_groups,
),
activation_layer_dict[activation_layer](),
)
else:
self.linear_out = nn.Identity()
def forward(self, inputs: torch.Tensor, hx: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]:
# inputs: shape: [b, t, h]
x = self.linear_in.forward(inputs)
x, hx = self.gru.forward(x, hx)
x = self.linear_out(x)
if self.gru_skip_op is not None:
x = x + self.gru_skip_op(inputs)
return x, hx
class Add(nn.Module):
def forward(self, a, b):
return a + b
class Concat(nn.Module):
def forward(self, a, b):
return torch.cat((a, b), dim=-1)
class Encoder(nn.Module):
def __init__(self, config: DfNet2Config):
super(Encoder, self).__init__()
self.embedding_input_size = config.conv_channels * config.erb_bins // 4
self.embedding_output_size = config.conv_channels * config.erb_bins // 4
self.embedding_hidden_size = config.embedding_hidden_size
self.spec_conv0 = CausalConv2d(
in_channels=1,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
fstride=1,
)
self.spec_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.spec_conv2 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.spec_conv3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=1,
)
self.df_conv0 = CausalConv2d(
in_channels=2,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_input,
bias=False,
separable=True,
fstride=1,
)
self.df_conv1 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.df_fc_emb = nn.Sequential(
GroupedLinear(
config.conv_channels * config.df_bins // 2,
self.embedding_input_size,
groups=config.encoder_linear_groups
),
nn.ReLU(inplace=True)
)
if config.encoder_combine_op == "concat":
self.embedding_input_size *= 2
self.combine = Concat()
else:
self.combine = Add()
# emb_gru
if config.spec_bins % 8 != 0:
raise AssertionError("spec_bins should be divisible by 8")
self.emb_gru = SqueezedGRU_S(
self.embedding_input_size,
self.embedding_hidden_size,
output_size=self.embedding_output_size,
num_layers=1,
batch_first=True,
skip_op=config.encoder_emb_skip_op,
linear_groups=config.encoder_emb_linear_groups,
activation_layer="relu",
)
# lsnr
self.lsnr_fc = nn.Sequential(
nn.Linear(self.embedding_output_size, 1),
nn.Sigmoid()
)
self.lsnr_scale = config.max_local_snr - config.min_local_snr
self.lsnr_offset = config.min_local_snr
def forward(self,
feat_erb: torch.Tensor,
feat_spec: torch.Tensor,
cache_dict: dict = None,
):
if cache_dict is None:
cache_dict = defaultdict(lambda: None)
cache0 = cache_dict["cache0"]
cache1 = cache_dict["cache1"]
cache2 = cache_dict["cache2"]
cache3 = cache_dict["cache3"]
cache4 = cache_dict["cache4"]
cache5 = cache_dict["cache5"]
cache6 = cache_dict["cache6"]
# feat_erb shape: (b, 1, t, erb_bins)
e0, new_cache0 = self.spec_conv0.forward(feat_erb, cache=cache0)
e1, new_cache1 = self.spec_conv1.forward(e0, cache=cache1)
e2, new_cache2 = self.spec_conv2.forward(e1, cache=cache2)
e3, new_cache3 = self.spec_conv3.forward(e2, cache=cache3)
# e0 shape: [b, c, t, erb_bins]
# e1 shape: [b, c, t, erb_bins // 2]
# e2 shape: [b, c, t, erb_bins // 4]
# e3 shape: [b, c, t, erb_bins // 4]
# e3 shape: [b, 64, t, 32/4=8]
# feat_spec, shape: (b, 2, t, df_bins)
c0, new_cache4 = self.df_conv0.forward(feat_spec, cache=cache4)
c1, new_cache5 = self.df_conv1.forward(c0, cache=cache5)
# c0 shape: [b, c, t, df_bins]
# c1 shape: [b, c, t, df_bins // 2]
# c1 shape: [b, 64, t, 96/2=48]
cemb = c1.permute(0, 2, 3, 1)
# cemb shape: [b, t, df_bins // 2, c]
cemb = cemb.flatten(2)
# cemb shape: [b, t, df_bins // 2 * c]
# cemb shape: [b, t, 96/2*64=3072]
cemb = self.df_fc_emb.forward(cemb)
# cemb shape: [b, t, erb_bins // 4 * c]
# cemb shape: [b, t, 32/4*64=512]
# e3 shape: [b, c, t, erb_bins // 4]
emb = e3.permute(0, 2, 3, 1)
# emb shape: [b, t, erb_bins // 4, c]
emb = emb.flatten(2)
# emb shape: [b, t, erb_bins // 4 * c]
# emb shape: [b, t, 32/4*64=512]
emb = self.combine(emb, cemb)
# if concat; emb shape: [b, t, spec_bins // 4 * c * 2]
# if add; emb shape: [b, t, spec_bins // 4 * c]
emb, new_cache6 = self.emb_gru.forward(emb, hx=cache6)
# emb shape: [b, t, spec_dim // 4 * c]
# h shape: [b, 1, spec_dim]
lsnr = self.lsnr_fc(emb) * self.lsnr_scale + self.lsnr_offset
# lsnr shape: [b, t, 1]
new_cache_dict = {
"cache0": new_cache0,
"cache1": new_cache1,
"cache2": new_cache2,
"cache3": new_cache3,
"cache4": new_cache4,
"cache5": new_cache5,
"cache6": new_cache6,
}
return e0, e1, e2, e3, emb, c0, lsnr, new_cache_dict
class ErbDecoder(nn.Module):
def __init__(self, config: DfNet2Config):
super(ErbDecoder, self).__init__()
if config.spec_bins % 8 != 0:
raise AssertionError("spec_bins should be divisible by 8")
self.emb_in_dim = config.conv_channels * config.erb_bins // 4
self.emb_out_dim = config.conv_channels * config.erb_bins // 4
self.emb_hidden_dim = config.decoder_emb_hidden_size
self.emb_gru = SqueezedGRU_S(
self.emb_in_dim,
self.emb_hidden_dim,
output_size=self.emb_out_dim,
num_layers=config.decoder_emb_num_layers - 1,
batch_first=True,
skip_op=config.decoder_emb_skip_op,
linear_groups=config.decoder_emb_linear_groups,
activation_layer="relu",
)
self.conv3p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
)
self.convt3 = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.conv_kernel_size_inner,
bias=False,
separable=True,
fstride=1,
)
self.conv2p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
)
self.convt2 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.conv1p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
)
self.convt1 = CausalConvTranspose2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=config.convt_kernel_size_inner,
bias=False,
separable=True,
fstride=2,
)
self.conv0p = CausalConv2d(
in_channels=config.conv_channels,
out_channels=config.conv_channels,
kernel_size=1,
bias=False,
separable=True,
fstride=1,
)
self.conv0_out = CausalConv2d(
in_channels=config.conv_channels,
out_channels=1,
kernel_size=config.conv_kernel_size_inner,
activation_layer="sigmoid",
bias=False,
separable=True,
fstride=1,
)
def forward(self, emb, e3, e2, e1, e0, cache_dict: dict = None) -> torch.Tensor:
if cache_dict is None:
cache_dict = defaultdict(lambda: None)
cache0 = cache_dict["cache0"]
cache1 = cache_dict["cache1"]
cache2 = cache_dict["cache2"]
cache3 = cache_dict["cache3"]
cache4 = cache_dict["cache4"]
# Estimates erb mask
b, _, t, f8 = e3.shape
# emb shape: [batch_size, time_steps, (freq_dim // 4) * conv_channels]
emb, new_cache0 = self.emb_gru.forward(emb, hx=cache0)
# emb shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
emb = emb.view(b, t, f8, -1).permute(0, 3, 1, 2)
e3, new_cache1 = self.convt3.forward(self.conv3p(e3)[0] + emb, cache=cache1)
# e3 shape: [batch_size, conv_channels, time_steps, freq_dim // 4]
e2, new_cache2 = self.convt2.forward(self.conv2p(e2)[0] + e3, cache=cache2)
# e2 shape: [batch_size, conv_channels, time_steps, freq_dim // 2]
e1, new_cache3 = self.convt1.forward(self.conv1p(e1)[0] + e2, cache=cache3)
# e1 shape: [batch_size, conv_channels, time_steps, freq_dim]
mask, new_cache4 = self.conv0_out.forward(self.conv0p(e0)[0] + e1, cache=cache4)
# mask shape: [batch_size, 1, time_steps, freq_dim]
new_cache_dict = {
"cache0": new_cache0,
"cache1": new_cache1,
"cache2": new_cache2,
"cache3": new_cache3,
"cache4": new_cache4,
}
return mask, new_cache_dict
class DfDecoder(nn.Module):
def __init__(self, config: DfNet2Config):
super(DfDecoder, self).__init__()
self.embedding_input_size = config.conv_channels * config.erb_bins // 4
self.df_decoder_hidden_size = config.df_decoder_hidden_size
self.df_num_layers = config.df_num_layers
self.df_order = config.df_order
self.df_bins = config.df_bins
self.df_out_ch = config.df_order * 2
self.df_convp = CausalConv2d(
config.conv_channels,
self.df_out_ch,
fstride=1,
kernel_size=(config.df_pathway_kernel_size_t, 1),
separable=True,
bias=False,
)
self.df_gru = SqueezedGRU_S(
self.embedding_input_size,
self.df_decoder_hidden_size,
num_layers=self.df_num_layers,
batch_first=True,
skip_op="none",
activation_layer="relu",
)
if config.df_gru_skip == "none":
self.df_skip = None
elif config.df_gru_skip == "identity":
if config.embedding_hidden_size != config.df_decoder_hidden_size:
raise AssertionError("Dimensions do not match")
self.df_skip = nn.Identity()
elif config.df_gru_skip == "grouped_linear":
self.df_skip = GroupedLinear(
self.embedding_input_size,
self.df_decoder_hidden_size,
groups=config.df_decoder_linear_groups
)
else:
raise NotImplementedError()
self.df_out: nn.Module
out_dim = self.df_bins * self.df_out_ch
self.df_out = nn.Sequential(
GroupedLinear(
input_size=self.df_decoder_hidden_size,
hidden_size=out_dim,
groups=config.df_decoder_linear_groups,
# groups = self.df_bins // 5,
),
nn.Tanh()
)
self.df_fc_a = nn.Sequential(
nn.Linear(self.df_decoder_hidden_size, 1),
nn.Sigmoid()
)
def forward(self, emb: torch.Tensor, c0: torch.Tensor, cache_dict: dict = None) -> torch.Tensor:
if cache_dict is None:
cache_dict = defaultdict(lambda: None)
cache0 = cache_dict["cache0"]
cache1 = cache_dict["cache1"]
# emb shape: [batch_size, time_steps, df_bins // 4 * channels]
b, t, _ = emb.shape
df_coefs, new_cache0 = self.df_gru.forward(emb, hx=cache0)
if self.df_skip is not None:
df_coefs = df_coefs + self.df_skip(emb)
# df_coefs shape: [batch_size, time_steps, df_decoder_hidden_size]
# c0 shape: [batch_size, channels, time_steps, df_bins]
c0, new_cache1 = self.df_convp.forward(c0, cache=cache1)
# c0 shape: [batch_size, df_order * 2, time_steps, df_bins]
c0 = c0.permute(0, 2, 3, 1)
# c0 shape: [batch_size, time_steps, df_bins, df_order * 2]
df_coefs = self.df_out(df_coefs) # [B, T, F*O*2], O: df_order
# df_coefs shape: [batch_size, time_steps, df_bins * df_order * 2]
df_coefs = df_coefs.view(b, t, self.df_bins, self.df_out_ch)
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
df_coefs = df_coefs + c0
# df_coefs shape: [batch_size, time_steps, df_bins, df_order * 2]
new_cache_dict = {
"cache0": new_cache0,
"cache1": new_cache1,
}
return df_coefs, new_cache_dict
class DfOutputReshapeMF(nn.Module):
"""Coefficients output reshape for multiframe/MultiFrameModule
Requires input of shape B, C, T, F, 2.
"""
def __init__(self, df_order: int, df_bins: int):
super().__init__()
self.df_order = df_order
self.df_bins = df_bins
def forward(self, coefs: torch.Tensor) -> torch.Tensor:
# [B, T, F, O*2] -> [B, O, T, F, 2]
new_shape = list(coefs.shape)
new_shape[-1] = -1
new_shape.append(2)
coefs = coefs.view(new_shape)
coefs = coefs.permute(0, 3, 1, 2, 4)
return coefs
class Mask(nn.Module):
def __init__(self, use_post_filter: bool = False, eps: float = 1e-12):
super().__init__()
self.use_post_filter = use_post_filter
self.eps = eps
def post_filter(self, mask: torch.Tensor, beta: float = 0.02) -> torch.Tensor:
"""
Post-Filter
A Perceptually-Motivated Approach for Low-Complexity, Real-Time Enhancement of Fullband Speech.
https://arxiv.org/abs/2008.04259
:param mask: Real valued mask, typically of shape [B, C, T, F].
:param beta: Global gain factor.
:return:
"""
mask_sin = mask * torch.sin(np.pi * mask / 2)
mask_pf = (1 + beta) * mask / (1 + beta * mask.div(mask_sin.clamp_min(self.eps)).pow(2))
return mask_pf
def forward(self, spec: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
# spec shape: [b, 1, t, spec_bins, 2]
if not self.training and self.use_post_filter:
mask = self.post_filter(mask)
# mask shape: [b, 1, t, spec_bins]
mask = mask.unsqueeze(4)
# mask shape: [b, 1, t, spec_bins, 1]
return spec * mask
class DeepFiltering(nn.Module):
def __init__(self,
df_bins: int,
df_order: int,
lookahead: int = 0,
):
super(DeepFiltering, self).__init__()
self.df_bins = df_bins
self.df_order = df_order
self.lookahead = lookahead
self.pad = nn.ConstantPad2d((0, 0, df_order - 1 - lookahead, lookahead), 0.0)
def forward(self, *args, **kwargs):
raise AssertionError("use `forward_offline` or `forward_online` stead.")
def spec_unfold_offline(self, spec: torch.Tensor) -> torch.Tensor:
"""
Pads and unfolds the spectrogram according to frame_size.
:param spec: shape: [b, c, t, f], dtype: torch.complex64
:return: shape: [b, c, t, f, df_order]
"""
if self.df_order <= 1:
return spec.unsqueeze(-1)
# spec shape: [b, 1, t, f], dtype: torch.complex64
spec = self.pad(spec)
# spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64
spec_unfold = spec.unfold(dimension=2, size=self.df_order, step=1)
# spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64
return spec_unfold
def forward_offline(self,
spec: torch.Tensor,
coefs: torch.Tensor,
):
# spec shape: [b, 1, t, spec_bins, 2]
spec_c = torch.view_as_complex(spec.contiguous())
# spec_c shape: [b, 1, t, spec_bins]
spec_u = self.spec_unfold_offline(spec_c)
# spec_u shape: [b, 1, t, spec_bins, df_order]
spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins)
# spec_f shape: [b, 1, t, df_bins, df_order]
# coefs shape: [b, df_order, t, df_bins, 2]
coefs = torch.view_as_complex(coefs.contiguous())
# coefs shape: [b, df_order, t, df_bins]
coefs = coefs.unsqueeze(dim=1)
# coefs shape: [b, 1, df_order, t, df_bins]
spec_f = self.df_offline(spec_f, coefs)
# spec_f shape: [b, 1, t, df_bins]
spec_f = torch.view_as_real(spec_f)
# spec_f shape: [b, 1, t, df_bins, 2]
return spec_f
def df_offline(self, spec: torch.Tensor, coefs: torch.Tensor):
"""
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
:param spec: [b, 1, t, df_bins, df_order] complex.
:param coefs: [b, 1, df_order, t, df_bins] complex.
:return: [b, 1, t, df_bins] complex.
"""
spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs)
return spec_f
def spec_unfold_online(self, spec: torch.Tensor, cache_spec: torch.Tensor = None):
"""
Pads and unfolds the spectrogram according to frame_size.
:param spec: shape: [b, c, t, f], dtype: torch.complex64
:param cache_spec: shape: [b, c, df_order-1, f], dtype: torch.complex64
:return: shape: [b, c, t, f, df_order]
"""
if self.df_order <= 1:
return spec.unsqueeze(-1)
if cache_spec is None:
b, c, _, f = spec.shape
cache_spec = spec.new_zeros(size=(b, c, self.df_order-1, f))
spec_pad = torch.concat(tensors=[
cache_spec, spec
], dim=2)
new_cache_spec = spec_pad[:, :, -(self.df_order-1):, :]
# spec_pad shape: [b, 1, t+df_order-1, f], dtype: torch.complex64
spec_unfold = spec_pad.unfold(dimension=2, size=self.df_order, step=1)
# spec_unfold shape: [b, 1, t, f, df_order], dtype: torch.complex64
return spec_unfold, new_cache_spec
def forward_online(self,
spec: torch.Tensor,
coefs: torch.Tensor,
cache_dict: dict = None,
):
if cache_dict is None:
cache_dict = defaultdict(lambda: None)
cache0 = cache_dict["cache0"]
cache1 = cache_dict["cache1"]
# spec shape: [b, 1, t, spec_bins, 2]
spec_c = torch.view_as_complex(spec.contiguous())
# spec_c shape: [b, 1, t, spec_bins]
spec_u, new_cache0 = self.spec_unfold_online(spec_c, cache_spec=cache0)
# spec_u shape: [b, 1, t, spec_bins, df_order]
spec_f = spec_u.narrow(dim=-2, start=0, length=self.df_bins)
# spec_f shape: [b, 1, t, df_bins, df_order]
# coefs shape: [b, df_order, t, df_bins, 2]
coefs = torch.view_as_complex(coefs.contiguous())
# coefs shape: [b, df_order, t, df_bins]
coefs = coefs.unsqueeze(dim=1)
# coefs shape: [b, 1, df_order, t, df_bins]
spec_f, new_cache1 = self.df_online(spec_f, coefs, cache_coefs=cache1)
# spec_f shape: [b, 1, t, df_bins]
spec_f = torch.view_as_real(spec_f)
# spec_f shape: [b, 1, t, df_bins, 2]
new_cache_dict = {
"cache0": new_cache0,
"cache1": new_cache1,
}
return spec_f, new_cache_dict
def df_online(self, spec: torch.Tensor, coefs: torch.Tensor, cache_coefs: torch.Tensor = None) -> torch.Tensor:
"""
Deep filter implementation using `torch.einsum`. Requires unfolded spectrogram.
:param spec: [b, 1, 1, df_bins, df_order] complex.
:param coefs: [b, 1, df_order, 1, df_bins] complex.
:param cache_coefs: [b, 1, df_order, lookahead, df_bins] complex.
:return: [b, 1, 1, df_bins] complex.
"""
if cache_coefs is None:
b, c, _, _, f = coefs.shape
cache_coefs = coefs.new_zeros(size=(b, c, self.df_order, self.lookahead, f))
coefs_pad = torch.concat(tensors=[
cache_coefs, coefs
], dim=3)
# coefs_pad shape: [b, 1, df_order, 1+lookahead, df_bins], torch.complex64.
coefs = coefs_pad[:, :, :, :-self.lookahead, :]
# coefs shape: [b, 1, df_order, 1, df_bins], torch.complex64.
new_cache_coefs = coefs_pad[:, :, :, -self.lookahead:, :]
# new_cache_coefs shape: [b, 1, df_order, lookahead, df_bins], torch.complex64.
spec_f = torch.einsum("...tfn,...ntf->...tf", spec, coefs)
return spec_f, new_cache_coefs
class DfNet2(nn.Module):
def __init__(self, config: DfNet2Config):
super(DfNet2, self).__init__()
self.config = config
self.eps = 1e-12
self.freq_bins = self.config.nfft // 2 + 1
self.nfft = config.nfft
self.win_size = config.win_size
self.hop_size = config.hop_size
self.win_type = config.win_type
self.stft = ConvSTFT(
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
power=None,
requires_grad=False
)
self.istft = ConviSTFT(
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
win_type=config.win_type,
requires_grad=False
)
self.erb_bands = ErbBands(
sample_rate=config.sample_rate,
nfft=config.nfft,
erb_bins=config.erb_bins,
min_freq_bins_for_erb=config.min_freq_bins_for_erb,
)
self.erb_ema = ErbEMA(
sample_rate=config.sample_rate,
hop_size=config.hop_size,
erb_bins=config.erb_bins,
)
self.spec_ema = SpecEMA(
sample_rate=config.sample_rate,
hop_size=config.hop_size,
df_bins=config.df_bins,
)
self.encoder = Encoder(config)
self.erb_decoder = ErbDecoder(config)
self.df_decoder = DfDecoder(config)
self.df_out_transform = DfOutputReshapeMF(config.df_order, config.df_bins)
self.df_op = DeepFiltering(
df_bins=config.df_bins,
df_order=config.df_order,
lookahead=config.df_lookahead,
)
self.mask = Mask(use_post_filter=config.use_post_filter)
self.lsnr_fn = LocalSnrTarget(
sample_rate=config.sample_rate,
nfft=config.nfft,
win_size=config.win_size,
hop_size=config.hop_size,
n_frame=config.n_frame,
min_local_snr=config.min_local_snr,
max_local_snr=config.max_local_snr,
db=True,
)
def signal_prepare(self, signal: torch.Tensor) -> torch.Tensor:
if signal.dim() == 2:
signal = torch.unsqueeze(signal, dim=1)
_, _, n_samples = signal.shape
remainder = (n_samples - self.win_size) % self.hop_size
if remainder > 0:
n_samples_pad = self.hop_size - remainder
signal = F.pad(signal, pad=(0, n_samples_pad), mode="constant", value=0)
return signal
def feature_prepare(self, signal: torch.Tensor):
# noisy shape: [b, num_samples_pad]
spec_cmp = self.stft.forward(signal)
# spec_complex shape: [b, f, t], torch.complex64
spec_cmp = torch.transpose(spec_cmp, dim0=1, dim1=2)
# spec_complex shape: [b, t, f], torch.complex64
spec_cmp_real = torch.view_as_real(spec_cmp)
# spec_cmp_real shape: [b, t, f, 2]
spec_mag = torch.abs(spec_cmp)
spec_pow = torch.square(spec_mag)
# shape: [b, t, f]
spec = torch.unsqueeze(spec_cmp_real, dim=1)
# spec shape: [b, 1, t, f, 2]
feat_erb = self.erb_bands.erb_scale(spec_pow, db=True)
# feat_erb shape: [b, t, erb_bins]
feat_erb = torch.unsqueeze(feat_erb, dim=1)
# feat_erb shape: [b, 1, t, erb_bins]
feat_spec = spec_cmp_real.permute(0, 3, 1, 2)
# feat_spec shape: [b, 2, t, f]
feat_spec = feat_spec[..., :self.df_decoder.df_bins]
# feat_spec shape: [b, 2, t, df_bins]
spec = spec.detach()
feat_erb = feat_erb.detach()
feat_spec = feat_spec.detach()
return spec, feat_erb, feat_spec
def feature_norm(self, feat_erb, feat_spec, cache_dict: dict = None):
if cache_dict is None:
cache_dict = defaultdict(lambda: None)
cache0 = cache_dict["cache0"]
cache1 = cache_dict["cache1"]
feat_erb, new_cache0 = self.erb_ema.norm(feat_erb, state=cache0)
feat_spec, new_cache1 = self.spec_ema.norm(feat_spec, state=cache1)
new_cache_dict = {
"cache0": new_cache0,
"cache1": new_cache1,
}
feat_erb = feat_erb.detach()
feat_spec = feat_spec.detach()
return feat_erb, feat_spec, new_cache_dict
def forward(self,
noisy: torch.Tensor,
):
"""
:param noisy:
:return:
est_spec: shape: [b, 257*2, t]
est_wav: shape: [b, num_samples]
est_mask: shape: [b, 257, t]
lsnr: shape: [b, 1, t]
"""
n_samples = noisy.shape[-1]
noisy = self.signal_prepare(noisy)
spec, feat_erb, feat_spec = self.feature_prepare(noisy)
if self.config.use_ema_norm:
feat_erb, feat_spec, _ = self.feature_norm(feat_erb, feat_spec)
e0, e1, e2, e3, emb, c0, lsnr, _ = self.encoder.forward(feat_erb, feat_spec)
mask, _ = self.erb_decoder.forward(emb, e3, e2, e1, e0)
# mask shape: [b, 1, t, erb_bins]
mask = self.erb_bands.erb_scale_inv(mask)
# mask shape: [b, 1, t, f]
if torch.any(mask > 1) or torch.any(mask < 0):
raise AssertionError
spec_m = self.mask.forward(spec, mask)
# spec_m shape: [b, 1, t, f, 2]
spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
# spec_m shape: [b, 1, t, spec_bins, 2]
# lsnr shape: [b, t, 1]
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
# lsnr shape: [b, 1, t]
df_coefs, _ = self.df_decoder.forward(emb, c0)
df_coefs = self.df_out_transform(df_coefs)
# df_coefs shape: [b, df_order, t, df_bins, 2]
spec_ = spec[:, :, :, :self.config.spec_bins, :]
# spec shape: [b, 1, t, spec_bins, 2]
spec_f = self.df_op.forward_offline(spec_, df_coefs)
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32
spec_e = torch.concat(tensors=[
spec_f, spec_m[..., self.df_decoder.df_bins:, :]
], dim=3)
spec_e = torch.squeeze(spec_e, dim=1)
spec_e = spec_e.permute(0, 2, 1, 3)
# spec_e shape: [b, spec_bins, t, 2]
# spec_e shape: [b, spec_bins, t, 2]
est_spec = torch.view_as_complex(spec_e.contiguous())
# est_spec shape: [b, spec_bins, t], torch.complex64
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
# est_spec shape: [b, f, t], torch.complex64
est_wav = self.istft.forward(est_spec)
est_wav = est_wav[:, :, :n_samples]
# est_wav shape: [b, 1, n_samples]
est_mask = torch.squeeze(mask, dim=1)
est_mask = est_mask.permute(0, 2, 1)
# est_mask shape: [b, f, t]
return est_spec, est_wav, est_mask, lsnr
def forward_chunk_by_chunk(self,
noisy: torch.Tensor,
):
noisy = self.signal_prepare(noisy)
b, _, _ = noisy.shape
noisy = torch.concat(tensors=[
noisy, noisy.new_zeros(size=(b, 1, (self.config.df_lookahead+1)*self.hop_size))
], dim=2)
b, _, num_samples = noisy.shape
t = (num_samples - self.win_size) // self.hop_size + 1
cache_dict0 = None
cache_dict1 = None
cache_dict2 = None
cache_dict3 = None
cache_dict4 = None
cache_dict5 = None
cache_dict6 = None
waveform_list = list()
for i in range(int(t)):
begin = i * self.hop_size
end = begin + self.win_size
sub_noisy = noisy[:, :, begin: end]
spec, feat_erb, feat_spec = self.feature_prepare(sub_noisy)
# spec shape: [b, 1, t, f, 2]
# feat_erb shape: [b, 1, t, erb_bins]
# feat_spec shape: [b, 2, t, df_bins]
if self.config.use_ema_norm:
feat_erb, feat_spec, cache_dict0 = self.feature_norm(feat_erb, feat_spec, cache_dict=cache_dict0)
e0, e1, e2, e3, emb, c0, lsnr, cache_dict1 = self.encoder.forward(feat_erb, feat_spec, cache_dict=cache_dict1)
mask, cache_dict2 = self.erb_decoder.forward(emb, e3, e2, e1, e0, cache_dict=cache_dict2)
# mask shape: [b, 1, t, erb_bins]
mask = self.erb_bands.erb_scale_inv(mask)
# mask shape: [b, 1, t, f]
spec_m = self.mask.forward(spec, mask)
# spec_m shape: [b, 1, t, f, 2]
spec_m = spec_m[:, :, :, :self.config.spec_bins, :]
# spec_m shape: [b, 1, t, spec_bins, 2]
# lsnr shape: [b, t, 1]
lsnr = torch.transpose(lsnr, dim0=2, dim1=1)
# lsnr shape: [b, 1, t]
df_coefs, cache_dict3 = self.df_decoder.forward(emb, c0, cache_dict=cache_dict3)
df_coefs = self.df_out_transform(df_coefs)
# df_coefs shape: [b, df_order, t, df_bins, 2]
spec_ = spec[:, :, :, :self.config.spec_bins, :]
# spec shape: [b, 1, t, spec_bins, 2]
spec_f, cache_dict4 = self.df_op.forward_online(spec_, df_coefs, cache_dict=cache_dict4)
# spec_f shape: [b, 1, t, df_bins, 2], torch.float32
spec_e, cache_dict5 = self.spec_e_m_combine_online(spec_f, spec_m, cache_dict=cache_dict5)
spec_e = torch.squeeze(spec_e, dim=1)
spec_e = spec_e.permute(0, 2, 1, 3)
# spec_e shape: [b, spec_bins, t, 2]
# spec_e shape: [b, spec_bins, t, 2]
est_spec = torch.view_as_complex(spec_e.contiguous())
# est_spec shape: [b, spec_bins, t], torch.complex64
est_spec = torch.concat(tensors=[est_spec, est_spec[:, -1:, :]], dim=1)
# est_spec shape: [b, f, t], torch.complex64
est_wav, cache_dict6 = self.istft.forward_chunk(est_spec, cache_dict=cache_dict6)
# est_wav shape: [b, 1, hop_size]
waveform_list.append(est_wav)
waveform = torch.concat(tensors=waveform_list, dim=-1)
# waveform shape: [b, 1, n]
return waveform
def spec_e_m_combine_online(self, spec_f: torch.Tensor, spec_m: torch.Tensor, cache_dict: dict = None):
"""
:param spec_f: shape: [b, 1, t, df_bins, 2], torch.float32
:param spec_m: shape: [b, 1, t, spec_bins, 2]
:param cache_dict:
:return:
"""
if cache_dict is None:
cache_dict = defaultdict(lambda: None)
cache_spec_m = cache_dict["cache_spec_m"]
if cache_spec_m is None:
b, c, t, f, _ = spec_m.shape
cache_spec_m = spec_m.new_zeros(size=(b, c, self.config.df_lookahead, f, 2))
# cache0 shape: [b, 1, lookahead, f, 2]
spec_m_cat = torch.concat(tensors=[
cache_spec_m, spec_m,
], dim=2)
spec_m = spec_m_cat[:, :, :-self.config.df_lookahead, :, :]
new_cache_spec_m = spec_m_cat[:, :, -self.config.df_lookahead:, :, :]
spec_e = torch.concat(tensors=[
spec_f, spec_m[..., self.df_decoder.df_bins:, :]
], dim=3)
new_cache_dict = {
"cache_spec_m": new_cache_spec_m,
}
return spec_e, new_cache_dict
def mask_loss_fn(self, est_mask: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
"""
:param est_mask: torch.Tensor, shape: [b, 257, t]
:param clean:
:param noisy:
:return:
"""
if noisy.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
noise = noisy - clean
clean = self.signal_prepare(clean)
noise = self.signal_prepare(noise)
stft_clean = self.stft.forward(clean)
mag_clean = torch.abs(stft_clean)
stft_noise = self.stft.forward(noise)
mag_noise = torch.abs(stft_noise)
gth_irm_mask = (mag_clean / (mag_clean + mag_noise + self.eps)).clamp(0, 1)
loss = F.l1_loss(gth_irm_mask, est_mask, reduction="mean")
return loss
def lsnr_loss_fn(self, lsnr: torch.Tensor, clean: torch.Tensor, noisy: torch.Tensor):
if noisy.shape != clean.shape:
raise AssertionError("Input signals must have the same shape")
noise = noisy - clean
clean = self.signal_prepare(clean)
noise = self.signal_prepare(noise)
stft_clean = self.stft.forward(clean)
stft_noise = self.stft.forward(noise)
# shape: [b, f, t]
stft_clean = torch.transpose(stft_clean, dim0=1, dim1=2)
stft_noise = torch.transpose(stft_noise, dim0=1, dim1=2)
# shape: [b, t, f]
stft_clean = torch.unsqueeze(stft_clean, dim=1)
stft_noise = torch.unsqueeze(stft_noise, dim=1)
# shape: [b, 1, t, f]
# lsnr shape: [b, 1, t]
lsnr = lsnr.squeeze(1)
# lsnr shape: [b, t]
lsnr_gth = self.lsnr_fn.forward(stft_clean, stft_noise)
# lsnr_gth shape: [b, t]
loss = F.mse_loss(lsnr, lsnr_gth)
return loss
class DfNet2PretrainedModel(DfNet2):
def __init__(self,
config: DfNet2Config,
):
super(DfNet2PretrainedModel, self).__init__(
config=config,
)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = DfNet2Config.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
import time
# torch.set_num_threads(1)
config = DfNet2Config()
model = DfNet2PretrainedModel(config=config)
model.eval()
num_samples = 16000
noisy = torch.randn(size=(1, num_samples), dtype=torch.float32)
duration = num_samples / config.sample_rate
begin = time.time()
est_spec, est_wav, est_mask, lsnr = model.forward(noisy)
time_cost = time.time() - begin
print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
# print(f"est_spec.shape: {est_spec.shape}")
# print(f"est_wav.shape: {est_wav.shape}")
# print(f"est_mask.shape: {est_mask.shape}")
# print(f"lsnr.shape: {lsnr.shape}")
waveform = est_wav
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
print(waveform[:, :, 15680: 15682])
print(waveform[:, :, 15760: 15762])
print(waveform[:, :, 15840: 15842])
begin = time.time()
waveform = model.forward_chunk_by_chunk(noisy)
time_cost = time.time() - begin
print(f"time_cost: {time_cost:.4f}, audio_duration: {duration:.4f}, fpr: {time_cost / duration:.4f}")
waveform = waveform[:, :, (config.df_lookahead*config.hop_size):]
print(f"waveform.shape: {waveform.shape}, waveform.dtype: {waveform.dtype}")
print(waveform[:, :, 300: 302])
print(waveform[:, :, 15680: 15682])
print(waveform[:, :, 15760: 15762])
print(waveform[:, :, 15840: 15842])
return
if __name__ == "__main__":
main()