File size: 2,590 Bytes
284eba0
5ff014d
284eba0
3e7dbba
42ff0a6
284eba0
0de8536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284eba0
0de8536
 
5ff014d
1d000d6
0de8536
5ff014d
284eba0
5ff014d
284eba0
0de8536
5ff014d
15f8afb
7c5859b
021cf70
 
 
 
 
834fb23
 
 
 
 
 
 
 
 
95eabb4
78cb3e9
284eba0
7c5859b
75b577b
 
 
 
e980426
7c5859b
 
75b577b
182cdaf
 
75b577b
182cdaf
75b577b
37216d8
0de8536
e980426
05b1421
834fb23
05b1421
0de8536
cbb0a6b
3e7dbba
f160696
284eba0
4191bbd
 
 
 
5eab787
f160696
834fb23
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import tensorflow as tf
import gdown
from PIL import Image
import pillow_avif

input_shape = (32, 32, 3)
resized_shape = (224, 224, 3)
num_classes = 10
labels = {
    0: "plane",
    1: "car",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

# Download the model file
def download_model():
    url = "https://drive.google.com/uc?id=1fclkgIgUo26g014beN8UhCJ2TfvGG2sG"
    output = "modelV2Lmixed.keras"
    gdown.download(url, output, quiet=False)
    return output

model_file = download_model()

# Load the model
model = tf.keras.models.load_model(model_file)

# Perform image classification for single class output
# def predict_class(image):
#     img = tf.cast(image, tf.float32)
#     img = tf.image.resize(img, [input_shape[0], input_shape[1]])
#     img = tf.expand_dims(img, axis=0)
#     prediction = model.predict(img)
#     class_index = tf.argmax(prediction[0]).numpy()
#     predicted_class = labels[class_index]
#     return predicted_class

# Perform image classification for multy class output
def predict_class(image):
    img = tf.cast(image, tf.float32)
    img = tf.image.resize(img, [input_shape[0], input_shape[1]])
    img = tf.expand_dims(img, axis=0)
    prediction = model.predict(img)
    return prediction[0]

# UI Design for single class output
# def classify_image(image):
#     predicted_class = predict_class(image)
#     output = f"<h2>Predicted Class: <span style='text-transform:uppercase';>{predicted_class}</span></h2>"
#     return output


# UI Design for multy class output
def classify_image(image):
    results = predict_class(image)
    print("results is ...", results)
    output = {labels.get(i): float(results[i]) for i in range(len(results))}
    print("output is ...", output)
    result = output if max(output.values()) >=0.98 else {"NO_CIFAR10_CLASS": 1}
    return result


inputs = gr.components.Image(type="pil", label="Upload an image")
# outputs = gr.outputs.HTML() #uncomment for single class output 
outputs = gr.components.Label(num_top_classes=4)

title = "<h1 style='text-align: center;'>Image Classifier</h1>"
description = "Upload an image and get the predicted class."
# css_code='body{background-image:url("file=wave.mp4");}'

gr.Interface(fn=classify_image, 
             inputs=inputs, 
             outputs=outputs, 
             title=title, 
             examples=[["00_plane.jpg"], ["01_car.jpg"], ["02_house.jpg"], ["03_cat.jpg"], ["04_deer.jpg"]],
             # css=css_code,
             description=description).launch()