akhaliq3
spaces demo
546a9ba
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import defaultdict
from datetime import datetime
import os
import sys
import importlib
import json
import random
import numpy as np
import inspect
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from model.third_party.HMNet.Models.Trainers.DistributedTrainer import (
DistributedTrainer,
)
from model.third_party.HMNet.Models.Trainers.Tasks import Task
from model.third_party.HMNet.Utils.GeneralUtils import (
AverageMeter,
BaseBatchGen,
bcolors,
)
from model.third_party.HMNet.DataLoader import iterators
class ObjectView(object):
def __init__(self, d):
self.__dict__ = d
class WrappedModel(nn.Module):
def __init__(self, model, criterion):
super(WrappedModel, self).__init__()
self.add_module("model", model)
self.add_module("criterion", criterion)
def forward(self, batch):
output = self.model(batch)
loss = self.criterion(output, batch)
return loss
class HMNetTrainer(DistributedTrainer):
"""
The trainer class for HMNet model training (pre-train and fine-tune.)
Its train() and eval() methods are intended to directly called to
start training and evaluation respectively.
Before running, the trainer must contain proper Task, Criterion, and Optimizer
instances.
"""
def __init__(self, opt):
super().__init__(opt)
self.task = Task.setup_task(self.opt["TASK"], self.opt, self.saveFolder)
def is_gradient_accumulation_boundary(self):
return (self.updates + 1) % self.grad_acc_steps == 0
def get_batch_generator(self, dataset_label):
batch_generator = self.task.batch_gen(
self.opt,
dataset_label=dataset_label,
model_config=self.module.config,
tokenizer=self.module.tokenizer,
world_size=self.opt["world_size"],
rank=self.opt["rank"],
seed=self.seed,
)
if isinstance(batch_generator, BaseBatchGen):
# If it is a wrapper class of an infinibatch iterator,
# get the internal infnitibatch iterator.
batch_generator = batch_generator.iterator
self.log(f"Loaded data on rank {self.opt['rank']}.")
return batch_generator
def set_up_model(self):
# instantiate module (tokenizer should be contained in module as self.module.tokenizer)
try:
model_module = importlib.import_module(
"model.third_party.HMNet.Models.Networks." + self.opt["MODEL"]
)
model_class = getattr(model_module, self.opt["MODEL"])
self.module = model_class(self.opt)
except Exception as e:
self.log(e)
self.log("ERROR: Model {} is unknown".format(self.opt["MODEL"]))
assert False
# calculate total trainable parameters
pytorch_total_params = sum(
p.numel() for p in self.module.parameters() if p.requires_grad
)
self.log("Total trainable parameters: {}".format(pytorch_total_params))
# instantiate criterion
try:
criterion_module = importlib.import_module(
"model.third_party.HMNet.Models.Criteria." + self.opt["CRITERION"]
)
criterion_class = getattr(criterion_module, self.opt["CRITERION"])
self.criterion = criterion_class(self.opt, self.module)
except Exception as e:
self.log(e)
self.log("ERROR: Criterion {} is unknown".format(self.opt["CRITERION"]))
assert False
self.module.to(self.opt["device"])
def get_optimizer_params_config(self, optimizer_class):
optimizer_parameters = {}
sig = inspect.signature(optimizer_class)
for param_name in sig.parameters.keys():
if param_name == "lr":
optimizer_parameters[param_name] = self.opt["START_LEARNING_RATE"]
if param_name not in ["params", "lr"] and param_name.upper() in self.opt:
optimizer_parameters[param_name] = self.opt[param_name.upper()]
return optimizer_parameters
def get_lr_scheduler_params_config(self, lr_scheduler_class):
lr_scheduler_parameters = {}
sig = inspect.signature(lr_scheduler_class)
for param_name in sig.parameters.keys():
if param_name not in ["optimizer"] and param_name.upper() in self.opt:
lr_scheduler_parameters[param_name] = self.opt[param_name.upper()]
return lr_scheduler_parameters
def set_up_optimizer_and_lr_scheduler(self):
parameters = self.module.get_training_parameters()
# instantiate optimizer
try: # first try pytorch native optimizer
optimizer_class = getattr(optim, self.opt["OPTIMIZER"])
self.log(
"Using pytorch native optimizier: {}".format(self.opt["OPTIMIZER"])
)
except:
try: # then try custom optimizer inside Models.Optimizers
optimizer_module = importlib.import_module(
"model.third_party.HMNet.Models.Optimizers." + self.opt["OPTIMIZER"]
)
optimizer_class = getattr(optimizer_module, self.opt["OPTIMIZER"])
self.log("Using custom optimizer: {}".format(self.opt["OPTIMIZER"]))
except Exception as e:
self.log(e)
self.log("ERROR: Optimizer {} is unknown".format(self.opt["OPTIMIZER"]))
assert False
optimizer_parameters = self.get_optimizer_params_config(optimizer_class)
self.log(f"Optimizer parameters: {optimizer_parameters}")
self.optimizer = optimizer_class(parameters, **optimizer_parameters)
self.optimizer.zero_grad()
# instantiate lr scheduler
try: # first look for pytorch native lr scheduler
lr_scheduler_class = getattr(lr_scheduler, self.opt["LR_SCHEDULER"])
self.log(
"Using pytorch native lr scheduler: {}".format(self.opt["LR_SCHEDULER"])
)
except:
try: # then look for custom lr scheduler inside Models.Optimizers
lr_scheduler_module = importlib.import_module(
"model.third_party.HMNet.Models.Optimizers."
+ self.opt["LR_SCHEDULER"]
)
lr_scheduler_class = getattr(
lr_scheduler_module, self.opt["LR_SCHEDULER"]
)
self.log(
"Using custom lr scheduler: {}".format(self.opt["LR_SCHEDULER"])
)
except Exception as e:
self.log(e)
self.log(
"ERROR: LR Scheduler {} is unknown".format(self.opt["LR_SCHEDULER"])
)
assert False
lr_scheduler_parameters = self.get_lr_scheduler_params_config(
lr_scheduler_class
)
self.log(f"Lr scheduler parameters: {lr_scheduler_parameters}")
self.lr_scheduler = lr_scheduler_class(
self.optimizer, **lr_scheduler_parameters
)
def initialize_fp16_DDP(self):
"""
Wrap the module and criterion to a single network, then depending on the settings,
wrap the network with apex amp module for fp16 training, and wrap the network with
pytorch DDP module for distributed data parallel training
"""
self.network = WrappedModel(self.module, self.criterion)
self.network.to(self.opt["device"])
if self.opt["fp16"]:
from apex import amp
self.network, self.optimizer = amp.initialize(
self.network, self.optimizer, opt_level=self.opt["fp16_opt_level"]
)
if self.opt["world_size"] > 1:
self.network = torch.nn.parallel.DistributedDataParallel(
self.network,
device_ids=[self.opt["local_rank"]],
output_device=self.opt["local_rank"],
find_unused_parameters=True,
)
self.log(f"Wrapped model with DDP on rank {self.opt['rank']}.")
assert self.module is self.network.module.model
else:
assert self.module is self.network.model
def eval(self):
if self.opt["rank"] == 0:
self.log("-----------------------------------------------")
self.log("Evaluating model ... ")
self.set_up_model()
for eval_dataset in ["dev", "test"]:
batch_generator_eval = self.get_batch_generator(eval_dataset)
self.task.evaluator.reset_best_score(set_high=True)
result, score, got_better_score = self.task.evaluator.eval_batches(
self.module, batch_generator_eval, self.saveFolder, eval_dataset
)
if self.opt["rank"] == 0:
self.log("{0} results breakdown\n{1}".format(eval_dataset, result))
def eval_return_results(self):
if self.opt["rank"] == 0:
self.log("-----------------------------------------------")
self.log("Evaluating model ... ")
self.set_up_model()
for eval_dataset in ["test"]:
batch_generator_eval = self.get_batch_generator(eval_dataset)
self.task.evaluator.reset_best_score(set_high=True)
result, score, got_better_score = self.task.evaluator.eval_batches(
self.module, batch_generator_eval, self.saveFolder, eval_dataset
)
if self.opt["rank"] == 0:
self.log("{0} results breakdown\n{1}".format(eval_dataset, result))
return result
def train(self):
self.log(f"train on rank {self.opt['rank']}")
if self.opt["rank"] == 0:
self.log("-----------------------------------------------")
self.log("Initializing model...")
self.set_up_model() # setup self.module as original model
self.network = None
self.train_batch_generator = self.get_batch_generator("train")
if isinstance(self.train_batch_generator, iterators.CheckpointableIterator):
# training batch generator is infinite
self.updates_per_epoch = self.opt["UPDATES_PER_EPOCH"]
else:
self.updates_per_epoch = len(self.train_batch_generator)
self.updates = 0
self.optim_steps = 0
self.start_epoch_idx = 0
self.start_batch_idx = 0
self.set_up_optimizer_and_lr_scheduler()
self.initialize_fp16_DDP()
if "RESUME" in self.opt:
# Resume complete training states, including optimizer, lr_scheduler, train batch generator, and updates count
# from the checkpoint location indicated in a .json file
self.load_checkpoint()
######################
# Start the main loop
######################
numEpochs = self.opt["MAX_NUM_EPOCHS"]
self.train_loss = AverageMeter() # track the average training loss
self.acc_loss = 0.0
# after every 'SAVE_PER_UPDATE_NUM' updates, it will save a checkpoint by setting save_a_checkpoint to True temporarily
save_a_checkpoint = False
for epoch in range(self.start_epoch_idx, numEpochs):
self.current_epoch_idx = epoch
self.log("Epoch {}".format(epoch))
startTime = datetime.now()
for batch_idx, batch in enumerate(self.train_batch_generator):
if self.current_epoch_idx == self.start_epoch_idx:
if isinstance(
self.train_batch_generator, iterators.CheckpointableIterator
):
batch_idx += self.start_batch_idx
elif batch_idx < self.start_batch_idx:
continue
self.current_batch_idx = batch_idx
# after every 'SAVE_PER_UPDATE_NUM' updates, save a checkpoint
if ("SAVE_PER_UPDATE_NUM" in self.opt) and (
self.updates + 1
) % self.opt["SAVE_PER_UPDATE_NUM"] == 0:
# Make sure the next update is going to update the weights and zero the gradients, then we can checkpoint
assert self.is_gradient_accumulation_boundary()
save_a_checkpoint = True
# update
self.update(batch)
if save_a_checkpoint:
# evaluate at the checkpointed moment, and log the results
if self.task.evaluator is not None:
evaluate_label = "update_" + str(self.updates)
eval_dataset = "dev"
batches = self.get_batch_generator(eval_dataset)
(
result,
score,
got_better_score,
) = self.task.evaluator.eval_batches(
self.module, batches, self.saveFolder, evaluate_label
)
self.tb_log_scalar("Eval/score", score, self.updates)
if got_better_score:
self.log(
"Got new better score on rank-{0} evaluator, at updates {1}".format(
self.opt["rank"], self.updates
)
)
self.log(
"Updates {0} - {1}: Current Score: {2:.3f} (best Score: {3:.3f})".format(
self.updates,
eval_dataset,
score,
self.task.evaluator.best_score,
)
)
self.log("Current results breakdown\n{0}".format(result))
self.log(
"Best results breakdown\n{0}".format(
self.task.evaluator.best_res
)
)
# save complete training states, including model weights, optimizer, lr_scheduler, batch generator, and updates count
self.save_checkpoint(self.updates)
save_a_checkpoint = False
# logging
if (
(batch_idx % 10 == 0)
or (epoch == 0 and batch_idx <= 50)
or "DEBUG" in self.opt
):
if self.opt["rank"] == 0:
batch_size = batch["encoder_input_ids"].shape[0]
self.log(
"epochs[{0:6}] updates[{1:6}] bsz[{2:d}] train loss[{3:.5f}] avg train loss[{4:.5f}] learning rate[{5:.5e}] remaining[{6}]".format(
epoch,
self.updates,
batch_size,
self.train_loss.val,
self.train_loss.avg,
self.lr_scheduler.get_lr()[0],
str(
(datetime.now() - startTime)
/ (batch_idx + 1)
* (self.updates_per_epoch - batch_idx - 1)
).split(".")[0],
)
)
self.tb_log_scalar(
"Loss/train_val", self.train_loss.val, self.updates
)
self.tb_log_scalar(
"Loss/train_avg", self.train_loss.avg, self.updates
)
self.tb_log_scalar(
"Learning Rate/lr",
self.lr_scheduler.get_lr()[0],
self.updates,
)
# if "DEBUG" in self.opt and batch_idx > 200: # exist early for DEBUG mode
# break
if (
isinstance(
self.train_batch_generator, iterators.CheckpointableIterator
)
and batch_idx + 1 == self.updates_per_epoch
):
break
self.log("This epoch takes" + str(datetime.now() - startTime))
self.log("PROGRESS: {0:.2f}%".format(100.0 * (epoch + 1) / numEpochs))
self.log("Config file is at " + self.opt["confFile"])
if "DEBUG" in self.opt: # exist early for DEBUG mode
break
def update(self, batch):
# forward loss, backward propagation, model update, and one step of optimization and lr scheduler
self.network.train()
# put the batch to the device
# @TODO make this more general, maybe have a self.task.move_batch(batch, device)
# so the trainer decides when and where to move batches, and task tells how
if isinstance(batch, tuple):
batch = tuple(t.to(self.opt["device"]) for t in batch)
elif isinstance(batch, list):
batch = [t.to(self.opt["device"]) for t in batch]
elif isinstance(batch, dict):
for k in batch:
if torch.is_tensor(batch[k]):
batch[k] = batch[k].to(self.opt["device"])
else:
assert torch.is_tensor(batch)
batch = batch.to(self.opt["device"])
# determine whether gradient sync can be skiped or not for this update
skip_gradient_sync = False
if self.opt["world_size"] > 1 and not self.is_gradient_accumulation_boundary():
if not self.opt["fp16"]:
# https://krishansubudhi.github.io/deeplearning/2020/02/06/apex-gradient-accumulation.html
# When using fp16, if we skip grad sync during grad accumulation, the grad sync at the
# grad accumulation boundary cannot properly sync the whole accumulated grad.
# So with fp16 on, we have to sync even if it's not grad accumulation boundary.
if self.high_pytorch_version:
skip_gradient_sync = True
# forward
if skip_gradient_sync:
with self.network.no_sync():
loss = self.network(batch)
else:
loss = self.network(batch)
if self.grad_acc_steps > 1:
loss = loss / self.grad_acc_steps
self.acc_loss += loss
# self.log(f"forward() done on rank {self.opt['rank']}")
# print(loss.item())
# backward
def backward(loss_tensor):
if self.opt["fp16"]:
from apex import amp
with amp.scale_loss(loss_tensor, self.optimizer) as scaled_loss:
scaled_loss.backward()
else:
loss_tensor.backward()
if skip_gradient_sync:
with self.network.no_sync():
backward(loss)
else:
if "DEBUG" in self.opt and self.opt["rank"] == 0:
self.log(
"Performing synchronized backward at step {0}".format(
self.optim_steps
)
)
backward(loss)
# self.log(f"backward() done on rank {self.opt['rank']}")
# step
if self.is_gradient_accumulation_boundary():
if self.opt["world_size"] > 1:
# ddp: use all_reduce to sum up values of self.acc_loss over all processes
# the operations happens in place (i.e., the value of self.acc_loss is replaced) and all processes received the updated value
torch.distributed.all_reduce(
self.acc_loss, torch.distributed.ReduceOp.SUM
)
self.acc_loss /= self.opt["world_size"]
self.train_loss.update(self.acc_loss.data, 1)
self.acc_loss = 0.0
if "GRAD_CLIPPING" in self.opt:
if self.opt["fp16"]:
from apex import amp
torch.nn.utils.clip_grad_norm_(
amp.master_params(self.optimizer), self.opt["GRAD_CLIPPING"]
)
else:
torch.nn.utils.clip_grad_norm_(
self.network.parameters(), self.opt["GRAD_CLIPPING"]
)
self.optim_steps += 1
self.optimizer.step()
self.optimizer.zero_grad()
self.lr_scheduler.step()
self.updates += 1
# self.log(f"step() done on rank {self.opt['rank']}")
def save_checkpoint(self, tag):
"""
Save complete training states, including model weights, optimizer, lr_scheduler,
fp16 loss scaler, random state, batch generator, and updates count
Also save a model with save_pretrained API for model transfer
"""
self.log("Saving checkpoint...")
resume_epoch_idx = self.current_epoch_idx
resume_batch_idx = self.current_batch_idx + 1
if resume_batch_idx == self.updates_per_epoch:
resume_batch_idx = 0
resume_epoch_idx += 1
if self.opt["fp16"]:
from apex import amp
if self.opt["rank"] == 0:
save_dir = os.path.join(self.saveFolder, str(tag))
os.makedirs(save_dir)
save_path = os.path.join(save_dir, "training_states.pt")
state = {
"network": self.network.state_dict(),
"optimizer": self.optimizer.state_dict(),
"lr_scheduler": self.lr_scheduler.state_dict(),
"amp": amp.state_dict() if self.opt["fp16"] else None,
"optim_steps": self.optim_steps,
"updates": self.updates,
"updates_per_epoch": self.updates_per_epoch,
"start_epoch_idx": resume_epoch_idx,
"start_batch_idx": resume_batch_idx,
}
torch.save(state, save_path)
if self.opt["world_size"] > 1:
torch.distributed.barrier()
save_dir = os.path.join(self.saveFolder, str(tag))
assert os.path.isdir(save_dir)
random_state_path = os.path.join(
save_dir, "random_state_rank_{:04d}".format(self.opt["rank"])
)
random_state = {
"random": random.getstate(),
"numpy_random": np.random.get_state(),
"torch_random": torch.get_rng_state(),
"torch_cuda_random": torch.cuda.get_rng_state(device=self.opt["device"])
if self.use_cuda
else None,
}
torch.save(random_state, random_state_path)
if isinstance(self.train_batch_generator, iterators.CheckpointableIterator):
# save batch generators for all ranks
batch_generator_file_path = os.path.join(
save_dir,
"batch_generator_checkpoint_rank_{:04d}".format(self.opt["rank"]),
)
batch_generator_state = self.train_batch_generator.getstate()
torch.save(batch_generator_state, batch_generator_file_path)
else:
self.log(
"Batch generator is not checkpointable. Cannot save to checkpoint."
)
if self.opt["rank"] == 0:
self.module.save_pretrained(save_dir)
if self.opt["rank"] == 0:
# save the latest checkpoint location to json file
checkpoint_location = {
"checkpoint_tag": str(tag),
"checkpoint_path": os.path.relpath(
self.saveFolder, start=self.opt["datadir"]
),
}
json.dump(
checkpoint_location,
open(
os.path.join(
self.opt["datadir"],
self.opt["basename"] + "_resume_checkpoint.json",
),
"w",
encoding="utf-8",
),
)
self.log(f"Finished saving checkpoint and model to {save_dir}.")
def load_model(self, model_path):
# Load the model only, without any training states, using the from_pretrained API
self.module = self.module.from_pretrained(model_path)
self.module.to(self.opt["device"])
def load_checkpoint(self):
"""
Load complete training states, including model weights, optimizer, lr_scheduler,
fp16 loss scaler, random state, batch generator, and updates count
"""
try:
# load the checkpoint location from json file
checkpoint_location = json.load(
open(
os.path.join(
self.opt["datadir"],
self.opt["basename"] + "_resume_checkpoint.json",
),
encoding="utf-8",
)
)
checkpoint_path = os.path.join(
self.opt["datadir"],
checkpoint_location["checkpoint_path"],
checkpoint_location["checkpoint_tag"],
)
tag = checkpoint_location["checkpoint_tag"]
if not os.path.isdir(checkpoint_path):
if self.opt["rank"] == 0:
self.log(
"Checkpoint path {} not exist. Continue without loading checkpoint".format(
checkpoint_path
)
)
return
except:
if self.opt["rank"] == 0:
self.log(
f"Cannot find checkpoint path from {self.opt['basename']+'_resume_checkpoint.json'}.\n"
f"Make sure {os.path.join(self.opt['datadir'], self.opt['basename']+'_resume_checkpoint.json')} exists.\n"
f"Continue without loading checkpoint"
)
return
# save a copy of the resumed checkpoint location in the save folder of current run
if self.opt["rank"] == 0:
json.dump(
checkpoint_location,
open(
os.path.join(self.saveFolder, "resumed_checkpoint.json"),
"w",
encoding="utf-8",
),
)
self.log(f"Loading checkpoint from {checkpoint_path}...")
load_path = os.path.join(checkpoint_path, "training_states.pt")
state = torch.load(load_path, map_location=self.opt["device"])
self.network.load_state_dict(state["network"])
self.optimizer.load_state_dict(state["optimizer"])
self.lr_scheduler.load_state_dict(state["lr_scheduler"])
if self.opt["fp16"]:
from apex import amp
amp.load_state_dict(state["amp"])
self.optim_steps = state["optim_steps"]
self.updates = state["updates"]
self.start_epoch_idx = state["start_epoch_idx"]
self.start_batch_idx = state["start_batch_idx"]
assert self.updates_per_epoch == state["updates_per_epoch"]
assert self.start_batch_idx < self.updates_per_epoch
random_state_path = os.path.join(
checkpoint_path, "random_state_rank_{:04d}".format(self.opt["rank"])
)
random_state = torch.load(random_state_path, map_location="cpu")
random.setstate(random_state["random"])
np.random.set_state(random_state["numpy_random"])
torch.set_rng_state(random_state["torch_random"])
if self.use_cuda:
torch.cuda.set_rng_state(
random_state["torch_cuda_random"], device=self.opt["device"]
)
if "RESET_DATA_LOADER" not in self.opt and isinstance(
self.train_batch_generator, iterators.CheckpointableIterator
):
batch_generator_file_path = os.path.join(
checkpoint_path,
"batch_generator_checkpoint_rank_{:04d}".format(self.opt["rank"]),
)
batch_generator_state = torch.load(
batch_generator_file_path, map_location="cpu"
)
self.train_batch_generator.setstate(batch_generator_state)
else:
self.log(
"No need to resume batch generator or batch generator is not checkpointable. Didn't load from checkpoint."
)
self.log(f"Finished loading checkpoint from {checkpoint_path}.")