lsxi77777's picture
commit message
a930e1f
import torch
from torch import nn
import numpy as np
import cv2
# import torchvision.transforms as transforms
import torch.nn.functional as F
from yacs.config import CfgNode as CN
def lower_config(yacs_cfg):
if not isinstance(yacs_cfg, CN):
return yacs_cfg
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}
def upper_config(dict_cfg):
if not isinstance(dict_cfg, dict):
return dict_cfg
return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}
class DataIOWrapper(nn.Module):
"""
Pre-propcess data from different sources
"""
def __init__(self, model, config, ckpt=None):
super().__init__()
self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
torch.set_grad_enabled(False)
self.model = model
self.config = config
self.img0_size = config['img0_resize']
self.img1_size = config['img1_resize']
self.df = config['df']
self.padding = config['padding']
self.coarse_scale = config['coarse_scale']
if ckpt:
ckpt_dict = torch.load(ckpt)
self.model.load_state_dict(ckpt_dict['state_dict'])
self.model = self.model.eval().to(self.device)
def preprocess_image(self, img, device, resize=None, df=None, padding=None, cam_K=None, dist=None, gray_scale=True):
# xoftr takes grayscale input images
if gray_scale and len(img.shape) == 3:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
h, w = img.shape[:2]
new_K = None
img_undistorted = None
if cam_K is not None and dist is not None:
new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h))
img = cv2.undistort(img, cam_K, dist, None, new_K)
img_undistorted = img.copy()
if resize is not None:
scale = resize / max(h, w)
w_new, h_new = int(round(w*scale)), int(round(h*scale))
else:
w_new, h_new = w, h
if df is not None:
w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new])
img = cv2.resize(img, (w_new, h_new))
scale = np.array([w/w_new, h/h_new], dtype=np.float)
if padding: # padding
pad_to = max(h_new, w_new)
img, mask = self.pad_bottom_right(img, pad_to, ret_mask=True)
mask = torch.from_numpy(mask).to(device)
else:
mask = None
# img = transforms.functional.to_tensor(img).unsqueeze(0).to(device)
if len(img.shape) == 2: # grayscale image
img = torch.from_numpy(img)[None][None].cuda().float() / 255.0
else: # Color image
img = torch.from_numpy(img).permute(2, 0, 1)[None].float() / 255.0
return img, scale, mask, new_K, img_undistorted
def from_cv_imgs(self, img0, img1, K0=None, K1=None, dist0=None, dist1=None):
img0_tensor, scale0, mask0, new_K0, img0_undistorted = self.preprocess_image(
img0, self.device, resize=self.img0_size, df=self.df, padding=self.padding, cam_K=K0, dist=dist0)
img1_tensor, scale1, mask1, new_K1, img1_undistorted = self.preprocess_image(
img1, self.device, resize=self.img1_size, df=self.df, padding=self.padding, cam_K=K1, dist=dist1)
mkpts0, mkpts1, mconf = self.match_images(img0_tensor, img1_tensor, mask0, mask1)
mkpts0 = mkpts0 * scale0
mkpts1 = mkpts1 * scale1
matches = np.concatenate([mkpts0, mkpts1], axis=1)
data = {'matches':matches,
'mkpts0':mkpts0,
'mkpts1':mkpts1,
'mconf':mconf,
'img0':img0,
'img1':img1
}
if K0 is not None and dist0 is not None:
data.update({'new_K0':new_K0, 'img0_undistorted':img0_undistorted})
if K1 is not None and dist1 is not None:
data.update({'new_K1':new_K1, 'img1_undistorted':img1_undistorted})
return data
def from_paths(self, img0_pth, img1_pth, K0=None, K1=None, dist0=None, dist1=None, read_color=False):
imread_flag = cv2.IMREAD_COLOR if read_color else cv2.IMREAD_GRAYSCALE
img0 = cv2.imread(img0_pth, imread_flag)
img1 = cv2.imread(img1_pth, imread_flag)
return self.from_cv_imgs(img0, img1, K0=K0, K1=K1, dist0=dist0, dist1=dist1)
def match_images(self, image0, image1, mask0, mask1):
batch = {'image0': image0, 'image1': image1}
if mask0 is not None: # img_padding is True
if self.coarse_scale:
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
scale_factor=self.coarse_scale,
mode='nearest',
recompute_scale_factor=False)[0].bool()
batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)})
self.model(batch)
mkpts0 = batch['mkpts0_f'].cpu().numpy()
mkpts1 = batch['mkpts1_f'].cpu().numpy()
mconf = batch['mconf_f'].cpu().numpy()
return mkpts0, mkpts1, mconf
def pad_bottom_right(self, inp, pad_size, ret_mask=False):
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
mask = None
if inp.ndim == 2:
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
padded[:inp.shape[0], :inp.shape[1]] = inp
if ret_mask:
mask = np.zeros((pad_size, pad_size), dtype=bool)
mask[:inp.shape[0], :inp.shape[1]] = True
elif inp.ndim == 3:
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
padded[:, :inp.shape[1], :inp.shape[2]] = inp
if ret_mask:
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
mask[:, :inp.shape[1], :inp.shape[2]] = True
else:
raise NotImplementedError()
return padded, mask