OSUM / wenet /utils /train_utils.py
tomxxie
适配zeroGPU
568e264
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2023 Tsinghua Univ. (authors: Xingchen Song)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from contextlib import nullcontext
import copy
from typing import List, Optional
import json
import logging
import os
import torch
import yaml
import torch.optim as optim
import torch.distributed as dist
from tensorboardX import SummaryWriter
from torch.utils.data import DataLoader
from torch.nn.utils import clip_grad_norm_
from torch.distributed.fsdp import (FullyShardedDataParallel as FSDP,
CPUOffload, MixedPrecision,
sharded_grad_scaler, ShardingStrategy)
try:
import deepspeed
from deepspeed.runtime.zero.stage_1_and_2 import (
estimate_zero2_model_states_mem_needs_all_live)
from deepspeed.runtime.zero.stage3 import (
estimate_zero3_model_states_mem_needs_all_live)
from deepspeed.utils.zero_to_fp32 import (
convert_zero_checkpoint_to_fp32_state_dict)
except ImportError:
pass
from wenet.utils.checkpoint import save_checkpoint
from wenet.utils.common import (StepTimer, get_nested_attribute, lrs_to_str,
tensor_to_scalar)
from wenet.utils.fsdp_utils import (check_gradient_checkpoint, fsdp_save_model,
apply_fsdp_checkpointing,
wenet_fsdp_wrap_policy)
from wenet.utils.scheduler import WarmupLR, NoamHoldAnnealing
from wenet.utils.ctc_utils import get_blank_id
from wenet.utils.common import TORCH_NPU_AVAILABLE
from wenet.utils.init_dataset import init_dataset
def add_model_args(parser):
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--override_config',
action='append',
default=[],
help="override yaml config")
parser.add_argument("--enc_init",
default=None,
type=str,
help="Pre-trained model to initialize encoder")
parser.add_argument(
'--enc_init_mods',
default="encoder.",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="List of encoder modules \
to initialize ,separated by a comma")
parser.add_argument(
'--freeze_modules',
default="",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help='free module names',
)
return parser
def add_trace_args(parser):
parser.add_argument('--jit',
action='store_true',
default=False,
help='if use jit to trace model while training stage')
parser.add_argument('--print_model',
action='store_true',
default=False,
help='print model')
return parser
def add_dataset_args(parser):
parser.add_argument('--data_type',
default='raw',
# choices=['raw', 'shard'],
help='train and cv data type')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
return parser
def add_lora_args(parser):
'''Configure parameters for LoRA fine-tuning. Set use_lora and
only_optimize_lora to true to enable LoRA functionality.
LoRA will be injected to model through (lora_modules, lora_attn_attr,
lora_list).
LoRA weights will be merged after calling model.eval()
(or model.train(mode=False)).
LoRA weights need to be loaded after fine-tuning with DeepSpeed.
'''
parser.add_argument("--use_lora",
default=False,
type=bool,
help="whether use the lora finetune.")
parser.add_argument("--only_optimize_lora",
default=False,
type=bool,
help="freeze all other paramters and only optimize \
LoRA-related prameters.")
parser.add_argument(
'--lora_modules',
default="encoder.encoders",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help='modules names needs inject lora',
)
parser.add_argument(
"--lora_attn_attr",
default="self_attn,src_attn",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="lora_attn_attr.")
parser.add_argument(
"--lora_list",
default="linear_out,linear_q,linear_k,linear_v",
type=lambda s: [str(mod) for mod in s.split(",") if s != ""],
help="lora module list.")
parser.add_argument("--lora_rank",
default=8,
type=int,
help="lora rank num.")
parser.add_argument("--lora_alpha",
default=8,
type=int,
help="lora scale param, scale=lora_alpha/lora_rank.")
parser.add_argument("--lora_dropout",
default=0,
type=float,
help="lora dropout param.")
parser.add_argument("--lora_ckpt_path",
default=None,
type=str,
help="lora checkpoint path.")
parser.add_argument("--lora_reinit",
default=False,
type=bool,
help="whether use the lora init, default is zero init.")
parser.add_argument('--lora_init_yaml',
default="wenet/finetune/lora/config.yaml",
type=str,
help='Path to the configuration YAML file')
return parser
def add_ddp_args(parser):
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo', "hccl"],
help='distributed backend')
parser.add_argument('--use_amp',
action='store_true',
default=False,
help='Use automatic mixed precision training')
parser.add_argument('--fp16_grad_sync',
action='store_true',
default=False,
help='Use fp16 gradient sync for ddp')
return parser
def add_deepspeed_args(parser):
parser.add_argument('--timeout',
default=30,
type=int,
help='timeout (in seconds) of wenet_join. ' +
'30s for aishell & 300s for wenetspeech')
parser.add_argument('--local_rank',
type=int,
default=-1,
help='local rank passed from distributed launcher')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
# DeepSpeed automaticly add '--deepspeed' and '--deepspeed_config' to parser
try:
parser = deepspeed.add_config_arguments(parser)
except Exception as e:
print(e)
return parser
def add_fsdp_args(parser):
parser.add_argument(
'--dtype',
default='fp32',
choices=['fp32', 'fp16', 'bf16'],
help='when amp is used, dtype is automatically set to fp16.\
this arg has no effect when deepspeed is enabled.')
parser.add_argument(
'--fsdp_cpu_offload',
default=False,
type=bool,
help='whether to offload parameters to CPU',
)
parser.add_argument(
'--fsdp_sync_module_states',
type=bool,
default=True,
help='\
each FSDP module will broadcast module parameters and buffers from \
rank 0 to ensure that they are replicated across ranks',
)
parser.add_argument(
'--fsdp_sharding_strategy',
default='zero2',
# TODO(Mddct): pipeline and model parallel (3-D parallelism)
choices=['no_shard', 'model', 'zero2', 'zero3'],
help='Sharding strategy for FSDP. Choose from the following options:\n'
' - "no_shard": Equivalent to DistributedDataParallel (DDP).\n'
' - "model": WENET_ENC_DEC strategy, equivalent to DeepSpeed zero1.\n'
' - "zero2": SHARD_GRAD_OP strategy, equivalent to DeepSpeed zero2.\n'
' - "zero3": FULL_SHARD strategy, equivalent to DeepSpeed zero3.\n'
'For more information, refer to the FSDP API documentation.')
return parser
def init_distributed(args):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
', rank {}, world_size {}'.format(rank, world_size))
if args.train_engine in ["torch_ddp", "torch_fsdp"]:
if "cuda" in args.device:
torch.cuda.set_device(local_rank)
elif "npu" in args.device and TORCH_NPU_AVAILABLE:
torch.npu.set_device(local_rank)
else:
logging.error("not supported device: {}".format(args.device))
dist.init_process_group(args.dist_backend)
elif args.train_engine == "deepspeed":
deepspeed.init_distributed(dist_backend=args.dist_backend)
else:
logging.error("not supported engine: {}".format(args.train_engine))
return world_size, local_rank, rank
def check_modify_and_save_config(args, configs, symbol_table):
if args.train_engine in ["torch_ddp", "torch_fsdp"]:
if args.use_amp:
configs["dtype"] = "fp16"
args.dtype = 'fp16'
else:
configs["dtype"] = args.dtype
elif args.train_engine == "deepspeed":
# NOTE(xcsong): DeepSpeed does not support uneven data. When using custom
# dataset, we need to manually ensure that the data is evenly distributed
# across all processe. we impl `train_utils.py::wenet_join` for this func
# ref: https://github.com/microsoft/DeepSpeed/issues/2223
#
# NOTE(xsong): We also need to keep:
# 1. `train_micro_batch_size_per_gpu == 1`
# 2. `accum_grad (in train_confomrer.yaml)
# == gradient_accumulation_steps (in ds_config.json)`
# 3. `grad_clip (in train_confomrer.yaml)
# == gradient_clipping (in ds_config.json)`
# The reason for such consistence checking lies in that deepspeed's native
# dataloader uses PyTorch's torch.utils.data.DistributedSampler which does
# not support IterableDataset, IterableDataset is extremly useful in large
# scale training because it lets you stream the data without having to
# download the complete dataset.
# ref: https://github.com/microsoft/DeepSpeed/issues/1371
# https://github.com/microsoft/DeepSpeed/issues/285
# To make deepspeed training compatible with IterableDataset, we have to
# use custom dataloader instead of deepspeed's native loader and thus we
# should configure batchsize in train_confomrer.yaml instead of
# ds_config.json. On the contrary, gradient accumulation / clipping should be
# configured in ds_config.json since they will be handled by ds automatically.
# ref: https://github.com/microsoft/DeepSpeed/issues/62
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
configs["dtype"] = "fp16"
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
configs["dtype"] = "bf16"
else:
configs["dtype"] = "fp32"
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
assert ds_configs["gradient_accumulation_steps"] == configs[
'accum_grad']
assert ds_configs["gradient_clipping"] == configs['grad_clip']
assert ds_configs["steps_per_print"] == configs['log_interval']
if args.use_lora:
configs['lora_conf'] = {}
configs['lora_conf']['lora_modules'] = args.lora_modules
configs['lora_conf']['lora_attn_attr'] = args.lora_attn_attr
configs['lora_conf']['lora_list'] = args.lora_list
configs['lora_conf']['lora_rank'] = args.lora_rank
configs['lora_conf']['lora_alpha'] = args.lora_alpha
configs['lora_conf']['lora_dropout'] = args.lora_dropout
if configs["model"] == 'asr_model':
if 'input_dim' not in configs:
if 'fbank_conf' in configs['dataset_conf']:
input_dim = configs['dataset_conf']['fbank_conf'][
'num_mel_bins']
elif 'log_mel_spectrogram_conf' in configs['dataset_conf']:
input_dim = configs['dataset_conf'][
'log_mel_spectrogram_conf']['num_mel_bins']
else:
input_dim = configs['dataset_conf']['mfcc_conf'][
'num_mel_bins']
else:
input_dim = configs['input_dim']
configs['input_dim'] = input_dim
configs, _ = get_blank_id(configs, symbol_table)
configs['output_dim'] = configs['vocab_size']
configs['train_engine'] = args.train_engine
configs['use_amp'] = args.use_amp
configs['model_dir'] = args.model_dir
configs['save_states'] = args.save_states
# Save configs to model_dir/train.yaml for inference and export
if int(os.environ.get('RANK', 0)) == 0:
saved_config_path = os.path.join(args.model_dir, 'train.yaml')
with open(saved_config_path, 'w') as fout:
data = yaml.dump(configs)
fout.write(data)
if configs["model_conf"].get("apply_non_blank_embedding", False):
logging.warn('Had better load a well trained model'
'if apply_non_blank_embedding is true !!!')
return configs
def init_dataset_and_dataloader(args, configs, tokenizer, seed=777):
generator = torch.Generator()
generator.manual_seed(seed)
# if save_interval in configs, steps mode else epoch mode
if "save_interval" in configs:
configs['dataset_conf']['cycle'] = configs.get('max_epoch', 100)
conf = configs['dataset_conf']
dataset_type = configs.get('dataset', 'asr')
configs['vocab_size'] = tokenizer.vocab_size()
train_dataset = init_dataset(dataset_type,
args.data_type,
args.train_data,
tokenizer,
conf,
True,
split='train')
tag = configs["init_infos"].get("tag", "init")
train_dataset.set_epoch(configs["init_infos"].get('epoch', 0) + int("epoch_" in tag) - 1)
cv_conf = copy.deepcopy(conf)
cv_conf['split_num'] = 1
cv_dataset = init_dataset(dataset_type,
args.data_type,
args.cv_data,
tokenizer,
cv_conf,
partition=False,
split='cv')
# NOTE(xcsong): Why we prefer persistent_workers=True ?
# https://discuss.pytorch.org/t/what-are-the-dis-advantages-of-persistent-workers/102110
train_data_loader = DataLoader(train_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch)
cv_data_loader = DataLoader(cv_dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch)
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
def wrap_cuda_model(args, model, configs=None):
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
world_size = int(os.environ.get('WORLD_SIZE', 1))
if hasattr(model, 'encoder'):
grad_ckpt = getattr(model.encoder, 'gradient_checkpointing', False)
else:
grad_ckpt = False
if args.train_engine == "torch_ddp": # native pytorch ddp
device = torch.device(args.device)
model.to(device)
# model = torch.nn.parallel.DistributedDataParallel(
# model, find_unused_parameters=not grad_ckpt)
model = torch.nn.parallel.DistributedDataParallel(
model, find_unused_parameters=True)
elif args.train_engine == "deepspeed": # deepspeed
# NOTE(xcsong): look in detail how the memory estimator API works:
# https://deepspeed.readthedocs.io/en/latest/memory.html#discussion
if int(os.environ.get('RANK', 0)) == 0:
logging.info("Estimating model states memory needs (zero2)...")
estimate_zero2_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size)
logging.info("Estimating model states memory needs (zero3)...")
estimate_zero3_model_states_mem_needs_all_live(
model,
num_gpus_per_node=local_world_size,
num_nodes=world_size // local_world_size)
device = torch.device(args.device) # Init device later
pass # Init DeepSpeed later
elif args.train_engine == 'torch_fsdp':
assert configs is not None
mixed_precision_dtype = {
'fp32': torch.float32,
"fp16": torch.float16,
"bf16": torch.bfloat16,
}[configs['dtype']]
sharding_strategy = {
'model': ShardingStrategy.SHARD_GRAD_OP,
'zero2': ShardingStrategy.SHARD_GRAD_OP,
'zero3': ShardingStrategy.FULL_SHARD,
'no_shard': ShardingStrategy.NO_SHARD,
}[args.fsdp_sharding_strategy]
wrap_policy = wenet_fsdp_wrap_policy(mode=args.fsdp_sharding_strategy)
layer_types = check_gradient_checkpoint(model)
if "cuda" in args.device:
device_id = torch.cuda.current_device()
elif "npu" in args.device and TORCH_NPU_AVAILABLE:
device_id = torch.npu.current_device()
else:
logging.error("not supported device: {}".format(args.device))
model = FSDP(
model,
auto_wrap_policy=wrap_policy,
cpu_offload=CPUOffload(offload_params=True)
if args.fsdp_cpu_offload is True else None,
mixed_precision=MixedPrecision(
param_dtype=mixed_precision_dtype,
reduce_dtype=mixed_precision_dtype,
buffer_dtype=mixed_precision_dtype,
),
sharding_strategy=sharding_strategy,
limit_all_gathers=True,
use_orig_params=True,
sync_module_states=args.fsdp_sync_module_states,
# init_distributed is called (torch.cuda.set_device),
# we should set device_id, see FSDP api
device_id=device_id)
apply_fsdp_checkpointing(model, layer_types)
device = torch.device(args.device)
else:
logging.error("not supported engine: {}".format(args.train_engine))
if args.train_engine in ["torch_fsdp", "torch_ddp"]:
if args.fp16_grad_sync:
from torch.distributed.algorithms.ddp_comm_hooks import (
default as comm_hooks, )
model.register_comm_hook(state=None,
hook=comm_hooks.fp16_compress_hook)
return model, device
def init_optimizer_and_scheduler(args, configs, model):
groups = []
lr = configs['optim_conf'].get('lr')
if isinstance(lr, List):
assert configs['scheduler'] == 'warmuplr'
modules_m = configs['optim_conf']['modules']
assert isinstance(modules_m, List)
assert len(modules_m) + 1 == len(lr)
special_param_ids = set()
rest_params = []
for (i, m_str) in enumerate(modules_m):
sub_module = get_nested_attribute(model, m_str)
subs_params = []
for _, sub_params in sub_module.named_parameters():
subs_params.append(sub_params)
special_param_ids.add(id(sub_params))
groups.append({'params': subs_params, 'lr': lr[i]})
# other model's parameters
for _, param in model.named_parameters():
if id(param) not in special_param_ids:
rest_params.append(param)
groups.append({'params': rest_params, 'lr': lr[-1]})
params = groups if len(groups) > 0 else model.parameters()
optim_conf = copy.deepcopy(configs['optim_conf'])
if 'modules' in optim_conf:
del optim_conf['modules']
if isinstance(lr, List):
optim_conf['lr'] = lr[-1]
if configs['optim'] == 'adam':
optimizer = optim.Adam(params, **optim_conf)
elif configs['optim'] == 'adamw':
optimizer = optim.AdamW(params, **optim_conf)
else:
raise ValueError("unknown optimizer: " + configs['optim'])
scheduler_type = None
if configs['scheduler'] == 'warmuplr':
scheduler_type = WarmupLR
scheduler = WarmupLR(optimizer, **configs['scheduler_conf'])
elif configs['scheduler'] == 'NoamHoldAnnealing':
scheduler_type = NoamHoldAnnealing
scheduler = NoamHoldAnnealing(optimizer, **configs['scheduler_conf'])
else:
raise ValueError("unknown scheduler: " + configs['scheduler'])
# NOTE(xcsong): Custom optimizer might yield poor performance when
# zero-offload is enabled, if you do want to offload optimizer to CPU,
# please set optimizer in ds_config.json, see:
# (https://www.deepspeed.ai/docs/config-json/#optimizer-parameters)
if args.train_engine == "deepspeed":
with open(args.deepspeed_config, 'r') as fin:
ds_configs = json.load(fin)
if "optimizer" in ds_configs:
# NOTE(xcsong): Disable custom optimizer if it is set in ds_config,
# extremely useful when enable cpu_offload, DeepspeedCpuAdam
# could be 4~5x faster than torch native adam
optimizer = None
if "scheduler" in ds_configs:
scheduler = None
else:
def scheduler(opt):
return scheduler_type(opt, **configs['scheduler_conf'])
model, optimizer, _, scheduler = deepspeed.initialize(
args=args,
model=model,
optimizer=optimizer,
lr_scheduler=scheduler,
model_parameters=model.parameters())
step = configs.get("init_infos", {}).get("step", -1)
scheduler.set_step(step)
return model, optimizer, scheduler
def trace_and_print_model(args, model):
# !!!IMPORTANT!!!
# Try to export the model by script, if fails, we should refine
# the code to satisfy the script export requirements
if int(os.environ.get('RANK', 0)) == 0:
if args.jit:
script_model = torch.jit.script(model)
script_model.save(os.path.join(args.model_dir, 'init.zip'))
if args.print_model:
print(model)
num_params = sum(p.numel() for p in model.parameters())
print('the number of model params: {:,d}'.format(num_params))
def init_summarywriter(args):
writer = None
if int(os.environ.get('RANK', 0)) == 0:
os.makedirs(args.model_dir, exist_ok=True)
exp_id = os.path.basename(args.model_dir)
writer = SummaryWriter(os.path.join(args.tensorboard_dir, exp_id))
return writer
def init_scaler(args):
scaler = None
if args.use_amp:
if "cuda" in args.device:
scaler = torch.cuda.amp.GradScaler()
elif "npu" in args.device and TORCH_NPU_AVAILABLE:
scaler = torch.npu.amp.GradScaler()
else:
logging.error("not supported device: {}".format(args.device))
elif args.train_engine == 'torch_fsdp':
# why bf16 don't need scaler:
# https://discuss.pytorch.org/t/why-bf16-do-not-need-loss-scaling/176596
if args.dtype in ['fp16']:
scaler = sharded_grad_scaler.ShardedGradScaler(enabled=True)
return scaler
def save_model(model, info_dict):
rank = int(os.environ.get('RANK', 0))
tag = info_dict["tag"]
model_dir = info_dict["model_dir"]
save_model_path = os.path.join(model_dir, '{}.pt'.format(tag))
# save ckpt
if info_dict["train_engine"] == "deepspeed":
# NOTE(xcsong): All ranks should call this API, but only rank 0
# save the general model params. see:
# https://github.com/microsoft/DeepSpeed/issues/2993
with torch.no_grad():
model.save_checkpoint(save_dir=model_dir,
tag=tag,
client_state=info_dict)
if info_dict["save_states"] == "model_only" and rank == 0:
convert_zero_checkpoint_to_fp32_state_dict(model_dir,
save_model_path,
tag=tag)
os.system("rm -rf {}/{}".format(model_dir, tag))
elif info_dict['train_engine'] == "torch_fsdp":
fsdp_save_model(model, save_model_path, info_dict)
elif rank == 0:
# NOTE(xcsong): For torch_ddp, only rank-0 should call this.
save_checkpoint(model, save_model_path, info_dict)
# save yaml
if rank == 0:
with open("{}/{}.yaml".format(model_dir, tag), 'w') as fout:
data = yaml.dump(info_dict)
fout.write(data)
def wenet_join(group_join, info_dict):
world_size = int(os.environ.get('WORLD_SIZE', 1))
local_rank = int(os.environ.get('LOCAL_RANK', 0))
rank = int(os.environ.get('RANK', 0))
train_engine = info_dict.get('train_engine', "torch_ddp")
if info_dict["batch_idx"] == 0 or train_engine == "torch_ddp":
# NOTE(xcsong): skip first batch because its processing time includes
# dataloader initialization time, which may exceed 30 seconds
return False
try:
# NOTE(xcsong): Why we need a new group?
# Because Deepspeed has its own group where all the relevant communication
# operations are executed. If we add a communication operation that is not
# managed by Deepspeed in this group, it's highly likely to cause
# communication chaos, resulting in hard-to-troubleshoot hangs.
dist.monitored_barrier(group=group_join,
timeout=group_join.options._timeout)
except RuntimeError as e:
logging.info("Detected uneven workload distribution: {}\n".format(e) +
"Break current worker to manually join all workers, " +
"world_size {}, current rank {}, current local_rank {}\n".
format(world_size, rank, local_rank))
return True
return False
def batch_forward(model, batch, scaler, info_dict, device):
train_engine = info_dict.get('train_engine', "torch_ddp")
accum_grad = info_dict.get('accum_grad', 1)
dtype = info_dict.get("dtype", "fp32")
if dtype == "fp16":
dtype = torch.float16
elif dtype == "bf16":
dtype = torch.bfloat16
else: # fp32
dtype = None
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
amp_autocast = torch.cuda.amp.autocast
if "npu" in device.__str__() and TORCH_NPU_AVAILABLE:
amp_autocast = torch.npu.amp.autocast
autocast = {
"deepspeed":
amp_autocast(enabled=dtype is not None,
dtype=dtype,
cache_enabled=False),
"torch_ddp":
amp_autocast(enabled=scaler is not None),
"torch_fsdp":
amp_autocast(enabled=True, dtype=dtype)
if dtype is not None else nullcontext()
}[train_engine]
with autocast:
loss_dict = model(batch, device)
info_dict['loss_dict'] = loss_dict
return info_dict
def batch_backward(model, scaler, info_dict):
train_engine = info_dict.get("train_engine", "torch_ddp")
accum_grad = info_dict.get('accum_grad', 1)
use_amp = info_dict.get('use_amp', False)
if use_amp:
assert scaler is not None
loss = info_dict['loss_dict']['loss']
if train_engine == "deepspeed":
# NOTE(xcsong): `model.backward(loss)` is equivalent to
# `scale_loss_wrt_accum_grad + loss.backward()`
# ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api
scaled_loss = model.backward(loss)
else:
assert train_engine in ["torch_ddp", "torch_fsdp"]
scaled_loss = loss / accum_grad
if scaler is not None:
# fp16 (amp and fsdp)
scaler.scale(scaled_loss).backward()
else:
# float32 (ddp and fsdp)
# bf16 (fsdp)
scaled_loss.backward()
info_dict['loss_dict']['loss'] = scaled_loss
for loss_name, loss_value in info_dict['loss_dict'].items():
if loss_value is not None:
info_dict['loss_dict'][loss_name] = tensor_to_scalar(loss_value)
return info_dict
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
rank = int(os.environ.get('RANK', 0))
train_engine = info_dict.get("train_engine", "torch_ddp")
accum_grad = info_dict.get('accum_grad', 1)
use_amp = info_dict.get('use_amp', False)
clip = info_dict.get('grad_clip', 50.0)
batch_idx = info_dict["batch_idx"]
if use_amp:
assert scaler is not None
grad_norm = 0.0
if train_engine == "deepspeed":
# NOTE(xcsong): The step() function in DeepSpeed engine updates the
# model parameters as well as the learning rate.
# Zeroing the gradients is handled automatically by
# DeepSpeed after the weights have been updated using a mini-batch.
# DeepSpeed also performs gradient averaging automatically at the
# gradient accumulation boundaries and addresses clip_grad_norm internally.
# `ds_model.step() = clip_grad_norm_() + optimizer.step()
# + optimizer.zero_grad() + scheduler.step()`
# ref: https://www.deepspeed.ai/tutorials/megatron/#using-the-training-api
info_dict["is_gradient_accumulation_boundary"] = \
model.is_gradient_accumulation_boundary()
model.step()
grad_norm = model.get_global_grad_norm()
if grad_norm is None:
grad_norm = 0.0
elif (batch_idx + 1) % accum_grad == 0:
# Use mixed precision training
# fp16 (ddp fsdp)
if scaler is not None:
scaler.unscale_(optimizer)
if train_engine == "torch_ddp":
grad_norm = clip_grad_norm_(model.parameters(), clip)
else:
# fsdp
grad_norm = model.clip_grad_norm_(clip)
# Must invoke scaler.update() if unscale_() is used in
# the iteration to avoid the following error:
# RuntimeError: unscale_() has already been called
# on this optimizer since the last update().
# We don't check grad here since that if the gradient
# has inf/nan values, scaler.step will skip
# optimizer.step().
scaler.step(optimizer)
scaler.update()
else:
if train_engine == "torch_ddp":
grad_norm = clip_grad_norm_(model.parameters(), clip)
else:
grad_norm = model.clip_grad_norm_(clip)
if torch.isfinite(grad_norm):
optimizer.step()
optimizer.zero_grad()
scheduler.step()
info_dict["lrs"] = [group['lr'] for group in optimizer.param_groups]
info_dict["grad_norm"] = tensor_to_scalar(grad_norm)
return info_dict
def log_per_step(writer, info_dict, timer: Optional[StepTimer] = None):
tag = info_dict["tag"]
step = info_dict["step"]
batch_idx = info_dict["batch_idx"]
loss_dict = info_dict['loss_dict']
epoch = info_dict.get('epoch', 0)
train_engine = info_dict.get("train_engine", "torch_ddp")
accum_grad = info_dict.get('accum_grad', 1) if tag != "CV" else 1
log_interval = info_dict.get('log_interval', 10)
lrs = info_dict.get("lrs", [0.0])
is_gradient_accumulation_boundary = info_dict.get(
"is_gradient_accumulation_boundary", False)
rank = int(os.environ.get('RANK', 0))
# TRAIN Tensorboard
if tag == "TRAIN" and rank == 0 and writer is not None:
if (train_engine == "deepspeed" and is_gradient_accumulation_boundary
) or (train_engine in ["torch_ddp", "torch_fsdp"] and
(batch_idx + 1) % accum_grad == 0):
writer.add_scalar('train/train_loss',
tensor_to_scalar(loss_dict['loss']) * accum_grad,
step)
writer.add_scalar('train/grad_norm', info_dict['grad_norm'], step)
for name, value in loss_dict.items():
if name != 'loss' and value is not None:
writer.add_scalar('train/{}'.format(name),
tensor_to_scalar(value), step)
# lr
for i, lr in enumerate(lrs):
writer.add_scalar('train/lr_{}'.format(i), lr, step)
# CV Tensorboard
elif "step_" in tag and rank == 0 and writer is not None:
for name, value in loss_dict.items():
writer.add_scalar('cv/{}'.format(name), tensor_to_scalar(value),
step)
logging.info(
'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, step + 1, lrs_to_str(lrs),
tensor_to_scalar(loss_dict["loss"]), rank,
tensor_to_scalar(loss_dict["acc"])))
return
# TRAIN & CV, Shell log (stdout)
if (batch_idx + 1) % log_interval == 0:
log_str = '{} | '.format(tag)
if timer is not None:
timer_step = step
if info_dict.get("cv_step", None) is not None:
timer_step = info_dict['cv_step']
steps_per_second = timer.steps_per_second(timer_step)
log_str += 'steps/sec {:.3f}| '.format(steps_per_second)
log_str += 'Batch {}/{} loss {:.6f} '.format(
epoch, batch_idx + 1 if 'save_interval' not in info_dict else
(step + 1) * accum_grad,
tensor_to_scalar(loss_dict['loss']) * accum_grad)
for name, value in loss_dict.items():
if name != 'loss' and value is not None:
log_str += '{} {:.6f} '.format(name, tensor_to_scalar(value))
if tag == "TRAIN":
log_str += 'lr {} grad_norm {:.6f} rank {}'.format(
lrs_to_str(lrs), info_dict['grad_norm'], rank)
logging.debug(log_str)
def log_per_epoch(writer, info_dict):
epoch = info_dict["epoch"]
loss_dict = info_dict["loss_dict"]
lrs = info_dict['lrs']
rank = int(os.environ.get('RANK', 0))
step = info_dict["step"]
logging.info(
'Epoch {} Step {} CV info lr {} cv_loss {} rank {} acc {}'.format(
epoch, step, lrs_to_str(lrs), tensor_to_scalar(loss_dict["loss"]),
rank, tensor_to_scalar(loss_dict["acc"])))
if int(os.environ.get('RANK', 0)) == 0:
for i, lr in enumerate(info_dict["lrs"]):
writer.add_scalar('epoch/lr_{}'.format(i), lr, epoch)
for name, value in loss_dict.items():
writer.add_scalar('epoch/{}'.format(name), tensor_to_scalar(value),
epoch)
def freeze_modules(model, args):
for name, param in model.named_parameters():
for module_name in args.freeze_modules:
if module_name in name:
param.requires_grad = False
logging.debug("{} module is freezed".format(name))
def reinit_lora(model, args, configs, tokenizer, seed=777):
from tqdm import tqdm
from wenet.finetune.lora.utils import estimate_gradient, reinit_lora_modules
from wenet.finetune.lora.layers import LoRALayer
from types import SimpleNamespace
logging.info("reinit lora modules.")
with open(args.lora_init_yaml, 'r') as file:
lora_config = yaml.safe_load(file)
generator = torch.Generator()
generator.manual_seed(seed)
dataset_conf = copy.deepcopy(configs['dataset_conf'])
dataset_conf['batch_conf']['batch_size'] = lora_config['init_batch_size']
dataset_type = configs.get('dataset', 'asr')
dataset = init_dataset(dataset_type, args.data_type, args.train_data,
tokenizer, dataset_conf, True)
dataloader = DataLoader(dataset,
batch_size=None,
pin_memory=args.pin_memory,
num_workers=args.num_workers,
persistent_workers=True,
generator=generator,
prefetch_factor=args.prefetch)
additional_kwargs = {}
if lora_config["init_config"]["mode"] == "gradient":
named_grads = estimate_gradient(model, dataloader,
lora_config['init_iters'])
additional_kwargs["named_grads"] = named_grads
lora_config = SimpleNamespace(**lora_config["init_config"])
for name, module in tqdm(
model.named_modules(),
desc="Reinitializing Lora",
total=len(list(model.named_modules())),
):
if isinstance(module, LoRALayer):
reinit_lora_modules(name, module, lora_config, **additional_kwargs)
# lora_init_model needs to be saved, w0 = w0 - A0 * B0
save_checkpoint(model, os.path.join(args.model_dir, "lora_init.pt"),
infos={"tag": "lora_init", **configs})