File size: 2,142 Bytes
f74ae8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f69c753
 
 
f74ae8e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f69c753
 
 
f74ae8e
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Tuple

from toolbox.torchaudio.configuration_utils import PretrainedConfig


class MPNetConfig(PretrainedConfig):
    """
    https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
    """
    def __init__(self,
                 num_gpus: int = 0,
                 batch_size: int = 4,
                 learning_rate: float = 0.0005,
                 adam_b1: float = 0.8,
                 adam_b2: float = 0.99,
                 lr_decay: float = 0.99,
                 seed: int = 1234,

                 dense_channel: int = 64,
                 compress_factor: float = 0.3,
                 num_tsconformers: int = 4,
                 beta: float = 2.0,

                 sample_rate: int = 16000,
                 segment_size: int = 32000,
                 n_fft: int = 400,
                 hop_size: int = 100,
                 win_size: int = 400,

                 num_workers: int = 4,

                 dist_config: dict = None,

                 discriminator_dim: int = 32,
                 discriminator_in_channel: int = 2,

                 **kwargs
                 ):
        super(MPNetConfig, self).__init__(**kwargs)
        self.num_gpus = num_gpus
        self.batch_size = batch_size
        self.learning_rate = learning_rate
        self.adam_b1 = adam_b1
        self.adam_b2 = adam_b2
        self.lr_decay = lr_decay
        self.seed = seed

        self.dense_channel = dense_channel
        self.compress_factor = compress_factor
        self.num_tsconformers = num_tsconformers
        self.beta = beta

        self.sample_rate = sample_rate
        self.segment_size = segment_size
        self.n_fft = n_fft
        self.hop_size = hop_size
        self.win_size = win_size

        self.num_workers = num_workers

        self.dist_config = dist_config or {
            "dist_backend": "nccl",
            "dist_url": "tcp://localhost:54321",
            "world_size": 1
        }

        self.discriminator_dim = discriminator_dim
        self.discriminator_in_channel = discriminator_in_channel


if __name__ == "__main__":
    pass