Spaces:
Build error
Build error
File size: 11,062 Bytes
bb3ea39 dd504c0 880da41 dd504c0 5dfb000 aa5e40b 402b433 880da41 402b433 880da41 f4b82b2 9651aac e7c9542 402b433 f4b82b2 c9dadbf f4b82b2 351ead9 22a02d6 f4b82b2 b3806d2 cb6bc7f f4b82b2 b3806d2 cb6bc7f 1d018e4 ce415a7 1d018e4 5dfb000 092e211 0f68393 092e211 5dfb000 5d89783 5dfb000 b489890 5dfb000 ea2b2ef eee2bf2 ea2b2ef c8e8ad1 a15c1b9 ea2b2ef 9651aac 0f77bb9 b2bf1e7 c9dadbf 411e4cb 0f77bb9 c1911e8 b2bf1e7 b3806d2 0cb19e3 b3806d2 5dfb000 388a978 5dfb000 b489890 f4b82b2 f981819 f4b82b2 c9dadbf 092e211 ea2b2ef c1911e8 6bc6828 ea2b2ef 092e211 402b433 c1911e8 cb347cf 9fec9a2 c1911e8 9fec9a2 c1911e8 cb3d625 c1911e8 a809352 402b433 7c408ba 402b433 0f77bb9 402b433 dd504c0 402b433 722b0aa e52b6e6 0f77bb9 402b433 7022373 402b433 dd504c0 402b433 3ea46fa 402b433 dca7dd8 0f77bb9 402b433 7022373 402b433 dd504c0 402b433 8222092 402b433 f4b82b2 0f77bb9 3d3c7f5 0f77bb9 a0c42c4 0f77bb9 a0c42c4 0f77bb9 3d3c7f5 0f77bb9 a0c42c4 0f77bb9 cc9878d 8222092 8ff5253 cb347cf 8ff5253 a92520a 9180d7e 8ff5253 cc9878d 2566e5b 8ff5253 cc9878d 1a2db09 274f0f4 1a2db09 2566e5b 43f0015 411e4cb 7f2fdd5 411e4cb 0cb19e3 f7b2f96 36eb9c8 0f77bb9 06125bf cc9878d 74bae21 1a2db09 0f77bb9 2134c75 1a2db09 0f77bb9 707ee9f ec02159 2639fe5 53874e7 c91b02f 700c051 1a2db09 2134c75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
import gradio as gr
import cv2
import torch
import torch.utils.data as data
from torchvision import transforms
from torch import nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib import colors
from mpl_toolkits.axes_grid1 import ImageGrid
import fire_network
import numpy as np
from PIL import Image
# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
device = 'cpu'
# Load nets
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net_sfm = fire_network.init_network(**state['net_params']).to(device)
net_sfm.load_state_dict(state['state_dict'])
dim_red_params_dict = {}
for name, param in net_sfm.named_parameters():
if 'dim_reduction' in name:
dim_red_params_dict[name] = param
state2 = torch.load('fire_imagenet.pth', map_location='cpu')
state2['net_params'] = state['net_params']
state2['state_dict'] = dict(state2['state_dict'], **dim_red_params_dict);
net_imagenet = fire_network.init_network(**state['net_params']).to(device)
net_imagenet.load_state_dict(state2['state_dict'], strict=False)
# ---------------------------------------
transform = transforms.Compose([
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize(**dict(zip(["mean", "std"], net_sfm.runtime['mean_std'])))
])
# ---------------------------------------
# class ImgDataset(data.Dataset):
# def __init__(self, images, imsize):
# self.images = images
# self.imsize = imsize
# self.transform = transforms.Compose([transforms.ToTensor(), \
# transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))])
# def __getitem__(self, index):
# img = self.images[index]
# img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS)
# print('after imresize:', img.size)
# return self.transform(img)
# def __len__(self):
# return len(self.images)
# ---------------------------------------
def match(query_feat, pos_feat, LoweRatioTh=0.9):
# first perform reciprocal nn
dist = torch.cdist(query_feat, pos_feat)
print('dist.size',dist.size())
best1 = torch.argmin(dist, dim=1)
best2 = torch.argmin(dist, dim=0)
print('best2.size',best2.size())
arange = torch.arange(best2.size(0))
reciprocal = best1[best2]==arange
# check Lowe ratio test
dist2 = dist.clone()
dist2[best2,arange] = float('Inf')
dist2_second2 = torch.argmin(dist2, dim=0)
ratio1to2 = dist[best2,arange] / dist2_second2
valid = torch.logical_and(reciprocal, ratio1to2<=LoweRatioTh)
pindices = torch.where(valid)[0]
qindices = best2[pindices]
# keep only the ones with same indices
valid = pindices==qindices
return pindices[valid]
# sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
def clear_figures():
plt.figure().clear()
plt.close()
plt.cla()
plt.clf()
col = plt.get_cmap('tab10')
def generate_matching_superfeatures(
im1, im2,
Imagenet_model=False,
scale_id=6, threshold=50,
random_mode=False, sf_ids=''): #, only_matching=True):
print('im1:', im1.size)
print('im2:', im2.size)
clear_figures()
net = net_sfm
if Imagenet_model:
net = net_imagenet
# dataset_ = ImgDataset(images=[im1, im2], imsize=1024)
# loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True)
im1_tensor = transform(im1).unsqueeze(0)
im2_tensor = transform(im2).unsqueeze(0)
im1_cv = np.array(im1)[:, :, ::-1].copy()
im2_cv = np.array(im2)[:, :, ::-1].copy()
# extract features
with torch.no_grad():
output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scales[scale_id]])
feats1 = output1[0][0]
attns1 = output1[1][0]
strenghts1 = output1[2][0]
output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scales[scale_id]])
feats2 = output2[0][0]
attns2 = output2[1][0]
strenghts2 = output2[2][0]
feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1)
feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1)
print('feats1n.shape', feats1n.shape)
ind_match = match(feats1n, feats2n)
print('ind', ind_match)
print('ind.shape', ind_match.shape)
# outputs = []
# for im_tensor in loader:
# outputs.append(net.get_superfeatures(im_tensor.to(device), scales=[scales[scale_id]]))
# feats1 = outputs[0][0][0]
# attns1 = outputs[0][1][0]
# strenghts1 = outputs[0][2][0]
# feats2 = outputs[1][0][0]
# attns2 = outputs[1][1][0]
# strenghts2 = outputs[1][2][0]
print(feats1.shape, feats2.shape)
print(attns1.shape, attns2.shape)
print(strenghts1.shape, strenghts2.shape)
# which sf
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
n_sf_ids = 10
if random_mode or sf_ids == '':
sf_idx_ = np.random.randint(256, size=n_sf_ids)
else:
sf_idx_ = map(int, sf_ids.strip().split(','))
# if only_matching:
if random_mode:
sf_idx_ = [int(jj) for jj in ind_match[np.random.randint(len(list(ind_match)), size=n_sf_ids)].numpy()]
sf_idx_ = list( dict.fromkeys(sf_idx_) )
else:
sf_idx_ = [i for i in sf_idx_ if i in list(ind_match)]
n_sf_ids = len(sf_idx_)
# Store all binary SF att maps to show them all at once in the end
all_att_bin1 = []
all_att_bin2 = []
for n, i in enumerate(sf_idx_):
# all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy())
att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
# print(att_heat_bin)
all_att_bin1.append(att_heat_bin)
att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
att_heat_bin = np.where(att_heat>threshold, 255, 0)
all_att_bin2.append(att_heat_bin)
fin_img = []
img1rsz = np.copy(im1_cv)
print('im1:', im1.size)
print('img1rsz:', img1rsz.shape)
for j, att in enumerate(all_att_bin1):
att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
# att = cv2.resize(att, imgz[i].shape[:2][::-1])
# att = att.resize(shape)
# att = resize(att, im1.size)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j]
# col_ = col.colors[j] if j < 7 else col.colors[j+1]
# if j == 0: col_ = col.colors[9]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img1rsz[m,n, :] = col_[::-1]
img2rsz = np.copy(im2_cv)
print('im2:', im2.size)
print('img2rsz:', img2rsz.shape)
for j, att in enumerate(all_att_bin2):
att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
# # att = cv2.resize(att, imgz[i].shape[:2][::-1])
# att = att.resize(im2.shape)
# print('att:', att.shape)
mask2d = zip(*np.where(att==255))
for m,n in mask2d:
col_ = col.colors[j]
# col_ = col.colors[j] if j < 7 else col.colors[j+1]
# if j == 0: col_ = col.colors[9]
col_ = 255*np.array(colors.to_rgba(col_))[:3]
img2rsz[m,n, :] = col_[::-1]
fig1 = plt.figure(1)
plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
ax1 = plt.gca()
# ax1.axis('scaled')
ax1.axis('off')
plt.tight_layout()
# fig1.canvas.draw()
fig2 = plt.figure(2)
plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
ax2 = plt.gca()
# ax2.axis('scaled')
ax2.axis('off')
plt.tight_layout()
# fig2.canvas.draw()
f = lambda m,c: plt.plot([],[],marker=m, color=c, ls="none")[0]
handles = [f("s", col.colors[i]) for i in range(n_sf_ids)]
fig_leg = plt.figure(3)
legend = plt.legend(handles, sf_idx_, framealpha=1, frameon=False, facecolor='w',fontsize=25, loc="center")
# fig_leg = legend.figure
# fig_leg.canvas.draw()
ax3 = plt.gca()
# ax2.axis('scaled')
ax3.axis('off')
plt.tight_layout()
# bbox = legend.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
im1 = None
im2 = None
return fig1, fig2, fig_leg
# ','.join(map(str, sf_idx_))
# GRADIO APP
title = "Visualizing Super-features"
description = "This is a visualization demo for the ICLR 2022 paper <b><a href='https://github.com/naver/fire' target='_blank'>Learning Super-Features for Image Retrieval</a></p></b>"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
iface = gr.Interface(
fn=generate_matching_superfeatures,
inputs=[
gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
# gr.inputs.Image(type="pil", label="First Image"),
# gr.inputs.Image(type="pil", label="Second Image"),
gr.inputs.Checkbox(default=False, label="ImageNet Model (Default: SfM-120k)"),
gr.inputs.Slider(minimum=0, maximum=6, step=1, default=4, label="Scale"),
gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"),
gr.inputs.Checkbox(default=True, label="Show random (matching) SFs"),
gr.inputs.Textbox(lines=1, default="", label="...or show specific SF IDs:", optional=True),
# gr.inputs.Checkbox(default=True, label="Show only matching SFs"),
],
outputs=[
gr.outputs.Image(type="plot", label="First Image SFs"),
gr.outputs.Image(type="plot", label="Second Image SFs"),
gr.outputs.Image(type="plot", label="SF legend")],
# gr.outputs.Textbox(label="SFs")],
# outputs=gr.outputs.Image(shape=(1024,2048), type="plot"),
title=title,
theme='peach',
layout="horizontal",
description=description,
article=article,
examples=[
["chateau_1.png", "chateau_2.png", False, 3, 150, False, '170,15,25,63,193,125,92,214,107'],
# ["anafi1.jpeg", "anafi2.jpeg", False, 4, 150, False, '178,190,144,47,241, 172'],
["areopoli1.jpeg", "areopoli2.jpeg", False, 4, 150, False, '205,2,163,130'],
["jaipur1.jpeg", "jaipur2.jpeg", False, 4, 50, False, '51,206,216,49,27'],
["basil1.jpeg", "basil2.jpeg", True, 4, 100, False, '75,152,19,36,156'],
["mill1.jpeg", "mill2.jpeg", False, 4, 100, False, '177,88,170,190,151,155'],
]
)
iface.launch(enable_queue=True) |