YannisK commited on
Commit
70e3f4a
·
1 Parent(s): b70a5dc
Files changed (2) hide show
  1. app.py +8 -11
  2. requirements.txt +2 -2
app.py CHANGED
@@ -41,9 +41,7 @@ sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
41
  col = plt.get_cmap('tab10')
42
 
43
  def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
44
- print(im1.size)
45
- return
46
-
47
  im1_tensor = transform(im1).unsqueeze(0)
48
  im2_tensor = transform(im2).unsqueeze(0)
49
 
@@ -80,17 +78,17 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
80
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
81
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
82
  all_att_bin2.append(att_heat_bin)
83
-
84
 
85
  fin_img = []
86
  img1rsz = np.copy(im1)
87
- print(img1rsz.size)
88
  for j, att in enumerate(all_att_bin1):
89
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
90
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
91
  # att = cv2.resize(att, imgz[i].shape[:2][::-1])
92
- att = att.resize(shape)
93
- # att = resize(att, im1.size)
94
  mask2d = zip(*np.where(att==255))
95
  for m,n in mask2d:
96
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
@@ -104,8 +102,7 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
104
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
105
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
106
  # att = cv2.resize(att, imgz[i].shape[:2][::-1])
107
- att = att.resize(im2.shape)
108
- print('att:', att.shape)
109
  mask2d = zip(*np.where(att==255))
110
  for m,n in mask2d:
111
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
@@ -135,8 +132,8 @@ article = "<p style='text-align: center'><a href='https://github.com/naver/fire'
135
  iface = gr.Interface(
136
  fn=generate_matching_superfeatures,
137
  inputs=[
138
- gr.inputs.Image(shape=(240, 240), type="pil"),
139
- gr.inputs.Image(shape=(240, 240), type="pil"),
140
  gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
141
  gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
142
  outputs="plot",
 
41
  col = plt.get_cmap('tab10')
42
 
43
  def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
44
+
 
 
45
  im1_tensor = transform(im1).unsqueeze(0)
46
  im2_tensor = transform(im2).unsqueeze(0)
47
 
 
78
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
79
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
80
  all_att_bin2.append(att_heat_bin)
81
+ print(all_att_bin2[0].shape)
82
 
83
  fin_img = []
84
  img1rsz = np.copy(im1)
85
+ print(img1rsz.shape)
86
  for j, att in enumerate(all_att_bin1):
87
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
88
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
89
  # att = cv2.resize(att, imgz[i].shape[:2][::-1])
90
+ att = resize(att, im1.shape[:2])
91
+ print(att.shape)
92
  mask2d = zip(*np.where(att==255))
93
  for m,n in mask2d:
94
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
 
102
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
103
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
104
  # att = cv2.resize(att, imgz[i].shape[:2][::-1])
105
+ att = resize(att, im2.shape[:2])
 
106
  mask2d = zip(*np.where(att==255))
107
  for m,n in mask2d:
108
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
 
132
  iface = gr.Interface(
133
  fn=generate_matching_superfeatures,
134
  inputs=[
135
+ gr.inputs.Image(shape=(1024, 1024), type="numpy"),
136
+ gr.inputs.Image(shape=(1024, 1024), type="numpy"),
137
  gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
138
  gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
139
  outputs="plot",
requirements.txt CHANGED
@@ -1,6 +1,6 @@
 
1
  numpy
2
  pyaml
3
  matplotlib
4
  torch==1.10.2
5
- torchvision==0.11.3
6
- scikit-image
 
1
+ opencv-python
2
  numpy
3
  pyaml
4
  matplotlib
5
  torch==1.10.2
6
+ torchvision==0.11.3