Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang) | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import copy | |
import datetime | |
import logging | |
import sys | |
from contextlib import nullcontext | |
# if your python version < 3.7 use the below one | |
# from contextlib import suppress as nullcontext | |
import torch | |
from wenet.utils.common import StepTimer | |
from wenet.utils.train_utils import (wenet_join, batch_forward, batch_backward, | |
update_parameter_and_lr, log_per_step, | |
save_model) | |
class Executor: | |
def __init__(self, | |
global_step: int = 0, | |
device: torch.device = torch.device("cpu")): | |
self.step = global_step + 1 | |
self.train_step_timer = None | |
self.cv_step_timer = None | |
self.device = device | |
def train(self, model, optimizer, scheduler, train_data_loader, | |
cv_data_loader, writer, configs, scaler, group_join): | |
''' Train one epoch | |
''' | |
if self.train_step_timer is None: | |
self.train_step_timer = StepTimer(self.step) | |
model.train() | |
info_dict = copy.deepcopy(configs) | |
logging.info('using accumulate grad, new batch size is {} times' | |
' larger than before'.format(info_dict['accum_grad'])) | |
# A context manager to be used in conjunction with an instance of | |
# torch.nn.parallel.DistributedDataParallel to be able to train | |
# with uneven inputs across participating processes. | |
if isinstance(model, torch.nn.parallel.DistributedDataParallel): | |
model_context = model.join | |
else: | |
model_context = nullcontext | |
with model_context(): | |
for batch_idx, batch_dict in enumerate(train_data_loader): | |
info_dict["tag"] = "TRAIN" | |
info_dict["step"] = self.step | |
info_dict["batch_idx"] = batch_idx | |
if wenet_join(group_join, info_dict): | |
break # fix by zhaoyi ,促进多机训练 | |
if batch_dict["target_lengths"].size(0) == 0: | |
continue | |
context = None | |
# Disable gradient synchronizations across DDP processes. | |
# Within this context, gradients will be accumulated on module | |
# variables, which will later be synchronized. | |
if info_dict.get("train_engine", "torch_ddp") in [ | |
"torch_ddp", "torch_fsdp" | |
] and (batch_idx + 1) % info_dict["accum_grad"] != 0: | |
context = model.no_sync | |
# Used for single gpu training and DDP gradient synchronization | |
# processes. | |
else: | |
context = nullcontext | |
with context(): | |
info_dict = batch_forward(model, batch_dict, scaler, | |
info_dict, self.device) | |
info_dict = batch_backward(model, scaler, info_dict) | |
info_dict = update_parameter_and_lr(model, optimizer, | |
scheduler, scaler, | |
info_dict) | |
# write training: tensorboard && log | |
log_per_step(writer, info_dict, timer=self.train_step_timer) | |
# save_interval = info_dict.get('save_interval', sys.maxsize) | |
# if (self.step + | |
# 1) % save_interval == 0 and self.step != 0 and ( | |
# batch_idx + 1) % info_dict["accum_grad"] == 0: | |
# import torch.distributed as dist | |
# # Ensure all ranks start CV at the same time in step mode | |
# dist.barrier() | |
# # loss_dict = self.cv(model, cv_data_loader, configs) | |
# model.train() | |
# info_dict.update({ | |
# "tag": | |
# "step_{}".format(self.step), | |
# "loss_dict": {'loss':999,'acc':999}, | |
# "save_time": | |
# datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S'), | |
# "lrs": | |
# [group['lr'] for group in optimizer.param_groups] | |
# }) | |
# save_model(model, info_dict) | |
# # write final cv: tensorboard | |
# log_per_step(writer, info_dict) | |
# # Ensure all ranks start Train at the same time in step mode | |
# dist.barrier() | |
self.step += 1 if (batch_idx + | |
1) % info_dict["accum_grad"] == 0 else 0 | |
def cv(self, model, cv_data_loader, configs): | |
''' Cross validation on | |
''' | |
if self.cv_step_timer is None: | |
self.cv_step_timer = StepTimer(0.0) | |
else: | |
self.cv_step_timer.last_iteration = 0.0 | |
model.eval() | |
info_dict = copy.deepcopy(configs) | |
num_seen_utts, loss_dict, total_acc = 1, {}, [] # avoid division by 0 | |
with torch.no_grad(): | |
for batch_idx, batch_dict in enumerate(cv_data_loader): | |
info_dict["tag"] = "CV" | |
info_dict["step"] = self.step | |
info_dict["batch_idx"] = batch_idx | |
info_dict["cv_step"] = batch_idx | |
num_utts = batch_dict["target_lengths"].size(0) | |
if num_utts == 0: | |
continue | |
info_dict = batch_forward(model, batch_dict, None, info_dict, | |
self.device) | |
_dict = info_dict["loss_dict"] | |
num_seen_utts += num_utts | |
total_acc.append(_dict['th_accuracy'].item( | |
) if _dict.get('th_accuracy', None) is not None else 0.0) | |
for loss_name, loss_value in _dict.items(): | |
if loss_value is not None and "loss" in loss_name \ | |
and torch.isfinite(loss_value): | |
loss_value = loss_value.item() | |
loss_dict[loss_name] = loss_dict.get(loss_name, 0) + \ | |
loss_value * num_utts | |
# write cv: log | |
log_per_step(writer=None, | |
info_dict=info_dict, | |
timer=self.cv_step_timer) | |
for loss_name, loss_value in loss_dict.items(): | |
loss_dict[loss_name] = loss_dict[loss_name] / num_seen_utts | |
loss_dict["acc"] = sum(total_acc) / len(total_acc) | |
return loss_dict |