dschandra commited on
Commit
f47a8e5
·
verified ·
1 Parent(s): 0d6064c

Update utils/video_processing.py

Browse files
Files changed (1) hide show
  1. utils/video_processing.py +22 -5
utils/video_processing.py CHANGED
@@ -12,12 +12,24 @@ MODEL_PATH = 'models/yolov8_model.pt'
12
  if not os.path.exists(MODEL_PATH):
13
  raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
14
 
15
- # Load YOLO model with weights_only=False for compatibility with PyTorch 2.6
16
  try:
17
- # Explicitly set weights_only=False to allow loading Ultralytics model metadata
18
- model = YOLO(MODEL_PATH, weights_only=False)
19
  except Exception as e:
20
- raise RuntimeError(f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. Ensure the model is a valid YOLOv8 .pt file from a trusted source.")
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  def track_ball(video_path: str) -> list:
23
  """
@@ -89,4 +101,9 @@ def generate_replay(video_path: str, trajectory: list, decision: str) -> str:
89
  cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
90
  (int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
91
  cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
92
- out
 
 
 
 
 
 
12
  if not os.path.exists(MODEL_PATH):
13
  raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
14
 
15
+ # Load YOLO model
16
  try:
17
+ # Load the model using Ultralytics YOLO
18
+ model = YOLO(MODEL_PATH)
19
  except Exception as e:
20
+ # If loading fails due to weights_only issue, try manual loading
21
+ try:
22
+ # Manually load the checkpoint with weights_only=False
23
+ checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
24
+ model = YOLO('yolov8n.yaml') # Load model architecture from YAML
25
+ model.load_state_dict(checkpoint['model'].state_dict()) # Load weights
26
+ except Exception as inner_e:
27
+ raise RuntimeError(
28
+ f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. "
29
+ f"Manual loading also failed: {str(inner_e)}. "
30
+ "Ensure the model is a valid YOLOv8 .pt file from a trusted source. "
31
+ "You may need to re-save the model or use a pre-trained model like yolov8n.pt."
32
+ )
33
 
34
  def track_ball(video_path: str) -> list:
35
  """
 
101
  cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
102
  (int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
103
  cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
104
+ out.write(frame)
105
+ frame_idx += 1
106
+
107
+ cap.release()
108
+ out.release()
109
+ return replay_path