|
import timeit |
|
import numpy as np |
|
import os |
|
import os.path as osp |
|
import shutil |
|
import copy |
|
import torch |
|
import torch.nn as nn |
|
import torch.distributed as dist |
|
from .cfg_holder import cfg_unique_holder as cfguh |
|
from . import sync |
|
|
|
print_console_local_rank0_only = True |
|
|
|
def print_log(*console_info): |
|
local_rank = sync.get_rank('local') |
|
if print_console_local_rank0_only and (local_rank!=0): |
|
return |
|
console_info = [str(i) for i in console_info] |
|
console_info = ' '.join(console_info) |
|
print(console_info) |
|
|
|
if local_rank!=0: |
|
return |
|
|
|
log_file = None |
|
try: |
|
log_file = cfguh().cfg.train.log_file |
|
except: |
|
try: |
|
log_file = cfguh().cfg.eval.log_file |
|
except: |
|
return |
|
if log_file is not None: |
|
with open(log_file, 'a') as f: |
|
f.write(console_info + '\n') |
|
|
|
class distributed_log_manager(object): |
|
def __init__(self): |
|
self.sum = {} |
|
self.cnt = {} |
|
self.time_check = timeit.default_timer() |
|
|
|
cfgt = cfguh().cfg.train |
|
use_tensorboard = getattr(cfgt, 'log_tensorboard', False) |
|
|
|
self.ddp = sync.is_ddp() |
|
self.rank = sync.get_rank('local') |
|
self.world_size = sync.get_world_size('local') |
|
|
|
self.tb = None |
|
if use_tensorboard and (self.rank==0): |
|
import tensorboardX |
|
monitoring_dir = osp.join(cfguh().cfg.train.log_dir, 'tensorboard') |
|
self.tb = tensorboardX.SummaryWriter(osp.join(monitoring_dir)) |
|
|
|
def accumulate(self, n, **data): |
|
if n < 0: |
|
raise ValueError |
|
|
|
for itemn, di in data.items(): |
|
if itemn in self.sum: |
|
self.sum[itemn] += di * n |
|
self.cnt[itemn] += n |
|
else: |
|
self.sum[itemn] = di * n |
|
self.cnt[itemn] = n |
|
|
|
def get_mean_value_dict(self): |
|
value_gather = [ |
|
self.sum[itemn]/self.cnt[itemn] \ |
|
for itemn in sorted(self.sum.keys()) ] |
|
|
|
value_gather_tensor = torch.FloatTensor(value_gather).to(self.rank) |
|
if self.ddp: |
|
dist.all_reduce(value_gather_tensor, op=dist.ReduceOp.SUM) |
|
value_gather_tensor /= self.world_size |
|
|
|
mean = {} |
|
for idx, itemn in enumerate(sorted(self.sum.keys())): |
|
mean[itemn] = value_gather_tensor[idx].item() |
|
return mean |
|
|
|
def tensorboard_log(self, step, data, mode='train', **extra): |
|
if self.tb is None: |
|
return |
|
if mode == 'train': |
|
self.tb.add_scalar('other/epochn', extra['epochn'], step) |
|
if 'lr' in extra: |
|
self.tb.add_scalar('other/lr', extra['lr'], step) |
|
for itemn, di in data.items(): |
|
if itemn.find('loss') == 0: |
|
self.tb.add_scalar('loss/'+itemn, di, step) |
|
elif itemn == 'Loss': |
|
self.tb.add_scalar('Loss', di, step) |
|
else: |
|
self.tb.add_scalar('other/'+itemn, di, step) |
|
elif mode == 'eval': |
|
if isinstance(data, dict): |
|
for itemn, di in data.items(): |
|
self.tb.add_scalar('eval/'+itemn, di, step) |
|
else: |
|
self.tb.add_scalar('eval', data, step) |
|
return |
|
|
|
def train_summary(self, itern, epochn, samplen, lr, tbstep=None): |
|
console_info = [ |
|
'Iter:{}'.format(itern), |
|
'Epoch:{}'.format(epochn), |
|
'Sample:{}'.format(samplen),] |
|
|
|
if lr is not None: |
|
console_info += ['LR:{:.4E}'.format(lr)] |
|
|
|
mean = self.get_mean_value_dict() |
|
|
|
tbstep = itern if tbstep is None else tbstep |
|
self.tensorboard_log( |
|
tbstep, mean, mode='train', |
|
itern=itern, epochn=epochn, lr=lr) |
|
|
|
loss = mean.pop('Loss') |
|
mean_info = ['Loss:{:.4f}'.format(loss)] + [ |
|
'{}:{:.4f}'.format(itemn, mean[itemn]) \ |
|
for itemn in sorted(mean.keys()) \ |
|
if itemn.find('loss') == 0 |
|
] |
|
console_info += mean_info |
|
console_info.append('Time:{:.2f}s'.format( |
|
timeit.default_timer() - self.time_check)) |
|
return ' , '.join(console_info) |
|
|
|
def clear(self): |
|
self.sum = {} |
|
self.cnt = {} |
|
self.time_check = timeit.default_timer() |
|
|
|
def tensorboard_close(self): |
|
if self.tb is not None: |
|
self.tb.close() |
|
|
|
|
|
|
|
def torch_to_numpy(*argv): |
|
if len(argv) > 1: |
|
data = list(argv) |
|
else: |
|
data = argv[0] |
|
|
|
if isinstance(data, torch.Tensor): |
|
return data.to('cpu').detach().numpy() |
|
|
|
elif isinstance(data, (list, tuple)): |
|
out = [] |
|
for di in data: |
|
out.append(torch_to_numpy(di)) |
|
return out |
|
|
|
elif isinstance(data, dict): |
|
out = {} |
|
for ni, di in data.items(): |
|
out[ni] = torch_to_numpy(di) |
|
return out |
|
|
|
else: |
|
return data |
|
|