Speech_Enhancement_Mamba / recipes /SEMamba_advanced.yaml
roychao19477
Initial clean commit
7001051
# Environment Settings
# These settings specify the hardware and distributed setup for the model training.
# Adjust `num_gpus` and `dist_config` according to your distributed training environment.
env_setting:
num_gpus: 2 # Number of GPUs. Now we don't support CPU mode.
num_workers: 20 # Number of worker threads for data loading.
seed: 1234 # Seed for random number generators to ensure reproducibility.
stdout_interval: 10
checkpoint_interval: 1000 # save model to ckpt every N steps
validation_interval: 1000
summary_interval: 100
dist_cfg:
dist_backend: nccl # Distributed training backend, 'nccl' for NVIDIA GPUs.
dist_url: tcp://localhost:19477 # URL for initializing distributed training.
world_size: 1 # Total number of processes in the distributed training.
# Datapath Configuratoin
data_cfg:
train_clean_json: data/train_clean.json
train_noisy_json: data/train_noisy.json
valid_clean_json: data/valid_clean.json
valid_noisy_json: data/valid_noisy.json
test_clean_json: data/test_clean.json
test_noisy_json: data/test_noisy.json
# Training Configuration
# This section details parameters that directly influence the training process,
# including batch sizes, learning rates, and optimizer specifics.
training_cfg:
training_epochs: 200 # Training epoch.
batch_size: 4 # Training batch size.
learning_rate: 0.0005 # Initial learning rate.
adam_b1: 0.8 # Beta1 hyperparameter for the AdamW optimizer.
adam_b2: 0.99 # Beta2 hyperparameter for the AdamW optimizer.
lr_decay: 0.99 # Learning rate decay per epoch.
segment_size: 32000 # Audio segment size used during training, dependent on sampling rate.
loss:
metric: 0.05
magnitude: 0.9
phase: 0.3
complex: 0.1
time: 0.2
consistancy: 0.1
use_PCS400: False # Use PCS or not
# STFT Configuration
# Configuration for Short-Time Fourier Transform (STFT), crucial for audio processing models.
stft_cfg:
sampling_rate: 16000 # Audio sampling rate in Hz.
n_fft: 400 # FFT components for transforming audio signals.
hop_size: 100 # Samples between successive frames.
win_size: 400 # Window size used in FFT.
# Model Configuration
# Defines the architecture specifics of the model, including layer configurations and feature compression.
model_cfg:
hid_feature: 64 # Channels in dense layers.
compress_factor: 0.3 # Compression factor applied to extracted features.
num_tfmamba: 4 # Number of Time-Frequency Mamba (TFMamba) blocks in the model.
d_state: 16 # Dimensionality of the state vector in Mamba blocks.
d_conv: 4 # Convolutional layer dimensionality within Mamba blocks.
expand: 4 # Expansion factor for the layers within the Mamba blocks.
norm_epsilon: 0.00001 # Numerical stability in normalization layers within the Mamba blocks.
beta: 2.0 # Hyperparameter for the Learnable Sigmoid function.
input_channel: 2 # Magnitude and Phase
output_channel: 1 # Single Channel Speech Enhancement