MINIMA / third_party /XoFTR /src /lightning /lightning_xoftr.py
lsxi77777's picture
commit message
a930e1f
from collections import defaultdict
import pprint
from loguru import logger
from pathlib import Path
import torch
import numpy as np
import pytorch_lightning as pl
from matplotlib import pyplot as plt
plt.switch_backend('agg')
from src.xoftr import XoFTR
from src.xoftr.utils.supervision import compute_supervision_coarse, compute_supervision_fine
from src.losses.xoftr_loss import XoFTRLoss
from src.optimizers import build_optimizer, build_scheduler
from src.utils.metrics import (
compute_symmetrical_epipolar_errors,
compute_pose_errors,
aggregate_metrics
)
from src.utils.plotting import make_matching_figures
from src.utils.comm import gather, all_gather
from src.utils.misc import lower_config, flattenList
from src.utils.profiler import PassThroughProfiler
class PL_XoFTR(pl.LightningModule):
def __init__(self, config, pretrained_ckpt=None, profiler=None, dump_dir=None):
"""
TODO:
- use the new version of PL logging API.
"""
super().__init__()
# Misc
self.config = config # full config
_config = lower_config(self.config)
self.xoftr_cfg = lower_config(_config['xoftr'])
self.profiler = profiler or PassThroughProfiler()
self.n_vals_plot = max(config.TRAINER.N_VAL_PAIRS_TO_PLOT // config.TRAINER.WORLD_SIZE, 1)
# Matcher: XoFTR
self.matcher = XoFTR(config=_config['xoftr'])
self.loss = XoFTRLoss(_config)
# Pretrained weights
if pretrained_ckpt:
state_dict = torch.load(pretrained_ckpt, map_location='cpu')['state_dict']
self.matcher.load_state_dict(state_dict, strict=False)
logger.info(f"Load \'{pretrained_ckpt}\' as pretrained checkpoint")
for name, param in self.matcher.named_parameters():
if name in state_dict.keys():
print("in ckpt: ", name)
else:
print("out ckpt: ", name)
# Testing
self.dump_dir = dump_dir
def configure_optimizers(self):
# FIXME: The scheduler did not work properly when `--resume_from_checkpoint`
optimizer = build_optimizer(self, self.config)
scheduler = build_scheduler(self.config, optimizer)
return [optimizer], [scheduler]
def optimizer_step(
self, epoch, batch_idx, optimizer, optimizer_idx,
optimizer_closure, on_tpu, using_native_amp, using_lbfgs):
# learning rate warm up
warmup_step = self.config.TRAINER.WARMUP_STEP
if self.trainer.global_step < warmup_step:
if self.config.TRAINER.WARMUP_TYPE == 'linear':
base_lr = self.config.TRAINER.WARMUP_RATIO * self.config.TRAINER.TRUE_LR
lr = base_lr + \
(self.trainer.global_step / self.config.TRAINER.WARMUP_STEP) * \
abs(self.config.TRAINER.TRUE_LR - base_lr)
for pg in optimizer.param_groups:
pg['lr'] = lr
elif self.config.TRAINER.WARMUP_TYPE == 'constant':
pass
else:
raise ValueError(f'Unknown lr warm-up strategy: {self.config.TRAINER.WARMUP_TYPE}')
# update params
optimizer.step(closure=optimizer_closure)
optimizer.zero_grad()
def _trainval_inference(self, batch):
with self.profiler.profile("Compute coarse supervision"):
compute_supervision_coarse(batch, self.config)
with self.profiler.profile("XoFTR"):
self.matcher(batch)
with self.profiler.profile("Compute fine supervision"):
compute_supervision_fine(batch, self.config)
with self.profiler.profile("Compute losses"):
self.loss(batch)
def _compute_metrics(self, batch):
with self.profiler.profile("Copmute metrics"):
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
compute_pose_errors(batch, self.config) # compute R_errs, t_errs, pose_errs for each pair
rel_pair_names = list(zip(*batch['pair_names']))
bs = batch['image0'].size(0)
metrics = {
# to filter duplicate pairs caused by DistributedSampler
'identifiers': ['#'.join(rel_pair_names[b]) for b in range(bs)],
'epi_errs': [batch['epi_errs'][batch['m_bids'] == b].cpu().numpy() for b in range(bs)],
'R_errs': batch['R_errs'],
't_errs': batch['t_errs'],
'inliers': batch['inliers']}
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
metrics.update({'scene_id': batch['scene_id']})
ret_dict = {'metrics': metrics}
return ret_dict, rel_pair_names
def training_step(self, batch, batch_idx):
self._trainval_inference(batch)
# logging
if self.trainer.global_rank == 0 and self.global_step % self.trainer.log_every_n_steps == 0:
# scalars
for k, v in batch['loss_scalars'].items():
self.logger[0].experiment.add_scalar(f'train/{k}', v, self.global_step)
if self.config.TRAINER.USE_WANDB:
self.logger[1].log_metrics({f'train/{k}': v}, self.global_step)
# figures
if self.config.TRAINER.ENABLE_PLOTTING:
compute_symmetrical_epipolar_errors(batch) # compute epi_errs for each match
figures = make_matching_figures(batch, self.config, self.config.TRAINER.PLOT_MODE)
for k, v in figures.items():
self.logger[0].experiment.add_figure(f'train_match/{k}', v, self.global_step)
return {'loss': batch['loss']}
def training_epoch_end(self, outputs):
avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
if self.trainer.global_rank == 0:
self.logger[0].experiment.add_scalar(
'train/avg_loss_on_epoch', avg_loss,
global_step=self.current_epoch)
if self.config.TRAINER.USE_WANDB:
self.logger[1].log_metrics(
{'train/avg_loss_on_epoch': avg_loss},
self.current_epoch)
def validation_step(self, batch, batch_idx):
# no loss calculation for VisTir during val
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
with self.profiler.profile("XoFTR"):
self.matcher(batch)
else:
self._trainval_inference(batch)
ret_dict, _ = self._compute_metrics(batch)
val_plot_interval = max(self.trainer.num_val_batches[0] // self.n_vals_plot, 1)
figures = {self.config.TRAINER.PLOT_MODE: []}
if batch_idx % val_plot_interval == 0:
figures = make_matching_figures(batch, self.config, mode=self.config.TRAINER.PLOT_MODE, ret_dict=ret_dict)
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
return {
**ret_dict,
'figures': figures,
}
else:
return {
**ret_dict,
'loss_scalars': batch['loss_scalars'],
'figures': figures,
}
def validation_epoch_end(self, outputs):
# handle multiple validation sets
multi_outputs = [outputs] if not isinstance(outputs[0], (list, tuple)) else outputs
multi_val_metrics = defaultdict(list)
for valset_idx, outputs in enumerate(multi_outputs):
# since pl performs sanity_check at the very begining of the training
cur_epoch = self.trainer.current_epoch
if not self.trainer.resume_from_checkpoint and self.trainer.running_sanity_check:
cur_epoch = -1
if self.config.DATASET.VAL_DATA_SOURCE == "VisTir":
metrics_per_scene = {}
for o in outputs:
if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
metrics_per_scene[o['metrics']['scene_id'][0]] = []
metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
aucs_per_scene = {}
for scene_id in metrics_per_scene.keys():
# 2. val metrics: dict of list, numpy
_metrics = metrics_per_scene[scene_id]
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
aucs_per_scene[scene_id] = val_metrics
# average the metrics of scenes
# since the number of images in each scene is different
val_metrics_4tb = {}
for thr in [5, 10, 20]:
temp = []
for scene_id in metrics_per_scene.keys():
temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
val_metrics_4tb[f'auc@{thr}'] = float(np.array(temp, dtype=float).mean())
temp = []
for scene_id in metrics_per_scene.keys():
temp.append(aucs_per_scene[scene_id][f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'])
val_metrics_4tb[f'prec@{self.config.TRAINER.EPI_ERR_THR:.0e}'] = float(np.array(temp, dtype=float).mean())
else:
# 1. loss_scalars: dict of list, on cpu
_loss_scalars = [o['loss_scalars'] for o in outputs]
loss_scalars = {k: flattenList(all_gather([_ls[k] for _ls in _loss_scalars])) for k in _loss_scalars[0]}
# 2. val metrics: dict of list, numpy
_metrics = [o['metrics'] for o in outputs]
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
for thr in [5, 10, 20]:
multi_val_metrics[f'auc@{thr}'].append(val_metrics_4tb[f'auc@{thr}'])
# 3. figures
_figures = [o['figures'] for o in outputs]
figures = {k: flattenList(gather(flattenList([_me[k] for _me in _figures]))) for k in _figures[0]}
# tensorboard records only on rank 0
if self.trainer.global_rank == 0:
if self.config.DATASET.VAL_DATA_SOURCE != "VisTir":
for k, v in loss_scalars.items():
mean_v = torch.stack(v).mean()
self.logger.experiment.add_scalar(f'val_{valset_idx}/avg_{k}', mean_v, global_step=cur_epoch)
for k, v in val_metrics_4tb.items():
self.logger[0].experiment.add_scalar(f"metrics_{valset_idx}/{k}", v, global_step=cur_epoch)
if self.config.TRAINER.USE_WANDB:
self.logger[1].log_metrics({f"metrics_{valset_idx}/{k}": v}, cur_epoch)
for k, v in figures.items():
if self.trainer.global_rank == 0:
for plot_idx, fig in enumerate(v):
self.logger[0].experiment.add_figure(
f'val_match_{valset_idx}/{k}/pair-{plot_idx}', fig, cur_epoch, close=True)
plt.close('all')
for thr in [5, 10, 20]:
# log on all ranks for ModelCheckpoint callback to work properly
self.log(f'auc@{thr}', torch.tensor(np.mean(multi_val_metrics[f'auc@{thr}']))) # ckpt monitors on this
def test_step(self, batch, batch_idx):
with self.profiler.profile("XoFTR"):
self.matcher(batch)
ret_dict, rel_pair_names = self._compute_metrics(batch)
with self.profiler.profile("dump_results"):
if self.dump_dir is not None:
# dump results for further analysis
keys_to_save = {'mkpts0_f', 'mkpts1_f', 'mconf_f', 'epi_errs'}
pair_names = list(zip(*batch['pair_names']))
bs = batch['image0'].shape[0]
dumps = []
for b_id in range(bs):
item = {}
mask = batch['m_bids'] == b_id
item['pair_names'] = pair_names[b_id]
item['identifier'] = '#'.join(rel_pair_names[b_id])
if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
item['scene_id'] = batch['scene_id']
item['K0'] = batch['K0'][b_id].cpu().numpy()
item['K1'] = batch['K1'][b_id].cpu().numpy()
item['dist0'] = batch['dist0'][b_id].cpu().numpy()
item['dist1'] = batch['dist1'][b_id].cpu().numpy()
for key in keys_to_save:
item[key] = batch[key][mask].cpu().numpy()
for key in ['R_errs', 't_errs', 'inliers']:
item[key] = batch[key][b_id]
dumps.append(item)
ret_dict['dumps'] = dumps
return ret_dict
def test_epoch_end(self, outputs):
if self.config.DATASET.TEST_DATA_SOURCE == "VisTir":
# metrics: dict of list, numpy
metrics_per_scene = {}
for o in outputs:
if not o['metrics']['scene_id'][0] in metrics_per_scene.keys():
metrics_per_scene[o['metrics']['scene_id'][0]] = []
metrics_per_scene[o['metrics']['scene_id'][0]].append(o['metrics'])
aucs_per_scene = {}
for scene_id in metrics_per_scene.keys():
# 2. val metrics: dict of list, numpy
_metrics = metrics_per_scene[scene_id]
metrics = {k: flattenList(all_gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
# NOTE: all ranks need to `aggregate_merics`, but only log at rank-0
val_metrics = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
aucs_per_scene[scene_id] = val_metrics
# average the metrics of scenes
# since the number of images in each scene is different
val_metrics_4tb = {}
for thr in [5, 10, 20]:
temp = []
for scene_id in metrics_per_scene.keys():
temp.append(aucs_per_scene[scene_id][f'auc@{thr}'])
val_metrics_4tb[f'auc@{thr}'] = np.array(temp, dtype=float).mean()
else:
# metrics: dict of list, numpy
_metrics = [o['metrics'] for o in outputs]
metrics = {k: flattenList(gather(flattenList([_me[k] for _me in _metrics]))) for k in _metrics[0]}
# [{key: [{...}, *#bs]}, *#batch]
if self.dump_dir is not None:
Path(self.dump_dir).mkdir(parents=True, exist_ok=True)
_dumps = flattenList([o['dumps'] for o in outputs]) # [{...}, #bs*#batch]
dumps = flattenList(gather(_dumps)) # [{...}, #proc*#bs*#batch]
logger.info(f'Prediction and evaluation results will be saved to: {self.dump_dir}')
if self.trainer.global_rank == 0:
print(self.profiler.summary())
val_metrics_4tb = aggregate_metrics(metrics, self.config.TRAINER.EPI_ERR_THR)
logger.info('\n' + pprint.pformat(val_metrics_4tb))
if self.dump_dir is not None:
np.save(Path(self.dump_dir) / 'XoFTR_pred_eval', dumps)