File size: 2,114 Bytes
bb3ea39
 
f4b82b2
bb3ea39
f4b82b2
bb3ea39
f4b82b2
9651aac
c9dadbf
9651aac
f4b82b2
c9dadbf
f4b82b2
 
 
 
 
 
 
c9dadbf
f4b82b2
 
 
 
 
9651aac
c9dadbf
 
 
 
 
 
1a2db09
9651aac
c9dadbf
 
f4b82b2
c9dadbf
 
f4b82b2
c9dadbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f4b82b2
1a2db09
 
 
 
 
 
 
 
 
 
 
 
 
f4b82b2
 
1a2db09
 
 
 
 
f4b82b2
1a2db09
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import gradio as gr

import torch

from how.networks import how_net

import fire_network

import cv2

# Possible Scales for multiscale inference
scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25] 

# Load net
state = torch.load('fire.pth', map_location='cpu')
state['net_params']['pretrained'] = None # no need for imagenet pretrained model
net = fire_network.init_network(**state['net_params']).to(device)
net.load_state_dict(state['state_dict'])

transform = transforms.Compose([
        transforms.Resize(1024),
        transforms.ToTensor(), 
        transforms.Normalize(**dict(zip(["mean", "std"], net.runtime['mean_std'])))
        ])


# which sf
sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]


col = plt.get_cmap('tab10')

def generate_matching_superfeatures(im1, im2, scale=6):
    
    im1_tensor = transform(im1)
    im2_tensor = transform(im2)

    im1_cv = cv2.imread(im1)
    im2_cv = cv2.imread(im2)

    # extract features
    with torch.no_grad():
        output1 = net.get_superfeatures(im1.to(device), scales=scales)
        feats1 = output1[0]
        attns1 = output1[1]
        strenghts1 = output1[2]

        output2 = net.get_superfeatures(im2.to(device), scales=scales)
        feats2 = output2[0]
        attns2 = output2[1]
        strenghts2 = output2[2]

    print(feats1.shape)
    print(attns1.shape)
    print(strenghts1.shape)



# GRADIO APP
title = "Visualizing Super-features"
description = "TBD"
article = "<p style='text-align: center'><a href='https://github.com/naver/fire' target='_blank'>Original Github Repo</a></p>"


iface = gr.Interface(
    fn=generate_matching_superfeatures,
    inputs=[
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Image(shape=(240, 240), type="pil"),
        gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
        gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
    outputs="plot",
    enable_queue=True,
    title=title,
    description=description,
    article=article,
    examples=[["chateau_1.png", "chateau_2.png", 6, 50]],
)
iface.launch()