Sompote commited on
Commit
c354b1a
·
verified ·
1 Parent(s): a126aef

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +63 -18
  2. best_model_mobilenet_v3_v2.pth +3 -0
app.py CHANGED
@@ -33,16 +33,16 @@ def initialize_models():
33
  # Set device
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
- # Initialize ShuffleNet model
37
- model = models.shufflenet_v2_x0_5(weights=None)
38
- model.fc = nn.Sequential(
39
- nn.Linear(model.fc.in_features, 2),
40
  nn.Softmax(dim=1)
41
  )
42
  model = model.to(device)
43
 
44
  # Load model weights
45
- best_model_path = "best_model_ShuffleNetV2.pth"
46
  if not os.path.exists(best_model_path):
47
  st.error(f"Model file not found: {best_model_path}")
48
  return None, None, None
@@ -54,7 +54,7 @@ def initialize_models():
54
  model.eval()
55
 
56
  # Load YOLO model
57
- yolo_model_path = "yolo11s.onnx"
58
  if not os.path.exists(yolo_model_path):
59
  st.error(f"YOLO model file not found: {yolo_model_path}")
60
  return device, model, None
@@ -80,7 +80,8 @@ def process_image(image, model, device):
80
  # Perform inference
81
  with torch.no_grad():
82
  output = model(input_tensor)
83
- probabilities = output[0]
 
84
  no_red_light_prob = probabilities[0].item()
85
  red_light_prob = probabilities[1].item()
86
  is_red_light = red_light_prob > no_red_light_prob
@@ -128,6 +129,44 @@ def put_text_with_background(img, text, position, font_scale=0.8, thickness=2, f
128
  # Put text
129
  cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def main():
132
  st.title("Traffic Light Detection with Protection Area")
133
 
@@ -279,17 +318,23 @@ def main():
279
  'confidence': confidence,
280
  'bbox': bbox
281
  })
282
-
283
- # Draw detection
284
- cv2.rectangle(cv_image,
285
- (int(bbox[0]), int(bbox[1])),
286
- (int(bbox[2]), int(bbox[3])),
287
- (0, 0, 255), 2)
288
-
289
- # Add label
290
- text = f"{class_name}: {confidence:.2%}"
291
- put_text_with_background(cv_image, text,
292
- (int(bbox[0]), int(bbox[1]) - 10))
 
 
 
 
 
 
293
 
294
  # Add status text
295
  status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
 
33
  # Set device
34
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
 
36
+ # Initialize MobileNetV3 model
37
+ model = models.mobilenet_v3_small(weights=None)
38
+ model.classifier = nn.Sequential(
39
+ nn.Linear(576, 2), # Direct mapping to output classes
40
  nn.Softmax(dim=1)
41
  )
42
  model = model.to(device)
43
 
44
  # Load model weights
45
+ best_model_path = "best_model_mobilenet_v3_v2.pth"
46
  if not os.path.exists(best_model_path):
47
  st.error(f"Model file not found: {best_model_path}")
48
  return None, None, None
 
54
  model.eval()
55
 
56
  # Load YOLO model
57
+ yolo_model_path = "../yolo11s.onnx" # Going up one directory since the app.py is in API22_FEB
58
  if not os.path.exists(yolo_model_path):
59
  st.error(f"YOLO model file not found: {yolo_model_path}")
60
  return device, model, None
 
80
  # Perform inference
81
  with torch.no_grad():
82
  output = model(input_tensor)
83
+ probabilities = output[0] # Get probabilities for both classes
84
+ # Class 0 is "No Red Light", Class 1 is "Red Light"
85
  no_red_light_prob = probabilities[0].item()
86
  red_light_prob = probabilities[1].item()
87
  is_red_light = red_light_prob > no_red_light_prob
 
129
  # Put text
130
  cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
131
 
132
+ def calculate_iou(box1, box2):
133
+ """Calculate Intersection over Union between two bounding boxes."""
134
+ x1 = max(box1[0], box2[0])
135
+ y1 = max(box1[1], box2[1])
136
+ x2 = min(box1[2], box2[2])
137
+ y2 = min(box1[3], box2[3])
138
+
139
+ intersection = max(0, x2 - x1) * max(0, y2 - y1)
140
+ box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
141
+ box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
142
+ union = box1_area + box2_area - intersection
143
+
144
+ return intersection / union if union > 0 else 0
145
+
146
+ def merge_overlapping_detections(detections, iou_threshold=0.5):
147
+ """Merge overlapping detections of the same class."""
148
+ if not detections:
149
+ return []
150
+
151
+ # Sort detections by confidence
152
+ detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
153
+ merged_detections = []
154
+
155
+ while detections:
156
+ best_detection = detections.pop(0)
157
+ i = 0
158
+ while i < len(detections):
159
+ current_detection = detections[i]
160
+ if (current_detection['class'] == best_detection['class'] and
161
+ calculate_iou(current_detection['bbox'], best_detection['bbox']) >= iou_threshold):
162
+ # Remove the lower confidence detection
163
+ detections.pop(i)
164
+ else:
165
+ i += 1
166
+ merged_detections.append(best_detection)
167
+
168
+ return merged_detections
169
+
170
  def main():
171
  st.title("Traffic Light Detection with Protection Area")
172
 
 
318
  'confidence': confidence,
319
  'bbox': bbox
320
  })
321
+
322
+ # Merge overlapping detections
323
+ detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5)
324
+
325
+ # Draw detections
326
+ for det in detection_results:
327
+ bbox = det['bbox']
328
+ # Draw detection box
329
+ cv2.rectangle(cv_image,
330
+ (int(bbox[0]), int(bbox[1])),
331
+ (int(bbox[2]), int(bbox[3])),
332
+ (0, 0, 255), 2)
333
+
334
+ # Add label
335
+ text = f"{det['class']}: {det['confidence']:.2%}"
336
+ put_text_with_background(cv_image, text,
337
+ (int(bbox[0]), int(bbox[1]) - 10))
338
 
339
  # Add status text
340
  status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
best_model_mobilenet_v3_v2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9d6dbfc5f368b8dd4f06f86e2ef088c0cd88c7bfd4f686800d2ef7b256b36f7
3
+ size 3850192