Spaces:
Build error
Build error
import gradio as gr | |
import cv2 | |
import torch | |
import matplotlib.pyplot as plt | |
from matplotlib import cm | |
from matplotlib import colors | |
from mpl_toolkits.axes_grid1 import ImageGrid | |
from torchvision import transforms | |
import fire_network | |
import numpy as np | |
from PIL import Image | |
# Possible Scales for multiscale inference | |
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] | |
device = 'cpu' | |
# Load net | |
state = torch.load('fire.pth', map_location='cpu') | |
state['net_params']['pretrained'] = None # no need for imagenet pretrained model | |
net = fire_network.init_network(**state['net_params']).to(device) | |
net.load_state_dict(state['state_dict']) | |
transform = transforms.Compose([ | |
transforms.Resize(1024), | |
transforms.ToTensor(), | |
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std']))) | |
]) | |
# which sf | |
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9] | |
col = plt.get_cmap('tab10') | |
def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50): | |
im1_tensor = transform(im1).unsqueeze(0) | |
im2_tensor = transform(im2).unsqueeze(0) | |
# im1_cv = cv2.imread(im1) | |
# im2_cv = cv2.imread(im2) | |
# extract features | |
with torch.no_grad(): | |
output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scale_id]) | |
feats1 = output1[0][0] | |
attns1 = output1[1][0] | |
strenghts1 = output1[2][0] | |
output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scale_id]) | |
feats2 = output2[0][0] | |
attns2 = output2[1][0] | |
strenghts2 = output2[2][0] | |
print(feats1.shape, feats2.shape) | |
print(attns1.shape, attns2.shape) | |
print(strenghts1.shape, strenghts2.shape) | |
# Store all binary SF att maps to show them all at once in the end | |
all_att_bin1 = [] | |
all_att_bin2 = [] | |
for n, i in enumerate(sf_idx_): | |
# all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy()) | |
att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32) | |
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) | |
att_heat_bin = np.where(att_heat>threshold, 255, 0) | |
all_att_bin1.append(att_heat_bin) | |
att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32) | |
att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0) | |
att_heat_bin = np.where(att_heat>threshold, 255, 0) | |
all_att_bin2.append(att_heat_bin) | |
fin_img = [] | |
img1rsz = np.copy(im1) | |
print(img1rsz.size) | |
for j, att in enumerate(all_att_bin1): | |
att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST) | |
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) | |
# att = cv2.resize(att, imgz[i].shape[:2][::-1]) | |
# att = att.resize(shape) | |
# att = resize(att, im1.size) | |
mask2d = zip(*np.where(att==255)) | |
for m,n in mask2d: | |
col_ = col.colors[j] if j < 7 else col.colors[j+1] | |
if j == 0: col_ = col.colors[9] | |
col_ = 255*np.array(colors.to_rgba(col_))[:3] | |
img1rsz[m,n, :] = col_[::-1] | |
fin_img.append(img1rsz) | |
img2rsz = np.copy(im2) | |
for j, att in enumerate(all_att_bin2): | |
att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST) | |
# att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC) | |
# # att = cv2.resize(att, imgz[i].shape[:2][::-1]) | |
# att = att.resize(im2.shape) | |
# print('att:', att.shape) | |
mask2d = zip(*np.where(att==255)) | |
for m,n in mask2d: | |
col_ = col.colors[j] if j < 7 else col.colors[j+1] | |
if j == 0: col_ = col.colors[9] | |
col_ = 255*np.array(colors.to_rgba(col_))[:3] | |
img2rsz[m,n, :] = col_[::-1] | |
fin_img.append(img2rsz) | |
fig = plt.figure(figsize=(12,25)) | |
grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1) | |
for ax, img in zip(grid, fin_img): | |
ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) | |
ax.axis('scaled') | |
ax.axis('off') | |
plt.tight_layout() | |
fig.suptitle("Matching SFs", fontsize=16) | |
return fig | |
# GRADIO APP | |
title = "Visualizing Super-features" | |
description = "TBD" | |
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>" | |
iface = gr.Interface( | |
fn=generate_matching_superfeatures, | |
inputs=[ | |
gr.inputs.Image(shape=(1024, 1024), type="pil"), | |
gr.inputs.Image(shape=(1024, 1024), type="pil"), | |
gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"), | |
gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")], | |
outputs="plot", | |
enable_queue=True, | |
title=title, | |
description=description, | |
article=article, | |
examples=[["chateau_1.png", "chateau_2.png", 6, 50]], | |
) | |
iface.launch() | |