|
|
|
|
|
"""Train a video classification model.""" |
|
|
|
import numpy as np |
|
import pprint |
|
import torch |
|
from fvcore.nn.precise_bn import get_bn_modules, update_bn_stats |
|
|
|
import timesformer.models.losses as losses |
|
import timesformer.models.optimizer as optim |
|
import timesformer.utils.checkpoint as cu |
|
import timesformer.utils.distributed as du |
|
import timesformer.utils.logging as logging |
|
import timesformer.utils.metrics as metrics |
|
import timesformer.utils.misc as misc |
|
import timesformer.visualization.tensorboard_vis as tb |
|
from timesformer.datasets import loader |
|
from timesformer.models import build_model |
|
from timesformer.utils.meters import TrainMeter, ValMeter |
|
from timesformer.utils.multigrid import MultigridSchedule |
|
|
|
from timm.data import Mixup |
|
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
def train_epoch( |
|
train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer=None |
|
): |
|
""" |
|
Perform the video training for one epoch. |
|
Args: |
|
train_loader (loader): video training loader. |
|
model (model): the video model to train. |
|
optimizer (optim): the optimizer to perform optimization on the model's |
|
parameters. |
|
train_meter (TrainMeter): training meters to log the training performance. |
|
cur_epoch (int): current epoch of training. |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
writer (TensorboardWriter, optional): TensorboardWriter object |
|
to writer Tensorboard log. |
|
""" |
|
|
|
model.train() |
|
train_meter.iter_tic() |
|
data_size = len(train_loader) |
|
|
|
cur_global_batch_size = cfg.NUM_SHARDS * cfg.TRAIN.BATCH_SIZE |
|
num_iters = cfg.GLOBAL_BATCH_SIZE // cur_global_batch_size |
|
|
|
for cur_iter, (inputs, labels, _, meta) in enumerate(train_loader): |
|
|
|
if cfg.NUM_GPUS: |
|
if isinstance(inputs, (list,)): |
|
for i in range(len(inputs)): |
|
inputs[i] = inputs[i].cuda(non_blocking=True) |
|
else: |
|
inputs = inputs.cuda(non_blocking=True) |
|
labels = labels.cuda() |
|
for key, val in meta.items(): |
|
if isinstance(val, (list,)): |
|
for i in range(len(val)): |
|
val[i] = val[i].cuda(non_blocking=True) |
|
else: |
|
meta[key] = val.cuda(non_blocking=True) |
|
|
|
|
|
lr = optim.get_epoch_lr(cur_epoch + float(cur_iter) / data_size, cfg) |
|
optim.set_lr(optimizer, lr) |
|
|
|
train_meter.data_toc() |
|
|
|
|
|
if not cfg.MIXUP.ENABLED: |
|
loss_fun = losses.get_loss_func(cfg.MODEL.LOSS_FUNC)(reduction="mean") |
|
else: |
|
mixup_fn = Mixup( |
|
mixup_alpha=cfg.MIXUP.ALPHA, cutmix_alpha=cfg.MIXUP.CUTMIX_ALPHA, cutmix_minmax=cfg.MIXUP.CUTMIX_MINMAX, prob=cfg.MIXUP.PROB, switch_prob=cfg.MIXUP.SWITCH_PROB, mode=cfg.MIXUP.MODE, |
|
label_smoothing=0.1, num_classes=cfg.MODEL.NUM_CLASSES) |
|
hard_labels = labels |
|
inputs, labels = mixup_fn(inputs, labels) |
|
loss_fun = SoftTargetCrossEntropy() |
|
|
|
if cfg.DETECTION.ENABLE: |
|
preds = model(inputs, meta["boxes"]) |
|
else: |
|
preds = model(inputs) |
|
|
|
|
|
loss = loss_fun(preds, labels) |
|
|
|
if cfg.MIXUP.ENABLED: |
|
labels = hard_labels |
|
|
|
|
|
misc.check_nan_losses(loss) |
|
|
|
|
|
if cur_global_batch_size >= cfg.GLOBAL_BATCH_SIZE: |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
|
|
optimizer.step() |
|
else: |
|
if cur_iter == 0: |
|
optimizer.zero_grad() |
|
loss.backward() |
|
if (cur_iter + 1) % num_iters == 0: |
|
for p in model.parameters(): |
|
p.grad /= num_iters |
|
optimizer.step() |
|
optimizer.zero_grad() |
|
|
|
if cfg.DETECTION.ENABLE: |
|
if cfg.NUM_GPUS > 1: |
|
loss = du.all_reduce([loss])[0] |
|
loss = loss.item() |
|
|
|
|
|
train_meter.update_stats(None, None, None, loss, lr) |
|
|
|
if writer is not None: |
|
writer.add_scalars( |
|
{"Train/loss": loss, "Train/lr": lr}, |
|
global_step=data_size * cur_epoch + cur_iter, |
|
) |
|
|
|
else: |
|
top1_err, top5_err = None, None |
|
if cfg.DATA.MULTI_LABEL: |
|
|
|
if cfg.NUM_GPUS > 1: |
|
[loss] = du.all_reduce([loss]) |
|
loss = loss.item() |
|
else: |
|
|
|
num_topks_correct = metrics.topks_correct(preds, labels, (1, 5)) |
|
top1_err, top5_err = [ |
|
(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct |
|
] |
|
|
|
if cfg.NUM_GPUS > 1: |
|
loss, top1_err, top5_err = du.all_reduce( |
|
[loss, top1_err, top5_err] |
|
) |
|
|
|
|
|
loss, top1_err, top5_err = ( |
|
loss.item(), |
|
top1_err.item(), |
|
top5_err.item(), |
|
) |
|
|
|
|
|
train_meter.update_stats( |
|
top1_err, |
|
top5_err, |
|
loss, |
|
lr, |
|
inputs[0].size(0) |
|
* max( |
|
cfg.NUM_GPUS, 1 |
|
), |
|
) |
|
|
|
if writer is not None: |
|
writer.add_scalars( |
|
{ |
|
"Train/loss": loss, |
|
"Train/lr": lr, |
|
"Train/Top1_err": top1_err, |
|
"Train/Top5_err": top5_err, |
|
}, |
|
global_step=data_size * cur_epoch + cur_iter, |
|
) |
|
|
|
train_meter.iter_toc() |
|
train_meter.log_iter_stats(cur_epoch, cur_iter) |
|
train_meter.iter_tic() |
|
|
|
|
|
train_meter.log_epoch_stats(cur_epoch) |
|
train_meter.reset() |
|
|
|
|
|
@torch.no_grad() |
|
def eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer=None): |
|
""" |
|
Evaluate the model on the val set. |
|
Args: |
|
val_loader (loader): data loader to provide validation data. |
|
model (model): model to evaluate the performance. |
|
val_meter (ValMeter): meter instance to record and calculate the metrics. |
|
cur_epoch (int): number of the current epoch of training. |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
writer (TensorboardWriter, optional): TensorboardWriter object |
|
to writer Tensorboard log. |
|
""" |
|
|
|
|
|
model.eval() |
|
val_meter.iter_tic() |
|
|
|
for cur_iter, (inputs, labels, _, meta) in enumerate(val_loader): |
|
if cfg.NUM_GPUS: |
|
|
|
if isinstance(inputs, (list,)): |
|
for i in range(len(inputs)): |
|
inputs[i] = inputs[i].cuda(non_blocking=True) |
|
else: |
|
inputs = inputs.cuda(non_blocking=True) |
|
labels = labels.cuda() |
|
for key, val in meta.items(): |
|
if isinstance(val, (list,)): |
|
for i in range(len(val)): |
|
val[i] = val[i].cuda(non_blocking=True) |
|
else: |
|
meta[key] = val.cuda(non_blocking=True) |
|
val_meter.data_toc() |
|
|
|
if cfg.DETECTION.ENABLE: |
|
|
|
preds = model(inputs, meta["boxes"]) |
|
ori_boxes = meta["ori_boxes"] |
|
metadata = meta["metadata"] |
|
|
|
if cfg.NUM_GPUS: |
|
preds = preds.cpu() |
|
ori_boxes = ori_boxes.cpu() |
|
metadata = metadata.cpu() |
|
|
|
if cfg.NUM_GPUS > 1: |
|
preds = torch.cat(du.all_gather_unaligned(preds), dim=0) |
|
ori_boxes = torch.cat(du.all_gather_unaligned(ori_boxes), dim=0) |
|
metadata = torch.cat(du.all_gather_unaligned(metadata), dim=0) |
|
|
|
val_meter.iter_toc() |
|
|
|
val_meter.update_stats(preds, ori_boxes, metadata) |
|
|
|
else: |
|
preds = model(inputs) |
|
|
|
if cfg.DATA.MULTI_LABEL: |
|
if cfg.NUM_GPUS > 1: |
|
preds, labels = du.all_gather([preds, labels]) |
|
else: |
|
|
|
num_topks_correct = metrics.topks_correct(preds, labels, (1, 5)) |
|
|
|
|
|
top1_err, top5_err = [ |
|
(1.0 - x / preds.size(0)) * 100.0 for x in num_topks_correct |
|
] |
|
if cfg.NUM_GPUS > 1: |
|
top1_err, top5_err = du.all_reduce([top1_err, top5_err]) |
|
|
|
|
|
top1_err, top5_err = top1_err.item(), top5_err.item() |
|
|
|
val_meter.iter_toc() |
|
|
|
val_meter.update_stats( |
|
top1_err, |
|
top5_err, |
|
inputs[0].size(0) |
|
* max( |
|
cfg.NUM_GPUS, 1 |
|
), |
|
) |
|
|
|
if writer is not None: |
|
writer.add_scalars( |
|
{"Val/Top1_err": top1_err, "Val/Top5_err": top5_err}, |
|
global_step=len(val_loader) * cur_epoch + cur_iter, |
|
) |
|
|
|
val_meter.update_predictions(preds, labels) |
|
|
|
val_meter.log_iter_stats(cur_epoch, cur_iter) |
|
val_meter.iter_tic() |
|
|
|
|
|
val_meter.log_epoch_stats(cur_epoch) |
|
|
|
if writer is not None: |
|
if cfg.DETECTION.ENABLE: |
|
writer.add_scalars( |
|
{"Val/mAP": val_meter.full_map}, global_step=cur_epoch |
|
) |
|
else: |
|
all_preds = [pred.clone().detach() for pred in val_meter.all_preds] |
|
all_labels = [ |
|
label.clone().detach() for label in val_meter.all_labels |
|
] |
|
if cfg.NUM_GPUS: |
|
all_preds = [pred.cpu() for pred in all_preds] |
|
all_labels = [label.cpu() for label in all_labels] |
|
writer.plot_eval( |
|
preds=all_preds, labels=all_labels, global_step=cur_epoch |
|
) |
|
|
|
val_meter.reset() |
|
|
|
|
|
def calculate_and_update_precise_bn(loader, model, num_iters=200, use_gpu=True): |
|
""" |
|
Update the stats in bn layers by calculate the precise stats. |
|
Args: |
|
loader (loader): data loader to provide training data. |
|
model (model): model to update the bn stats. |
|
num_iters (int): number of iterations to compute and update the bn stats. |
|
use_gpu (bool): whether to use GPU or not. |
|
""" |
|
|
|
def _gen_loader(): |
|
for inputs, *_ in loader: |
|
if use_gpu: |
|
if isinstance(inputs, (list,)): |
|
for i in range(len(inputs)): |
|
inputs[i] = inputs[i].cuda(non_blocking=True) |
|
else: |
|
inputs = inputs.cuda(non_blocking=True) |
|
yield inputs |
|
|
|
|
|
update_bn_stats(model, _gen_loader(), num_iters) |
|
|
|
|
|
def build_trainer(cfg): |
|
""" |
|
Build training model and its associated tools, including optimizer, |
|
dataloaders and meters. |
|
Args: |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
Returns: |
|
model (nn.Module): training model. |
|
optimizer (Optimizer): optimizer. |
|
train_loader (DataLoader): training data loader. |
|
val_loader (DataLoader): validatoin data loader. |
|
precise_bn_loader (DataLoader): training data loader for computing |
|
precise BN. |
|
train_meter (TrainMeter): tool for measuring training stats. |
|
val_meter (ValMeter): tool for measuring validation stats. |
|
""" |
|
|
|
model = build_model(cfg) |
|
if du.is_master_proc() and cfg.LOG_MODEL_INFO: |
|
misc.log_model_info(model, cfg, use_train_input=True) |
|
|
|
|
|
optimizer = optim.construct_optimizer(model, cfg) |
|
|
|
|
|
train_loader = loader.construct_loader(cfg, "train") |
|
val_loader = loader.construct_loader(cfg, "val") |
|
|
|
precise_bn_loader = loader.construct_loader( |
|
cfg, "train", is_precise_bn=True |
|
) |
|
|
|
train_meter = TrainMeter(len(train_loader), cfg) |
|
val_meter = ValMeter(len(val_loader), cfg) |
|
|
|
return ( |
|
model, |
|
optimizer, |
|
train_loader, |
|
val_loader, |
|
precise_bn_loader, |
|
train_meter, |
|
val_meter, |
|
) |
|
|
|
|
|
def train(cfg): |
|
""" |
|
Train a video model for many epochs on train set and evaluate it on val set. |
|
Args: |
|
cfg (CfgNode): configs. Details can be found in |
|
slowfast/config/defaults.py |
|
""" |
|
|
|
du.init_distributed_training(cfg) |
|
|
|
np.random.seed(cfg.RNG_SEED) |
|
torch.manual_seed(cfg.RNG_SEED) |
|
|
|
|
|
logging.setup_logging(cfg.OUTPUT_DIR) |
|
|
|
|
|
multigrid = None |
|
if cfg.MULTIGRID.LONG_CYCLE or cfg.MULTIGRID.SHORT_CYCLE: |
|
multigrid = MultigridSchedule() |
|
cfg = multigrid.init_multigrid(cfg) |
|
if cfg.MULTIGRID.LONG_CYCLE: |
|
cfg, _ = multigrid.update_long_cycle(cfg, cur_epoch=0) |
|
|
|
logger.info("Train with config:") |
|
logger.info(pprint.pformat(cfg)) |
|
|
|
|
|
model = build_model(cfg) |
|
if du.is_master_proc() and cfg.LOG_MODEL_INFO: |
|
misc.log_model_info(model, cfg, use_train_input=True) |
|
|
|
|
|
optimizer = optim.construct_optimizer(model, cfg) |
|
|
|
|
|
if not cfg.TRAIN.FINETUNE: |
|
start_epoch = cu.load_train_checkpoint(cfg, model, optimizer) |
|
else: |
|
start_epoch = 0 |
|
cu.load_checkpoint(cfg.TRAIN.CHECKPOINT_FILE_PATH, model) |
|
|
|
|
|
train_loader = loader.construct_loader(cfg, "train") |
|
val_loader = loader.construct_loader(cfg, "val") |
|
|
|
precise_bn_loader = ( |
|
loader.construct_loader(cfg, "train", is_precise_bn=True) |
|
if cfg.BN.USE_PRECISE_STATS |
|
else None |
|
) |
|
|
|
train_meter = TrainMeter(len(train_loader), cfg) |
|
val_meter = ValMeter(len(val_loader), cfg) |
|
|
|
|
|
if cfg.TENSORBOARD.ENABLE and du.is_master_proc( |
|
cfg.NUM_GPUS * cfg.NUM_SHARDS |
|
): |
|
writer = tb.TensorboardWriter(cfg) |
|
else: |
|
writer = None |
|
|
|
|
|
logger.info("Start epoch: {}".format(start_epoch + 1)) |
|
|
|
for cur_epoch in range(start_epoch, cfg.SOLVER.MAX_EPOCH): |
|
if cfg.MULTIGRID.LONG_CYCLE: |
|
cfg, changed = multigrid.update_long_cycle(cfg, cur_epoch) |
|
if changed: |
|
( |
|
model, |
|
optimizer, |
|
train_loader, |
|
val_loader, |
|
precise_bn_loader, |
|
train_meter, |
|
val_meter, |
|
) = build_trainer(cfg) |
|
|
|
|
|
if cu.has_checkpoint(cfg.OUTPUT_DIR): |
|
last_checkpoint = cu.get_last_checkpoint(cfg.OUTPUT_DIR) |
|
assert "{:05d}.pyth".format(cur_epoch) in last_checkpoint |
|
else: |
|
last_checkpoint = cfg.TRAIN.CHECKPOINT_FILE_PATH |
|
logger.info("Load from {}".format(last_checkpoint)) |
|
cu.load_checkpoint( |
|
last_checkpoint, model, cfg.NUM_GPUS > 1, optimizer |
|
) |
|
|
|
|
|
loader.shuffle_dataset(train_loader, cur_epoch) |
|
|
|
|
|
train_epoch( |
|
train_loader, model, optimizer, train_meter, cur_epoch, cfg, writer |
|
) |
|
|
|
is_checkp_epoch = cu.is_checkpoint_epoch( |
|
cfg, |
|
cur_epoch, |
|
None if multigrid is None else multigrid.schedule, |
|
) |
|
is_eval_epoch = misc.is_eval_epoch( |
|
cfg, cur_epoch, None if multigrid is None else multigrid.schedule |
|
) |
|
|
|
|
|
if ( |
|
(is_checkp_epoch or is_eval_epoch) |
|
and cfg.BN.USE_PRECISE_STATS |
|
and len(get_bn_modules(model)) > 0 |
|
): |
|
calculate_and_update_precise_bn( |
|
precise_bn_loader, |
|
model, |
|
min(cfg.BN.NUM_BATCHES_PRECISE, len(precise_bn_loader)), |
|
cfg.NUM_GPUS > 0, |
|
) |
|
_ = misc.aggregate_sub_bn_stats(model) |
|
|
|
|
|
if is_checkp_epoch: |
|
cu.save_checkpoint(cfg.OUTPUT_DIR, model, optimizer, cur_epoch, cfg) |
|
|
|
if is_eval_epoch: |
|
eval_epoch(val_loader, model, val_meter, cur_epoch, cfg, writer) |
|
|
|
if writer is not None: |
|
writer.close() |
|
|