Spaces:
Build error
Build error
File size: 1,928 Bytes
bb3ea39 f4b82b2 bb3ea39 f4b82b2 bb3ea39 f4b82b2 9651aac f4b82b2 9651aac 1a2db09 9651aac f4b82b2 1a2db09 f4b82b2 1a2db09 f4b82b2 1a2db09 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
import gradio as gr
import torch
from how.networks import how_net
import fire_network
# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
infer_opts = {"scales": scales, "features_num": 1000}
# Load net
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net = fire_network.init_network(**state['net_params']).to(device)
net.load_state_dict(state['state_dict'])
transforms_ = transforms.Compose([
transforms.Resize(1024),
transforms.ToTensor(),
transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))
])
# Wrapper
def generate_matching_superfeatures(im1, im2, scale=6):
# extract features
with torch.no_grad():
output1 = net.get_superfeatures(im1.to(device), scales=scales)
feats1 = output1[0]
attns1 = output1[1]
strenghts1 = output1[2]
output2 = net.get_superfeatures(im2.to(device), scales=scales)
feats2 = output2[0]
attns2 = output2[1]
strenghts2 = output2[2]
# GRADIO APP
title = "Visualizing Super-features"
description = "TBD"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"
iface = gr.Interface(
fn=generate_matching_superfeatures,
inputs=[
gr.inputs.Image(shape=(240, 240), type="pil"),
gr.inputs.Image(shape=(240, 240), type="pil"),
gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
outputs="plot",
enable_queue=True,
title=title,
description=description,
article=article,
examples=[["chateau_1.png", "chateau_2.png", 6, 50]],
)
iface.launch()
|