File size: 3,408 Bytes
bd94e77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Any, Dict, List, Tuple, Union

from toolbox.torchaudio.configuration_utils import PretrainedConfig


class DfNetConfig(PretrainedConfig):
    def __init__(self,
                 sample_rate: int,
                 fft_size: int,
                 hop_size: int,
                 df_bins: int,
                 erb_bins: int,
                 min_freq_bins_for_erb: int,
                 df_order: int,
                 df_lookahead: int,
                 norm_tau: int,
                 lsnr_max: int,
                 lsnr_min: int,
                 conv_channels: int,
                 conv_kernel_size_input: Tuple[int, int],
                 conv_kernel_size_inner: Tuple[int, int],
                 convt_kernel_size_inner: Tuple[int, int],
                 conv_lookahead: int,
                 emb_hidden_dim: int,
                 mask_post_filter: bool,
                 df_hidden_dim: int,
                 df_num_layers: int,
                 df_pathway_kernel_size_t: int,
                 df_gru_skip: str,
                 post_filter_beta: float,
                 df_n_iter: float,
                 lsnr_dropout: bool,
                 encoder_gru_skip_op: str,
                 encoder_linear_groups: int,
                 encoder_squeezed_gru_linear_groups: int,
                 encoder_concat: bool,
                 erb_decoder_gru_skip_op: str,
                 erb_decoder_linear_groups: int,
                 erb_decoder_emb_num_layers: int,
                 df_decoder_linear_groups: int,
                 **kwargs
                 ):
        super(DfNetConfig, self).__init__(**kwargs)
        if df_gru_skip not in ("none", "identity", "grouped_linear"):
            raise AssertionError

        self.sample_rate = sample_rate
        self.fft_size = fft_size
        self.hop_size = hop_size
        self.df_bins = df_bins
        self.erb_bins = erb_bins
        self.min_freq_bins_for_erb = min_freq_bins_for_erb
        self.df_order = df_order
        self.df_lookahead = df_lookahead
        self.norm_tau = norm_tau
        self.lsnr_max = lsnr_max
        self.lsnr_min = lsnr_min

        self.conv_channels = conv_channels
        self.conv_kernel_size_input = conv_kernel_size_input
        self.conv_kernel_size_inner = conv_kernel_size_inner
        self.convt_kernel_size_inner = convt_kernel_size_inner
        self.conv_lookahead = conv_lookahead

        self.emb_hidden_dim = emb_hidden_dim
        self.mask_post_filter = mask_post_filter
        self.df_hidden_dim = df_hidden_dim
        self.df_num_layers = df_num_layers
        self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
        self.df_gru_skip = df_gru_skip
        self.post_filter_beta = post_filter_beta
        self.df_n_iter = df_n_iter
        self.lsnr_dropout = lsnr_dropout
        self.encoder_gru_skip_op = encoder_gru_skip_op
        self.encoder_linear_groups = encoder_linear_groups
        self.encoder_squeezed_gru_linear_groups = encoder_squeezed_gru_linear_groups
        self.encoder_concat = encoder_concat

        self.erb_decoder_gru_skip_op = erb_decoder_gru_skip_op
        self.erb_decoder_linear_groups = erb_decoder_linear_groups
        self.erb_decoder_emb_num_layers = erb_decoder_emb_num_layers

        self.df_decoder_linear_groups = df_decoder_linear_groups


if __name__ == "__main__":
    pass