File size: 1,928 Bytes
bb3ea39
 
f4b82b2
bb3ea39
f4b82b2
bb3ea39
f4b82b2
9651aac
 
f4b82b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9651aac
1a2db09
 
9651aac
f4b82b2
 
 
 
 
 
 
 
 
 
 
 
 
1a2db09
 
 
 
 
 
 
 
 
 
 
 
 
f4b82b2
 
1a2db09
 
 
 
 
f4b82b2
1a2db09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import gradio as gr

import torch

from how.networks import how_net

import fire_network


# Possible Scales for multiscale inference
    scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] 
    infer_opts = {"scales": scales, "features_num": 1000}


# Load net
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net = fire_network.init_network(**state['net_params']).to(device)
net.load_state_dict(state['state_dict'])

transforms_ = transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(), 
        transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))
        ])


# Wrapper
def generate_matching_superfeatures(im1, im2, scale=6):
    
    # extract features
    with torch.no_grad():
            output1 = net.get_superfeatures(im1.to(device), scales=scales)
            feats1 = output1[0]
            attns1 = output1[1]
            strenghts1 = output1[2]

            output2 = net.get_superfeatures(im2.to(device), scales=scales)
            feats2 = output2[0]
            attns2 = output2[1]
            strenghts2 = output2[2]




# GRADIO APP
title = "Visualizing Super-features"
description = "TBD"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"


iface = gr.Interface(
    fn=generate_matching_superfeatures,
    inputs=[
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
        gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
    outputs="plot",
    enable_queue=True,
    title=title,
    description=description,
    article=article,
    examples=[["chateau_1.png", "chateau_2.png", 6, 50]],
)
iface.launch()