YourMT3 / amt /src /model /init_train.py
mimbres's picture
.
a03c9b4
raw
history blame
14.2 kB
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""init_train.py"""
from typing import Tuple, Literal, Any
from copy import deepcopy
import os
import argparse
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.utilities import rank_zero_only
from config.config import shared_cfg as default_shared_cfg
from config.config import audio_cfg as default_audio_cfg
from config.config import model_cfg as default_model_cfg
from config.config import DEEPSPEED_CFG
def initialize_trainer(args: argparse.Namespace,
stage: Literal['train', 'test'] = 'train') -> Tuple[pl.Trainer, WandbLogger, dict]:
"""Initialize trainer and logger"""
shared_cfg = deepcopy(default_shared_cfg)
# create save dir
os.makedirs(shared_cfg["WANDB"]["save_dir"], exist_ok=True)
# collecting specific checkpoint from exp_id with extension (@xxx where xxx is checkpoint name)
if "@" in args.exp_id:
args.exp_id, checkpoint_name = args.exp_id.split("@")
else:
checkpoint_name = "last.ckpt"
# checkpoint dir
lightning_dir = os.path.join(shared_cfg["WANDB"]["save_dir"], args.project, args.exp_id)
# create logger
if args.wandb_mode is not None:
shared_cfg["WANDB"]["mode"] = str(args.wandb_mode)
if shared_cfg["WANDB"].get("cache_dir", None) is not None:
os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir")
del shared_cfg["WANDB"]["cache_dir"] # remove cache_dir from shared_cfg
wandb_logger = WandbLogger(log_model="all",
project=args.project,
id=args.exp_id,
allow_val_change=True,
**shared_cfg['WANDB'])
# check if any checkpoint exists
last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name)
if os.path.exists(os.path.join(last_ckpt_path)):
print(f'Resuming from {last_ckpt_path}')
elif stage == 'train':
print(f'No checkpoint found in {last_ckpt_path}. Starting from scratch')
last_ckpt_path = None
else:
raise ValueError(f'No checkpoint found in {last_ckpt_path}. Quit...')
# add info
dir_info = dict(lightning_dir=lightning_dir, last_ckpt_path=last_ckpt_path)
# define checkpoint callback
checkpoint_callback = ModelCheckpoint(**shared_cfg["CHECKPOINT"],)
# define lr scheduler monitor callback
lr_monitor = LearningRateMonitor(logging_interval='step')
# deepspeed strategy
if args.strategy == 'deepspeed':
strategy = pl.strategies.DeepSpeedStrategy(config=DEEPSPEED_CFG)
# validation interval
if stage == 'train' and args.val_interval is not None:
shared_cfg["TRAINER"]["check_val_every_n_epoch"] = None
shared_cfg["TRAINER"]["val_check_interval"] = int(args.val_interval)
# define trainer
sync_batchnorm = False
if stage == 'train':
# train batch size
if args.train_batch_size is not None:
train_sub_bsz = int(args.train_batch_size[0])
train_local_bsz = int(args.train_batch_size[1])
if train_local_bsz % train_sub_bsz == 0:
shared_cfg["BSZ"]["train_sub"] = train_sub_bsz
shared_cfg["BSZ"]["train_local"] = train_local_bsz
else:
raise ValueError(
f'Local batch size {train_local_bsz} must be divisible by sub batch size {train_sub_bsz}')
# ddp strategy
if args.strategy == 'ddp':
args.strategy = 'ddp_find_unused_parameters_true' # fix for conformer or pitchshifter having unused parameter issue
# sync-batchnorm
if args.sync_batchnorm is True:
sync_batchnorm = True
train_params = dict(**shared_cfg["TRAINER"],
devices=args.num_gpus if args.num_gpus == 'auto' else int(args.num_gpus),
num_nodes=int(args.num_nodes),
strategy=strategy if args.strategy == 'deepspeed' else args.strategy,
precision=args.precision,
max_epochs=args.max_epochs if stage == 'train' else None,
max_steps=args.max_steps if stage == 'train' else -1,
logger=wandb_logger,
callbacks=[checkpoint_callback, lr_monitor],
sync_batchnorm=sync_batchnorm)
trainer = pl.trainer.trainer.Trainer(**train_params)
# Update wandb logger (for DDP)
if trainer.global_rank == 0:
wandb_logger.experiment.config.update(args, allow_val_change=True)
return trainer, wandb_logger, dir_info, shared_cfg
def update_config(args, shared_cfg, stage: Literal['train', 'test'] = 'train'):
"""Update audio/model/shared configurations with args"""
audio_cfg = default_audio_cfg
model_cfg = default_model_cfg
# Only update config when training
if stage == 'train':
# Augmentation parameters
if args.random_amp_range is not None:
shared_cfg["AUGMENTATION"]["train_random_amp_range"] = list(
(float(args.random_amp_range[0]), float(args.random_amp_range[1])))
if args.stem_iaug_prob is not None:
shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] = float(args.stem_iaug_prob)
if args.xaug_max_k is not None:
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["max_k"] = int(args.xaug_max_k)
if args.xaug_tau is not None:
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["tau"] = float(args.xaug_tau)
if args.xaug_alpha is not None:
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["alpha"] = float(args.xaug_alpha)
if args.xaug_no_instr_overlap is not None:
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_instr_overlap"] = bool(args.xaug_no_instr_overlap)
if args.xaug_no_drum_overlap is not None:
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_drum_overlap"] = bool(args.xaug_no_drum_overlap)
if args.uhat_intra_stem_augment is not None:
shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["uhat_intra_stem_augment"] = bool(
args.uhat_intra_stem_augment)
if args.pitch_shift_range is not None:
if args.pitch_shift_range in [["0", "0"], [0, 0]]:
shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = None
else:
shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = list(
(int(args.pitch_shift_range[0]), int(args.pitch_shift_range[1])))
train_stem_iaug_prob = shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"]
random_amp_range = shared_cfg["AUGMENTATION"]["train_random_amp_range"]
train_stem_xaug_policy = shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]
print(f'Random amp range: {random_amp_range}\n' +
f'Intra-stem augmentation probability: {train_stem_iaug_prob}\n' +
f'Stem augmentation policy: {train_stem_xaug_policy}\n' +
f'Pitch shift range: {shared_cfg["AUGMENTATION"]["train_pitch_shift_range"]}\n')
# Update audio config
if args.audio_codec != None:
assert args.audio_codec in ['spec', 'melspec']
audio_cfg["codec"] = str(args.audio_codec)
if args.hop_length != None:
audio_cfg["hop_length"] = int(args.hop_length)
if args.n_mels != None:
audio_cfg["n_mels"] = int(args.n_mels)
if args.input_frames != None:
audio_cfg["input_frames"] = int(args.input_frames)
# Update shared config
if shared_cfg["TOKENIZER"]["max_shift_steps"] == "auto":
shift_steps_ms = shared_cfg["TOKENIZER"]["shift_step_ms"]
input_frames = audio_cfg["input_frames"]
fs = audio_cfg["sample_rate"]
max_shift_steps = (input_frames / fs) // (shift_steps_ms / 1000) + 2 # 206 by default
shared_cfg["TOKENIZER"]["max_shift_steps"] = int(max_shift_steps)
# Update model config
if args.encoder_type != None:
model_cfg["encoder_type"] = str(args.encoder_type)
if args.decoder_type != None:
model_cfg["decoder_type"] = str(args.decoder_type)
if args.pre_encoder_type != "default":
model_cfg["pre_encoder_type"] = str(args.pre_encoder_type)
if args.pre_decoder_type != 'default':
model_cfg["pre_decoder_type"] = str(args.pre_decoder_type)
if args.conv_out_channels != None:
model_cfg["conv_out_channels"] = int(args.conv_out_channels)
assert isinstance(args.task_cond_decoder, bool) and isinstance(args.task_cond_encoder, bool)
model_cfg["use_task_conditional_encoder"] = args.task_cond_encoder
model_cfg["use_task_conditional_decoder"] = args.task_cond_decoder
if args.encoder_position_encoding_type != 'default':
if args.encoder_position_encoding_type in ['None', 'none', '0']:
model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = None
elif args.encoder_position_encoding_type in [
'sinusoidal', 'rope', 'trainable', 'alibi', 'alibit', 'tkd', 'td', 'tk', 'kdt'
]:
model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = str(
args.encoder_position_encoding_type)
else:
raise ValueError(f'Encoder PE type {args.encoder_position_encoding_type} not supported')
if args.decoder_position_encoding_type != 'default':
if args.decoder_position_encoding_type in ['None', 'none', '0']:
raise ValueError('Decoder PE type cannot be None')
elif args.decoder_position_encoding_type in ['sinusoidal', 'trainable']:
model_cfg["decoder"][model_cfg["decoder_type"]]["position_encoding_type"] = str(
args.decoder_position_encoding_type)
else:
raise ValueError(f'Decoder PE {args.decoder_position_encoding_type} not supported')
if args.tie_word_embedding is not None:
model_cfg["tie_word_embedding"] = bool(args.tie_word_embedding)
if args.d_feat != None:
model_cfg["d_feat"] = int(args.d_feat)
if args.d_latent != None:
model_cfg['encoder']['perceiver-tf']["d_latent"] = int(args.d_latent)
if args.num_latents != None:
model_cfg['encoder']['perceiver-tf']['num_latents'] = int(args.num_latents)
if args.perceiver_tf_d_model != None:
model_cfg['encoder']['perceiver-tf']['d_model'] = int(args.perceiver_tf_d_model)
if args.num_perceiver_tf_blocks != None:
model_cfg["encoder"]["perceiver-tf"]["num_blocks"] = int(args.num_perceiver_tf_blocks)
if args.num_perceiver_tf_local_transformers_per_block != None:
model_cfg["encoder"]["perceiver-tf"]["num_local_transformers_per_block"] = int(
args.num_perceiver_tf_local_transformers_per_block)
if args.num_perceiver_tf_temporal_transformers_per_block != None:
model_cfg["encoder"]["perceiver-tf"]["num_temporal_transformers_per_block"] = int(
args.num_perceiver_tf_temporal_transformers_per_block)
if args.attention_to_channel != None:
model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = bool(args.attention_to_channel)
if args.sca_use_query_residual != None:
model_cfg["encoder"]["perceiver-tf"]["sca_use_query_residual"] = bool(args.sca_use_query_residual)
if args.layer_norm_type != None:
model_cfg["encoder"]["perceiver-tf"]["layer_norm"] = str(args.layer_norm_type)
if args.ff_layer_type != None:
model_cfg["encoder"]["perceiver-tf"]["ff_layer_type"] = str(args.ff_layer_type)
if args.ff_widening_factor != None:
model_cfg["encoder"]["perceiver-tf"]["ff_widening_factor"] = int(args.ff_widening_factor)
if args.moe_num_experts != None:
model_cfg["encoder"]["perceiver-tf"]["moe_num_experts"] = int(args.moe_num_experts)
if args.moe_topk != None:
model_cfg["encoder"]["perceiver-tf"]["moe_topk"] = int(args.moe_topk)
if args.hidden_act != None:
model_cfg["encoder"]["perceiver-tf"]["hidden_act"] = str(args.hidden_act)
if args.rotary_type != None:
assert len(
args.rotary_type
) == 3, "rotary_type must be a 3-letter string (e.g. 'ppl': 'pixel' for SCA, 'pixel' for latent, 'lang' for temporal transformer)"
model_cfg["encoder"]["perceiver-tf"]["rotary_type_sca"] = str(args.rotary_type)[0]
model_cfg["encoder"]["perceiver-tf"]["rotary_type_latent"] = str(args.rotary_type)[1]
model_cfg["encoder"]["perceiver-tf"]["rotary_type_temporal"] = str(args.rotary_type)[2]
if args.rope_apply_to_keys != None:
model_cfg["encoder"]["perceiver-tf"]["rope_apply_to_keys"] = bool(args.rope_apply_to_keys)
if args.rope_partial_pe != None:
model_cfg["encoder"]["perceiver-tf"]["rope_partial_pe"] = bool(args.rope_partial_pe)
if args.decoder_ff_layer_type != None:
model_cfg["decoder"][model_cfg["decoder_type"]]["ff_layer_type"] = str(args.decoder_ff_layer_type)
if args.decoder_ff_widening_factor != None:
model_cfg["decoder"][model_cfg["decoder_type"]]["ff_widening_factor"] = int(args.decoder_ff_widening_factor)
if args.event_length != None:
model_cfg["event_length"] = int(args.event_length)
if stage == 'train':
if args.encoder_dropout_rate != None:
model_cfg["encoder"][model_cfg["encoder_type"]]["dropout_rate"] = float(args.encoder_dropout_rate)
if args.decoder_dropout_rate != None:
model_cfg["decoder"][model_cfg["decoder_type"]]["dropout_rate"] = float(args.decoder_dropout_rate)
return shared_cfg, audio_cfg, model_cfg # return updated configs