Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from torch import nn | |
import numpy as np | |
import cv2 | |
# import torchvision.transforms as transforms | |
import torch.nn.functional as F | |
from yacs.config import CfgNode as CN | |
def lower_config(yacs_cfg): | |
if not isinstance(yacs_cfg, CN): | |
return yacs_cfg | |
return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()} | |
def upper_config(dict_cfg): | |
if not isinstance(dict_cfg, dict): | |
return dict_cfg | |
return {k.upper(): upper_config(v) for k, v in dict_cfg.items()} | |
class DataIOWrapper(nn.Module): | |
""" | |
Pre-propcess data from different sources | |
""" | |
def __init__(self, model, config, ckpt=None): | |
super().__init__() | |
self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu') | |
torch.set_grad_enabled(False) | |
self.model = model | |
self.config = config | |
self.img0_size = config['img0_resize'] | |
self.img1_size = config['img1_resize'] | |
self.df = config['df'] | |
self.padding = config['padding'] | |
self.coarse_scale = config['coarse_scale'] | |
if ckpt: | |
ckpt_dict = torch.load(ckpt) | |
self.model.load_state_dict(ckpt_dict['state_dict']) | |
self.model = self.model.eval().to(self.device) | |
def preprocess_image(self, img, device, resize=None, df=None, padding=None, cam_K=None, dist=None, gray_scale=True): | |
# xoftr takes grayscale input images | |
if gray_scale and len(img.shape) == 3: | |
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
h, w = img.shape[:2] | |
new_K = None | |
img_undistorted = None | |
if cam_K is not None and dist is not None: | |
new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h)) | |
img = cv2.undistort(img, cam_K, dist, None, new_K) | |
img_undistorted = img.copy() | |
if resize is not None: | |
scale = resize / max(h, w) | |
w_new, h_new = int(round(w*scale)), int(round(h*scale)) | |
else: | |
w_new, h_new = w, h | |
if df is not None: | |
w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new]) | |
img = cv2.resize(img, (w_new, h_new)) | |
scale = np.array([w/w_new, h/h_new], dtype=np.float) | |
if padding: # padding | |
pad_to = max(h_new, w_new) | |
img, mask = self.pad_bottom_right(img, pad_to, ret_mask=True) | |
mask = torch.from_numpy(mask).to(device) | |
else: | |
mask = None | |
# img = transforms.functional.to_tensor(img).unsqueeze(0).to(device) | |
if len(img.shape) == 2: # grayscale image | |
img = torch.from_numpy(img)[None][None].cuda().float() / 255.0 | |
else: # Color image | |
img = torch.from_numpy(img).permute(2, 0, 1)[None].float() / 255.0 | |
return img, scale, mask, new_K, img_undistorted | |
def from_cv_imgs(self, img0, img1, K0=None, K1=None, dist0=None, dist1=None): | |
img0_tensor, scale0, mask0, new_K0, img0_undistorted = self.preprocess_image( | |
img0, self.device, resize=self.img0_size, df=self.df, padding=self.padding, cam_K=K0, dist=dist0) | |
img1_tensor, scale1, mask1, new_K1, img1_undistorted = self.preprocess_image( | |
img1, self.device, resize=self.img1_size, df=self.df, padding=self.padding, cam_K=K1, dist=dist1) | |
mkpts0, mkpts1, mconf = self.match_images(img0_tensor, img1_tensor, mask0, mask1) | |
mkpts0 = mkpts0 * scale0 | |
mkpts1 = mkpts1 * scale1 | |
matches = np.concatenate([mkpts0, mkpts1], axis=1) | |
data = {'matches':matches, | |
'mkpts0':mkpts0, | |
'mkpts1':mkpts1, | |
'mconf':mconf, | |
'img0':img0, | |
'img1':img1 | |
} | |
if K0 is not None and dist0 is not None: | |
data.update({'new_K0':new_K0, 'img0_undistorted':img0_undistorted}) | |
if K1 is not None and dist1 is not None: | |
data.update({'new_K1':new_K1, 'img1_undistorted':img1_undistorted}) | |
return data | |
def from_paths(self, img0_pth, img1_pth, K0=None, K1=None, dist0=None, dist1=None, read_color=False): | |
imread_flag = cv2.IMREAD_COLOR if read_color else cv2.IMREAD_GRAYSCALE | |
img0 = cv2.imread(img0_pth, imread_flag) | |
img1 = cv2.imread(img1_pth, imread_flag) | |
return self.from_cv_imgs(img0, img1, K0=K0, K1=K1, dist0=dist0, dist1=dist1) | |
def match_images(self, image0, image1, mask0, mask1): | |
batch = {'image0': image0, 'image1': image1} | |
if mask0 is not None: # img_padding is True | |
if self.coarse_scale: | |
[ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(), | |
scale_factor=self.coarse_scale, | |
mode='nearest', | |
recompute_scale_factor=False)[0].bool() | |
batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)}) | |
self.model(batch) | |
mkpts0 = batch['mkpts0_f'].cpu().numpy() | |
mkpts1 = batch['mkpts1_f'].cpu().numpy() | |
mconf = batch['mconf_f'].cpu().numpy() | |
return mkpts0, mkpts1, mconf | |
def pad_bottom_right(self, inp, pad_size, ret_mask=False): | |
assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}" | |
mask = None | |
if inp.ndim == 2: | |
padded = np.zeros((pad_size, pad_size), dtype=inp.dtype) | |
padded[:inp.shape[0], :inp.shape[1]] = inp | |
if ret_mask: | |
mask = np.zeros((pad_size, pad_size), dtype=bool) | |
mask[:inp.shape[0], :inp.shape[1]] = True | |
elif inp.ndim == 3: | |
padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype) | |
padded[:, :inp.shape[1], :inp.shape[2]] = inp | |
if ret_mask: | |
mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool) | |
mask[:, :inp.shape[1], :inp.shape[2]] = True | |
else: | |
raise NotImplementedError() | |
return padded, mask | |