Sompote commited on
Commit
408fdbe
·
verified ·
1 Parent(s): f33aeb4

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +189 -54
  2. best_segment.pt +3 -0
app.py CHANGED
@@ -8,6 +8,7 @@ from PIL import Image
8
  import cv2
9
  from ultralytics import YOLO
10
  import os
 
11
  from streamlit_image_coordinates import streamlit_image_coordinates
12
 
13
  # Set page config
@@ -45,7 +46,7 @@ def initialize_models():
45
  best_model_path = "best_model_mobilenet_v3_v2.pth"
46
  if not os.path.exists(best_model_path):
47
  st.error(f"Model file not found: {best_model_path}")
48
- return None, None, None
49
 
50
  if device.type == 'cuda':
51
  model.load_state_dict(torch.load(best_model_path))
@@ -53,18 +54,27 @@ def initialize_models():
53
  model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu')))
54
  model.eval()
55
 
56
- # Load YOLO model
57
  yolo_model_path = "yolo11s.onnx"
58
  if not os.path.exists(yolo_model_path):
59
  st.error(f"YOLO model file not found: {yolo_model_path}")
60
- return device, model, None
61
 
62
  yolo_model = YOLO(yolo_model_path)
63
- return device, model, yolo_model
 
 
 
 
 
 
 
 
 
64
 
65
  except Exception as e:
66
  st.error(f"Error initializing models: {str(e)}")
67
- return None, None, None
68
 
69
  def process_image(image, model, device):
70
  # Define image transformations
@@ -167,8 +177,31 @@ def merge_overlapping_detections(detections, iou_threshold=0.5):
167
 
168
  return merged_detections
169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
  def main():
171
- st.title("Train obstruction detection V1")
172
 
173
  # Initialize session state
174
  if 'points' not in st.session_state:
@@ -177,13 +210,32 @@ def main():
177
  st.session_state.protection_area_defined = False
178
  if 'current_step' not in st.session_state:
179
  st.session_state.current_step = 1
 
 
 
 
 
 
 
 
 
180
 
181
  # Create tabs for the two steps
182
  step1, step2 = st.tabs(["Step 1: Define Protection Area", "Step 2: Detect Objects"])
183
 
184
  with step1:
185
  st.header("Step 1: Define Protection Area")
186
- st.write("Upload an image and define the protection area by clicking 4 points")
 
 
 
 
 
 
 
 
 
 
187
 
188
  # File uploader for protection area definition
189
  setup_image = st.file_uploader("Choose an image for protection area setup", type=['jpg', 'jpeg', 'png'], key="setup_image")
@@ -199,57 +251,143 @@ def main():
199
  # Create a copy for drawing
200
  draw_image = cv_image.copy()
201
 
202
- # Instructions
203
- st.write("👆 Click directly on the image to add points for the protection area (need 4 points)")
204
- st.write("🔄 Click 'Reset Points' to start over")
205
-
206
  # Reset button
207
- if st.button('Reset Points'):
208
  st.session_state.points = []
209
  st.session_state.protection_area_defined = False
 
 
 
 
 
210
  st.rerun()
211
 
212
- # Display current image with points
213
- if len(st.session_state.points) > 0:
214
- # Draw existing points and lines
215
- points = np.array(st.session_state.points, dtype=np.int32)
216
- cv2.polylines(draw_image, [points],
217
- True if len(points) == 4 else False,
218
- (0, 255, 0), 2)
219
- # Draw points with numbers
220
- for i, point in enumerate(points):
221
- cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
222
- cv2.putText(draw_image, str(i+1),
223
- (point[0]+10, point[1]+10),
224
- cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
225
-
226
- # Create columns for better layout
227
- col1, col2 = st.columns([4, 1])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
 
229
- with col1:
230
- # Display the image and handle click events
231
- if len(st.session_state.points) < 4:
232
- clicked = streamlit_image_coordinates(
233
- cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB),
234
- key=f"image_coordinates_{len(st.session_state.points)}"
235
- )
236
-
237
- if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None:
238
- x, y = clicked['x'], clicked['y']
239
- if 0 <= x < width and 0 <= y < height:
240
- st.session_state.points.append([x, y])
241
- if len(st.session_state.points) == 4:
242
- st.session_state.protection_area_defined = True
243
- st.rerun()
244
  else:
245
- st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
246
-
247
- with col2:
248
- st.write(f"Points: {len(st.session_state.points)}/4")
249
- if len(st.session_state.points) > 0:
250
- st.write("Current Points:")
251
- for i, point in enumerate(st.session_state.points):
252
- st.write(f"Point {i+1}: ({point[0]}, {point[1]})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
  with step2:
255
  st.header("Step 2: Detect Objects")
@@ -264,9 +402,6 @@ def main():
264
  detection_image = st.file_uploader("Choose an image for detection", type=['jpg', 'jpeg', 'png'], key="detection_image")
265
 
266
  if detection_image is not None:
267
- # Initialize models
268
- device, model, yolo_model = initialize_models()
269
-
270
  if device is None or model is None:
271
  st.error("Failed to initialize models. Please check the error messages above.")
272
  return
 
8
  import cv2
9
  from ultralytics import YOLO
10
  import os
11
+ import random
12
  from streamlit_image_coordinates import streamlit_image_coordinates
13
 
14
  # Set page config
 
46
  best_model_path = "best_model_mobilenet_v3_v2.pth"
47
  if not os.path.exists(best_model_path):
48
  st.error(f"Model file not found: {best_model_path}")
49
+ return None, None, None, None
50
 
51
  if device.type == 'cuda':
52
  model.load_state_dict(torch.load(best_model_path))
 
54
  model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu')))
55
  model.eval()
56
 
57
+ # Load YOLO model for object detection
58
  yolo_model_path = "yolo11s.onnx"
59
  if not os.path.exists(yolo_model_path):
60
  st.error(f"YOLO model file not found: {yolo_model_path}")
61
+ return device, model, None, None
62
 
63
  yolo_model = YOLO(yolo_model_path)
64
+
65
+ # Load YOLO segmentation model
66
+ seg_model_path = "best_segment.pt"
67
+ if not os.path.exists(seg_model_path):
68
+ st.error(f"YOLO segmentation model file not found: {seg_model_path}")
69
+ return device, model, yolo_model, None
70
+
71
+ seg_model = YOLO(seg_model_path)
72
+
73
+ return device, model, yolo_model, seg_model
74
 
75
  except Exception as e:
76
  st.error(f"Error initializing models: {str(e)}")
77
+ return None, None, None, None
78
 
79
  def process_image(image, model, device):
80
  # Define image transformations
 
177
 
178
  return merged_detections
179
 
180
+ def get_segmentation_masks(image, seg_model, conf_threshold=0.25):
181
+ """Get segmentation masks from YOLO segmentation model."""
182
+ results = seg_model(image, conf=conf_threshold)
183
+
184
+ masks = []
185
+ if results and len(results) > 0 and results[0].masks is not None:
186
+ for i, mask in enumerate(results[0].masks.xy):
187
+ class_id = int(results[0].boxes.cls[i])
188
+ class_name = results[0].names[class_id]
189
+ confidence = float(results[0].boxes.conf[i])
190
+
191
+ # Convert mask to numpy array
192
+ mask_np = np.array(mask, dtype=np.int32)
193
+
194
+ masks.append({
195
+ 'mask': mask_np,
196
+ 'class': class_name,
197
+ 'confidence': confidence,
198
+ 'class_id': class_id
199
+ })
200
+
201
+ return masks, results
202
+
203
  def main():
204
+ st.title("Train obstruction detection V1.2")
205
 
206
  # Initialize session state
207
  if 'points' not in st.session_state:
 
210
  st.session_state.protection_area_defined = False
211
  if 'current_step' not in st.session_state:
212
  st.session_state.current_step = 1
213
+ if 'protection_method' not in st.session_state:
214
+ st.session_state.protection_method = "manual"
215
+ if 'segmentation_masks' not in st.session_state:
216
+ st.session_state.segmentation_masks = []
217
+ if 'selected_mask_index' not in st.session_state:
218
+ st.session_state.selected_mask_index = -1
219
+
220
+ # Initialize models
221
+ device, model, yolo_model, seg_model = initialize_models()
222
 
223
  # Create tabs for the two steps
224
  step1, step2 = st.tabs(["Step 1: Define Protection Area", "Step 2: Detect Objects"])
225
 
226
  with step1:
227
  st.header("Step 1: Define Protection Area")
228
+
229
+ # Method selection
230
+ method = st.radio(
231
+ "Select method to define protection area:",
232
+ ["Manual (Click 4 points)", "Automatic Segmentation (Select a segment)"],
233
+ index=0 if st.session_state.protection_method == "manual" else 1,
234
+ key="method_selection"
235
+ )
236
+
237
+ # Update protection method in session state
238
+ st.session_state.protection_method = "manual" if method == "Manual (Click 4 points)" else "yolo"
239
 
240
  # File uploader for protection area definition
241
  setup_image = st.file_uploader("Choose an image for protection area setup", type=['jpg', 'jpeg', 'png'], key="setup_image")
 
251
  # Create a copy for drawing
252
  draw_image = cv_image.copy()
253
 
 
 
 
 
254
  # Reset button
255
+ if st.button('Reset Points/Selection'):
256
  st.session_state.points = []
257
  st.session_state.protection_area_defined = False
258
+ st.session_state.selected_mask_index = -1
259
+ # Clear segmentation masks to force re-detection
260
+ st.session_state.segmentation_masks = []
261
+ if 'mask_colors' in st.session_state:
262
+ del st.session_state.mask_colors
263
  st.rerun()
264
 
265
+ # Manual method
266
+ if st.session_state.protection_method == "manual":
267
+ # Instructions
268
+ st.write("👆 Click directly on the image to add points for the protection area (need 4 points)")
269
+
270
+ # Display current image with points
271
+ if len(st.session_state.points) > 0:
272
+ # Draw existing points and lines
273
+ points = np.array(st.session_state.points, dtype=np.int32)
274
+ cv2.polylines(draw_image, [points],
275
+ True if len(points) == 4 else False,
276
+ (0, 255, 0), 2)
277
+ # Draw points with numbers
278
+ for i, point in enumerate(points):
279
+ cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
280
+ cv2.putText(draw_image, str(i+1),
281
+ (point[0]+10, point[1]+10),
282
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
283
+
284
+ # Create columns for better layout
285
+ col1, col2 = st.columns([4, 1])
286
+
287
+ with col1:
288
+ # Display the image and handle click events
289
+ if len(st.session_state.points) < 4:
290
+ clicked = streamlit_image_coordinates(
291
+ cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB),
292
+ key=f"image_coordinates_{len(st.session_state.points)}"
293
+ )
294
+
295
+ if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None:
296
+ x, y = clicked['x'], clicked['y']
297
+ if 0 <= x < width and 0 <= y < height:
298
+ st.session_state.points.append([x, y])
299
+ if len(st.session_state.points) == 4:
300
+ st.session_state.protection_area_defined = True
301
+ st.rerun()
302
+ else:
303
+ st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB))
304
+
305
+ with col2:
306
+ st.write(f"Points: {len(st.session_state.points)}/4")
307
+ if len(st.session_state.points) > 0:
308
+ st.write("Current Points:")
309
+ for i, point in enumerate(st.session_state.points):
310
+ st.write(f"Point {i+1}: ({point[0]}, {point[1]})")
311
 
312
+ # YOLO Segmentation method
313
+ else:
314
+ if seg_model is None:
315
+ st.error("YOLO segmentation model not loaded. Please check the error messages above.")
 
 
 
 
 
 
 
 
 
 
 
316
  else:
317
+ # Always run segmentation when in YOLO mode to ensure fresh results
318
+ with st.spinner("Running segmentation..."):
319
+ masks, results = get_segmentation_masks(cv_image, seg_model)
320
+ st.session_state.segmentation_masks = masks
321
+
322
+ # Generate random colors for each mask
323
+ st.session_state.mask_colors = []
324
+ for _ in range(len(masks)):
325
+ st.session_state.mask_colors.append([random.randint(0, 255) for _ in range(3)])
326
+
327
+ # Display segmentation results
328
+ if len(st.session_state.segmentation_masks) > 0:
329
+ # Create a copy of the image for drawing masks
330
+ mask_image = cv_image.copy()
331
+
332
+ # Draw all masks with transparency
333
+ for i, mask_data in enumerate(st.session_state.segmentation_masks):
334
+ mask = mask_data['mask']
335
+ color = st.session_state.mask_colors[i]
336
+
337
+ # Create a blank image for this mask
338
+ mask_overlay = np.zeros_like(mask_image)
339
+
340
+ # Draw the filled polygon
341
+ cv2.fillPoly(mask_overlay, [mask], color)
342
+
343
+ # Add the mask to the image with transparency
344
+ alpha = 0.4
345
+ if i == st.session_state.selected_mask_index:
346
+ alpha = 0.7 # Make selected mask more visible
347
+
348
+ mask_image = cv2.addWeighted(mask_image, 1, mask_overlay, alpha, 0)
349
+
350
+ # Draw the polygon outline
351
+ line_thickness = 2
352
+ if i == st.session_state.selected_mask_index:
353
+ line_thickness = 4 # Make selected mask outline thicker
354
+
355
+ cv2.polylines(mask_image, [mask], True, color, line_thickness)
356
+
357
+ # Add class label
358
+ class_name = mask_data['class']
359
+ confidence = mask_data['confidence']
360
+ label = f"{class_name} {confidence:.2f}"
361
+
362
+ # Find a good position for the label (use the top-left point of the mask)
363
+ label_pos = (int(mask[0][0]), int(mask[0][1]) - 10)
364
+ put_text_with_background(mask_image, label, label_pos)
365
+
366
+ # Display the image with masks
367
+ col1, col2 = st.columns([4, 1])
368
+
369
+ with col1:
370
+ st.image(cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB))
371
+
372
+ with col2:
373
+ st.write("Available Segments:")
374
+ for i, mask_data in enumerate(st.session_state.segmentation_masks):
375
+ if st.button(f"Select {mask_data['class']} #{i+1}", key=f"select_mask_{i}"):
376
+ st.session_state.selected_mask_index = i
377
+ # Use the selected mask as protection area
378
+ st.session_state.points = mask_data['mask'].tolist()
379
+ st.session_state.protection_area_defined = True
380
+ st.rerun()
381
+
382
+ # Add a re-detect button
383
+ if st.button("Re-detect Segments"):
384
+ st.session_state.segmentation_masks = []
385
+ if 'mask_colors' in st.session_state:
386
+ del st.session_state.mask_colors
387
+ st.session_state.selected_mask_index = -1
388
+ st.rerun()
389
+ else:
390
+ st.warning("No segmentation masks found in the image. Try another image or use manual method.")
391
 
392
  with step2:
393
  st.header("Step 2: Detect Objects")
 
402
  detection_image = st.file_uploader("Choose an image for detection", type=['jpg', 'jpeg', 'png'], key="detection_image")
403
 
404
  if detection_image is not None:
 
 
 
405
  if device is None or model is None:
406
  st.error("Failed to initialize models. Please check the error messages above.")
407
  return
best_segment.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b1b6f8dceeec2d8116f20b0d73084c5f9e33859bfd8f4891ef7e2a46cc674ac8
3
+ size 20521693