OSUM / wenet /utils /executor.py
tomxxie
适配zeroGPU
568e264
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import datetime
import logging
import sys
from contextlib import nullcontext
# if your python version < 3.7 use the below one
# from contextlib import suppress as nullcontext
import torch
from wenet.utils.common import StepTimer
from wenet.utils.train_utils import (wenet_join, batch_forward, batch_backward,
update_parameter_and_lr, log_per_step,
save_model)
class Executor:
def __init__(self,
global_step: int = 0,
device: torch.device = torch.device("cpu")):
self.step = global_step + 1
self.train_step_timer = None
self.cv_step_timer = None
self.device = device
def train(self, model, optimizer, scheduler, train_data_loader,
cv_data_loader, writer, configs, scaler, group_join):
''' Train one epoch
'''
if self.train_step_timer is None:
self.train_step_timer = StepTimer(self.step)
model.train()
info_dict = copy.deepcopy(configs)
logging.info('using accumulate grad, new batch size is {} times'
' larger than before'.format(info_dict['accum_grad']))
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
if isinstance(model, torch.nn.parallel.DistributedDataParallel):
model_context = model.join
else:
model_context = nullcontext
with model_context():
for batch_idx, batch_dict in enumerate(train_data_loader):
info_dict["tag"] = "TRAIN"
info_dict["step"] = self.step
info_dict["batch_idx"] = batch_idx
if wenet_join(group_join, info_dict):
break # fix by zhaoyi ,促进多机训练
if batch_dict["target_lengths"].size(0) == 0:
continue
context = None
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if info_dict.get("train_engine", "torch_ddp") in [
"torch_ddp", "torch_fsdp"
] and (batch_idx + 1) % info_dict["accum_grad"] != 0:
context = model.no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else:
context = nullcontext
with context():
info_dict = batch_forward(model, batch_dict, scaler,
info_dict, self.device)
info_dict = batch_backward(model, scaler, info_dict)
info_dict = update_parameter_and_lr(model, optimizer,
scheduler, scaler,
info_dict)
# write training: tensorboard && log
log_per_step(writer, info_dict, timer=self.train_step_timer)
# save_interval = info_dict.get('save_interval', sys.maxsize)
# if (self.step +
# 1) % save_interval == 0 and self.step != 0 and (
# batch_idx + 1) % info_dict["accum_grad"] == 0:
# import torch.distributed as dist
# # Ensure all ranks start CV at the same time in step mode
# dist.barrier()
# # loss_dict = self.cv(model, cv_data_loader, configs)
# model.train()
# info_dict.update({
# "tag":
# "step_{}".format(self.step),
# "loss_dict": {'loss':999,'acc':999},
# "save_time":
# datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'),
# "lrs":
# [group['lr'] for group in optimizer.param_groups]
# })
# save_model(model, info_dict)
# # write final cv: tensorboard
# log_per_step(writer, info_dict)
# # Ensure all ranks start Train at the same time in step mode
# dist.barrier()
self.step += 1 if (batch_idx +
1) % info_dict["accum_grad"] == 0 else 0
def cv(self, model, cv_data_loader, configs):
''' Cross validation on
'''
if self.cv_step_timer is None:
self.cv_step_timer = StepTimer(0.0)
else:
self.cv_step_timer.last_iteration = 0.0
model.eval()
info_dict = copy.deepcopy(configs)
num_seen_utts, loss_dict, total_acc = 1, {}, [] # avoid division by 0
with torch.no_grad():
for batch_idx, batch_dict in enumerate(cv_data_loader):
info_dict["tag"] = "CV"
info_dict["step"] = self.step
info_dict["batch_idx"] = batch_idx
info_dict["cv_step"] = batch_idx
num_utts = batch_dict["target_lengths"].size(0)
if num_utts == 0:
continue
info_dict = batch_forward(model, batch_dict, None, info_dict,
self.device)
_dict = info_dict["loss_dict"]
num_seen_utts += num_utts
total_acc.append(_dict['th_accuracy'].item(
) if _dict.get('th_accuracy', None) is not None else 0.0)
for loss_name, loss_value in _dict.items():
if loss_value is not None and "loss" in loss_name \
and torch.isfinite(loss_value):
loss_value = loss_value.item()
loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \
loss_value * num_utts
# write cv: log
log_per_step(writer=None,
info_dict=info_dict,
timer=self.cv_step_timer)
for loss_name, loss_value in loss_dict.items():
loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts
loss_dict["acc"] = sum(total_acc) / len(total_acc)
return loss_dict