nx_denoise / toolbox /torchaudio /models /dfnet3 /configuration_dfnet3.py
HoneyTian's picture
first commit
bd94e77
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Any, Dict, List, Tuple, Union
from toolbox.torchaudio.configuration_utils import PretrainedConfig
class DfNetConfig(PretrainedConfig):
def __init__(self,
sample_rate: int,
fft_size: int,
hop_size: int,
df_bins: int,
erb_bins: int,
min_freq_bins_for_erb: int,
df_order: int,
df_lookahead: int,
norm_tau: int,
lsnr_max: int,
lsnr_min: int,
conv_channels: int,
conv_kernel_size_input: Tuple[int, int],
conv_kernel_size_inner: Tuple[int, int],
convt_kernel_size_inner: Tuple[int, int],
conv_lookahead: int,
emb_hidden_dim: int,
mask_post_filter: bool,
df_hidden_dim: int,
df_num_layers: int,
df_pathway_kernel_size_t: int,
df_gru_skip: str,
post_filter_beta: float,
df_n_iter: float,
lsnr_dropout: bool,
encoder_gru_skip_op: str,
encoder_linear_groups: int,
encoder_squeezed_gru_linear_groups: int,
encoder_concat: bool,
erb_decoder_gru_skip_op: str,
erb_decoder_linear_groups: int,
erb_decoder_emb_num_layers: int,
df_decoder_linear_groups: int,
**kwargs
):
super(DfNetConfig, self).__init__(**kwargs)
if df_gru_skip not in ("none", "identity", "grouped_linear"):
raise AssertionError
self.sample_rate = sample_rate
self.fft_size = fft_size
self.hop_size = hop_size
self.df_bins = df_bins
self.erb_bins = erb_bins
self.min_freq_bins_for_erb = min_freq_bins_for_erb
self.df_order = df_order
self.df_lookahead = df_lookahead
self.norm_tau = norm_tau
self.lsnr_max = lsnr_max
self.lsnr_min = lsnr_min
self.conv_channels = conv_channels
self.conv_kernel_size_input = conv_kernel_size_input
self.conv_kernel_size_inner = conv_kernel_size_inner
self.convt_kernel_size_inner = convt_kernel_size_inner
self.conv_lookahead = conv_lookahead
self.emb_hidden_dim = emb_hidden_dim
self.mask_post_filter = mask_post_filter
self.df_hidden_dim = df_hidden_dim
self.df_num_layers = df_num_layers
self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
self.df_gru_skip = df_gru_skip
self.post_filter_beta = post_filter_beta
self.df_n_iter = df_n_iter
self.lsnr_dropout = lsnr_dropout
self.encoder_gru_skip_op = encoder_gru_skip_op
self.encoder_linear_groups = encoder_linear_groups
self.encoder_squeezed_gru_linear_groups = encoder_squeezed_gru_linear_groups
self.encoder_concat = encoder_concat
self.erb_decoder_gru_skip_op = erb_decoder_gru_skip_op
self.erb_decoder_linear_groups = erb_decoder_linear_groups
self.erb_decoder_emb_num_layers = erb_decoder_emb_num_layers
self.df_decoder_linear_groups = df_decoder_linear_groups
if __name__ == "__main__":
pass