nx_denoise / toolbox /torchaudio /models /mpnet /configuration_mpnet.py
HoneyTian's picture
update
8ed9309
#!/usr/bin/python3
# -*- coding: utf-8 -*-
from typing import Tuple
from toolbox.torchaudio.configuration_utils import PretrainedConfig
class MPNetConfig(PretrainedConfig):
"""
https://github.com/yxlu-0102/MP-SENet/blob/main/config.json
"""
def __init__(self,
num_gpus: int = 0,
batch_size: int = 4,
learning_rate: float = 0.0005,
adam_b1: float = 0.8,
adam_b2: float = 0.99,
lr_decay: float = 0.99,
seed: int = 1234,
dense_channel: int = 64,
compress_factor: float = 0.3,
num_tsconformers: int = 4,
beta: float = 2.0,
sample_rate: int = 16000,
segment_size: int = 32000,
n_fft: int = 400,
hop_size: int = 100,
win_size: int = 400,
num_workers: int = 4,
dist_config: dict = None,
discriminator_dim: int = 32,
discriminator_in_channel: int = 2,
**kwargs
):
super(MPNetConfig, self).__init__(**kwargs)
self.num_gpus = num_gpus
self.batch_size = batch_size
self.learning_rate = learning_rate
self.adam_b1 = adam_b1
self.adam_b2 = adam_b2
self.lr_decay = lr_decay
self.seed = seed
self.dense_channel = dense_channel
self.compress_factor = compress_factor
self.num_tsconformers = num_tsconformers
self.beta = beta
self.sample_rate = sample_rate
self.segment_size = segment_size
self.n_fft = n_fft
self.hop_size = hop_size
self.win_size = win_size
self.num_workers = num_workers
self.dist_config = dist_config or {
"dist_backend": "nccl",
"dist_url": "tcp://localhost:54321",
"world_size": 1
}
self.discriminator_dim = discriminator_dim
self.discriminator_in_channel = discriminator_in_channel
if __name__ == "__main__":
pass