veronhii commited on
Commit
5b75e0d
·
1 Parent(s): 363efde

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -66
app.py CHANGED
@@ -1,67 +1,55 @@
1
- # import gradio as gr
2
- # import tensorflow as tf
3
- # import numpy as np
4
-
5
- # num_classes = 200
6
- # IMG_HEIGHT = 300
7
- # IMG_WIDTH = 300
8
-
9
- # with open("classlabel.txt", 'r') as file:
10
- # CLASS_LABEL = [x.strip() for x in file.readlines()]
11
-
12
- # def normalize_image(img):
13
- # img = tf.cast(img, tf.float32) / 255.0
14
- # img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH), method='bilinear')
15
- # return img
16
-
17
- # def load_model(model_name):
18
- # # Load the model based on the model_name input
19
- # if model_name == "model1":
20
- # return tf.keras.models.load_model("model/Xception.h5")
21
- # elif model_name == "model2":
22
- # return tf.keras.models.load_model("model/InceptionV3.h5")
23
- # elif model_name == "model3":
24
- # return tf.keras.models.load_model("model/InceptionResNetV2.h5")
25
- # elif model_name == "model4":
26
- # return tf.keras.models.load_model("model/DenseNet201.h5")
27
- # else:
28
- # raise ValueError("Invalid model_name")
29
-
30
- # def predict_top_classes(img, model_name):
31
- # img = img.convert('RGB')
32
- # img_data = normalize_image(img)
33
- # x = np.array(img_data)
34
- # x = np.expand_dims(x, axis=0)
35
- # model = load_model(model_name)
36
- # temp = model.predict(x)
37
-
38
- # idx = np.argsort(np.squeeze(temp))[::-1]
39
- # top5_value = np.asarray([temp[0][i] for i in idx[0:5])
40
- # top5_idx = idx[0:5]
41
-
42
- # return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
43
-
44
- # interface = gr.Interface(
45
- # predict_top_classes,
46
- # [
47
- # gr.inputs.Image(type='pil'),
48
- # gr.inputs.Button(label="Model 1 (Xception)", value="model1"),
49
- # gr.inputs.Button(label="Model 2 (InceptionV3)", value="model2"),
50
- # gr.inputs.Button(label="Model 3 (InceptionResNetV2)", value="model3"),
51
- # gr.inputs.Button(label="Model 4 (DenseNet201)", value="model4")
52
- # ],
53
- # outputs='label'
54
- # )
55
- # interface.launch()
56
-
57
-
58
  import gradio as gr
59
-
60
-
61
- def greet(name):
62
- return "Hello " + name
63
-
64
-
65
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
66
-
67
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+
5
+ num_classes = 200
6
+ IMG_HEIGHT = 300
7
+ IMG_WIDTH = 300
8
+
9
+ with open("classlabel.txt", 'r') as file:
10
+ CLASS_LABEL = [x.strip() for x in file.readlines()]
11
+
12
+ def normalize_image(img):
13
+ img = tf.cast(img, tf.float32) / 255.0
14
+ img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH), method='bilinear')
15
+ return img
16
+
17
+ def load_model(model_name):
18
+ # Load the model based on the model_name input
19
+ if model_name == "model1":
20
+ return tf.keras.models.load_model("model/Xception.h5")
21
+ elif model_name == "model2":
22
+ return tf.keras.models.load_model("model/InceptionV3.h5")
23
+ elif model_name == "model3":
24
+ return tf.keras.models.load_model("model/InceptionResNetV2.h5")
25
+ elif model_name == "model4":
26
+ return tf.keras.models.load_model("model/DenseNet201.h5")
27
+ else:
28
+ raise ValueError("Invalid model_name")
29
+
30
+ def predict_top_classes(img, model_name):
31
+ img = img.convert('RGB')
32
+ img_data = normalize_image(img)
33
+ x = np.array(img_data)
34
+ x = np.expand_dims(x, axis=0)
35
+ model = load_model(model_name)
36
+ temp = model.predict(x)
37
+
38
+ idx = np.argsort(np.squeeze(temp))[::-1]
39
+ top5_value = np.asarray([temp[0][i] for i in idx[0:5]])
40
+ top5_idx = idx[0:5]
41
+
42
+ return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
43
+
44
+ interface = gr.Interface(
45
+ predict_top_classes,
46
+ [
47
+ gr.inputs.Image(type='pil'),
48
+ gr.inputs.Button(label="Model 1 (Xception)", value="model1"),
49
+ gr.inputs.Button(label="Model 2 (InceptionV3)", value="model2"),
50
+ gr.inputs.Button(label="Model 3 (InceptionResNetV2)", value="model3"),
51
+ gr.inputs.Button(label="Model 4 (DenseNet201)", value="model4")
52
+ ],
53
+ outputs='label'
54
+ )
55
+ interface.launch()