# -*- coding: utf-8 -*- # @Author : xuelun import cv2 import math import uuid import pytorch_lightning as pl from pathlib import Path from os.path import join, exists from argparse import ArgumentParser from yacs.config import CfgNode as CN from pytorch_lightning.plugins import DDPPlugin from pytorch_lightning.loggers import TensorBoardLogger import tools as com from trainer import Trainer from networks.loftr.configs.outdoor import trainer_cfg, network_cfg from networks.loftr.config import get_cfg_defaults as get_network_cfg from trainer.config import get_cfg_defaults as get_trainer_cfg from trainer.debug import get_cfg_defaults as get_debug_cfg from datasets.data import MultiSceneDataModule from datasets import gl3d from datasets import gtasfm from datasets import multifov from datasets import blendedmvs from datasets import iclnuim from datasets import scenenet from datasets import eth3d from datasets import kitti from datasets import robotcar Benchmarks = dict( GL3D = gl3d.cfg, GTASfM = gtasfm.cfg, MultiFoV = multifov.cfg, BlendedMVS = blendedmvs.cfg, ICLNUIM = iclnuim.cfg, SceneNet = scenenet.cfg, ETH3DO = eth3d.cfgO, ETH3DI = eth3d.cfgI, KITTI = kitti.cfg, RobotcarNight = robotcar.night, RobotcarSeason = robotcar.season, RobotcarWeather = robotcar.weather, ) RANSACs = dict( RANSAC = cv2.RANSAC, FAST = cv2.USAC_FAST, MAGSAC = cv2.USAC_MAGSAC, PROSAC = cv2.USAC_PROSAC, DEFAULT = cv2.USAC_DEFAULT, ACCURATE = cv2.USAC_ACCURATE, PARALLEL = cv2.USAC_PARALLEL, ) MODEL_ZOO = ['gim_dkm', 'gim_loftr', 'gim_lightglue', 'root_sift'] if __name__ == '__main__': # ------------ # Hyperparameters # ------------ parser = ArgumentParser() # Project args parser.add_argument('--trains', type=str, choices=set(Benchmarks), nargs='+', default=[], help=f'Train Datasets: {set(Benchmarks)}', ) parser.add_argument('--valids', type=str, choices=set(Benchmarks), nargs='+', default=[], help=f'Valid Datasets: {set(Benchmarks)}', ) parser.add_argument('--tests', type=str, choices=set(Benchmarks), default=None, help=f'Test Datasets: {set(Benchmarks)}', ) parser.add_argument('--debug', action='store_true', help='For debug mode') # Loader args parser.add_argument('--batch_size', type=int, default=12, help='input batch size for training and validation (default=2)') parser.add_argument('--threads', type=int, default=3, help='Number of threads (default: 3)') # Traner args parser.add_argument('--gpus', type=int, default=1, help='GPU numbers') parser.add_argument('--num_nodes', type=int, default=1, help='Cluster node numbers') parser.add_argument('--max_epochs', type=int, default=30, help='Traning epochs (default: 30)') parser.add_argument("--git", type=str, default='xxxxxx', help=f'Git ID',) parser.add_argument("--weight", type=str, default=None, choices=MODEL_ZOO, required=True, help=f'Pretrained model weight',) # Hyper-parameters parser.add_argument('--img_size', type=int, default=9999, help='Image Size') parser.add_argument('--lr', type=float, default=8e-3, help='Learning rate') # Runtime args parser.add_argument('--test', action='store_true', help="Tesing") parser.add_argument('--viz', action='store_true', help="Tesing") parser.add_argument("--max_samples", type=int, default=None, help=f'Max Samples in Testing',) parser.add_argument("--min_score", type=float, default=0.0, help='Min Score in Testing',) parser.add_argument("--max_score", type=float, default=1.0, help='Max Score in Testing',) parser.add_argument("--ransac_threshold", type=float, default=0.5, help='RANSAC Threshold',) parser.add_argument('--ransac', type=str, choices=set(RANSACs), default='MAGSAC', help=f'RANSAC Methods: {set(RANSACs)}', ) parser.add_argument("--version", type=str, default='AUC', help=f'Model version',) args = parser.parse_args() # ------------ # Project config # ------------ pcfg = CN(vars(args)) tcfg = get_trainer_cfg() ncfg = get_network_cfg() dcfg = CN({x:Benchmarks.get(x, None) for x in set(args.trains + args.valids + [args.tests])}) tcfg.merge_from_other_cfg(trainer_cfg) if args.debug: tcfg.merge_from_other_cfg(get_debug_cfg()) ncfg.merge_from_other_cfg(network_cfg) dcfg.DF = ncfg.LOFTR.RESOLUTION[0] # load weight ncfg.LOFTR.WEIGHT = join('weights', args.weight + '_' + args.version + '.ckpt') if args.weight == 'root_sift': ncfg.LOFTR.WEIGHT = None # ------------ # Testing setting # ------------ if args.max_samples is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MAX_SAMPLES'] = args.max_samples if args.min_score is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MIN_OVERLAP_SCORE'] = args.min_score if args.max_score is not None and args.test: dcfg[args.tests]['DATASET']['TESTS']['MAX_OVERLAP_SCORE'] = args.max_score # print(dcfg) # ------------ # Update Trainer Config # ------------ TRAINER = tcfg.TRAINER TRAINER.TRUE_BATCH_SIZE = args.gpus * args.batch_size TRAINER.SCALING = _scaling = TRAINER.TRUE_BATCH_SIZE / TRAINER.CANONICAL_BS TRAINER.CANONICAL_LR = args.lr TRAINER.TRUE_LR = TRAINER.CANONICAL_LR * _scaling TRAINER.WARMUP_STEP = math.floor(TRAINER.WARMUP_STEP / _scaling) TRAINER.RANSAC_PIXEL_THR = args.ransac_threshold TRAINER.POSE_ESTIMATION_METHOD = RANSACs[args.ransac] # ------------ # W&B logger # ------------ # com.login(args.server) wid = str(uuid.uuid1()).split('-')[0] com.hint('ID = {}'.format(wid)) logger = TensorBoardLogger('tensorboard', name='test', version='test') # ------------ # reproducible # ------------ pl.seed_everything(TRAINER.SEED, workers=True) # ------------ # data loader # ------------ dm = MultiSceneDataModule(args, dcfg) # ------------ # model # ------------ trainer = Trainer(pcfg, tcfg, dcfg, ncfg) # ------------ # training # ------------ fitter = pl.Trainer.from_argparse_args( args, # ddp sync_batchnorm=True, strategy=DDPPlugin(find_unused_parameters=False), # reproducible benchmark=True, deterministic=False, # logger enable_checkpointing=False, logger=logger, log_every_n_steps=TRAINER.LOG_INTERVAL, # prepare weights_summary='top', val_check_interval=TRAINER.VAL_CHECK_INTERVAL, num_sanity_val_steps=TRAINER.NUM_SANITY_VAL_STEPS, limit_train_batches=TRAINER.LIMIT_TRAIN_BATCHES, limit_val_batches=TRAINER.LIMIT_VALID_BATCHES, # faster training # amp_level=TRAINER.AMP_LEVEL, # amp_backend=TRAINER.AMP_BACKEND, # precision=TRAINER.PRECISION, #https://github.com/PyTorchLightning/pytorch-lightning/issues/5558 # better fine-tune gradient_clip_val=TRAINER.GRADIENT_CLIP_VAL, gradient_clip_algorithm=TRAINER.GRADIENT_CLIP_ALGORITHM, ) # ------------ # Fitting # ------------ if args.test: scene = Path(dcfg[pcfg["tests"]]['DATASET']['TESTS']['LIST_PATH']).stem.split('_')[0] path = f"dump/zeb/[T] {pcfg.weight} {scene:>15} {pcfg.version}.txt" if exists(path): print(f"{path} already exists") exit(0) elif not exists(str(Path(path).parent)): Path(path).parent.mkdir(parents=True) fitter.test(trainer, datamodule=dm) else: fitter.fit(trainer, datamodule=dm)