veronhii commited on
Commit
707d390
·
1 Parent(s): 5b75e0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -8
app.py CHANGED
@@ -16,13 +16,13 @@ def normalize_image(img):
16
 
17
  def load_model(model_name):
18
  # Load the model based on the model_name input
19
- if model_name == "model1":
20
  return tf.keras.models.load_model("model/Xception.h5")
21
- elif model_name == "model2":
22
  return tf.keras.models.load_model("model/InceptionV3.h5")
23
- elif model_name == "model3":
24
  return tf.keras.models.load_model("model/InceptionResNetV2.h5")
25
- elif model_name == "model4":
26
  return tf.keras.models.load_model("model/DenseNet201.h5")
27
  else:
28
  raise ValueError("Invalid model_name")
@@ -41,14 +41,22 @@ def predict_top_classes(img, model_name):
41
 
42
  return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
43
 
 
 
 
 
 
 
 
44
  interface = gr.Interface(
45
  predict_top_classes,
46
  [
47
  gr.inputs.Image(type='pil'),
48
- gr.inputs.Button(label="Model 1 (Xception)", value="model1"),
49
- gr.inputs.Button(label="Model 2 (InceptionV3)", value="model2"),
50
- gr.inputs.Button(label="Model 3 (InceptionResNetV2)", value="model3"),
51
- gr.inputs.Button(label="Model 4 (DenseNet201)", value="model4")
 
52
  ],
53
  outputs='label'
54
  )
 
16
 
17
  def load_model(model_name):
18
  # Load the model based on the model_name input
19
+ if model_name == 1:
20
  return tf.keras.models.load_model("model/Xception.h5")
21
+ elif model_name == 2:
22
  return tf.keras.models.load_model("model/InceptionV3.h5")
23
+ elif model_name == 3:
24
  return tf.keras.models.load_model("model/InceptionResNetV2.h5")
25
+ elif model_name == 4:
26
  return tf.keras.models.load_model("model/DenseNet201.h5")
27
  else:
28
  raise ValueError("Invalid model_name")
 
41
 
42
  return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
43
 
44
+ models = {
45
+ "Xception": 1,
46
+ "InceptionV3": 2,
47
+ "InceptionResNetV2": 3,
48
+ "DenseNet201": 4
49
+ }
50
+
51
  interface = gr.Interface(
52
  predict_top_classes,
53
  [
54
  gr.inputs.Image(type='pil'),
55
+ gr.Dropdown(
56
+ choices=list(models.keys()),
57
+ value=list(models.values()),
58
+ label="Select a model"
59
+ )
60
  ],
61
  outputs='label'
62
  )