eusholli commited on
Commit
9a91192
·
1 Parent(s): 143a483

Added choice of object and/or pose detection

Browse files
Files changed (1) hide show
  1. app.py +129 -82
app.py CHANGED
@@ -20,10 +20,11 @@ from io import BytesIO
20
  # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
21
  # Update below string to set display title of analysis
22
 
23
- ANALYSIS_TITLE = "YOLO-8 Pose and Efficient Action Detection"
24
 
25
- # Load the YOLOv8 model for pose estimation
26
  pose_model = YOLO("yolov8n-pose.pt")
 
27
 
28
 
29
  def detect_action(keypoints, prev_keypoints=None):
@@ -134,78 +135,103 @@ def analyze_frame(frame: np.ndarray):
134
  img_container["input"] = frame
135
  frame = frame.copy()
136
 
137
- # Run YOLOv8 pose estimation on the frame
138
- pose_results = pose_model(frame)
139
-
140
  detections = []
141
 
142
- for i, box in enumerate(pose_results[0].boxes):
143
- class_id = int(box.cls)
144
- detection = {
145
- "label": pose_model.names[class_id],
146
- "score": float(box.conf),
147
- "box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()]
148
- }
149
-
150
- # Get keypoints for this detection if available
151
- try:
152
- if pose_results[0].keypoints is not None:
153
- keypoints = pose_results[0].keypoints[i].data.cpu().numpy()
154
-
155
- # Detect action using the keypoints
156
- prev_keypoints = img_container.get("prev_keypoints")
157
- action = detect_action(keypoints, prev_keypoints)
158
- detection["action"] = action
159
-
160
- # Store current keypoints for next frame
161
- img_container["prev_keypoints"] = keypoints
162
- else:
163
- detection["action"] = "No keypoint data"
164
- except IndexError:
165
- detection["action"] = "Action detection failed"
166
-
167
- detections.append(detection)
168
-
169
- # Draw pose keypoints without bounding boxes
170
- frame = pose_results[0].plot(boxes=False, labels=False, kpt_line=True)
171
-
172
- for detection in detections:
173
- label = f"{detection['label']} {detection['score']:.2f}"
174
- action = detection['action']
175
-
176
- # Get bounding box coordinates
177
- x1, y1, x2, y2 = detection["box_coords"]
178
-
179
- # Increase font size and thickness
180
- font_scale = 0.7
181
- thickness = 2
182
-
183
- # Get text size for label and action
184
- (label_width, label_height), _ = cv2.getTextSize(
185
- label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
186
- (action_width, action_height), _ = cv2.getTextSize(
187
- action, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
188
-
189
- # Calculate positions for centered labels at the top of the box
190
- label_x = int((x1 + x2) / 2)
191
- label_y = int(y1) - 10 # 10 pixels above the top of the box
192
- action_y = label_y - label_height - 10 # 10 pixels above the label
193
-
194
- # Draw yellow background for label
195
- cv2.rectangle(frame, (label_x - label_width // 2 - 5, label_y - label_height - 5),
196
- (label_x + label_width // 2 + 5, label_y + 5), (0, 255, 255), -1)
197
-
198
- # Draw yellow background for action
199
- cv2.rectangle(frame, (label_x - action_width // 2 - 5, action_y - action_height - 5),
200
- (label_x + action_width // 2 + 5, action_y + 5), (0, 255, 255), -1)
201
-
202
- # Draw black text for label
203
- cv2.putText(frame, label, (label_x - label_width // 2, label_y),
204
- cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness)
205
-
206
- # Draw black text for action
207
- cv2.putText(frame, action, (label_x - action_width // 2, action_y),
208
- cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
 
210
  end_time = time.time()
211
  execution_time_ms = round((end_time - start_time) * 1000, 2)
@@ -328,6 +354,7 @@ with col1:
328
  # Text input for YouTube URL
329
  st.subheader("Enter a YouTube URL")
330
  youtube_url = st.text_input("YouTube URL")
 
331
 
332
  # File uploader for videos
333
  st.subheader("Upload a Video")
@@ -355,7 +382,9 @@ st.markdown(
355
 
356
 
357
  def analysis_init():
358
- global analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder
 
 
359
 
360
  with col2:
361
  st.header("Analysis")
@@ -364,9 +393,11 @@ def analysis_init():
364
  st.subheader("Output Frame")
365
  output_placeholder = st.empty() # Placeholder for output frame
366
  analysis_time = st.empty() # Placeholder for analysis time
367
- show_labels = st.checkbox(
368
- "Show the detected labels", value=True
369
- ) # Checkbox to show/hide labels
 
 
370
  labels_placeholder = st.empty() # Placeholder for labels
371
 
372
 
@@ -449,16 +480,28 @@ def process_video(video_path):
449
 
450
  # Function to get the video stream URL from YouTube using yt-dlp
451
 
452
-
453
  def get_youtube_stream_url(youtube_url):
454
  ydl_opts = {
455
- 'format': 'best[ext=mp4]',
456
  'quiet': True,
 
457
  }
 
458
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
459
- info_dict = ydl.extract_info(youtube_url, download=False)
460
- stream_url = info_dict['url']
461
- return stream_url
 
 
 
 
 
 
 
 
 
 
 
462
 
463
 
464
  # If a YouTube URL is provided, process the video
@@ -467,7 +510,11 @@ if youtube_url:
467
 
468
  stream_url = get_youtube_stream_url(youtube_url)
469
 
470
- process_video(stream_url) # Process the video
 
 
 
 
471
 
472
  # If a video is uploaded or a URL is provided, process the video
473
  if uploaded_video is not None or video_url:
 
20
  # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
21
  # Update below string to set display title of analysis
22
 
23
+ ANALYSIS_TITLE = "YOLO-8 Object Detection, Pose Estimation, and Action Detection"
24
 
25
+ # Load the YOLOv8 models
26
  pose_model = YOLO("yolov8n-pose.pt")
27
+ object_model = YOLO("yolov8n.pt")
28
 
29
 
30
  def detect_action(keypoints, prev_keypoints=None):
 
135
  img_container["input"] = frame
136
  frame = frame.copy()
137
 
 
 
 
138
  detections = []
139
 
140
+ if show_labels in ["Object Detection", "Both"]:
141
+ # Run YOLOv8 object detection on the frame
142
+ object_results = object_model(frame)
143
+
144
+ for i, box in enumerate(object_results[0].boxes):
145
+ class_id = int(box.cls)
146
+ detection = {
147
+ "label": object_model.names[class_id],
148
+ "score": float(box.conf),
149
+ "box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()]
150
+ }
151
+ detections.append(detection)
152
+
153
+ if show_labels in ["Pose Estimation", "Both"]:
154
+ # Run YOLOv8 pose estimation on the frame
155
+ pose_results = pose_model(frame)
156
+
157
+ for i, box in enumerate(pose_results[0].boxes):
158
+ class_id = int(box.cls)
159
+ detection = {
160
+ "label": pose_model.names[class_id],
161
+ "score": float(box.conf),
162
+ "box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()]
163
+ }
164
+
165
+ # Get keypoints for this detection if available
166
+ try:
167
+ if pose_results[0].keypoints is not None:
168
+ keypoints = pose_results[0].keypoints[i].data.cpu().numpy()
169
+
170
+ # Detect action using the keypoints
171
+ prev_keypoints = img_container.get("prev_keypoints")
172
+ action = detect_action(keypoints, prev_keypoints)
173
+ detection["action"] = action
174
+
175
+ # Store current keypoints for next frame
176
+ img_container["prev_keypoints"] = keypoints
177
+
178
+ # Calculate the average position of visible keypoints
179
+ visible_keypoints = keypoints[0][keypoints[0]
180
+ [:, 2] > 0.5][:, :2]
181
+ if len(visible_keypoints) > 0:
182
+ label_x, label_y = np.mean(
183
+ visible_keypoints, axis=0).astype(int)
184
+ else:
185
+ # Fallback to the center of the bounding box if no keypoints are visible
186
+ x1, y1, x2, y2 = detection["box_coords"]
187
+ label_x = int((x1 + x2) / 2)
188
+ label_y = int((y1 + y2) / 2)
189
+ else:
190
+ detection["action"] = "No keypoint data"
191
+ # Use the center of the bounding box for label position
192
+ x1, y1, x2, y2 = detection["box_coords"]
193
+ label_x = int((x1 + x2) / 2)
194
+ label_y = int((y1 + y2) / 2)
195
+ except IndexError:
196
+ detection["action"] = "Action detection failed"
197
+ # Use the center of the bounding box for label position
198
+ x1, y1, x2, y2 = detection["box_coords"]
199
+ label_x = int((x1 + x2) / 2)
200
+ label_y = int((y1 + y2) / 2)
201
+
202
+ # Only display the action as the label
203
+ label = detection.get('action', '')
204
+
205
+ # Increase font scale and thickness to match box label size
206
+ font_scale = 2.0
207
+ thickness = 2
208
+
209
+ # Get text size for label
210
+ (label_width, label_height), _ = cv2.getTextSize(
211
+ label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness)
212
+
213
+ # Calculate position for centered label
214
+ label_y = label_y - 10 # 10 pixels above the calculated position
215
+
216
+ # Draw yellow background for label
217
+ cv2.rectangle(frame, (label_x - label_width // 2 - 5, label_y - label_height - 5),
218
+ (label_x + label_width // 2 + 5, label_y + 5), (0, 255, 255), -1)
219
+
220
+ # Draw black text for label
221
+ cv2.putText(frame, label, (label_x - label_width // 2, label_y),
222
+ cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness)
223
+
224
+ detections.append(detection)
225
+
226
+ # Draw detections on the frame
227
+ if show_labels == "Object Detection":
228
+ frame = object_results[0].plot()
229
+ elif show_labels == "Pose Estimation":
230
+ frame = pose_results[0].plot(boxes=False, labels=False, kpt_line=True)
231
+ else: # Both
232
+ frame = object_results[0].plot()
233
+ frame = pose_results[0].plot(
234
+ boxes=False, labels=False, kpt_line=True, img=frame)
235
 
236
  end_time = time.time()
237
  execution_time_ms = round((end_time - start_time) * 1000, 2)
 
354
  # Text input for YouTube URL
355
  st.subheader("Enter a YouTube URL")
356
  youtube_url = st.text_input("YouTube URL")
357
+ yt_error = st.empty() # Placeholder for analysis time
358
 
359
  # File uploader for videos
360
  st.subheader("Upload a Video")
 
382
 
383
 
384
  def analysis_init():
385
+ global yt_error, analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder
386
+
387
+ yt_error.empty() # Placeholder for analysis time
388
 
389
  with col2:
390
  st.header("Analysis")
 
393
  st.subheader("Output Frame")
394
  output_placeholder = st.empty() # Placeholder for output frame
395
  analysis_time = st.empty() # Placeholder for analysis time
396
+ show_labels = st.radio(
397
+ "Choose Detection Type",
398
+ ("Object Detection", "Pose Estimation", "Both"),
399
+ index=2 # Set default to "Both" (index 2)
400
+ )
401
  labels_placeholder = st.empty() # Placeholder for labels
402
 
403
 
 
480
 
481
  # Function to get the video stream URL from YouTube using yt-dlp
482
 
 
483
  def get_youtube_stream_url(youtube_url):
484
  ydl_opts = {
485
+ 'format': 'bestvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best',
486
  'quiet': True,
487
+ 'no_warnings': True,
488
  }
489
+
490
  with yt_dlp.YoutubeDL(ydl_opts) as ydl:
491
+ try:
492
+ info_dict = ydl.extract_info(youtube_url, download=False)
493
+ if 'url' in info_dict:
494
+ return info_dict['url']
495
+ elif 'entries' in info_dict:
496
+ return info_dict['entries'][0]['url']
497
+ else:
498
+ yt_error.error(
499
+ "Unable to extract video URL. The video might be unavailable or restricted.")
500
+ return None
501
+ except yt_dlp.utils.DownloadError as e:
502
+ yt_error.error(
503
+ f"Error: Unable to process the YouTube URL. {str(e)}")
504
+ return None
505
 
506
 
507
  # If a YouTube URL is provided, process the video
 
510
 
511
  stream_url = get_youtube_stream_url(youtube_url)
512
 
513
+ if stream_url:
514
+ process_video(stream_url) # Process the video
515
+ else:
516
+ yt_error.error(
517
+ "Unable to process the YouTube video. Please try a different URL or video format.")
518
 
519
  # If a video is uploaded or a URL is provided, process the video
520
  if uploaded_video is not None or video_url: