Guru-25 commited on
Commit
b7c0c37
·
verified ·
1 Parent(s): 5f7e302
Files changed (2) hide show
  1. app.py +125 -98
  2. requirements.txt +1 -3
app.py CHANGED
@@ -4,12 +4,14 @@ import numpy as np
4
  import tempfile
5
  import os
6
  import time
 
7
  from scripts.inference import GazePredictor
8
  from utils.ear_utils import BlinkDetector
9
  from gradio_webrtc import WebRTC
10
  from ultralytics import YOLO
11
  import torch
12
- import spaces # Add spaces import
 
13
 
14
  def smooth_values(history, current_value, window_size=5):
15
  if current_value is not None:
@@ -34,12 +36,45 @@ def smooth_values(history, current_value, window_size=5):
34
  else:
35
  return history[-1] if history else None
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # --- Model Paths ---
38
  GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
39
  DISTRACTION_MODEL_PATH = "best.pt"
40
 
41
  # --- Global Initializations ---
42
- blink_detector = BlinkDetector() # Keep BlinkDetector global as it is CPU-only
 
 
 
 
 
43
 
44
  # Distraction Class Names
45
  distraction_class_names = [
@@ -71,12 +106,10 @@ EYE_CLOSURE_THRESHOLD = 10
71
  HEAD_STABILITY_THRESHOLD = 0.05
72
  DISTRACTION_CONF_THRESHOLD = 0.1
73
 
74
- @spaces.GPU
75
  def analyze_video(input_video):
76
- local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH, device='cuda') # Load directly to CUDA
77
- local_blink_detector = blink_detector # Use global CPU instance
78
-
79
  cap = cv2.VideoCapture(input_video)
 
 
80
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
81
  temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4')
82
  os.close(temp_fd)
@@ -200,11 +233,8 @@ def analyze_video(input_video):
200
  out.release()
201
  return temp_path
202
 
203
- @spaces.GPU
204
  def analyze_distraction_video(input_video):
205
- local_distraction_model = YOLO(DISTRACTION_MODEL_PATH)
206
- local_distraction_model.to('cuda') # Move to GPU
207
-
208
  cap = cv2.VideoCapture(input_video)
209
  if not cap.isOpened():
210
  print("Error: Could not open video file.")
@@ -223,7 +253,7 @@ def analyze_distraction_video(input_video):
223
  break
224
 
225
  try:
226
- results = local_distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
227
 
228
  display_text = "safe driving"
229
  alarm_action = None
@@ -272,6 +302,73 @@ def analyze_distraction_video(input_video):
272
  out.release()
273
  return temp_path
274
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
275
  def terminate_gaze_stream():
276
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
277
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
@@ -292,15 +389,13 @@ def terminate_gaze_stream():
292
 
293
  def terminate_distraction_stream():
294
  global stop_distraction_processing
295
- print("Distraction Live Termination signal received. Stopping processing.")
 
296
  stop_distraction_processing = True
297
- return "Distraction Live Processing Terminated."
298
 
299
- @spaces.GPU
300
  def process_gaze_frame(frame):
301
- gaze_predictor_live = GazePredictor(GAZE_MODEL_PATH, device='cuda') # Load directly to CUDA
302
- local_blink_detector = blink_detector # Use global CPU instance
303
-
304
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
305
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
306
 
@@ -316,8 +411,11 @@ def process_gaze_frame(frame):
316
  start_time = current_time
317
 
318
  try:
319
- head_pose_gaze, gaze_h, gaze_v = gaze_predictor_live.predict_gaze(frame)
320
- ear, left_eye, right_eye, head_pose, left_iris, right_iris = local_blink_detector.detect_blinks(frame)
 
 
 
321
 
322
  if ear is None:
323
  cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
@@ -409,72 +507,11 @@ def process_gaze_frame(frame):
409
  cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
410
  return error_frame
411
 
412
- @spaces.GPU
413
- def process_distraction_frame(frame):
414
- distraction_model_live = YOLO(DISTRACTION_MODEL_PATH)
415
- distraction_model_live.to('cuda')
416
-
417
- global stop_distraction_processing
418
-
419
- if stop_distraction_processing:
420
- return np.zeros((480, 640, 3), dtype=np.uint8)
421
-
422
- if frame is None:
423
- return np.zeros((480, 640, 3), dtype=np.uint8)
424
-
425
- try:
426
- frame_to_process = frame
427
-
428
- results = distraction_model_live(frame_to_process, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
429
-
430
- display_text = "safe driving"
431
- alarm_action = None
432
-
433
- for result in results:
434
- if result.boxes is not None and len(result.boxes) > 0:
435
- boxes = result.boxes.xyxy.cpu().numpy()
436
- scores = result.boxes.conf.cpu().numpy()
437
- classes = result.boxes.cls.cpu().numpy()
438
-
439
- if len(boxes) > 0:
440
- max_score_idx = scores.argmax()
441
- detected_action_idx = int(classes[max_score_idx])
442
- if 0 <= detected_action_idx < len(distraction_class_names):
443
- detected_action = distraction_class_names[detected_action_idx]
444
- confidence = scores[max_score_idx]
445
- display_text = f"{detected_action}: {confidence:.2f}"
446
- if detected_action != 'safe driving':
447
- alarm_action = detected_action
448
- else:
449
- print(f"Warning: Detected class index {detected_action_idx} out of bounds.")
450
- display_text = "Unknown Detection"
451
-
452
- frame_bgr = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
453
- if alarm_action:
454
- print(f"ALARM: Unsafe behavior detected - {alarm_action}!")
455
- cv2.putText(frame_bgr, f"ALARM: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
456
-
457
- text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255)
458
- cv2.putText(frame_bgr, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
459
-
460
- frame_rgb_processed = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
461
- return frame_rgb_processed
462
-
463
- except Exception as e:
464
- print(f"Error processing distraction frame: {e}")
465
- error_frame = np.zeros((480, 640, 3), dtype=np.uint8)
466
- if not error_frame.flags.writeable:
467
- error_frame = error_frame.copy()
468
- error_frame_bgr = cv2.cvtColor(error_frame, cv2.COLOR_RGB2BGR)
469
- cv2.putText(error_frame_bgr, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
470
- error_frame_rgb = cv2.cvtColor(error_frame_bgr, cv2.COLOR_BGR2RGB)
471
- return error_frame_rgb
472
-
473
  def create_gaze_interface():
474
  with gr.Blocks() as gaze_demo:
475
  gr.Markdown("## Real-time Gaze & Drowsiness Tracking")
476
  with gr.Row():
477
- webcam_stream = WebRTC(label="Webcam Stream")
478
  with gr.Row():
479
  terminate_btn = gr.Button("Terminate Process")
480
 
@@ -489,20 +526,10 @@ def create_gaze_interface():
489
  return gaze_demo
490
 
491
  def create_distraction_interface():
492
- distraction_demo = gr.Interface(
493
- fn=analyze_distraction_video,
494
- inputs=gr.Video(sources=["upload", "webcam"], label="Input Video (Upload or Record)"),
495
- outputs=gr.Video(label="Processed Video"),
496
- title="Distraction Detection Analysis",
497
- description="Upload or record a video to analyze driver distraction."
498
- )
499
- return distraction_demo
500
-
501
- def create_distraction_live_interface():
502
- with gr.Blocks() as distraction_live_demo:
503
- gr.Markdown("## Real-time Distraction Detection (Live)")
504
  with gr.Row():
505
- webcam_stream = WebRTC(label="Webcam Stream")
506
  with gr.Row():
507
  terminate_btn = gr.Button("Terminate Process")
508
 
@@ -514,7 +541,7 @@ def create_distraction_live_interface():
514
 
515
  terminate_btn.click(fn=terminate_distraction_stream, inputs=None, outputs=None)
516
 
517
- return distraction_live_demo
518
 
519
  def create_video_interface():
520
  video_demo = gr.Interface(
@@ -527,8 +554,8 @@ def create_video_interface():
527
  return video_demo
528
 
529
  demo = gr.TabbedInterface(
530
- [create_video_interface(), create_gaze_interface(), create_distraction_interface(), create_distraction_live_interface()],
531
- ["Gaze Video Upload", "Gaze & Drowsiness (Live)", "Distraction Video Upload", "Distraction Detection (Live)"],
532
  title="Driver Monitoring System"
533
  )
534
 
@@ -545,4 +572,4 @@ if __name__ == "__main__":
545
  frame_count_webcam = 0
546
  stop_gaze_processing = False
547
  stop_distraction_processing = False
548
- demo.launch()
 
4
  import tempfile
5
  import os
6
  import time
7
+ import spaces
8
  from scripts.inference import GazePredictor
9
  from utils.ear_utils import BlinkDetector
10
  from gradio_webrtc import WebRTC
11
  from ultralytics import YOLO
12
  import torch
13
+ import json
14
+ import requests
15
 
16
  def smooth_values(history, current_value, window_size=5):
17
  if current_value is not None:
 
36
  else:
37
  return history[-1] if history else None
38
 
39
+ # --- Configure Twilio TURN servers for WebRTC ---
40
+ def get_twilio_turn_credentials():
41
+ # Replace with your Twilio credentials or set as environment variables
42
+ twilio_account_sid = os.environ.get("TWILIO_ACCOUNT_SID", "")
43
+ twilio_auth_token = os.environ.get("TWILIO_AUTH_TOKEN", "")
44
+
45
+ if not twilio_account_sid or not twilio_auth_token:
46
+ print("Warning: Twilio credentials not found. Using default RTCConfiguration.")
47
+ return None
48
+
49
+ try:
50
+ response = requests.post(
51
+ f"https://api.twilio.com/2010-04-01/Accounts/{twilio_account_sid}/Tokens.json",
52
+ auth=(twilio_account_sid, twilio_auth_token),
53
+ )
54
+ data = response.json()
55
+ return data["ice_servers"]
56
+ except Exception as e:
57
+ print(f"Error fetching Twilio TURN credentials: {e}")
58
+ return None
59
+
60
+ # Configure WebRTC
61
+ ice_servers = get_twilio_turn_credentials()
62
+ if ice_servers:
63
+ rtc_configuration = {"iceServers": ice_servers}
64
+ else:
65
+ rtc_configuration = {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
66
+
67
  # --- Model Paths ---
68
  GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
69
  DISTRACTION_MODEL_PATH = "best.pt"
70
 
71
  # --- Global Initializations ---
72
+ gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
73
+ blink_detector = BlinkDetector()
74
+
75
+ # Load Distraction Model
76
+ distraction_model = YOLO(DISTRACTION_MODEL_PATH)
77
+ distraction_model.to('cpu')
78
 
79
  # Distraction Class Names
80
  distraction_class_names = [
 
106
  HEAD_STABILITY_THRESHOLD = 0.05
107
  DISTRACTION_CONF_THRESHOLD = 0.1
108
 
 
109
  def analyze_video(input_video):
 
 
 
110
  cap = cv2.VideoCapture(input_video)
111
+ local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
112
+ local_blink_detector = BlinkDetector()
113
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
114
  temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4')
115
  os.close(temp_fd)
 
233
  out.release()
234
  return temp_path
235
 
236
+ @spaces.GPU(duration=30) # Set duration to 30 seconds for real-time processing
237
  def analyze_distraction_video(input_video):
 
 
 
238
  cap = cv2.VideoCapture(input_video)
239
  if not cap.isOpened():
240
  print("Error: Could not open video file.")
 
253
  break
254
 
255
  try:
256
+ results = distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
257
 
258
  display_text = "safe driving"
259
  alarm_action = None
 
302
  out.release()
303
  return temp_path
304
 
305
+ @spaces.GPU(duration=30) # Set duration to 30 seconds for real-time processing
306
+ def process_distraction_frame(frame):
307
+ global stop_distraction_processing
308
+
309
+ if stop_distraction_processing:
310
+ return np.zeros((480, 640, 3), dtype=np.uint8)
311
+
312
+ if frame is None:
313
+ return np.zeros((480, 640, 3), dtype=np.uint8)
314
+
315
+ try:
316
+ # Run distraction detection model
317
+ results = distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
318
+
319
+ display_text = "safe driving"
320
+ alarm_action = None
321
+
322
+ for result in results:
323
+ if result.boxes is not None and len(result.boxes) > 0:
324
+ boxes = result.boxes.xyxy.cpu().numpy()
325
+ scores = result.boxes.conf.cpu().numpy()
326
+ classes = result.boxes.cls.cpu().numpy()
327
+
328
+ if len(boxes) > 0:
329
+ # Draw bounding boxes
330
+ for i, box in enumerate(boxes):
331
+ x1, y1, x2, y2 = map(int, box)
332
+ cls_id = int(classes[i])
333
+ confidence = scores[i]
334
+
335
+ if 0 <= cls_id < len(distraction_class_names):
336
+ action = distraction_class_names[cls_id]
337
+ color = (0, 255, 0) if action == "safe driving" else (0, 0, 255)
338
+ cv2.rectangle(frame, (x1, y1), (x2, y2), color, 2)
339
+ cv2.putText(frame, f"{action} {confidence:.2f}", (x1, y1-10),
340
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
341
+
342
+ # Select highest confidence detection for status
343
+ if i == scores.argmax():
344
+ detected_action = action
345
+ confidence_score = confidence
346
+ display_text = f"{detected_action}: {confidence_score:.2f}"
347
+ if detected_action != 'safe driving':
348
+ alarm_action = detected_action
349
+ else:
350
+ print(f"Warning: Detected class index {cls_id} out of bounds.")
351
+ display_text = "Unknown Detection"
352
+
353
+ if alarm_action:
354
+ cv2.putText(frame, f"ALERT: {alarm_action}", (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
355
+
356
+ # Always show current detection status
357
+ text_color = (0, 255, 0) if alarm_action is None else (0, 255, 255)
358
+ cv2.putText(frame, display_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, text_color, 2)
359
+
360
+ # Convert BGR to RGB for Gradio display
361
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
362
+ return frame_rgb
363
+
364
+ except Exception as e:
365
+ print(f"Error processing frame for distraction detection: {e}")
366
+ error_frame = np.zeros((480, 640, 3), dtype=np.uint8)
367
+ if not error_frame.flags.writeable:
368
+ error_frame = error_frame.copy()
369
+ cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
370
+ return error_frame
371
+
372
  def terminate_gaze_stream():
373
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
374
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
 
389
 
390
  def terminate_distraction_stream():
391
  global stop_distraction_processing
392
+
393
+ print("Distraction Termination signal received. Stopping processing.")
394
  stop_distraction_processing = True
395
+ return "Distraction Processing Terminated."
396
 
397
+ @spaces.GPU(duration=30) # Set duration to 30 seconds for real-time processing
398
  def process_gaze_frame(frame):
 
 
 
399
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
400
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
401
 
 
411
  start_time = current_time
412
 
413
  try:
414
+ head_pose_gaze, gaze_h, gaze_v = gaze_predictor.predict_gaze(frame)
415
+ current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
416
+ smoothed_gaze = smooth_values(gaze_history, current_gaze)
417
+
418
+ ear, left_eye, right_eye, head_pose, left_iris, right_iris = blink_detector.detect_blinks(frame)
419
 
420
  if ear is None:
421
  cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
 
507
  cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
508
  return error_frame
509
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
510
  def create_gaze_interface():
511
  with gr.Blocks() as gaze_demo:
512
  gr.Markdown("## Real-time Gaze & Drowsiness Tracking")
513
  with gr.Row():
514
+ webcam_stream = WebRTC(label="Webcam Stream", rtc_configuration=rtc_configuration)
515
  with gr.Row():
516
  terminate_btn = gr.Button("Terminate Process")
517
 
 
526
  return gaze_demo
527
 
528
  def create_distraction_interface():
529
+ with gr.Blocks() as distraction_demo:
530
+ gr.Markdown("## Real-time Distraction Detection")
 
 
 
 
 
 
 
 
 
 
531
  with gr.Row():
532
+ webcam_stream = WebRTC(label="Webcam Stream", rtc_configuration=rtc_configuration)
533
  with gr.Row():
534
  terminate_btn = gr.Button("Terminate Process")
535
 
 
541
 
542
  terminate_btn.click(fn=terminate_distraction_stream, inputs=None, outputs=None)
543
 
544
+ return distraction_demo
545
 
546
  def create_video_interface():
547
  video_demo = gr.Interface(
 
554
  return video_demo
555
 
556
  demo = gr.TabbedInterface(
557
+ [create_video_interface(), create_gaze_interface(), create_distraction_interface()],
558
+ ["Gaze Video Upload", "Gaze & Drowsiness (Live)", "Distraction Detection (Live)"],
559
  title="Driver Monitoring System"
560
  )
561
 
 
572
  frame_count_webcam = 0
573
  stop_gaze_processing = False
574
  stop_distraction_processing = False
575
+ demo.launch()
requirements.txt CHANGED
@@ -11,6 +11,4 @@ tensorflow
11
  pygame
12
  twilio
13
  ultralytics==8.3.93
14
- # torch==2.6.0 # Replace with ZeroGPU compatible version, e.g., 2.4.0
15
- torch==2.4.0 # Example compatible version
16
- spaces # Add spaces for ZeroGPU
 
11
  pygame
12
  twilio
13
  ultralytics==8.3.93
14
+ torch==2.6.0