Update utils/video_processing.py
Browse files- utils/video_processing.py +22 -5
utils/video_processing.py
CHANGED
@@ -12,12 +12,24 @@ MODEL_PATH = 'models/yolov8_model.pt'
|
|
12 |
if not os.path.exists(MODEL_PATH):
|
13 |
raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
|
14 |
|
15 |
-
# Load YOLO model
|
16 |
try:
|
17 |
-
#
|
18 |
-
model = YOLO(MODEL_PATH
|
19 |
except Exception as e:
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
def track_ball(video_path: str) -> list:
|
23 |
"""
|
@@ -89,4 +101,9 @@ def generate_replay(video_path: str, trajectory: list, decision: str) -> str:
|
|
89 |
cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
|
90 |
(int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
|
91 |
cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
92 |
-
out
|
|
|
|
|
|
|
|
|
|
|
|
12 |
if not os.path.exists(MODEL_PATH):
|
13 |
raise FileNotFoundError(f"YOLO model file not found at {MODEL_PATH}. Please ensure 'yolov8_model.pt' is in the 'models/' directory.")
|
14 |
|
15 |
+
# Load YOLO model
|
16 |
try:
|
17 |
+
# Load the model using Ultralytics YOLO
|
18 |
+
model = YOLO(MODEL_PATH)
|
19 |
except Exception as e:
|
20 |
+
# If loading fails due to weights_only issue, try manual loading
|
21 |
+
try:
|
22 |
+
# Manually load the checkpoint with weights_only=False
|
23 |
+
checkpoint = torch.load(MODEL_PATH, map_location='cpu', weights_only=False)
|
24 |
+
model = YOLO('yolov8n.yaml') # Load model architecture from YAML
|
25 |
+
model.load_state_dict(checkpoint['model'].state_dict()) # Load weights
|
26 |
+
except Exception as inner_e:
|
27 |
+
raise RuntimeError(
|
28 |
+
f"Failed to load YOLO model from {MODEL_PATH}: {str(e)}. "
|
29 |
+
f"Manual loading also failed: {str(inner_e)}. "
|
30 |
+
"Ensure the model is a valid YOLOv8 .pt file from a trusted source. "
|
31 |
+
"You may need to re-save the model or use a pre-trained model like yolov8n.pt."
|
32 |
+
)
|
33 |
|
34 |
def track_ball(video_path: str) -> list:
|
35 |
"""
|
|
|
101 |
cv2.line(frame, (int(trajectory[i-1][0]), int(trajectory[i-1][1])),
|
102 |
(int(trajectory[i][0]), int(trajectory[i][1])), (255, 0, 0), 2)
|
103 |
cv2.putText(frame, f"Decision: {decision}", (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2)
|
104 |
+
out.write(frame)
|
105 |
+
frame_idx += 1
|
106 |
+
|
107 |
+
cap.release()
|
108 |
+
out.release()
|
109 |
+
return replay_path
|