Spaces:
Build error
Build error
import gradio as gr | |
import cv2 | |
import torch | |
import torch.utils.data as data | |
from torchvision import transforms | |
from torch import nn | |
import torch.nn.functional as F | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from matplotlib import colors | |
from mpl_toolkits.axes_grid1 import ImageGrid | |
import fire_network | |
import numpy as np | |
from PIL import Image | |
# Possible Scales for multiscale inference | |
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] | |
device = 'cpu' | |
# Load net | |
state = torch.load('fire.pth', map_location='cpu') | |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model | |
net_sfm = fire_network.init_network(**state['net_params']).to(device) | |
net_sfm.load_state_dict(state['state_dict']) | |
dim_red_params_dict = {} | |
for name, param in net_sfm.named_parameters(): | |
if 'dim_reduction' in name: | |
dim_red_params_dict[name] = param | |
state2 = torch.load('fire_imagenet.pth', map_location='cpu') | |
state2['net_params'] = state['net_params'] | |
state2['state_dict'] += dim_red_params_dict | |
# state2['net_params'] = | |
net_imagenet = fire_network.init_network(**state['net_params']).to(device) | |
net_imagenet.load_state_dict(state2['state_dict']) #, strict=False) | |
# --------------------------------------- | |
transform = transforms.Compose([ | |
transforms.Resize(1024), | |
transforms.ToTensor(), | |
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std']))) | |
]) | |
# --------------------------------------- | |
# class ImgDataset(data.Dataset): | |
# def __init__(self, images, imsize): | |
# self.images = images | |
# self.imsize = imsize | |
# self.transform = transforms.Compose([transforms.ToTensor(), \ | |
# transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))]) | |
# def __getitem__(self, index): | |
# img = self.images[index] | |
# img.thumbnail((self.imsize, self.imsize), Image.Resampling.LANCZOS) | |
# print('after imresize:', img.size) | |
# return self.transform(img) | |
# def __len__(self): | |
# return len(self.images) | |
# --------------------------------------- | |
def match(query_feat, pos_feat, LoweRatioTh=0.9): | |
# first perform reciprocal nn | |
dist = torch.cdist(query_feat, pos_feat) | |
print('dist.size',dist.size()) | |
best1 = torch.argmin(dist, dim=1) | |
best2 = torch.argmin(dist, dim=0) | |
print('best2.size',best2.size()) | |
arange = torch.arange(best2.size(0)) | |
reciprocal = best1[best2]==arange | |
# check Lowe ratio test | |
dist2 = dist.clone() | |
dist2[best2,arange] = float('Inf') | |
dist2_second2 = torch.argmin(dist2, dim=0) | |
ratio1to2 = dist[best2,arange] / dist2_second2 | |
valid = torch.logical_and(reciprocal, ratio1to2<=LoweRatioTh) | |
pindices = torch.where(valid)[0] | |
qindices = best2[pindices] | |
# keep only the ones with same indices | |
valid = pindices==qindices | |
return pindices[valid] | |
# sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9] | |
col = plt.get_cmap('tab10') | |
def generate_matching_superfeatures(im1, im2, model_fn, scale_id=6, threshold=50, sf_ids='', only_matching=True): | |
print('im1:', im1.size) | |
print('im2:', im2.size) | |
net = net_sfm | |
if model_fn == "ImageNet": | |
net = net_imagenet | |
# dataset_ = ImgDataset(images=[im1, im2], imsize=1024) | |
# loader = torch.utils.data.DataLoader(dataset_, shuffle=False, pin_memory=True) | |
im1_tensor = transform(im1).unsqueeze(0) | |
im2_tensor = transform(im2).unsqueeze(0) | |
im1_cv = np.array(im1)[:, :, ::-1].copy() | |
im2_cv = np.array(im2)[:, :, ::-1].copy() | |
# extract features | |
with torch.no_grad(): | |
output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scales[scale_id]]) | |
feats1 = output1[0][0] | |
attns1 = output1[1][0] | |
strenghts1 = output1[2][0] | |
output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scales[scale_id]]) | |
feats2 = output2[0][0] | |
attns2 = output2[1][0] | |
strenghts2 = output2[2][0] | |
feats1n = F.normalize(torch.t(torch.squeeze(feats1)), dim=1) | |
feats2n = F.normalize(torch.t(torch.squeeze(feats2)), dim=1) | |
print('feats1n.shape', feats1n.shape) | |
ind_match = match(feats1n, feats2n) | |
print('ind', ind_match) | |
print('ind.shape', ind_match.shape) | |
# outputs = [] | |
# for im_tensor in loader: | |
# outputs.append(net.get_superfeatures(im_tensor.to(device), scales=[scales[scale_id]])) | |
# feats1 = outputs[0][0][0] | |
# attns1 = outputs[0][1][0] | |
# strenghts1 = outputs[0][2][0] | |
# feats2 = outputs[1][0][0] | |
# attns2 = outputs[1][1][0] | |
# strenghts2 = outputs[1][2][0] | |
print(feats1.shape, feats2.shape) | |
print(attns1.shape, attns2.shape) | |
print(strenghts1.shape, strenghts2.shape) | |
# which sf | |
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9] | |
random_mode = False | |
n_sf_ids = 0 | |
if sf_ids.lower().startswith('r'): | |
random_mode = True | |
n_sf_ids = int(sf_ids[1:]) | |
sf_idx_ = np.random.randint(256, size=n_sf_ids) | |
elif sf_ids != '': | |
sf_idx_ = map(int, sf_ids.strip().split(',')) | |
if only_matching: | |
if random_mode: | |
sf_idx_ = [int(jj) for jj in ind_match[np.random.randint(len(list(ind_match)), size=n_sf_ids)].numpy()] | |
sf_idx_ = list( dict.fromkeys(sf_idx_) ) | |
else: | |
sf_idx_ = [i for i in sf_idx_ if i in list(ind_match)] | |
# Store all binary SF att maps to show them all at once in the end | |
all_att_bin1 = [] | |
all_att_bin2 = [] | |
for n, i in enumerate(sf_idx_): | |
# all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy()) | |
att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32) | |
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) | |
att_heat_bin = np.where(att_heat>threshold, 255, 0) | |
# print(att_heat_bin) | |
all_att_bin1.append(att_heat_bin) | |
att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32) | |
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) | |
att_heat_bin = np.where(att_heat>threshold, 255, 0) | |
all_att_bin2.append(att_heat_bin) | |
fin_img = [] | |
img1rsz = np.copy(im1_cv) | |
print('im1:', im1.size) | |
print('img1rsz:', img1rsz.shape) | |
for j, att in enumerate(all_att_bin1): | |
att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST) | |
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) | |
# att = cv2.resize(att, imgz[i].shape[:2][::-1]) | |
# att = att.resize(shape) | |
# att = resize(att, im1.size) | |
mask2d = zip(*np.where(att==255)) | |
for m,n in mask2d: | |
col_ = col.colors[j] if j < 7 else col.colors[j+1] | |
if j == 0: col_ = col.colors[9] | |
col_ = 255*np.array(colors.to_rgba(col_))[:3] | |
img1rsz[m,n, :] = col_[::-1] | |
fin_img.append(img1rsz) | |
img2rsz = np.copy(im2_cv) | |
print('im2:', im2.size) | |
print('img2rsz:', img2rsz.shape) | |
for j, att in enumerate(all_att_bin2): | |
att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST) | |
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) | |
# # att = cv2.resize(att, imgz[i].shape[:2][::-1]) | |
# att = att.resize(im2.shape) | |
# print('att:', att.shape) | |
mask2d = zip(*np.where(att==255)) | |
for m,n in mask2d: | |
col_ = col.colors[j] if j < 7 else col.colors[j+1] | |
if j == 0: col_ = col.colors[9] | |
col_ = 255*np.array(colors.to_rgba(col_))[:3] | |
img2rsz[m,n, :] = col_[::-1] | |
fin_img.append(img2rsz) | |
fig1 = plt.figure(1) | |
plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB)) | |
ax1 = plt.gca() | |
# ax1.axis('scaled') | |
ax1.axis('off') | |
plt.tight_layout() | |
# fig1.canvas.draw() | |
fig2 = plt.figure(2) | |
plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB)) | |
ax2 = plt.gca() | |
# ax2.axis('scaled') | |
ax2.axis('off') | |
plt.tight_layout() | |
# fig2.canvas.draw() | |
# fig = plt.figure() | |
# grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1) | |
# for ax, img in zip(grid, fin_img): | |
# ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
# ax.axis('scaled') | |
# ax.axis('off') | |
# plt.tight_layout() | |
# fig.suptitle("Matching SFs", fontsize=16) | |
# fig.canvas.draw() | |
# # Now we can save it to a numpy array. | |
# data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
# data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
return fig1, fig2, ','.join(map(str, sf_idx_)) | |
# GRADIO APP | |
title = "Visualizing Super-features" | |
description = "This is a visualization demo for the ICLR 2022 paper <b><a href='https://github.com/naver/fire' target='_blank'>Learning Super-Features for Image Retrieval</a></p></b>" | |
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>" | |
# css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}" | |
# css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }" | |
# css = ".output_image, .input_image {hieght: 1000px !important}" | |
css = ".input_image, .input_image {height: 600px !important; width: 600px !important;} " | |
# css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}" | |
iface = gr.Interface( | |
fn=generate_matching_superfeatures, | |
inputs=[ | |
# gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"), | |
# gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"), | |
gr.inputs.Image(type="pil", label="First Image"), | |
gr.inputs.Image(type="pil", label="Second Image"), | |
gr.inputs.Radio(["ImageNet", "SfM-120k (landmarks)"], label="Model", optional=False), | |
gr.inputs.Slider(minimum=0, maximum=6, step=1, default=2, label="Scale"), | |
gr.inputs.Slider(minimum=0, maximum=255, step=25, default=150, label="Binarization Threshold"), | |
gr.inputs.Textbox(lines=1, default="", label="SF IDs to show (comma separated numbers from 0-255; typing 'rX' will return X random SFs", optional=True), | |
gr.inputs.Checkbox(default=True, label="Show only matching SFs", optional=False), | |
], | |
outputs=[ | |
gr.outputs.Image(type="plot", label="First Image SFs"), | |
gr.outputs.Image(type="plot", label="Second Image SFs"), | |
gr.outputs.Textbox(label="SFs")], | |
# outputs=gr.outputs.Image(shape=(1024,2048), type="plot"), | |
title=title, | |
theme='peach', | |
layout="horizontal", | |
description=description, | |
article=article, | |
css=css, | |
examples=[ | |
["chateau_1.png", "chateau_2.png", 3, 150, 'r8', True], | |
["anafi1.jpeg", "anafi2.jpeg", 4, 150, 'r8', True], | |
["areopoli1.jpeg", "areopoli2.jpeg", 4, 150, 'r8', True], | |
] | |
) | |
iface.launch(enable_queue=True) |