AnsMurtaza commited on
Commit
1bc005f
·
verified ·
1 Parent(s): 0a1a23c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +214 -0
app.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ import tempfile
5
+ import numpy as np
6
+ import traceback
7
+ from tensorflow.keras.models import load_model
8
+ from tensorflow.keras import backend as K
9
+ from tensorflow.keras.preprocessing import image as keras_image
10
+
11
+ print("Initializing Gradio app...")
12
+
13
+ # Update this path to match where your .h5 models actually live
14
+ MODEL_DIR = r"D:\ML_Bench\License_Plate\mlp\vehicle-rear-master\models"
15
+
16
+
17
+ def get_model_choices():
18
+ if not os.path.exists(MODEL_DIR):
19
+ print(f"[ERROR] Model directory not found: {MODEL_DIR}")
20
+ return [], MODEL_DIR
21
+
22
+ model_files = [f for f in os.listdir(MODEL_DIR) if f.lower().endswith('.h5')]
23
+ model_files.sort()
24
+ print(f"[INFO] Found model files: {model_files}")
25
+ return model_files, MODEL_DIR
26
+
27
+
28
+ def process_load(img_path, shape_tuple):
29
+ """
30
+ Load and convert the image at img_path to a numpy array
31
+ of shape (height, width, 3) matching shape_tuple, scaled to [0,1].
32
+ """
33
+ try:
34
+ target_h, target_w = shape_tuple[0], shape_tuple[1]
35
+ print(f"[DEBUG] process_load: loading '{img_path}' at size ({target_h}, {target_w})")
36
+ img = keras_image.load_img(img_path, target_size=(target_h, target_w))
37
+ x = keras_image.img_to_array(img)
38
+ print(f"[DEBUG] process_load: result array shape {x.shape}")
39
+ return x
40
+ except Exception as e:
41
+ print(f"[ERROR] process_load failed for {img_path}: {e}")
42
+ traceback.print_exc()
43
+ raise
44
+
45
+
46
+ def predict(img1, img2, model_file_name):
47
+ print("\n[DEBUG] predict() called.")
48
+ print(f" img1 type: {type(img1)}, img2 type: {type(img2)}, model_file_name: {model_file_name}")
49
+
50
+ # 1. Make sure both images are uploaded
51
+ if img1 is None or img2 is None:
52
+ print("[ERROR] One or both images were not uploaded.")
53
+ return "Error: You must upload both Image 1 and Image 2."
54
+
55
+ # 2. Get list of available models each time
56
+ model_files, model_dir = get_model_choices()
57
+ if not model_files:
58
+ print("[ERROR] No .h5 model files found in directory.")
59
+ return f"Error: No .h5 model files found in {MODEL_DIR}."
60
+
61
+ # 3. Make sure a model was actually selected
62
+ if model_file_name is None or model_file_name not in model_files:
63
+ print(f"[ERROR] Invalid or missing model selection: {model_file_name}")
64
+ return "Error: Please select a valid model from the dropdown."
65
+
66
+ model_path = os.path.join(model_dir, model_file_name)
67
+ print(f"[INFO] Loading model from: {model_path}")
68
+
69
+ img1_path = img2_path = None
70
+ try:
71
+ # 4. Infer input size based on filename convention (defensively)
72
+ basename = os.path.splitext(model_file_name)[0]
73
+ parts = basename.split('_')
74
+ name = ""
75
+ if len(parts) >= 3 and parts[0].lower() == "model" and parts[1].lower() == "shape":
76
+ name = parts[2].lower()
77
+ print(f"[DEBUG] Inferred model base name: {name}")
78
+ else:
79
+ name = "smallvgg"
80
+ print(f"[DEBUG] Filename did not match 'model_shape_X'; using default '{name}'")
81
+
82
+ if name in ['resnet50', 'vgg16']:
83
+ image_size_h_c, image_size_w_c = 224, 224
84
+ elif name == 'googlenet':
85
+ image_size_h_c, image_size_w_c = 112, 112
86
+ else:
87
+ image_size_h_c, image_size_w_c = 128, 128
88
+
89
+ input_shape = (image_size_h_c, image_size_w_c, 3)
90
+ print(f"[INFO] Inferred input_shape = {input_shape} based on '{name}'")
91
+
92
+ # 5. Save PIL images to temporary files
93
+ try:
94
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f1:
95
+ img1_path = f1.name
96
+ img1.save(img1_path)
97
+ print(f"[DEBUG] img1 saved to: {img1_path} ({os.path.getsize(img1_path)} bytes)")
98
+
99
+ with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as f2:
100
+ img2_path = f2.name
101
+ img2.save(img2_path)
102
+ print(f"[DEBUG] img2 saved to: {img2_path} ({os.path.getsize(img2_path)} bytes)")
103
+
104
+ except Exception as e:
105
+ print(f"[ERROR] Failed to save uploaded images: {e}")
106
+ traceback.print_exc()
107
+ raise
108
+
109
+ # 6. Preprocess images
110
+ try:
111
+ img1_arr = process_load(img1_path, input_shape) / 255.0
112
+ img2_arr = process_load(img2_path, input_shape) / 255.0
113
+ img1_batch = np.expand_dims(img1_arr, axis=0)
114
+ img2_batch = np.expand_dims(img2_arr, axis=0)
115
+ print(f"[DEBUG] Prepared batches: img1_batch.shape = {img1_batch.shape}, img2_batch.shape = {img2_batch.shape}")
116
+ except Exception as e:
117
+ print(f"[ERROR] Preprocessing failed: {e}")
118
+ traceback.print_exc()
119
+ raise
120
+
121
+ # 7. Clear previous Keras session (avoid GPU OOM if repeated calls)
122
+ print("[DEBUG] Clearing Keras session.")
123
+ K.clear_session()
124
+
125
+ # 8. Load model
126
+ try:
127
+ print("[DEBUG] Calling load_model...")
128
+ model = load_model(model_path, compile=False)
129
+ print(f"[INFO] Model loaded successfully. Model inputs: {model.input_shape}")
130
+ except Exception as e:
131
+ error_str = f"Prediction failed: could not load model '{model_file_name}': {e}"
132
+ print(f"[ERROR] {error_str}")
133
+ traceback.print_exc()
134
+ return error_str
135
+
136
+ # 9. Perform prediction
137
+ try:
138
+ print("[DEBUG] Calling model.predict...")
139
+ Y_ = model.predict([img1_batch, img2_batch])
140
+ print(f"[DEBUG] Raw model output: {Y_}")
141
+ except Exception as e:
142
+ error_str = f"Prediction failed during model.predict: {e}"
143
+ print(f"[ERROR] {error_str}")
144
+ traceback.print_exc()
145
+ return error_str
146
+
147
+ # 10. Interpret prediction
148
+ try:
149
+ pred_idx = int(np.argmax(Y_[0]))
150
+ pred = 'positive' if pred_idx == 1 else 'negative'
151
+ confidence = float(np.max(Y_[0]))
152
+ result_str = f"Prediction: {pred} (confidence: {confidence:.2f})"
153
+ print(f"[INFO] {result_str}")
154
+ return result_str
155
+ except Exception as e:
156
+ error_str = f"Prediction failed while interpreting output: {e}"
157
+ print(f"[ERROR] {error_str}")
158
+ traceback.print_exc()
159
+ return error_str
160
+
161
+ except Exception as e:
162
+ # Catches any “unexpected” error we didn’t already return
163
+ error_msg = f"Prediction failed with unexpected error: {str(e)}"
164
+ print(f"[ERROR] {error_msg}")
165
+ traceback.print_exc()
166
+ return error_msg
167
+
168
+ finally:
169
+ # 11. Cleanup temp files
170
+ for path in [img1_path, img2_path]:
171
+ if path and os.path.exists(path):
172
+ try:
173
+ os.remove(path)
174
+ print(f"[DEBUG] Deleted temp file: {path}")
175
+ except Exception as e:
176
+ print(f"[WARNING] Could not delete temp file {path}: {e}")
177
+ traceback.print_exc()
178
+
179
+
180
+ def main():
181
+ model_files, _ = get_model_choices()
182
+ if not model_files:
183
+ print(f"[ERROR] No model files found. Please put .h5 files in {MODEL_DIR} before running.")
184
+ else:
185
+ print(f"[INFO] Models available: {model_files}")
186
+
187
+ with gr.Blocks() as demo:
188
+ gr.Markdown("# Siamese Shape Stream Prediction App")
189
+ gr.Markdown("Upload two images and select a model to get a similarity prediction.")
190
+
191
+ with gr.Row():
192
+ img1 = gr.Image(label="Image 1", type="pil")
193
+ img2 = gr.Image(label="Image 2", type="pil")
194
+
195
+ model_dropdown = gr.Dropdown(
196
+ choices=model_files,
197
+ value=model_files[0] if model_files else None,
198
+ label="Model File (.h5)",
199
+ interactive=True
200
+ )
201
+
202
+ predict_btn = gr.Button("Predict")
203
+ output = gr.Textbox(label="Prediction Output", lines=2)
204
+
205
+ predict_btn.click(
206
+ fn=predict,
207
+ inputs=[img1, img2, model_dropdown],
208
+ outputs=output
209
+ )
210
+
211
+ demo.launch(share=True, server_name="0.0.0.0", server_port=7860, debug=True, enable_queue=True)
212
+
213
+ if __name__ == "__main__":
214
+ main()