#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import Tuple from toolbox.torchaudio.configuration_utils import PretrainedConfig class DfNetConfig(PretrainedConfig): def __init__(self, sample_rate: int = 8000, nfft: int = 512, win_size: int = 200, hop_size: int = 80, win_type: str = "hann", spec_bins: int = 256, conv_channels: int = 64, conv_kernel_size_input: Tuple[int, int] = (3, 3), conv_kernel_size_inner: Tuple[int, int] = (1, 3), conv_lookahead: int = 0, convt_kernel_size_inner: Tuple[int, int] = (1, 3), embedding_hidden_size: int = 256, encoder_combine_op: str = "concat", encoder_emb_skip_op: str = "none", encoder_emb_linear_groups: int = 16, encoder_emb_hidden_size: int = 256, encoder_linear_groups: int = 32, lsnr_max: int = 30, lsnr_min: int = -15, norm_tau: float = 1., decoder_emb_num_layers: int = 3, decoder_emb_skip_op: str = "none", decoder_emb_linear_groups: int = 16, decoder_emb_hidden_size: int = 256, df_decoder_hidden_size: int = 256, df_num_layers: int = 2, df_order: int = 5, df_bins: int = 96, df_gru_skip: str = "grouped_linear", df_decoder_linear_groups: int = 16, df_pathway_kernel_size_t: int = 5, df_lookahead: int = 2, use_post_filter: bool = False, **kwargs ): super(DfNetConfig, self).__init__(**kwargs) # transform self.sample_rate = sample_rate self.nfft = nfft self.win_size = win_size self.hop_size = hop_size self.win_type = win_type # spectrum self.spec_bins = spec_bins # conv self.conv_channels = conv_channels self.conv_kernel_size_input = conv_kernel_size_input self.conv_kernel_size_inner = conv_kernel_size_inner self.conv_lookahead = conv_lookahead self.convt_kernel_size_inner = convt_kernel_size_inner self.embedding_hidden_size = embedding_hidden_size # encoder self.encoder_emb_skip_op = encoder_emb_skip_op self.encoder_emb_linear_groups = encoder_emb_linear_groups self.encoder_emb_hidden_size = encoder_emb_hidden_size self.encoder_linear_groups = encoder_linear_groups self.encoder_combine_op = encoder_combine_op self.lsnr_max = lsnr_max self.lsnr_min = lsnr_min self.norm_tau = norm_tau # decoder self.decoder_emb_num_layers = decoder_emb_num_layers self.decoder_emb_skip_op = decoder_emb_skip_op self.decoder_emb_linear_groups = decoder_emb_linear_groups self.decoder_emb_hidden_size = decoder_emb_hidden_size # df decoder self.df_decoder_hidden_size = df_decoder_hidden_size self.df_num_layers = df_num_layers self.df_order = df_order self.df_bins = df_bins self.df_gru_skip = df_gru_skip self.df_decoder_linear_groups = df_decoder_linear_groups self.df_pathway_kernel_size_t = df_pathway_kernel_size_t self.df_lookahead = df_lookahead # runtime self.use_post_filter = use_post_filter if __name__ == "__main__": pass