Sompote commited on
Commit
1173d05
Β·
verified Β·
1 Parent(s): dc023a3

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +33 -7
  2. app.py +319 -0
  3. best_model_ShuffleNetV2.pth +3 -0
  4. requirements.txt +8 -0
  5. yolo11s.onnx +3 -0
README.md CHANGED
@@ -1,13 +1,39 @@
1
  ---
2
- title: Tran Obstruction V1
3
- emoji: πŸ’»
4
- colorFrom: pink
5
- colorTo: red
6
  sdk: streamlit
7
- sdk_version: 1.42.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: Train obstruction detection V1
3
+ emoji: πŸš—
4
+ colorFrom: blue
5
+ colorTo: green
6
  sdk: streamlit
7
+ sdk_version: "1.29.0"
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
+
13
+ # Traffic Light Detection with Protection Area
14
+
15
+ This Streamlit app performs traffic light detection and monitors objects within a user-defined protection area. The app uses:
16
+ - ShuffleNetV2 for traffic light detection
17
+ - YOLO for object detection
18
+ - Interactive point selection for defining protection areas
19
+
20
+ ## Features
21
+ - Upload and process images
22
+ - Interactive protection area definition
23
+ - Traffic light state detection
24
+ - Object detection within protection area
25
+ - Real-time visualization
26
+
27
+ ## How to Use
28
+ 1. Upload an image using the file uploader
29
+ 2. Click on the image to define 4 points for the protection area
30
+ 3. Use "Reset Points" if you need to start over
31
+ 4. Click "Process Detection" when ready
32
+ 5. View the results showing traffic light state and detected objects
33
+
34
+ ## Models
35
+ - Traffic Light Detection: ShuffleNetV2
36
+ - Object Detection: YOLOv8
37
+
38
+ ## Requirements
39
+ See requirements.txt for all dependencies.
app.py ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ from torchvision import models
7
+ from PIL import Image
8
+ import cv2
9
+ from ultralytics import YOLO
10
+ import os
11
+ from streamlit_image_coordinates import streamlit_image_coordinates
12
+
13
+ # Set page config
14
+ st.set_page_config(
15
+ page_title="Traffic Light Detection App",
16
+ layout="wide",
17
+ menu_items={
18
+ 'Get Help': 'https://github.com/yourusername/traffic-light-detection',
19
+ 'Report a bug': "https://github.com/yourusername/traffic-light-detection/issues",
20
+ 'About': "# Traffic Light Detection App\nThis app detects traffic lights and monitors objects in a protection area."
21
+ }
22
+ )
23
+
24
+ # Define allowed classes
25
+ ALLOWED_CLASSES = {
26
+ 'person', 'bicycle', 'car', 'motorcycle', 'bus', 'train', 'truck',
27
+ 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe'
28
+ }
29
+
30
+ @st.cache_resource
31
+ def initialize_models():
32
+ try:
33
+ # Set device
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+
36
+ # Initialize ShuffleNet model
37
+ model = models.shufflenet_v2_x0_5(weights=None)
38
+ model.fc = nn.Sequential(
39
+ nn.Linear(model.fc.in_features, 2),
40
+ nn.Softmax(dim=1)
41
+ )
42
+ model = model.to(device)
43
+
44
+ # Load model weights
45
+ best_model_path = "best_model_ShuffleNetV2.pth"
46
+ if not os.path.exists(best_model_path):
47
+ st.error(f"Model file not found: {best_model_path}")
48
+ return None, None, None
49
+
50
+ if device.type == 'cuda':
51
+ model.load_state_dict(torch.load(best_model_path))
52
+ else:
53
+ model.load_state_dict(torch.load(best_model_path, map_location=torch.device('cpu')))
54
+ model.eval()
55
+
56
+ # Load YOLO model
57
+ yolo_model_path = "yolo11s.onnx"
58
+ if not os.path.exists(yolo_model_path):
59
+ st.error(f"YOLO model file not found: {yolo_model_path}")
60
+ return device, model, None
61
+
62
+ yolo_model = YOLO(yolo_model_path)
63
+ return device, model, yolo_model
64
+
65
+ except Exception as e:
66
+ st.error(f"Error initializing models: {str(e)}")
67
+ return None, None, None
68
+
69
+ def process_image(image, model, device):
70
+ # Define image transformations
71
+ transform = transforms.Compose([
72
+ transforms.Resize((224, 224)),
73
+ transforms.ToTensor(),
74
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
75
+ ])
76
+
77
+ # Process image
78
+ input_tensor = transform(image).unsqueeze(0).to(device)
79
+
80
+ # Perform inference
81
+ with torch.no_grad():
82
+ output = model(input_tensor)
83
+ probabilities = output[0]
84
+ no_red_light_prob = probabilities[0].item()
85
+ red_light_prob = probabilities[1].item()
86
+ is_red_light = red_light_prob > no_red_light_prob
87
+
88
+ return is_red_light, red_light_prob, no_red_light_prob
89
+
90
+ def is_point_in_polygon(point, polygon):
91
+ """Check if a point is inside a polygon using ray casting algorithm."""
92
+ x, y = point
93
+ n = len(polygon)
94
+ inside = False
95
+ p1x, p1y = polygon[0]
96
+ for i in range(n + 1):
97
+ p2x, p2y = polygon[i % n]
98
+ if y > min(p1y, p2y):
99
+ if y <= max(p1y, p2y):
100
+ if x <= max(p1x, p2x):
101
+ if p1y != p2y:
102
+ xinters = (y - p1y) * (p2x - p1x) / (p2y - p1y) + p1x
103
+ if p1x == p2x or x <= xinters:
104
+ inside = not inside
105
+ p1x, p1y = p2x, p2y
106
+ return inside
107
+
108
+ def is_bbox_in_area(bbox, protection_area, image_shape):
109
+ """Check if bounding box center is in protection area."""
110
+ # Get bbox center point
111
+ center_x = (bbox[0] + bbox[2]) / 2
112
+ center_y = (bbox[1] + bbox[3]) / 2
113
+ return is_point_in_polygon((center_x, center_y), protection_area)
114
+
115
+ def put_text_with_background(img, text, position, font_scale=0.8, thickness=2, font=cv2.FONT_HERSHEY_SIMPLEX):
116
+ """Put text with background on image."""
117
+ # Get text size
118
+ (text_width, text_height), baseline = cv2.getTextSize(text, font, font_scale, thickness)
119
+
120
+ # Calculate background rectangle
121
+ padding = 5
122
+ bg_rect_pt1 = (position[0], position[1] - text_height - padding)
123
+ bg_rect_pt2 = (position[0] + text_width + padding * 2, position[1] + padding)
124
+
125
+ # Draw background rectangle
126
+ cv2.rectangle(img, bg_rect_pt1, bg_rect_pt2, (0, 0, 0), -1)
127
+
128
+ # Put text
129
+ cv2.putText(img, text, (position[0] + padding, position[1]), font, font_scale, (255, 255, 255), thickness)
130
+
131
+ def main():
132
+ st.title("Traffic Light Detection with Protection Area")
133
+
134
+ # Initialize session state for protection area points
135
+ if 'points' not in st.session_state:
136
+ st.session_state.points = []
137
+ if 'processing_done' not in st.session_state:
138
+ st.session_state.processing_done = False
139
+
140
+ # File uploader
141
+ uploaded_file = st.file_uploader("Choose an image", type=['jpg', 'jpeg', 'png'])
142
+
143
+ if uploaded_file is not None:
144
+ # Convert uploaded file to PIL Image
145
+ image = Image.open(uploaded_file).convert('RGB')
146
+
147
+ # Convert to OpenCV format for drawing
148
+ cv_image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
149
+ height, width = cv_image.shape[:2]
150
+
151
+ # Create a copy for drawing
152
+ draw_image = cv_image.copy()
153
+
154
+ # Instructions
155
+ st.write("πŸ‘† Click directly on the image to add points for the protection area (need 4 points)")
156
+ st.write("πŸ”„ Click 'Reset Points' to start over")
157
+
158
+ # Reset button
159
+ if st.button('Reset Points'):
160
+ st.session_state.points = []
161
+ st.session_state.processing_done = False
162
+ st.rerun()
163
+
164
+ # Display current image with points
165
+ if len(st.session_state.points) > 0:
166
+ # Draw existing points and lines
167
+ points = np.array(st.session_state.points, dtype=np.int32)
168
+ cv2.polylines(draw_image, [points],
169
+ True if len(points) == 4 else False,
170
+ (0, 255, 0), 2)
171
+ # Draw points with numbers
172
+ for i, point in enumerate(points):
173
+ cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
174
+ cv2.putText(draw_image, str(i+1),
175
+ (point[0]+10, point[1]+10),
176
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
177
+
178
+ # Create columns for better layout
179
+ col1, col2 = st.columns([4, 1])
180
+
181
+ with col1:
182
+ # Display the image and handle click events
183
+ if len(st.session_state.points) < 4 and not st.session_state.processing_done:
184
+ # Create a placeholder for the image
185
+ image_placeholder = st.empty()
186
+
187
+ # Display the image with current points
188
+ clicked = streamlit_image_coordinates(
189
+ cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB),
190
+ key=f"image_coordinates_{len(st.session_state.points)}"
191
+ )
192
+
193
+ # Handle click events
194
+ if clicked is not None and clicked.get('x') is not None and clicked.get('y') is not None:
195
+ x, y = clicked['x'], clicked['y']
196
+ if 0 <= x < width and 0 <= y < height:
197
+ # Add new point
198
+ new_points = st.session_state.points.copy()
199
+ new_points.append([x, y])
200
+ st.session_state.points = new_points
201
+
202
+ # Update the image with the new point
203
+ points = np.array(st.session_state.points, dtype=np.int32)
204
+ if len(points) > 0:
205
+ cv2.polylines(draw_image, [points],
206
+ True if len(points) == 4 else False,
207
+ (0, 255, 0), 2)
208
+ for i, point in enumerate(points):
209
+ cv2.circle(draw_image, tuple(point), 5, (0, 0, 255), -1)
210
+ cv2.putText(draw_image, str(i+1),
211
+ (point[0]+10, point[1]+10),
212
+ cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 0, 255), 2)
213
+
214
+ # Rerun to update the display
215
+ st.rerun()
216
+ else:
217
+ # Just display the image if we're done adding points
218
+ st.image(cv2.cvtColor(draw_image, cv2.COLOR_BGR2RGB), use_column_width=True)
219
+
220
+ with col2:
221
+ # Show progress
222
+ st.write(f"Points: {len(st.session_state.points)}/4")
223
+
224
+ # Show current points
225
+ if len(st.session_state.points) > 0:
226
+ st.write("Current Points:")
227
+ for i, point in enumerate(st.session_state.points):
228
+ st.write(f"Point {i+1}: ({point[0]}, {point[1]})")
229
+
230
+ # Add option to remove last point
231
+ if st.button("Remove Last Point"):
232
+ st.session_state.points.pop()
233
+ st.rerun()
234
+
235
+ # Process button
236
+ if len(st.session_state.points) == 4 and not st.session_state.processing_done:
237
+ st.write("βœ… Protection area defined! Click 'Process Detection' to continue.")
238
+ if st.button('Process Detection', type='primary'):
239
+ st.session_state.processing_done = True
240
+
241
+ # Initialize models
242
+ device, model, yolo_model = initialize_models()
243
+
244
+ if device is None or model is None:
245
+ st.error("Failed to initialize models. Please check the error messages above.")
246
+ return
247
+
248
+ # Process image for red light detection
249
+ is_red_light, red_light_prob, no_red_light_prob = process_image(image, model, device)
250
+
251
+ # Display red light detection results
252
+ st.write("\nπŸ”₯ Red Light Detection Results:")
253
+ st.write(f"Red Light Detected: {is_red_light}")
254
+ st.write(f"Red Light Probability: {red_light_prob:.2%}")
255
+ st.write(f"No Red Light Probability: {no_red_light_prob:.2%}")
256
+
257
+ if is_red_light and yolo_model is not None:
258
+ # Draw protection area
259
+ cv2.polylines(cv_image, [np.array(st.session_state.points)], True, (0, 255, 0), 2)
260
+
261
+ # Run YOLO detection
262
+ results = yolo_model(cv_image, conf=0.25)
263
+
264
+ # Process detections
265
+ detection_results = []
266
+ for result in results:
267
+ if result.boxes is not None:
268
+ for box in result.boxes:
269
+ class_id = int(box.cls[0])
270
+ class_name = yolo_model.names[class_id]
271
+
272
+ if class_name in ALLOWED_CLASSES:
273
+ bbox = box.xyxy[0].cpu().numpy()
274
+
275
+ if is_bbox_in_area(bbox, st.session_state.points, cv_image.shape):
276
+ confidence = float(box.conf[0])
277
+ detection_results.append({
278
+ 'class': class_name,
279
+ 'confidence': confidence,
280
+ 'bbox': bbox
281
+ })
282
+
283
+ # Draw detection
284
+ cv2.rectangle(cv_image,
285
+ (int(bbox[0]), int(bbox[1])),
286
+ (int(bbox[2]), int(bbox[3])),
287
+ (0, 0, 255), 2)
288
+
289
+ # Add label
290
+ text = f"{class_name}: {confidence:.2%}"
291
+ put_text_with_background(cv_image, text,
292
+ (int(bbox[0]), int(bbox[1]) - 10))
293
+
294
+ # Add status text
295
+ status_text = f"Red Light: DETECTED ({red_light_prob:.1%})"
296
+ put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
297
+
298
+ count_text = f"Objects in Protection Area: {len(detection_results)}"
299
+ put_text_with_background(cv_image, count_text, (10, 70), font_scale=0.8)
300
+
301
+ # Display results
302
+ st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
303
+
304
+ # Display detections
305
+ if detection_results:
306
+ st.write("\n🎯 Detected Objects in Protection Area:")
307
+ for i, det in enumerate(detection_results, 1):
308
+ st.write(f"\nObject {i}:")
309
+ st.write(f"- Class: {det['class']}")
310
+ st.write(f"- Confidence: {det['confidence']:.2%}")
311
+ else:
312
+ st.write("\nNo objects detected in protection area")
313
+ else:
314
+ status_text = f"Red Light: NOT DETECTED ({red_light_prob:.1%})"
315
+ put_text_with_background(cv_image, status_text, (10, 30), font_scale=1.0, thickness=2)
316
+ st.image(cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB))
317
+
318
+ if __name__ == "__main__":
319
+ main()
best_model_ShuffleNetV2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a84d301e6e12068e6d1f1f1ded010df866432d4c8b27382ec6dc58e6d8b78a4
3
+ size 1525166
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ torch
3
+ torchvision
4
+ numpy
5
+ Pillow
6
+ opencv-python-headless
7
+ ultralytics
8
+ streamlit-image-coordinates
yolo11s.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99a699d299959fb9307386ee7abd7a76af6798924fb129bcacd3b3a95c77dbf2
3
+ size 38011340