MINIMA / third_party /XoFTR /src /lightning /lightning_xoftr_pretrain.py
lsxi77777's picture
commit message
a930e1f
from loguru import logger
import torch
import pytorch_lightning as pl
from matplotlib import pyplot as plt
plt.switch_backend('agg')
from src.xoftr import XoFTR_Pretrain
from src.losses.xoftr_loss_pretrain import XoFTRLossPretrain
from src.optimizers import build_optimizer, build_scheduler
from src.utils.plotting import make_mae_figures
from src.utils.comm import all_gather
from src.utils.misc import lower_config, flattenList
from src.utils.profiler import PassThroughProfiler
from src.utils.pretrain_utils import generate_random_masks, get_target
class PL_XoFTR_Pretrain(pl.LightningModule):
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
"""
TODO:
- use the new version of PL logging API.
"""
super().__init__()
# Misc
self.config = config # full config
_config = lower_config(self.config)
self.xoftr_cfg = lower_config(_config['xoftr'])
self.profiler = profiler or PassThroughProfiler()
self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
# generator to create the same masks for validation
self.val_seed = self.config.PRETRAIN.VAL_SEED
self.val_generator = torch.Generator(device="cuda").manual_seed(self.val_seed)
self.mae_margins = config.PRETRAIN.MAE_MARGINS
# Matcher: XoFTR
self.matcher = XoFTR_Pretrain(config=_config['xoftr'])
self.loss = XoFTRLossPretrain(_config)
# Pretrained weights
if pretrained_ckpt:
state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
self.matcher.load_state_dict(state_dict, strict=False)
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
# Testing
self.dump_dir = dump_dir
def configure_optimizers(self):
# FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
optimizer = build_optimizer(self, self.config)
scheduler = build_scheduler(self.config, optimizer)
return [optimizer], [scheduler]
def optimizer_step(
self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# learning rate warm up
warmup_step = self.config.TRAINER.WARMUP_STEP
if self.trainer.global_step < warmup_step:
if self.config.TRAINER.WARMUP_TYPE == 'linear':
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
lr = base_lr + \
(self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
abs(self.config.TRAINER.TRUE_LR - base_lr)
for pg in optimizer.param_groups:
pg['lr'] = lr
elif self.config.TRAINER.WARMUP_TYPE == 'constant':
pass
else:
raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
# update params
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
def _trainval_inference(self, batch, generator=None):
generate_random_masks(batch,
patch_size=self.config.PRETRAIN.PATCH_SIZE,
mask_ratio=self.config.PRETRAIN.MASK_RATIO,
generator=generator,
margins=self.mae_margins)
with self.profiler.profile("XoFTR"):
self.matcher(batch)
with self.profiler.profile("Compute losses"):
# Create target pacthes to reconstruct
get_target(batch)
self.loss(batch)
def training_step(self, batch, batch_idx):
self._trainval_inference(batch)
# logging
if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
# scalars
for k, v in batch['loss_scalars'].items():
self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
if self.config.TRAINER.USE_WANDB:
self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
if self.config.TRAINER.ENABLE_PLOTTING:
figures = make_mae_figures(batch)
for i, figure in enumerate(figures):
self.logger[0].experiment.add_figure(
f'train_mae/node_{self.trainer.global_rank}-device_{self.device.index}-batch_{i}',
figure, self.global_step)
return {'loss': batch['loss']}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
if self.trainer.global_rank == 0:
self.logger[0].experiment.add_scalar(
'train/avg_loss_on_epoch', avg_loss,
global_step=self.current_epoch)
if self.config.TRAINER.USE_WANDB:
self.logger[1].log_metrics(
{'train/avg_loss_on_epoch': avg_loss},
self.current_epoch)
def validation_step(self, batch, batch_idx):
self._trainval_inference(batch, self.val_generator)
val_plot_interval = max(self.trainer.num_val_batches[0] // \
(self.trainer.num_gpus * self.n_vals_plot), 1)
figures = []
if batch_idx % val_plot_interval == 0:
figures = make_mae_figures(batch)
return {
'loss_scalars': batch['loss_scalars'],
'figures': figures,
}
def validation_epoch_end(self, outputs):
self.val_generator.manual_seed(self.val_seed)
# handle multiple validation sets
multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
for valset_idx, outputs in enumerate(multi_outputs):
# since pl performs sanity_check at the very begining of the training
cur_epoch = self.trainer.current_epoch
if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
cur_epoch = -1
# 1. loss_scalars: dict of list, on cpu
_loss_scalars = [o['loss_scalars'] for o in outputs]
loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
_figures = [o['figures'] for o in outputs]
figures = [item for sublist in _figures for item in sublist]
# tensorboard records only on rank 0
if self.trainer.global_rank == 0:
for k, v in loss_scalars.items():
mean_v = torch.stack(v).mean()
self.logger[0].experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
if self.config.TRAINER.USE_WANDB:
self.logger[1].log_metrics({f'val_{valset_idx}/avg_{k}': mean_v}, cur_epoch)
for plot_idx, fig in enumerate(figures):
self.logger[0].experiment.add_figure(
f'val_mae_{valset_idx}/pair-{plot_idx}', fig, cur_epoch, close=True)
plt.close('all')