YannisK commited on
Commit
9595159
·
1 Parent(s): c91b02f
Files changed (2) hide show
  1. app.py +6 -73
  2. jaipur1.jpeg +2 -2
app.py CHANGED
@@ -41,29 +41,12 @@ state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict);
41
  net_imagenet = fire_network.init_network(**state['net_params']).to(device)
42
  net_imagenet.load_state_dict(state2['state_dict'], strict=False)
43
 
44
- # ---------------------------------------
45
  transform = transforms.Compose([
46
  transforms.Resize(1024),
47
  transforms.ToTensor(),
48
  transforms.Normalize(**dict(zip(["mean", "std"], net_sfm.runtime['mean_std'])))
49
  ])
50
- # ---------------------------------------
51
-
52
- # class ImgDataset(data.Dataset):
53
- # def __init__(self, images, imsize):
54
- # self.images = images
55
- # self.imsize = imsize
56
- # self.transform = transforms.Compose([transforms.ToTensor(), \
57
- # transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
58
- # def __getitem__(self, index):
59
- # img = self.images[index]
60
- # img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
61
- # print('after imresize:', img.size)
62
- # return self.transform(img)
63
- # def __len__(self):
64
- # return len(self.images)
65
-
66
- # ---------------------------------------
67
 
68
  def match(query_feat, pos_feat, LoweRatioTh=0.9):
69
  # first perform reciprocal nn
@@ -87,14 +70,14 @@ def match(query_feat, pos_feat, LoweRatioTh=0.9):
87
  return pindices[valid]
88
 
89
 
90
- # sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
91
  def clear_figures():
92
  plt.figure().clear()
93
  plt.close()
94
  plt.cla()
95
  plt.clf()
96
 
97
- col = plt.get_cmap('tab10')
 
98
 
99
  def generate_matching_superfeatures(
100
  im1, im2,
@@ -105,15 +88,12 @@ def generate_matching_superfeatures(
105
  print('im2:', im2.size)
106
 
107
  clear_figures()
 
108
 
109
  net = net_sfm
110
  if Imagenet_model:
111
  net = net_imagenet
112
 
113
- # dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
114
- # loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
115
-
116
-
117
  im1_tensor = transform(im1).unsqueeze(0)
118
  im2_tensor = transform(im2).unsqueeze(0)
119
 
@@ -134,32 +114,17 @@ def generate_matching_superfeatures(
134
 
135
  feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1)
136
  feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1)
137
- print('feats1n.shape', feats1n.shape)
138
  ind_match = match(feats1n, feats2n)
139
- print('ind', ind_match)
140
- print('ind.shape', ind_match.shape)
141
- # outputs = []
142
- # for im_tensor in loader:
143
- # outputs.append(net.get_superfeatures(im_tensor.to(device), scales=[scales[scale_id]]))
144
- # feats1 = outputs[0][0][0]
145
- # attns1 = outputs[0][1][0]
146
- # strenghts1 = outputs[0][2][0]
147
- # feats2 = outputs[1][0][0]
148
- # attns2 = outputs[1][1][0]
149
- # strenghts2 = outputs[1][2][0]
150
- print(feats1.shape, feats2.shape)
151
- print(attns1.shape, attns2.shape)
152
- print(strenghts1.shape, strenghts2.shape)
153
 
154
  # which sf
155
- sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
156
  n_sf_ids = 10
157
  if random_mode or sf_ids == '':
158
  sf_idx_ = np.random.randint(256, size=n_sf_ids)
159
  else:
160
  sf_idx_ = map(int, sf_ids.strip().split(','))
161
 
162
- # if only_matching:
163
  if random_mode:
164
  sf_idx_ = [int(jj) for jj in ind_match[np.random.randint(len(list(ind_match)), size=n_sf_ids)].numpy()]
165
  sf_idx_ = list( dict.fromkeys(sf_idx_) )
@@ -172,11 +137,9 @@ def generate_matching_superfeatures(
172
  all_att_bin1 = []
173
  all_att_bin2 = []
174
  for n, i in enumerate(sf_idx_):
175
- # all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy())
176
  att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
177
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
178
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
179
- # print(att_heat_bin)
180
  all_att_bin1.append(att_heat_bin)
181
 
182
  att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
@@ -187,19 +150,11 @@ def generate_matching_superfeatures(
187
 
188
  fin_img = []
189
  img1rsz = np.copy(im1_cv)
190
- print('im1:', im1.size)
191
- print('img1rsz:', img1rsz.shape)
192
  for j, att in enumerate(all_att_bin1):
193
  att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
194
- # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
195
- # att = cv2.resize(att, imgz[i].shape[:2][::-1])
196
- # att = att.resize(shape)
197
- # att = resize(att, im1.size)
198
  mask2d = zip(*np.where(att==255))
199
  for m,n in mask2d:
200
  col_ = col.colors[j]
201
- # col_ = col.colors[j] if j < 7 else col.colors[j+1]
202
- # if j == 0: col_ = col.colors[9]
203
  col_ = 255*np.array(colors.to_rgba(col_))[:3]
204
  img1rsz[m,n, :] = col_[::-1]
205
 
@@ -208,51 +163,35 @@ def generate_matching_superfeatures(
208
  print('img2rsz:', img2rsz.shape)
209
  for j, att in enumerate(all_att_bin2):
210
  att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
211
- # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
212
- # # att = cv2.resize(att, imgz[i].shape[:2][::-1])
213
- # att = att.resize(im2.shape)
214
- # print('att:', att.shape)
215
  mask2d = zip(*np.where(att==255))
216
  for m,n in mask2d:
217
  col_ = col.colors[j]
218
- # col_ = col.colors[j] if j < 7 else col.colors[j+1]
219
- # if j == 0: col_ = col.colors[9]
220
  col_ = 255*np.array(colors.to_rgba(col_))[:3]
221
  img2rsz[m,n, :] = col_[::-1]
222
 
223
  fig1 = plt.figure(1)
224
  plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
225
  ax1 = plt.gca()
226
- # ax1.axis('scaled')
227
  ax1.axis('off')
228
  plt.tight_layout()
229
- # fig1.canvas.draw()
230
 
231
  fig2 = plt.figure(2)
232
  plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
233
  ax2 = plt.gca()
234
- # ax2.axis('scaled')
235
  ax2.axis('off')
236
  plt.tight_layout()
237
- # fig2.canvas.draw()
238
 
239
  f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0]
240
  handles = [f("s", col.colors[i]) for i in range(n_sf_ids)]
241
  fig_leg = plt.figure(3)
242
  legend = plt.legend(handles, sf_idx_, framealpha=1, frameon=False, facecolor='w',fontsize=25, loc="center")
243
- # fig_leg = legend.figure
244
- # fig_leg.canvas.draw()
245
  ax3 = plt.gca()
246
- # ax2.axis('scaled')
247
  ax3.axis('off')
248
  plt.tight_layout()
249
- # bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
250
 
251
-
252
  im1 = None
253
  im2 = None
254
  return fig1, fig2, fig_leg
255
- # ','.join(map(str, sf_idx_))
256
 
257
 
258
  # GRADIO APP
@@ -265,21 +204,16 @@ iface = gr.Interface(
265
  inputs=[
266
  gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
267
  gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
268
- # gr.inputs.Image(type="pil", label="First Image"),
269
- # gr.inputs.Image(type="pil", label="Second Image"),
270
  gr.inputs.Checkbox(default=False, label="ImageNet Model (Default: SfM-120k)"),
271
  gr.inputs.Slider(minimum=0, maximum=6, step=1, default=4, label="Scale"),
272
  gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"),
273
  gr.inputs.Checkbox(default=True, label="Show random (matching) SFs"),
274
  gr.inputs.Textbox(lines=1, default="", label="...or show specific SF IDs:", optional=True),
275
- # gr.inputs.Checkbox(default=True, label="Show only matching SFs"),
276
  ],
277
  outputs=[
278
  gr.outputs.Image(type="plot", label="First Image SFs"),
279
  gr.outputs.Image(type="plot", label="Second Image SFs"),
280
  gr.outputs.Image(type="plot", label="SF legend")],
281
- # gr.outputs.Textbox(label="SFs")],
282
- # outputs=gr.outputs.Image(shape=(1024,2048), type="plot"),
283
  title=title,
284
  theme='peach',
285
  layout="horizontal",
@@ -287,7 +221,6 @@ iface = gr.Interface(
287
  article=article,
288
  examples=[
289
  ["chateau_1.png", "chateau_2.png", False, 3, 150, False, '170,15,25,63,193,125,92,214,107'],
290
- # ["anafi1.jpeg", "anafi2.jpeg", False, 4, 150, False, '178,190,144,47,241, 172'],
291
  ["areopoli1.jpeg", "areopoli2.jpeg", False, 4, 150, False, '205,2,163,130'],
292
  ["jaipur1.jpeg", "jaipur2.jpeg", False, 4, 50, False, '51,206,216,49,27'],
293
  ["basil1.jpeg", "basil2.jpeg", True, 4, 100, False, '75,152,19,36,156'],
 
41
  net_imagenet = fire_network.init_network(**state['net_params']).to(device)
42
  net_imagenet.load_state_dict(state2['state_dict'], strict=False)
43
 
 
44
  transform = transforms.Compose([
45
  transforms.Resize(1024),
46
  transforms.ToTensor(),
47
  transforms.Normalize(**dict(zip(["mean", "std"], net_sfm.runtime['mean_std'])))
48
  ])
49
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  def match(query_feat, pos_feat, LoweRatioTh=0.9):
52
  # first perform reciprocal nn
 
70
  return pindices[valid]
71
 
72
 
 
73
  def clear_figures():
74
  plt.figure().clear()
75
  plt.close()
76
  plt.cla()
77
  plt.clf()
78
 
79
+
80
+
81
 
82
  def generate_matching_superfeatures(
83
  im1, im2,
 
88
  print('im2:', im2.size)
89
 
90
  clear_figures()
91
+ col = plt.get_cmap('tab10')
92
 
93
  net = net_sfm
94
  if Imagenet_model:
95
  net = net_imagenet
96
 
 
 
 
 
97
  im1_tensor = transform(im1).unsqueeze(0)
98
  im2_tensor = transform(im2).unsqueeze(0)
99
 
 
114
 
115
  feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1)
116
  feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1)
 
117
  ind_match = match(feats1n, feats2n)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
  # which sf
120
+ sf_idx_ = []
121
  n_sf_ids = 10
122
  if random_mode or sf_ids == '':
123
  sf_idx_ = np.random.randint(256, size=n_sf_ids)
124
  else:
125
  sf_idx_ = map(int, sf_ids.strip().split(','))
126
 
127
+ # only_matching:
128
  if random_mode:
129
  sf_idx_ = [int(jj) for jj in ind_match[np.random.randint(len(list(ind_match)), size=n_sf_ids)].numpy()]
130
  sf_idx_ = list( dict.fromkeys(sf_idx_) )
 
137
  all_att_bin1 = []
138
  all_att_bin2 = []
139
  for n, i in enumerate(sf_idx_):
 
140
  att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
141
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
142
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
 
143
  all_att_bin1.append(att_heat_bin)
144
 
145
  att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
 
150
 
151
  fin_img = []
152
  img1rsz = np.copy(im1_cv)
 
 
153
  for j, att in enumerate(all_att_bin1):
154
  att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
 
 
 
 
155
  mask2d = zip(*np.where(att==255))
156
  for m,n in mask2d:
157
  col_ = col.colors[j]
 
 
158
  col_ = 255*np.array(colors.to_rgba(col_))[:3]
159
  img1rsz[m,n, :] = col_[::-1]
160
 
 
163
  print('img2rsz:', img2rsz.shape)
164
  for j, att in enumerate(all_att_bin2):
165
  att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
 
 
 
 
166
  mask2d = zip(*np.where(att==255))
167
  for m,n in mask2d:
168
  col_ = col.colors[j]
 
 
169
  col_ = 255*np.array(colors.to_rgba(col_))[:3]
170
  img2rsz[m,n, :] = col_[::-1]
171
 
172
  fig1 = plt.figure(1)
173
  plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
174
  ax1 = plt.gca()
 
175
  ax1.axis('off')
176
  plt.tight_layout()
 
177
 
178
  fig2 = plt.figure(2)
179
  plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
180
  ax2 = plt.gca()
 
181
  ax2.axis('off')
182
  plt.tight_layout()
 
183
 
184
  f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0]
185
  handles = [f("s", col.colors[i]) for i in range(n_sf_ids)]
186
  fig_leg = plt.figure(3)
187
  legend = plt.legend(handles, sf_idx_, framealpha=1, frameon=False, facecolor='w',fontsize=25, loc="center")
 
 
188
  ax3 = plt.gca()
 
189
  ax3.axis('off')
190
  plt.tight_layout()
 
191
 
 
192
  im1 = None
193
  im2 = None
194
  return fig1, fig2, fig_leg
 
195
 
196
 
197
  # GRADIO APP
 
204
  inputs=[
205
  gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
206
  gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
 
 
207
  gr.inputs.Checkbox(default=False, label="ImageNet Model (Default: SfM-120k)"),
208
  gr.inputs.Slider(minimum=0, maximum=6, step=1, default=4, label="Scale"),
209
  gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"),
210
  gr.inputs.Checkbox(default=True, label="Show random (matching) SFs"),
211
  gr.inputs.Textbox(lines=1, default="", label="...or show specific SF IDs:", optional=True),
 
212
  ],
213
  outputs=[
214
  gr.outputs.Image(type="plot", label="First Image SFs"),
215
  gr.outputs.Image(type="plot", label="Second Image SFs"),
216
  gr.outputs.Image(type="plot", label="SF legend")],
 
 
217
  title=title,
218
  theme='peach',
219
  layout="horizontal",
 
221
  article=article,
222
  examples=[
223
  ["chateau_1.png", "chateau_2.png", False, 3, 150, False, '170,15,25,63,193,125,92,214,107'],
 
224
  ["areopoli1.jpeg", "areopoli2.jpeg", False, 4, 150, False, '205,2,163,130'],
225
  ["jaipur1.jpeg", "jaipur2.jpeg", False, 4, 50, False, '51,206,216,49,27'],
226
  ["basil1.jpeg", "basil2.jpeg", True, 4, 100, False, '75,152,19,36,156'],
jaipur1.jpeg CHANGED

Git LFS Details

  • SHA256: 32ef5bbeb649ce8b57699e12643e2d737db4cf1e9080906cefb34714ccd8f0c6
  • Pointer size: 132 Bytes
  • Size of remote file: 2.57 MB

Git LFS Details

  • SHA256: 2c4038427f0614a088c335e16f23fe7d645f1dcf2d2095fb2407e3e82d22b8bb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB