Spaces:
Running
Running
Upload 2 files
Browse files- app.py +63 -18
- best_model_mobilenet_v3_v2.pth +3 -0
app.py
CHANGED
@@ -33,16 +33,16 @@ def initialize_models():
|
|
33 |
# Set device
|
34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
|
36 |
-
# Initialize
|
37 |
-
model = models.
|
38 |
-
model.
|
39 |
-
nn.Linear(
|
40 |
nn.Softmax(dim=1)
|
41 |
)
|
42 |
model = model.to(device)
|
43 |
|
44 |
# Load model weights
|
45 |
-
best_model_path = "
|
46 |
if not os.path.exists(best_model_path):
|
47 |
st.error(f"Model file not found: {best_model_path}")
|
48 |
return None, None, None
|
@@ -54,7 +54,7 @@ def initialize_models():
|
|
54 |
model.eval()
|
55 |
|
56 |
# Load YOLO model
|
57 |
-
yolo_model_path = "yolo11s.onnx"
|
58 |
if not os.path.exists(yolo_model_path):
|
59 |
st.error(f"YOLO model file not found: {yolo_model_path}")
|
60 |
return device, model, None
|
@@ -80,7 +80,8 @@ def process_image(image, model, device):
|
|
80 |
# Perform inference
|
81 |
with torch.no_grad():
|
82 |
output = model(input_tensor)
|
83 |
-
probabilities = output[0]
|
|
|
84 |
no_red_light_prob = probabilities[0].item()
|
85 |
red_light_prob = probabilities[1].item()
|
86 |
is_red_light = red_light_prob > no_red_light_prob
|
@@ -128,6 +129,44 @@ def put_text_with_background(img, text, position, font_scale=0.8, thickness=2, f
|
|
128 |
# Put text
|
129 |
cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def main():
|
132 |
st.title("Traffic Light Detection with Protection Area")
|
133 |
|
@@ -279,17 +318,23 @@ def main():
|
|
279 |
'confidence': confidence,
|
280 |
'bbox': bbox
|
281 |
})
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
290 |
-
|
291 |
-
|
292 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
293 |
|
294 |
# Add status text
|
295 |
status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
|
|
|
33 |
# Set device
|
34 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
|
36 |
+
# Initialize MobileNetV3 model
|
37 |
+
model = models.mobilenet_v3_small(weights=None)
|
38 |
+
model.classifier = nn.Sequential(
|
39 |
+
nn.Linear(576, 2), # Direct mapping to output classes
|
40 |
nn.Softmax(dim=1)
|
41 |
)
|
42 |
model = model.to(device)
|
43 |
|
44 |
# Load model weights
|
45 |
+
best_model_path = "best_model_mobilenet_v3_v2.pth"
|
46 |
if not os.path.exists(best_model_path):
|
47 |
st.error(f"Model file not found: {best_model_path}")
|
48 |
return None, None, None
|
|
|
54 |
model.eval()
|
55 |
|
56 |
# Load YOLO model
|
57 |
+
yolo_model_path = "../yolo11s.onnx" # Going up one directory since the app.py is in API22_FEB
|
58 |
if not os.path.exists(yolo_model_path):
|
59 |
st.error(f"YOLO model file not found: {yolo_model_path}")
|
60 |
return device, model, None
|
|
|
80 |
# Perform inference
|
81 |
with torch.no_grad():
|
82 |
output = model(input_tensor)
|
83 |
+
probabilities = output[0] # Get probabilities for both classes
|
84 |
+
# Class 0 is "No Red Light", Class 1 is "Red Light"
|
85 |
no_red_light_prob = probabilities[0].item()
|
86 |
red_light_prob = probabilities[1].item()
|
87 |
is_red_light = red_light_prob > no_red_light_prob
|
|
|
129 |
# Put text
|
130 |
cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
|
131 |
|
132 |
+
def calculate_iou(box1, box2):
|
133 |
+
"""Calculate Intersection over Union between two bounding boxes."""
|
134 |
+
x1 = max(box1[0], box2[0])
|
135 |
+
y1 = max(box1[1], box2[1])
|
136 |
+
x2 = min(box1[2], box2[2])
|
137 |
+
y2 = min(box1[3], box2[3])
|
138 |
+
|
139 |
+
intersection = max(0, x2 - x1) * max(0, y2 - y1)
|
140 |
+
box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
|
141 |
+
box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
|
142 |
+
union = box1_area + box2_area - intersection
|
143 |
+
|
144 |
+
return intersection / union if union > 0 else 0
|
145 |
+
|
146 |
+
def merge_overlapping_detections(detections, iou_threshold=0.5):
|
147 |
+
"""Merge overlapping detections of the same class."""
|
148 |
+
if not detections:
|
149 |
+
return []
|
150 |
+
|
151 |
+
# Sort detections by confidence
|
152 |
+
detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
|
153 |
+
merged_detections = []
|
154 |
+
|
155 |
+
while detections:
|
156 |
+
best_detection = detections.pop(0)
|
157 |
+
i = 0
|
158 |
+
while i < len(detections):
|
159 |
+
current_detection = detections[i]
|
160 |
+
if (current_detection['class'] == best_detection['class'] and
|
161 |
+
calculate_iou(current_detection['bbox'], best_detection['bbox']) >= iou_threshold):
|
162 |
+
# Remove the lower confidence detection
|
163 |
+
detections.pop(i)
|
164 |
+
else:
|
165 |
+
i += 1
|
166 |
+
merged_detections.append(best_detection)
|
167 |
+
|
168 |
+
return merged_detections
|
169 |
+
|
170 |
def main():
|
171 |
st.title("Traffic Light Detection with Protection Area")
|
172 |
|
|
|
318 |
'confidence': confidence,
|
319 |
'bbox': bbox
|
320 |
})
|
321 |
+
|
322 |
+
# Merge overlapping detections
|
323 |
+
detection_results = merge_overlapping_detections(detection_results, iou_threshold=0.5)
|
324 |
+
|
325 |
+
# Draw detections
|
326 |
+
for det in detection_results:
|
327 |
+
bbox = det['bbox']
|
328 |
+
# Draw detection box
|
329 |
+
cv2.rectangle(cv_image,
|
330 |
+
(int(bbox[0]), int(bbox[1])),
|
331 |
+
(int(bbox[2]), int(bbox[3])),
|
332 |
+
(0, 0, 255), 2)
|
333 |
+
|
334 |
+
# Add label
|
335 |
+
text = f"{det['class']}: {det['confidence']:.2%}"
|
336 |
+
put_text_with_background(cv_image, text,
|
337 |
+
(int(bbox[0]), int(bbox[1]) - 10))
|
338 |
|
339 |
# Add status text
|
340 |
status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
|
best_model_mobilenet_v3_v2.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a9d6dbfc5f368b8dd4f06f86e2ef088c0cd88c7bfd4f686800d2ef7b256b36f7
|
3 |
+
size 3850192
|