#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import Any, Dict, List, Tuple, Union from toolbox.torchaudio.configuration_utils import PretrainedConfig class DfNetConfig(PretrainedConfig): def __init__(self, sample_rate: int, fft_size: int, hop_size: int, df_bins: int, erb_bins: int, min_freq_bins_for_erb: int, df_order: int, df_lookahead: int, norm_tau: int, lsnr_max: int, lsnr_min: int, conv_channels: int, conv_kernel_size_input: Tuple[int, int], conv_kernel_size_inner: Tuple[int, int], convt_kernel_size_inner: Tuple[int, int], conv_lookahead: int, emb_hidden_dim: int, mask_post_filter: bool, df_hidden_dim: int, df_num_layers: int, df_pathway_kernel_size_t: int, df_gru_skip: str, post_filter_beta: float, df_n_iter: float, lsnr_dropout: bool, encoder_gru_skip_op: str, encoder_linear_groups: int, encoder_squeezed_gru_linear_groups: int, encoder_concat: bool, erb_decoder_gru_skip_op: str, erb_decoder_linear_groups: int, erb_decoder_emb_num_layers: int, df_decoder_linear_groups: int, **kwargs ): super(DfNetConfig, self).__init__(**kwargs) if df_gru_skip not in ("none", "identity", "grouped_linear"): raise AssertionError self.sample_rate = sample_rate self.fft_size = fft_size self.hop_size = hop_size self.df_bins = df_bins self.erb_bins = erb_bins self.min_freq_bins_for_erb = min_freq_bins_for_erb self.df_order = df_order self.df_lookahead = df_lookahead self.norm_tau = norm_tau self.lsnr_max = lsnr_max self.lsnr_min = lsnr_min self.conv_channels = conv_channels self.conv_kernel_size_input = conv_kernel_size_input self.conv_kernel_size_inner = conv_kernel_size_inner self.convt_kernel_size_inner = convt_kernel_size_inner self.conv_lookahead = conv_lookahead self.emb_hidden_dim = emb_hidden_dim self.mask_post_filter = mask_post_filter self.df_hidden_dim = df_hidden_dim self.df_num_layers = df_num_layers self.df_pathway_kernel_size_t = df_pathway_kernel_size_t self.df_gru_skip = df_gru_skip self.post_filter_beta = post_filter_beta self.df_n_iter = df_n_iter self.lsnr_dropout = lsnr_dropout self.encoder_gru_skip_op = encoder_gru_skip_op self.encoder_linear_groups = encoder_linear_groups self.encoder_squeezed_gru_linear_groups = encoder_squeezed_gru_linear_groups self.encoder_concat = encoder_concat self.erb_decoder_gru_skip_op = erb_decoder_gru_skip_op self.erb_decoder_linear_groups = erb_decoder_linear_groups self.erb_decoder_emb_num_layers = erb_decoder_emb_num_layers self.df_decoder_linear_groups = df_decoder_linear_groups if __name__ == "__main__": pass