YannisK commited on
Commit
c1911e8
·
1 Parent(s): dd99365
Files changed (1) hide show
  1. app.py +17 -11
app.py CHANGED
@@ -89,14 +89,7 @@ col = plt.get_cmap('tab10')
89
  def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='', only_matching=True):
90
  print('im1:', im1.size)
91
  print('im2:', im2.size)
92
- # which sf
93
- sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
94
- if sf_ids.lower().startswith('r'):
95
- n_sf_ids = int(sf_ids[1:])
96
- sf_idx_ = np.random.randint(256, size=n_sf_ids)
97
- elif sf_ids != '':
98
- sf_idx_ = map(int, sf_ids.strip().split(','))
99
-
100
 
101
  # dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
102
  # loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
@@ -120,8 +113,8 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='
120
  attns2 = output2[1][0]
121
  strenghts2 = output2[2][0]
122
 
123
- feats1n = F.normalize(torch.squeeze(feats1), dim=1)
124
- feats2n = F.normalize(torch.squeeze(feats2), dim=1)
125
  print('feats1n.shape', feats1n.shape)
126
  ind_match = match(feats1n, feats2n)
127
  print('ind', ind_match)
@@ -139,7 +132,20 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='
139
  print(attns1.shape, attns2.shape)
140
  print(strenghts1.shape, strenghts2.shape)
141
 
142
- # if only_matching:
 
 
 
 
 
 
 
 
 
 
 
 
 
143
 
144
  # Store all binary SF att maps to show them all at once in the end
145
  all_att_bin1 = []
 
89
  def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='', only_matching=True):
90
  print('im1:', im1.size)
91
  print('im2:', im2.size)
92
+
 
 
 
 
 
 
 
93
 
94
  # dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
95
  # loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
 
113
  attns2 = output2[1][0]
114
  strenghts2 = output2[2][0]
115
 
116
+ feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1)
117
+ feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1)
118
  print('feats1n.shape', feats1n.shape)
119
  ind_match = match(feats1n, feats2n)
120
  print('ind', ind_match)
 
132
  print(attns1.shape, attns2.shape)
133
  print(strenghts1.shape, strenghts2.shape)
134
 
135
+ # which sf
136
+ sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
137
+ if sf_ids.lower().startswith('r'):
138
+ n_sf_ids = int(sf_ids[1:])
139
+ sf_idx_ = np.random.randint(256, size=n_sf_ids)
140
+ elif sf_ids != '':
141
+ sf_idx_ = map(int, sf_ids.strip().split(','))
142
+
143
+ if only_matching:
144
+ sf_idx_ = [i for i in sf_idx_ if i in list(ind_match)]
145
+
146
+
147
+
148
+
149
 
150
  # Store all binary SF att maps to show them all at once in the end
151
  all_att_bin1 = []