#!/usr/bin/python3 # -*- coding: utf-8 -*- from typing import Tuple from toolbox.torchaudio.configuration_utils import PretrainedConfig class MPNetConfig(PretrainedConfig): """ https://github.com/yxlu-0102/MP-SENet/blob/main/config.json """ def __init__(self, num_gpus: int = 0, batch_size: int = 4, learning_rate: float = 0.0005, adam_b1: float = 0.8, adam_b2: float = 0.99, lr_decay: float = 0.99, seed: int = 1234, dense_channel: int = 64, compress_factor: float = 0.3, num_tsconformers: int = 4, beta: float = 2.0, sample_rate: int = 16000, segment_size: int = 32000, n_fft: int = 400, hop_size: int = 100, win_size: int = 400, num_workers: int = 4, dist_config: dict = None, discriminator_dim: int = 32, discriminator_in_channel: int = 2, **kwargs ): super(MPNetConfig, self).__init__(**kwargs) self.num_gpus = num_gpus self.batch_size = batch_size self.learning_rate = learning_rate self.adam_b1 = adam_b1 self.adam_b2 = adam_b2 self.lr_decay = lr_decay self.seed = seed self.dense_channel = dense_channel self.compress_factor = compress_factor self.num_tsconformers = num_tsconformers self.beta = beta self.sample_rate = sample_rate self.segment_size = segment_size self.n_fft = n_fft self.hop_size = hop_size self.win_size = win_size self.num_workers = num_workers self.dist_config = dist_config or { "dist_backend": "nccl", "dist_url": "tcp://localhost:54321", "world_size": 1 } self.discriminator_dim = discriminator_dim self.discriminator_in_channel = discriminator_in_channel if __name__ == "__main__": pass