import argparse import os import time import sys import glob from pathlib import Path from RDD.RDD_helper import RDD_helper import torch.distributed def parse_arguments(): parser = argparse.ArgumentParser(description="XFeat training script.") parser.add_argument('--megadepth_root_path', type=str, default='./data/megadepth', help='Path to the MegaDepth dataset root directory.') parser.add_argument('--test_data_root', type=str, default='./data/megadepth_test_1500', help='Path to the MegaDepth test dataset root directory.') parser.add_argument('--ckpt_save_path', type=str, required=True, help='Path to save the checkpoints.') parser.add_argument('--model_name', type=str, default='RDD', help='Name of the model to save.') parser.add_argument('--air_ground_root_path', type=str, default='./data/air_ground_data_2/AirGround') parser.add_argument('--batch_size', type=int, default=4, help='Batch size for training. Default is 4.') parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate. Default is 0.0001.') parser.add_argument('--gamma_steplr', type=float, default=0.5, help='Gamma value for StepLR scheduler. Default is 0.5.') parser.add_argument('--training_res', type=int, default=800, help='Training resolution as width,height. Default is 800 for training descriptor.') parser.add_argument('--save_ckpt_every', type=int, default=500, help='Save checkpoints every N steps. Default is 500.') parser.add_argument('--test_every_iter', type=int, default=2000, help='Save checkpoints every N steps. Default is 2000.') parser.add_argument('--weights', type=str, default=None,) parser.add_argument('--num_encoder_layers', type=int, default=4) parser.add_argument('--enc_n_points', type=int, default=8) parser.add_argument('--num_feature_levels', type=int, default=5) parser.add_argument('--train_detector', action='store_true', default=False) parser.add_argument('--epochs', type=int, default=20) parser.add_argument('--distributed', action='store_true', default=False) parser.add_argument('--config_path', type=str, default='./configs/default.yaml') args = parser.parse_args() return args args = parse_arguments() import torch from torch import optim import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter import numpy as np from RDD.RDD import build from training.utils import * from training.losses import * from benchmarks.mega_1500 import MegaDepthPoseMNNBenchmark from RDD.dataset.megadepth.megadepth import MegaDepthDataset from RDD.dataset.megadepth import megadepth_warper from torch.utils.data import Dataset, DataLoader, DistributedSampler, RandomSampler, WeightedRandomSampler from training.losses.detector_loss import compute_correspondence, DetectorLoss from training.losses.descriptor_loss import DescriptorLoss import tqdm from torch.optim.lr_scheduler import MultiStepLR, StepLR from datetime import timedelta from RDD.utils import read_config torch.autograd.set_detect_anomaly(True) torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = True class Trainer(): """ Class for training XFeat with default params as described in the paper. We use a blend of MegaDepth (labeled) pairs with synthetically warped images (self-supervised). The major bottleneck is to keep loading huge megadepth h5 files from disk, the network training itself is quite fast. """ def __init__(self, rank, args=None): config = read_config(args.config_path) config['num_encoder_layers'] = args.num_encoder_layers config['enc_n_points'] = args.enc_n_points config['num_feature_levels'] = args.num_feature_levels config['train_detector'] = args.train_detector config['weights'] = args.weights # distributed training if args.distributed: print(f"Training in distributed mode with {args.n_gpus} GPUs") assert torch.cuda.is_available() device = rank torch.distributed.init_process_group( backend="nccl", world_size=args.n_gpus, rank=device, init_method="file://" + str(args.lock_file), timeout=timedelta(seconds=2000) ) torch.cuda.set_device(device) # adjust batch size and num of workers since these are per GPU batch_size = int(args.batch_size / args.n_gpus) self.n_gpus = args.n_gpus else: device = "cuda" if torch.cuda.is_available() else "cpu" batch_size = args.batch_size print(f"Using device {device}") self.seed = 0 self.set_seed(self.seed) self.training_res = args.training_res self.dev = device config['device'] = device model = build(config) self.rank = rank if args.weights is not None: print('Loading weights from ', args.weights) model.load_state_dict(torch.load(args.weights, map_location='cpu')) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) self.model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[device], find_unused_parameters=True ) else: self.model = model.to(device) self.saved_ckpts = [] self.best = -1.0 self.best_loss = 1e6 self.fine_weight = 1.0 self.dual_softmax_weight = 1.0 self.heatmaps_weight = 1.0 #Setup optimizer self.batch_size = batch_size self.epochs = args.epochs self.opt = optim.AdamW(filter(lambda x: x.requires_grad, self.model.parameters()) , lr = args.lr, weight_decay=1e-4) # losses if args.train_detector: self.DetectorLoss = DetectorLoss(temperature=0.1, scores_th=0.1) else: self.DescriptorLoss = DescriptorLoss(inv_temp=20, dual_softmax_weight=1, heatmap_weight=1) self.benchmark = MegaDepthPoseMNNBenchmark(data_root=args.test_data_root) ##################### MEGADEPTH INIT ########################## TRAIN_BASE_PATH = f"{args.megadepth_root_path}/megadepth_indices" print('Loading MegaDepth dataset from ', TRAIN_BASE_PATH) TRAINVAL_DATA_SOURCE = args.megadepth_root_path self.TRAINVAL_DATA_SOURCE = TRAINVAL_DATA_SOURCE TRAIN_NPZ_ROOT = f"{TRAIN_BASE_PATH}/scene_info_0.1_0.7" self.TRAIN_NPZ_ROOT = TRAIN_NPZ_ROOT npz_paths = glob.glob(TRAIN_NPZ_ROOT + '/*.npz')[:] self.npz_paths = npz_paths self.epoch = 0 self.create_data_loader() ##################### MEGADEPTH INIT END ####################### os.makedirs(args.ckpt_save_path, exist_ok=True) os.makedirs(args.ckpt_save_path / 'logdir', exist_ok=True) self.save_ckpt_every = args.save_ckpt_every self.ckpt_save_path = args.ckpt_save_path if rank == 0: self.writer = SummaryWriter(str(self.ckpt_save_path) + f'/logdir/{args.model_name}_' + time.strftime("%Y_%m_%d-%H_%M_%S")) else: self.writer = None self.model_name = args.model_name if args.distributed: self.scheduler = MultiStepLR(self.opt, milestones=[2, 4, 8, 16], gamma=args.gamma_steplr) else: self.scheduler = StepLR(self.opt, step_size=args.test_every_iter, gamma=args.gamma_steplr) def set_seed(self, seed): np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) def create_data_loader(self): # Create sampler if not args.train_detector: mega_crop = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, npz_path = path, min_overlap_score=0.01, max_overlap_score=0.7, image_size=self.training_res, num_per_scene=200, gray=False, crop_or_scale='crop') for path in self.npz_paths] ) mega_scale = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, npz_path = path, min_overlap_score=0.01, max_overlap_score=0.7, image_size=self.training_res, num_per_scene=200, gray=False, crop_or_scale='scale') for path in self.npz_paths] ) combined_dataset = torch.utils.data.ConcatDataset([mega_crop, mega_scale]) else: mega_crop = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, npz_path = path, min_overlap_score=0.1, max_overlap_score=0.8, image_size=self.training_res, num_per_scene=100, gray=False, crop_or_scale='crop') for path in self.npz_paths] ) mega_scale = torch.utils.data.ConcatDataset( [MegaDepthDataset(root = self.TRAINVAL_DATA_SOURCE, npz_path = path, min_overlap_score=0.1, max_overlap_score=0.8, image_size=self.training_res, num_per_scene=100, gray=False, crop_or_scale='scale') for path in self.npz_paths] ) combined_dataset = torch.utils.data.ConcatDataset([mega_crop, mega_scale]) # Create sampler if args.distributed: sampler = DistributedSampler(combined_dataset, rank=self.rank, num_replicas=self.n_gpus) else: # Create sampler sampler = RandomSampler(combined_dataset) # Create single DataLoader with combined dataset self.data_loader = DataLoader(combined_dataset, batch_size=self.batch_size, sampler=sampler, num_workers=4, pin_memory=True) def validate(self, total_steps): with torch.no_grad(): if args.train_detector: method = 'sparse' else: method = 'aliked' if args.distributed: self.model.module.eval() model_helper = RDD_helper(self.model.module) test_out = self.benchmark.benchmark(model_helper, model_name='experiment', plot_every_iter=1, plot=False, method=method) else: self.model.eval() model_helper = RDD_helper(self.model) test_out = self.benchmark.benchmark(model_helper, model_name='experiment', plot_every_iter=1, plot=False, method=method) auc5 = test_out['auc_5'] auc10 = test_out['auc_10'] auc20 = test_out['auc_20'] if self.rank == 0: self.writer.add_scalar('Accuracy/auc5', auc5, total_steps) self.writer.add_scalar('Accuracy/auc10', auc10, total_steps) self.writer.add_scalar('Accuracy/auc20', auc20, total_steps) if auc5 > self.best: self.best = auc5 if args.distributed: torch.save(self.model.module.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_best.pth') else: torch.save(self.model.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_best.pth') self.model.train() def _inference(self, d): if d is not None: for k in d.keys(): if isinstance(d[k], torch.Tensor): d[k] = d[k].to(self.dev) p1, p2 = d['image0'], d['image1'] if not args.train_detector: positives_md_coarse = megadepth_warper.spvs_coarse(d, self.stride) with torch.no_grad(): p1 = p1 ; p2 = p2 if not args.train_detector: positives_c = positives_md_coarse # Check if batch is corrupted with too few correspondences is_corrupted = False if not args.train_detector: for p in positives_c: if len(p) < 30: is_corrupted = True if is_corrupted: return None, None, None, None # Forward pass feats1, scores_map1, hmap1 = self.model(p1) feats2, scores_map2, hmap2 = self.model(p2) if args.train_detector: # move all tensors on batch to GPU for k in d.keys(): if isinstance(d[k], torch.Tensor): d[k] = d[k].to(self.dev) elif isinstance(d[k], dict): for k2 in d[k].keys(): if isinstance(d[k][k2], torch.Tensor): d[k][k2] = d[k][k2].to(self.dev) # Get positive correspondencies pred0 = {'descriptor_map': F.interpolate(feats1, size=scores_map1.shape[-2:], mode='bilinear', align_corners=True), 'scores_map': scores_map1 } pred1 = {'descriptor_map': F.interpolate(feats2, size=scores_map2.shape[-2:], mode='bilinear', align_corners=True), 'scores_map': scores_map2 } if args.distributed: correspondences, pred0_with_rand, pred1_with_rand = compute_correspondence(self.model.module, pred0, pred1, d, debug=True) else: correspondences, pred0_with_rand, pred1_with_rand = compute_correspondence(self.model, pred0, pred1, d, debug=False) loss_kp = self.DetectorLoss(correspondences, pred0_with_rand, pred1_with_rand) loss = loss_kp acc_coarse, acc_kp, nb_coarse = 0, 0, 0 else: loss_items = [] acc_coarse_items = [] acc_kp_items = [] for b in range(len(positives_c)): if len(positives_c[b]) > 10000: positives = positives_c[b][torch.randperm(len(positives_c[b]))[:10000]] else: positives = positives_c[b] # Get positive correspondencies pts1, pts2 = positives[:, :2], positives[:, 2:] h1 = hmap1[b, :, :, :] h2 = hmap2[b, :, :, :] m1 = feats1[b, :, pts1[:,1].long(), pts1[:,0].long()].permute(1,0) m2 = feats2[b, :, pts2[:,1].long(), pts2[:,0].long()].permute(1,0) # Compute losses loss_ds, loss_h, acc_kp = self.DescriptorLoss(m1, m2, h1, h2, pts1, pts2) loss_items.append(loss_ds.unsqueeze(0)) acc_coarse = check_accuracy1(m1, m2) acc_kp_items.append(acc_kp) acc_coarse_items.append(acc_coarse) nb_coarse = len(m1) loss = loss_kp if args.train_detector else torch.cat(loss_items, -1).mean() acc_coarse = sum(acc_coarse_items) / len(acc_coarse_items) acc_kp = sum(acc_kp_items) / len(acc_kp_items) return loss, acc_coarse, acc_kp, nb_coarse def train(self): self.model.train() self.stride = 4 if args.num_feature_levels == 5 else 8 total_steps = 0 for epoch in range(self.epochs): if args.distributed: self.data_loader.sampler.set_epoch(epoch) pbar = tqdm.tqdm(total=len(self.data_loader), desc=f"Epoch {epoch+1}/{args.epochs}") if self.rank == 0 else None for i, d in enumerate(self.data_loader): loss, acc_coarse, acc_kp, nb_coarse = self._inference(d) if loss is None: continue # Compute Backward Pass loss.backward() torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.) self.opt.step() self.opt.zero_grad() if (total_steps + 1) % self.save_ckpt_every == 0 and self.rank == 0: print('saving iter ', total_steps + 1) if args.distributed: torch.save(self.model.module.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_{total_steps + 1}.pth') else: torch.save(self.model.state_dict(), str(self.ckpt_save_path) + f'/{self.model_name}_{total_steps + 1}.pth') self.saved_ckpts.append(total_steps + 1) if len(self.saved_ckpts) > 5: os.remove(str(self.ckpt_save_path) + f'/{self.model_name}_{self.saved_ckpts[0]}.pth') self.saved_ckpts = self.saved_ckpts[1:] if args.distributed: torch.distributed.barrier() if (total_steps+1) % args.test_every_iter == 0: self.validate(total_steps) if pbar is not None: if args.train_detector: pbar.set_description( 'Loss: {:.4f} '.format(loss.item()) ) else: pbar.set_description( 'Loss: {:.4f} acc_coarse {:.3f} acc_kp: {:.3f} #matches_c: {:d}'.format( loss.item(), acc_coarse, acc_kp, nb_coarse) ) pbar.update(1) # Log metrics if self.rank == 0: self.writer.add_scalar('Loss/total', loss.item(), total_steps) self.writer.add_scalar('Accuracy/coarse_mdepth', acc_coarse, total_steps) self.writer.add_scalar('Count/matches_coarse', nb_coarse, total_steps) if not args.distributed: self.scheduler.step() total_steps = total_steps + 1 self.validate(total_steps) if self.rank == 0: print('Epoch ', epoch, ' done.') print('Creating new data loader with seed ', self.seed) self.seed = self.seed + 1 self.set_seed(self.seed) self.scheduler.step() self.epoch = self.epoch + 1 self.create_data_loader() def main_worker(rank, args): trainer = Trainer( rank=rank, args=args ) # The most fun part trainer.train() if __name__ == '__main__': if args.distributed: import torch.multiprocessing as mp mp.set_start_method('spawn', force=True) if not Path(args.ckpt_save_path).exists(): os.makedirs(args.ckpt_save_path) args.ckpt_save_path = Path(args.ckpt_save_path).resolve() if args.distributed: args.n_gpus = torch.cuda.device_count() args.lock_file = Path(args.ckpt_save_path) / "distributed_lock" if args.lock_file.exists(): args.lock_file.unlink() # Each process gets its own rank and dataset torch.multiprocessing.spawn( main_worker, nprocs=args.n_gpus, args=(args,) ) else: main_worker(0, args)