Spaces:
Build error
Build error
edits
Browse files
app.py
CHANGED
@@ -89,14 +89,7 @@ col = plt.get_cmap('tab10')
|
|
89 |
def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='', only_matching=True):
|
90 |
print('im1:', im1.size)
|
91 |
print('im2:', im2.size)
|
92 |
-
|
93 |
-
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
|
94 |
-
if sf_ids.lower().startswith('r'):
|
95 |
-
n_sf_ids = int(sf_ids[1:])
|
96 |
-
sf_idx_ = np.random.randint(256, size=n_sf_ids)
|
97 |
-
elif sf_ids != '':
|
98 |
-
sf_idx_ = map(int, sf_ids.strip().split(','))
|
99 |
-
|
100 |
|
101 |
# dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
|
102 |
# loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
|
@@ -120,8 +113,8 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='
|
|
120 |
attns2 = output2[1][0]
|
121 |
strenghts2 = output2[2][0]
|
122 |
|
123 |
-
feats1n = F.normalize(torch.squeeze(feats1), dim=1)
|
124 |
-
feats2n = F.normalize(torch.squeeze(feats2), dim=1)
|
125 |
print('feats1n.shape', feats1n.shape)
|
126 |
ind_match = match(feats1n, feats2n)
|
127 |
print('ind', ind_match)
|
@@ -139,7 +132,20 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='
|
|
139 |
print(attns1.shape, attns2.shape)
|
140 |
print(strenghts1.shape, strenghts2.shape)
|
141 |
|
142 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
# Store all binary SF att maps to show them all at once in the end
|
145 |
all_att_bin1 = []
|
|
|
89 |
def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids='', only_matching=True):
|
90 |
print('im1:', im1.size)
|
91 |
print('im2:', im2.size)
|
92 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
|
94 |
# dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
|
95 |
# loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
|
|
|
113 |
attns2 = output2[1][0]
|
114 |
strenghts2 = output2[2][0]
|
115 |
|
116 |
+
feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1)
|
117 |
+
feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1)
|
118 |
print('feats1n.shape', feats1n.shape)
|
119 |
ind_match = match(feats1n, feats2n)
|
120 |
print('ind', ind_match)
|
|
|
132 |
print(attns1.shape, attns2.shape)
|
133 |
print(strenghts1.shape, strenghts2.shape)
|
134 |
|
135 |
+
# which sf
|
136 |
+
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
|
137 |
+
if sf_ids.lower().startswith('r'):
|
138 |
+
n_sf_ids = int(sf_ids[1:])
|
139 |
+
sf_idx_ = np.random.randint(256, size=n_sf_ids)
|
140 |
+
elif sf_ids != '':
|
141 |
+
sf_idx_ = map(int, sf_ids.strip().split(','))
|
142 |
+
|
143 |
+
if only_matching:
|
144 |
+
sf_idx_ = [i for i in sf_idx_ if i in list(ind_match)]
|
145 |
+
|
146 |
+
|
147 |
+
|
148 |
+
|
149 |
|
150 |
# Store all binary SF att maps to show them all at once in the end
|
151 |
all_att_bin1 = []
|