Spaces:
Running
on
Zero
Running
on
Zero
# Environment Settings | |
# These settings specify the hardware and distributed setup for the model training. | |
# Adjust `num_gpus` and `dist_config` according to your distributed training environment. | |
env_setting: | |
num_gpus: 2 # Number of GPUs. Now we don't support CPU mode. | |
num_workers: 20 # Number of worker threads for data loading. | |
seed: 1234 # Seed for random number generators to ensure reproducibility. | |
stdout_interval: 10 | |
checkpoint_interval: 1000 # save model to ckpt every N steps | |
validation_interval: 1000 | |
summary_interval: 100 | |
dist_cfg: | |
dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs. | |
dist_url: tcp://localhost:19477 # URL for initializing distributed training. | |
world_size: 1 # Total number of processes in the distributed training. | |
# Datapath Configuratoin | |
data_cfg: | |
train_clean_json: data/train_clean.json | |
train_noisy_json: data/train_noisy.json | |
valid_clean_json: data/valid_clean.json | |
valid_noisy_json: data/valid_noisy.json | |
test_clean_json: data/test_clean.json | |
test_noisy_json: data/test_noisy.json | |
# Training Configuration | |
# This section details parameters that directly influence the training process, | |
# including batch sizes, learning rates, and optimizer specifics. | |
training_cfg: | |
training_epochs: 200 # Training epoch. | |
batch_size: 4 # Training batch size. | |
learning_rate: 0.0005 # Initial learning rate. | |
adam_b1: 0.8 # Beta1 hyperparameter for the AdamW optimizer. | |
adam_b2: 0.99 # Beta2 hyperparameter for the AdamW optimizer. | |
lr_decay: 0.99 # Learning rate decay per epoch. | |
segment_size: 32000 # Audio segment size used during training, dependent on sampling rate. | |
loss: | |
metric: 0.05 | |
magnitude: 0.9 | |
phase: 0.3 | |
complex: 0.1 | |
time: 0.2 | |
consistancy: 0.1 | |
use_PCS400: False # Use PCS or not | |
# STFT Configuration | |
# Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models. | |
stft_cfg: | |
sampling_rate: 16000 # Audio sampling rate in Hz. | |
n_fft: 400 # FFT components for transforming audio signals. | |
hop_size: 100 # Samples between successive frames. | |
win_size: 400 # Window size used in FFT. | |
# Model Configuration | |
# Defines the architecture specifics of the model, including layer configurations and feature compression. | |
model_cfg: | |
hid_feature: 64 # Channels in dense layers. | |
compress_factor: 0.3 # Compression factor applied to extracted features. | |
num_tfmamba: 4 # Number of Time-Frequency Mamba (TFMamba) blocks in the model. | |
d_state: 16 # Dimensionality of the state vector in Mamba blocks. | |
d_conv: 4 # Convolutional layer dimensionality within Mamba blocks. | |
expand: 4 # Expansion factor for the layers within the Mamba blocks. | |
norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks. | |
beta: 2.0 # Hyperparameter for the Learnable Sigmoid function. | |
input_channel: 2 # Magnitude and Phase | |
output_channel: 1 # Single Channel Speech Enhancement | |