# -*- coding: utf-8 -*- # @Author : xuelun import os import cv2 import csv import math import torch import scipy.io import warnings import argparse import numpy as np from os import mkdir from tqdm import tqdm from copy import deepcopy from os.path import join, exists from torch.utils.data import DataLoader from datasets.walk.video_streamer import VideoStreamer from datasets.walk.video_loader import WALKDataset, collate_fn from networks.mit_semseg.models import ModelBuilder, SegmentationModule gray2tensor = lambda x: (torch.from_numpy(x).float() / 255)[None, None] color2tensor = lambda x: (torch.from_numpy(x).float() / 255).permute(2, 0, 1)[None] warnings.simplefilter("ignore", category=UserWarning) methods = {'SIFT', 'GIM_GLUE', 'GIM_LOFTR', 'GIM_DKM'} PALETTE = scipy.io.loadmat('weights/color150.mat')['colors'] CLS_DICT = {} # {'person': 13, 'sky': 3} with open('weights/object150_info.csv') as f: reader = csv.reader(f) next(reader) for row in reader: name = row[5].split(";")[0] if name == 'screen': name = '_'.join(row[5].split(";")[:2]) CLS_DICT[name] = int(row[0]) - 1 exclude = ['person', 'sky', 'car'] def main(): parser = argparse.ArgumentParser() parser.add_argument('--debug', action='store_true') parser.add_argument("--gpu", type=int, default=0, help='-1 for CPU') parser.add_argument("--range", type=int, nargs='+', default=None, help='Video Range for seconds') parser.add_argument('--scene_name', type=str, default=None, help='Scene (video) name') parser.add_argument('--method', type=str, choices=methods, required=True, help='Method name') parser.add_argument('--resize', action='store_true', help='whether resize') parser.add_argument('--skip', type=int, required=True, help='Video skip frame: 1, 2, 3, ...') parser.add_argument('--watermarker', type=int, nargs='+', default=None, help='Watermarker Rectangle Range') opt = parser.parse_args() data_root = join('data', 'ZeroMatch') video_name = opt.scene_name.strip() video_path = join(data_root, 'video_1080p', video_name + '.mp4') # get real size of video vcap = cv2.VideoCapture(video_path) vwidth = vcap.get(3) # float `width` vheight = vcap.get(4) # float `height` fps = vcap.get(5) # float `fps` end_range = math.floor(vcap.get(cv2.CAP_PROP_FRAME_COUNT) / fps - 300) vcap.release() fps = math.ceil(fps) opt.range = [300, end_range] if opt.range is None else opt.range opt.range = [0, -1] if video_name == 'Od-rKbC30TM' else opt.range # for demo if fps <= 30: skip = [10, 20, 40][opt.skip] else: skip = [20, 40, 80][opt.skip] dump_dir = join(data_root, 'pseudo', 'WALK ' + opt.method + ' [R] ' + '{}'.format('T' if opt.resize else 'F') + ' [S] ' + '{:2}'.format(skip)) if not exists(dump_dir): mkdir(dump_dir) debug_dir = join('dump', video_name + ' ' + opt.method) if opt.resize: debug_dir = debug_dir + ' Resize' if opt.debug and (not exists(debug_dir)): mkdir(debug_dir) # start process video gap = 10 if fps <= 30 else 20 vs = VideoStreamer(basedir=video_path, resize=opt.resize, df=8, skip=gap, vrange=opt.range) # read the first frame rgb = vs[vs.listing[0]] width, height = rgb.shape[1], rgb.shape[0] # calculate ratio vratio = np.array([vwidth / width, vheight / height])[None] # set dump name scene_name = f'{video_name} ' scene_name += f'WH {width:4} {height:4} ' scene_name += f'RG {vs.range[0]:4} {vs.range[1]:4} ' scene_name += f'SP {skip} ' scene_name += f'{len(video_name)}' save_dir = join(dump_dir, scene_name) device = torch.device('cuda:{}'.format(opt.gpu)) if opt.gpu >= 0 else torch.device('cpu') # initialize segmentation model net_encoder = ModelBuilder.build_encoder( arch='resnet50dilated', fc_dim=2048, weights='weights/encoder_epoch_20.pth') net_decoder = ModelBuilder.build_decoder( arch='ppm_deepsup', fc_dim=2048, num_class=150, weights='weights/decoder_epoch_20.pth', use_softmax=True) crit = torch.nn.NLLLoss(ignore_index=-1) segmentation_module = SegmentationModule(net_encoder, net_decoder, crit).to(device).eval() old_segment_root = join(data_root, 'segment', opt.scene_name) new_segment_root = join(data_root, 'segment', opt.scene_name.strip()) if not os.path.exists(new_segment_root): if os.path.exists(old_segment_root): os.rename(old_segment_root, new_segment_root) else: os.makedirs(new_segment_root, exist_ok=True) segment_root = new_segment_root model, detectAndCompute = None, None if opt.method == 'SIFT': model = cv2.SIFT_create(nfeatures=32400, contrastThreshold=1e-5) detectAndCompute = model.detectAndCompute elif opt.method == 'GIM_DKM': from networks.dkm.models.model_zoo.DKMv3 import DKMv3 model = DKMv3(weights=None, h=672, w=896) checkpoints_path = join('weights', 'gim_dkm_100h.ckpt') state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('model.'): state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) if 'encoder.net.fc' in k: state_dict.pop(k) model.load_state_dict(state_dict) model = model.eval().to(device) elif opt.method == 'GIM_LOFTR': from networks.loftr.loftr import LoFTR from networks.loftr.misc import lower_config from networks.loftr.config import get_cfg_defaults cfg = get_cfg_defaults() cfg.TEMP_BUG_FIX = True cfg.LOFTR.WEIGHT = 'weights/gim_loftr_50h.ckpt' cfg.LOFTR.FINE_CONCAT_COARSE_FEAT = False cfg = lower_config(cfg) model = LoFTR(cfg['loftr']) model = model.to(device) model = model.eval() elif opt.method == 'GIM_GLUE': from networks.lightglue.matching import Matching model = Matching() checkpoints_path = join('weights', 'gim_lightglue_100h.ckpt') state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('model.'): state_dict.pop(k) if k.startswith('superpoint.'): state_dict[k.replace('superpoint.', '', 1)] = state_dict.pop(k) model.detector.load_state_dict(state_dict) state_dict = torch.load(checkpoints_path, map_location='cpu') if 'state_dict' in state_dict.keys(): state_dict = state_dict['state_dict'] for k in list(state_dict.keys()): if k.startswith('superpoint.'): state_dict.pop(k) if k.startswith('model.'): state_dict[k.replace('model.', '', 1)] = state_dict.pop(k) model.model.load_state_dict(state_dict) model = model.to(device) model = model.eval() cache_dir = None if opt.resize: cache_dir = join(data_root, 'pseudo', 'WALK ' + 'GIM_DKM' + ' [R] F' + ' [S] ' + '{:2}'.format(skip), scene_name) _w_ = width if opt.method == 'SIFT' or opt.method == 'GLUE' else 1600 # TODO: confirm DKM _h_ = height if opt.method == 'SIFT' or opt.method == 'GLUE' else 900 # TODO: confirm DKM ids = list(zip(vs.listing[:-skip // gap], vs.listing[skip // gap:])) # start matching and make pseudo labels nums = None idxs = None checkpoint = 0 if not opt.debug: if exists(join(save_dir, 'nums.npy')) and exists(join(save_dir, 'idxs.npy')): with open(join(save_dir, 'nums.npy'), 'rb') as f: nums = np.load(f) with open(join(save_dir, 'idxs.npy'), 'rb') as f: idxs = np.load(f) assert len(nums) == len(idxs) == (len(os.listdir(save_dir)) - 2) whole = [str(x) + '.npy' for x in np.array(ids)] cache = [str(x) + '.npy' for x in idxs] leave = list(set(whole) - set(cache)) if len(leave): leave = list(map(lambda x: int(x.rsplit('[')[-1].strip().split()[0]), leave)) skip_id = np.array(sorted(leave)) skip_id = (skip_id[1:] - skip_id[:-1]) // gap len_id = len(skip_id) if len_id == 0: exit(0) skip_id = [i for i in range(len_id) if skip_id[i:].sum() == (len_id - i)] if len(skip_id) == 0: exit(0) skip_id = skip_id[0] checkpoint = np.where(np.array(ids)[:, 0]==sorted(leave)[skip_id])[0][0] if len(nums) + skip_id > checkpoint: exit(0) assert checkpoint == len(nums) + skip_id else: exit(0) else: if not exists(save_dir): mkdir(save_dir) nums = np.array([]) idxs = np.array([]) datasets = WALKDataset(data_root, vs=vs, ids=ids, checkpoint=checkpoint, opt=opt) loader_params = {'batch_size': 1, 'shuffle': False, 'num_workers': 5, 'pin_memory': True, 'drop_last': False} loader = DataLoader(datasets, collate_fn=collate_fn, **loader_params) for i, batch in enumerate(tqdm(loader, ncols=120, bar_format="{l_bar}{bar:3}{r_bar}", desc='{:11} - [{:5}, {:2}{}]'.format(video_name[:40], opt.method, skip, '*' if opt.resize else ''), total=len(loader), leave=False)): idx = batch['idx'].item() assert i == idx idx0 = batch['idx0'].item() idx1 = batch['idx1'].item() assert idx0 == ids[idx+checkpoint][0] and idx1 == ids[idx+checkpoint][1] # cache loaded image if not batch['rgb0_is_good'].item(): img_path0 = batch['img_path0'][0] if not os.path.exists(img_path0): cv2.imwrite(img_path0, batch['rgb0'].squeeze(0).numpy()) if not batch['rgb1_is_good'].item(): img_path1 = batch['img_path1'][0] if not os.path.exists(img_path1): cv2.imwrite(img_path1, batch['rgb1'].squeeze(0).numpy()) current_id = np.array([idx0, idx1]) save_name = '{}.npy'.format(str(current_id)) save_path = join(save_dir, save_name) if exists(save_path) and not opt.debug: continue rgb0 = batch['rgb0'].squeeze(0).numpy() rgb1 = batch['rgb1'].squeeze(0).numpy() _rgb0_, _rgb1_ = deepcopy(rgb0), deepcopy(rgb1) # get correspondeces in unresize image pt0, pt1 = None, None if opt.resize: cache_path = join(cache_dir, save_name) if not exists(cache_path): continue with open(cache_path, 'rb') as f: pts = np.load(f) pt0, pt1 = pts[:, :2], pts[:, 2:] # process first frame image xA0, xA1, yA0, yA1, hA, wA, wA_new, hA_new = None, None, None, None, None, None, None, None if opt.resize: # crop rgb0 xA0 = math.floor(pt0[:, 0].min()) xA1 = math.ceil(pt0[:, 0].max()) yA0 = math.floor(pt0[:, 1].min()) yA1 = math.ceil(pt0[:, 1].max()) rgb0 = rgb0[yA0:yA1, xA0:xA1] hA, wA = rgb0.shape[:2] wA_new, hA_new = get_resized_wh(wA, hA, [_h_, _w_]) wA_new, hA_new = get_divisible_wh(wA_new, hA_new, 8) rgb0 = cv2.resize(rgb0, (wA_new, hA_new), interpolation=cv2.INTER_AREA) # go on gray0 = cv2.cvtColor(rgb0, cv2.COLOR_RGB2GRAY) # semantic segmentation with torch.no_grad(): seg_path0 = join(segment_root, '{}.npy'.format(idx0)) if not os.path.exists(seg_path0): mask0 = segment(_rgb0_, device, segmentation_module) np.save(seg_path0, mask0) else: mask0 = np.load(seg_path0) # process next frame image xB0, xB1, yB0, yB1, hB, wB, wB_new, hB_new = None, None, None, None, None, None, None, None if opt.resize: # crop rgb1 xB0 = math.floor(pt1[:, 0].min()) xB1 = math.ceil(pt1[:, 0].max()) yB0 = math.floor(pt1[:, 1].min()) yB1 = math.ceil(pt1[:, 1].max()) rgb1 = rgb1[yB0:yB1, xB0:xB1] hB, wB = rgb1.shape[:2] wB_new, hB_new = get_resized_wh(wB, hB, [_h_, _w_]) wB_new, hB_new = get_divisible_wh(wB_new, hB_new, 8) rgb1 = cv2.resize(rgb1, (wB_new, hB_new), interpolation=cv2.INTER_AREA) # go on gray1 = cv2.cvtColor(rgb1, cv2.COLOR_RGB2GRAY) # semantic segmentation with torch.no_grad(): seg_path1 = join(segment_root, '{}.npy'.format(idx1)) if not os.path.exists(seg_path1): mask1 = segment(_rgb1_, device, segmentation_module) np.save(seg_path1, mask1) else: mask1 = np.load(seg_path1) if mask0.shape[:2] != _rgb0_.shape[:2]: mask0 = cv2.resize(mask0, _rgb0_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) if mask1.shape != _rgb1_.shape[:2]: mask1 = cv2.resize(mask1, _rgb1_.shape[:2][::-1], interpolation=cv2.INTER_NEAREST) if opt.resize: # resize mask0 mask0 = mask0[yA0:yA1, xA0:xA1] mask0 = cv2.resize(mask0, (wA_new, hA_new), interpolation=cv2.INTER_NEAREST) # resize mask1 mask1 = mask1[yB0:yB1, xB0:xB1] mask1 = cv2.resize(mask1, (wB_new, hB_new), interpolation=cv2.INTER_NEAREST) data = None if opt.method == 'SIFT': mask_0 = mask0 != CLS_DICT[exclude[0]] mask_1 = mask1 != CLS_DICT[exclude[0]] for cls in exclude[1:]: mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) mask_0 = mask_0.astype(np.uint8) mask_1 = mask_1.astype(np.uint8) if mask_0.sum() == 0 or mask_1.sum() == 0: continue # keypoint detection and description kpts0, desc0 = detectAndCompute(rgb0, mask_0) if desc0 is None or desc0.shape[0] < 8: continue kpts0 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts0]) kpts0, desc0 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts0, desc0]) desc0 = (desc0 / desc0.sum(dim=1, keepdim=True)).sqrt() # keypoint detection and description kpts1, desc1 = detectAndCompute(rgb1, mask_1) if desc1 is None or desc1.shape[0] < 8: continue kpts1 = np.array([[kp.pt[0], kp.pt[1]] for kp in kpts1]) kpts1, desc1 = map(lambda x: torch.from_numpy(x).to(device).float(), [kpts1, desc1]) desc1 = (desc1 / desc1.sum(dim=1, keepdim=True)).sqrt() # mutual nearest matching and ratio filter matches = desc0 @ desc1.transpose(0, 1) mask = (matches == matches.max(dim=1, keepdim=True).values) & \ (matches == matches.max(dim=0, keepdim=True).values) # noinspection PyUnresolvedReferences valid, indices = mask.max(dim=1) ratio = torch.topk(matches, k=2, dim=1).values ratio = (-2 * ratio + 2).sqrt() # ratio = (ratio[:, 0] / ratio[:, 1]) < opt.mt ratio = (ratio[:, 0] / ratio[:, 1]) < 0.8 valid = valid & ratio # get matched keypoints mkpts0 = kpts0[valid] mkpts1 = kpts1[indices[valid]] b_ids = torch.where(valid[None])[0] data = dict( m_bids = b_ids, mkpts0_f = mkpts0, mkpts1_f = mkpts1, ) elif opt.method == 'GIM_DKM': mask_0 = mask0 != CLS_DICT[exclude[0]] mask_1 = mask1 != CLS_DICT[exclude[0]] for cls in exclude[1:]: mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) mask_0 = mask_0.astype(np.uint8) mask_1 = mask_1.astype(np.uint8) if mask_0.sum() == 0 or mask_1.sum() == 0: continue img0 = rgb0 * mask_0[..., None] img1 = rgb1 * mask_1[..., None] width0, height0 = img0.shape[1], img0.shape[0] width1, height1 = img1.shape[1], img1.shape[0] with torch.no_grad(): with warnings.catch_warnings(): warnings.simplefilter("ignore") img0 = torch.from_numpy(img0).permute(2, 0, 1).to(device)[None] / 255 img1 = torch.from_numpy(img1).permute(2, 0, 1).to(device)[None] / 255 dense_matches, dense_certainty = model.match(img0, img1) sparse_matches, mconf = model.sample(dense_matches, dense_certainty, 5000) mkpts0 = sparse_matches[:, :2] mkpts0 = torch.stack((width0 * (mkpts0[:, 0] + 1) / 2, height0 * (mkpts0[:, 1] + 1) / 2), dim=-1) mkpts1 = sparse_matches[:, 2:] mkpts1 = torch.stack((width1 * (mkpts1[:, 0] + 1) / 2, height1 * (mkpts1[:, 1] + 1) / 2), dim=-1) m_bids = torch.zeros(sparse_matches.shape[0], dtype=torch.long, device=device) data = dict( m_bids = m_bids, mkpts0_f = mkpts0, mkpts1_f = mkpts1, ) elif opt.method == 'GIM_LOFTR': mask_0 = mask0 != CLS_DICT[exclude[0]] mask_1 = mask1 != CLS_DICT[exclude[0]] for cls in exclude[1:]: mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) mask_0 = mask_0.astype(np.uint8) mask_1 = mask_1.astype(np.uint8) if mask_0.sum() == 0 or mask_1.sum() == 0: continue mask_0 = cv2.resize(mask_0, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST) mask_1 = cv2.resize(mask_1, None, fx=1/8, fy=1/8, interpolation=cv2.INTER_NEAREST) data = dict( image0=gray2tensor(gray0), image1=gray2tensor(gray1), color0=color2tensor(rgb0), color1=color2tensor(rgb1), mask0=torch.from_numpy(mask_0)[None], mask1=torch.from_numpy(mask_1)[None], ) with torch.no_grad(): data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in data.items()} model(data) elif opt.method == 'GIM_GLUE': mask_0 = mask0 != CLS_DICT[exclude[0]] mask_1 = mask1 != CLS_DICT[exclude[0]] for cls in exclude[1:]: mask_0 = mask_0 & (mask0 != CLS_DICT[cls]) mask_1 = mask_1 & (mask1 != CLS_DICT[cls]) mask_0 = mask_0.astype(np.uint8) mask_1 = mask_1.astype(np.uint8) if mask_0.sum() == 0 or mask_1.sum() == 0: continue size0 = torch.tensor(gray0.shape[-2:][::-1])[None] size1 = torch.tensor(gray1.shape[-2:][::-1])[None] data = dict( gray0 = gray2tensor(gray0 * mask_0), gray1 = gray2tensor(gray1 * mask_1), size0 = size0, size1 = size1, ) with torch.no_grad(): data = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in data.items()} pred = model(data) kpts0, kpts1 = pred['keypoints0'][0], pred['keypoints1'][0] matches = pred['matches'][0] if len(matches) == 0: continue mkpts0 = kpts0[matches[..., 0]] mkpts1 = kpts1[matches[..., 1]] m_bids = torch.zeros(matches[..., 0].size(), dtype=torch.long, device=device) data = dict( m_bids = m_bids, mkpts0_f = mkpts0, mkpts1_f = mkpts1, ) # auto remove watermarker kpts0 = data['mkpts0_f'].clone() # (N, 2) kpts1 = data['mkpts1_f'].clone() # (N, 2) moved = ~((kpts0 - kpts1).abs() < 1).min(dim=1).values # (N) data['m_bids'] = data['m_bids'][moved] data['mkpts0_f'] = data['mkpts0_f'][moved] data['mkpts1_f'] = data['mkpts1_f'][moved] robust_fitting(data) if (data['inliers'] is None) or (sum(data['inliers'][0]) == 0): continue inliers = data['inliers'][0] if opt.debug: data.update(dict( # for debug visualization mask0 = mask0, mask1 = mask1, gray0 = gray0, gray1 = gray1, color0 = rgb0, color1 = rgb1, hw0_i = rgb0.shape[:2], hw1_i = rgb1.shape[:2], dataset_name = ['WALK'], scene_id = [video_name], pair_id = [[idx0, idx1]], imsize0=[[width, height]], imsize1=[[width, height]], )) out = fast_make_matching_robust_fitting_figure(data) cv2.imwrite(join(debug_dir, '{} {:8d} {:8d}.png'.format(scene_name, idx0, idx1)), cv2.cvtColor(out, cv2.COLOR_RGB2BGR)) continue if opt.resize: mkpts0_f = (data['mkpts0_f'].cpu().numpy()[inliers] * np.array([[wA/wA_new, hA/hA_new]]) + np.array([[xA0, yA0]])) * vratio mkpts1_f = (data['mkpts1_f'].cpu().numpy()[inliers] * np.array([[wB/wB_new, hB/hB_new]]) + np.array([[xB0, yB0]])) * vratio else: mkpts0_f = data['mkpts0_f'].cpu().numpy()[inliers] * vratio mkpts1_f = data['mkpts1_f'].cpu().numpy()[inliers] * vratio pts = np.concatenate([mkpts0_f, mkpts1_f], axis=1).astype(np.float32) nums = np.concatenate([nums, np.array([len(pts)])], axis=0) if len(nums) else np.array([len(pts)]) idxs = np.concatenate([idxs, current_id[None]], axis=0) if len(idxs) else current_id[None] with open(save_path, 'wb') as f: np.save(f, pts) with open(join(save_dir, 'nums.npy'), 'wb') as f: np.save(f, nums) with open(join(save_dir, 'idxs.npy'), 'wb') as f: np.save(f, idxs) def robust_fitting(data, b_id=0): m_bids = data['m_bids'].cpu().numpy() kpts0 = data['mkpts0_f'].cpu().numpy() kpts1 = data['mkpts1_f'].cpu().numpy() mask = m_bids == b_id # noinspection PyBroadException try: _, mask = cv2.findFundamentalMat(kpts0[mask], kpts1[mask], cv2.USAC_MAGSAC, ransacReprojThreshold=0.5, confidence=0.999999, maxIters=100000) mask = (mask.ravel() > 0)[None] except: mask = None data.update(dict(inliers=mask)) def get_resized_wh(w, h, resize): nh, nw = resize sh, sw = nh / h, nw / w scale = min(sh, sw) w_new, h_new = int(round(w*scale)), int(round(h*scale)) return w_new, h_new def get_divisible_wh(w, h, df=None): if df is not None: w_new = max((w // df), 1) * df h_new = max((h // df), 1) * df else: w_new, h_new = w, h return w_new, h_new def read_deeplab_image(img, size=1920): width, height = img.shape[1], img.shape[0] if max(width, height) > size: if width > height: img = cv2.resize(img, (size, int(size * height / width)), interpolation=cv2.INTER_AREA) else: img = cv2.resize(img, (int(size * width / height), size), interpolation=cv2.INTER_AREA) img = (torch.from_numpy(img).float() / 255).permute(2, 0, 1)[None] return img def read_segmentation_image(img): img = read_deeplab_image(img, size=720)[0] img = img - torch.tensor([0.485, 0.456, 0.406]).view(-1, 1, 1) img = img / torch.tensor([0.229, 0.224, 0.225]).view(-1, 1, 1) return img def segment(rgb, device, segmentation_module): img_data = read_segmentation_image(rgb) singleton_batch = {'img_data': img_data[None].to(device)} output_size = img_data.shape[1:] # Run the segmentation at the highest resolution. scores = segmentation_module(singleton_batch, segSize=output_size) # Get the predicted scores for each pixel _, pred = torch.max(scores, dim=1) return pred.cpu()[0].numpy().astype(np.uint8) def getLabel(pair, idxs, nums, h5py_i, h5py_f): """ Args: pair: [6965 6970] idxs: (N, 2) nums: (N,) h5py_i: (M, 2) h5py_f: (M, 2) Returns: pseudo_label (N, 4) """ i, j = np.where(idxs == pair) if len(i) == 0: return None assert (len(i) == len(j) == 2) and (i[0] == i[1]) and (j[0] == 0) and (j[1] == 1) i = i[0] nums = nums[:i+1] idx0, idx1 = sum(nums[:-1]), sum(nums) mkpts0 = h5py_i[idx0:idx1] mkpts1 = h5py_f[idx0:idx1] # (N, 2) return mkpts0, mkpts1 def fast_make_matching_robust_fitting_figure(data, b_id=0): b_mask = data['m_bids'] == b_id gray0 = data['gray0'] gray1 = data['gray1'] kpts0 = data['mkpts0_f'][b_mask].cpu().numpy() kpts1 = data['mkpts1_f'][b_mask].cpu().numpy() margin = 2 (h0, w0), (h1, w1) = data['hw0_i'], data['hw1_i'] h, w = max(h0, h1), max(w0, w1) H, W = margin * 5 + h * 4, margin * 3 + w * 2 # canvas out = 255 * np.ones((H, W), np.uint8) wx = [margin, margin + w0, margin + w + margin, margin + w + margin + w1] hx = lambda row: margin * row + h * (row-1) out = np.stack([out] * 3, -1) sh = hx(row=1) color0 = data['color0'] # (rH, rW, 3) color1 = data['color1'] # (rH, rW, 3) out[sh: sh + h0, wx[0]: wx[1]] = color0 out[sh: sh + h1, wx[2]: wx[3]] = color1 sh = hx(row=2) img0 = np.stack([gray0] * 3, -1) * 0 for cls in exclude: img0[data['mask0'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]] out[sh: sh + h0, wx[0]: wx[1]] = img0 img1 = np.stack([gray1] * 3, -1) * 0 for cls in exclude: img1[data['mask1'] == CLS_DICT[cls]] = PALETTE[CLS_DICT[cls]] out[sh: sh + h1, wx[2]: wx[3]] = img1 # before outlier filtering sh = hx(row=3) mkpts0, mkpts1 = np.round(kpts0).astype(int), np.round(kpts1).astype(int) out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): # display line end-points as circles c = (230, 216, 132) cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA) cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA) # after outlier filtering if data['inliers'] is not None: sh = hx(row=4) inliers = data['inliers'][b_id] mkpts0, mkpts1 = np.round(kpts0).astype(int)[inliers], np.round(kpts1).astype(int)[inliers] out[sh: sh + h0, wx[0]: wx[1]] = np.stack([gray0] * 3, -1) out[sh: sh + h1, wx[2]: wx[3]] = np.stack([gray1] * 3, -1) for (x0, y0), (x1, y1) in zip(mkpts0, mkpts1): # display line end-points as circles c = (230, 216, 132) cv2.circle(out, (x0, y0+sh), 3, c, -1, lineType=cv2.LINE_AA) cv2.circle(out, (x1 + margin + w, y1+sh), 3, c, -1, lineType=cv2.LINE_AA) # Big text. text = [ f' ', f'#Matches {len(kpts0)}', f'#Matches {sum(data["inliers"][b_id]) if data["inliers"] is not None else 0}', ] sc = min(H / 640., 1.0) Ht = int(30 * sc) # text height txt_color_fg = (255, 255, 255) # white txt_color_bg = (0, 0, 0) # black for i, t in enumerate(text): cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_bg, 2, cv2.LINE_AA) cv2.putText(out, t, (int(8 * sc), Ht * (i + 1)), cv2.FONT_HERSHEY_DUPLEX, 1.0 * sc, txt_color_fg, 1, cv2.LINE_AA) fingerprint = [ 'Dataset: {}'.format(data['dataset_name'][b_id]), 'Scene ID: {}'.format(data['scene_id'][b_id]), 'Pair ID: {}'.format(data['pair_id'][b_id]), 'Image sizes: {} - {}'.format(data['imsize0'][b_id], data['imsize1'][b_id]), ] sc = min(H / 640., 1.0) Ht = int(18 * sc) # text height txt_color_fg = (255, 255, 255) # white txt_color_bg = (0, 0, 0) # black for i, t in enumerate(reversed(fingerprint)): cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_bg, 2, cv2.LINE_AA) cv2.putText(out, t, (int(8 * sc), int(H - Ht * (i + .6))), cv2.FONT_HERSHEY_SIMPLEX, .5 * sc, txt_color_fg, 1, cv2.LINE_AA) return out if __name__ == '__main__': with torch.no_grad(): main()