Vincentqyw commited on
Commit
2c8b554
·
1 Parent(s): 44ae162

add: rord libs

Browse files
.gitignore CHANGED
@@ -1,8 +1,6 @@
1
  build/
2
-
3
- lib/
4
  bin/
5
-
6
  cmake_modules/
7
  cmake-build-debug/
8
  .idea/
 
1
  build/
2
+ # lib
 
3
  bin/
 
4
  cmake_modules/
5
  cmake-build-debug/
6
  .idea/
third_party/RoRD/lib/__init__.py ADDED
File without changes
third_party/RoRD/lib/dataloaders/datasetPhotoTourism_combined.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ import random
5
+
6
+ import h5py
7
+ import numpy as np
8
+ from PIL import Image
9
+ from tqdm import tqdm
10
+ import joblib
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+ from torch.utils.data import DataLoader
15
+
16
+ from lib.utils import preprocess_image
17
+ from lib.utils import preprocess_image, grid_positions, upscale_positions
18
+ from lib.dataloaders.datasetPhotoTourism_ipr import PhotoTourismIPR
19
+ from lib.dataloaders.datasetPhotoTourism_real import PhotoTourism
20
+
21
+ from sys import exit, argv
22
+ import cv2
23
+ import csv
24
+
25
+ np.random.seed(0)
26
+
27
+
28
+ class PhotoTourismCombined(Dataset):
29
+ def __init__(self, base_path, preprocessing, ipr_pref=0.5, train=True, cropSize=256):
30
+ self.base_path = base_path
31
+ self.preprocessing = preprocessing
32
+ self.cropSize=cropSize
33
+
34
+ self.ipr_pref = ipr_pref
35
+
36
+ # self.dataset_len = 0
37
+ # self.dataset_len2 = 0
38
+
39
+ print("[INFO] Building Original Dataset")
40
+ self.PTReal = PhotoTourism(base_path, preprocessing=preprocessing, train=train, image_size=cropSize)
41
+ self.PTReal.build_dataset()
42
+
43
+ # self.dataset_len1 = len(self.PTReal)
44
+ # print("size 1:",len(self.PTReal))
45
+ # for _ in self.PTReal:
46
+ # pass
47
+ # print("size 2:",len(self.PTReal))
48
+ self.dataset_len1 = len(self.PTReal)
49
+ # joblib.dump(self.PTReal.dataset, os.path.join(self.base_path, "orig_PT_2.gz"), 3)
50
+
51
+ print("[INFO] Building IPR Dataset")
52
+ self.PTipr = PhotoTourismIPR(base_path, preprocessing=preprocessing, train=train, cropSize=cropSize)
53
+ self.PTipr.build_dataset()
54
+
55
+ # self.dataset_len2 = len(self.PTipr)
56
+ # print("size 1:",len(self.PTipr))
57
+ # for _ in self.PTipr:
58
+ # pass
59
+ # print("size 2:",len(self.PTipr))
60
+ self.dataset_len2 = len(self.PTipr)
61
+
62
+ # joblib.dump((self.PTipr.dataset_H, self.PTipr.valid_images), os.path.join(self.base_path, "ipr_PT_2.gz"), 3)
63
+
64
+ def __getitem__(self, idx):
65
+ if random.random()<self.ipr_pref:
66
+ return (self.PTipr[idx%self.dataset_len1], 1)
67
+ return (self.PTReal[idx%self.dataset_len2], 0)
68
+
69
+ def __len__(self):
70
+ return self.dataset_len2+self.dataset_len1
71
+
72
+
73
+ if __name__=="__main__":
74
+ pt = PhotoTourismCombined("/scratch/udit/phototourism/", 'caffe', 256)
75
+ dl = DataLoader(pt, batch_size=1, num_workers=2)
76
+ for _ in dl:
77
+ pass
third_party/RoRD/lib/dataloaders/datasetPhotoTourism_ipr.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from sys import exit, argv
3
+ import csv
4
+ import random
5
+
6
+ import joblib
7
+ import numpy as np
8
+ import cv2
9
+ from PIL import Image
10
+ from tqdm import tqdm
11
+
12
+ import torch
13
+ from torch.utils.data import Dataset
14
+
15
+ from lib.utils import preprocess_image, grid_positions, upscale_positions
16
+
17
+ np.random.seed(0)
18
+
19
+
20
+ class PhotoTourismIPR(Dataset):
21
+ def __init__(self, base_path, preprocessing, train=True, cropSize=256):
22
+ self.base_path = base_path
23
+ self.train = train
24
+ self.preprocessing = preprocessing
25
+ self.valid_images = []
26
+ self.cropSize=cropSize
27
+
28
+ def getImageFiles(self):
29
+ img_files = []
30
+ img_path = "dense/images"
31
+ if self.train:
32
+ print("Inside training!!")
33
+
34
+ with open(os.path.join("configs", "train_scenes_small.txt")) as f:
35
+ scenes = f.read().strip("\n").split("\n")
36
+
37
+ print("[INFO]",scenes)
38
+ for scene in scenes:
39
+ image_dir = os.path.join(self.base_path, scene, img_path)
40
+ img_names = os.listdir(image_dir)
41
+ img_files += [os.path.join(image_dir, img) for img in img_names]
42
+ return img_files
43
+
44
+ def imgCrop(self, img1):
45
+ w, h = img1.size
46
+ left = np.random.randint(low = 0, high = w - (self.cropSize))
47
+ upper = np.random.randint(low = 0, high = h - (self.cropSize))
48
+
49
+ cropImg = img1.crop((left, upper, left+self.cropSize, upper+self.cropSize))
50
+
51
+ return cropImg
52
+
53
+ def getGrid(self, im1, im2, H, scaling_steps=3):
54
+ h1, w1 = int(im1.shape[0]/(2**scaling_steps)), int(im1.shape[1]/(2**scaling_steps))
55
+ device = torch.device("cpu")
56
+
57
+ fmap_pos1 = grid_positions(h1, w1, device)
58
+ pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps).data.cpu().numpy()
59
+
60
+ pos1[[0, 1]] = pos1[[1, 0]]
61
+
62
+ ones = np.ones((1, pos1.shape[1]))
63
+ pos1Homo = np.vstack((pos1, ones))
64
+ pos2Homo = np.dot(H, pos1Homo)
65
+ pos2Homo = pos2Homo/pos2Homo[2, :]
66
+ pos2 = pos2Homo[0:2, :]
67
+
68
+ pos1[[0, 1]] = pos1[[1, 0]]
69
+ pos2[[0, 1]] = pos2[[1, 0]]
70
+ pos1 = pos1.astype(np.float32)
71
+ pos2 = pos2.astype(np.float32)
72
+
73
+ ids = []
74
+ for i in range(pos2.shape[1]):
75
+ x, y = pos2[:, i]
76
+
77
+ if(2 < x < (im1.shape[0]-2) and 2 < y < (im1.shape[1]-2)):
78
+ ids.append(i)
79
+ pos1 = pos1[:, ids]
80
+ pos2 = pos2[:, ids]
81
+
82
+ return pos1, pos2
83
+
84
+ def imgRotH(self, img1, min=0, max=360):
85
+ width, height = img1.size
86
+ theta = np.random.randint(low=min, high=max) * (np.pi / 180)
87
+ Tx = width / 2
88
+ Ty = height / 2
89
+ sx = random.uniform(-1e-2, 1e-2)
90
+ sy = random.uniform(-1e-2, 1e-2)
91
+ p1 = random.uniform(-1e-4, 1e-4)
92
+ p2 = random.uniform(-1e-4, 1e-4)
93
+
94
+ alpha = np.cos(theta)
95
+ beta = np.sin(theta)
96
+
97
+ He = np.matrix([[alpha, beta, Tx * (1 - alpha) - Ty * beta], [-beta, alpha, beta * Tx + (1 - alpha) * Ty], [0, 0, 1]])
98
+ Ha = np.matrix([[1, sy, 0], [sx, 1, 0], [0, 0, 1]])
99
+ Hp = np.matrix([[1, 0, 0], [0, 1, 0], [p1, p2, 1]])
100
+
101
+ H = He @ Ha @ Hp
102
+
103
+ return H, theta
104
+
105
+ def build_dataset(self):
106
+ print("Building Dataset.")
107
+
108
+ imgFiles = self.getImageFiles()
109
+
110
+ for idx in tqdm(range(len(imgFiles))):
111
+
112
+ img = imgFiles[idx]
113
+ img1 = Image.open(img)
114
+
115
+ if(img1.mode != 'RGB'):
116
+ img1 = img1.convert('RGB')
117
+ if(img1.size[0] < self.cropSize or img1.size[1] < self.cropSize):
118
+ continue
119
+
120
+ self.valid_images.append(img)
121
+
122
+ def __len__(self):
123
+ return len(self.valid_images)
124
+
125
+ def __getitem__(self, idx):
126
+ while 1:
127
+ try:
128
+ img = self.valid_images[idx]
129
+
130
+ img1 = Image.open(img)
131
+ img1 = self.imgCrop(img1)
132
+ width, height = img1.size
133
+
134
+ H, theta = self.imgRotH(img1, min=0, max=360)
135
+
136
+ img1 = np.array(img1)
137
+ img2 = cv2.warpPerspective(img1, H, dsize=(width,height))
138
+ img2 = np.array(img2)
139
+
140
+ pos1, pos2 = self.getGrid(img1, img2, H)
141
+
142
+ assert (len(pos1) != 0 and len(pos2) != 0)
143
+ break
144
+ except IndexError:
145
+ print("IndexError")
146
+ exit(1)
147
+ except:
148
+ del self.valid_images[idx]
149
+
150
+ img1 = preprocess_image(img1, preprocessing=self.preprocessing)
151
+ img2 = preprocess_image(img2, preprocessing=self.preprocessing)
152
+
153
+ return {
154
+ 'image1': torch.from_numpy(img1.astype(np.float32)),
155
+ 'image2': torch.from_numpy(img2.astype(np.float32)),
156
+ 'pos1': torch.from_numpy(pos1.astype(np.float32)),
157
+ 'pos2': torch.from_numpy(pos2.astype(np.float32)),
158
+ 'H': np.array(H),
159
+ 'theta': np.array([theta])
160
+ }
161
+
162
+
163
+ if __name__ == '__main__':
164
+ rootDir = argv[1]
165
+
166
+ training_dataset = PhotoTourismIPR(rootDir, 'caffe')
167
+ training_dataset.build_dataset()
168
+
169
+ data = training_dataset[0]
170
+ print(data['image1'].shape, data['image2'].shape, data['pos1'].shape, data['pos2'].shape, len(training_dataset))
third_party/RoRD/lib/dataloaders/datasetPhotoTourism_real.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import time
4
+ from tqdm import tqdm
5
+
6
+ import h5py
7
+ import numpy as np
8
+ from PIL import Image
9
+
10
+ import torch
11
+ from torch.utils.data import Dataset
12
+ from lib.utils import preprocess_image
13
+
14
+ import joblib
15
+
16
+
17
+ class PhotoTourism(Dataset):
18
+ def __init__(
19
+ self,
20
+ #scene_list_path='megadepth_utils/train_scenes.txt',
21
+ # scene_info_path='/local/dataset/megadepth/scene_info',
22
+ base_path='/scratch/udit/phototourism',
23
+ train=True,
24
+ preprocessing=None,
25
+ min_overlap_ratio=.5,
26
+ max_overlap_ratio=1,
27
+ max_scale_ratio=np.inf,
28
+ pairs_per_scene=500,
29
+ image_size=256
30
+ ):
31
+ if train:
32
+ scene_list_path = os.path.join(base_path, "train_scenes.txt.bkp")
33
+ else:
34
+ scene_list_path = os.path.join(base_path, "valid_scenes.txt")
35
+ self.scenes = []
36
+ with open(scene_list_path, 'r') as f:
37
+ lines = f.readlines()
38
+ for line in lines:
39
+ self.scenes.append(line.strip('\n'))
40
+
41
+ # self.scene_info_path = scene_info_path
42
+ self.base_path = base_path
43
+
44
+ self.train = train
45
+
46
+ self.preprocessing = preprocessing
47
+
48
+ self.min_overlap_ratio = min_overlap_ratio
49
+ self.max_overlap_ratio = max_overlap_ratio
50
+ self.max_scale_ratio = max_scale_ratio
51
+
52
+ self.pairs_per_scene = pairs_per_scene
53
+
54
+ self.image_size = image_size
55
+
56
+ self.dataset = []
57
+
58
+ def build_dataset(self):
59
+ cache_path = os.path.join(self.base_path, "orig_PT_2.gz")
60
+ if os.path.exists(cache_path):
61
+ self.dataset = joblib.load(cache_path)
62
+ return
63
+
64
+ self.dataset = []
65
+ if not self.train:
66
+ np_random_state = np.random.get_state()
67
+ np.random.seed(42)
68
+ print('Building the validation dataset...')
69
+ else:
70
+ print('Building a new training dataset...')
71
+
72
+ for scene in tqdm(self.scenes, total=len(self.scenes)):
73
+
74
+ scene_info_path = os.path.join(
75
+ self.base_path, scene, '%s.npz' % scene
76
+ )
77
+
78
+ if not os.path.exists(scene_info_path):
79
+ continue
80
+
81
+ scene_info = np.load(scene_info_path, allow_pickle=True)
82
+ overlap_matrix = scene_info['overlap_matrix']
83
+ scale_ratio_matrix = scene_info['scale_ratio_matrix']
84
+
85
+ valid = np.logical_and(
86
+ np.logical_and(
87
+ overlap_matrix >= self.min_overlap_ratio,
88
+ overlap_matrix <= self.max_overlap_ratio
89
+ ),
90
+ scale_ratio_matrix <= self.max_scale_ratio
91
+ )
92
+
93
+ pairs = np.vstack(np.where(valid))
94
+ try:
95
+ selected_ids = np.random.choice(
96
+ pairs.shape[1], self.pairs_per_scene
97
+ )
98
+ except:
99
+ return
100
+
101
+ image_paths = scene_info['image_paths']
102
+ depth_paths = scene_info['depth_paths']
103
+ points3D_id_to_2D = scene_info['points3D_id_to_2D']
104
+ points3D_id_to_ndepth = scene_info['points3D_id_to_ndepth']
105
+ intrinsics = scene_info['intrinsics']
106
+ poses = scene_info['poses']
107
+
108
+ for pair_idx in selected_ids:
109
+ idx1 = pairs[0, pair_idx]
110
+ idx2 = pairs[1, pair_idx]
111
+ matches = np.array(list(
112
+ points3D_id_to_2D[idx1].keys() &
113
+ points3D_id_to_2D[idx2].keys()
114
+ ))
115
+
116
+ # Scale filtering
117
+ matches_nd1 = np.array([points3D_id_to_ndepth[idx1][match] for match in matches])
118
+ matches_nd2 = np.array([points3D_id_to_ndepth[idx2][match] for match in matches])
119
+ scale_ratio = np.maximum(matches_nd1 / matches_nd2, matches_nd2 / matches_nd1)
120
+ matches = matches[np.where(scale_ratio <= self.max_scale_ratio)[0]]
121
+
122
+ point3D_id = np.random.choice(matches)
123
+ point2D1 = points3D_id_to_2D[idx1][point3D_id]
124
+ point2D2 = points3D_id_to_2D[idx2][point3D_id]
125
+ nd1 = points3D_id_to_ndepth[idx1][point3D_id]
126
+ nd2 = points3D_id_to_ndepth[idx2][point3D_id]
127
+ central_match = np.array([
128
+ point2D1[1], point2D1[0],
129
+ point2D2[1], point2D2[0]
130
+ ])
131
+ self.dataset.append({
132
+ 'image_path1': image_paths[idx1],
133
+ 'depth_path1': depth_paths[idx1],
134
+ 'intrinsics1': intrinsics[idx1],
135
+ 'pose1': poses[idx1],
136
+ 'image_path2': image_paths[idx2],
137
+ 'depth_path2': depth_paths[idx2],
138
+ 'intrinsics2': intrinsics[idx2],
139
+ 'pose2': poses[idx2],
140
+ 'central_match': central_match,
141
+ 'scale_ratio': max(nd1 / nd2, nd2 / nd1)
142
+ })
143
+ np.random.shuffle(self.dataset)
144
+ joblib.dump(self.dataset, cache_path, 3)
145
+ if not self.train:
146
+ np.random.set_state(np_random_state)
147
+
148
+ def __len__(self):
149
+ return len(self.dataset)
150
+
151
+ def recover_pair(self, pair_metadata):
152
+ depth_path1 = os.path.join(
153
+ self.base_path, pair_metadata['depth_path1']
154
+ )
155
+ with h5py.File(depth_path1, 'r') as hdf5_file:
156
+ depth1 = np.array(hdf5_file['/depth'])
157
+ assert(np.min(depth1) >= 0)
158
+ image_path1 = os.path.join(
159
+ self.base_path, pair_metadata['image_path1']
160
+ )
161
+ image1 = Image.open(image_path1)
162
+ if image1.mode != 'RGB':
163
+ image1 = image1.convert('RGB')
164
+ image1 = np.array(image1)
165
+ assert(image1.shape[0] == depth1.shape[0] and image1.shape[1] == depth1.shape[1])
166
+ intrinsics1 = pair_metadata['intrinsics1']
167
+ pose1 = pair_metadata['pose1']
168
+
169
+ depth_path2 = os.path.join(
170
+ self.base_path, pair_metadata['depth_path2']
171
+ )
172
+ with h5py.File(depth_path2, 'r') as hdf5_file:
173
+ depth2 = np.array(hdf5_file['/depth'])
174
+ assert(np.min(depth2) >= 0)
175
+ image_path2 = os.path.join(
176
+ self.base_path, pair_metadata['image_path2']
177
+ )
178
+ image2 = Image.open(image_path2)
179
+ if image2.mode != 'RGB':
180
+ image2 = image2.convert('RGB')
181
+ image2 = np.array(image2)
182
+ assert(image2.shape[0] == depth2.shape[0] and image2.shape[1] == depth2.shape[1])
183
+ intrinsics2 = pair_metadata['intrinsics2']
184
+ pose2 = pair_metadata['pose2']
185
+
186
+ central_match = pair_metadata['central_match']
187
+ image1, bbox1, image2, bbox2 = self.crop(image1, image2, central_match)
188
+
189
+ depth1 = depth1[
190
+ bbox1[0] : bbox1[0] + self.image_size,
191
+ bbox1[1] : bbox1[1] + self.image_size
192
+ ]
193
+ depth2 = depth2[
194
+ bbox2[0] : bbox2[0] + self.image_size,
195
+ bbox2[1] : bbox2[1] + self.image_size
196
+ ]
197
+
198
+ return (
199
+ image1, depth1, intrinsics1, pose1, bbox1,
200
+ image2, depth2, intrinsics2, pose2, bbox2
201
+ )
202
+
203
+ def crop(self, image1, image2, central_match):
204
+ bbox1_i = max(int(central_match[0]) - self.image_size // 2, 0)
205
+ if bbox1_i + self.image_size >= image1.shape[0]:
206
+ bbox1_i = image1.shape[0] - self.image_size
207
+ bbox1_j = max(int(central_match[1]) - self.image_size // 2, 0)
208
+ if bbox1_j + self.image_size >= image1.shape[1]:
209
+ bbox1_j = image1.shape[1] - self.image_size
210
+
211
+ bbox2_i = max(int(central_match[2]) - self.image_size // 2, 0)
212
+ if bbox2_i + self.image_size >= image2.shape[0]:
213
+ bbox2_i = image2.shape[0] - self.image_size
214
+ bbox2_j = max(int(central_match[3]) - self.image_size // 2, 0)
215
+ if bbox2_j + self.image_size >= image2.shape[1]:
216
+ bbox2_j = image2.shape[1] - self.image_size
217
+
218
+ return (
219
+ image1[
220
+ bbox1_i : bbox1_i + self.image_size,
221
+ bbox1_j : bbox1_j + self.image_size
222
+ ],
223
+ np.array([bbox1_i, bbox1_j]),
224
+ image2[
225
+ bbox2_i : bbox2_i + self.image_size,
226
+ bbox2_j : bbox2_j + self.image_size
227
+ ],
228
+ np.array([bbox2_i, bbox2_j])
229
+ )
230
+
231
+ def __getitem__(self, idx):
232
+ while 1:
233
+ try:
234
+ (
235
+ image1, depth1, intrinsics1, pose1, bbox1,
236
+ image2, depth2, intrinsics2, pose2, bbox2
237
+ ) = self.recover_pair(self.dataset[idx])
238
+ image1 = preprocess_image(image1, preprocessing=self.preprocessing)
239
+ image2 = preprocess_image(image2, preprocessing=self.preprocessing)
240
+ assert np.all(image1.shape==image2.shape)
241
+ break
242
+ except IndexError:
243
+ idx-=1
244
+ except:
245
+ del self.dataset[idx]
246
+
247
+ return {
248
+ 'image1': torch.from_numpy(image1.astype(np.float32)),
249
+ 'depth1': torch.from_numpy(depth1.astype(np.float32)),
250
+ 'intrinsics1': torch.from_numpy(intrinsics1.astype(np.float32)),
251
+ 'pose1': torch.from_numpy(pose1.astype(np.float32)),
252
+ 'bbox1': torch.from_numpy(bbox1.astype(np.float32)),
253
+ 'image2': torch.from_numpy(image2.astype(np.float32)),
254
+ 'depth2': torch.from_numpy(depth2.astype(np.float32)),
255
+ 'intrinsics2': torch.from_numpy(intrinsics2.astype(np.float32)),
256
+ 'pose2': torch.from_numpy(pose2.astype(np.float32)),
257
+ 'bbox2': torch.from_numpy(bbox2.astype(np.float32))
258
+ }
third_party/RoRD/lib/exceptions.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ class EmptyTensorError(Exception):
2
+ pass
3
+
4
+
5
+ class NoGradientError(Exception):
6
+ pass
third_party/RoRD/lib/extractMatchTop.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import numpy as np
3
+ import imageio
4
+ import torch
5
+ from tqdm import tqdm
6
+ import time
7
+ import scipy
8
+ import scipy.io
9
+ import scipy.misc
10
+
11
+ from lib.model_test import D2Net
12
+ from lib.utils import preprocess_image
13
+ from lib.pyramid import process_multiscale
14
+
15
+ import cv2
16
+ import matplotlib.pyplot as plt
17
+ import os
18
+ from sys import exit, argv
19
+ from PIL import Image
20
+ from skimage.feature import match_descriptors
21
+ from skimage.measure import ransac
22
+ from skimage.transform import ProjectiveTransform, AffineTransform
23
+ import pydegensac
24
+
25
+
26
+ def extractSingle(image, model, device):
27
+
28
+ with torch.no_grad():
29
+ keypoints, scores, descriptors = process_multiscale(
30
+ image.to(device).unsqueeze(0),
31
+ model,
32
+ scales=[1]
33
+ )
34
+
35
+ keypoints = keypoints[:, [1, 0, 2]]
36
+
37
+ feat = {}
38
+ feat['keypoints'] = keypoints
39
+ feat['scores'] = scores
40
+ feat['descriptors'] = descriptors
41
+
42
+ return feat
43
+
44
+
45
+ def siftMatching(img1, img2, HFile1, HFile2, device):
46
+ if HFile1 is not None:
47
+ H1 = np.load(HFile1)
48
+ H2 = np.load(HFile2)
49
+
50
+ rgbFile1 = img1
51
+ img1 = Image.open(img1)
52
+
53
+ if(img1.mode != 'RGB'):
54
+ img1 = img1.convert('RGB')
55
+ img1 = np.array(img1)
56
+
57
+ if HFile1 is not None:
58
+ img1 = cv2.warpPerspective(img1, H1, dsize=(400,400))
59
+
60
+ #### Visualization ####
61
+ # cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
62
+ # cv2.waitKey(0)
63
+
64
+ rgbFile2 = img2
65
+ img2 = Image.open(img2)
66
+
67
+ if(img2.mode != 'RGB'):
68
+ img2 = img2.convert('RGB')
69
+ img2 = np.array(img2)
70
+
71
+ if HFile2 is not None:
72
+ img2 = cv2.warpPerspective(img2, H2, dsize=(400,400))
73
+
74
+ #### Visualization ####
75
+ # cv2.imshow("Image", cv2.cvtColor(img2, cv2.COLOR_BGR2RGB))
76
+ # cv2.waitKey(0)
77
+
78
+ # surf = cv2.xfeatures2d.SURF_create(100) # SURF
79
+ surf = cv2.xfeatures2d.SIFT_create()
80
+
81
+ kp1, des1 = surf.detectAndCompute(img1, None)
82
+ kp2, des2 = surf.detectAndCompute(img2, None)
83
+
84
+ matches = mnn_matcher(
85
+ torch.from_numpy(des1).float().to(device=device),
86
+ torch.from_numpy(des2).float().to(device=device)
87
+ )
88
+
89
+ src_pts = np.float32([ kp1[m[0]].pt for m in matches ]).reshape(-1, 2)
90
+ dst_pts = np.float32([ kp2[m[1]].pt for m in matches ]).reshape(-1, 2)
91
+
92
+ if(src_pts.shape[0] < 5 or dst_pts.shape[0] < 5):
93
+ return [], []
94
+
95
+ H, inliers = pydegensac.findHomography(src_pts, dst_pts, 8.0, 0.99, 10000)
96
+
97
+ n_inliers = np.sum(inliers)
98
+
99
+ inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in src_pts[inliers]]
100
+ inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in dst_pts[inliers]]
101
+ placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(n_inliers)]
102
+
103
+ #### Visualization ####
104
+ image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None)
105
+ image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
106
+ # cv2.imshow('Matches', image3)
107
+ # cv2.waitKey()
108
+
109
+ src_pts = np.float32([ inlier_keypoints_left[m.queryIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
110
+ dst_pts = np.float32([ inlier_keypoints_right[m.trainIdx].pt for m in placeholder_matches ]).reshape(-1, 2)
111
+
112
+ if HFile1 is None:
113
+ return src_pts, dst_pts, image3, image3
114
+
115
+ orgSrc, orgDst = orgKeypoints(src_pts, dst_pts, H1, H2)
116
+ matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)
117
+
118
+ return orgSrc, orgDst, matchImg, image3
119
+
120
+
121
+ def orgKeypoints(src_pts, dst_pts, H1, H2):
122
+ ones = np.ones((src_pts.shape[0], 1))
123
+
124
+ src_pts = np.hstack((src_pts, ones))
125
+ dst_pts = np.hstack((dst_pts, ones))
126
+
127
+ orgSrc = np.linalg.inv(H1) @ src_pts.T
128
+ orgDst = np.linalg.inv(H2) @ dst_pts.T
129
+
130
+ orgSrc = orgSrc/orgSrc[2, :]
131
+ orgDst = orgDst/orgDst[2, :]
132
+
133
+ orgSrc = np.asarray(orgSrc)[0:2, :]
134
+ orgDst = np.asarray(orgDst)[0:2, :]
135
+
136
+ return orgSrc, orgDst
137
+
138
+
139
+ def drawOrg(image1, image2, orgSrc, orgDst):
140
+ img1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)
141
+ img2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)
142
+
143
+ for i in range(orgSrc.shape[1]):
144
+ im1 = cv2.circle(img1, (int(orgSrc[0, i]), int(orgSrc[1, i])), 3, (0, 0, 255), 1)
145
+ for i in range(orgDst.shape[1]):
146
+ im2 = cv2.circle(img2, (int(orgDst[0, i]), int(orgDst[1, i])), 3, (0, 0, 255), 1)
147
+
148
+ im4 = cv2.hconcat([im1, im2])
149
+ for i in range(orgSrc.shape[1]):
150
+ im4 = cv2.line(im4, (int(orgSrc[0, i]), int(orgSrc[1, i])), (int(orgDst[0, i]) + im1.shape[1], int(orgDst[1, i])), (0, 255, 0), 1)
151
+ im4 = cv2.cvtColor(im4, cv2.COLOR_BGR2RGB)
152
+ # cv2.imshow("Image", im4)
153
+ # cv2.waitKey(0)
154
+
155
+ return im4
156
+
157
+
158
+
159
+ def getPerspKeypoints(rgbFile1, rgbFile2, HFile1, HFile2, model, device):
160
+ if HFile1 is None:
161
+ igp1, img1 = read_and_process_image(rgbFile1, H=None)
162
+ else:
163
+ H1 = np.load(HFile1)
164
+ igp1, img1 = read_and_process_image(rgbFile1, H=H1)
165
+
166
+ c,h,w = igp1.shape
167
+
168
+ if HFile2 is None:
169
+ igp2, img2 = read_and_process_image(rgbFile2, H=None)
170
+ else:
171
+ H2 = np.load(HFile2)
172
+ igp2, img2 = read_and_process_image(rgbFile2, H=H2)
173
+
174
+ feat1 = extractSingle(igp1, model, device)
175
+ feat2 = extractSingle(igp2, model, device)
176
+
177
+ matches = mnn_matcher(
178
+ torch.from_numpy(feat1['descriptors']).to(device=device),
179
+ torch.from_numpy(feat2['descriptors']).to(device=device),
180
+ )
181
+ pos_a = feat1["keypoints"][matches[:, 0], : 2]
182
+ pos_b = feat2["keypoints"][matches[:, 1], : 2]
183
+
184
+ H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
185
+ pos_a = pos_a[inliers]
186
+ pos_b = pos_b[inliers]
187
+
188
+ inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
189
+ inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
190
+ placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]
191
+
192
+ image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
193
+ image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
194
+
195
+ #### Visualization ####
196
+ # cv2.imshow('Matches', image3)
197
+ # cv2.waitKey()
198
+
199
+ if HFile1 is None:
200
+ return pos_a, pos_b, image3, image3
201
+
202
+ orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
203
+ matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst) # Reproject matches to perspective View
204
+
205
+ return orgSrc, orgDst, matchImg, image3
206
+
207
+
208
+
209
+ ###### Ensemble
210
+ def read_and_process_image(img_path, resize=None, H=None, h=None, w=None, preprocessing='caffe'):
211
+ img1 = Image.open(img_path)
212
+ if resize:
213
+ img1 = img1.resize(resize)
214
+ if(img1.mode != 'RGB'):
215
+ img1 = img1.convert('RGB')
216
+ img1 = np.array(img1)
217
+ if H is not None:
218
+ img1 = cv2.warpPerspective(img1, H, dsize=(400, 400))
219
+ # cv2.imshow("Image", cv2.cvtColor(img1, cv2.COLOR_BGR2RGB))
220
+ # cv2.waitKey(0)
221
+ igp1 = torch.from_numpy(preprocess_image(img1, preprocessing=preprocessing).astype(np.float32))
222
+ return igp1, img1
223
+
224
+ def mnn_matcher_scorer(descriptors_a, descriptors_b, k=np.inf):
225
+ device = descriptors_a.device
226
+ sim = descriptors_a @ descriptors_b.t()
227
+ val1, nn12 = torch.max(sim, dim=1)
228
+ val2, nn21 = torch.max(sim, dim=0)
229
+ ids1 = torch.arange(0, sim.shape[0], device=device)
230
+ mask = (ids1 == nn21[nn12])
231
+ matches = torch.stack([ids1[mask], nn12[mask]]).t()
232
+ remaining_matches_dist = val1[mask]
233
+ return matches, remaining_matches_dist
234
+
235
+ def mnn_matcher(descriptors_a, descriptors_b):
236
+ device = descriptors_a.device
237
+ sim = descriptors_a @ descriptors_b.t()
238
+ nn12 = torch.max(sim, dim=1)[1]
239
+ nn21 = torch.max(sim, dim=0)[1]
240
+ ids1 = torch.arange(0, sim.shape[0], device=device)
241
+ mask = (ids1 == nn21[nn12])
242
+ matches = torch.stack([ids1[mask], nn12[mask]])
243
+ return matches.t().data.cpu().numpy()
244
+
245
+
246
+ def getPerspKeypointsEnsemble(model1, model2, rgbFile1, rgbFile2, HFile1, HFile2, device):
247
+ if HFile1 is None:
248
+ igp1, img1 = read_and_process_image(rgbFile1, H=None)
249
+ else:
250
+ H1 = np.load(HFile1)
251
+ igp1, img1 = read_and_process_image(rgbFile1, H=H1)
252
+
253
+ c,h,w = igp1.shape
254
+
255
+ if HFile2 is None:
256
+ igp2, img2 = read_and_process_image(rgbFile2, H=None)
257
+ else:
258
+ H2 = np.load(HFile2)
259
+ igp2, img2 = read_and_process_image(rgbFile2, H=H2)
260
+
261
+ with torch.no_grad():
262
+ keypoints_a1, scores_a1, descriptors_a1 = process_multiscale(
263
+ igp1.to(device).unsqueeze(0),
264
+ model1,
265
+ scales=[1]
266
+ )
267
+ keypoints_a1 = keypoints_a1[:, [1, 0, 2]]
268
+
269
+ keypoints_a2, scores_a2, descriptors_a2 = process_multiscale(
270
+ igp1.to(device).unsqueeze(0),
271
+ model2,
272
+ scales=[1]
273
+ )
274
+ keypoints_a2 = keypoints_a2[:, [1, 0, 2]]
275
+
276
+ keypoints_b1, scores_b1, descriptors_b1 = process_multiscale(
277
+ igp2.to(device).unsqueeze(0),
278
+ model1,
279
+ scales=[1]
280
+ )
281
+ keypoints_b1 = keypoints_b1[:, [1, 0, 2]]
282
+
283
+ keypoints_b2, scores_b2, descriptors_b2 = process_multiscale(
284
+ igp2.to(device).unsqueeze(0),
285
+ model2,
286
+ scales=[1]
287
+ )
288
+ keypoints_b2 = keypoints_b2[:, [1, 0, 2]]
289
+
290
+ # calculating matches for both models
291
+ matches1, dist_1 = mnn_matcher_scorer(
292
+ torch.from_numpy(descriptors_a1).to(device=device),
293
+ torch.from_numpy(descriptors_b1).to(device=device),
294
+ # len(matches1)
295
+ )
296
+ matches2, dist_2 = mnn_matcher_scorer(
297
+ torch.from_numpy(descriptors_a2).to(device=device),
298
+ torch.from_numpy(descriptors_b2).to(device=device),
299
+ # len(matches1)
300
+ )
301
+
302
+ full_matches = torch.cat([matches1, matches2])
303
+ full_dist = torch.cat([dist_1, dist_2])
304
+ assert len(full_dist)==(len(dist_1)+len(dist_2)), "something wrong"
305
+
306
+ k_final = len(full_dist)//2
307
+ # k_final = len(full_dist)
308
+ # k_final = max(len(dist_1), len(dist_2))
309
+ top_k_mask = torch.topk(full_dist, k=k_final)[1]
310
+ first = []
311
+ second = []
312
+
313
+ for valid_id in top_k_mask:
314
+ if valid_id<len(dist_1):
315
+ first.append(valid_id)
316
+ else:
317
+ second.append(valid_id-len(dist_1))
318
+ # final_matches = full_matches[top_k_mask]
319
+
320
+ matches1 = matches1[torch.tensor(first, device=device).long()].data.cpu().numpy()
321
+ matches2 = matches2[torch.tensor(second, device=device).long()].data.cpu().numpy()
322
+
323
+ pos_a1 = keypoints_a1[matches1[:, 0], : 2]
324
+ pos_b1 = keypoints_b1[matches1[:, 1], : 2]
325
+
326
+ pos_a2 = keypoints_a2[matches2[:, 0], : 2]
327
+ pos_b2 = keypoints_b2[matches2[:, 1], : 2]
328
+
329
+ pos_a = np.concatenate([pos_a1, pos_a2], 0)
330
+ pos_b = np.concatenate([pos_b1, pos_b2], 0)
331
+
332
+ # pos_a, pos_b, inliers = apply_ransac(pos_a, pos_b)
333
+ H, inliers = pydegensac.findHomography(pos_a, pos_b, 8.0, 0.99, 10000)
334
+ pos_a = pos_a[inliers]
335
+ pos_b = pos_b[inliers]
336
+
337
+ inlier_keypoints_left = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_a]
338
+ inlier_keypoints_right = [cv2.KeyPoint(point[0], point[1], 1) for point in pos_b]
339
+ placeholder_matches = [cv2.DMatch(idx, idx, 1) for idx in range(len(pos_a))]
340
+
341
+ image3 = cv2.drawMatches(img1, inlier_keypoints_left, img2, inlier_keypoints_right, placeholder_matches, None, matchColor=[0, 255, 0])
342
+ image3 = cv2.cvtColor(image3, cv2.COLOR_BGR2RGB)
343
+ # cv2.imshow('Matches', image3)
344
+ # cv2.waitKey()
345
+
346
+
347
+ orgSrc, orgDst = orgKeypoints(pos_a, pos_b, H1, H2)
348
+ matchImg = drawOrg(cv2.imread(rgbFile1), cv2.imread(rgbFile2), orgSrc, orgDst)
349
+
350
+ return orgSrc, orgDst, matchImg, image3
351
+
352
+
353
+ if __name__ == '__main__':
354
+ WEIGHTS = '../models/rord.pth'
355
+
356
+ srcR = argv[1]
357
+ trgR = argv[2]
358
+ srcH = argv[3]
359
+ trgH = argv[4]
360
+
361
+ orgSrc, orgDst = getPerspKeypoints(srcR, trgR, srcH, trgH, WEIGHTS, ('gpu'))
third_party/RoRD/lib/loss.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import matplotlib.pyplot as plt
3
+ import os
4
+ import numpy as np
5
+
6
+ import torch
7
+ import torch.nn.functional as F
8
+
9
+ from lib.utils import (
10
+ grid_positions,
11
+ upscale_positions,
12
+ downscale_positions,
13
+ savefig,
14
+ imshow_image
15
+ )
16
+ from lib.exceptions import NoGradientError, EmptyTensorError
17
+
18
+ matplotlib.use('Agg')
19
+
20
+
21
+ def loss_function(
22
+ model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None
23
+ ):
24
+ output = model({
25
+ 'image1': batch['image1'].to(device),
26
+ 'image2': batch['image2'].to(device)
27
+ })
28
+
29
+ loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
30
+ has_grad = False
31
+
32
+ n_valid_samples = 0
33
+ for idx_in_batch in range(batch['image1'].size(0)):
34
+ # Annotations
35
+ depth1 = batch['depth1'][idx_in_batch].to(device) # [h1, w1]
36
+ intrinsics1 = batch['intrinsics1'][idx_in_batch].to(device) # [3, 3]
37
+ pose1 = batch['pose1'][idx_in_batch].view(4, 4).to(device) # [4, 4]
38
+ bbox1 = batch['bbox1'][idx_in_batch].to(device) # [2]
39
+
40
+ depth2 = batch['depth2'][idx_in_batch].to(device)
41
+ intrinsics2 = batch['intrinsics2'][idx_in_batch].to(device)
42
+ pose2 = batch['pose2'][idx_in_batch].view(4, 4).to(device)
43
+ bbox2 = batch['bbox2'][idx_in_batch].to(device)
44
+
45
+ # Network output
46
+ dense_features1 = output['dense_features1'][idx_in_batch]
47
+ c, h1, w1 = dense_features1.size()
48
+ scores1 = output['scores1'][idx_in_batch].view(-1)
49
+
50
+ dense_features2 = output['dense_features2'][idx_in_batch]
51
+ _, h2, w2 = dense_features2.size()
52
+ scores2 = output['scores2'][idx_in_batch]
53
+
54
+ all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
55
+ descriptors1 = all_descriptors1
56
+
57
+ all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
58
+
59
+ # Warp the positions from image 1 to image 2
60
+ fmap_pos1 = grid_positions(h1, w1, device)
61
+ pos1 = upscale_positions(fmap_pos1, scaling_steps=scaling_steps)
62
+ try:
63
+ pos1, pos2, ids = warp(
64
+ pos1,
65
+ depth1, intrinsics1, pose1, bbox1,
66
+ depth2, intrinsics2, pose2, bbox2
67
+ )
68
+ except EmptyTensorError:
69
+ continue
70
+ fmap_pos1 = fmap_pos1[:, ids]
71
+ descriptors1 = descriptors1[:, ids]
72
+ scores1 = scores1[ids]
73
+
74
+ # Skip the pair if not enough GT correspondences are available
75
+ if ids.size(0) < 128:
76
+ continue
77
+
78
+ # Descriptors at the corresponding positions
79
+ fmap_pos2 = torch.round(
80
+ downscale_positions(pos2, scaling_steps=scaling_steps)
81
+ ).long()
82
+ descriptors2 = F.normalize(
83
+ dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
84
+ dim=0
85
+ )
86
+ positive_distance = 2 - 2 * (
87
+ descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
88
+ ).squeeze()
89
+
90
+ all_fmap_pos2 = grid_positions(h2, w2, device)
91
+ position_distance = torch.max(
92
+ torch.abs(
93
+ fmap_pos2.unsqueeze(2).float() -
94
+ all_fmap_pos2.unsqueeze(1)
95
+ ),
96
+ dim=0
97
+ )[0]
98
+ is_out_of_safe_radius = position_distance > safe_radius
99
+ distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
100
+ negative_distance2 = torch.min(
101
+ distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
102
+ dim=1
103
+ )[0]
104
+
105
+ all_fmap_pos1 = grid_positions(h1, w1, device)
106
+ position_distance = torch.max(
107
+ torch.abs(
108
+ fmap_pos1.unsqueeze(2).float() -
109
+ all_fmap_pos1.unsqueeze(1)
110
+ ),
111
+ dim=0
112
+ )[0]
113
+ is_out_of_safe_radius = position_distance > safe_radius
114
+ distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
115
+ negative_distance1 = torch.min(
116
+ distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
117
+ dim=1
118
+ )[0]
119
+
120
+ diff = positive_distance - torch.min(
121
+ negative_distance1, negative_distance2
122
+ )
123
+
124
+ scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
125
+
126
+ loss = loss + (
127
+ torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
128
+ torch.sum(scores1 * scores2)
129
+ )
130
+
131
+ has_grad = True
132
+ n_valid_samples += 1
133
+
134
+ # print(plot, batch['batch_idx'],batch['log_interval'])
135
+ if plot and batch['batch_idx'] % batch['log_interval'] == 0:
136
+ # print("should plot")
137
+ pos1_aux = pos1.cpu().numpy()
138
+ pos2_aux = pos2.cpu().numpy()
139
+ k = pos1_aux.shape[1]
140
+ col = np.random.rand(k, 3)
141
+ n_sp = 4
142
+ plt.figure()
143
+ plt.subplot(1, n_sp, 1)
144
+ im1 = imshow_image(
145
+ batch['image1'][idx_in_batch].cpu().numpy(),
146
+ preprocessing=batch['preprocessing']
147
+ )
148
+ plt.imshow(im1)
149
+ plt.scatter(
150
+ pos1_aux[1, :], pos1_aux[0, :],
151
+ s=0.25**2, c=col, marker=',', alpha=0.5
152
+ )
153
+ plt.axis('off')
154
+ plt.subplot(1, n_sp, 2)
155
+ plt.imshow(
156
+ output['scores1'][idx_in_batch].data.cpu().numpy(),
157
+ cmap='Reds'
158
+ )
159
+ plt.axis('off')
160
+ plt.subplot(1, n_sp, 3)
161
+ im2 = imshow_image(
162
+ batch['image2'][idx_in_batch].cpu().numpy(),
163
+ preprocessing=batch['preprocessing']
164
+ )
165
+ plt.imshow(im2)
166
+ plt.scatter(
167
+ pos2_aux[1, :], pos2_aux[0, :],
168
+ s=0.25**2, c=col, marker=',', alpha=0.5
169
+ )
170
+ plt.axis('off')
171
+ plt.subplot(1, n_sp, 4)
172
+ plt.imshow(
173
+ output['scores2'][idx_in_batch].data.cpu().numpy(),
174
+ cmap='Reds'
175
+ )
176
+ plt.axis('off')
177
+ savefig(os.path.join(plot_path, '%s.%02d.%02d.%d.png' % (
178
+ 'train' if batch['train'] else 'valid',
179
+ batch['epoch_idx'],
180
+ batch['batch_idx'] // batch['log_interval'],
181
+ idx_in_batch
182
+ )), dpi=300)
183
+ plt.close()
184
+
185
+ if not has_grad:
186
+ raise NoGradientError
187
+
188
+ loss = loss / n_valid_samples
189
+
190
+ return loss
191
+
192
+
193
+ def interpolate_depth(pos, depth):
194
+ device = pos.device
195
+
196
+ ids = torch.arange(0, pos.size(1), device=device)
197
+
198
+ h, w = depth.size()
199
+
200
+ i = pos[0, :]
201
+ j = pos[1, :]
202
+
203
+ # Valid corners
204
+ i_top_left = torch.floor(i).long()
205
+ j_top_left = torch.floor(j).long()
206
+ valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
207
+
208
+ i_top_right = torch.floor(i).long()
209
+ j_top_right = torch.ceil(j).long()
210
+ valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
211
+
212
+ i_bottom_left = torch.ceil(i).long()
213
+ j_bottom_left = torch.floor(j).long()
214
+ valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
215
+
216
+ i_bottom_right = torch.ceil(i).long()
217
+ j_bottom_right = torch.ceil(j).long()
218
+ valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
219
+
220
+ valid_corners = torch.min(
221
+ torch.min(valid_top_left, valid_top_right),
222
+ torch.min(valid_bottom_left, valid_bottom_right)
223
+ )
224
+
225
+ i_top_left = i_top_left[valid_corners]
226
+ j_top_left = j_top_left[valid_corners]
227
+
228
+ i_top_right = i_top_right[valid_corners]
229
+ j_top_right = j_top_right[valid_corners]
230
+
231
+ i_bottom_left = i_bottom_left[valid_corners]
232
+ j_bottom_left = j_bottom_left[valid_corners]
233
+
234
+ i_bottom_right = i_bottom_right[valid_corners]
235
+ j_bottom_right = j_bottom_right[valid_corners]
236
+
237
+ ids = ids[valid_corners]
238
+ if ids.size(0) == 0:
239
+ raise EmptyTensorError
240
+
241
+ # Valid depth
242
+ valid_depth = torch.min(
243
+ torch.min(
244
+ depth[i_top_left, j_top_left] > 0,
245
+ depth[i_top_right, j_top_right] > 0
246
+ ),
247
+ torch.min(
248
+ depth[i_bottom_left, j_bottom_left] > 0,
249
+ depth[i_bottom_right, j_bottom_right] > 0
250
+ )
251
+ )
252
+
253
+ i_top_left = i_top_left[valid_depth]
254
+ j_top_left = j_top_left[valid_depth]
255
+
256
+ i_top_right = i_top_right[valid_depth]
257
+ j_top_right = j_top_right[valid_depth]
258
+
259
+ i_bottom_left = i_bottom_left[valid_depth]
260
+ j_bottom_left = j_bottom_left[valid_depth]
261
+
262
+ i_bottom_right = i_bottom_right[valid_depth]
263
+ j_bottom_right = j_bottom_right[valid_depth]
264
+
265
+ ids = ids[valid_depth]
266
+ if ids.size(0) == 0:
267
+ raise EmptyTensorError
268
+
269
+ # Interpolation
270
+ i = i[ids]
271
+ j = j[ids]
272
+ dist_i_top_left = i - i_top_left.float()
273
+ dist_j_top_left = j - j_top_left.float()
274
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
275
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
276
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
277
+ w_bottom_right = dist_i_top_left * dist_j_top_left
278
+
279
+ interpolated_depth = (
280
+ w_top_left * depth[i_top_left, j_top_left] +
281
+ w_top_right * depth[i_top_right, j_top_right] +
282
+ w_bottom_left * depth[i_bottom_left, j_bottom_left] +
283
+ w_bottom_right * depth[i_bottom_right, j_bottom_right]
284
+ )
285
+
286
+ pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
287
+
288
+ return [interpolated_depth, pos, ids]
289
+
290
+
291
+ def uv_to_pos(uv):
292
+ return torch.cat([uv[1, :].view(1, -1), uv[0, :].view(1, -1)], dim=0)
293
+
294
+
295
+ def warp(
296
+ pos1,
297
+ depth1, intrinsics1, pose1, bbox1,
298
+ depth2, intrinsics2, pose2, bbox2
299
+ ):
300
+ device = pos1.device
301
+
302
+ Z1, pos1, ids = interpolate_depth(pos1, depth1)
303
+
304
+ # COLMAP convention
305
+ u1 = pos1[1, :] + bbox1[1] + .5
306
+ v1 = pos1[0, :] + bbox1[0] + .5
307
+
308
+ X1 = (u1 - intrinsics1[0, 2]) * (Z1 / intrinsics1[0, 0])
309
+ Y1 = (v1 - intrinsics1[1, 2]) * (Z1 / intrinsics1[1, 1])
310
+
311
+ XYZ1_hom = torch.cat([
312
+ X1.view(1, -1),
313
+ Y1.view(1, -1),
314
+ Z1.view(1, -1),
315
+ torch.ones(1, Z1.size(0), device=device)
316
+ ], dim=0)
317
+ XYZ2_hom = torch.chain_matmul(pose2, torch.inverse(pose1), XYZ1_hom)
318
+ XYZ2 = XYZ2_hom[: -1, :] / XYZ2_hom[-1, :].view(1, -1)
319
+
320
+ uv2_hom = torch.matmul(intrinsics2, XYZ2)
321
+ uv2 = uv2_hom[: -1, :] / uv2_hom[-1, :].view(1, -1)
322
+
323
+ u2 = uv2[0, :] - bbox2[1] - .5
324
+ v2 = uv2[1, :] - bbox2[0] - .5
325
+ uv2 = torch.cat([u2.view(1, -1), v2.view(1, -1)], dim=0)
326
+
327
+ annotated_depth, pos2, new_ids = interpolate_depth(uv_to_pos(uv2), depth2)
328
+
329
+ ids = ids[new_ids]
330
+ pos1 = pos1[:, new_ids]
331
+ estimated_depth = XYZ2[2, new_ids]
332
+
333
+ inlier_mask = torch.abs(estimated_depth - annotated_depth) < 0.05
334
+
335
+ ids = ids[inlier_mask]
336
+ if ids.size(0) == 0:
337
+ raise EmptyTensorError
338
+
339
+ pos2 = pos2[:, inlier_mask]
340
+ pos1 = pos1[:, inlier_mask]
341
+
342
+ return pos1, pos2, ids
third_party/RoRD/lib/losses/lossPhotoTourism.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib
2
+ import matplotlib.pyplot as plt
3
+
4
+ import numpy as np
5
+ import cv2
6
+ from sys import exit
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+ from lib.utils import (
12
+ grid_positions,
13
+ upscale_positions,
14
+ downscale_positions,
15
+ savefig,
16
+ imshow_image
17
+ )
18
+ from lib.exceptions import NoGradientError, EmptyTensorError
19
+
20
+ matplotlib.use('Agg')
21
+
22
+
23
+ def loss_function(
24
+ model, batch, device, margin=1, safe_radius=4, scaling_steps=3, plot=False, plot_path=None
25
+ ):
26
+ output = model({
27
+ 'image1': batch['image1'].to(device),
28
+ 'image2': batch['image2'].to(device)
29
+ })
30
+
31
+
32
+ loss = torch.tensor(np.array([0], dtype=np.float32), device=device)
33
+ has_grad = False
34
+
35
+ n_valid_samples = 0
36
+ for idx_in_batch in range(batch['image1'].size(0)):
37
+ # Network output
38
+ dense_features1 = output['dense_features1'][idx_in_batch]
39
+ c, h1, w1 = dense_features1.size()
40
+ scores1 = output['scores1'][idx_in_batch].view(-1)
41
+
42
+ dense_features2 = output['dense_features2'][idx_in_batch]
43
+ _, h2, w2 = dense_features2.size()
44
+ scores2 = output['scores2'][idx_in_batch]
45
+
46
+ all_descriptors1 = F.normalize(dense_features1.view(c, -1), dim=0)
47
+ descriptors1 = all_descriptors1
48
+
49
+ all_descriptors2 = F.normalize(dense_features2.view(c, -1), dim=0)
50
+
51
+ fmap_pos1 = grid_positions(h1, w1, device)
52
+
53
+ pos1 = batch['pos1'][idx_in_batch].to(device)
54
+ pos2 = batch['pos2'][idx_in_batch].to(device)
55
+
56
+ ids = idsAlign(pos1, device, h1, w1)
57
+
58
+ fmap_pos1 = fmap_pos1[:, ids]
59
+ descriptors1 = descriptors1[:, ids]
60
+ scores1 = scores1[ids]
61
+
62
+ # Skip the pair if not enough GT correspondences are available
63
+ if ids.size(0) < 128:
64
+ continue
65
+
66
+ # Descriptors at the corresponding positions
67
+ fmap_pos2 = torch.round(
68
+ downscale_positions(pos2, scaling_steps=scaling_steps)
69
+ ).long()
70
+
71
+ descriptors2 = F.normalize(
72
+ dense_features2[:, fmap_pos2[0, :], fmap_pos2[1, :]],
73
+ dim=0
74
+ )
75
+ positive_distance = 2 - 2 * (
76
+ descriptors1.t().unsqueeze(1) @ descriptors2.t().unsqueeze(2)
77
+ ).squeeze()
78
+
79
+ all_fmap_pos2 = grid_positions(h2, w2, device)
80
+ position_distance = torch.max(
81
+ torch.abs(
82
+ fmap_pos2.unsqueeze(2).float() -
83
+ all_fmap_pos2.unsqueeze(1)
84
+ ),
85
+ dim=0
86
+ )[0]
87
+ is_out_of_safe_radius = position_distance > safe_radius
88
+
89
+ distance_matrix = 2 - 2 * (descriptors1.t() @ all_descriptors2)
90
+
91
+ negative_distance2 = torch.min(
92
+ distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
93
+ dim=1
94
+ )[0]
95
+
96
+ all_fmap_pos1 = grid_positions(h1, w1, device)
97
+ position_distance = torch.max(
98
+ torch.abs(
99
+ fmap_pos1.unsqueeze(2).float() -
100
+ all_fmap_pos1.unsqueeze(1)
101
+ ),
102
+ dim=0
103
+ )[0]
104
+ is_out_of_safe_radius = position_distance > safe_radius
105
+
106
+ distance_matrix = 2 - 2 * (descriptors2.t() @ all_descriptors1)
107
+
108
+ negative_distance1 = torch.min(
109
+ distance_matrix + (1 - is_out_of_safe_radius.float()) * 10.,
110
+ dim=1
111
+ )[0]
112
+
113
+ diff = positive_distance - torch.min(
114
+ negative_distance1, negative_distance2
115
+ )
116
+
117
+ scores2 = scores2[fmap_pos2[0, :], fmap_pos2[1, :]]
118
+
119
+ loss = loss + (
120
+ torch.sum(scores1 * scores2 * F.relu(margin + diff)) /
121
+ (torch.sum(scores1 * scores2) )
122
+ )
123
+
124
+ has_grad = True
125
+ n_valid_samples += 1
126
+
127
+ if plot and batch['batch_idx'] % batch['log_interval'] == 0:
128
+ drawTraining(batch['image1'], batch['image2'], pos1, pos2, batch, idx_in_batch, output, save=True, plot_path=plot_path)
129
+
130
+ if not has_grad:
131
+ raise NoGradientError
132
+
133
+ loss = loss / (n_valid_samples )
134
+
135
+ return loss
136
+
137
+
138
+ def idsAlign(pos1, device, h1, w1):
139
+ pos1D = downscale_positions(pos1, scaling_steps=3)
140
+ row = pos1D[0, :]
141
+ col = pos1D[1, :]
142
+
143
+ ids = []
144
+
145
+ for i in range(row.shape[0]):
146
+
147
+ index = ((w1) * (row[i])) + (col[i])
148
+ ids.append(index)
149
+
150
+ ids = torch.round(torch.Tensor(ids)).long().to(device)
151
+
152
+ return ids
153
+
154
+
155
+ def drawTraining(image1, image2, pos1, pos2, batch, idx_in_batch, output, save=False, plot_path="train_viz"):
156
+ pos1_aux = pos1.cpu().numpy()
157
+ pos2_aux = pos2.cpu().numpy()
158
+
159
+ k = pos1_aux.shape[1]
160
+ col = np.random.rand(k, 3)
161
+ n_sp = 4
162
+ plt.figure()
163
+ plt.subplot(1, n_sp, 1)
164
+ im1 = imshow_image(
165
+ image1[0].cpu().numpy(),
166
+ preprocessing=batch['preprocessing']
167
+ )
168
+ plt.imshow(im1)
169
+ plt.scatter(
170
+ pos1_aux[1, :], pos1_aux[0, :],
171
+ s=0.25**2, c=col, marker=',', alpha=0.5
172
+ )
173
+ plt.axis('off')
174
+ plt.subplot(1, n_sp, 2)
175
+ plt.imshow(
176
+ output['scores1'][idx_in_batch].data.cpu().numpy(),
177
+ cmap='Reds'
178
+ )
179
+ plt.axis('off')
180
+ plt.subplot(1, n_sp, 3)
181
+ im2 = imshow_image(
182
+ image2[0].cpu().numpy(),
183
+ preprocessing=batch['preprocessing']
184
+ )
185
+ plt.imshow(im2)
186
+ plt.scatter(
187
+ pos2_aux[1, :], pos2_aux[0, :],
188
+ s=0.25**2, c=col, marker=',', alpha=0.5
189
+ )
190
+ plt.axis('off')
191
+ plt.subplot(1, n_sp, 4)
192
+ plt.imshow(
193
+ output['scores2'][idx_in_batch].data.cpu().numpy(),
194
+ cmap='Reds'
195
+ )
196
+ plt.axis('off')
197
+
198
+ if(save == True):
199
+ savefig(plot_path+'/%s.%02d.%02d.%d.png' % (
200
+ 'train' if batch['train'] else 'valid',
201
+ batch['epoch_idx'],
202
+ batch['batch_idx'] // batch['log_interval'],
203
+ idx_in_batch
204
+ ), dpi=300)
205
+ else:
206
+ plt.show()
207
+
208
+ plt.close()
209
+
210
+ im1 = cv2.cvtColor(im1, cv2.COLOR_BGR2RGB)
211
+ im2 = cv2.cvtColor(im2, cv2.COLOR_BGR2RGB)
212
+
213
+ for i in range(0, pos1_aux.shape[1], 5):
214
+ im1 = cv2.circle(im1, (pos1_aux[1, i], pos1_aux[0, i]), 1, (0, 0, 255), 2)
215
+ for i in range(0, pos2_aux.shape[1], 5):
216
+ im2 = cv2.circle(im2, (pos2_aux[1, i], pos2_aux[0, i]), 1, (0, 0, 255), 2)
217
+
218
+ im3 = cv2.hconcat([im1, im2])
219
+
220
+ for i in range(0, pos1_aux.shape[1], 5):
221
+ im3 = cv2.line(im3, (int(pos1_aux[1, i]), int(pos1_aux[0, i])), (int(pos2_aux[1, i]) + im1.shape[1], int(pos2_aux[0, i])), (0, 255, 0), 1)
222
+
223
+ if(save == True):
224
+ cv2.imwrite(plot_path+'/%s.%02d.%02d.%d.png' % (
225
+ 'train_corr' if batch['train'] else 'valid',
226
+ batch['epoch_idx'],
227
+ batch['batch_idx'] // batch['log_interval'],
228
+ idx_in_batch
229
+ ), im3)
230
+ else:
231
+ cv2.imshow('Image', im3)
232
+ cv2.waitKey(0)
third_party/RoRD/lib/model.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ import torchvision.models as models
6
+
7
+
8
+ class DenseFeatureExtractionModule(nn.Module):
9
+ def __init__(self, finetune_feature_extraction=False, use_cuda=True):
10
+ super(DenseFeatureExtractionModule, self).__init__()
11
+
12
+ model = models.vgg16()
13
+ vgg16_layers = [
14
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2',
15
+ 'pool1',
16
+ 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2',
17
+ 'pool2',
18
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3',
19
+ 'pool3',
20
+ 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3',
21
+ 'pool4',
22
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
23
+ 'pool5'
24
+ ]
25
+ conv4_3_idx = vgg16_layers.index('conv4_3')
26
+
27
+ self.model = nn.Sequential(
28
+ *list(model.features.children())[: conv4_3_idx + 1]
29
+ )
30
+
31
+ self.num_channels = 512
32
+
33
+ # Fix forward parameters
34
+ for param in self.model.parameters():
35
+ param.requires_grad = False
36
+ if finetune_feature_extraction:
37
+ # Unlock conv4_3
38
+ for param in list(self.model.parameters())[-2 :]:
39
+ param.requires_grad = True
40
+
41
+ if use_cuda:
42
+ self.model = self.model.cuda()
43
+
44
+ def forward(self, batch):
45
+ output = self.model(batch)
46
+ return output
47
+
48
+
49
+ class SoftDetectionModule(nn.Module):
50
+ def __init__(self, soft_local_max_size=3):
51
+ super(SoftDetectionModule, self).__init__()
52
+
53
+ self.soft_local_max_size = soft_local_max_size
54
+
55
+ self.pad = self.soft_local_max_size // 2
56
+
57
+ def forward(self, batch):
58
+ b = batch.size(0)
59
+
60
+ batch = F.relu(batch)
61
+
62
+ max_per_sample = torch.max(batch.view(b, -1), dim=1)[0]
63
+ exp = torch.exp(batch / max_per_sample.view(b, 1, 1, 1))
64
+ sum_exp = (
65
+ self.soft_local_max_size ** 2 *
66
+ F.avg_pool2d(
67
+ F.pad(exp, [self.pad] * 4, mode='constant', value=1.),
68
+ self.soft_local_max_size, stride=1
69
+ )
70
+ )
71
+ local_max_score = exp / sum_exp
72
+
73
+ depth_wise_max = torch.max(batch, dim=1)[0]
74
+ depth_wise_max_score = batch / depth_wise_max.unsqueeze(1)
75
+
76
+ all_scores = local_max_score * depth_wise_max_score
77
+ score = torch.max(all_scores, dim=1)[0]
78
+
79
+ score = score / torch.sum(score.view(b, -1), dim=1).view(b, 1, 1)
80
+
81
+ return score
82
+
83
+
84
+ class D2Net(nn.Module):
85
+ def __init__(self, model_file=None, use_cuda=True):
86
+ super(D2Net, self).__init__()
87
+
88
+ self.dense_feature_extraction = DenseFeatureExtractionModule(
89
+ finetune_feature_extraction=True,
90
+ use_cuda=use_cuda
91
+ )
92
+
93
+ self.detection = SoftDetectionModule()
94
+
95
+ if model_file is not None:
96
+ if use_cuda:
97
+ self.load_state_dict(torch.load(model_file)['model'])
98
+ else:
99
+ self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
100
+
101
+ def forward(self, batch):
102
+ b = batch['image1'].size(0)
103
+
104
+ dense_features = self.dense_feature_extraction(
105
+ torch.cat([batch['image1'], batch['image2']], dim=0)
106
+ )
107
+
108
+ scores = self.detection(dense_features)
109
+
110
+ dense_features1 = dense_features[: b, :, :, :]
111
+ dense_features2 = dense_features[b :, :, :, :]
112
+
113
+ scores1 = scores[: b, :, :]
114
+ scores2 = scores[b :, :, :]
115
+
116
+ return {
117
+ 'dense_features1': dense_features1,
118
+ 'scores1': scores1,
119
+ 'dense_features2': dense_features2,
120
+ 'scores2': scores2
121
+ }
third_party/RoRD/lib/model_test.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class DenseFeatureExtractionModule(nn.Module):
7
+ def __init__(self, use_relu=True, use_cuda=True):
8
+ super(DenseFeatureExtractionModule, self).__init__()
9
+
10
+ self.model = nn.Sequential(
11
+ nn.Conv2d(3, 64, 3, padding=1),
12
+ nn.ReLU(inplace=True),
13
+ nn.Conv2d(64, 64, 3, padding=1),
14
+ nn.ReLU(inplace=True),
15
+ nn.MaxPool2d(2, stride=2),
16
+ nn.Conv2d(64, 128, 3, padding=1),
17
+ nn.ReLU(inplace=True),
18
+ nn.Conv2d(128, 128, 3, padding=1),
19
+ nn.ReLU(inplace=True),
20
+ nn.MaxPool2d(2, stride=2),
21
+ nn.Conv2d(128, 256, 3, padding=1),
22
+ nn.ReLU(inplace=True),
23
+ nn.Conv2d(256, 256, 3, padding=1),
24
+ nn.ReLU(inplace=True),
25
+ nn.Conv2d(256, 256, 3, padding=1),
26
+ nn.ReLU(inplace=True),
27
+ nn.AvgPool2d(2, stride=1),
28
+ nn.Conv2d(256, 512, 3, padding=2, dilation=2),
29
+ nn.ReLU(inplace=True),
30
+ nn.Conv2d(512, 512, 3, padding=2, dilation=2),
31
+ nn.ReLU(inplace=True),
32
+ nn.Conv2d(512, 512, 3, padding=2, dilation=2),
33
+ )
34
+ self.num_channels = 512
35
+
36
+ self.use_relu = use_relu
37
+
38
+ if use_cuda:
39
+ self.model = self.model.cuda()
40
+
41
+ def forward(self, batch):
42
+ output = self.model(batch)
43
+ if self.use_relu:
44
+ output = F.relu(output)
45
+ return output
46
+
47
+
48
+ class D2Net(nn.Module):
49
+ def __init__(self, model_file=None, use_relu=True, use_cuda=False):
50
+ super(D2Net, self).__init__()
51
+
52
+ self.dense_feature_extraction = DenseFeatureExtractionModule(
53
+ use_relu=use_relu, use_cuda=use_cuda
54
+ )
55
+
56
+ self.detection = HardDetectionModule()
57
+
58
+ self.localization = HandcraftedLocalizationModule()
59
+
60
+ if model_file is not None:
61
+ if use_cuda:
62
+ self.load_state_dict(torch.load(model_file)['model'])
63
+ else:
64
+ self.load_state_dict(torch.load(model_file, map_location='cpu')['model'])
65
+
66
+ def forward(self, batch):
67
+ _, _, h, w = batch.size()
68
+ dense_features = self.dense_feature_extraction(batch)
69
+
70
+ detections = self.detection(dense_features)
71
+
72
+ displacements = self.localization(dense_features)
73
+
74
+ return {
75
+ 'dense_features': dense_features,
76
+ 'detections': detections,
77
+ 'displacements': displacements
78
+ }
79
+
80
+
81
+ class HardDetectionModule(nn.Module):
82
+ def __init__(self, edge_threshold=5):
83
+ super(HardDetectionModule, self).__init__()
84
+
85
+ self.edge_threshold = edge_threshold
86
+
87
+ self.dii_filter = torch.tensor(
88
+ [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
89
+ ).view(1, 1, 3, 3)
90
+ self.dij_filter = 0.25 * torch.tensor(
91
+ [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
92
+ ).view(1, 1, 3, 3)
93
+ self.djj_filter = torch.tensor(
94
+ [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
95
+ ).view(1, 1, 3, 3)
96
+
97
+ def forward(self, batch):
98
+ b, c, h, w = batch.size()
99
+ device = batch.device
100
+
101
+ depth_wise_max = torch.max(batch, dim=1)[0]
102
+ is_depth_wise_max = (batch == depth_wise_max)
103
+ del depth_wise_max
104
+
105
+ local_max = F.max_pool2d(batch, 3, stride=1, padding=1)
106
+ is_local_max = (batch == local_max)
107
+ del local_max
108
+
109
+ dii = F.conv2d(
110
+ batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
111
+ ).view(b, c, h, w)
112
+ dij = F.conv2d(
113
+ batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
114
+ ).view(b, c, h, w)
115
+ djj = F.conv2d(
116
+ batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
117
+ ).view(b, c, h, w)
118
+
119
+ det = dii * djj - dij * dij
120
+ tr = dii + djj
121
+ del dii, dij, djj
122
+
123
+ threshold = (self.edge_threshold + 1) ** 2 / self.edge_threshold
124
+ is_not_edge = torch.min(tr * tr / det <= threshold, det > 0)
125
+
126
+ detected = torch.min(
127
+ is_depth_wise_max,
128
+ torch.min(is_local_max, is_not_edge)
129
+ )
130
+ del is_depth_wise_max, is_local_max, is_not_edge
131
+
132
+ return detected
133
+
134
+
135
+ class HandcraftedLocalizationModule(nn.Module):
136
+ def __init__(self):
137
+ super(HandcraftedLocalizationModule, self).__init__()
138
+
139
+ self.di_filter = torch.tensor(
140
+ [[0, -0.5, 0], [0, 0, 0], [0, 0.5, 0]]
141
+ ).view(1, 1, 3, 3)
142
+ self.dj_filter = torch.tensor(
143
+ [[0, 0, 0], [-0.5, 0, 0.5], [0, 0, 0]]
144
+ ).view(1, 1, 3, 3)
145
+
146
+ self.dii_filter = torch.tensor(
147
+ [[0, 1., 0], [0, -2., 0], [0, 1., 0]]
148
+ ).view(1, 1, 3, 3)
149
+ self.dij_filter = 0.25 * torch.tensor(
150
+ [[1., 0, -1.], [0, 0., 0], [-1., 0, 1.]]
151
+ ).view(1, 1, 3, 3)
152
+ self.djj_filter = torch.tensor(
153
+ [[0, 0, 0], [1., -2., 1.], [0, 0, 0]]
154
+ ).view(1, 1, 3, 3)
155
+
156
+ def forward(self, batch):
157
+ b, c, h, w = batch.size()
158
+ device = batch.device
159
+
160
+ dii = F.conv2d(
161
+ batch.view(-1, 1, h, w), self.dii_filter.to(device), padding=1
162
+ ).view(b, c, h, w)
163
+ dij = F.conv2d(
164
+ batch.view(-1, 1, h, w), self.dij_filter.to(device), padding=1
165
+ ).view(b, c, h, w)
166
+ djj = F.conv2d(
167
+ batch.view(-1, 1, h, w), self.djj_filter.to(device), padding=1
168
+ ).view(b, c, h, w)
169
+ det = dii * djj - dij * dij
170
+
171
+ inv_hess_00 = djj / det
172
+ inv_hess_01 = -dij / det
173
+ inv_hess_11 = dii / det
174
+ del dii, dij, djj, det
175
+
176
+ di = F.conv2d(
177
+ batch.view(-1, 1, h, w), self.di_filter.to(device), padding=1
178
+ ).view(b, c, h, w)
179
+ dj = F.conv2d(
180
+ batch.view(-1, 1, h, w), self.dj_filter.to(device), padding=1
181
+ ).view(b, c, h, w)
182
+
183
+ step_i = -(inv_hess_00 * di + inv_hess_01 * dj)
184
+ step_j = -(inv_hess_01 * di + inv_hess_11 * dj)
185
+ del inv_hess_00, inv_hess_01, inv_hess_11, di, dj
186
+
187
+ return torch.stack([step_i, step_j], dim=1)
third_party/RoRD/lib/pyramid.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from lib.exceptions import EmptyTensorError
6
+ from lib.utils import interpolate_dense_features, upscale_positions
7
+
8
+
9
+ def process_multiscale(image, model, scales=[.5, 1, 2]):
10
+ b, _, h_init, w_init = image.size()
11
+ device = image.device
12
+ assert(b == 1)
13
+
14
+ all_keypoints = torch.zeros([3, 0])
15
+ all_descriptors = torch.zeros([
16
+ model.dense_feature_extraction.num_channels, 0
17
+ ])
18
+ all_scores = torch.zeros(0)
19
+
20
+ previous_dense_features = None
21
+ banned = None
22
+ for idx, scale in enumerate(scales):
23
+ current_image = F.interpolate(
24
+ image, scale_factor=scale,
25
+ mode='bilinear', align_corners=True
26
+ )
27
+ _, _, h_level, w_level = current_image.size()
28
+
29
+ dense_features = model.dense_feature_extraction(current_image)
30
+ del current_image
31
+
32
+ _, _, h, w = dense_features.size()
33
+
34
+ # Sum the feature maps.
35
+ if previous_dense_features is not None:
36
+ dense_features += F.interpolate(
37
+ previous_dense_features, size=[h, w],
38
+ mode='bilinear', align_corners=True
39
+ )
40
+ del previous_dense_features
41
+
42
+ # Recover detections.
43
+ detections = model.detection(dense_features)
44
+ if banned is not None:
45
+ banned = F.interpolate(banned.float(), size=[h, w]).bool()
46
+ detections = torch.min(detections, ~banned)
47
+ banned = torch.max(
48
+ torch.max(detections, dim=1)[0].unsqueeze(1), banned
49
+ )
50
+ else:
51
+ banned = torch.max(detections, dim=1)[0].unsqueeze(1)
52
+ fmap_pos = torch.nonzero(detections[0].cpu()).t()
53
+ del detections
54
+
55
+ # Recover displacements.
56
+ displacements = model.localization(dense_features)[0].cpu()
57
+ displacements_i = displacements[
58
+ 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
59
+ ]
60
+ displacements_j = displacements[
61
+ 1, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
62
+ ]
63
+ del displacements
64
+
65
+ mask = torch.min(
66
+ torch.abs(displacements_i) < 0.5,
67
+ torch.abs(displacements_j) < 0.5
68
+ )
69
+ fmap_pos = fmap_pos[:, mask]
70
+ valid_displacements = torch.stack([
71
+ displacements_i[mask],
72
+ displacements_j[mask]
73
+ ], dim=0)
74
+ del mask, displacements_i, displacements_j
75
+
76
+ fmap_keypoints = fmap_pos[1 :, :].float() + valid_displacements
77
+ del valid_displacements
78
+
79
+ try:
80
+ raw_descriptors, _, ids = interpolate_dense_features(
81
+ fmap_keypoints.to(device),
82
+ dense_features[0]
83
+ )
84
+ except EmptyTensorError:
85
+ continue
86
+ fmap_pos = fmap_pos.to(device)
87
+ fmap_keypoints = fmap_keypoints.to(device)
88
+ fmap_pos = fmap_pos[:, ids]
89
+ fmap_keypoints = fmap_keypoints[:, ids]
90
+ del ids
91
+
92
+ keypoints = upscale_positions(fmap_keypoints, scaling_steps=2)
93
+ del fmap_keypoints
94
+
95
+ descriptors = F.normalize(raw_descriptors, dim=0).cpu()
96
+ del raw_descriptors
97
+
98
+ keypoints[0, :] *= h_init / h_level
99
+ keypoints[1, :] *= w_init / w_level
100
+
101
+ fmap_pos = fmap_pos.cpu()
102
+ keypoints = keypoints.cpu()
103
+
104
+ keypoints = torch.cat([
105
+ keypoints,
106
+ torch.ones([1, keypoints.size(1)]) * 1 / scale,
107
+ ], dim=0)
108
+
109
+ scores = dense_features[
110
+ 0, fmap_pos[0, :], fmap_pos[1, :], fmap_pos[2, :]
111
+ ].cpu() / (idx + 1)
112
+ del fmap_pos
113
+
114
+ all_keypoints = torch.cat([all_keypoints, keypoints], dim=1)
115
+ all_descriptors = torch.cat([all_descriptors, descriptors], dim=1)
116
+ all_scores = torch.cat([all_scores, scores], dim=0)
117
+ del keypoints, descriptors
118
+
119
+ previous_dense_features = dense_features
120
+ del dense_features
121
+ del previous_dense_features, banned
122
+
123
+ keypoints = all_keypoints.t().detach().numpy()
124
+ del all_keypoints
125
+ scores = all_scores.detach().numpy()
126
+ del all_scores
127
+ descriptors = all_descriptors.t().detach().numpy()
128
+ del all_descriptors
129
+ return keypoints, scores, descriptors
third_party/RoRD/lib/utils.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+
3
+ import numpy as np
4
+
5
+ import torch
6
+
7
+ from lib.exceptions import EmptyTensorError
8
+
9
+
10
+ def preprocess_image(image, preprocessing=None):
11
+ image = image.astype(np.float32)
12
+ image = np.transpose(image, [2, 0, 1])
13
+ if preprocessing is None:
14
+ pass
15
+ elif preprocessing == 'caffe':
16
+ # RGB -> BGR
17
+ image = image[:: -1, :, :]
18
+ # Zero-center by mean pixel
19
+ mean = np.array([103.939, 116.779, 123.68])
20
+ image = image - mean.reshape([3, 1, 1])
21
+ elif preprocessing == 'torch':
22
+ image /= 255.0
23
+ mean = np.array([0.485, 0.456, 0.406])
24
+ std = np.array([0.229, 0.224, 0.225])
25
+ image = (image - mean.reshape([3, 1, 1])) / std.reshape([3, 1, 1])
26
+ else:
27
+ raise ValueError('Unknown preprocessing parameter.')
28
+ return image
29
+
30
+
31
+ def imshow_image(image, preprocessing=None):
32
+ if preprocessing is None:
33
+ pass
34
+ elif preprocessing == 'caffe':
35
+ mean = np.array([103.939, 116.779, 123.68])
36
+ image = image + mean.reshape([3, 1, 1])
37
+ # RGB -> BGR
38
+ image = image[:: -1, :, :]
39
+ elif preprocessing == 'torch':
40
+ mean = np.array([0.485, 0.456, 0.406])
41
+ std = np.array([0.229, 0.224, 0.225])
42
+ image = image * std.reshape([3, 1, 1]) + mean.reshape([3, 1, 1])
43
+ image *= 255.0
44
+ else:
45
+ raise ValueError('Unknown preprocessing parameter.')
46
+ image = np.transpose(image, [1, 2, 0])
47
+ image = np.round(image).astype(np.uint8)
48
+ return image
49
+
50
+
51
+ def grid_positions(h, w, device, matrix=False):
52
+ lines = torch.arange(
53
+ 0, h, device=device
54
+ ).view(-1, 1).float().repeat(1, w)
55
+ columns = torch.arange(
56
+ 0, w, device=device
57
+ ).view(1, -1).float().repeat(h, 1)
58
+ if matrix:
59
+ return torch.stack([lines, columns], dim=0)
60
+ else:
61
+ return torch.cat([lines.view(1, -1), columns.view(1, -1)], dim=0)
62
+
63
+
64
+ def upscale_positions(pos, scaling_steps=0):
65
+ for _ in range(scaling_steps):
66
+ pos = pos * 2 + 0.5
67
+ return pos
68
+
69
+
70
+ def downscale_positions(pos, scaling_steps=0):
71
+ for _ in range(scaling_steps):
72
+ pos = (pos - 0.5) / 2
73
+ return pos
74
+
75
+
76
+ def interpolate_dense_features(pos, dense_features, return_corners=False):
77
+ device = pos.device
78
+
79
+ ids = torch.arange(0, pos.size(1), device=device)
80
+
81
+ _, h, w = dense_features.size()
82
+
83
+ i = pos[0, :]
84
+ j = pos[1, :]
85
+
86
+ # Valid corners
87
+ i_top_left = torch.floor(i).long()
88
+ j_top_left = torch.floor(j).long()
89
+ valid_top_left = torch.min(i_top_left >= 0, j_top_left >= 0)
90
+
91
+ i_top_right = torch.floor(i).long()
92
+ j_top_right = torch.ceil(j).long()
93
+ valid_top_right = torch.min(i_top_right >= 0, j_top_right < w)
94
+
95
+ i_bottom_left = torch.ceil(i).long()
96
+ j_bottom_left = torch.floor(j).long()
97
+ valid_bottom_left = torch.min(i_bottom_left < h, j_bottom_left >= 0)
98
+
99
+ i_bottom_right = torch.ceil(i).long()
100
+ j_bottom_right = torch.ceil(j).long()
101
+ valid_bottom_right = torch.min(i_bottom_right < h, j_bottom_right < w)
102
+
103
+ valid_corners = torch.min(
104
+ torch.min(valid_top_left, valid_top_right),
105
+ torch.min(valid_bottom_left, valid_bottom_right)
106
+ )
107
+
108
+ i_top_left = i_top_left[valid_corners]
109
+ j_top_left = j_top_left[valid_corners]
110
+
111
+ i_top_right = i_top_right[valid_corners]
112
+ j_top_right = j_top_right[valid_corners]
113
+
114
+ i_bottom_left = i_bottom_left[valid_corners]
115
+ j_bottom_left = j_bottom_left[valid_corners]
116
+
117
+ i_bottom_right = i_bottom_right[valid_corners]
118
+ j_bottom_right = j_bottom_right[valid_corners]
119
+
120
+ ids = ids[valid_corners]
121
+ if ids.size(0) == 0:
122
+ raise EmptyTensorError
123
+
124
+ # Interpolation
125
+ i = i[ids]
126
+ j = j[ids]
127
+ dist_i_top_left = i - i_top_left.float()
128
+ dist_j_top_left = j - j_top_left.float()
129
+ w_top_left = (1 - dist_i_top_left) * (1 - dist_j_top_left)
130
+ w_top_right = (1 - dist_i_top_left) * dist_j_top_left
131
+ w_bottom_left = dist_i_top_left * (1 - dist_j_top_left)
132
+ w_bottom_right = dist_i_top_left * dist_j_top_left
133
+
134
+ descriptors = (
135
+ w_top_left * dense_features[:, i_top_left, j_top_left] +
136
+ w_top_right * dense_features[:, i_top_right, j_top_right] +
137
+ w_bottom_left * dense_features[:, i_bottom_left, j_bottom_left] +
138
+ w_bottom_right * dense_features[:, i_bottom_right, j_bottom_right]
139
+ )
140
+
141
+ pos = torch.cat([i.view(1, -1), j.view(1, -1)], dim=0)
142
+
143
+ if not return_corners:
144
+ return [descriptors, pos, ids]
145
+ else:
146
+ corners = torch.stack([
147
+ torch.stack([i_top_left, j_top_left], dim=0),
148
+ torch.stack([i_top_right, j_top_right], dim=0),
149
+ torch.stack([i_bottom_left, j_bottom_left], dim=0),
150
+ torch.stack([i_bottom_right, j_bottom_right], dim=0)
151
+ ], dim=0)
152
+ return [descriptors, pos, ids, corners]
153
+
154
+
155
+ def savefig(filepath, fig=None, dpi=None):
156
+ # TomNorway - https://stackoverflow.com/a/53516034
157
+ if not fig:
158
+ fig = plt.gcf()
159
+
160
+ plt.subplots_adjust(0, 0, 1, 1, 0, 0)
161
+ for ax in fig.axes:
162
+ ax.axis('off')
163
+ ax.margins(0, 0)
164
+ ax.xaxis.set_major_locator(plt.NullLocator())
165
+ ax.yaxis.set_major_locator(plt.NullLocator())
166
+
167
+ fig.savefig(filepath, pad_inches=0, bbox_inches='tight', dpi=dpi)