File size: 2,128 Bytes
49bc02a
 
 
 
 
 
 
 
 
 
 
5b7f9a4
49bc02a
 
5b7f9a4
49bc02a
 
 
 
 
 
5b7f9a4
49bc02a
 
 
5b7f9a4
49bc02a
 
 
 
5b7f9a4
49bc02a
 
 
 
5b7f9a4
49bc02a
 
 
 
 
 
 
5b7f9a4
49bc02a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7f9a4
 
 
 
 
 
49bc02a
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import cv2
import json
import gradio as gr
import numpy as np
import tensorflow as tf

from backbone import create_name_vit
from backbone import ClassificationModel



vit_l16_512 = {
    "backbone_name": "vit-l/16",
    "backbone_params": {
        "image_size": 512,
        "representation_size": 0,
        "attention_dropout_rate": 0.,
        "dropout_rate": 0.,
        "channels": 3
    },
    "dropout_rate": 0.,
    "pretrained": "./weights/vit_l16_512/model-weights"
}

# Init backbone
backbone = create_name_vit(vit_l16_512["backbone_name"], **vit_l16_512["backbone_params"])

# Init classification model
model = ClassificationModel(
    backbone=backbone,
    dropout_rate=vit_l16_512["dropout_rate"],
    num_classes=1000
)

# Load weights
model.load_weights(vit_l16_512["pretrained"])
model.trainable = False

# Load ImageNet idx to label mapping
with open("assets/imagenet_1000_idx2labels.json") as f:
    idx_to_label = json.load(f)


def resize_with_normalization(image, size=[512, 512]):
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, size)
    image -= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32)
    image /= tf.constant(127.5, shape=(1, 1, 3), dtype=tf.float32)
    image = tf.expand_dims(image, axis=0)
    return image

def softmax_stable(x):
    return(np.exp(x - np.max(x)) / np.exp(x - np.max(x)).sum())

def classify_image(img, top_k):
    img = tf.convert_to_tensor(img)
    img = resize_with_normalization(img)

    pred_logits = model.predict(img, batch_size=1, workers=8)[0]
    pred_probs = softmax_stable(pred_logits)
    top_k_labels = pred_probs.argsort()[-top_k:][::-1]

    return {idx_to_label[str(idx)] : round(float(pred_probs[idx]), 4) for idx in top_k_labels}


demo = gr.Interface(
    classify_image, 
    inputs=[gr.Image(), gr.Slider(0, 1000, value=5)], 
    outputs=gr.outputs.Label(),
    title="Image Classification with Kakao Brain ViT",
    examples=[
        ["assets/halloween-gaf8ad7ebc_1920.jpeg", 5],
        ["assets/IMG_4484.jpeg", 5],
        ["assets/IMG_4737.jpeg", 5],
        ["assets/IMG_4740.jpeg", 5],
    ],
)
demo.launch()