Spaces:
Runtime error
Runtime error
""" Utility file for trainers """ | |
import os | |
import shutil | |
from glob import glob | |
import torch | |
import torch.distributed as dist | |
''' checkpoint functions ''' | |
# saves checkpoint | |
def save_checkpoint(model, \ | |
optimizer, \ | |
scheduler, \ | |
epoch, \ | |
checkpoint_dir, \ | |
name, \ | |
model_name): | |
os.makedirs(checkpoint_dir, exist_ok=True) | |
checkpoint_state = { | |
"model": model.state_dict(), | |
"optimizer": optimizer.state_dict(), | |
"scheduler": scheduler.state_dict(), | |
"epoch": epoch | |
} | |
checkpoint_path = os.path.join(checkpoint_dir,'{}_{}_{}.pt'.format(name, model_name, epoch)) | |
torch.save(checkpoint_state, checkpoint_path) | |
print("Saved checkpoint: {}".format(checkpoint_path)) | |
# reload model weights from checkpoint file | |
def reload_ckpt(args, \ | |
network, \ | |
optimizer, \ | |
scheduler, \ | |
gpu, \ | |
model_name, \ | |
manual_reload_name=None, \ | |
manual_reload=False, \ | |
manual_reload_dir=None, \ | |
epoch=None, \ | |
fit_sefa=False): | |
if manual_reload: | |
reload_name = manual_reload_name | |
else: | |
reload_name = args.name | |
if manual_reload_dir: | |
ckpt_dir = manual_reload_dir + reload_name + "/ckpt/" | |
else: | |
ckpt_dir = args.output_dir + reload_name + "/ckpt/" | |
temp_ckpt_dir = f'{args.output_dir}{reload_name}/ckpt_temp/' | |
reload_epoch = epoch | |
# find best or latest epoch | |
if epoch==None: | |
reload_epoch_temp = 0 | |
reload_epoch_ckpt = 0 | |
if len(os.listdir(temp_ckpt_dir))!=0: | |
reload_epoch_temp = find_best_epoch(temp_ckpt_dir) | |
if len(os.listdir(ckpt_dir))!=0: | |
reload_epoch_ckpt = find_best_epoch(ckpt_dir) | |
if reload_epoch_ckpt >= reload_epoch_temp: | |
reload_epoch = reload_epoch_ckpt | |
else: | |
reload_epoch = reload_epoch_temp | |
ckpt_dir = temp_ckpt_dir | |
else: | |
if os.path.isfile(f"{temp_ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt"): | |
ckpt_dir = temp_ckpt_dir | |
# reloading weight | |
if model_name==None: | |
resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{reload_epoch}.pt" | |
else: | |
resuming_path = f"{ckpt_dir}{reload_epoch}/{reload_name}_{model_name}_{reload_epoch}.pt" | |
if gpu==0: | |
print("===Resume checkpoint from: {}===".format(resuming_path)) | |
loc = 'cuda:{}'.format(gpu) | |
checkpoint = torch.load(resuming_path, map_location=loc) | |
start_epoch = 0 if manual_reload and not fit_sefa else checkpoint["epoch"] | |
if manual_reload_dir is not None and 'parameter_estimation' in manual_reload_dir: | |
from collections import OrderedDict | |
new_state_dict = OrderedDict() | |
for k, v in checkpoint["model"].items(): | |
name = 'module.' + k # add `module.` | |
new_state_dict[name] = v | |
network.load_state_dict(new_state_dict) | |
else: | |
network.load_state_dict(checkpoint["model"]) | |
if not manual_reload: | |
optimizer.load_state_dict(checkpoint["optimizer"]) | |
scheduler.load_state_dict(checkpoint["scheduler"]) | |
if gpu==0: | |
# print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, checkpoint['epoch'])) | |
print("=> loaded checkpoint '{}' (epoch {})".format(resuming_path, epoch)) | |
return start_epoch | |
# find best epoch for reloading current model | |
def find_best_epoch(input_dir): | |
cur_epochs = glob("{}*".format(input_dir)) | |
return find_by_name(cur_epochs) | |
# sort string epoch names by integers | |
def find_by_name(epochs): | |
int_epochs = [] | |
for e in epochs: | |
int_epochs.append(int(os.path.basename(e))) | |
int_epochs.sort() | |
return (int_epochs[-1]) | |
# remove ckpt files | |
def remove_ckpt(cur_ckpt_path_dir, leave=2): | |
ckpt_nums = [int(i) for i in os.listdir(cur_ckpt_path_dir)] | |
ckpt_nums.sort() | |
del_num = len(ckpt_nums) - leave | |
cur_del_num = 0 | |
while del_num > 0: | |
shutil.rmtree("{}{}".format(cur_ckpt_path_dir, ckpt_nums[cur_del_num])) | |
del_num -= 1 | |
cur_del_num += 1 | |
''' multi-GPU functions ''' | |
# gather function implemented from DirectCLR | |
class GatherLayer_Direct(torch.autograd.Function): | |
""" | |
Gather tensors from all workers with support for backward propagation: | |
This implementation does not cut the gradients as torch.distributed.all_gather does. | |
""" | |
def forward(ctx, x): | |
output = [torch.zeros_like(x) for _ in range(dist.get_world_size())] | |
dist.all_gather(output, x) | |
return tuple(output) | |
def backward(ctx, *grads): | |
all_gradients = torch.stack(grads) | |
dist.all_reduce(all_gradients) | |
return all_gradients[dist.get_rank()] | |
from classy_vision.generic.distributed_util import ( | |
convert_to_distributed_tensor, | |
convert_to_normal_tensor, | |
is_distributed_training_run, | |
) | |
def gather_from_all(tensor: torch.Tensor) -> torch.Tensor: | |
""" | |
Similar to classy_vision.generic.distributed_util.gather_from_all | |
except that it does not cut the gradients | |
""" | |
if tensor.ndim == 0: | |
# 0 dim tensors cannot be gathered. so unsqueeze | |
tensor = tensor.unsqueeze(0) | |
if is_distributed_training_run(): | |
tensor, orig_device = convert_to_distributed_tensor(tensor) | |
gathered_tensors = GatherLayer_Direct.apply(tensor) | |
gathered_tensors = [ | |
convert_to_normal_tensor(_tensor, orig_device) | |
for _tensor in gathered_tensors | |
] | |
else: | |
gathered_tensors = [tensor] | |
gathered_tensor = torch.cat(gathered_tensors, 0) | |
return gathered_tensor | |