Spaces:
Build error
Build error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
from collections import defaultdict | |
from datetime import datetime | |
import os | |
import sys | |
import importlib | |
import json | |
import random | |
import numpy as np | |
import inspect | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
import torch.optim.lr_scheduler as lr_scheduler | |
from model.third_party.HMNet.Models.Trainers.DistributedTrainer import ( | |
DistributedTrainer, | |
) | |
from model.third_party.HMNet.Models.Trainers.Tasks import Task | |
from model.third_party.HMNet.Utils.GeneralUtils import ( | |
AverageMeter, | |
BaseBatchGen, | |
bcolors, | |
) | |
from model.third_party.HMNet.DataLoader import iterators | |
class ObjectView(object): | |
def __init__(self, d): | |
self.__dict__ = d | |
class WrappedModel(nn.Module): | |
def __init__(self, model, criterion): | |
super(WrappedModel, self).__init__() | |
self.add_module("model", model) | |
self.add_module("criterion", criterion) | |
def forward(self, batch): | |
output = self.model(batch) | |
loss = self.criterion(output, batch) | |
return loss | |
class HMNetTrainer(DistributedTrainer): | |
""" | |
The trainer class for HMNet model training (pre-train and fine-tune.) | |
Its train() and eval() methods are intended to directly called to | |
start training and evaluation respectively. | |
Before running, the trainer must contain proper Task, Criterion, and Optimizer | |
instances. | |
""" | |
def __init__(self, opt): | |
super().__init__(opt) | |
self.task = Task.setup_task(self.opt["TASK"], self.opt, self.saveFolder) | |
def is_gradient_accumulation_boundary(self): | |
return (self.updates + 1) % self.grad_acc_steps == 0 | |
def get_batch_generator(self, dataset_label): | |
batch_generator = self.task.batch_gen( | |
self.opt, | |
dataset_label=dataset_label, | |
model_config=self.module.config, | |
tokenizer=self.module.tokenizer, | |
world_size=self.opt["world_size"], | |
rank=self.opt["rank"], | |
seed=self.seed, | |
) | |
if isinstance(batch_generator, BaseBatchGen): | |
# If it is a wrapper class of an infinibatch iterator, | |
# get the internal infnitibatch iterator. | |
batch_generator = batch_generator.iterator | |
self.log(f"Loaded data on rank {self.opt['rank']}.") | |
return batch_generator | |
def set_up_model(self): | |
# instantiate module (tokenizer should be contained in module as self.module.tokenizer) | |
try: | |
model_module = importlib.import_module( | |
"model.third_party.HMNet.Models.Networks." + self.opt["MODEL"] | |
) | |
model_class = getattr(model_module, self.opt["MODEL"]) | |
self.module = model_class(self.opt) | |
except Exception as e: | |
self.log(e) | |
self.log("ERROR: Model {} is unknown".format(self.opt["MODEL"])) | |
assert False | |
# calculate total trainable parameters | |
pytorch_total_params = sum( | |
p.numel() for p in self.module.parameters() if p.requires_grad | |
) | |
self.log("Total trainable parameters: {}".format(pytorch_total_params)) | |
# instantiate criterion | |
try: | |
criterion_module = importlib.import_module( | |
"model.third_party.HMNet.Models.Criteria." + self.opt["CRITERION"] | |
) | |
criterion_class = getattr(criterion_module, self.opt["CRITERION"]) | |
self.criterion = criterion_class(self.opt, self.module) | |
except Exception as e: | |
self.log(e) | |
self.log("ERROR: Criterion {} is unknown".format(self.opt["CRITERION"])) | |
assert False | |
self.module.to(self.opt["device"]) | |
def get_optimizer_params_config(self, optimizer_class): | |
optimizer_parameters = {} | |
sig = inspect.signature(optimizer_class) | |
for param_name in sig.parameters.keys(): | |
if param_name == "lr": | |
optimizer_parameters[param_name] = self.opt["START_LEARNING_RATE"] | |
if param_name not in ["params", "lr"] and param_name.upper() in self.opt: | |
optimizer_parameters[param_name] = self.opt[param_name.upper()] | |
return optimizer_parameters | |
def get_lr_scheduler_params_config(self, lr_scheduler_class): | |
lr_scheduler_parameters = {} | |
sig = inspect.signature(lr_scheduler_class) | |
for param_name in sig.parameters.keys(): | |
if param_name not in ["optimizer"] and param_name.upper() in self.opt: | |
lr_scheduler_parameters[param_name] = self.opt[param_name.upper()] | |
return lr_scheduler_parameters | |
def set_up_optimizer_and_lr_scheduler(self): | |
parameters = self.module.get_training_parameters() | |
# instantiate optimizer | |
try: # first try pytorch native optimizer | |
optimizer_class = getattr(optim, self.opt["OPTIMIZER"]) | |
self.log( | |
"Using pytorch native optimizier: {}".format(self.opt["OPTIMIZER"]) | |
) | |
except: | |
try: # then try custom optimizer inside Models.Optimizers | |
optimizer_module = importlib.import_module( | |
"model.third_party.HMNet.Models.Optimizers." + self.opt["OPTIMIZER"] | |
) | |
optimizer_class = getattr(optimizer_module, self.opt["OPTIMIZER"]) | |
self.log("Using custom optimizer: {}".format(self.opt["OPTIMIZER"])) | |
except Exception as e: | |
self.log(e) | |
self.log("ERROR: Optimizer {} is unknown".format(self.opt["OPTIMIZER"])) | |
assert False | |
optimizer_parameters = self.get_optimizer_params_config(optimizer_class) | |
self.log(f"Optimizer parameters: {optimizer_parameters}") | |
self.optimizer = optimizer_class(parameters, **optimizer_parameters) | |
self.optimizer.zero_grad() | |
# instantiate lr scheduler | |
try: # first look for pytorch native lr scheduler | |
lr_scheduler_class = getattr(lr_scheduler, self.opt["LR_SCHEDULER"]) | |
self.log( | |
"Using pytorch native lr scheduler: {}".format(self.opt["LR_SCHEDULER"]) | |
) | |
except: | |
try: # then look for custom lr scheduler inside Models.Optimizers | |
lr_scheduler_module = importlib.import_module( | |
"model.third_party.HMNet.Models.Optimizers." | |
+ self.opt["LR_SCHEDULER"] | |
) | |
lr_scheduler_class = getattr( | |
lr_scheduler_module, self.opt["LR_SCHEDULER"] | |
) | |
self.log( | |
"Using custom lr scheduler: {}".format(self.opt["LR_SCHEDULER"]) | |
) | |
except Exception as e: | |
self.log(e) | |
self.log( | |
"ERROR: LR Scheduler {} is unknown".format(self.opt["LR_SCHEDULER"]) | |
) | |
assert False | |
lr_scheduler_parameters = self.get_lr_scheduler_params_config( | |
lr_scheduler_class | |
) | |
self.log(f"Lr scheduler parameters: {lr_scheduler_parameters}") | |
self.lr_scheduler = lr_scheduler_class( | |
self.optimizer, **lr_scheduler_parameters | |
) | |
def initialize_fp16_DDP(self): | |
""" | |
Wrap the module and criterion to a single network, then depending on the settings, | |
wrap the network with apex amp module for fp16 training, and wrap the network with | |
pytorch DDP module for distributed data parallel training | |
""" | |
self.network = WrappedModel(self.module, self.criterion) | |
self.network.to(self.opt["device"]) | |
if self.opt["fp16"]: | |
from apex import amp | |
self.network, self.optimizer = amp.initialize( | |
self.network, self.optimizer, opt_level=self.opt["fp16_opt_level"] | |
) | |
if self.opt["world_size"] > 1: | |
self.network = torch.nn.parallel.DistributedDataParallel( | |
self.network, | |
device_ids=[self.opt["local_rank"]], | |
output_device=self.opt["local_rank"], | |
find_unused_parameters=True, | |
) | |
self.log(f"Wrapped model with DDP on rank {self.opt['rank']}.") | |
assert self.module is self.network.module.model | |
else: | |
assert self.module is self.network.model | |
def eval(self): | |
if self.opt["rank"] == 0: | |
self.log("-----------------------------------------------") | |
self.log("Evaluating model ... ") | |
self.set_up_model() | |
for eval_dataset in ["dev", "test"]: | |
batch_generator_eval = self.get_batch_generator(eval_dataset) | |
self.task.evaluator.reset_best_score(set_high=True) | |
result, score, got_better_score = self.task.evaluator.eval_batches( | |
self.module, batch_generator_eval, self.saveFolder, eval_dataset | |
) | |
if self.opt["rank"] == 0: | |
self.log("{0} results breakdown\n{1}".format(eval_dataset, result)) | |
def eval_return_results(self): | |
if self.opt["rank"] == 0: | |
self.log("-----------------------------------------------") | |
self.log("Evaluating model ... ") | |
self.set_up_model() | |
for eval_dataset in ["test"]: | |
batch_generator_eval = self.get_batch_generator(eval_dataset) | |
self.task.evaluator.reset_best_score(set_high=True) | |
result, score, got_better_score = self.task.evaluator.eval_batches( | |
self.module, batch_generator_eval, self.saveFolder, eval_dataset | |
) | |
if self.opt["rank"] == 0: | |
self.log("{0} results breakdown\n{1}".format(eval_dataset, result)) | |
return result | |
def train(self): | |
self.log(f"train on rank {self.opt['rank']}") | |
if self.opt["rank"] == 0: | |
self.log("-----------------------------------------------") | |
self.log("Initializing model...") | |
self.set_up_model() # setup self.module as original model | |
self.network = None | |
self.train_batch_generator = self.get_batch_generator("train") | |
if isinstance(self.train_batch_generator, iterators.CheckpointableIterator): | |
# training batch generator is infinite | |
self.updates_per_epoch = self.opt["UPDATES_PER_EPOCH"] | |
else: | |
self.updates_per_epoch = len(self.train_batch_generator) | |
self.updates = 0 | |
self.optim_steps = 0 | |
self.start_epoch_idx = 0 | |
self.start_batch_idx = 0 | |
self.set_up_optimizer_and_lr_scheduler() | |
self.initialize_fp16_DDP() | |
if "RESUME" in self.opt: | |
# Resume complete training states, including optimizer, lr_scheduler, train batch generator, and updates count | |
# from the checkpoint location indicated in a .json file | |
self.load_checkpoint() | |
###################### | |
# Start the main loop | |
###################### | |
numEpochs = self.opt["MAX_NUM_EPOCHS"] | |
self.train_loss = AverageMeter() # track the average training loss | |
self.acc_loss = 0.0 | |
# after every 'SAVE_PER_UPDATE_NUM' updates, it will save a checkpoint by setting save_a_checkpoint to True temporarily | |
save_a_checkpoint = False | |
for epoch in range(self.start_epoch_idx, numEpochs): | |
self.current_epoch_idx = epoch | |
self.log("Epoch {}".format(epoch)) | |
startTime = datetime.now() | |
for batch_idx, batch in enumerate(self.train_batch_generator): | |
if self.current_epoch_idx == self.start_epoch_idx: | |
if isinstance( | |
self.train_batch_generator, iterators.CheckpointableIterator | |
): | |
batch_idx += self.start_batch_idx | |
elif batch_idx < self.start_batch_idx: | |
continue | |
self.current_batch_idx = batch_idx | |
# after every 'SAVE_PER_UPDATE_NUM' updates, save a checkpoint | |
if ("SAVE_PER_UPDATE_NUM" in self.opt) and ( | |
self.updates + 1 | |
) % self.opt["SAVE_PER_UPDATE_NUM"] == 0: | |
# Make sure the next update is going to update the weights and zero the gradients, then we can checkpoint | |
assert self.is_gradient_accumulation_boundary() | |
save_a_checkpoint = True | |
# update | |
self.update(batch) | |
if save_a_checkpoint: | |
# evaluate at the checkpointed moment, and log the results | |
if self.task.evaluator is not None: | |
evaluate_label = "update_" + str(self.updates) | |
eval_dataset = "dev" | |
batches = self.get_batch_generator(eval_dataset) | |
( | |
result, | |
score, | |
got_better_score, | |
) = self.task.evaluator.eval_batches( | |
self.module, batches, self.saveFolder, evaluate_label | |
) | |
self.tb_log_scalar("Eval/score", score, self.updates) | |
if got_better_score: | |
self.log( | |
"Got new better score on rank-{0} evaluator, at updates {1}".format( | |
self.opt["rank"], self.updates | |
) | |
) | |
self.log( | |
"Updates {0} - {1}: Current Score: {2:.3f} (best Score: {3:.3f})".format( | |
self.updates, | |
eval_dataset, | |
score, | |
self.task.evaluator.best_score, | |
) | |
) | |
self.log("Current results breakdown\n{0}".format(result)) | |
self.log( | |
"Best results breakdown\n{0}".format( | |
self.task.evaluator.best_res | |
) | |
) | |
# save complete training states, including model weights, optimizer, lr_scheduler, batch generator, and updates count | |
self.save_checkpoint(self.updates) | |
save_a_checkpoint = False | |
# logging | |
if ( | |
(batch_idx % 10 == 0) | |
or (epoch == 0 and batch_idx <= 50) | |
or "DEBUG" in self.opt | |
): | |
if self.opt["rank"] == 0: | |
batch_size = batch["encoder_input_ids"].shape[0] | |
self.log( | |
"epochs[{0:6}] updates[{1:6}] bsz[{2:d}] train loss[{3:.5f}] avg train loss[{4:.5f}] learning rate[{5:.5e}] remaining[{6}]".format( | |
epoch, | |
self.updates, | |
batch_size, | |
self.train_loss.val, | |
self.train_loss.avg, | |
self.lr_scheduler.get_lr()[0], | |
str( | |
(datetime.now() - startTime) | |
/ (batch_idx + 1) | |
* (self.updates_per_epoch - batch_idx - 1) | |
).split(".")[0], | |
) | |
) | |
self.tb_log_scalar( | |
"Loss/train_val", self.train_loss.val, self.updates | |
) | |
self.tb_log_scalar( | |
"Loss/train_avg", self.train_loss.avg, self.updates | |
) | |
self.tb_log_scalar( | |
"Learning Rate/lr", | |
self.lr_scheduler.get_lr()[0], | |
self.updates, | |
) | |
# if "DEBUG" in self.opt and batch_idx > 200: # exist early for DEBUG mode | |
# break | |
if ( | |
isinstance( | |
self.train_batch_generator, iterators.CheckpointableIterator | |
) | |
and batch_idx + 1 == self.updates_per_epoch | |
): | |
break | |
self.log("This epoch takes" + str(datetime.now() - startTime)) | |
self.log("PROGRESS: {0:.2f}%".format(100.0 * (epoch + 1) / numEpochs)) | |
self.log("Config file is at " + self.opt["confFile"]) | |
if "DEBUG" in self.opt: # exist early for DEBUG mode | |
break | |
def update(self, batch): | |
# forward loss, backward propagation, model update, and one step of optimization and lr scheduler | |
self.network.train() | |
# put the batch to the device | |
# @TODO make this more general, maybe have a self.task.move_batch(batch, device) | |
# so the trainer decides when and where to move batches, and task tells how | |
if isinstance(batch, tuple): | |
batch = tuple(t.to(self.opt["device"]) for t in batch) | |
elif isinstance(batch, list): | |
batch = [t.to(self.opt["device"]) for t in batch] | |
elif isinstance(batch, dict): | |
for k in batch: | |
if torch.is_tensor(batch[k]): | |
batch[k] = batch[k].to(self.opt["device"]) | |
else: | |
assert torch.is_tensor(batch) | |
batch = batch.to(self.opt["device"]) | |
# determine whether gradient sync can be skiped or not for this update | |
skip_gradient_sync = False | |
if self.opt["world_size"] > 1 and not self.is_gradient_accumulation_boundary(): | |
if not self.opt["fp16"]: | |
# https://krishansubudhi.github.io/deeplearning/2020/02/06/apex-gradient-accumulation.html | |
# When using fp16, if we skip grad sync during grad accumulation, the grad sync at the | |
# grad accumulation boundary cannot properly sync the whole accumulated grad. | |
# So with fp16 on, we have to sync even if it's not grad accumulation boundary. | |
if self.high_pytorch_version: | |
skip_gradient_sync = True | |
# forward | |
if skip_gradient_sync: | |
with self.network.no_sync(): | |
loss = self.network(batch) | |
else: | |
loss = self.network(batch) | |
if self.grad_acc_steps > 1: | |
loss = loss / self.grad_acc_steps | |
self.acc_loss += loss | |
# self.log(f"forward() done on rank {self.opt['rank']}") | |
# print(loss.item()) | |
# backward | |
def backward(loss_tensor): | |
if self.opt["fp16"]: | |
from apex import amp | |
with amp.scale_loss(loss_tensor, self.optimizer) as scaled_loss: | |
scaled_loss.backward() | |
else: | |
loss_tensor.backward() | |
if skip_gradient_sync: | |
with self.network.no_sync(): | |
backward(loss) | |
else: | |
if "DEBUG" in self.opt and self.opt["rank"] == 0: | |
self.log( | |
"Performing synchronized backward at step {0}".format( | |
self.optim_steps | |
) | |
) | |
backward(loss) | |
# self.log(f"backward() done on rank {self.opt['rank']}") | |
# step | |
if self.is_gradient_accumulation_boundary(): | |
if self.opt["world_size"] > 1: | |
# ddp: use all_reduce to sum up values of self.acc_loss over all processes | |
# the operations happens in place (i.e., the value of self.acc_loss is replaced) and all processes received the updated value | |
torch.distributed.all_reduce( | |
self.acc_loss, torch.distributed.ReduceOp.SUM | |
) | |
self.acc_loss /= self.opt["world_size"] | |
self.train_loss.update(self.acc_loss.data, 1) | |
self.acc_loss = 0.0 | |
if "GRAD_CLIPPING" in self.opt: | |
if self.opt["fp16"]: | |
from apex import amp | |
torch.nn.utils.clip_grad_norm_( | |
amp.master_params(self.optimizer), self.opt["GRAD_CLIPPING"] | |
) | |
else: | |
torch.nn.utils.clip_grad_norm_( | |
self.network.parameters(), self.opt["GRAD_CLIPPING"] | |
) | |
self.optim_steps += 1 | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
self.lr_scheduler.step() | |
self.updates += 1 | |
# self.log(f"step() done on rank {self.opt['rank']}") | |
def save_checkpoint(self, tag): | |
""" | |
Save complete training states, including model weights, optimizer, lr_scheduler, | |
fp16 loss scaler, random state, batch generator, and updates count | |
Also save a model with save_pretrained API for model transfer | |
""" | |
self.log("Saving checkpoint...") | |
resume_epoch_idx = self.current_epoch_idx | |
resume_batch_idx = self.current_batch_idx + 1 | |
if resume_batch_idx == self.updates_per_epoch: | |
resume_batch_idx = 0 | |
resume_epoch_idx += 1 | |
if self.opt["fp16"]: | |
from apex import amp | |
if self.opt["rank"] == 0: | |
save_dir = os.path.join(self.saveFolder, str(tag)) | |
os.makedirs(save_dir) | |
save_path = os.path.join(save_dir, "training_states.pt") | |
state = { | |
"network": self.network.state_dict(), | |
"optimizer": self.optimizer.state_dict(), | |
"lr_scheduler": self.lr_scheduler.state_dict(), | |
"amp": amp.state_dict() if self.opt["fp16"] else None, | |
"optim_steps": self.optim_steps, | |
"updates": self.updates, | |
"updates_per_epoch": self.updates_per_epoch, | |
"start_epoch_idx": resume_epoch_idx, | |
"start_batch_idx": resume_batch_idx, | |
} | |
torch.save(state, save_path) | |
if self.opt["world_size"] > 1: | |
torch.distributed.barrier() | |
save_dir = os.path.join(self.saveFolder, str(tag)) | |
assert os.path.isdir(save_dir) | |
random_state_path = os.path.join( | |
save_dir, "random_state_rank_{:04d}".format(self.opt["rank"]) | |
) | |
random_state = { | |
"random": random.getstate(), | |
"numpy_random": np.random.get_state(), | |
"torch_random": torch.get_rng_state(), | |
"torch_cuda_random": torch.cuda.get_rng_state(device=self.opt["device"]) | |
if self.use_cuda | |
else None, | |
} | |
torch.save(random_state, random_state_path) | |
if isinstance(self.train_batch_generator, iterators.CheckpointableIterator): | |
# save batch generators for all ranks | |
batch_generator_file_path = os.path.join( | |
save_dir, | |
"batch_generator_checkpoint_rank_{:04d}".format(self.opt["rank"]), | |
) | |
batch_generator_state = self.train_batch_generator.getstate() | |
torch.save(batch_generator_state, batch_generator_file_path) | |
else: | |
self.log( | |
"Batch generator is not checkpointable. Cannot save to checkpoint." | |
) | |
if self.opt["rank"] == 0: | |
self.module.save_pretrained(save_dir) | |
if self.opt["rank"] == 0: | |
# save the latest checkpoint location to json file | |
checkpoint_location = { | |
"checkpoint_tag": str(tag), | |
"checkpoint_path": os.path.relpath( | |
self.saveFolder, start=self.opt["datadir"] | |
), | |
} | |
json.dump( | |
checkpoint_location, | |
open( | |
os.path.join( | |
self.opt["datadir"], | |
self.opt["basename"] + "_resume_checkpoint.json", | |
), | |
"w", | |
encoding="utf-8", | |
), | |
) | |
self.log(f"Finished saving checkpoint and model to {save_dir}.") | |
def load_model(self, model_path): | |
# Load the model only, without any training states, using the from_pretrained API | |
self.module = self.module.from_pretrained(model_path) | |
self.module.to(self.opt["device"]) | |
def load_checkpoint(self): | |
""" | |
Load complete training states, including model weights, optimizer, lr_scheduler, | |
fp16 loss scaler, random state, batch generator, and updates count | |
""" | |
try: | |
# load the checkpoint location from json file | |
checkpoint_location = json.load( | |
open( | |
os.path.join( | |
self.opt["datadir"], | |
self.opt["basename"] + "_resume_checkpoint.json", | |
), | |
encoding="utf-8", | |
) | |
) | |
checkpoint_path = os.path.join( | |
self.opt["datadir"], | |
checkpoint_location["checkpoint_path"], | |
checkpoint_location["checkpoint_tag"], | |
) | |
tag = checkpoint_location["checkpoint_tag"] | |
if not os.path.isdir(checkpoint_path): | |
if self.opt["rank"] == 0: | |
self.log( | |
"Checkpoint path {} not exist. Continue without loading checkpoint".format( | |
checkpoint_path | |
) | |
) | |
return | |
except: | |
if self.opt["rank"] == 0: | |
self.log( | |
f"Cannot find checkpoint path from {self.opt['basename']+'_resume_checkpoint.json'}.\n" | |
f"Make sure {os.path.join(self.opt['datadir'], self.opt['basename']+'_resume_checkpoint.json')} exists.\n" | |
f"Continue without loading checkpoint" | |
) | |
return | |
# save a copy of the resumed checkpoint location in the save folder of current run | |
if self.opt["rank"] == 0: | |
json.dump( | |
checkpoint_location, | |
open( | |
os.path.join(self.saveFolder, "resumed_checkpoint.json"), | |
"w", | |
encoding="utf-8", | |
), | |
) | |
self.log(f"Loading checkpoint from {checkpoint_path}...") | |
load_path = os.path.join(checkpoint_path, "training_states.pt") | |
state = torch.load(load_path, map_location=self.opt["device"]) | |
self.network.load_state_dict(state["network"]) | |
self.optimizer.load_state_dict(state["optimizer"]) | |
self.lr_scheduler.load_state_dict(state["lr_scheduler"]) | |
if self.opt["fp16"]: | |
from apex import amp | |
amp.load_state_dict(state["amp"]) | |
self.optim_steps = state["optim_steps"] | |
self.updates = state["updates"] | |
self.start_epoch_idx = state["start_epoch_idx"] | |
self.start_batch_idx = state["start_batch_idx"] | |
assert self.updates_per_epoch == state["updates_per_epoch"] | |
assert self.start_batch_idx < self.updates_per_epoch | |
random_state_path = os.path.join( | |
checkpoint_path, "random_state_rank_{:04d}".format(self.opt["rank"]) | |
) | |
random_state = torch.load(random_state_path, map_location="cpu") | |
random.setstate(random_state["random"]) | |
np.random.set_state(random_state["numpy_random"]) | |
torch.set_rng_state(random_state["torch_random"]) | |
if self.use_cuda: | |
torch.cuda.set_rng_state( | |
random_state["torch_cuda_random"], device=self.opt["device"] | |
) | |
if "RESET_DATA_LOADER" not in self.opt and isinstance( | |
self.train_batch_generator, iterators.CheckpointableIterator | |
): | |
batch_generator_file_path = os.path.join( | |
checkpoint_path, | |
"batch_generator_checkpoint_rank_{:04d}".format(self.opt["rank"]), | |
) | |
batch_generator_state = torch.load( | |
batch_generator_file_path, map_location="cpu" | |
) | |
self.train_batch_generator.setstate(batch_generator_state) | |
else: | |
self.log( | |
"No need to resume batch generator or batch generator is not checkpointable. Didn't load from checkpoint." | |
) | |
self.log(f"Finished loading checkpoint from {checkpoint_path}.") | |