|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.backends.cudnn as cudnn |
|
|
|
|
|
import torch.distributed as dist |
|
import torch.multiprocessing as mp |
|
|
|
import os |
|
import os.path as osp |
|
import sys |
|
import numpy as np |
|
import pprint |
|
import timeit |
|
import time |
|
import copy |
|
import matplotlib.pyplot as plt |
|
|
|
from .cfg_holder import cfg_unique_holder as cfguh |
|
|
|
from .data_factory import \ |
|
get_dataset, collate, \ |
|
get_loader, \ |
|
get_transform, \ |
|
get_estimator, \ |
|
get_formatter, \ |
|
get_sampler |
|
|
|
from .model_zoo import \ |
|
get_model, get_optimizer, get_scheduler |
|
|
|
from .log_service import print_log, distributed_log_manager |
|
|
|
from .evaluator import get_evaluator |
|
from . import sync |
|
|
|
class train_stage(object): |
|
""" |
|
This is a template for a train stage, |
|
(can be either train or test or anything) |
|
Usually, it takes RANK |
|
one dataloader, one model, one optimizer, one scheduler. |
|
But it is not limited to these parameters. |
|
""" |
|
def __init__(self): |
|
self.nested_eval_stage = None |
|
self.rv_keep = None |
|
|
|
def is_better(self, x): |
|
return (self.rv_keep is None) or (x>self.rv_keep) |
|
|
|
def set_model(self, net, mode): |
|
if mode == 'train': |
|
return net.train() |
|
elif mode == 'eval': |
|
return net.eval() |
|
else: |
|
raise ValueError |
|
|
|
def __call__(self, |
|
**paras): |
|
cfg = cfguh().cfg |
|
cfgt = cfg.train |
|
logm = distributed_log_manager() |
|
epochn, itern, samplen = 0, 0, 0 |
|
|
|
step_type = cfgt.get('step_type', 'iter') |
|
assert step_type in ['epoch', 'iter', 'sample'], \ |
|
'Step type must be in [epoch, iter, sample]' |
|
|
|
step_num = cfgt.get('step_num' , None) |
|
gradacc_every = cfgt.get('gradacc_every', 1 ) |
|
log_every = cfgt.get('log_every' , None) |
|
ckpt_every = cfgt.get('ckpt_every' , None) |
|
eval_start = cfgt.get('eval_start' , 0 ) |
|
eval_every = cfgt.get('eval_every' , None) |
|
|
|
if paras.get('resume_step', None) is not None: |
|
resume_step = paras['resume_step'] |
|
assert step_type == resume_step['type'] |
|
epochn = resume_step['epochn'] |
|
itern = resume_step['itern'] |
|
samplen = resume_step['samplen'] |
|
del paras['resume_step'] |
|
|
|
trainloader = paras['trainloader'] |
|
optimizer = paras['optimizer'] |
|
scheduler = paras['scheduler'] |
|
net = paras['net'] |
|
|
|
GRANK, LRANK, NRANK = sync.get_rank('all') |
|
GWSIZE, LWSIZE, NODES = sync.get_world_size('all') |
|
|
|
weight_path = osp.join(cfgt.log_dir, 'weight') |
|
if (GRANK==0) and (not osp.isdir(weight_path)): |
|
os.makedirs(weight_path) |
|
if (GRANK==0) and (cfgt.save_init_model): |
|
self.save(net, is_init=True, step=0, optimizer=optimizer) |
|
|
|
epoch_time = timeit.default_timer() |
|
end_flag = False |
|
net.train() |
|
|
|
while True: |
|
if step_type == 'epoch': |
|
lr = scheduler[epochn] if scheduler is not None else None |
|
for batch in trainloader: |
|
|
|
if not isinstance(batch[0], list): |
|
bs = batch[0].shape[0] |
|
else: |
|
bs = len(batch[0]) |
|
if cfgt.skip_partial_batch and (bs != cfgt.batch_size_per_gpu): |
|
continue |
|
|
|
itern_next = itern + 1 |
|
samplen_next = samplen + bs*GWSIZE |
|
|
|
if step_type == 'iter': |
|
lr = scheduler[itern//gradacc_every] if scheduler is not None else None |
|
grad_update = itern%gradacc_every==(gradacc_every-1) |
|
elif step_type == 'sample': |
|
lr = scheduler[samplen] if scheduler is not None else None |
|
|
|
|
|
|
|
|
|
paras_new = self.main( |
|
batch=batch, |
|
lr=lr, |
|
itern=itern, |
|
epochn=epochn, |
|
samplen=samplen, |
|
isinit=False, |
|
grad_update=grad_update, |
|
**paras) |
|
|
|
|
|
paras.update(paras_new) |
|
logm.accumulate(bs, **paras['log_info']) |
|
|
|
|
|
|
|
|
|
|
|
display_flag = False |
|
if log_every is not None: |
|
display_i = (itern//log_every) != (itern_next//log_every) |
|
display_s = (samplen//log_every) != (samplen_next//log_every) |
|
display_flag = (display_i and (step_type=='iter')) \ |
|
or (display_s and (step_type=='sample')) |
|
|
|
if display_flag: |
|
tbstep = itern_next if step_type=='iter' else samplen_next |
|
console_info = logm.train_summary( |
|
itern_next, epochn, samplen_next, lr, tbstep=tbstep) |
|
logm.clear() |
|
print_log(console_info) |
|
|
|
|
|
|
|
|
|
|
|
eval_flag = False |
|
if (self.nested_eval_stage is not None) and (eval_every is not None) and (NRANK == 0): |
|
if step_type=='iter': |
|
eval_flag = (itern//eval_every) != (itern_next//eval_every) |
|
eval_flag = eval_flag and (itern_next>=eval_start) |
|
eval_flag = eval_flag or itern==0 |
|
if step_type=='sample': |
|
eval_flag = (samplen//eval_every) != (samplen_next//eval_every) |
|
eval_flag = eval_flag and (samplen_next>=eval_start) |
|
eval_flag = eval_flag or samplen==0 |
|
|
|
if eval_flag: |
|
eval_cnt = itern_next if step_type=='iter' else samplen_next |
|
net = self.set_model(net, 'eval') |
|
rv = self.nested_eval_stage( |
|
eval_cnt=eval_cnt, **paras) |
|
rv = rv.get('eval_rv', None) |
|
if rv is not None: |
|
logm.tensorboard_log(eval_cnt, rv, mode='eval') |
|
if self.is_better(rv): |
|
self.rv_keep = rv |
|
if GRANK==0: |
|
step = {'epochn':epochn, 'itern':itern_next, |
|
'samplen':samplen_next, 'type':step_type, } |
|
self.save(net, is_best=True, step=step, optimizer=optimizer) |
|
net = self.set_model(net, 'train') |
|
|
|
|
|
|
|
|
|
|
|
ckpt_flag = False |
|
if (GRANK==0) and (ckpt_every is not None): |
|
|
|
ckpt_i = (itern//ckpt_every) != (itern_next//ckpt_every) |
|
ckpt_s = (samplen//ckpt_every) != (samplen_next//ckpt_every) |
|
ckpt_flag = (ckpt_i and (step_type=='iter')) \ |
|
or (ckpt_s and (step_type=='sample')) |
|
|
|
if ckpt_flag: |
|
if step_type == 'iter': |
|
print_log('Checkpoint... {}'.format(itern_next)) |
|
step = {'epochn':epochn, 'itern':itern_next, |
|
'samplen':samplen_next, 'type':step_type, } |
|
self.save(net, itern=itern_next, step=step, optimizer=optimizer) |
|
else: |
|
print_log('Checkpoint... {}'.format(samplen_next)) |
|
step = {'epochn':epochn, 'itern':itern_next, |
|
'samplen':samplen_next, 'type':step_type, } |
|
self.save(net, samplen=samplen_next, step=step, optimizer=optimizer) |
|
|
|
|
|
|
|
|
|
|
|
itern = itern_next |
|
samplen = samplen_next |
|
|
|
if step_type is not None: |
|
end_flag = (itern>=step_num and (step_type=='iter')) \ |
|
or (samplen>=step_num and (step_type=='sample')) |
|
if end_flag: |
|
break |
|
|
|
|
|
epochn += 1 |
|
print_log('Epoch {} time:{:.2f}s.'.format( |
|
epochn, timeit.default_timer()-epoch_time)) |
|
epoch_time = timeit.default_timer() |
|
|
|
if end_flag: |
|
break |
|
elif step_type != 'epoch': |
|
|
|
trainloader = self.trick_update_trainloader(trainloader) |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
display_flag = False |
|
if (log_every is not None) and (step_type=='epoch'): |
|
display_flag = (epochn==1) or (epochn%log_every==0) |
|
|
|
if display_flag: |
|
console_info = logm.train_summary( |
|
itern, epochn, samplen, lr, tbstep=epochn) |
|
logm.clear() |
|
print_log(console_info) |
|
|
|
|
|
|
|
|
|
|
|
eval_flag = False |
|
if (self.nested_eval_stage is not None) and (eval_every is not None) \ |
|
and (step_type=='epoch') and (NRANK==0): |
|
eval_flag = (epochn%eval_every==0) and (itern_next>=eval_start) |
|
eval_flag = (epochn==1) or eval_flag |
|
|
|
if eval_flag: |
|
net = self.set_model(net, 'eval') |
|
rv = self.nested_eval_stage( |
|
eval_cnt=epochn, |
|
**paras)['eval_rv'] |
|
if rv is not None: |
|
logm.tensorboard_log(epochn, rv, mode='eval') |
|
if self.is_better(rv): |
|
self.rv_keep = rv |
|
if (GRANK==0): |
|
step = {'epochn':epochn, 'itern':itern, |
|
'samplen':samplen, 'type':step_type, } |
|
self.save(net, is_best=True, step=step, optimizer=optimizer) |
|
net = self.set_model(net, 'train') |
|
|
|
|
|
|
|
|
|
|
|
ckpt_flag = False |
|
if (ckpt_every is not None) and (GRANK==0) and (step_type=='epoch'): |
|
|
|
ckpt_flag = epochn%ckpt_every==0 |
|
|
|
if ckpt_flag: |
|
print_log('Checkpoint... {}'.format(itern_next)) |
|
step = {'epochn':epochn, 'itern':itern, |
|
'samplen':samplen, 'type':step_type, } |
|
self.save(net, epochn=epochn, step=step, optimizer=optimizer) |
|
|
|
|
|
|
|
|
|
if (step_type=='epoch') and (epochn>=step_num): |
|
break |
|
|
|
|
|
|
|
trainloader = self.trick_update_trainloader(trainloader) |
|
|
|
logm.tensorboard_close() |
|
return {} |
|
|
|
def main(self, **paras): |
|
raise NotImplementedError |
|
|
|
def trick_update_trainloader(self, trainloader): |
|
return trainloader |
|
|
|
def save_model(self, net, path_noext, **paras): |
|
cfgt = cfguh().cfg.train |
|
path = path_noext+'.pth' |
|
if isinstance(net, (torch.nn.DataParallel, |
|
torch.nn.parallel.DistributedDataParallel)): |
|
netm = net.module |
|
else: |
|
netm = net |
|
torch.save(netm.state_dict(), path) |
|
print_log('Saving model file {0}'.format(path)) |
|
|
|
def save(self, net, itern=None, epochn=None, samplen=None, |
|
is_init=False, is_best=False, is_last=False, **paras): |
|
exid = cfguh().cfg.env.experiment_id |
|
cfgt = cfguh().cfg.train |
|
cfgm = cfguh().cfg.model |
|
if isinstance(net, (torch.nn.DataParallel, |
|
torch.nn.parallel.DistributedDataParallel)): |
|
netm = net.module |
|
else: |
|
netm = net |
|
net_symbol = cfgm.symbol |
|
|
|
check = sum([ |
|
itern is not None, samplen is not None, epochn is not None, |
|
is_init, is_best, is_last]) |
|
assert check<2 |
|
|
|
if itern is not None: |
|
path_noexp = '{}_{}_iter_{}'.format(exid, net_symbol, itern) |
|
elif samplen is not None: |
|
path_noexp = '{}_{}_samplen_{}'.format(exid, net_symbol, samplen) |
|
elif epochn is not None: |
|
path_noexp = '{}_{}_epoch_{}'.format(exid, net_symbol, epochn) |
|
elif is_init: |
|
path_noexp = '{}_{}_init'.format(exid, net_symbol) |
|
elif is_best: |
|
path_noexp = '{}_{}_best'.format(exid, net_symbol) |
|
elif is_last: |
|
path_noexp = '{}_{}_last'.format(exid, net_symbol) |
|
else: |
|
path_noexp = '{}_{}_default'.format(exid, net_symbol) |
|
|
|
path_noexp = osp.join(cfgt.log_dir, 'weight', path_noexp) |
|
self.save_model(net, path_noexp, **paras) |
|
|
|
class eval_stage(object): |
|
def __init__(self): |
|
self.evaluator = None |
|
|
|
def create_dir(self, path): |
|
local_rank = sync.get_rank('local') |
|
if (not osp.isdir(path)) and (local_rank == 0): |
|
os.makedirs(path) |
|
sync.nodewise_sync().barrier() |
|
|
|
def __call__(self, |
|
evalloader, |
|
net, |
|
**paras): |
|
cfgt = cfguh().cfg.eval |
|
local_rank = sync.get_rank('local') |
|
if self.evaluator is None: |
|
evaluator = get_evaluator()(cfgt.evaluator) |
|
self.evaluator = evaluator |
|
else: |
|
evaluator = self.evaluator |
|
|
|
time_check = timeit.default_timer() |
|
|
|
for idx, batch in enumerate(evalloader): |
|
rv = self.main(batch, net) |
|
evaluator.add_batch(**rv) |
|
if cfgt.output_result: |
|
try: |
|
self.output_f(**rv, cnt=paras['eval_cnt']) |
|
except: |
|
self.output_f(**rv) |
|
if idx%cfgt.log_display == cfgt.log_display-1: |
|
print_log('processed.. {}, Time:{:.2f}s'.format( |
|
idx+1, timeit.default_timer() - time_check)) |
|
time_check = timeit.default_timer() |
|
|
|
|
|
evaluator.set_sample_n(len(evalloader.dataset)) |
|
eval_rv = evaluator.compute() |
|
if local_rank == 0: |
|
evaluator.one_line_summary() |
|
evaluator.save(cfgt.log_dir) |
|
evaluator.clear_data() |
|
return { |
|
'eval_rv' : eval_rv |
|
} |
|
|
|
class exec_container(object): |
|
""" |
|
This is the base functor for all types of executions. |
|
One execution can have multiple stages, |
|
but are only allowed to use the same |
|
config, network, dataloader. |
|
Thus, in most of the cases, one exec_container is one |
|
training/evaluation/demo... |
|
If DPP is in use, this functor should be spawn. |
|
""" |
|
def __init__(self, |
|
cfg, |
|
**kwargs): |
|
self.cfg = cfg |
|
self.registered_stages = [] |
|
self.node_rank = None |
|
self.local_rank = None |
|
self.global_rank = None |
|
self.local_world_size = None |
|
self.global_world_size = None |
|
self.nodewise_sync_global_obj = sync.nodewise_sync_global() |
|
|
|
def register_stage(self, stage): |
|
self.registered_stages.append(stage) |
|
|
|
def __call__(self, |
|
local_rank, |
|
**kwargs): |
|
cfg = self.cfg |
|
cfguh().save_cfg(cfg) |
|
|
|
self.node_rank = cfg.env.node_rank |
|
self.local_rank = local_rank |
|
self.nodes = cfg.env.nodes |
|
self.local_world_size = cfg.env.gpu_count |
|
|
|
self.global_rank = self.local_rank + self.node_rank * self.nodes |
|
self.global_world_size = self.nodes * self.local_world_size |
|
|
|
dist.init_process_group( |
|
backend = cfg.env.dist_backend, |
|
init_method = cfg.env.dist_url, |
|
rank = self.global_rank, |
|
world_size = self.global_world_size,) |
|
torch.cuda.set_device(local_rank) |
|
sync.nodewise_sync().copy_global(self.nodewise_sync_global_obj).local_init() |
|
|
|
if isinstance(cfg.env.rnd_seed, int): |
|
np.random.seed(cfg.env.rnd_seed + self.global_rank) |
|
torch.manual_seed(cfg.env.rnd_seed + self.global_rank) |
|
|
|
time_start = timeit.default_timer() |
|
|
|
para = {'itern_total' : 0,} |
|
dl_para = self.prepare_dataloader() |
|
assert isinstance(dl_para, dict) |
|
para.update(dl_para) |
|
|
|
md_para = self.prepare_model() |
|
assert isinstance(md_para, dict) |
|
para.update(md_para) |
|
|
|
for stage in self.registered_stages: |
|
stage_para = stage(**para) |
|
if stage_para is not None: |
|
para.update(stage_para) |
|
|
|
if self.global_rank==0: |
|
self.save_last_model(**para) |
|
|
|
print_log( |
|
'Total {:.2f} seconds'.format(timeit.default_timer() - time_start)) |
|
dist.destroy_process_group() |
|
|
|
def prepare_dataloader(self): |
|
""" |
|
Prepare the dataloader from config. |
|
""" |
|
return { |
|
'trainloader' : None, |
|
'evalloader' : None} |
|
|
|
def prepare_model(self): |
|
""" |
|
Prepare the model from config. |
|
""" |
|
return {'net' : None} |
|
|
|
def save_last_model(self, **para): |
|
return |
|
|
|
def destroy(self): |
|
self.nodewise_sync_global_obj.destroy() |
|
|
|
class train(exec_container): |
|
def prepare_dataloader(self): |
|
cfg = cfguh().cfg |
|
trainset = get_dataset()(cfg.train.dataset) |
|
sampler = get_sampler()( |
|
dataset=trainset, cfg=cfg.train.dataset.get('sampler', 'default_train')) |
|
trainloader = torch.utils.data.DataLoader( |
|
trainset, |
|
batch_size = cfg.train.batch_size_per_gpu, |
|
sampler = sampler, |
|
num_workers = cfg.train.dataset_num_workers_per_gpu, |
|
drop_last = False, |
|
pin_memory = cfg.train.dataset.get('pin_memory', False), |
|
collate_fn = collate(),) |
|
|
|
evalloader = None |
|
if 'eval' in cfg: |
|
evalset = get_dataset()(cfg.eval.dataset) |
|
if evalset is not None: |
|
sampler = get_sampler()( |
|
dataset=evalset, cfg=cfg.eval.dataset.get('sampler', 'default_eval')) |
|
evalloader = torch.utils.data.DataLoader( |
|
evalset, |
|
batch_size = cfg.eval.batch_size_per_gpu, |
|
sampler = sampler, |
|
num_workers = cfg.eval.dataset_num_workers_per_gpu, |
|
drop_last = False, |
|
pin_memory = cfg.eval.dataset.get('pin_memory', False), |
|
collate_fn = collate(),) |
|
|
|
return { |
|
'trainloader' : trainloader, |
|
'evalloader' : evalloader,} |
|
|
|
def prepare_model(self): |
|
cfg = cfguh().cfg |
|
net = get_model()(cfg.model) |
|
if cfg.env.cuda: |
|
net.to(self.local_rank) |
|
net = torch.nn.parallel.DistributedDataParallel( |
|
net, device_ids=[self.local_rank], |
|
find_unused_parameters=True) |
|
net.train() |
|
scheduler = get_scheduler()(cfg.train.scheduler) |
|
optimizer = get_optimizer()(net, cfg.train.optimizer) |
|
return { |
|
'net' : net, |
|
'optimizer' : optimizer, |
|
'scheduler' : scheduler,} |
|
|
|
def save_last_model(self, **para): |
|
cfgt = cfguh().cfg.train |
|
net = para['net'] |
|
net_symbol = cfguh().cfg.model.symbol |
|
if isinstance(net, (torch.nn.DataParallel, |
|
torch.nn.parallel.DistributedDataParallel)): |
|
netm = net.module |
|
else: |
|
netm = net |
|
path = osp.join(cfgt.log_dir, '{}_{}_last.pth'.format( |
|
cfgt.experiment_id, net_symbol)) |
|
torch.save(netm.state_dict(), path) |
|
print_log('Saving model file {0}'.format(path)) |
|
|
|
class eval(exec_container): |
|
def prepare_dataloader(self): |
|
cfg = cfguh().cfg |
|
evalloader = None |
|
if cfg.eval.get('dataset', None) is not None: |
|
evalset = get_dataset()(cfg.eval.dataset) |
|
if evalset is None: |
|
return |
|
sampler = get_sampler()( |
|
dataset=evalset, cfg=getattr(cfg.eval.dataset, 'sampler', 'default_eval')) |
|
evalloader = torch.utils.data.DataLoader( |
|
evalset, |
|
batch_size = cfg.eval.batch_size_per_gpu, |
|
sampler = sampler, |
|
num_workers = cfg.eval.dataset_num_workers_per_gpu, |
|
drop_last = False, |
|
pin_memory = False, |
|
collate_fn = collate(), ) |
|
return { |
|
'trainloader' : None, |
|
'evalloader' : evalloader,} |
|
|
|
def prepare_model(self): |
|
cfg = cfguh().cfg |
|
net = get_model()(cfg.model) |
|
if cfg.env.cuda: |
|
net.to(self.local_rank) |
|
net = torch.nn.parallel.DistributedDataParallel( |
|
net, device_ids=[self.local_rank], |
|
find_unused_parameters=True) |
|
net.eval() |
|
return {'net' : net,} |
|
|
|
def save_last_model(self, **para): |
|
return |
|
|
|
|
|
|
|
|
|
|
|
def torch_to_numpy(*argv): |
|
if len(argv) > 1: |
|
data = list(argv) |
|
else: |
|
data = argv[0] |
|
|
|
if isinstance(data, torch.Tensor): |
|
return data.to('cpu').detach().numpy() |
|
elif isinstance(data, (list, tuple)): |
|
out = [] |
|
for di in data: |
|
out.append(torch_to_numpy(di)) |
|
return out |
|
elif isinstance(data, dict): |
|
out = {} |
|
for ni, di in data.items(): |
|
out[ni] = torch_to_numpy(di) |
|
return out |
|
else: |
|
return data |
|
|
|
import importlib |
|
|
|
def get_obj_from_str(string, reload=False): |
|
module, cls = string.rsplit(".", 1) |
|
if reload: |
|
module_imp = importlib.import_module(module) |
|
importlib.reload(module_imp) |
|
return getattr(importlib.import_module(module, package=None), cls) |
|
|