Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -16,13 +16,13 @@ def normalize_image(img):
|
|
16 |
|
17 |
def load_model(model_name):
|
18 |
# Load the model based on the model_name input
|
19 |
-
if model_name ==
|
20 |
return tf.keras.models.load_model("model/Xception.h5")
|
21 |
-
elif model_name ==
|
22 |
return tf.keras.models.load_model("model/InceptionV3.h5")
|
23 |
-
elif model_name ==
|
24 |
return tf.keras.models.load_model("model/InceptionResNetV2.h5")
|
25 |
-
elif model_name ==
|
26 |
return tf.keras.models.load_model("model/DenseNet201.h5")
|
27 |
else:
|
28 |
raise ValueError("Invalid model_name")
|
@@ -41,14 +41,22 @@ def predict_top_classes(img, model_name):
|
|
41 |
|
42 |
return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
|
43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
interface = gr.Interface(
|
45 |
predict_top_classes,
|
46 |
[
|
47 |
gr.inputs.Image(type='pil'),
|
48 |
-
gr.
|
49 |
-
|
50 |
-
|
51 |
-
|
|
|
52 |
],
|
53 |
outputs='label'
|
54 |
)
|
|
|
16 |
|
17 |
def load_model(model_name):
|
18 |
# Load the model based on the model_name input
|
19 |
+
if model_name == 1:
|
20 |
return tf.keras.models.load_model("model/Xception.h5")
|
21 |
+
elif model_name == 2:
|
22 |
return tf.keras.models.load_model("model/InceptionV3.h5")
|
23 |
+
elif model_name == 3:
|
24 |
return tf.keras.models.load_model("model/InceptionResNetV2.h5")
|
25 |
+
elif model_name == 4:
|
26 |
return tf.keras.models.load_model("model/DenseNet201.h5")
|
27 |
else:
|
28 |
raise ValueError("Invalid model_name")
|
|
|
41 |
|
42 |
return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
|
43 |
|
44 |
+
models = {
|
45 |
+
"Xception": 1,
|
46 |
+
"InceptionV3": 2,
|
47 |
+
"InceptionResNetV2": 3,
|
48 |
+
"DenseNet201": 4
|
49 |
+
}
|
50 |
+
|
51 |
interface = gr.Interface(
|
52 |
predict_top_classes,
|
53 |
[
|
54 |
gr.inputs.Image(type='pil'),
|
55 |
+
gr.Dropdown(
|
56 |
+
choices=list(models.keys()),
|
57 |
+
value=list(models.values()),
|
58 |
+
label="Select a model"
|
59 |
+
)
|
60 |
],
|
61 |
outputs='label'
|
62 |
)
|