Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -8,15 +8,16 @@ import plotly.graph_objects as go
|
|
8 |
import uuid
|
9 |
import os
|
10 |
|
11 |
-
# Load the trained YOLOv8n model
|
12 |
model = YOLO("best.pt")
|
|
|
13 |
|
14 |
# Constants for LBW decision and video processing
|
15 |
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
|
16 |
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
|
17 |
-
FRAME_RATE = 30 # Input video frame rate
|
18 |
SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
|
19 |
-
CONF_THRESHOLD = 0.2 #
|
20 |
IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
|
21 |
IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
|
22 |
PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
|
@@ -28,6 +29,9 @@ def process_video(video_path):
|
|
28 |
if not os.path.exists(video_path):
|
29 |
return [], [], [], "Error: Video file not found"
|
30 |
cap = cv2.VideoCapture(video_path)
|
|
|
|
|
|
|
31 |
frames = []
|
32 |
ball_positions = []
|
33 |
detection_frames = []
|
@@ -40,7 +44,8 @@ def process_video(video_path):
|
|
40 |
break
|
41 |
frame_count += 1
|
42 |
frames.append(frame.copy())
|
43 |
-
|
|
|
44 |
detections = 0
|
45 |
for detection in results[0].boxes:
|
46 |
if detection.cls == 0: # Class 0 is the ball
|
@@ -57,6 +62,7 @@ def process_video(video_path):
|
|
57 |
debug_log.append("No balls detected in any frame")
|
58 |
else:
|
59 |
debug_log.append(f"Total ball detections: {len(ball_positions)}")
|
|
|
60 |
|
61 |
return frames, ball_positions, detection_frames, "\n".join(debug_log)
|
62 |
|
@@ -91,7 +97,6 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
91 |
impact_frame = detection_frames[i]
|
92 |
break
|
93 |
elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
|
94 |
-
# Fallback to y-position if no significant y-change
|
95 |
impact_idx = i
|
96 |
impact_frame = detection_frames[i]
|
97 |
break
|
@@ -206,7 +211,7 @@ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d
|
|
206 |
impact_scatter = go.Scatter3d(
|
207 |
x=[impact_point_3d[0]] if impact_point_3d else [],
|
208 |
y=[impact_point_3d[1]] if impact_point_3d else [],
|
209 |
-
z=[impact_point_3d[2]] if
|
210 |
mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
|
211 |
)
|
212 |
data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
|
@@ -228,8 +233,9 @@ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d
|
|
228 |
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
|
229 |
if not frames:
|
230 |
return None
|
|
|
231 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
232 |
-
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (
|
233 |
|
234 |
if trajectory and detection_frames:
|
235 |
trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
|
|
|
8 |
import uuid
|
9 |
import os
|
10 |
|
11 |
+
# Load the trained YOLOv8n model with optimizations
|
12 |
model = YOLO("best.pt")
|
13 |
+
model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
|
14 |
|
15 |
# Constants for LBW decision and video processing
|
16 |
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
|
17 |
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
|
18 |
+
FRAME_RATE = 30 # Input video frame rate (adjust if known)
|
19 |
SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
|
20 |
+
CONF_THRESHOLD = 0.2 # Confidence threshold
|
21 |
IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
|
22 |
IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
|
23 |
PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
|
|
|
29 |
if not os.path.exists(video_path):
|
30 |
return [], [], [], "Error: Video file not found"
|
31 |
cap = cv2.VideoCapture(video_path)
|
32 |
+
# Get native video resolution
|
33 |
+
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
34 |
+
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
35 |
frames = []
|
36 |
ball_positions = []
|
37 |
detection_frames = []
|
|
|
44 |
break
|
45 |
frame_count += 1
|
46 |
frames.append(frame.copy())
|
47 |
+
# Use native resolution for inference
|
48 |
+
results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
|
49 |
detections = 0
|
50 |
for detection in results[0].boxes:
|
51 |
if detection.cls == 0: # Class 0 is the ball
|
|
|
62 |
debug_log.append("No balls detected in any frame")
|
63 |
else:
|
64 |
debug_log.append(f"Total ball detections: {len(ball_positions)}")
|
65 |
+
debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
|
66 |
|
67 |
return frames, ball_positions, detection_frames, "\n".join(debug_log)
|
68 |
|
|
|
97 |
impact_frame = detection_frames[i]
|
98 |
break
|
99 |
elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
|
|
|
100 |
impact_idx = i
|
101 |
impact_frame = detection_frames[i]
|
102 |
break
|
|
|
211 |
impact_scatter = go.Scatter3d(
|
212 |
x=[impact_point_3d[0]] if impact_point_3d else [],
|
213 |
y=[impact_point_3d[1]] if impact_point_3d else [],
|
214 |
+
z=[impact_point_3d[2]] if impact_point_3d else [],
|
215 |
mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
|
216 |
)
|
217 |
data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
|
|
|
233 |
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
|
234 |
if not frames:
|
235 |
return None
|
236 |
+
frame_height, frame_width = frames[0].shape[:2]
|
237 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
238 |
+
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
|
239 |
|
240 |
if trajectory and detection_frames:
|
241 |
trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
|