File size: 2,196 Bytes
bb3ea39
 
f4b82b2
497a5c7
bb3ea39
f4b82b2
9651aac
f4b82b2
c9dadbf
f4b82b2
351ead9
 
f4b82b2
 
 
 
 
 
c9dadbf
f4b82b2
 
 
 
 
9651aac
c9dadbf
 
 
 
 
 
4e8ced7
9651aac
c9dadbf
 
f4b82b2
497a5c7
 
f4b82b2
c9dadbf
 
 
 
 
 
 
 
 
 
 
 
497a5c7
 
 
 
 
f4b82b2
1a2db09
 
 
 
 
 
 
 
 
 
 
 
 
f4b82b2
 
1a2db09
 
 
 
 
f4b82b2
1a2db09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import gradio as gr

import torch
from torchvision import transforms

import fire_network

# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] 

device = 'cpu'

# Load net
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net = fire_network.init_network(**state['net_params']).to(device)
net.load_state_dict(state['state_dict'])

transform = transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(), 
        transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))
        ])


# which sf
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]


col = plt.get_cmap('tab10')

def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
    
    im1_tensor = transform(im1)
    im2_tensor = transform(im2)

    # im1_cv = cv2.imread(im1)
    # im2_cv = cv2.imread(im2)

    # extract features
    with torch.no_grad():
        output1 = net.get_superfeatures(im1.to(device), scales=scales)
        feats1 = output1[0]
        attns1 = output1[1]
        strenghts1 = output1[2]

        output2 = net.get_superfeatures(im2.to(device), scales=scales)
        feats2 = output2[0]
        attns2 = output2[1]
        strenghts2 = output2[2]

    print(len(feats1))
    # print(feats1.shape)
    print(feats1[0].shape)
    # print(attns1.shape)
    # print(strenghts1.shape)



# GRADIO APP
title = "Visualizing Super-features"
description = "TBD"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"


iface = gr.Interface(
    fn=generate_matching_superfeatures,
    inputs=[
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
        gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
    outputs="plot",
    enable_queue=True,
    title=title,
    description=description,
    article=article,
    examples=[["chateau_1.png", "chateau_2.png", 6, 50]],
)
iface.launch()