veronhii commited on
Commit
3662000
·
1 Parent(s): 2710676

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +35 -17
app.py CHANGED
@@ -14,21 +14,39 @@ def normalize_image(img):
14
  img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH), method='bilinear')
15
  return img
16
 
17
- def predict_top_classes(img):
18
- img = img.convert('RGB')
19
- img_data = normalize_image(img)
20
- x = np.array(img_data)
21
- x = np.expand_dims(x, axis=0)
22
- temp = model.predict(x)
23
-
24
- idx = np.argsort(np.squeeze(temp))[::-1]
25
- top5_value = np.asarray([temp[0][i] for i in idx[0:5]])
26
- top5_idx = idx[0:5]
27
-
28
- return {CLASS_LABEL[i]:str(v) for i,v in zip(top5_idx,top5_value)}
29
-
30
-
31
- model = tf.keras.models.load_model("Xception3.h5")
32
-
33
- interface = gr.Interface(predict_top_classes, gr.inputs.Image(type='pil'), outputs='label')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  interface.launch()
 
14
  img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH), method='bilinear')
15
  return img
16
 
17
+ def load_model(model_name):
18
+ # Load the model based on the model_name input
19
+ if model_name == "model1":
20
+ return tf.keras.models.load_model("model1.h5")
21
+ elif model_name == "model2":
22
+ return tf.keras.models.load_model("model2.h5")
23
+ elif model_name == "model3":
24
+ return tf.keras.models.load_model("model3.h5")
25
+ else:
26
+ raise ValueError("Invalid model_name")
27
+
28
+ def predict_top_classes(img, model_name):
29
+ img = img.convert('RGB')
30
+ img_data = normalize_image(img)
31
+ x = np.array(img_data)
32
+ x = np.expand_dims(x, axis=0)
33
+ model = load_model(model_name)
34
+ temp = model.predict(x)
35
+
36
+ idx = np.argsort(np.squeeze(temp))[::-1]
37
+ top5_value = np.asarray([temp[0][i] for i in idx[0:5])
38
+ top5_idx = idx[0:5]
39
+
40
+ return {CLASS_LABEL[i]: str(v) for i, v in zip(top5_idx, top5_value)}
41
+
42
+ interface = gr.Interface(
43
+ predict_top_classes,
44
+ [
45
+ gr.inputs.Image(type='pil'),
46
+ gr.inputs.Button(label="Model 1 (Xception)", value="model1"),
47
+ gr.inputs.Button(label="Model 2 (InceptionV3)", value="model2"),
48
+ gr.inputs.Button(label="Model 3 (InceptionResNetV2)", value="model3"),
49
+ ],
50
+ outputs='label'
51
+ )
52
  interface.launch()