Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
from typing import Tuple | |
from toolbox.torchaudio.configuration_utils import PretrainedConfig | |
class MPNetConfig(PretrainedConfig): | |
""" | |
https://github.com/yxlu-0102/MP-SENet/blob/main/config.json | |
""" | |
def __init__(self, | |
num_gpus: int = 0, | |
batch_size: int = 4, | |
learning_rate: float = 0.0005, | |
adam_b1: float = 0.8, | |
adam_b2: float = 0.99, | |
lr_decay: float = 0.99, | |
seed: int = 1234, | |
dense_channel: int = 64, | |
compress_factor: float = 0.3, | |
num_tsconformers: int = 4, | |
beta: float = 2.0, | |
sample_rate: int = 16000, | |
segment_size: int = 32000, | |
n_fft: int = 400, | |
hop_size: int = 100, | |
win_size: int = 400, | |
num_workers: int = 4, | |
dist_config: dict = None, | |
discriminator_dim: int = 32, | |
discriminator_in_channel: int = 2, | |
**kwargs | |
): | |
super(MPNetConfig, self).__init__(**kwargs) | |
self.num_gpus = num_gpus | |
self.batch_size = batch_size | |
self.learning_rate = learning_rate | |
self.adam_b1 = adam_b1 | |
self.adam_b2 = adam_b2 | |
self.lr_decay = lr_decay | |
self.seed = seed | |
self.dense_channel = dense_channel | |
self.compress_factor = compress_factor | |
self.num_tsconformers = num_tsconformers | |
self.beta = beta | |
self.sample_rate = sample_rate | |
self.segment_size = segment_size | |
self.n_fft = n_fft | |
self.hop_size = hop_size | |
self.win_size = win_size | |
self.num_workers = num_workers | |
self.dist_config = dist_config or { | |
"dist_backend": "nccl", | |
"dist_url": "tcp://localhost:54321", | |
"world_size": 1 | |
} | |
self.discriminator_dim = discriminator_dim | |
self.discriminator_in_channel = discriminator_in_channel | |
if __name__ == "__main__": | |
pass | |