File size: 4,007 Bytes
da716ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120

import os, cv2, time, math
print("=> Loading libraries...")
start = time.time()

import requests, torch
import gradio as gr
from torchvision import transforms
from datasets import load_dataset
from timm.data import create_transform
from timm.models import create_model, load_checkpoint
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).")
print("=> Loading model...")
start = time.time()

size = "b"
img_size = 224
crop_pct = 0.9
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

model = create_model(f"tpmlp_{size}").cuda()
load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True)
model.eval()

response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

augs = create_transform(
    input_size=(3, 224, 224),
    is_training=False,
    use_prefetcher=False,
    crop_pct=0.9,
)


scale_size = math.floor(img_size / crop_pct)
resize = transforms.Compose([
    transforms.Resize(scale_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor()
])
normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD))

def transform(img):
    img = resize(img.convert("RGB"))
    tensor = normalize(img)
    return img, tensor

def predict(inp):
    img, inp = transform(inp)
    inp = inp.unsqueeze(0)
    with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=True) as cam:
        grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True)
        
        # Here grayscale_cam has only one image in the batch
        grayscale_cam = grayscale_cam[0, :]
        probs = probs[0, :]

        cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED)
        confidences = {labels[i]: float(probs[i]) for i in range(1000)}
    return confidences, cam_image

print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).")

if not os.path.isdir("../example-imgs"):
    os.mkdir("../example-imgs")

print("=> Loading examples.")
indices = [
    0,      # Coucal
    2,      # Volcano
    7,      # Sombrero
    9,      # Balance beam
    10,     # Sulphur-crested cockatoo
    11,     # Shower cap
    12,     # Petri dish INCORRECTLY CLASSIFIED as lens
    14,     # Angora rabbit
]
ds = load_dataset("imagenet-1k", split="validation", streaming=True)
examples = []; idx = 0
start = time.time()
for data in ds:
    if idx == indices:
        data['image'].save(f"../example-imgs/{idx}.png")
    idx += 1
    if idx == max(indices):
        break
del ds
print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).")

# demo = gr.Interface(
#     fn=predict, 
#     inputs=gr.inputs.Image(type="pil"),
#     outputs=[gr.outputs.Label(num_top_classes=4), gr.outputs.Image(type="numpy")],
#     examples=[f"../example-imgs/{idx}.png" for idx in indices],
# )


with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo:
    gr.HTML("""
    <h1 align="center">Interactive Demo</h1>
    <h2 align="center">CS-Mixer: A Cross-Scale Vision MLP Model with Spatial–Channel Mixing</h2>
    <br><br>
    """)
    with gr.Row():
        input_image = gr.Image(type="pil", min_width=300, label="Input Image")
        softmax = gr.Label(num_top_classes=4, min_width=200, label="Model Predictions")
        grad_cam = gr.Image(type="numpy", min_width=300, label="Grad-CAM")
    with gr.Row():
        gr.Button("Predict").click(fn=predict, inputs=input_image, outputs=[softmax, grad_cam])
        gr.ClearButton(input_image)
    with gr.Row():
        gr.Examples([f"../example-imgs/{idx}.png" for idx in indices], inputs=input_image, outputs=[softmax, grad_cam], fn=predict, run_on_click=True)
            
demo.launch(share=True, allowed_paths=["../example-imgs"])