File size: 4,871 Bytes
ed91efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3d872
ed91efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd3d872
 
ed91efa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Tuple

from toolbox.torchaudio.configuration_utils import PretrainedConfig


class DfNet2Config(PretrainedConfig):
    def __init__(self,
                 sample_rate: int = 8000,
                 nfft: int = 512,
                 win_size: int = 200,
                 hop_size: int = 80,
                 win_type: str = "hann",

                 spec_bins: int = 256,
                 erb_bins: int = 32,
                 min_freq_bins_for_erb: int = 2,
                 use_ema_norm: bool = True,

                 conv_channels: int = 64,
                 conv_kernel_size_input: Tuple[int, int] = (3, 3),
                 conv_kernel_size_inner: Tuple[int, int] = (1, 3),

                 convt_kernel_size_inner: Tuple[int, int] = (1, 3),

                 embedding_hidden_size: int = 256,
                 encoder_combine_op: str = "concat",

                 encoder_emb_skip_op: str = "none",
                 encoder_emb_linear_groups: int = 16,
                 encoder_emb_hidden_size: int = 256,

                 encoder_linear_groups: int = 32,

                 decoder_emb_num_layers: int = 3,
                 decoder_emb_skip_op: str = "none",
                 decoder_emb_linear_groups: int = 16,
                 decoder_emb_hidden_size: int = 256,

                 df_decoder_hidden_size: int = 256,
                 df_num_layers: int = 2,
                 df_order: int = 5,
                 df_bins: int = 96,
                 df_gru_skip: str =  "grouped_linear",
                 df_decoder_linear_groups: int =  16,
                 df_pathway_kernel_size_t: int = 5,
                 df_lookahead: int = 2,

                 n_frame: int = 3,
                 max_local_snr: int = 30,
                 min_local_snr: int = -15,
                 norm_tau: float = 1.,

                 min_snr_db: float = -10,
                 max_snr_db: float = 20,

                 lr: float = 0.001,
                 lr_scheduler: str = "CosineAnnealingLR",
                 lr_scheduler_kwargs: dict = None,

                 max_epochs: int = 100,
                 clip_grad_norm: float = 10.,
                 seed: int = 1234,

                 num_workers: int = 4,
                 batch_size: int = 4,
                 eval_steps: int = 25000,

                 use_post_filter: bool = False,

                 **kwargs
                 ):
        super(DfNet2Config, self).__init__(**kwargs)
        # transform
        self.sample_rate = sample_rate
        self.nfft = nfft
        self.win_size = win_size
        self.hop_size = hop_size
        self.win_type = win_type

        # spectrum
        self.spec_bins = spec_bins
        self.erb_bins = erb_bins
        self.min_freq_bins_for_erb = min_freq_bins_for_erb

        self.use_ema_norm = use_ema_norm

        # conv
        self.conv_channels = conv_channels
        self.conv_kernel_size_input = conv_kernel_size_input
        self.conv_kernel_size_inner = conv_kernel_size_inner

        self.convt_kernel_size_inner = convt_kernel_size_inner

        self.embedding_hidden_size = embedding_hidden_size

        # encoder
        self.encoder_emb_skip_op = encoder_emb_skip_op
        self.encoder_emb_linear_groups = encoder_emb_linear_groups
        self.encoder_emb_hidden_size = encoder_emb_hidden_size

        self.encoder_linear_groups = encoder_linear_groups
        self.encoder_combine_op = encoder_combine_op

        # decoder
        self.decoder_emb_num_layers = decoder_emb_num_layers
        self.decoder_emb_skip_op = decoder_emb_skip_op
        self.decoder_emb_linear_groups = decoder_emb_linear_groups
        self.decoder_emb_hidden_size = decoder_emb_hidden_size

        # df decoder
        self.df_decoder_hidden_size = df_decoder_hidden_size
        self.df_num_layers = df_num_layers
        self.df_order = df_order
        self.df_bins = df_bins
        self.df_gru_skip = df_gru_skip
        self.df_decoder_linear_groups = df_decoder_linear_groups
        self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
        self.df_lookahead = df_lookahead

        # lsnr
        self.n_frame = n_frame
        self.max_local_snr = max_local_snr
        self.min_local_snr = min_local_snr
        self.norm_tau = norm_tau

        # data snr
        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db

        # train
        self.lr = lr
        self.lr_scheduler = lr_scheduler
        self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()

        self.max_epochs = max_epochs
        self.clip_grad_norm = clip_grad_norm
        self.seed = seed

        self.num_workers = num_workers
        self.batch_size = batch_size
        self.eval_steps = eval_steps

        # runtime
        self.use_post_filter = use_post_filter


if __name__ == "__main__":
    pass