Blazer007 commited on
Commit
6e77c58
·
1 Parent(s): a88f3df

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +59 -0
app.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import gradio as gr
4
+ from huggingface_hub import from_pretrained_keras
5
+
6
+ model = from_pretrained_keras("keras-io/conv_mixer_image_classification")
7
+
8
+ class_names = [
9
+ "Airplane",
10
+ "Automobile",
11
+ "Bird",
12
+ "Cat",
13
+ "Deer",
14
+ "Dog",
15
+ "Frog",
16
+ "Horse",
17
+ "Ship",
18
+ "Truck",
19
+ ]
20
+
21
+ examples = [
22
+ ['./aeroplane.png'],
23
+ ['./horse.png'],
24
+ ['./ship.png'],
25
+ ['./truck.png']
26
+ ]
27
+
28
+ IMG_SIZE = 32
29
+
30
+ def infer(input_image):
31
+ image_tensor = tf.convert_to_tensor(input_image)
32
+ image_tensor.set_shape([None, None, 3])
33
+ image_tensor = tf.image.resize(image_tensor, (IMG_SIZE, IMG_SIZE))
34
+ predictions = model.predict(np.expand_dims((image_tensor), axis=0))
35
+ predictions = np.squeeze(predictions)
36
+ predictions = np.argmax(predictions)
37
+ predicted_label = class_names[predictions.item()]
38
+ return str(predicted_label)
39
+
40
+
41
+ input = gr.inputs.Image(shape=(IMG_SIZE, IMG_SIZE))
42
+ output = [gr.outputs.Label(label = "Output")]
43
+
44
+ title = "Image Classification using Conv Mixer Model"
45
+ description = "Upload an image or select from examples to classify it.<br>The allowed classes are - Airplane, Automobile, Bird, Cat, Deer, Dog, Frog, Horse, Ship, Truck.<br><p><b>Model Repo - https://huggingface.co/keras-io/conv_mixer_image_classification</b> <br><b>Keras Example - https://keras.io/examples/vision/convmixer//</b></p>"
46
+
47
+
48
+ article = "<div style='text-align: center;'><a href='https://twitter.com/_Blazer_007' target='_blank'>Space by Vivek Rai</a><br><a href='https://twitter.com/RisingSayak' target='_blank'>Keras example by Sayak Paul</a></div>"
49
+
50
+ gr_interface = gr.Interface(
51
+ infer,
52
+ input,
53
+ output,
54
+ examples=examples,
55
+ allow_flagging=False,
56
+ analytics_enabled=False,
57
+ title=title,
58
+ description=description,
59
+ article=article).launch(enable_queue=True, debug=True)