File size: 6,362 Bytes
a930e1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import torch
from torch import nn
import numpy as np
import cv2
# import torchvision.transforms as transforms
import torch.nn.functional as F
from yacs.config import CfgNode as CN

def lower_config(yacs_cfg):
    if not isinstance(yacs_cfg, CN):
        return yacs_cfg
    return {k.lower(): lower_config(v) for k, v in yacs_cfg.items()}


def upper_config(dict_cfg):
    if not isinstance(dict_cfg, dict):
        return dict_cfg
    return {k.upper(): upper_config(v) for k, v in dict_cfg.items()}


class DataIOWrapper(nn.Module):
    """

    Pre-propcess data from different sources

    """

    def __init__(self, model, config, ckpt=None):
        super().__init__()

        self.device = torch.device('cuda:{}'.format(0) if torch.cuda.is_available() else 'cpu')
        torch.set_grad_enabled(False)
        self.model = model
        self.config = config
        self.img0_size = config['img0_resize'] 
        self.img1_size = config['img1_resize'] 
        self.df = config['df']
        self.padding = config['padding']
        self.coarse_scale = config['coarse_scale']

        if ckpt:
            ckpt_dict = torch.load(ckpt)
            self.model.load_state_dict(ckpt_dict['state_dict'])
            self.model = self.model.eval().to(self.device)

    def preprocess_image(self, img, device, resize=None, df=None, padding=None, cam_K=None, dist=None, gray_scale=True):
        # xoftr takes grayscale input images
        if gray_scale and len(img.shape) == 3:
            img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

        h, w = img.shape[:2]
        new_K = None
        img_undistorted = None
        if cam_K is not None and dist is not None:
            new_K, roi = cv2.getOptimalNewCameraMatrix(cam_K, dist, (w,h), 0, (w,h))
            img = cv2.undistort(img, cam_K, dist, None, new_K)
            img_undistorted = img.copy()
        
        if resize is not None:
            scale = resize / max(h, w)
            w_new, h_new = int(round(w*scale)), int(round(h*scale))
        else:
            w_new, h_new = w, h
        
        if df is not None:
            w_new, h_new = map(lambda x: int(x // df * df), [w_new, h_new])
        
        img = cv2.resize(img, (w_new, h_new))
        scale = np.array([w/w_new, h/h_new], dtype=np.float)
        if padding:  # padding
            pad_to = max(h_new, w_new)
            img, mask = self.pad_bottom_right(img, pad_to, ret_mask=True)
            mask = torch.from_numpy(mask).to(device)
        else:
            mask = None
        # img = transforms.functional.to_tensor(img).unsqueeze(0).to(device)
        if len(img.shape) == 2:  # grayscale image
            img = torch.from_numpy(img)[None][None].cuda().float() / 255.0
        else:  # Color image
            img = torch.from_numpy(img).permute(2, 0, 1)[None].float() / 255.0
        return img, scale, mask, new_K, img_undistorted
    
    def from_cv_imgs(self, img0, img1, K0=None, K1=None, dist0=None, dist1=None):
        img0_tensor, scale0, mask0, new_K0, img0_undistorted = self.preprocess_image(
            img0, self.device, resize=self.img0_size, df=self.df, padding=self.padding, cam_K=K0, dist=dist0)
        img1_tensor, scale1, mask1, new_K1, img1_undistorted = self.preprocess_image(
            img1, self.device, resize=self.img1_size, df=self.df, padding=self.padding, cam_K=K1, dist=dist1)
        mkpts0, mkpts1, mconf = self.match_images(img0_tensor, img1_tensor, mask0, mask1)
        mkpts0 = mkpts0 * scale0
        mkpts1 = mkpts1 * scale1
        matches = np.concatenate([mkpts0, mkpts1], axis=1)
        data = {'matches':matches,
                'mkpts0':mkpts0,
                'mkpts1':mkpts1,
                'mconf':mconf,
                'img0':img0,
                'img1':img1
                }
        if K0 is not None and dist0 is not None:
            data.update({'new_K0':new_K0, 'img0_undistorted':img0_undistorted})
        if K1 is not None and dist1 is not None:
            data.update({'new_K1':new_K1, 'img1_undistorted':img1_undistorted})
        return data

    def from_paths(self, img0_pth, img1_pth, K0=None, K1=None, dist0=None, dist1=None, read_color=False):
        
        imread_flag = cv2.IMREAD_COLOR if read_color else cv2.IMREAD_GRAYSCALE

        img0 = cv2.imread(img0_pth, imread_flag)
        img1 = cv2.imread(img1_pth, imread_flag)
        return self.from_cv_imgs(img0, img1, K0=K0, K1=K1, dist0=dist0, dist1=dist1)

    def match_images(self, image0, image1, mask0, mask1):
        batch = {'image0': image0, 'image1': image1}
        if mask0 is not None:  # img_padding is True
            if self.coarse_scale:
                [ts_mask_0, ts_mask_1] = F.interpolate(torch.stack([mask0, mask1], dim=0)[None].float(),
                                                       scale_factor=self.coarse_scale,
                                                       mode='nearest',
                                                       recompute_scale_factor=False)[0].bool()
            batch.update({'mask0': ts_mask_0.unsqueeze(0), 'mask1': ts_mask_1.unsqueeze(0)})
        self.model(batch)
        mkpts0 = batch['mkpts0_f'].cpu().numpy()
        mkpts1 = batch['mkpts1_f'].cpu().numpy()
        mconf = batch['mconf_f'].cpu().numpy()
        return mkpts0, mkpts1, mconf
    
    def pad_bottom_right(self, inp, pad_size, ret_mask=False):
        assert isinstance(pad_size, int) and pad_size >= max(inp.shape[-2:]), f"{pad_size} < {max(inp.shape[-2:])}"
        mask = None
        if inp.ndim == 2:
            padded = np.zeros((pad_size, pad_size), dtype=inp.dtype)
            padded[:inp.shape[0], :inp.shape[1]] = inp
            if ret_mask:
                mask = np.zeros((pad_size, pad_size), dtype=bool)
                mask[:inp.shape[0], :inp.shape[1]] = True
        elif inp.ndim == 3:
            padded = np.zeros((inp.shape[0], pad_size, pad_size), dtype=inp.dtype)
            padded[:, :inp.shape[1], :inp.shape[2]] = inp
            if ret_mask:
                mask = np.zeros((inp.shape[0], pad_size, pad_size), dtype=bool)
                mask[:, :inp.shape[1], :inp.shape[2]] = True
        else:
            raise NotImplementedError()
        return padded, mask