SuperFeatures / app.py
YannisK's picture
temp
880da41
raw
history blame
2.23 kB
import gradio as gr
import torch
import matplotlib.pyplot as plt
from torchvision import transforms
import fire_network
# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
device = 'cpu'
# Load net
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net = fire_network.init_network(**state['net_params']).to(device)
net.load_state_dict(state['state_dict'])
transform = transforms.Compose([
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))
])
# which sf
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
col = plt.get_cmap('tab10')
def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
im1_tensor = transform(im1)
im2_tensor = transform(im2)
# im1_cv = cv2.imread(im1)
# im2_cv = cv2.imread(im2)
# extract features
with torch.no_grad():
output1 = net.get_superfeatures(im1.to(device), scales=scales)
feats1 = output1[0]
attns1 = output1[1]
strenghts1 = output1[2]
output2 = net.get_superfeatures(im2.to(device), scales=scales)
feats2 = output2[0]
attns2 = output2[1]
strenghts2 = output2[2]
print(len(feats1))
# print(feats1.shape)
print(feats1[0].shape)
# print(attns1.shape)
# print(strenghts1.shape)
# GRADIO APP
title = "Visualizing Super-features"
description = "TBD"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
iface = gr.Interface(
fn=generate_matching_superfeatures,
inputs=[
gr.inputs.Image(shape=(240, 240), type="pil"),
gr.inputs.Image(shape=(240, 240), type="pil"),
gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
outputs="plot",
enable_queue=True,
title=title,
description=description,
article=article,
examples=[["chateau_1.png", "chateau_2.png", 6, 50]],
)
iface.launch()