# Copyright 2024 The YourMT3 Authors. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Please see the details in the LICENSE file. """init_train.py""" from typing import Tuple, Literal, Any from copy import deepcopy import os import argparse import pytorch_lightning as pl from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.callbacks import LearningRateMonitor from pytorch_lightning.utilities import rank_zero_only from config.config import shared_cfg as default_shared_cfg from config.config import audio_cfg as default_audio_cfg from config.config import model_cfg as default_model_cfg from config.config import DEEPSPEED_CFG def initialize_trainer(args: argparse.Namespace, stage: Literal['train', 'test'] = 'train') -> Tuple[pl.Trainer, WandbLogger, dict]: """Initialize trainer and logger""" shared_cfg = deepcopy(default_shared_cfg) # create save dir os.makedirs(shared_cfg["WANDB"]["save_dir"], exist_ok=True) # collecting specific checkpoint from exp_id with extension (@xxx where xxx is checkpoint name) if "@" in args.exp_id: args.exp_id, checkpoint_name = args.exp_id.split("@") else: checkpoint_name = "last.ckpt" # checkpoint dir lightning_dir = os.path.join(shared_cfg["WANDB"]["save_dir"], args.project, args.exp_id) # create logger if args.wandb_mode is not None: shared_cfg["WANDB"]["mode"] = str(args.wandb_mode) if shared_cfg["WANDB"].get("cache_dir", None) is not None: os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir") del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg wandb_logger = WandbLogger(log_model="all", project=args.project, id=args.exp_id, allow_val_change=True, **shared_cfg['WANDB']) # check if any checkpoint exists last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name) if os.path.exists(os.path.join(last_ckpt_path)): print(f'Resuming from {last_ckpt_path}') elif stage == 'train': print(f'No checkpoint found in {last_ckpt_path}. Starting from scratch') last_ckpt_path = None else: raise ValueError(f'No checkpoint found in {last_ckpt_path}. Quit...') # add info dir_info = dict(lightning_dir=lightning_dir, last_ckpt_path=last_ckpt_path) # define checkpoint callback checkpoint_callback = ModelCheckpoint(**shared_cfg["CHECKPOINT"],) # define lr scheduler monitor callback lr_monitor = LearningRateMonitor(logging_interval='step') # deepspeed strategy if args.strategy == 'deepspeed': strategy = pl.strategies.DeepSpeedStrategy(config=DEEPSPEED_CFG) # validation interval if stage == 'train' and args.val_interval is not None: shared_cfg["TRAINER"]["check_val_every_n_epoch"] = None shared_cfg["TRAINER"]["val_check_interval"] = int(args.val_interval) # define trainer sync_batchnorm = False if stage == 'train': # train batch size if args.train_batch_size is not None: train_sub_bsz = int(args.train_batch_size[0]) train_local_bsz = int(args.train_batch_size[1]) if train_local_bsz % train_sub_bsz == 0: shared_cfg["BSZ"]["train_sub"] = train_sub_bsz shared_cfg["BSZ"]["train_local"] = train_local_bsz else: raise ValueError( f'Local batch size {train_local_bsz} must be divisible by sub batch size {train_sub_bsz}') # ddp strategy if args.strategy == 'ddp': args.strategy = 'ddp_find_unused_parameters_true' # fix for conformer or pitchshifter having unused parameter issue # sync-batchnorm if args.sync_batchnorm is True: sync_batchnorm = True train_params = dict(**shared_cfg["TRAINER"], devices=args.num_gpus if args.num_gpus == 'auto' else int(args.num_gpus), num_nodes=int(args.num_nodes), strategy=strategy if args.strategy == 'deepspeed' else args.strategy, precision=args.precision, max_epochs=args.max_epochs if stage == 'train' else None, max_steps=args.max_steps if stage == 'train' else -1, logger=wandb_logger, callbacks=[checkpoint_callback, lr_monitor], sync_batchnorm=sync_batchnorm) trainer = pl.trainer.trainer.Trainer(**train_params) # Update wandb logger (for DDP) if trainer.global_rank == 0: wandb_logger.experiment.config.update(args, allow_val_change=True) return trainer, wandb_logger, dir_info, shared_cfg def update_config(args, shared_cfg, stage: Literal['train', 'test'] = 'train'): """Update audio/model/shared configurations with args""" audio_cfg = default_audio_cfg model_cfg = default_model_cfg # Only update config when training if stage == 'train': # Augmentation parameters if args.random_amp_range is not None: shared_cfg["AUGMENTATION"]["train_random_amp_range"] = list( (float(args.random_amp_range[0]), float(args.random_amp_range[1]))) if args.stem_iaug_prob is not None: shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] = float(args.stem_iaug_prob) if args.xaug_max_k is not None: shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["max_k"] = int(args.xaug_max_k) if args.xaug_tau is not None: shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["tau"] = float(args.xaug_tau) if args.xaug_alpha is not None: shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["alpha"] = float(args.xaug_alpha) if args.xaug_no_instr_overlap is not None: shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_instr_overlap"] = bool(args.xaug_no_instr_overlap) if args.xaug_no_drum_overlap is not None: shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_drum_overlap"] = bool(args.xaug_no_drum_overlap) if args.uhat_intra_stem_augment is not None: shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["uhat_intra_stem_augment"] = bool( args.uhat_intra_stem_augment) if args.pitch_shift_range is not None: if args.pitch_shift_range in [["0", "0"], [0, 0]]: shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = None else: shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = list( (int(args.pitch_shift_range[0]), int(args.pitch_shift_range[1]))) train_stem_iaug_prob = shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] random_amp_range = shared_cfg["AUGMENTATION"]["train_random_amp_range"] train_stem_xaug_policy = shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"] print(f'Random amp range: {random_amp_range}\n' + f'Intra-stem augmentation probability: {train_stem_iaug_prob}\n' + f'Stem augmentation policy: {train_stem_xaug_policy}\n' + f'Pitch shift range: {shared_cfg["AUGMENTATION"]["train_pitch_shift_range"]}\n') # Update audio config if args.audio_codec != None: assert args.audio_codec in ['spec', 'melspec'] audio_cfg["codec"] = str(args.audio_codec) if args.hop_length != None: audio_cfg["hop_length"] = int(args.hop_length) if args.n_mels != None: audio_cfg["n_mels"] = int(args.n_mels) if args.input_frames != None: audio_cfg["input_frames"] = int(args.input_frames) # Update shared config if shared_cfg["TOKENIZER"]["max_shift_steps"] == "auto": shift_steps_ms = shared_cfg["TOKENIZER"]["shift_step_ms"] input_frames = audio_cfg["input_frames"] fs = audio_cfg["sample_rate"] max_shift_steps = (input_frames / fs) // (shift_steps_ms / 1000) + 2 # 206 by default shared_cfg["TOKENIZER"]["max_shift_steps"] = int(max_shift_steps) # Update model config if args.encoder_type != None: model_cfg["encoder_type"] = str(args.encoder_type) if args.decoder_type != None: model_cfg["decoder_type"] = str(args.decoder_type) if args.pre_encoder_type != "default": model_cfg["pre_encoder_type"] = str(args.pre_encoder_type) if args.pre_decoder_type != 'default': model_cfg["pre_decoder_type"] = str(args.pre_decoder_type) if args.conv_out_channels != None: model_cfg["conv_out_channels"] = int(args.conv_out_channels) assert isinstance(args.task_cond_decoder, bool) and isinstance(args.task_cond_encoder, bool) model_cfg["use_task_conditional_encoder"] = args.task_cond_encoder model_cfg["use_task_conditional_decoder"] = args.task_cond_decoder if args.encoder_position_encoding_type != 'default': if args.encoder_position_encoding_type in ['None', 'none', '0']: model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = None elif args.encoder_position_encoding_type in [ 'sinusoidal', 'rope', 'trainable', 'alibi', 'alibit', 'tkd', 'td', 'tk', 'kdt' ]: model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = str( args.encoder_position_encoding_type) else: raise ValueError(f'Encoder PE type {args.encoder_position_encoding_type} not supported') if args.decoder_position_encoding_type != 'default': if args.decoder_position_encoding_type in ['None', 'none', '0']: raise ValueError('Decoder PE type cannot be None') elif args.decoder_position_encoding_type in ['sinusoidal', 'trainable']: model_cfg["decoder"][model_cfg["decoder_type"]]["position_encoding_type"] = str( args.decoder_position_encoding_type) else: raise ValueError(f'Decoder PE {args.decoder_position_encoding_type} not supported') if args.tie_word_embedding is not None: model_cfg["tie_word_embedding"] = bool(args.tie_word_embedding) if args.d_feat != None: model_cfg["d_feat"] = int(args.d_feat) if args.d_latent != None: model_cfg['encoder']['perceiver-tf']["d_latent"] = int(args.d_latent) if args.num_latents != None: model_cfg['encoder']['perceiver-tf']['num_latents'] = int(args.num_latents) if args.perceiver_tf_d_model != None: model_cfg['encoder']['perceiver-tf']['d_model'] = int(args.perceiver_tf_d_model) if args.num_perceiver_tf_blocks != None: model_cfg["encoder"]["perceiver-tf"]["num_blocks"] = int(args.num_perceiver_tf_blocks) if args.num_perceiver_tf_local_transformers_per_block != None: model_cfg["encoder"]["perceiver-tf"]["num_local_transformers_per_block"] = int( args.num_perceiver_tf_local_transformers_per_block) if args.num_perceiver_tf_temporal_transformers_per_block != None: model_cfg["encoder"]["perceiver-tf"]["num_temporal_transformers_per_block"] = int( args.num_perceiver_tf_temporal_transformers_per_block) if args.attention_to_channel != None: model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = bool(args.attention_to_channel) if args.sca_use_query_residual != None: model_cfg["encoder"]["perceiver-tf"]["sca_use_query_residual"] = bool(args.sca_use_query_residual) if args.layer_norm_type != None: model_cfg["encoder"]["perceiver-tf"]["layer_norm"] = str(args.layer_norm_type) if args.ff_layer_type != None: model_cfg["encoder"]["perceiver-tf"]["ff_layer_type"] = str(args.ff_layer_type) if args.ff_widening_factor != None: model_cfg["encoder"]["perceiver-tf"]["ff_widening_factor"] = int(args.ff_widening_factor) if args.moe_num_experts != None: model_cfg["encoder"]["perceiver-tf"]["moe_num_experts"] = int(args.moe_num_experts) if args.moe_topk != None: model_cfg["encoder"]["perceiver-tf"]["moe_topk"] = int(args.moe_topk) if args.hidden_act != None: model_cfg["encoder"]["perceiver-tf"]["hidden_act"] = str(args.hidden_act) if args.rotary_type != None: assert len( args.rotary_type ) == 3, "rotary_type must be a 3-letter string (e.g. 'ppl': 'pixel' for SCA, 'pixel' for latent, 'lang' for temporal transformer)" model_cfg["encoder"]["perceiver-tf"]["rotary_type_sca"] = str(args.rotary_type)[0] model_cfg["encoder"]["perceiver-tf"]["rotary_type_latent"] = str(args.rotary_type)[1] model_cfg["encoder"]["perceiver-tf"]["rotary_type_temporal"] = str(args.rotary_type)[2] if args.rope_apply_to_keys != None: model_cfg["encoder"]["perceiver-tf"]["rope_apply_to_keys"] = bool(args.rope_apply_to_keys) if args.rope_partial_pe != None: model_cfg["encoder"]["perceiver-tf"]["rope_partial_pe"] = bool(args.rope_partial_pe) if args.decoder_ff_layer_type != None: model_cfg["decoder"][model_cfg["decoder_type"]]["ff_layer_type"] = str(args.decoder_ff_layer_type) if args.decoder_ff_widening_factor != None: model_cfg["decoder"][model_cfg["decoder_type"]]["ff_widening_factor"] = int(args.decoder_ff_widening_factor) if args.event_length != None: model_cfg["event_length"] = int(args.event_length) if stage == 'train': if args.encoder_dropout_rate != None: model_cfg["encoder"][model_cfg["encoder_type"]]["dropout_rate"] = float(args.encoder_dropout_rate) if args.decoder_dropout_rate != None: model_cfg["decoder"][model_cfg["decoder_type"]]["dropout_rate"] = float(args.decoder_dropout_rate) return shared_cfg, audio_cfg, model_cfg # return updated configs