File size: 2,453 Bytes
284eba0
 
 
3e7dbba
284eba0
0de8536
 
 
 
 
 
 
 
 
 
 
 
 
 
 
284eba0
0de8536
 
 
 
 
 
284eba0
0de8536
284eba0
0de8536
 
15f8afb
0de8536
92e65a5
 
 
 
 
 
 
 
 
0de8536
 
 
4e2ea8a
0de8536
7f5ab8a
284eba0
0de8536
e980426
 
 
 
 
0de8536
e980426
7f17609
9822b4d
 
 
 
7f17609
cbb0a6b
0de8536
e980426
 
49ddf98
e980426
49ddf98
0de8536
cbb0a6b
3e7dbba
f160696
284eba0
4191bbd
 
 
 
49ddf98
f160696
49ddf98
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import gradio as gr
import tensorflow as tf
import gdown
from PIL import Image

input_shape = (32, 32, 3)
resized_shape = (224, 224, 3)
num_classes = 10
labels = {
    0: "plane",
    1: "car",
    2: "bird",
    3: "cat",
    4: "deer",
    5: "dog",
    6: "frog",
    7: "horse",
    8: "ship",
    9: "truck",
}

# Download the model file
def download_model():
    url = "https://drive.google.com/uc?id=12700bE-pomYKoVQ214VrpBoJ7akXcTpL"
    output = "modelV2Lmixed.keras"
    gdown.download(url, output, quiet=False)
    return output

model_file = download_model()

# Load the model
model = tf.keras.models.load_model(model_file)

# Perform image classification
# def predict_class(image):
#     img = tf.cast(image, tf.float32)
#     img = tf.image.resize(img, [input_shape[0], input_shape[1]])
#     img = tf.expand_dims(img, axis=0)
#     prediction = model.predict(img)
#     class_index = tf.argmax(prediction[0]).numpy()
#     predicted_class = labels[class_index]
#     return predicted_class

def predict_class(image):
    img = tf.cast(image, tf.float32)
    img = tf.image.resize(img, [input_shape[0], input_shape[1]])
    img = tf.expand_dims(img, axis=0)
    prediction = model.predict(img)
    return prediction[0]

# UI Design
# def classify_image(image):
#     predicted_class = predict_class(image)
#     output = f"<h2>Predicted Class: <span style='text-transform:uppercase';>{predicted_class}</span></h2>"
#     return output

def classify_image(image):
    results = predict_class(image)
    # output = {}
    # for index in range(len(results)):
    #     predicted_label = labels.get(index)
    #     score = results[index]
    #     output[predicted_label] = str(score)
    output = {labels.get(i): float(results[i]) for i in range(len(results))}
    return output



inputs = gr.inputs.Image(type="pil", label="Upload an image")
# outputs = gr.outputs.HTML()
outputs = gr.outputs.Label(num_top_classes=5)

title = "<h1 style='text-align: center;'>Image Classifier</h1>"
description = "Upload an image and get the predicted class."
# css_code='body{background-image:url("file=wave.mp4");}'

gr.Interface(fn=classify_image, 
             inputs=inputs, 
             outputs=outputs, 
             title=title, 
             examples=[["00_plane.jpg"], ["01_car.jpg"], ["02_bird.jpg"], ["03_cat.jpg"], ["04_deer.jpg"]],
             # css=css_code,
             description=description,
            enable_queue=True).launch()