File size: 14,198 Bytes
a03c9b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
# Copyright 2024 The YourMT3 Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Please see the details in the LICENSE file.
"""init_train.py"""
from typing import Tuple, Literal, Any
from copy import deepcopy
import os
import argparse
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import LearningRateMonitor
from pytorch_lightning.utilities import rank_zero_only
from config.config import shared_cfg as default_shared_cfg
from config.config import audio_cfg as default_audio_cfg
from config.config import model_cfg as default_model_cfg
from config.config import DEEPSPEED_CFG


def initialize_trainer(args: argparse.Namespace,
                       stage: Literal['train', 'test'] = 'train') -> Tuple[pl.Trainer, WandbLogger, dict]:
    """Initialize trainer and logger"""
    shared_cfg = deepcopy(default_shared_cfg)

    # create save dir
    os.makedirs(shared_cfg["WANDB"]["save_dir"], exist_ok=True)

    # collecting specific checkpoint from exp_id with extension (@xxx where xxx is checkpoint name)
    if "@" in args.exp_id:
        args.exp_id, checkpoint_name = args.exp_id.split("@")
    else:
        checkpoint_name = "last.ckpt"

    # checkpoint dir
    lightning_dir = os.path.join(shared_cfg["WANDB"]["save_dir"], args.project, args.exp_id)

    # create logger
    if args.wandb_mode is not None:
        shared_cfg["WANDB"]["mode"] = str(args.wandb_mode)
    if shared_cfg["WANDB"].get("cache_dir", None) is not None:
        os.environ["WANDB_CACHE_DIR"] = shared_cfg["WANDB"].get("cache_dir")
        del shared_cfg["WANDB"]["cache_dir"]  # remove cache_dir from shared_cfg
    wandb_logger = WandbLogger(log_model="all",
                               project=args.project,
                               id=args.exp_id,
                               allow_val_change=True,
                               **shared_cfg['WANDB'])

    # check if any checkpoint exists
    last_ckpt_path = os.path.join(lightning_dir, "checkpoints", checkpoint_name)
    if os.path.exists(os.path.join(last_ckpt_path)):
        print(f'Resuming from {last_ckpt_path}')
    elif stage == 'train':
        print(f'No checkpoint found in {last_ckpt_path}. Starting from scratch')
        last_ckpt_path = None
    else:
        raise ValueError(f'No checkpoint found in {last_ckpt_path}. Quit...')

    # add info
    dir_info = dict(lightning_dir=lightning_dir, last_ckpt_path=last_ckpt_path)

    # define checkpoint callback
    checkpoint_callback = ModelCheckpoint(**shared_cfg["CHECKPOINT"],)

    # define lr scheduler monitor callback
    lr_monitor = LearningRateMonitor(logging_interval='step')

    # deepspeed strategy
    if args.strategy == 'deepspeed':
        strategy = pl.strategies.DeepSpeedStrategy(config=DEEPSPEED_CFG)

    # validation interval
    if stage == 'train' and args.val_interval is not None:
        shared_cfg["TRAINER"]["check_val_every_n_epoch"] = None
        shared_cfg["TRAINER"]["val_check_interval"] = int(args.val_interval)

    # define trainer
    sync_batchnorm = False
    if stage == 'train':
        # train batch size
        if args.train_batch_size is not None:
            train_sub_bsz = int(args.train_batch_size[0])
            train_local_bsz = int(args.train_batch_size[1])
            if train_local_bsz % train_sub_bsz == 0:
                shared_cfg["BSZ"]["train_sub"] = train_sub_bsz
                shared_cfg["BSZ"]["train_local"] = train_local_bsz
            else:
                raise ValueError(
                    f'Local batch size {train_local_bsz} must be divisible by sub batch size {train_sub_bsz}')

        # ddp strategy
        if args.strategy == 'ddp':
            args.strategy = 'ddp_find_unused_parameters_true'  # fix for conformer or pitchshifter having unused parameter issue

            # sync-batchnorm
            if args.sync_batchnorm is True:
                sync_batchnorm = True

    train_params = dict(**shared_cfg["TRAINER"],
                        devices=args.num_gpus if args.num_gpus == 'auto' else int(args.num_gpus),
                        num_nodes=int(args.num_nodes),
                        strategy=strategy if args.strategy == 'deepspeed' else args.strategy,
                        precision=args.precision,
                        max_epochs=args.max_epochs if stage == 'train' else None,
                        max_steps=args.max_steps if stage == 'train' else -1,
                        logger=wandb_logger,
                        callbacks=[checkpoint_callback, lr_monitor],
                        sync_batchnorm=sync_batchnorm)
    trainer = pl.trainer.trainer.Trainer(**train_params)

    # Update wandb logger (for DDP)
    if trainer.global_rank == 0:
        wandb_logger.experiment.config.update(args, allow_val_change=True)

    return trainer, wandb_logger, dir_info, shared_cfg


def update_config(args, shared_cfg, stage: Literal['train', 'test'] = 'train'):
    """Update audio/model/shared configurations with args"""
    audio_cfg = default_audio_cfg
    model_cfg = default_model_cfg

    # Only update config when training
    if stage == 'train':
        # Augmentation parameters
        if args.random_amp_range is not None:
            shared_cfg["AUGMENTATION"]["train_random_amp_range"] = list(
                (float(args.random_amp_range[0]), float(args.random_amp_range[1])))
        if args.stem_iaug_prob is not None:
            shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"] = float(args.stem_iaug_prob)

        if args.xaug_max_k is not None:
            shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["max_k"] = int(args.xaug_max_k)
        if args.xaug_tau is not None:
            shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["tau"] = float(args.xaug_tau)
        if args.xaug_alpha is not None:
            shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["alpha"] = float(args.xaug_alpha)
        if args.xaug_no_instr_overlap is not None:
            shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_instr_overlap"] = bool(args.xaug_no_instr_overlap)
        if args.xaug_no_drum_overlap is not None:
            shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["no_drum_overlap"] = bool(args.xaug_no_drum_overlap)
        if args.uhat_intra_stem_augment is not None:
            shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]["uhat_intra_stem_augment"] = bool(
                args.uhat_intra_stem_augment)

        if args.pitch_shift_range is not None:
            if args.pitch_shift_range in [["0", "0"], [0, 0]]:
                shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = None
            else:
                shared_cfg["AUGMENTATION"]["train_pitch_shift_range"] = list(
                    (int(args.pitch_shift_range[0]), int(args.pitch_shift_range[1])))

        train_stem_iaug_prob = shared_cfg["AUGMENTATION"]["train_stem_iaug_prob"]
        random_amp_range = shared_cfg["AUGMENTATION"]["train_random_amp_range"]
        train_stem_xaug_policy = shared_cfg["AUGMENTATION"]["train_stem_xaug_policy"]
        print(f'Random amp range: {random_amp_range}\n' +
              f'Intra-stem augmentation probability: {train_stem_iaug_prob}\n' +
              f'Stem augmentation policy: {train_stem_xaug_policy}\n' +
              f'Pitch shift range: {shared_cfg["AUGMENTATION"]["train_pitch_shift_range"]}\n')

    # Update audio config
    if args.audio_codec != None:
        assert args.audio_codec in ['spec', 'melspec']
        audio_cfg["codec"] = str(args.audio_codec)
    if args.hop_length != None:
        audio_cfg["hop_length"] = int(args.hop_length)
    if args.n_mels != None:
        audio_cfg["n_mels"] = int(args.n_mels)
    if args.input_frames != None:
        audio_cfg["input_frames"] = int(args.input_frames)

    # Update shared config
    if shared_cfg["TOKENIZER"]["max_shift_steps"] == "auto":
        shift_steps_ms = shared_cfg["TOKENIZER"]["shift_step_ms"]
        input_frames = audio_cfg["input_frames"]
        fs = audio_cfg["sample_rate"]
        max_shift_steps = (input_frames / fs) // (shift_steps_ms / 1000) + 2  # 206 by default
        shared_cfg["TOKENIZER"]["max_shift_steps"] = int(max_shift_steps)

    # Update model config
    if args.encoder_type != None:
        model_cfg["encoder_type"] = str(args.encoder_type)
    if args.decoder_type != None:
        model_cfg["decoder_type"] = str(args.decoder_type)
    if args.pre_encoder_type != "default":
        model_cfg["pre_encoder_type"] = str(args.pre_encoder_type)
    if args.pre_decoder_type != 'default':
        model_cfg["pre_decoder_type"] = str(args.pre_decoder_type)
    if args.conv_out_channels != None:
        model_cfg["conv_out_channels"] = int(args.conv_out_channels)
    assert isinstance(args.task_cond_decoder, bool) and isinstance(args.task_cond_encoder, bool)
    model_cfg["use_task_conditional_encoder"] = args.task_cond_encoder
    model_cfg["use_task_conditional_decoder"] = args.task_cond_decoder

    if args.encoder_position_encoding_type != 'default':
        if args.encoder_position_encoding_type in ['None', 'none', '0']:
            model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = None
        elif args.encoder_position_encoding_type in [
                'sinusoidal', 'rope', 'trainable', 'alibi', 'alibit', 'tkd', 'td', 'tk', 'kdt'
        ]:
            model_cfg["encoder"][model_cfg["encoder_type"]]["position_encoding_type"] = str(
                args.encoder_position_encoding_type)
        else:
            raise ValueError(f'Encoder PE type {args.encoder_position_encoding_type} not supported')
    if args.decoder_position_encoding_type != 'default':
        if args.decoder_position_encoding_type in ['None', 'none', '0']:
            raise ValueError('Decoder PE type cannot be None')
        elif args.decoder_position_encoding_type in ['sinusoidal', 'trainable']:
            model_cfg["decoder"][model_cfg["decoder_type"]]["position_encoding_type"] = str(
                args.decoder_position_encoding_type)
        else:
            raise ValueError(f'Decoder PE {args.decoder_position_encoding_type} not supported')

    if args.tie_word_embedding is not None:
        model_cfg["tie_word_embedding"] = bool(args.tie_word_embedding)

    if args.d_feat != None:
        model_cfg["d_feat"] = int(args.d_feat)
    if args.d_latent != None:
        model_cfg['encoder']['perceiver-tf']["d_latent"] = int(args.d_latent)
    if args.num_latents != None:
        model_cfg['encoder']['perceiver-tf']['num_latents'] = int(args.num_latents)
    if args.perceiver_tf_d_model != None:
        model_cfg['encoder']['perceiver-tf']['d_model'] = int(args.perceiver_tf_d_model)
    if args.num_perceiver_tf_blocks != None:
        model_cfg["encoder"]["perceiver-tf"]["num_blocks"] = int(args.num_perceiver_tf_blocks)
    if args.num_perceiver_tf_local_transformers_per_block != None:
        model_cfg["encoder"]["perceiver-tf"]["num_local_transformers_per_block"] = int(
            args.num_perceiver_tf_local_transformers_per_block)
    if args.num_perceiver_tf_temporal_transformers_per_block != None:
        model_cfg["encoder"]["perceiver-tf"]["num_temporal_transformers_per_block"] = int(
            args.num_perceiver_tf_temporal_transformers_per_block)
    if args.attention_to_channel != None:
        model_cfg["encoder"]["perceiver-tf"]["attention_to_channel"] = bool(args.attention_to_channel)
    if args.sca_use_query_residual != None:
        model_cfg["encoder"]["perceiver-tf"]["sca_use_query_residual"] = bool(args.sca_use_query_residual)
    if args.layer_norm_type != None:
        model_cfg["encoder"]["perceiver-tf"]["layer_norm"] = str(args.layer_norm_type)
    if args.ff_layer_type != None:
        model_cfg["encoder"]["perceiver-tf"]["ff_layer_type"] = str(args.ff_layer_type)
    if args.ff_widening_factor != None:
        model_cfg["encoder"]["perceiver-tf"]["ff_widening_factor"] = int(args.ff_widening_factor)
    if args.moe_num_experts != None:
        model_cfg["encoder"]["perceiver-tf"]["moe_num_experts"] = int(args.moe_num_experts)
    if args.moe_topk != None:
        model_cfg["encoder"]["perceiver-tf"]["moe_topk"] = int(args.moe_topk)
    if args.hidden_act != None:
        model_cfg["encoder"]["perceiver-tf"]["hidden_act"] = str(args.hidden_act)
    if args.rotary_type != None:
        assert len(
            args.rotary_type
        ) == 3, "rotary_type must be a 3-letter string (e.g. 'ppl': 'pixel' for SCA, 'pixel' for latent, 'lang' for temporal transformer)"
        model_cfg["encoder"]["perceiver-tf"]["rotary_type_sca"] = str(args.rotary_type)[0]
        model_cfg["encoder"]["perceiver-tf"]["rotary_type_latent"] = str(args.rotary_type)[1]
        model_cfg["encoder"]["perceiver-tf"]["rotary_type_temporal"] = str(args.rotary_type)[2]
    if args.rope_apply_to_keys != None:
        model_cfg["encoder"]["perceiver-tf"]["rope_apply_to_keys"] = bool(args.rope_apply_to_keys)
    if args.rope_partial_pe != None:
        model_cfg["encoder"]["perceiver-tf"]["rope_partial_pe"] = bool(args.rope_partial_pe)

    if args.decoder_ff_layer_type != None:
        model_cfg["decoder"][model_cfg["decoder_type"]]["ff_layer_type"] = str(args.decoder_ff_layer_type)
    if args.decoder_ff_widening_factor != None:
        model_cfg["decoder"][model_cfg["decoder_type"]]["ff_widening_factor"] = int(args.decoder_ff_widening_factor)

    if args.event_length != None:
        model_cfg["event_length"] = int(args.event_length)

    if stage == 'train':
        if args.encoder_dropout_rate != None:
            model_cfg["encoder"][model_cfg["encoder_type"]]["dropout_rate"] = float(args.encoder_dropout_rate)
        if args.decoder_dropout_rate != None:
            model_cfg["decoder"][model_cfg["decoder_type"]]["dropout_rate"] = float(args.decoder_dropout_rate)

    return shared_cfg, audio_cfg, model_cfg  # return updated configs