File size: 40,341 Bytes
568e264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#               2023 Tsinghua Univ. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import nullcontext
import copy
from typing import List, Optional

import json
import logging
import os
import torch
import yaml

import torch.optim as optim
import torch.distributed as dist

from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
                                    CPUOffload, MixedPrecision,
                                    sharded_grad_scaler, ShardingStrategy)
try:
    import deepspeed
    from deepspeed.runtime.zero.stage_1_and_2 import (
        estimate_zero2_model_states_mem_needs_all_live)
    from deepspeed.runtime.zero.stage3 import (
        estimate_zero3_model_states_mem_needs_all_live)
    from deepspeed.utils.zero_to_fp32 import (
        convert_zero_checkpoint_to_fp32_state_dict)
except ImportError:
    pass


from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str,
                                tensor_to_scalar)
from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model,
                                    apply_fsdp_checkpointing,
                                    wenet_fsdp_wrap_policy)
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id
from wenet.utils.common import TORCH_NPU_AVAILABLE
from wenet.utils.init_dataset import init_dataset


def add_model_args(parser):
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument('--model_dir', required=True, help='save model dir')
    parser.add_argument('--checkpoint', help='checkpoint model')
    parser.add_argument('--tensorboard_dir',
                        default='tensorboard',
                        help='tensorboard log dir')
    parser.add_argument('--override_config',
                        action='append',
                        default=[],
                        help="override yaml config")
    parser.add_argument("--enc_init",
                        default=None,
                        type=str,
                        help="Pre-trained model to initialize encoder")
    parser.add_argument(
        '--enc_init_mods',
        default="encoder.",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help="List of encoder modules \
                        to initialize ,separated by a comma")
    parser.add_argument(
        '--freeze_modules',
        default="",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help='free module names',
    )
    return parser


def add_trace_args(parser):
    parser.add_argument('--jit',
                        action='store_true',
                        default=False,
                        help='if use jit to trace model while training stage')
    parser.add_argument('--print_model',
                        action='store_true',
                        default=False,
                        help='print model')
    return parser


def add_dataset_args(parser):
    parser.add_argument('--data_type',
                        default='raw',
                        # choices=['raw', 'shard'],
                        help='train and cv data type')
    parser.add_argument('--train_data', required=True, help='train data file')
    parser.add_argument('--cv_data', required=True, help='cv data file')
    parser.add_argument('--num_workers',
                        default=0,
                        type=int,
                        help='num of subprocess workers for reading')
    parser.add_argument('--pin_memory',
                        action='store_true',
                        default=False,
                        help='Use pinned memory buffers used for reading')
    parser.add_argument('--prefetch',
                        default=100,
                        type=int,
                        help='prefetch number')
    return parser


def add_lora_args(parser):
    '''Configure parameters for LoRA fine-tuning. Set use_lora and
       only_optimize_lora to true to enable LoRA functionality.
       LoRA will be injected to model through (lora_modules, lora_attn_attr,
       lora_list).
       LoRA weights will be merged after calling model.eval()
       (or model.train(mode=False)).
       LoRA weights need to be loaded after fine-tuning with DeepSpeed.
    '''
    parser.add_argument("--use_lora",
                        default=False,
                        type=bool,
                        help="whether use the lora finetune.")
    parser.add_argument("--only_optimize_lora",
                        default=False,
                        type=bool,
                        help="freeze all other paramters and only optimize \
                        LoRA-related prameters.")
    parser.add_argument(
        '--lora_modules',
        default="encoder.encoders",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help='modules names needs inject lora',
    )
    parser.add_argument(
        "--lora_attn_attr",
        default="self_attn,src_attn",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help="lora_attn_attr.")
    parser.add_argument(
        "--lora_list",
        default="linear_out,linear_q,linear_k,linear_v",
        type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
        help="lora module list.")
    parser.add_argument("--lora_rank",
                        default=8,
                        type=int,
                        help="lora rank num.")
    parser.add_argument("--lora_alpha",
                        default=8,
                        type=int,
                        help="lora scale param, scale=lora_alpha/lora_rank.")
    parser.add_argument("--lora_dropout",
                        default=0,
                        type=float,
                        help="lora dropout param.")
    parser.add_argument("--lora_ckpt_path",
                        default=None,
                        type=str,
                        help="lora checkpoint path.")
    parser.add_argument("--lora_reinit",
                        default=False,
                        type=bool,
                        help="whether use the lora init, default is zero init.")
    parser.add_argument('--lora_init_yaml',
                        default="wenet/finetune/lora/config.yaml",
                        type=str,
                        help='Path to the configuration YAML file')
    return parser


def add_ddp_args(parser):
    parser.add_argument('--ddp.dist_backend',
                        dest='dist_backend',
                        default='nccl',
                        choices=['nccl', 'gloo', "hccl"],
                        help='distributed backend')
    parser.add_argument('--use_amp',
                        action='store_true',
                        default=False,
                        help='Use automatic mixed precision training')
    parser.add_argument('--fp16_grad_sync',
                        action='store_true',
                        default=False,
                        help='Use fp16 gradient sync for ddp')
    return parser


def add_deepspeed_args(parser):
    parser.add_argument('--timeout',
                        default=30,
                        type=int,
                        help='timeout (in seconds) of wenet_join. ' +
                        '30s for aishell & 300s for wenetspeech')
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help='local rank passed from distributed launcher')
    parser.add_argument('--deepspeed.save_states',
                        dest='save_states',
                        default='model_only',
                        choices=['model_only', 'model+optimizer'],
                        help='save model/optimizer states')
    # DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser
    try:
        parser = deepspeed.add_config_arguments(parser)
    except Exception as e:
        print(e)
    return parser


def add_fsdp_args(parser):
    parser.add_argument(
        '--dtype',
        default='fp32',
        choices=['fp32', 'fp16', 'bf16'],
        help='when amp is used, dtype is automatically set to fp16.\
        this arg has no effect when deepspeed is enabled.')
    parser.add_argument(
        '--fsdp_cpu_offload',
        default=False,
        type=bool,
        help='whether to offload parameters to CPU',
    )
    parser.add_argument(
        '--fsdp_sync_module_states',
        type=bool,
        default=True,
        help='\
        each FSDP module will broadcast module parameters and buffers from \
        rank 0 to ensure that they are replicated across ranks',
    )
    parser.add_argument(
        '--fsdp_sharding_strategy',
        default='zero2',
        # TODO(Mddct): pipeline and model parallel (3-D parallelism)
        choices=['no_shard', 'model', 'zero2', 'zero3'],
        help='Sharding strategy for FSDP. Choose from the following options:\n'
        '  - "no_shard": Equivalent to DistributedDataParallel (DDP).\n'
        '  - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n'
        '  - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n'
        '  - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n'
        'For more information, refer to the FSDP API documentation.')
    return parser


def init_distributed(args):
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    rank = int(os.environ.get('RANK', 0))
    logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
                 ', rank {}, world_size {}'.format(rank, world_size))
    if args.train_engine in ["torch_ddp", "torch_fsdp"]:
        if "cuda" in args.device:
            torch.cuda.set_device(local_rank)
        elif "npu" in args.device and TORCH_NPU_AVAILABLE:
            torch.npu.set_device(local_rank)
        else:
            logging.error("not supported device: {}".format(args.device))
        dist.init_process_group(args.dist_backend)
    elif args.train_engine == "deepspeed":
        deepspeed.init_distributed(dist_backend=args.dist_backend)
    else:
        logging.error("not supported engine: {}".format(args.train_engine))
    return world_size, local_rank, rank


def check_modify_and_save_config(args, configs, symbol_table):
    if args.train_engine in ["torch_ddp", "torch_fsdp"]:
        if args.use_amp:
            configs["dtype"] = "fp16"
            args.dtype = 'fp16'
        else:
            configs["dtype"] = args.dtype
    elif args.train_engine == "deepspeed":
        # NOTE(xcsong): DeepSpeed does not support uneven data. When using custom
        #   dataset, we need to manually ensure that the data is evenly distributed
        #   across all processe. we impl `train_utils.py::wenet_join` for this func
        #   ref: https://github.com/microsoft/DeepSpeed/issues/2223
        #
        # NOTE(xsong):  We also need to keep:
        #       1. `train_micro_batch_size_per_gpu == 1`
        #       2. `accum_grad (in train_confomrer.yaml)
        #               == gradient_accumulation_steps (in ds_config.json)`
        #       3. `grad_clip (in train_confomrer.yaml)
        #               == gradient_clipping (in ds_config.json)`
        #   The reason for such consistence checking lies in that deepspeed's native
        #   dataloader uses PyTorch's torch.utils.data.DistributedSampler which does
        #   not support IterableDataset, IterableDataset is extremly useful in large
        #   scale training because it lets you stream the data without having to
        #   download the complete dataset.
        #       ref: https://github.com/microsoft/DeepSpeed/issues/1371
        #           https://github.com/microsoft/DeepSpeed/issues/285
        #   To make deepspeed training compatible with IterableDataset, we have to
        #   use custom dataloader instead of deepspeed's native loader and thus we
        #   should configure batchsize in train_confomrer.yaml instead of
        #   ds_config.json. On the contrary, gradient accumulation / clipping should be
        #   configured in ds_config.json since they will be handled by ds automatically.
        #       ref: https://github.com/microsoft/DeepSpeed/issues/62
        with open(args.deepspeed_config, 'r') as fin:
            ds_configs = json.load(fin)
        if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
            configs["dtype"] = "fp16"
        elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
            configs["dtype"] = "bf16"
        else:
            configs["dtype"] = "fp32"
        assert ds_configs["train_micro_batch_size_per_gpu"] == 1
        assert ds_configs["gradient_accumulation_steps"] == configs[
            'accum_grad']
        assert ds_configs["gradient_clipping"] == configs['grad_clip']
        assert ds_configs["steps_per_print"] == configs['log_interval']

    if args.use_lora:
        configs['lora_conf'] = {}
        configs['lora_conf']['lora_modules'] = args.lora_modules
        configs['lora_conf']['lora_attn_attr'] = args.lora_attn_attr
        configs['lora_conf']['lora_list'] = args.lora_list
        configs['lora_conf']['lora_rank'] = args.lora_rank
        configs['lora_conf']['lora_alpha'] = args.lora_alpha
        configs['lora_conf']['lora_dropout'] = args.lora_dropout

    if configs["model"] == 'asr_model':
        if 'input_dim' not in configs:
            if 'fbank_conf' in configs['dataset_conf']:
                input_dim = configs['dataset_conf']['fbank_conf'][
                    'num_mel_bins']
            elif 'log_mel_spectrogram_conf' in configs['dataset_conf']:
                input_dim = configs['dataset_conf'][
                    'log_mel_spectrogram_conf']['num_mel_bins']
            else:
                input_dim = configs['dataset_conf']['mfcc_conf'][
                    'num_mel_bins']
        else:
            input_dim = configs['input_dim']

        configs['input_dim'] = input_dim

    configs, _ = get_blank_id(configs, symbol_table)
    configs['output_dim'] = configs['vocab_size']

    configs['train_engine'] = args.train_engine
    configs['use_amp'] = args.use_amp
    configs['model_dir'] = args.model_dir
    configs['save_states'] = args.save_states

    # Save configs to model_dir/train.yaml for inference and export
    if int(os.environ.get('RANK', 0)) == 0:
        saved_config_path = os.path.join(args.model_dir, 'train.yaml')
        with open(saved_config_path, 'w') as fout:
            data = yaml.dump(configs)
            fout.write(data)

    if configs["model_conf"].get("apply_non_blank_embedding", False):
        logging.warn('Had better load a well trained model'
                     'if apply_non_blank_embedding is true !!!')

    return configs


def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):
    generator = torch.Generator()
    generator.manual_seed(seed)

    # if save_interval in configs, steps mode else epoch mode
    if "save_interval" in configs:
        configs['dataset_conf']['cycle'] = configs.get('max_epoch', 100)
    conf = configs['dataset_conf']
    dataset_type = configs.get('dataset', 'asr')
    configs['vocab_size'] = tokenizer.vocab_size()
    train_dataset = init_dataset(dataset_type,
                                 args.data_type,
                                 args.train_data,
                                 tokenizer,
                                 conf,
                                 True,
                                 split='train')
    tag = configs["init_infos"].get("tag", "init")
    train_dataset.set_epoch(configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) - 1)
    cv_conf = copy.deepcopy(conf)
    cv_conf['split_num'] = 1
    cv_dataset = init_dataset(dataset_type,
                              args.data_type,
                              args.cv_data,
                              tokenizer,
                              cv_conf,
                              partition=False,
                              split='cv')

    # NOTE(xcsong): Why we prefer persistent_workers=True ?
    #   https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110
    train_data_loader = DataLoader(train_dataset,
                                   batch_size=None,
                                   pin_memory=args.pin_memory,
                                   num_workers=args.num_workers,
                                   persistent_workers=True,
                                   generator=generator,
                                   prefetch_factor=args.prefetch)
    cv_data_loader = DataLoader(cv_dataset,
                                batch_size=None,
                                pin_memory=args.pin_memory,
                                num_workers=args.num_workers,
                                persistent_workers=True,
                                generator=generator,
                                prefetch_factor=args.prefetch)
    return train_dataset, cv_dataset, train_data_loader, cv_data_loader


def wrap_cuda_model(args, model, configs=None):
    local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    if hasattr(model, 'encoder'):
        grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False)
    else:
        grad_ckpt = False
    if args.train_engine == "torch_ddp":  # native pytorch ddp
        device = torch.device(args.device)
        model.to(device)
        # model = torch.nn.parallel.DistributedDataParallel(
        #     model, find_unused_parameters=not grad_ckpt)
        model = torch.nn.parallel.DistributedDataParallel(
            model, find_unused_parameters=True)
    elif args.train_engine == "deepspeed":  # deepspeed
        # NOTE(xcsong): look in detail how the memory estimator API works:
        #   https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
        if int(os.environ.get('RANK', 0)) == 0:
            logging.info("Estimating model states memory needs (zero2)...")
            estimate_zero2_model_states_mem_needs_all_live(
                model,
                num_gpus_per_node=local_world_size,
                num_nodes=world_size // local_world_size)
            logging.info("Estimating model states memory needs (zero3)...")
            estimate_zero3_model_states_mem_needs_all_live(
                model,
                num_gpus_per_node=local_world_size,
                num_nodes=world_size // local_world_size)
        device = torch.device(args.device)  # Init device later
        pass  # Init DeepSpeed later
    elif args.train_engine == 'torch_fsdp':
        assert configs is not None
        mixed_precision_dtype = {
            'fp32': torch.float32,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
        }[configs['dtype']]

        sharding_strategy = {
            'model': ShardingStrategy.SHARD_GRAD_OP,
            'zero2': ShardingStrategy.SHARD_GRAD_OP,
            'zero3': ShardingStrategy.FULL_SHARD,
            'no_shard': ShardingStrategy.NO_SHARD,
        }[args.fsdp_sharding_strategy]
        wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy)
        layer_types = check_gradient_checkpoint(model)
        if "cuda" in args.device:
            device_id = torch.cuda.current_device()
        elif "npu" in args.device and TORCH_NPU_AVAILABLE:
            device_id = torch.npu.current_device()
        else:
            logging.error("not supported device: {}".format(args.device))
        model = FSDP(
            model,
            auto_wrap_policy=wrap_policy,
            cpu_offload=CPUOffload(offload_params=True)
            if args.fsdp_cpu_offload is True else None,
            mixed_precision=MixedPrecision(
                param_dtype=mixed_precision_dtype,
                reduce_dtype=mixed_precision_dtype,
                buffer_dtype=mixed_precision_dtype,
            ),
            sharding_strategy=sharding_strategy,
            limit_all_gathers=True,
            use_orig_params=True,
            sync_module_states=args.fsdp_sync_module_states,
            # init_distributed is called (torch.cuda.set_device),
            # we should set device_id, see FSDP api
            device_id=device_id)
        apply_fsdp_checkpointing(model, layer_types)
        device = torch.device(args.device)
    else:
        logging.error("not supported engine: {}".format(args.train_engine))
    if args.train_engine in ["torch_fsdp", "torch_ddp"]:
        if args.fp16_grad_sync:
            from torch.distributed.algorithms.ddp_comm_hooks import (
                default as comm_hooks, )
            model.register_comm_hook(state=None,
                                     hook=comm_hooks.fp16_compress_hook)

    return model, device


def init_optimizer_and_scheduler(args, configs, model):
    groups = []
    lr = configs['optim_conf'].get('lr')
    if isinstance(lr, List):
        assert configs['scheduler'] == 'warmuplr'
        modules_m = configs['optim_conf']['modules']
        assert isinstance(modules_m, List)
        assert len(modules_m) + 1 == len(lr)
        special_param_ids = set()
        rest_params = []
        for (i, m_str) in enumerate(modules_m):
            sub_module = get_nested_attribute(model, m_str)
            subs_params = []
            for _, sub_params in sub_module.named_parameters():
                subs_params.append(sub_params)
                special_param_ids.add(id(sub_params))
            groups.append({'params': subs_params, 'lr': lr[i]})
        # other model's parameters
        for _, param in model.named_parameters():
            if id(param) not in special_param_ids:
                rest_params.append(param)
        groups.append({'params': rest_params, 'lr': lr[-1]})

    params = groups if len(groups) > 0 else model.parameters()
    optim_conf = copy.deepcopy(configs['optim_conf'])
    if 'modules' in optim_conf:
        del optim_conf['modules']
    if isinstance(lr, List):
        optim_conf['lr'] = lr[-1]
    if configs['optim'] == 'adam':
        optimizer = optim.Adam(params, **optim_conf)
    elif configs['optim'] == 'adamw':
        optimizer = optim.AdamW(params, **optim_conf)
    else:
        raise ValueError("unknown optimizer: " + configs['optim'])

    scheduler_type = None
    if configs['scheduler'] == 'warmuplr':
        scheduler_type = WarmupLR
        scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
    elif configs['scheduler'] == 'NoamHoldAnnealing':
        scheduler_type = NoamHoldAnnealing
        scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf'])
    else:
        raise ValueError("unknown scheduler: " + configs['scheduler'])

    # NOTE(xcsong): Custom optimizer might yield poor performance when
    #   zero-offload is enabled, if you do want to offload optimizer to CPU,
    #   please set optimizer in ds_config.json, see:
    #   (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters)
    if args.train_engine == "deepspeed":
        with open(args.deepspeed_config, 'r') as fin:
            ds_configs = json.load(fin)
        if "optimizer" in ds_configs:
            # NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
            # extremely useful when enable cpu_offload, DeepspeedCpuAdam
            # could be 4~5x faster than torch native adam
            optimizer = None
            if "scheduler" in ds_configs:
                scheduler = None
            else:

                def scheduler(opt):
                    return scheduler_type(opt, **configs['scheduler_conf'])

        model, optimizer, _, scheduler = deepspeed.initialize(
            args=args,
            model=model,
            optimizer=optimizer,
            lr_scheduler=scheduler,
            model_parameters=model.parameters())

    step = configs.get("init_infos", {}).get("step", -1)
    scheduler.set_step(step)
    return model, optimizer, scheduler


def trace_and_print_model(args, model):
    # !!!IMPORTANT!!!
    # Try to export the model by script, if fails, we should refine
    # the code to satisfy the script export requirements
    if int(os.environ.get('RANK', 0)) == 0:
        if args.jit:
            script_model = torch.jit.script(model)
            script_model.save(os.path.join(args.model_dir, 'init.zip'))
        if args.print_model:
            print(model)
            num_params = sum(p.numel() for p in model.parameters())
            print('the number of model params: {:,d}'.format(num_params))


def init_summarywriter(args):
    writer = None
    if int(os.environ.get('RANK', 0)) == 0:
        os.makedirs(args.model_dir, exist_ok=True)
        exp_id = os.path.basename(args.model_dir)
        writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
    return writer


def init_scaler(args):
    scaler = None
    if args.use_amp:
        if "cuda" in args.device:
            scaler = torch.cuda.amp.GradScaler()
        elif "npu" in args.device and TORCH_NPU_AVAILABLE:
            scaler = torch.npu.amp.GradScaler()
        else:
            logging.error("not supported device: {}".format(args.device))
    elif args.train_engine == 'torch_fsdp':
        # why bf16 don't need scaler:
        # https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596
        if args.dtype in ['fp16']:
            scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True)
    return scaler


def save_model(model, info_dict):
    rank = int(os.environ.get('RANK', 0))
    tag = info_dict["tag"]
    model_dir = info_dict["model_dir"]
    save_model_path = os.path.join(model_dir, '{}.pt'.format(tag))
    # save ckpt
    if info_dict["train_engine"] == "deepspeed":
        # NOTE(xcsong): All ranks should call this API, but only rank 0
        #   save the general model params. see:
        #   https://github.com/microsoft/DeepSpeed/issues/2993
        with torch.no_grad():
            model.save_checkpoint(save_dir=model_dir,
                                  tag=tag,
                                  client_state=info_dict)
            if info_dict["save_states"] == "model_only" and rank == 0:
                convert_zero_checkpoint_to_fp32_state_dict(model_dir,
                                                           save_model_path,
                                                           tag=tag)
                os.system("rm -rf {}/{}".format(model_dir, tag))

    elif info_dict['train_engine'] == "torch_fsdp":
        fsdp_save_model(model, save_model_path, info_dict)
    elif rank == 0:
        # NOTE(xcsong): For torch_ddp, only rank-0 should call this.
        save_checkpoint(model, save_model_path, info_dict)
    # save yaml
    if rank == 0:
        with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout:
            data = yaml.dump(info_dict)
            fout.write(data)


def wenet_join(group_join, info_dict):
    world_size = int(os.environ.get('WORLD_SIZE', 1))
    local_rank = int(os.environ.get('LOCAL_RANK', 0))
    rank = int(os.environ.get('RANK', 0))
    train_engine = info_dict.get('train_engine', "torch_ddp")

    if info_dict["batch_idx"] == 0 or train_engine == "torch_ddp":
        # NOTE(xcsong): skip first batch because its processing time includes
        #   dataloader initialization time, which may exceed 30 seconds
        return False

    try:
        # NOTE(xcsong): Why we need a new group?
        #   Because Deepspeed has its own group where all the relevant communication
        #   operations are executed. If we add a communication operation that is not
        #   managed by Deepspeed in this group, it's highly likely to cause
        #   communication chaos, resulting in hard-to-troubleshoot hangs.
        dist.monitored_barrier(group=group_join,
                               timeout=group_join.options._timeout)
    except RuntimeError as e:
        logging.info("Detected uneven workload distribution: {}\n".format(e) +
                     "Break current worker to manually join all workers, " +
                     "world_size {}, current rank {}, current local_rank {}\n".
                     format(world_size, rank, local_rank))
        return True

    return False


def batch_forward(model, batch, scaler, info_dict, device):
    train_engine = info_dict.get('train_engine', "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1)

    dtype = info_dict.get("dtype", "fp32")
    if dtype == "fp16":
        dtype = torch.float16
    elif dtype == "bf16":
        dtype = torch.bfloat16
    else:  # fp32
        dtype = None

    # autocast context
    # The more details about amp can be found in
    # https://pytorch.org/docs/stable/notes/amp_examples.html
    amp_autocast = torch.cuda.amp.autocast
    if "npu" in device.__str__() and TORCH_NPU_AVAILABLE:
        amp_autocast = torch.npu.amp.autocast
    autocast = {
        "deepspeed":
        amp_autocast(enabled=dtype is not None,
                     dtype=dtype,
                     cache_enabled=False),
        "torch_ddp":
        amp_autocast(enabled=scaler is not None),
        "torch_fsdp":
        amp_autocast(enabled=True, dtype=dtype)
        if dtype is not None else nullcontext()
    }[train_engine]
    with autocast:
        loss_dict = model(batch, device)

    info_dict['loss_dict'] = loss_dict
    return info_dict


def batch_backward(model, scaler, info_dict):
    train_engine = info_dict.get("train_engine", "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1)
    use_amp = info_dict.get('use_amp', False)
    if use_amp:
        assert scaler is not None
    loss = info_dict['loss_dict']['loss']

    if train_engine == "deepspeed":
        # NOTE(xcsong): `model.backward(loss)` is equivalent to
        #               `scale_loss_wrt_accum_grad + loss.backward()`
        #   ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api
        scaled_loss = model.backward(loss)
    else:
        assert train_engine in ["torch_ddp", "torch_fsdp"]
        scaled_loss = loss / accum_grad
        if scaler is not None:
            # fp16 (amp and fsdp)
            scaler.scale(scaled_loss).backward()
        else:
            # float32  (ddp and fsdp)
            # bf16 (fsdp)
            scaled_loss.backward()

    info_dict['loss_dict']['loss'] = scaled_loss
    for loss_name, loss_value in info_dict['loss_dict'].items():
        if loss_value is not None:
            info_dict['loss_dict'][loss_name] = tensor_to_scalar(loss_value)

    return info_dict


def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
    rank = int(os.environ.get('RANK', 0))
    train_engine = info_dict.get("train_engine", "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1)
    use_amp = info_dict.get('use_amp', False)
    clip = info_dict.get('grad_clip', 50.0)
    batch_idx = info_dict["batch_idx"]
    if use_amp:
        assert scaler is not None

    grad_norm = 0.0
    if train_engine == "deepspeed":
        # NOTE(xcsong): The step() function in DeepSpeed engine updates the
        #   model parameters as well as the learning rate.
        #   Zeroing the gradients is handled automatically by
        #   DeepSpeed after the weights have been updated using a mini-batch.
        #   DeepSpeed also performs gradient averaging automatically at the
        #   gradient accumulation boundaries and addresses clip_grad_norm internally.
        #   `ds_model.step() =  clip_grad_norm_() + optimizer.step()
        #                       + optimizer.zero_grad() + scheduler.step()`
        #   ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api
        info_dict["is_gradient_accumulation_boundary"] = \
            model.is_gradient_accumulation_boundary()
        model.step()
        grad_norm = model.get_global_grad_norm()
        if grad_norm is None:
            grad_norm = 0.0
    elif (batch_idx + 1) % accum_grad == 0:
        # Use mixed precision training
        # fp16 (ddp fsdp)
        if scaler is not None:
            scaler.unscale_(optimizer)
            if train_engine == "torch_ddp":
                grad_norm = clip_grad_norm_(model.parameters(), clip)
            else:
                # fsdp
                grad_norm = model.clip_grad_norm_(clip)
            # Must invoke scaler.update() if unscale_() is used in
            # the iteration to avoid the following error:
            #   RuntimeError: unscale_() has already been called
            #   on this optimizer since the last update().
            # We don't check grad here since that if the gradient
            # has inf/nan values, scaler.step will skip
            # optimizer.step().
            scaler.step(optimizer)
            scaler.update()
        else:
            if train_engine == "torch_ddp":
                grad_norm = clip_grad_norm_(model.parameters(), clip)
            else:
                grad_norm = model.clip_grad_norm_(clip)
            if torch.isfinite(grad_norm):
                optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

    info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups]
    info_dict["grad_norm"] = tensor_to_scalar(grad_norm)

    return info_dict


def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
    tag = info_dict["tag"]
    step = info_dict["step"]
    batch_idx = info_dict["batch_idx"]
    loss_dict = info_dict['loss_dict']
    epoch = info_dict.get('epoch', 0)
    train_engine = info_dict.get("train_engine", "torch_ddp")
    accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1
    log_interval = info_dict.get('log_interval', 10)
    lrs = info_dict.get("lrs", [0.0])
    is_gradient_accumulation_boundary = info_dict.get(
        "is_gradient_accumulation_boundary", False)

    rank = int(os.environ.get('RANK', 0))
    # TRAIN Tensorboard
    if tag == "TRAIN" and rank == 0 and writer is not None:
        if (train_engine == "deepspeed" and is_gradient_accumulation_boundary
            ) or (train_engine in ["torch_ddp", "torch_fsdp"] and
                  (batch_idx + 1) % accum_grad == 0):
            writer.add_scalar('train/train_loss',
                              tensor_to_scalar(loss_dict['loss']) * accum_grad,
                              step)
            writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step)
            for name, value in loss_dict.items():
                if name != 'loss' and value is not None:
                    writer.add_scalar('train/{}'.format(name),
                                      tensor_to_scalar(value), step)
            # lr
            for i, lr in enumerate(lrs):
                writer.add_scalar('train/lr_{}'.format(i), lr, step)
    # CV Tensorboard
    elif "step_" in tag and rank == 0 and writer is not None:
        for name, value in loss_dict.items():
            writer.add_scalar('cv/{}'.format(name), tensor_to_scalar(value),
                              step)
        logging.info(
            'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
                epoch, step + 1, lrs_to_str(lrs),
                tensor_to_scalar(loss_dict["loss"]), rank,
                tensor_to_scalar(loss_dict["acc"])))
        return

    # TRAIN & CV, Shell log (stdout)
    if (batch_idx + 1) % log_interval == 0:
        log_str = '{} | '.format(tag)
        if timer is not None:
            timer_step = step
            if info_dict.get("cv_step", None) is not None:
                timer_step = info_dict['cv_step']
            steps_per_second = timer.steps_per_second(timer_step)
            log_str += 'steps/sec {:.3f}| '.format(steps_per_second)
        log_str += 'Batch {}/{} loss {:.6f} '.format(
            epoch, batch_idx + 1 if 'save_interval' not in info_dict else
            (step + 1) * accum_grad,
            tensor_to_scalar(loss_dict['loss']) * accum_grad)
        for name, value in loss_dict.items():
            if name != 'loss' and value is not None:
                log_str += '{} {:.6f} '.format(name, tensor_to_scalar(value))
        if tag == "TRAIN":
            log_str += 'lr {} grad_norm {:.6f} rank {}'.format(
                lrs_to_str(lrs), info_dict['grad_norm'], rank)
        logging.debug(log_str)


def log_per_epoch(writer, info_dict):
    epoch = info_dict["epoch"]
    loss_dict = info_dict["loss_dict"]
    lrs = info_dict['lrs']
    rank = int(os.environ.get('RANK', 0))
    step = info_dict["step"]
    logging.info(
        'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
            epoch, step, lrs_to_str(lrs), tensor_to_scalar(loss_dict["loss"]),
            rank, tensor_to_scalar(loss_dict["acc"])))

    if int(os.environ.get('RANK', 0)) == 0:
        for i, lr in enumerate(info_dict["lrs"]):
            writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch)
        for name, value in loss_dict.items():
            writer.add_scalar('epoch/{}'.format(name), tensor_to_scalar(value),
                              epoch)


def freeze_modules(model, args):
    for name, param in model.named_parameters():
        for module_name in args.freeze_modules:
            if module_name in name:
                param.requires_grad = False
                logging.debug("{} module is freezed".format(name))


def reinit_lora(model, args, configs, tokenizer, seed=777):
    from tqdm import tqdm
    from wenet.finetune.lora.utils import estimate_gradient, reinit_lora_modules
    from wenet.finetune.lora.layers import LoRALayer
    from types import SimpleNamespace

    logging.info("reinit lora modules.")
    with open(args.lora_init_yaml, 'r') as file:
        lora_config = yaml.safe_load(file)

    generator = torch.Generator()
    generator.manual_seed(seed)
    dataset_conf = copy.deepcopy(configs['dataset_conf'])
    dataset_conf['batch_conf']['batch_size'] = lora_config['init_batch_size']
    dataset_type = configs.get('dataset', 'asr')
    dataset = init_dataset(dataset_type, args.data_type, args.train_data,
                           tokenizer, dataset_conf, True)
    dataloader = DataLoader(dataset,
                            batch_size=None,
                            pin_memory=args.pin_memory,
                            num_workers=args.num_workers,
                            persistent_workers=True,
                            generator=generator,
                            prefetch_factor=args.prefetch)
    additional_kwargs = {}
    if lora_config["init_config"]["mode"] == "gradient":
        named_grads = estimate_gradient(model, dataloader,
                                        lora_config['init_iters'])
        additional_kwargs["named_grads"] = named_grads
    lora_config = SimpleNamespace(**lora_config["init_config"])
    for name, module in tqdm(
        model.named_modules(),
        desc="Reinitializing Lora",
        total=len(list(model.named_modules())),
    ):
        if isinstance(module, LoRALayer):
            reinit_lora_modules(name, module, lora_config, **additional_kwargs)
    # lora_init_model needs to be saved, w0 = w0 - A0 * B0
    save_checkpoint(model, os.path.join(args.model_dir, "lora_init.pt"),
                    infos={"tag": "lora_init", **configs})