YannisK commited on
Commit
1e7dbb8
·
1 Parent(s): 4ee6e9d
Files changed (1) hide show
  1. app.py +15 -7
app.py CHANGED
@@ -37,13 +37,20 @@ transform = transforms.Compose([
37
  ])
38
 
39
 
40
- # which sf
41
- sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
42
 
43
  col = plt.get_cmap('tab10')
44
 
45
- def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
46
 
 
 
 
 
 
 
 
 
47
  im1_tensor = transform(im1).unsqueeze(0)
48
  im2_tensor = transform(im2).unsqueeze(0)
49
 
@@ -74,7 +81,7 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
74
  att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
75
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
76
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
77
- print(att_heat_bin)
78
  all_att_bin1.append(att_heat_bin)
79
 
80
  att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
@@ -164,10 +171,11 @@ css = ".input_image, .input_image {height: 600px !important; width: 600px !impor
164
  iface = gr.Interface(
165
  fn=generate_matching_superfeatures,
166
  inputs=[
167
- gr.inputs.Image(shape=(1024, 1024), type="pil"),
168
- gr.inputs.Image(shape=(1024, 1024), type="pil"),
169
  gr.inputs.Slider(minimum=1, maximum=7, step=1, default=3, label="Scale"),
170
- gr.inputs.Slider(minimum=1, maximum=255, step=25, default=100, label="Binarization Threshold")],
 
171
  outputs=["plot", "plot"],
172
  # outputs=gr.outputs.Image(shape=(1024,2048), type="plot"),
173
  title=title,
 
37
  ])
38
 
39
 
40
+ # sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
 
41
 
42
  col = plt.get_cmap('tab10')
43
 
44
+ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids=''):
45
 
46
+ # which sf
47
+ sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
48
+ if sf_ids.lower().startswith('r'):
49
+ n_sf_ids = int(sf_ids[1:])
50
+ sf_idx_ = np.random.randint(256, size=n_sf_ids)
51
+ elif sf_ids != '':
52
+ sf_idx_ = map(int, sf_ids.strip().split(','))
53
+
54
  im1_tensor = transform(im1).unsqueeze(0)
55
  im2_tensor = transform(im2).unsqueeze(0)
56
 
 
81
  att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
82
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
83
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
84
+ # print(att_heat_bin)
85
  all_att_bin1.append(att_heat_bin)
86
 
87
  att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
 
171
  iface = gr.Interface(
172
  fn=generate_matching_superfeatures,
173
  inputs=[
174
+ gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
175
+ gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
176
  gr.inputs.Slider(minimum=1, maximum=7, step=1, default=3, label="Scale"),
177
+ gr.inputs.Slider(minimum=1, maximum=255, step=25, default=100, label="Binarization Threshold"),
178
+ gr.inputs.Textbox(lines=1, default="", label="Super-feature IDs to show", optional=True)],
179
  outputs=["plot", "plot"],
180
  # outputs=gr.outputs.Image(shape=(1024,2048), type="plot"),
181
  title=title,