#!/usr/bin/python3 # -*- coding: utf-8 -*- from toolbox.torchaudio.configuration_utils import PretrainedConfig class LstmConfig(PretrainedConfig): def __init__(self, sample_rate: int = 8000, segment_size: int = 32000, nfft: int = 512, win_size: int = 512, hop_size: int = 256, win_type: str = "hann", hidden_size: int = 1024, num_layers: int = 2, dropout: float = 0.2, min_snr_db: float = -10, max_snr_db: float = 20, max_epochs: int = 100, batch_size: int = 4, num_workers: int = 4, seed: int = 1234, lr: float = 0.001, lr_scheduler: str = "CosineAnnealingLR", lr_scheduler_kwargs: dict = None, weight_decay: float = 0.00001, clip_grad_norm: float = 10., eval_steps: int = 25000, **kwargs ): super(LstmConfig, self).__init__(**kwargs) self.sample_rate = sample_rate self.segment_size = segment_size self.nfft = nfft self.win_size = win_size self.hop_size = hop_size self.win_type = win_type self.hidden_size = hidden_size self.num_layers = num_layers self.dropout = dropout self.min_snr_db = min_snr_db self.max_snr_db = max_snr_db self.max_epochs = max_epochs self.batch_size = batch_size self.num_workers = num_workers self.seed = seed self.lr = lr self.lr_scheduler = lr_scheduler self.lr_scheduler_kwargs = lr_scheduler_kwargs or dict() self.weight_decay = weight_decay self.clip_grad_norm = clip_grad_norm self.eval_steps = eval_steps def main(): config = LstmConfig() config.to_yaml_file("config.yaml") return if __name__ == "__main__": main()