Guru-25 commited on
Commit
5f7e302
·
verified ·
1 Parent(s): 6b22c31
Files changed (1) hide show
  1. app.py +21 -26
app.py CHANGED
@@ -39,14 +39,7 @@ GAZE_MODEL_PATH = os.path.join("models", "gaze_estimation_model.pth")
39
  DISTRACTION_MODEL_PATH = "best.pt"
40
 
41
  # --- Global Initializations ---
42
- # Load models on CPU initially
43
- gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
44
- blink_detector = BlinkDetector()
45
-
46
- # Load Distraction Model on CPU initially
47
- distraction_model = YOLO(DISTRACTION_MODEL_PATH)
48
- distraction_model.to('cpu')
49
-
50
 
51
  # Distraction Class Names
52
  distraction_class_names = [
@@ -78,10 +71,12 @@ EYE_CLOSURE_THRESHOLD = 10
78
  HEAD_STABILITY_THRESHOLD = 0.05
79
  DISTRACTION_CONF_THRESHOLD = 0.1
80
 
 
81
  def analyze_video(input_video):
 
 
 
82
  cap = cv2.VideoCapture(input_video)
83
- local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH)
84
- local_blink_detector = BlinkDetector()
85
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
86
  temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4')
87
  os.close(temp_fd)
@@ -205,7 +200,11 @@ def analyze_video(input_video):
205
  out.release()
206
  return temp_path
207
 
 
208
  def analyze_distraction_video(input_video):
 
 
 
209
  cap = cv2.VideoCapture(input_video)
210
  if not cap.isOpened():
211
  print("Error: Could not open video file.")
@@ -224,7 +223,7 @@ def analyze_distraction_video(input_video):
224
  break
225
 
226
  try:
227
- results = distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
228
 
229
  display_text = "safe driving"
230
  alarm_action = None
@@ -297,16 +296,14 @@ def terminate_distraction_stream():
297
  stop_distraction_processing = True
298
  return "Distraction Live Processing Terminated."
299
 
300
- @spaces.GPU # Add ZeroGPU decorator
301
  def process_gaze_frame(frame):
 
 
 
302
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
303
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
304
 
305
- try:
306
- gaze_predictor.model.to('cuda')
307
- except Exception as e:
308
- print(f"Warning: Could not move gaze model to CUDA: {e}")
309
-
310
  if stop_gaze_processing:
311
  return np.zeros((480, 640, 3), dtype=np.uint8)
312
 
@@ -319,11 +316,8 @@ def process_gaze_frame(frame):
319
  start_time = current_time
320
 
321
  try:
322
- head_pose_gaze, gaze_h, gaze_v = gaze_predictor.predict_gaze(frame)
323
- current_gaze = np.array([gaze_h, gaze_v]) if gaze_h is not None and gaze_v is not None else None
324
- smoothed_gaze = smooth_values(gaze_history, current_gaze)
325
-
326
- ear, left_eye, right_eye, head_pose, left_iris, right_iris = blink_detector.detect_blinks(frame)
327
 
328
  if ear is None:
329
  cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
@@ -415,11 +409,12 @@ def process_gaze_frame(frame):
415
  cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
416
  return error_frame
417
 
418
- @spaces.GPU # Add ZeroGPU decorator
419
  def process_distraction_frame(frame):
420
- global stop_distraction_processing
 
421
 
422
- distraction_model.to('cuda')
423
 
424
  if stop_distraction_processing:
425
  return np.zeros((480, 640, 3), dtype=np.uint8)
@@ -430,7 +425,7 @@ def process_distraction_frame(frame):
430
  try:
431
  frame_to_process = frame
432
 
433
- results = distraction_model(frame_to_process, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
434
 
435
  display_text = "safe driving"
436
  alarm_action = None
 
39
  DISTRACTION_MODEL_PATH = "best.pt"
40
 
41
  # --- Global Initializations ---
42
+ blink_detector = BlinkDetector() # Keep BlinkDetector global as it is CPU-only
 
 
 
 
 
 
 
43
 
44
  # Distraction Class Names
45
  distraction_class_names = [
 
71
  HEAD_STABILITY_THRESHOLD = 0.05
72
  DISTRACTION_CONF_THRESHOLD = 0.1
73
 
74
+ @spaces.GPU
75
  def analyze_video(input_video):
76
+ local_gaze_predictor = GazePredictor(GAZE_MODEL_PATH, device='cuda') # Load directly to CUDA
77
+ local_blink_detector = blink_detector # Use global CPU instance
78
+
79
  cap = cv2.VideoCapture(input_video)
 
 
80
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
81
  temp_fd, temp_path = tempfile.mkstemp(suffix='.mp4')
82
  os.close(temp_fd)
 
200
  out.release()
201
  return temp_path
202
 
203
+ @spaces.GPU
204
  def analyze_distraction_video(input_video):
205
+ local_distraction_model = YOLO(DISTRACTION_MODEL_PATH)
206
+ local_distraction_model.to('cuda') # Move to GPU
207
+
208
  cap = cv2.VideoCapture(input_video)
209
  if not cap.isOpened():
210
  print("Error: Could not open video file.")
 
223
  break
224
 
225
  try:
226
+ results = local_distraction_model(frame, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
227
 
228
  display_text = "safe driving"
229
  alarm_action = None
 
296
  stop_distraction_processing = True
297
  return "Distraction Live Processing Terminated."
298
 
299
+ @spaces.GPU
300
  def process_gaze_frame(frame):
301
+ gaze_predictor_live = GazePredictor(GAZE_MODEL_PATH, device='cuda') # Load directly to CUDA
302
+ local_blink_detector = blink_detector # Use global CPU instance
303
+
304
  global gaze_history, head_history, ear_history, stable_gaze_time, stable_head_time
305
  global eye_closed_time, blink_count, start_time, is_unconscious, frame_count_webcam, stop_gaze_processing
306
 
 
 
 
 
 
307
  if stop_gaze_processing:
308
  return np.zeros((480, 640, 3), dtype=np.uint8)
309
 
 
316
  start_time = current_time
317
 
318
  try:
319
+ head_pose_gaze, gaze_h, gaze_v = gaze_predictor_live.predict_gaze(frame)
320
+ ear, left_eye, right_eye, head_pose, left_iris, right_iris = local_blink_detector.detect_blinks(frame)
 
 
 
321
 
322
  if ear is None:
323
  cv2.putText(frame, "No face detected", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
 
409
  cv2.putText(error_frame, f"Error: {e}", (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 0, 0), 2)
410
  return error_frame
411
 
412
+ @spaces.GPU
413
  def process_distraction_frame(frame):
414
+ distraction_model_live = YOLO(DISTRACTION_MODEL_PATH)
415
+ distraction_model_live.to('cuda')
416
 
417
+ global stop_distraction_processing
418
 
419
  if stop_distraction_processing:
420
  return np.zeros((480, 640, 3), dtype=np.uint8)
 
425
  try:
426
  frame_to_process = frame
427
 
428
+ results = distraction_model_live(frame_to_process, conf=DISTRACTION_CONF_THRESHOLD, verbose=False)
429
 
430
  display_text = "safe driving"
431
  alarm_action = None