File size: 3,609 Bytes
9a960ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
"""
Gradio demo of image classification with OOD detection.

If the image example is probably OOD, the model will abstain from the prediction.
"""
import os
import pickle
import json
from glob import glob

import gradio as gr
from gradio.components import Image, Label, JSON
import numpy as np
import torch
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from torchvision.models.feature_extraction import create_feature_extractor, get_graph_node_names

import logging

_logger = logging.getLogger(__name__)

device = "cuda" if torch.cuda.is_available() else "cpu"
TOPK = 3

# load model
print("Loading model...")
model = timm.create_model("resnet50", pretrained=True)
model.to(device)
model.eval()

# dataset labels
idx2label = json.loads(open("ilsvrc2012.json").read())
idx2label = {int(k): v for k, v in idx2label.items()}
print(idx2label)

# transformation
config = resolve_data_config({}, model=model)
config["is_training"] = False
transform = create_transform(**config)

# print features names
print(get_graph_node_names(model)[0])

# load train scores
penultimate_features_key = "global_pool.flatten"
logits_key = "fc"
features_names = [penultimate_features_key, logits_key]

# create feature extractor
feature_extractor = create_feature_extractor(model, features_names)

# OOD dtector thresholds
msp_threshold = 0.3796
energy_threshold = 0.3781

## unpickle detectors


def mahalanobis_penult(features):
    scores = torch.norm(features, dim=1, keepdims=True)
    s = torch.min(scores, dim=1)[0]
    return -s.item()


def msp(logits):
    return torch.softmax(logits, dim=1).max(-1)[0].item()


def energy(logits):
    return torch.logsumexp(logits, dim=1).item()


def predict(image):
    # forward pass
    inputs = transform(image).unsqueeze(0)
    with torch.no_grad():
        features = feature_extractor(inputs)

    # top 5 predictions
    probabilities = torch.softmax(features[logits_key], dim=-1)
    softmax, class_idxs = torch.topk(probabilities, TOPK)
    _logger.info(softmax)
    _logger.info(class_idxs)

    result = {idx2label[i.item()]: v.item() for i, v in zip(class_idxs.squeeze(), softmax.squeeze())}
    # OOD
    msp_score = msp(features[logits_key])
    energy_score = energy(features[logits_key])
    ood_scores = {
        "msp": msp_score,
        "msp_is_ood": msp_score < msp_threshold,
        "energy": energy_score,
        "energy_is_ood": energy_score < energy_threshold,
    }
    _logger.info(ood_scores)
    return result, ood_scores


def main():
    # image examples for demo shuffled
    examples = glob("images/imagenet/*.jpg") + glob("images/ood/*.jpg")
    np.random.seed(42)
    np.random.shuffle(examples)

    # gradio interface
    interface = gr.Interface(
        fn=predict,
        inputs=Image(type="pil"),
        outputs=[
            Label(num_top_classes=TOPK, label="Model prediction"),
            JSON(label="OOD scores"),
        ],
        examples=examples,
        examples_per_page=len(examples),
        allow_flagging="never",
        theme="default",
        title="OOD Detection 🧐",
        description="Out-of-distribution (OOD) detection is an essential safety measure for machine learning models. This app demonstrates how these methods can be useful in determining wether the inputs of a ResNet-50 model trained on ImageNet-1K can be trusted by the model. Enjoy the demo!",
    )
    interface.launch(
        server_port=7860,
    )
    interface.close()


if __name__ == "__main__":
    logging.basicConfig(level=logging.WARN)

    gr.close_all()
    main()