File size: 4,870 Bytes
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85a1b16
 
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94ba8b5
 
 
 
 
 
 
9a0003a
 
 
 
 
 
 
 
 
 
 
 
 
94ba8b5
 
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
85a1b16
 
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94ba8b5
 
 
 
 
 
 
 
 
1d4c9c3
94ba8b5
9a0003a
 
 
 
 
 
 
 
 
 
 
 
94ba8b5
 
 
1d4c9c3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Tuple

from toolbox.torchaudio.configuration_utils import PretrainedConfig


class DfNetConfig(PretrainedConfig):
    def __init__(self,
                 sample_rate: int = 8000,
                 nfft: int = 512,
                 win_size: int = 200,
                 hop_size: int = 80,
                 win_type: str = "hann",

                 spec_bins: int = 256,
                 erb_bins: int = 32,
                 min_freq_bins_for_erb: int = 2,

                 conv_channels: int = 64,
                 conv_kernel_size_input: Tuple[int, int] = (3, 3),
                 conv_kernel_size_inner: Tuple[int, int] = (1, 3),
                 conv_lookahead: int = 0,

                 convt_kernel_size_inner: Tuple[int, int] = (1, 3),

                 embedding_hidden_size: int = 256,
                 encoder_combine_op: str = "concat",

                 encoder_emb_skip_op: str = "none",
                 encoder_emb_linear_groups: int = 16,
                 encoder_emb_hidden_size: int = 256,

                 encoder_linear_groups: int = 32,

                 decoder_emb_num_layers: int = 3,
                 decoder_emb_skip_op: str = "none",
                 decoder_emb_linear_groups: int = 16,
                 decoder_emb_hidden_size: int = 256,

                 df_decoder_hidden_size: int = 256,
                 df_num_layers: int = 2,
                 df_order: int = 5,
                 df_bins: int = 96,
                 df_gru_skip: str =  "grouped_linear",
                 df_decoder_linear_groups: int =  16,
                 df_pathway_kernel_size_t: int = 5,
                 df_lookahead: int = 2,

                 n_frame: int = 3,
                 max_local_snr: int = 30,
                 min_local_snr: int = -15,
                 norm_tau: float = 1.,

                 min_snr_db: float = -10,
                 max_snr_db: float = 20,

                 lr: float = 0.001,
                 lr_scheduler: str = "CosineAnnealingLR",
                 lr_scheduler_kwargs: dict = None,

                 max_epochs: int = 100,
                 clip_grad_norm: float = 10.,
                 seed: int = 1234,

                 num_workers: int = 4,
                 batch_size: int = 4,
                 eval_steps: int = 25000,

                 use_post_filter: bool = False,

                 **kwargs
                 ):
        super(DfNetConfig, self).__init__(**kwargs)
        # transform
        self.sample_rate = sample_rate
        self.nfft = nfft
        self.win_size = win_size
        self.hop_size = hop_size
        self.win_type = win_type

        # spectrum
        self.spec_bins = spec_bins
        self.erb_bins = erb_bins
        self.min_freq_bins_for_erb = min_freq_bins_for_erb

        # conv
        self.conv_channels = conv_channels
        self.conv_kernel_size_input = conv_kernel_size_input
        self.conv_kernel_size_inner = conv_kernel_size_inner
        self.conv_lookahead = conv_lookahead

        self.convt_kernel_size_inner = convt_kernel_size_inner

        self.embedding_hidden_size = embedding_hidden_size

        # encoder
        self.encoder_emb_skip_op = encoder_emb_skip_op
        self.encoder_emb_linear_groups = encoder_emb_linear_groups
        self.encoder_emb_hidden_size = encoder_emb_hidden_size

        self.encoder_linear_groups = encoder_linear_groups
        self.encoder_combine_op = encoder_combine_op

        # decoder
        self.decoder_emb_num_layers = decoder_emb_num_layers
        self.decoder_emb_skip_op = decoder_emb_skip_op
        self.decoder_emb_linear_groups = decoder_emb_linear_groups
        self.decoder_emb_hidden_size = decoder_emb_hidden_size

        # df decoder
        self.df_decoder_hidden_size = df_decoder_hidden_size
        self.df_num_layers = df_num_layers
        self.df_order = df_order
        self.df_bins = df_bins
        self.df_gru_skip = df_gru_skip
        self.df_decoder_linear_groups = df_decoder_linear_groups
        self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
        self.df_lookahead = df_lookahead

        # lsnr
        self.n_frame = n_frame
        self.max_local_snr = max_local_snr
        self.min_local_snr = min_local_snr
        self.norm_tau = norm_tau

        # data snr
        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db

        # train
        self.lr = lr
        self.lr_scheduler = lr_scheduler
        self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()

        self.max_epochs = max_epochs
        self.clip_grad_norm = clip_grad_norm
        self.seed = seed

        self.num_workers = num_workers
        self.batch_size = batch_size
        self.eval_steps = eval_steps

        # runtime
        self.use_post_filter = use_post_filter


if __name__ == "__main__":
    pass