Sompote commited on
Commit
e1a29fd
Β·
verified Β·
1 Parent(s): 99b3da4

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +164 -173
app.py CHANGED
@@ -54,7 +54,7 @@ def initialize_models():
54
  model.eval()
55
 
56
  # Load YOLO model
57
- yolo_model_path = "yolo11s.onnx" # Going up one directory since the app.py is in API22_FEB
58
  if not os.path.exists(yolo_model_path):
59
  st.error(f"YOLO model file not found: {yolo_model_path}")
60
  return device, model, None
@@ -168,197 +168,188 @@ def merge_overlapping_detections(detections, iou_threshold=0.5):
168
  return merged_detections
169
 
170
  def main():
171
- st.title("Traffic Light Detection with Protection Area")
172
 
173
- # Initialize session state for protection area points
174
  if 'points' not in st.session_state:
175
  st.session_state.points = []
176
- if 'processing_done' not in st.session_state:
177
- st.session_state.processing_done = False
 
 
178
 
179
- # File uploader
180
- uploaded_file = st.file_uploader("Choose an image", type=['jpg', 'jpeg', 'png'])
181
 
182
- if uploaded_file is not None:
183
- # Convert uploaded file to PIL Image
184
- image = Image.open(uploaded_file).convert('RGB')
185
 
186
- # Convert to OpenCV format for drawing
187
- cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
188
- height, width = cv_image.shape[:2]
189
 
190
- # Create a copy for drawing
191
- draw_image = cv_image.copy()
192
-
193
- # Instructions
194
- st.write("πŸ‘† Click directly on the image to add points for the protection area (need 4 points)")
195
- st.write("πŸ”„ Click 'Reset Points' to start over")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
- # Reset button
198
- if st.button('Reset Points'):
199
- st.session_state.points = []
200
- st.session_state.processing_done = False
201
- st.rerun()
202
 
203
- # Display current image with points
204
- if len(st.session_state.points) > 0:
205
- # Draw existing points and lines
206
- points = np.array(st.session_state.points, dtype=np.int32)
207
- cv2.polylines(draw_image, [points],
208
- True if len(points) == 4 else False,
209
- (0, 255, 0), 2)
210
- # Draw points with numbers
211
- for i, point in enumerate(points):
212
- cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
213
- cv2.putText(draw_image, str(i+1),
214
- (point[0]+10, point[1]+10),
215
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
216
 
217
- # Create columns for better layout
218
- col1, col2 = st.columns([4, 1])
219
 
220
- with col1:
221
- # Display the image and handle click events
222
- if len(st.session_state.points) < 4 and not st.session_state.processing_done:
223
- # Create a placeholder for the image
224
- image_placeholder = st.empty()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- # Display the image with current points
227
- clicked = streamlit_image_coordinates(
228
- cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB),
229
- key=f"image_coordinates_{len(st.session_state.points)}"
230
- )
231
 
232
- # Handle click events
233
- if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None:
234
- x, y = clicked['x'], clicked['y']
235
- if 0 <= x < width and 0 <= y < height:
236
- # Add new point
237
- new_points = st.session_state.points.copy()
238
- new_points.append([x, y])
239
- st.session_state.points = new_points
240
-
241
- # Update the image with the new point
242
- points = np.array(st.session_state.points, dtype=np.int32)
243
- if len(points) > 0:
244
- cv2.polylines(draw_image, [points],
245
- True if len(points) == 4 else False,
246
- (0, 255, 0), 2)
247
- for i, point in enumerate(points):
248
- cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
249
- cv2.putText(draw_image, str(i+1),
250
- (point[0]+10, point[1]+10),
251
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
252
-
253
- # Rerun to update the display
254
- st.rerun()
255
- else:
256
- # Just display the image if we're done adding points
257
- st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB), use_column_width=True)
258
-
259
- with col2:
260
- # Show progress
261
- st.write(f"Points: {len(st.session_state.points)}/4")
262
-
263
- # Show current points
264
- if len(st.session_state.points) > 0:
265
- st.write("Current Points:")
266
- for i, point in enumerate(st.session_state.points):
267
- st.write(f"Point {i+1}: ({point[0]}, {point[1]})")
268
-
269
- # Add option to remove last point
270
- if st.button("Remove Last Point"):
271
- st.session_state.points.pop()
272
- st.rerun()
273
-
274
- # Process button
275
- if len(st.session_state.points) == 4 and not st.session_state.processing_done:
276
- st.write("βœ… Protection area defined! Click 'Process Detection' to continue.")
277
- if st.button('Process Detection', type='primary'):
278
- st.session_state.processing_done = True
279
 
280
- # Initialize models
281
- device, model, yolo_model = initialize_models()
282
 
283
- if device is None or model is None:
284
- st.error("Failed to initialize models. Please check the error messages above.")
285
- return
 
 
 
 
 
 
 
 
 
 
286
 
287
- # Process image for red light detection
288
- is_red_light, red_light_prob, no_red_light_prob = process_image(image, model, device)
 
289
 
290
- # Display red light detection results
291
- st.write("\nπŸ”₯ Red Light Detection Results:")
292
- st.write(f"Red Light Detected: {is_red_light}")
293
- st.write(f"Red Light Probability: {red_light_prob:.2%}")
294
- st.write(f"No Red Light Probability: {no_red_light_prob:.2%}")
295
 
296
- if is_red_light and yolo_model is not None:
297
- # Draw protection area
298
- cv2.polylines(cv_image, [np.array(st.session_state.points)], True, (0, 255, 0), 2)
299
-
300
- # Run YOLO detection
301
- results = yolo_model(cv_image, conf=0.25)
302
-
303
- # Process detections
304
- detection_results = []
305
- for result in results:
306
- if result.boxes is not None:
307
- for box in result.boxes:
308
- class_id = int(box.cls[0])
309
- class_name = yolo_model.names[class_id]
310
-
311
- if class_name in ALLOWED_CLASSES:
312
- bbox = box.xyxy[0].cpu().numpy()
313
-
314
- if is_bbox_in_area(bbox, st.session_state.points, cv_image.shape):
315
- confidence = float(box.conf[0])
316
- detection_results.append({
317
- 'class': class_name,
318
- 'confidence': confidence,
319
- 'bbox': bbox
320
- })
321
-
322
- # Merge overlapping detections
323
- detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5)
324
-
325
- # Draw detections
326
- for det in detection_results:
327
- bbox = det['bbox']
328
- # Draw detection box
329
- cv2.rectangle(cv_image,
330
- (int(bbox[0]), int(bbox[1])),
331
- (int(bbox[2]), int(bbox[3])),
332
- (0, 0, 255), 2)
333
-
334
- # Add label
335
- text = f"{det['class']}: {det['confidence']:.2%}"
336
- put_text_with_background(cv_image, text,
337
- (int(bbox[0]), int(bbox[1]) - 10))
338
-
339
- # Add status text
340
- status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
341
- put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
342
-
343
- count_text = f"Objects in Protection Area: {len(detection_results)}"
344
- put_text_with_background(cv_image, count_text, (10, 70), font_scale=0.8)
345
-
346
- # Display results
347
- st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
348
-
349
- # Display detections
350
- if detection_results:
351
- st.write("\n🎯 Detected Objects in Protection Area:")
352
- for i, det in enumerate(detection_results, 1):
353
- st.write(f"\nObject {i}:")
354
- st.write(f"- Class: {det['class']}")
355
- st.write(f"- Confidence: {det['confidence']:.2%}")
356
- else:
357
- st.write("\nNo objects detected in protection area")
358
  else:
359
- status_text = f"Red Light: NOT DETECTED ({red_light_prob:.1%})"
360
- put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
361
- st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
 
 
362
 
363
  if __name__ == "__main__":
364
  main()
 
54
  model.eval()
55
 
56
  # Load YOLO model
57
+ yolo_model_path = "yolo11s.onnx"
58
  if not os.path.exists(yolo_model_path):
59
  st.error(f"YOLO model file not found: {yolo_model_path}")
60
  return device, model, None
 
168
  return merged_detections
169
 
170
  def main():
171
+ st.title("Train obstruciton detection V1")
172
 
173
+ # Initialize session state
174
  if 'points' not in st.session_state:
175
  st.session_state.points = []
176
+ if 'protection_area_defined' not in st.session_state:
177
+ st.session_state.protection_area_defined = False
178
+ if 'current_step' not in st.session_state:
179
+ st.session_state.current_step = 1
180
 
181
+ # Create tabs for the two steps
182
+ step1, step2 = st.tabs(["Step 1: Define Protection Area", "Step 2: Detect Objects"])
183
 
184
+ with step1:
185
+ st.header("Step 1: Define Protection Area")
186
+ st.write("Upload an image and define the protection area by clicking 4 points")
187
 
188
+ # File uploader for protection area definition
189
+ setup_image = st.file_uploader("Choose an image for protection area setup", type=['jpg', 'jpeg', 'png'], key="setup_image")
 
190
 
191
+ if setup_image is not None:
192
+ # Convert uploaded file to PIL Image
193
+ image = Image.open(setup_image).convert('RGB')
194
+
195
+ # Convert to OpenCV format for drawing
196
+ cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
197
+ height, width = cv_image.shape[:2]
198
+
199
+ # Create a copy for drawing
200
+ draw_image = cv_image.copy()
201
+
202
+ # Instructions
203
+ st.write("πŸ‘† Click directly on the image to add points for the protection area (need 4 points)")
204
+ st.write("πŸ”„ Click 'Reset Points' to start over")
205
+
206
+ # Reset button
207
+ if st.button('Reset Points'):
208
+ st.session_state.points = []
209
+ st.session_state.protection_area_defined = False
210
+ st.rerun()
211
+
212
+ # Display current image with points
213
+ if len(st.session_state.points) > 0:
214
+ # Draw existing points and lines
215
+ points = np.array(st.session_state.points, dtype=np.int32)
216
+ cv2.polylines(draw_image, [points],
217
+ True if len(points) == 4 else False,
218
+ (0, 255, 0), 2)
219
+ # Draw points with numbers
220
+ for i, point in enumerate(points):
221
+ cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
222
+ cv2.putText(draw_image, str(i+1),
223
+ (point[0]+10, point[1]+10),
224
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
225
+
226
+ # Create columns for better layout
227
+ col1, col2 = st.columns([4, 1])
228
+
229
+ with col1:
230
+ # Display the image and handle click events
231
+ if len(st.session_state.points) < 4:
232
+ clicked = streamlit_image_coordinates(
233
+ cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB),
234
+ key=f"image_coordinates_{len(st.session_state.points)}"
235
+ )
236
+
237
+ if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None:
238
+ x, y = clicked['x'], clicked['y']
239
+ if 0 <= x < width and 0 <= y < height:
240
+ st.session_state.points.append([x, y])
241
+ if len(st.session_state.points) == 4:
242
+ st.session_state.protection_area_defined = True
243
+ st.rerun()
244
+ else:
245
+ st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
246
+
247
+ with col2:
248
+ st.write(f"Points: {len(st.session_state.points)}/4")
249
+ if len(st.session_state.points) > 0:
250
+ st.write("Current Points:")
251
+ for i, point in enumerate(st.session_state.points):
252
+ st.write(f"Point {i+1}: ({point[0]}, {point[1]})")
253
+
254
+ with step2:
255
+ st.header("Step 2: Detect Objects")
256
 
257
+ if not st.session_state.protection_area_defined:
258
+ st.warning("⚠️ Please complete Step 1 first to define the protection area.")
259
+ return
 
 
260
 
261
+ st.write("Upload images to detect red lights and objects in the protection area")
 
 
 
 
 
 
 
 
 
 
 
 
262
 
263
+ # File uploader for detection
264
+ detection_image = st.file_uploader("Choose an image for detection", type=['jpg', 'jpeg', 'png'], key="detection_image")
265
 
266
+ if detection_image is not None:
267
+ # Initialize models
268
+ device, model, yolo_model = initialize_models()
269
+
270
+ if device is None or model is None:
271
+ st.error("Failed to initialize models. Please check the error messages above.")
272
+ return
273
+
274
+ # Load and process image
275
+ image = Image.open(detection_image).convert('RGB')
276
+ cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
277
+
278
+ # Process image for red light detection
279
+ is_red_light, red_light_prob, no_red_light_prob = process_image(image, model, device)
280
+
281
+ # Display red light detection results
282
+ st.write("\nπŸ”₯ Red Light Detection Results:")
283
+ st.write(f"Red Light Detected: {is_red_light}")
284
+ st.write(f"Red Light Probability: {red_light_prob:.2%}")
285
+ st.write(f"No Red Light Probability: {no_red_light_prob:.2%}")
286
+
287
+ if is_red_light and yolo_model is not None:
288
+ # Draw protection area
289
+ cv2.polylines(cv_image, [np.array(st.session_state.points)], True, (0, 255, 0), 2)
290
 
291
+ # Run YOLO detection
292
+ results = yolo_model(cv_image, conf=0.25)
 
 
 
293
 
294
+ # Process detections
295
+ detection_results = []
296
+ for result in results:
297
+ if result.boxes is not None:
298
+ for box in result.boxes:
299
+ class_id = int(box.cls[0])
300
+ class_name = yolo_model.names[class_id]
301
+
302
+ if class_name in ALLOWED_CLASSES:
303
+ bbox = box.xyxy[0].cpu().numpy()
304
+
305
+ if is_bbox_in_area(bbox, st.session_state.points, cv_image.shape):
306
+ confidence = float(box.conf[0])
307
+ detection_results.append({
308
+ 'class': class_name,
309
+ 'confidence': confidence,
310
+ 'bbox': bbox
311
+ })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
+ # Merge overlapping detections
314
+ detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5)
315
 
316
+ # Draw detections
317
+ for det in detection_results:
318
+ bbox = det['bbox']
319
+ # Draw detection box
320
+ cv2.rectangle(cv_image,
321
+ (int(bbox[0]), int(bbox[1])),
322
+ (int(bbox[2]), int(bbox[3])),
323
+ (0, 0, 255), 2)
324
+
325
+ # Add label
326
+ text = f"{det['class']}: {det['confidence']:.2%}"
327
+ put_text_with_background(cv_image, text,
328
+ (int(bbox[0]), int(bbox[1]) - 10))
329
 
330
+ # Add status text
331
+ status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
332
+ put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
333
 
334
+ count_text = f"Objects in Protection Area: {len(detection_results)}"
335
+ put_text_with_background(cv_image, count_text, (10, 70), font_scale=0.8)
 
 
 
336
 
337
+ # Display results
338
+ st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
339
+
340
+ # Display detections
341
+ if detection_results:
342
+ st.write("\n🎯 Detected Objects in Protection Area:")
343
+ for i, det in enumerate(detection_results, 1):
344
+ st.write(f"\nObject {i}:")
345
+ st.write(f"- Class: {det['class']}")
346
+ st.write(f"- Confidence: {det['confidence']:.2%}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
  else:
348
+ st.write("\nNo objects detected in protection area")
349
+ else:
350
+ status_text = f"Red Light: NOT DETECTED ({red_light_prob:.1%})"
351
+ put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
352
+ st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
353
 
354
  if __name__ == "__main__":
355
  main()