File size: 4,556 Bytes
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a0003a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d4c9c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9a0003a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d4c9c3
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Tuple

from toolbox.torchaudio.configuration_utils import PretrainedConfig


class DfNetConfig(PretrainedConfig):
    def __init__(self,
                 sample_rate: int = 8000,
                 nfft: int = 512,
                 win_size: int = 200,
                 hop_size: int = 80,
                 win_type: str = "hann",

                 spec_bins: int = 256,

                 conv_channels: int = 64,
                 conv_kernel_size_input: Tuple[int, int] = (3, 3),
                 conv_kernel_size_inner: Tuple[int, int] = (1, 3),
                 conv_lookahead: int = 0,

                 convt_kernel_size_inner: Tuple[int, int] = (1, 3),

                 embedding_hidden_size: int = 256,
                 encoder_combine_op: str = "concat",

                 encoder_emb_skip_op: str = "none",
                 encoder_emb_linear_groups: int = 16,
                 encoder_emb_hidden_size: int = 256,

                 encoder_linear_groups: int = 32,

                 lsnr_max: int = 30,
                 lsnr_min: int = -15,
                 norm_tau: float = 1.,

                 decoder_emb_num_layers: int = 3,
                 decoder_emb_skip_op: str = "none",
                 decoder_emb_linear_groups: int = 16,
                 decoder_emb_hidden_size: int = 256,

                 df_decoder_hidden_size: int = 256,
                 df_num_layers: int = 2,
                 df_order: int = 5,
                 df_bins: int = 96,
                 df_gru_skip: str =  "grouped_linear",
                 df_decoder_linear_groups: int =  16,
                 df_pathway_kernel_size_t: int = 5,
                 df_lookahead: int = 2,

                 use_post_filter: bool = False,

                 lr: float = 0.001,
                 lr_scheduler: str = "CosineAnnealingLR",
                 lr_scheduler_kwargs: dict = None,

                 max_epochs: int = 100,
                 clip_grad_norm: float = 10.,
                 seed: int = 1234,

                 min_snr_db: float = -10,
                 max_snr_db: float = 20,

                 num_workers: int = 4,
                 batch_size: int = 4,
                 eval_steps: int = 25000,

                 **kwargs
                 ):
        super(DfNetConfig, self).__init__(**kwargs)
        # transform
        self.sample_rate = sample_rate
        self.nfft = nfft
        self.win_size = win_size
        self.hop_size = hop_size
        self.win_type = win_type

        # spectrum
        self.spec_bins = spec_bins

        # conv
        self.conv_channels = conv_channels
        self.conv_kernel_size_input = conv_kernel_size_input
        self.conv_kernel_size_inner = conv_kernel_size_inner
        self.conv_lookahead = conv_lookahead

        self.convt_kernel_size_inner = convt_kernel_size_inner

        self.embedding_hidden_size = embedding_hidden_size

        # encoder
        self.encoder_emb_skip_op = encoder_emb_skip_op
        self.encoder_emb_linear_groups = encoder_emb_linear_groups
        self.encoder_emb_hidden_size = encoder_emb_hidden_size

        self.encoder_linear_groups = encoder_linear_groups
        self.encoder_combine_op = encoder_combine_op

        self.lsnr_max = lsnr_max
        self.lsnr_min = lsnr_min
        self.norm_tau = norm_tau

        # decoder
        self.decoder_emb_num_layers = decoder_emb_num_layers
        self.decoder_emb_skip_op = decoder_emb_skip_op
        self.decoder_emb_linear_groups = decoder_emb_linear_groups
        self.decoder_emb_hidden_size = decoder_emb_hidden_size

        # df decoder
        self.df_decoder_hidden_size = df_decoder_hidden_size
        self.df_num_layers = df_num_layers
        self.df_order = df_order
        self.df_bins = df_bins
        self.df_gru_skip = df_gru_skip
        self.df_decoder_linear_groups = df_decoder_linear_groups
        self.df_pathway_kernel_size_t = df_pathway_kernel_size_t
        self.df_lookahead = df_lookahead

        # runtime
        self.use_post_filter = use_post_filter

        #
        self.lr = lr
        self.lr_scheduler = lr_scheduler
        self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict()

        self.max_epochs = max_epochs
        self.clip_grad_norm = clip_grad_norm
        self.seed = seed

        self.min_snr_db = min_snr_db
        self.max_snr_db = max_snr_db

        self.num_workers = num_workers
        self.batch_size = batch_size
        self.eval_steps = eval_steps


if __name__ == "__main__":
    pass