|
import time |
|
import tqdm |
|
import os |
|
import json |
|
import pickle |
|
import sys |
|
import copy |
|
import numpy as np |
|
import itertools |
|
import random |
|
import torch |
|
import io |
|
from torch.nn.parallel import DistributedDataParallel |
|
from torch.cuda.amp import autocast |
|
from .unified_tester import tester, dict_to_cuda, list_to_cuda, move_to_cuda |
|
from collections import OrderedDict |
|
from uniperceiver.evaluation import build_evaluation |
|
import uniperceiver.utils.comm as comm |
|
from uniperceiver.utils.engine_util import * |
|
from .build import ENGINE_REGISTRY |
|
from uniperceiver.datasets import ( |
|
build_standard_valtest_loader, |
|
build_unified_train_loader, |
|
) |
|
|
|
from uniperceiver.utils.events import get_event_storage |
|
from uniperceiver.utils.events import EventStorage |
|
from omegaconf import DictConfig |
|
from uniperceiver.losses import build_losses |
|
from uniperceiver.optim import build_optimizer |
|
from uniperceiver.modeling import build_model |
|
from uniperceiver.lr_scheduler import build_lr_scheduler |
|
from torch.cuda.amp import autocast |
|
from uniperceiver.checkpoint import TorchCheckpointer |
|
|
|
import logging |
|
import math |
|
import weakref |
|
|
|
from uniperceiver.config import CfgNode |
|
|
|
|
|
from . import hooks |
|
|
|
|
|
from timm.data import Mixup |
|
from timm.utils import ModelEma |
|
from uniperceiver.utils.misc import NativeScalerWithGradNormCount as NativeScaler |
|
from uniperceiver.utils.misc import ApexScalerWithGradNormCount as ApexScaler |
|
|
|
from collections import defaultdict |
|
from .train_loop import TrainerBase |
|
from uniperceiver.utils.logger import setup_logger |
|
|
|
try: |
|
from apex import amp |
|
APEX_INSTALLED = True |
|
except: |
|
print('apex has not been installed.') |
|
APEX_INSTALLED = False |
|
|
|
__all__ = ['UnifiedTrainer'] |
|
|
|
|
|
@ENGINE_REGISTRY.register() |
|
class UnifiedTrainer(TrainerBase): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.logger = logging.getLogger(__name__) |
|
if not self.logger.isEnabledFor( |
|
logging.INFO): |
|
setup_logger() |
|
|
|
self.task_cfg = dict() |
|
self.task_names = [] |
|
for task in cfg.TASKS: |
|
name = task['NAME'] |
|
self.task_names.append(name) |
|
|
|
|
|
self.task_cfg[name] = CfgNode(task) |
|
|
|
self.cfg = cfg |
|
|
|
|
|
model = self.build_model(cfg) |
|
self.logger.info("Model Creation Done") |
|
|
|
self.apex_need_reload = False |
|
|
|
self.optimizer = self.build_optimizer(cfg, model) |
|
|
|
if cfg.SOLVER.APEX_FP16 and APEX_INSTALLED: |
|
self.apex_fp16 = True |
|
|
|
model, self.optimizer = amp.initialize(model, |
|
self.optimizer, |
|
opt_level=self.cfg.SOLVER.APEX_OPT_LEVEL, |
|
master_weights=self.cfg.SOLVER.APEX_MASTER_WEIGHTS, |
|
min_loss_scale=self.cfg.SOLVER.MIN_LOSS_SCLE, |
|
loss_scale="dynamic") |
|
|
|
|
|
if comm.get_world_size() > 1: |
|
model = DistributedDataParallel( |
|
model, |
|
find_unused_parameters=cfg.find_unused_parameters, |
|
device_ids=[comm.get_local_rank()], |
|
broadcast_buffers=False) |
|
self.model = model |
|
|
|
|
|
self.model.train() |
|
|
|
self.train_data_loader = build_train_loader(cfg, self.task_cfg, self.model) |
|
self.val_data_loader = build_val_loader(cfg, self.task_cfg) |
|
self.test_data_loader = build_test_loader(cfg, self.task_cfg) |
|
|
|
if isinstance(self.train_data_loader, list): |
|
self.iters_per_epoch_list = [ |
|
len(loader) for loader in self.train_data_loader |
|
] |
|
self._train_data_loader_iter_list = [ |
|
iter(loader) for loader in self.train_data_loader |
|
] |
|
|
|
self.iters_per_epoch = len(self.train_data_loader[0]) |
|
self._train_data_loader_iter = iter(self.train_data_loader[0]) |
|
else: |
|
self.iters_per_epoch = len(self.train_data_loader) |
|
self._train_data_loader_iter = iter(self.train_data_loader) |
|
|
|
if self.val_data_loader is not None: |
|
self.val_evaluator = build_evaluation(cfg, |
|
cfg.INFERENCE.VAL_ANNFILE, |
|
None) |
|
else: |
|
self.val_evaluator = None |
|
|
|
if self.test_data_loader is not None: |
|
self.test_evaluator = build_evaluation(cfg, |
|
cfg.INFERENCE.TEST_ANNFILE, |
|
cfg.OUTPUT_DIR) |
|
else: |
|
self.test_evaluator = None |
|
|
|
self.ss_prob = 0.0 |
|
|
|
|
|
self.model_ema = None |
|
if cfg.MODEL.MODEL_EMA: |
|
self.model_ema = ModelEma( |
|
self.model, |
|
decay=cfg.MODEL.MODEL_EMA_DECAY, |
|
device='cpu' if cfg.MODEL.MODEL_EMA_FORCE_CPU else '', |
|
resume='') |
|
|
|
self.checkpointer = TorchCheckpointer( |
|
|
|
self.model, |
|
self.model_ema, |
|
cfg.OUTPUT_DIR, |
|
trainer=weakref.proxy(self), |
|
checkpoint_mapping=cfg.SOLVER.CHECKPOINT_MAPPING, |
|
mapping=cfg.SOLVER.CHECKPOINT_MAP, |
|
resume_tau=cfg.SOLVER.RESUME_TAU, |
|
ceph_save=cfg.SOLVER.CHECKPOINT_CEPH_SAVE, |
|
ceph_config=cfg.DATALOADER.get("TCS_CONF_PATH", |
|
"petreloss.config"), |
|
) |
|
self.checkpointer.add_checkpointable('optimizer', self.optimizer) |
|
|
|
if cfg.MODEL.MODEL_EMA: |
|
self.checkpointer.add_checkpointable('ema_model',self.model_ema.ema) |
|
|
|
self.start_iter = 0 |
|
self.max_iter = cfg.SOLVER.EPOCH * self.iters_per_epoch |
|
self.register_hooks(self.build_hooks()) |
|
|
|
if cfg.SOLVER.AMP_FP16: |
|
|
|
self.amp_scaler = NativeScaler(enabled=True, growth_interval=cfg.SOLVER.LOSS_SCALE_WINDOW) |
|
self.amp_fp16=True |
|
else: |
|
self.amp_scaler = NativeScaler(enabled=False) |
|
self.amp_fp16=False |
|
|
|
if cfg.SOLVER.APEX_FP16 and APEX_INSTALLED: |
|
|
|
self.amp_scaler = ApexScaler(enabled=True) |
|
|
|
else: |
|
self.apex_fp16 = False |
|
|
|
self.fp16 = cfg.SOLVER.AMP_FP16 or cfg.SOLVER.APEX_FP16 |
|
self.bf16 = cfg.SOLVER.BF16 |
|
if self.fp16: |
|
assert not self.bf16 |
|
|
|
if self.amp_scaler is not None: |
|
self.checkpointer.add_checkpointable('amp_scaler', self.amp_scaler) |
|
|
|
|
|
self.val_evaluator = dict() |
|
self.test_evaluator = dict() |
|
self.mixup_fn = dict() |
|
for name, new_cfg in self.task_cfg.items(): |
|
if self.val_data_loader[name]: |
|
self.val_evaluator[name] = build_evaluation( |
|
new_cfg, new_cfg.INFERENCE.VAL_ANNFILE, cfg.OUTPUT_DIR) |
|
else: |
|
self.val_evaluator[name] = None |
|
if self.test_data_loader[name]: |
|
self.test_evaluator[name] = build_evaluation(new_cfg, new_cfg.INFERENCE.TEST_ANNFILE, cfg.OUTPUT_DIR) |
|
else: |
|
self.test_evaluator[name] = None |
|
|
|
if new_cfg.DATALOADER.MIXUP > 0 or new_cfg.DATALOADER.CUTMIX > 0: |
|
self.mixup_fn[name] = Mixup( |
|
mixup_alpha=new_cfg.DATALOADER.MIXUP, cutmix_alpha=new_cfg.DATALOADER.CUTMIX, cutmix_minmax=None, |
|
prob=new_cfg.DATALOADER.MIXUP_PROB, switch_prob=new_cfg.DATALOADER.MIXUP_SWITCH_PROB, mode=new_cfg.DATALOADER.MIXUP_MODE, |
|
label_smoothing=new_cfg.DATALOADER.MIXUP_LABEL_SMOOTHING, num_classes=new_cfg.MODEL.LABELS_NUM) |
|
else: |
|
self.mixup_fn[name] = None |
|
|
|
if cfg.DATALOADER.USE_WEIGHTED_SAMPLER: |
|
|
|
self.iters_per_epoch = 1 |
|
|
|
|
|
self.scheduler = self.build_lr_scheduler(cfg, self.optimizer, self.iters_per_epoch) |
|
self.checkpointer.add_checkpointable('scheduler', self.scheduler) |
|
|
|
self.accum_iter = max(1, cfg.SOLVER.ACCUM_ITER) |
|
self.step_index = 0 |
|
|
|
self.grad_print = getattr(cfg.SOLVER, "GRAD_PRINT", False) |
|
|
|
if self.cfg.SOLVER.GradHistogram: |
|
assert self.cfg.SOLVER.TORCH_OPTIMIZER and self.cfg.SOLVER.PARAMS_SEPERATE |
|
|
|
def resume_or_load(self, resume=True): |
|
|
|
self.checkpointer.resume_or_load(self.cfg.MODEL.WEIGHTS, |
|
resume=resume, |
|
resume_optmizer=self.cfg.SOLVER.RESUME_OPTIMIZER) |
|
if resume and self.checkpointer.has_checkpoint(): |
|
self.start_iter = self.iter + 1 |
|
|
|
if self.apex_fp16: |
|
self.apex_need_reload = True |
|
|
|
@classmethod |
|
def build_losses(cls, cfg): |
|
losses = {} |
|
for task_config in cfg.TASKS: |
|
task_config = DictConfig(task_config) |
|
losses[task_config.NAME] = build_losses(task_config) |
|
|
|
return losses |
|
|
|
def build_hooks(self): |
|
|
|
self.max_iter = self.cfg.SOLVER.MAX_ITER |
|
cfg = self.cfg.clone() |
|
cfg.defrost() |
|
cfg.DATALOADER.NUM_WORKERS = 0 |
|
|
|
ret = [ |
|
hooks.IterationTimer(), |
|
hooks.LRScheduler(), |
|
hooks.ModelWeightsManipulating() |
|
] |
|
|
|
|
|
|
|
|
|
|
|
if comm.is_main_process(): |
|
ret.append(hooks.PeriodicCheckpointer(self.checkpointer, cfg.SOLVER.CHECKPOINT_PERIOD, |
|
max_to_keep= cfg.SOLVER.CHECKPOINT_MAX_SAVE )) |
|
|
|
def test_and_save_results(epoch): |
|
eval_results = self.test(self.cfg, self.model, self.test_data_loader, self.test_evaluator, epoch) |
|
return eval_results |
|
|
|
def val_and_save_results(epoch): |
|
eval_results = self.test(self.cfg, self.model, self.val_data_loader, self.val_evaluator, epoch) |
|
return eval_results |
|
|
|
if self.model_ema is not None: |
|
|
|
def test_and_save_results_ema(epoch): |
|
eval_results = self.test(self.cfg, self.model_ema.ema, |
|
self.test_data_loader, |
|
self.test_evaluator, epoch) |
|
ema_results = {} |
|
for taskname, taskresults in eval_results.items(): |
|
if isinstance(taskresults, dict): |
|
taskresults = { |
|
f'{k}_ema': v |
|
for k, v in taskresults.items() |
|
} |
|
ema_results[taskname] = taskresults |
|
|
|
return ema_results |
|
|
|
def val_and_save_results_ema(epoch): |
|
eval_results = self.test(self.cfg, self.model_ema.ema, |
|
self.val_data_loader, |
|
self.val_evaluator, epoch) |
|
ema_results = {} |
|
for taskname, taskresults in eval_results.items(): |
|
if isinstance(taskresults, dict): |
|
taskresults = {f'{k}_ema': v for k, v in taskresults.items()} |
|
ema_results[taskname] = taskresults |
|
|
|
return ema_results |
|
|
|
|
|
|
|
if self.val_data_loader is not None: |
|
ret.append( |
|
hooks.IterEvalHook( |
|
eval_period = cfg.SOLVER.EVAL_PERIOD, |
|
eval_start = cfg.INFERENCE.VAL_EVAL_START, |
|
eval_function = val_and_save_results, |
|
stage = 'val', |
|
multi_gpu_eval=True |
|
)) |
|
if self.model_ema is not None: |
|
ret.append( |
|
hooks.IterEvalHook( |
|
eval_period = cfg.SOLVER.EVAL_PERIOD, |
|
eval_start = cfg.INFERENCE.VAL_EVAL_START, |
|
eval_function = val_and_save_results_ema, |
|
stage = 'val', |
|
multi_gpu_eval=True |
|
)) |
|
|
|
if self.test_data_loader is not None: |
|
ret.append( |
|
hooks.IterEvalHook( |
|
eval_period = cfg.SOLVER.EVAL_PERIOD, |
|
eval_start = cfg.INFERENCE.TEST_EVAL_START, |
|
eval_function = test_and_save_results, |
|
stage = 'test', |
|
multi_gpu_eval=True |
|
)) |
|
if self.model_ema is not None: |
|
ret.append( |
|
hooks.IterEvalHook( |
|
eval_period=cfg.SOLVER.EVAL_PERIOD, |
|
eval_start=cfg.INFERENCE.TEST_EVAL_START, |
|
eval_function=test_and_save_results_ema, |
|
stage='test', |
|
multi_gpu_eval=True)) |
|
|
|
if comm.is_main_process(): |
|
|
|
|
|
ret.append(hooks.PeriodicWriter(build_writers(cfg, self.max_iter), period=cfg.SOLVER.WRITE_PERIOD)) |
|
|
|
return ret |
|
|
|
def train(self): |
|
""" |
|
Args: |
|
start_iter, max_iter (int): See docs above |
|
""" |
|
start_iter = self.start_iter |
|
max_iter = self.max_iter |
|
logger = logging.getLogger(__name__) |
|
logger.info("Starting training from iteration {}".format(start_iter)) |
|
|
|
self.iter = self.start_iter = start_iter |
|
self.max_iter = max_iter |
|
|
|
with EventStorage(start_iter) as self.storage: |
|
try: |
|
|
|
self.before_train() |
|
for self.iter in range(start_iter, max_iter): |
|
self.before_step() |
|
|
|
self.run_step_torch() |
|
|
|
self.after_step() |
|
|
|
if self.apex_need_reload: |
|
optimizer_state_dict = torch.load(self.checkpointer.get_checkpoint_file())['optimizer'] |
|
self.optimizer.load_state_dict(optimizer_state_dict) |
|
self.apex_need_reload = False |
|
|
|
self.iter += 1 |
|
except Exception: |
|
logger.exception("Exception during training:") |
|
raise |
|
finally: |
|
self.after_train() |
|
|
|
@classmethod |
|
def build_model(cls, cfg): |
|
model = build_model(cfg) |
|
logger = logging.getLogger(__name__) |
|
logger.info("Model:\n{}".format(model)) |
|
return model |
|
|
|
@classmethod |
|
def build_optimizer(cls, cfg, model): |
|
logger = logging.getLogger(__name__) |
|
logger.info("building optimizer...") |
|
return build_optimizer(cfg, model) |
|
|
|
@classmethod |
|
def build_lr_scheduler(cls, cfg, optimizer, iters_per_epoch): |
|
logger = logging.getLogger(__name__) |
|
logger.info("building lr_scheduler...") |
|
return build_lr_scheduler(cfg, optimizer, iters_per_epoch) |
|
|
|
def run_step_torch(self): |
|
if self.accum_iter > 1: |
|
for micro_step in range(self.accum_iter): |
|
self.micro_step = micro_step |
|
self.run_min_batch() |
|
else: |
|
self.micro_step = 0 |
|
self.run_min_batch() |
|
|
|
def run_min_batch(self): |
|
timer_fn = time.perf_counter |
|
assert self.model.training, "[SimpleTrainer] model was changed to eval mode!" |
|
torch.cuda.synchronize() |
|
|
|
start = timer_fn() |
|
data = get_batch_data(self.cfg, self._train_data_loader_iter, self.train_data_loader) |
|
data_time = time.perf_counter() - start |
|
|
|
task = data['task_info']['task_name'] |
|
data = move_to_cuda(data) |
|
|
|
|
|
if self.mixup_fn[task] is not None: |
|
|
|
data['input_sample_list'][0]["data"], data[ |
|
'target_idx_list'][0] = self.mixup_fn[task]( |
|
data['input_sample_list'][0]["data"], data["target_idx_list"][0]) |
|
|
|
if not self.amp_fp16: |
|
losses_dict = self.model(data) |
|
|
|
else: |
|
with autocast(self.amp_fp16): |
|
losses_dict = self.model(data) |
|
|
|
losses = sum(losses_dict.values()) |
|
|
|
|
|
losses /= self.accum_iter |
|
|
|
total_grad = self.amp_scaler(losses, self.optimizer, clip_grad=self.cfg.SOLVER.GRAD_CLIP, |
|
parameters=self.model.parameters(), create_graph=False, |
|
update_grad=(self.micro_step + 1 == self.accum_iter), fp16=self.fp16, iter=self.iter, |
|
min_loss_scale=self.cfg.SOLVER.MIN_LOSS_SCLE, |
|
loss_scale_window=self.cfg.SOLVER.LOSS_SCALE_WINDOW) |
|
|
|
if self.micro_step + 1 != self.accum_iter: |
|
return |
|
|
|
if self.micro_step + 1 == self.accum_iter: |
|
write_metrics(losses_dict, data_time, task + '/') |
|
|
|
if comm.is_main_process(): |
|
storage = get_event_storage() |
|
if torch.logical_or(total_grad.isnan(), total_grad.isinf()): |
|
logger = logging.getLogger(__name__) |
|
logger.info('grad to nan or inf in task {} {}'.format(task, total_grad)) |
|
storage.put_scalar("total_grad", total_grad, smoothing_hint=False) |
|
|
|
if self.apex_need_reload: |
|
pass |
|
else: |
|
self.amp_scaler.step(self.optimizer) |
|
|
|
if comm.is_main_process(): |
|
storage.put_scalar("amp_scale", self.amp_scaler.get_scale(), smoothing_hint=False) |
|
if hasattr(comm.unwrap_model(self.model).loss_prepare, 'temperature_dict'): |
|
if isinstance(comm.unwrap_model(self.model).loss_prepare, torch.nn.ModuleList): |
|
temperature_dict = comm.unwrap_model(self.model).loss_prepare[-1].temperature_dict |
|
else: |
|
temperature_dict = comm.unwrap_model(self.model).loss_prepare.temperature_dict |
|
storage.put_scalars(**temperature_dict, smoothing_hint=False) |
|
|
|
if self.amp_fp16: |
|
self.amp_scaler.update() |
|
|
|
|
|
self.optimizer.zero_grad() |
|
if self.model_ema is not None: |
|
self.model_ema.update(self.model) |
|
torch.cuda.synchronize() |
|
|
|
def cast_layers(self): |
|
logger = self.logger |
|
if self.cfg.MODEL.LN_FP32: |
|
logger.info("cast LN to fp32") |
|
|
|
def cast_ln_fp32(module): |
|
if isinstance(module, CustomLayernorm): |
|
module.float() |
|
|
|
self.model_engine.module.apply(cast_ln_fp32) |
|
|
|
if self.iter == 0: |
|
comm.unwrap_model(self.model).operatedweight() |
|
|
|
|
|
|
|
def test(self, cfg, model, test_data_loader, evaluator, epoch): |
|
return tester(self.task_cfg, model, test_data_loader, evaluator, epoch, self.amp_fp16, self.apex_fp16) |
|
|