Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -16,15 +16,15 @@ model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
|
|
16 |
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
|
17 |
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
|
18 |
FRAME_RATE = 20 # Default frame rate, updated dynamically
|
19 |
-
SLOW_MOTION_FACTOR =
|
20 |
-
CONF_THRESHOLD = 0.25 #
|
21 |
-
IMPACT_ZONE_Y = 0.
|
22 |
-
IMPACT_VELOCITY_THRESHOLD =
|
23 |
PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
|
24 |
STUMPS_HEIGHT = 0.71 # meters (stumps height)
|
25 |
CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
|
26 |
CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
|
27 |
-
MAX_POSITION_JUMP =
|
28 |
|
29 |
def process_video(video_path):
|
30 |
if not os.path.exists(video_path):
|
@@ -59,6 +59,7 @@ def process_video(video_path):
|
|
59 |
if detections == 1: # Only process frames with exactly one ball detection
|
60 |
for detection in results[0].boxes:
|
61 |
if detection.cls == 0: # Class 0 is the ball
|
|
|
62 |
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
|
63 |
# Scale coordinates back to original frame size
|
64 |
x1 = x1 * frame_width / img_width
|
@@ -68,8 +69,10 @@ def process_video(video_path):
|
|
68 |
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
|
69 |
detection_frames.append(frame_count - 1)
|
70 |
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
|
|
|
|
|
|
|
71 |
frames[-1] = frame
|
72 |
-
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
|
73 |
# Save debug frame
|
74 |
cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
|
75 |
cap.release()
|
@@ -108,6 +111,7 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
108 |
filtered_positions.append(curr_pos)
|
109 |
filtered_frames.append(detection_frames[i])
|
110 |
else:
|
|
|
111 |
continue
|
112 |
|
113 |
if len(filtered_positions) < 2:
|
@@ -117,43 +121,32 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
117 |
y_coords = [pos[1] for pos in filtered_positions]
|
118 |
times = np.array(filtered_frames) / FRAME_RATE
|
119 |
|
120 |
-
# Convert to 3D for
|
121 |
detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
|
122 |
|
123 |
-
# Pitch point: Detection with lowest
|
124 |
-
pitch_idx = min(range(len(
|
125 |
pitch_point = filtered_positions[pitch_idx]
|
126 |
pitch_frame = filtered_frames[pitch_idx]
|
127 |
|
128 |
-
# Impact point:
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
for i in range(1, len(y_coords)):
|
134 |
-
if velocities[i-1] > IMPACT_VELOCITY_THRESHOLD:
|
135 |
-
impact_idx = i
|
136 |
-
impact_frame = filtered_frames[i]
|
137 |
-
break
|
138 |
-
elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
|
139 |
-
impact_idx = i
|
140 |
-
impact_frame = filtered_frames[i]
|
141 |
-
break
|
142 |
-
if impact_idx is None:
|
143 |
-
impact_idx = len(filtered_positions) - 1
|
144 |
-
impact_frame = filtered_frames[-1]
|
145 |
impact_point = filtered_positions[impact_idx]
|
|
|
146 |
|
147 |
try:
|
148 |
# Use linear interpolation for more stable trajectory
|
149 |
-
fx = interp1d(times
|
150 |
-
fy = interp1d(times
|
151 |
except Exception as e:
|
152 |
return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
|
153 |
|
154 |
# Generate dense points for all frames between first and last detection
|
155 |
total_frames = max(detection_frames) - min(detection_frames) + 1
|
156 |
-
t_full = np.linspace(times[0], times[
|
157 |
x_full = fx(t_full)
|
158 |
y_full = fy(t_full)
|
159 |
trajectory_2d = list(zip(x_full, y_full))
|
@@ -163,13 +156,13 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
163 |
impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
|
164 |
|
165 |
# Debug trajectory and points
|
166 |
-
debug_log =
|
167 |
-
f"Trajectory estimated successfully
|
168 |
-
f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f}), 3D: {pitch_point_3d}
|
169 |
-
f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f}), 3D: {impact_point_3d}
|
170 |
-
f"Detections in frames: {filtered_frames}
|
171 |
-
f"Velocities: {velocities}"
|
172 |
-
|
173 |
# Save trajectory plot for debugging
|
174 |
import matplotlib.pyplot as plt
|
175 |
plt.plot(x_coords, y_coords, 'bo-', label='Filtered Detections')
|
@@ -178,7 +171,7 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
178 |
plt.legend()
|
179 |
plt.savefig("trajectory_debug.png")
|
180 |
|
181 |
-
return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
|
182 |
|
183 |
def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
|
184 |
"""Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
|
@@ -291,18 +284,18 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
|
|
291 |
frame_idx = i - min_frame if trajectory_indices else -1
|
292 |
if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
|
293 |
end_idx = trajectory_indices[frame_idx] + 1
|
294 |
-
cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)
|
295 |
if pitch_point and i == pitch_frame:
|
296 |
x, y = pitch_point
|
297 |
-
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
|
298 |
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
|
299 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
300 |
if impact_point and i == impact_frame:
|
301 |
x, y = impact_point
|
302 |
-
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
|
303 |
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
|
304 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
305 |
-
for _ in range(SLOW_MOTION_FACTOR):
|
306 |
out.write(frame)
|
307 |
out.release()
|
308 |
return output_path
|
|
|
16 |
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
|
17 |
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
|
18 |
FRAME_RATE = 20 # Default frame rate, updated dynamically
|
19 |
+
SLOW_MOTION_FACTOR = 1.5 # Faster replay (e.g., 30 / 1.5 = 20 FPS)
|
20 |
+
CONF_THRESHOLD = 0.25 # Confidence threshold for detection
|
21 |
+
IMPACT_ZONE_Y = 0.9 # Adjusted to 90% of frame height for impact zone
|
22 |
+
IMPACT_VELOCITY_THRESHOLD = 500 # Reduced for better impact detection
|
23 |
PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
|
24 |
STUMPS_HEIGHT = 0.71 # meters (stumps height)
|
25 |
CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
|
26 |
CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
|
27 |
+
MAX_POSITION_JUMP = 150 # Increased to include more detections
|
28 |
|
29 |
def process_video(video_path):
|
30 |
if not os.path.exists(video_path):
|
|
|
59 |
if detections == 1: # Only process frames with exactly one ball detection
|
60 |
for detection in results[0].boxes:
|
61 |
if detection.cls == 0: # Class 0 is the ball
|
62 |
+
conf = detection.conf.cpu().numpy()[0]
|
63 |
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
|
64 |
# Scale coordinates back to original frame size
|
65 |
x1 = x1 * frame_width / img_width
|
|
|
69 |
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
|
70 |
detection_frames.append(frame_count - 1)
|
71 |
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
|
72 |
+
debug_log.append(f"Frame {frame_count}: 1 ball detection, confidence={conf:.3f}")
|
73 |
+
else:
|
74 |
+
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
|
75 |
frames[-1] = frame
|
|
|
76 |
# Save debug frame
|
77 |
cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
|
78 |
cap.release()
|
|
|
111 |
filtered_positions.append(curr_pos)
|
112 |
filtered_frames.append(detection_frames[i])
|
113 |
else:
|
114 |
+
debug_log.append(f"Filtered out detection at frame {detection_frames[i] + 1}: large jump ({distance:.1f} pixels)")
|
115 |
continue
|
116 |
|
117 |
if len(filtered_positions) < 2:
|
|
|
121 |
y_coords = [pos[1] for pos in filtered_positions]
|
122 |
times = np.array(filtered_frames) / FRAME_RATE
|
123 |
|
124 |
+
# Convert to 3D for visualization
|
125 |
detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
|
126 |
|
127 |
+
# Pitch point: Detection with lowest y-coordinate (near bowler's end)
|
128 |
+
pitch_idx = min(range(len(filtered_positions)), key=lambda i: filtered_positions[i][1])
|
129 |
pitch_point = filtered_positions[pitch_idx]
|
130 |
pitch_frame = filtered_frames[pitch_idx]
|
131 |
|
132 |
+
# Impact point: Detection with highest y-coordinate after pitch point (near stumps)
|
133 |
+
post_pitch_indices = [i for i in range(len(filtered_positions)) if filtered_frames[i] > pitch_frame]
|
134 |
+
if not post_pitch_indices:
|
135 |
+
return None, None, None, None, None, None, None, None, None, "Error: No detections after pitch point"
|
136 |
+
impact_idx = max(post_pitch_indices, key=lambda i: filtered_positions[i][1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
impact_point = filtered_positions[impact_idx]
|
138 |
+
impact_frame = filtered_frames[impact_idx]
|
139 |
|
140 |
try:
|
141 |
# Use linear interpolation for more stable trajectory
|
142 |
+
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
|
143 |
+
fy = interp1d(times, y_coords, kind='linear', fill_value="extrapolate")
|
144 |
except Exception as e:
|
145 |
return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
|
146 |
|
147 |
# Generate dense points for all frames between first and last detection
|
148 |
total_frames = max(detection_frames) - min(detection_frames) + 1
|
149 |
+
t_full = np.linspace(times[0], times[-1], total_frames * SLOW_MOTION_FACTOR)
|
150 |
x_full = fx(t_full)
|
151 |
y_full = fy(t_full)
|
152 |
trajectory_2d = list(zip(x_full, y_full))
|
|
|
156 |
impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
|
157 |
|
158 |
# Debug trajectory and points
|
159 |
+
debug_log = [
|
160 |
+
f"Trajectory estimated successfully",
|
161 |
+
f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f}), 3D: {pitch_point_3d}",
|
162 |
+
f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f}), 3D: {impact_point_3d}",
|
163 |
+
f"Detections in frames: {filtered_frames}",
|
164 |
+
f"Velocities: {velocities}" if 'velocities' in locals() else "Velocities: Not calculated"
|
165 |
+
]
|
166 |
# Save trajectory plot for debugging
|
167 |
import matplotlib.pyplot as plt
|
168 |
plt.plot(x_coords, y_coords, 'bo-', label='Filtered Detections')
|
|
|
171 |
plt.legend()
|
172 |
plt.savefig("trajectory_debug.png")
|
173 |
|
174 |
+
return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "\n".join(debug_log)
|
175 |
|
176 |
def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
|
177 |
"""Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
|
|
|
284 |
frame_idx = i - min_frame if trajectory_indices else -1
|
285 |
if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
|
286 |
end_idx = trajectory_indices[frame_idx] + 1
|
287 |
+
cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2) # Blue line in BGR
|
288 |
if pitch_point and i == pitch_frame:
|
289 |
x, y = pitch_point
|
290 |
+
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red circle
|
291 |
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
|
292 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
293 |
if impact_point and i == impact_frame:
|
294 |
x, y = impact_point
|
295 |
+
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1) # Yellow circle
|
296 |
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
|
297 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
298 |
+
for _ in range(int(SLOW_MOTION_FACTOR)):
|
299 |
out.write(frame)
|
300 |
out.release()
|
301 |
return output_path
|