veronhii commited on
Commit
2fee968
·
1 Parent(s): a28b6a5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -27
app.py CHANGED
@@ -1,32 +1,38 @@
 
1
  import gradio as gr
2
  import tensorflow as tf
3
  import numpy as np
4
 
 
5
  num_classes = 200
6
  IMG_HEIGHT = 300
7
  IMG_WIDTH = 300
8
 
 
9
  with open("classlabel.txt", 'r') as file:
10
  CLASS_LABEL = [x.strip() for x in file.readlines()]
11
 
 
12
  def normalize_image(img):
13
  img = tf.cast(img, tf.float32) / 255.0
14
  img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH), method='bilinear')
15
  return img
16
 
 
17
  def load_model(model_name):
18
  # Load the model based on the model_name input
19
- if model_name == 1:
20
  return tf.keras.models.load_model("model/Xception.h5")
21
- elif model_name == 2:
22
  return tf.keras.models.load_model("model/InceptionV3.h5")
23
- elif model_name == 3:
24
  return tf.keras.models.load_model("model/InceptionResNetV2.h5")
25
- elif model_name == 4:
26
  return tf.keras.models.load_model("model/DenseNet201.h5")
27
  else:
28
  raise ValueError("Invalid model_name")
29
 
 
30
  def predict_top_classes(img, model_name):
31
  img = img.convert('RGB')
32
  img_data = normalize_image(img)
@@ -38,32 +44,24 @@ def predict_top_classes(img, model_name):
38
  idx = np.argsort(np.squeeze(temp))[::-1]
39
  top5_value = np.asarray([temp[0][i] for i in idx[0:5]])
40
  top5_idx = idx[0:5]
41
-
 
42
  return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
43
 
44
- models = {
45
- "Xception": 1,
46
- "InceptionV3": 2,
47
- "InceptionResNetV2": 3,
48
- "DenseNet201": 4
49
- }
50
-
51
- def dropdown_example(choice, img):
52
- model_name = models[choice]
53
- return predict_top_classes(img, model_name)
54
-
55
- dropdown = gr.inputs.Dropdown(
56
- choices=list(models.keys()),
57
- type="index",
58
- label="Select a model"
59
- )
60
-
61
- image_input = gr.inputs.Image(type='pil')
62
-
63
  interface = gr.Interface(
64
- fn=dropdown_example,
65
- inputs=[dropdown, image_input],
 
 
 
 
 
 
 
 
66
  outputs='label'
67
  )
68
-
 
69
  interface.launch()
 
1
+ # Import libraries
2
  import gradio as gr
3
  import tensorflow as tf
4
  import numpy as np
5
 
6
+ # Initialize the number of classes, also the image's height and width
7
  num_classes = 200
8
  IMG_HEIGHT = 300
9
  IMG_WIDTH = 300
10
 
11
+ # Open the classlabel.txt to read the class labels
12
  with open("classlabel.txt", 'r') as file:
13
  CLASS_LABEL = [x.strip() for x in file.readlines()]
14
 
15
+ # Function to normalize the image
16
  def normalize_image(img):
17
  img = tf.cast(img, tf.float32) / 255.0
18
  img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH), method='bilinear')
19
  return img
20
 
21
+ # Function to select and load the model
22
  def load_model(model_name):
23
  # Load the model based on the model_name input
24
+ if model_name == "Xception":
25
  return tf.keras.models.load_model("model/Xception.h5")
26
+ elif model_name == "InceptionV3":
27
  return tf.keras.models.load_model("model/InceptionV3.h5")
28
+ elif model_name == "InceptionResNetV2":
29
  return tf.keras.models.load_model("model/InceptionResNetV2.h5")
30
+ elif model_name == "DenseNet201":
31
  return tf.keras.models.load_model("model/DenseNet201.h5")
32
  else:
33
  raise ValueError("Invalid model_name")
34
 
35
+ # Main function, let the model make the prediction on the image uploaded
36
  def predict_top_classes(img, model_name):
37
  img = img.convert('RGB')
38
  img_data = normalize_image(img)
 
44
  idx = np.argsort(np.squeeze(temp))[::-1]
45
  top5_value = np.asarray([temp[0][i] for i in idx[0:5]])
46
  top5_idx = idx[0:5]
47
+
48
+ # Return the top 5 highest probability class labels
49
  return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
50
 
51
+ # Define the interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  interface = gr.Interface(
53
+ predict_top_classes,
54
+ [
55
+ gr.inputs.Image(type='pil'),
56
+ gr.inputs.Dropdown(
57
+ choices=["Xception","InceptionV3","InceptionResNetV2","DenseNet201"],
58
+ type="value",
59
+ label="Select a model",
60
+ info="Base model that done feature extraction and fine-tuning process"
61
+ )
62
+ ]
63
  outputs='label'
64
  )
65
+
66
+ # Launch the interface
67
  interface.launch()