File size: 5,756 Bytes
068d302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
757440e
068d302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cbc414
 
 
 
068d302
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from datasets import load_dataset
from torchvision import transforms
import torch 
from timm import create_model
from omegaconf import OmegaConf
import faiss
import pickle
import gradio as gr
import os
import joblib
import torch.nn as nn
from typing import Dict, Iterable, Callable
from torch import Tensor
import torchvision
from PIL import Image


def get_model(args,arch,load_from,arch_path):
    if load_from == 'timm':
        model = create_model(arch,pretrained = True).to(args.PARAMETERS.device)
        print("Load model timm")
    elif load_from == 'torchvision':
        if arch == 'resnet50':
            model = torchvision.models.resnet50(pretrained=False)
    if len(arch_path)>0:
        print("Loading pretrained Model")
        model.load_state_dict(torch.load(arch_path,map_location='cpu')['state_dict'],strict = True)
    model.eval()
    return model


def get_transform(args):
    return transforms.Compose([transforms.Resize([args.PARAMETERS.img_resize,args.PARAMETERS.img_resize]),
                               transforms.CenterCrop([args.PARAMETERS.img_crop,args.PARAMETERS.img_crop]),
                               transforms.ToTensor()])


class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        _ = self.model(x)
        return self._features


def _load_dataset(args):
    if args.PARAMETERS.metric == 'L2':
        faiss_metric = faiss.METRIC_L2
    dataset = load_dataset(args.PARAMETERS.dataset,split = 'train')
    dataset = dataset.add_faiss_index(column=args.ROBUST.embedding_col,metric_type = faiss_metric)
    dataset = dataset.add_faiss_index(column=args.NONROBUST.embedding_col,metric_type = faiss_metric)
    return dataset


args = OmegaConf.load("configs/resnet.yaml")
wiki_dataset = _load_dataset(args)
TRANSFORMS         = get_transform(args)
robust_model       = get_model(args,args.ROBUST.arch,args.ROBUST.load_from,args.ROBUST.arch_path)
non_robust_model   = get_model(args,args.NONROBUST.arch,args.NONROBUST.load_from,args.NONROBUST.arch_path)
fe_robust_model    = FeatureExtractor(robust_model,layers = [args.ROBUST.layer])
fe_nonrobust_model = FeatureExtractor(non_robust_model,layers = [args.NONROBUST.layer])


# +
def retrieval_fn(image,radio):
    #try:
    image = Image.fromarray(image)
    #except:
    #pass
    image = TRANSFORMS(image).unsqueeze(0)
    image = image.to(args.PARAMETERS.device)
    
    if radio == 'robust':
        emb = fe_robust_model(image)[args.ROBUST.layer]    
        emb = emb.view(1,-1).detach().cpu().numpy()
        scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.ROBUST.embedding_col,
                                                                  query = emb, 
                                                                  k = 3)
    elif radio == 'standard':
        emb = fe_nonrobust_model(image)[args.NONROBUST.layer]    
        emb = emb.view(1,-1).detach().cpu().numpy()
        scores, retrieved_examples = wiki_dataset.get_nearest_examples(index_name = args.NONROBUST.embedding_col,
                                                                  query = emb, 
                                                                  k=3)
    return scores,retrieved_examples

def gradio_fn(image,radio):
    scores,retrieved_examples = retrieval_fn(image,radio)
    m = []
    for description,image,score in zip(retrieved_examples['description'],
                                       retrieved_examples['image'],
                                       scores):
        m.append(description)
        m.append(image)
    return m


# -

if __name__ == '__main__':
    demo = gr.Blocks()
    with demo:
        gr.Markdown("# Robust vs Standard Image Retrieval")
        with gr.Tabs():
            with gr.TabItem("Upload your Image"):
                with gr.Row():
                    with gr.Column():
                        with gr.Row():
                            image_input = gr.Image(label="Input Image")
                        with gr.Row():
                            radio_button = gr.Radio(["robust","standard"], 
                                                    value = "robust",
                                                    label = "OD Model")
                        with gr.Row():
                            calculate_button = gr.Button("Compute")
                    with gr.Column():
                        textbox1 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
                        output_image1    = gr.Image(label="1st Best match")
                        textbox2 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
                        output_image2    = gr.Image(label="2nd Best match")
                        textbox3 = gr.Textbox(label = "Artist / Title / Style / Genre / Date")
                        output_image3    = gr.Image(label="3rd Best match")

                calculate_button.click(fn = gradio_fn,
                                       inputs = [image_input,radio_button],
                                       outputs = [textbox1,output_image1,textbox2,output_image2,textbox3,output_image3])
    demo.launch(share = False,debug = True)