File size: 4,334 Bytes
e8085fb
af134f4
da716ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af134f4
 
da716ed
 
 
 
 
 
 
 
 
 
af134f4
da716ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
af134f4
da716ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e8085fb
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
os.system("pip install datasets einops tabulate opencv-python ttach -U")

import os, cv2, time, math
print("=> Loading libraries...")
start = time.time()

import requests, torch
import gradio as gr
from torchvision import transforms
from datasets import load_dataset
from timm.data import create_transform
from timm.models import create_model, load_checkpoint
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image


device = "cuda" if torch.cuda.is_available() else "cpu"

print(f"=> Libraries loaded in {time.time()- start:.2f} sec(s).")
print("=> Loading model...")
start = time.time()

size = "b"
img_size = 224
crop_pct = 0.9
IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)

model = create_model(f"tpmlp_{size}").to(device)
load_checkpoint(model, f"../tpmlp_{size}.pth.tar", True)
model.eval()

response = requests.get("https://git.io/JJkYN")
labels = response.text.split("\n")

augs = create_transform(
    input_size=(3, 224, 224),
    is_training=False,
    use_prefetcher=False,
    crop_pct=0.9,
)


scale_size = math.floor(img_size / crop_pct)
resize = transforms.Compose([
    transforms.Resize(scale_size),
    transforms.CenterCrop(img_size),
    transforms.ToTensor()
])
normalize = transforms.Normalize(mean=torch.tensor(IMAGENET_DEFAULT_MEAN), std=torch.tensor(IMAGENET_DEFAULT_STD))

def transform(img):
    img = resize(img.convert("RGB"))
    tensor = normalize(img)
    return img, tensor

def predict(inp):
    img, inp = transform(inp)
    inp = inp.unsqueeze(0)
    with GradCAM(model=model, target_layers=[model.layers[3]], use_cuda=device=="cuda") as cam:
        grayscale_cam, probs = cam(input_tensor=inp, aug_smooth=False, eigen_smooth=False, return_probs=True)
        
        # Here grayscale_cam has only one image in the batch
        grayscale_cam = grayscale_cam[0, :]
        probs = probs[0, :]

        cam_image = show_cam_on_image(img.permute(1, 2, 0).detach().cpu().numpy(), grayscale_cam, use_rgb=True, image_weight=0.5, colormap=cv2.COLORMAP_TWILIGHT_SHIFTED)
        confidences = {labels[i]: float(probs[i]) for i in range(1000)}
    return confidences, cam_image

print(f"=> Model (tpmlp_{size}) loaded in {time.time()- start:.2f} sec(s).")

if not os.path.isdir("../example-imgs"):
    os.mkdir("../example-imgs")

print("=> Loading examples.")
indices = [
    0,      # Coucal
    2,      # Volcano
    7,      # Sombrero
    9,      # Balance beam
    10,     # Sulphur-crested cockatoo
    11,     # Shower cap
    12,     # Petri dish INCORRECTLY CLASSIFIED as lens
    14,     # Angora rabbit
]
ds = load_dataset("imagenet-1k", split="validation", streaming=True)
examples = []; idx = 0
start = time.time()
for data in ds:
    if idx == indices:
        data['image'].save(f"../example-imgs/{idx}.png")
    idx += 1
    if idx == max(indices):
        break
del ds
print(f"=> Examples loaded in {time.time()- start:.2f} sec(s).")

# demo = gr.Interface(
#     fn=predict, 
#     inputs=gr.inputs.Image(type="pil"),
#     outputs=[gr.outputs.Label(num_top_classes=4), gr.outputs.Image(type="numpy")],
#     examples=[f"../example-imgs/{idx}.png" for idx in indices],
# )


with gr.Blocks(theme=gr.themes.Monochrome(font=[gr.themes.GoogleFont("DM Sans"), "sans-serif"])) as demo:
    gr.HTML("""
    <h1 align="center">Interactive Demo</h1>
    <h2 align="center">CS-Mixer: A Cross-Scale Vision MLP Model with Spatial–Channel Mixing</h2>
    <br><br>
    """)
    with gr.Row():
        input_image = gr.Image(type="pil", min_width=300, label="Input Image")
        softmax = gr.Label(num_top_classes=4, min_width=200, label="Model Predictions")
        grad_cam = gr.Image(type="numpy", min_width=300, label="Grad-CAM")
    with gr.Row():
        gr.Button("Predict").click(fn=predict, inputs=input_image, outputs=[softmax, grad_cam])
        gr.ClearButton(input_image)
    with gr.Row():
        gr.Examples([f"../example-imgs/{idx}.png" for idx in indices], inputs=input_image, outputs=[softmax, grad_cam], fn=predict, run_on_click=True)
            
demo.launch(
    share=False, debug=False, allowed_paths=["../example-imgs"], server_name="0.0.0.0", # ssl_verify=False,
    server_port=8000, # ssl_certfile="/workspace/openssl/cert.pem", ssl_keyfile="/workspace/openssl/key.pem"
)