Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,7 @@ from scipy.interpolate import interp1d
|
|
7 |
import plotly.graph_objects as go
|
8 |
import uuid
|
9 |
import os
|
|
|
10 |
|
11 |
# Load the trained YOLOv8n model with optimizations
|
12 |
model = YOLO("best.pt")
|
@@ -92,9 +93,6 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
92 |
if distance <= MAX_POSITION_JUMP:
|
93 |
filtered_positions.append(curr_pos)
|
94 |
filtered_frames.append(detection_frames[i])
|
95 |
-
else:
|
96 |
-
# Skip sudden jumps to maintain continuity
|
97 |
-
continue
|
98 |
|
99 |
if len(filtered_positions) < 2:
|
100 |
return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
|
@@ -106,7 +104,7 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
106 |
pitch_point = filtered_positions[0]
|
107 |
pitch_frame = filtered_frames[0]
|
108 |
|
109 |
-
#
|
110 |
impact_idx = None
|
111 |
impact_frame = None
|
112 |
for i in range(1, len(y_coords)):
|
@@ -125,17 +123,16 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
|
|
125 |
impact_point = filtered_positions[impact_idx]
|
126 |
|
127 |
try:
|
128 |
-
#
|
129 |
-
fx = interp1d(times
|
130 |
-
fy = interp1d(times
|
|
|
|
|
|
|
|
|
131 |
except Exception as e:
|
132 |
return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
|
133 |
|
134 |
-
t_full = np.linspace(times[0], times[-1], len(times) * 4) # Dense points for smooth trajectory
|
135 |
-
x_full = fx(t_full)
|
136 |
-
y_full = fy(t_full)
|
137 |
-
trajectory_2d = list(zip(x_full, y_full))
|
138 |
-
|
139 |
trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
|
140 |
detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
|
141 |
pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
|
@@ -249,6 +246,89 @@ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d
|
|
249 |
fig = go.Figure(data=data, layout=layout)
|
250 |
return fig
|
251 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
|
253 |
if not frames:
|
254 |
return None
|
@@ -282,17 +362,20 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
|
|
282 |
def drs_review(video):
|
283 |
frames, ball_positions, detection_frames, debug_log = process_video(video)
|
284 |
if not frames:
|
285 |
-
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None
|
286 |
|
287 |
trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
|
288 |
|
289 |
if trajectory_2d is None:
|
290 |
-
return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None)
|
291 |
|
292 |
decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)
|
293 |
|
294 |
-
|
295 |
-
slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame,
|
|
|
|
|
|
|
296 |
|
297 |
detections_fig = None
|
298 |
trajectory_fig = None
|
@@ -303,6 +386,7 @@ def drs_review(video):
|
|
303 |
debug_output = f"{debug_log}\n{trajectory_log}"
|
304 |
return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
|
305 |
slow_motion_path,
|
|
|
306 |
detections_fig,
|
307 |
trajectory_fig)
|
308 |
|
@@ -312,12 +396,13 @@ iface = gr.Interface(
|
|
312 |
inputs=gr.Video(label="Upload Video Clip"),
|
313 |
outputs=[
|
314 |
gr.Textbox(label="DRS Decision and Debug Log"),
|
315 |
-
gr.Video(label="
|
|
|
316 |
gr.Plot(label="3D Ball Detections Plot"),
|
317 |
gr.Plot(label="3D Ball Trajectory Plot")
|
318 |
],
|
319 |
title="AI-Powered DRS for LBW in Local Cricket",
|
320 |
-
description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show
|
321 |
)
|
322 |
|
323 |
if __name__ == "__main__":
|
|
|
7 |
import plotly.graph_objects as go
|
8 |
import uuid
|
9 |
import os
|
10 |
+
import plotly.io as pio
|
11 |
|
12 |
# Load the trained YOLOv8n model with optimizations
|
13 |
model = YOLO("best.pt")
|
|
|
93 |
if distance <= MAX_POSITION_JUMP:
|
94 |
filtered_positions.append(curr_pos)
|
95 |
filtered_frames.append(detection_frames[i])
|
|
|
|
|
|
|
96 |
|
97 |
if len(filtered_positions) < 2:
|
98 |
return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
|
|
|
104 |
pitch_point = filtered_positions[0]
|
105 |
pitch_frame = filtered_frames[0]
|
106 |
|
107 |
+
# Detect impact point
|
108 |
impact_idx = None
|
109 |
impact_frame = None
|
110 |
for i in range(1, len(y_coords)):
|
|
|
123 |
impact_point = filtered_positions[impact_idx]
|
124 |
|
125 |
try:
|
126 |
+
# Cubic interpolation for smooth trajectory
|
127 |
+
fx = interp1d(times, x_coords, kind='cubic', fill_value="extrapolate")
|
128 |
+
fy = interp1d(times, y_coords, kind='cubic', fill_value="extrapolate")
|
129 |
+
t_full = np.linspace(times[0], times[-1], len(times) * 4)
|
130 |
+
x_full = fx(t_full)
|
131 |
+
y_full = fy(t_full)
|
132 |
+
trajectory_2d = list(zip(x_full, y_full))
|
133 |
except Exception as e:
|
134 |
return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
|
135 |
|
|
|
|
|
|
|
|
|
|
|
136 |
trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
|
137 |
detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
|
138 |
pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
|
|
|
246 |
fig = go.Figure(data=data, layout=layout)
|
247 |
return fig
|
248 |
|
249 |
+
def generate_3d_trajectory_video(trajectory_3d, pitch_point_3d, impact_point_3d, detection_frames, pitch_frame, impact_frame, output_path):
|
250 |
+
if not trajectory_3d or not detection_frames:
|
251 |
+
return None
|
252 |
+
|
253 |
+
# Define video parameters
|
254 |
+
frame_width = 1280
|
255 |
+
frame_height = 720
|
256 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
257 |
+
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
|
258 |
+
|
259 |
+
# Wicket lines (stumps and bails)
|
260 |
+
stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
|
261 |
+
stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
|
262 |
+
stump_z = [0, 0, 0]
|
263 |
+
stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
|
264 |
+
bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
|
265 |
+
bail_y = [PITCH_LENGTH, PITCH_LENGTH]
|
266 |
+
bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
|
267 |
+
|
268 |
+
stump_traces = []
|
269 |
+
for i in range(3):
|
270 |
+
stump_traces.append(go.Scatter3d(
|
271 |
+
x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
|
272 |
+
mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
|
273 |
+
))
|
274 |
+
bail_traces = [
|
275 |
+
go.Scatter3d(
|
276 |
+
x=bail_x, y=bail_y, z=bail_z,
|
277 |
+
mode='lines', line=dict(color='black', width=5), name='Bail'
|
278 |
+
)
|
279 |
+
]
|
280 |
+
|
281 |
+
# Generate frames for each detection
|
282 |
+
for i, frame_idx in enumerate(detection_frames):
|
283 |
+
# Trajectory up to current frame
|
284 |
+
traj_idx = min(i * 4, len(trajectory_3d) - 1) # Match 2D trajectory density
|
285 |
+
x, y, z = zip(*trajectory_3d[:traj_idx + 1]) if trajectory_3d else ([], [], [])
|
286 |
+
trajectory_line = go.Scatter3d(
|
287 |
+
x=x, y=y, z=z, mode='lines',
|
288 |
+
line=dict(color='blue', width=4), name='Ball Trajectory'
|
289 |
+
)
|
290 |
+
|
291 |
+
# Pitch point (red marker) if at or after pitch_frame
|
292 |
+
pitch_scatter = go.Scatter3d(
|
293 |
+
x=[pitch_point_3d[0]] if frame_idx >= pitch_frame else [],
|
294 |
+
y=[pitch_point_3d[1]] if frame_idx >= pitch_frame else [],
|
295 |
+
z=[pitch_point_3d[2]] if frame_idx >= pitch_frame else [],
|
296 |
+
mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
|
297 |
+
)
|
298 |
+
|
299 |
+
# Impact point (yellow marker) if at or after impact_frame
|
300 |
+
impact_scatter = go.Scatter3d(
|
301 |
+
x=[impact_point_3d[0]] if frame_idx >= impact_frame else [],
|
302 |
+
y=[impact_point_3d[1]] if frame_idx >= impact_frame else [],
|
303 |
+
z=[impact_point_3d[2]] if frame_idx >= impact_frame else [],
|
304 |
+
mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
|
305 |
+
)
|
306 |
+
|
307 |
+
# Create frame
|
308 |
+
fig = go.Figure(
|
309 |
+
data=[trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces,
|
310 |
+
layout=go.Layout(
|
311 |
+
scene=dict(
|
312 |
+
xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
|
313 |
+
xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
|
314 |
+
zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
|
315 |
+
aspectratio=dict(x=1, y=4, z=0.5)
|
316 |
+
),
|
317 |
+
showlegend=True
|
318 |
+
)
|
319 |
+
)
|
320 |
+
|
321 |
+
# Render frame to image
|
322 |
+
img_bytes = pio.to_image(fig, format='png', width=frame_width, height=frame_height)
|
323 |
+
img = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
|
324 |
+
|
325 |
+
# Write frame multiple times for slow motion
|
326 |
+
for _ in range(SLOW_MOTION_FACTOR):
|
327 |
+
out.write(img)
|
328 |
+
|
329 |
+
out.release()
|
330 |
+
return output_path
|
331 |
+
|
332 |
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
|
333 |
if not frames:
|
334 |
return None
|
|
|
362 |
def drs_review(video):
|
363 |
frames, ball_positions, detection_frames, debug_log = process_video(video)
|
364 |
if not frames:
|
365 |
+
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None, None
|
366 |
|
367 |
trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
|
368 |
|
369 |
if trajectory_2d is None:
|
370 |
+
return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None, None)
|
371 |
|
372 |
decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)
|
373 |
|
374 |
+
output_path_2d = f"output_2d_{uuid.uuid4()}.mp4"
|
375 |
+
slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path_2d)
|
376 |
+
|
377 |
+
output_path_3d = f"output_3d_{uuid.uuid4()}.mp4"
|
378 |
+
trajectory_video_path = generate_3d_trajectory_video(trajectory_3d, pitch_point_3d, impact_point_3d, detection_frames, pitch_frame, impact_frame, output_path_3d)
|
379 |
|
380 |
detections_fig = None
|
381 |
trajectory_fig = None
|
|
|
386 |
debug_output = f"{debug_log}\n{trajectory_log}"
|
387 |
return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
|
388 |
slow_motion_path,
|
389 |
+
trajectory_video_path,
|
390 |
detections_fig,
|
391 |
trajectory_fig)
|
392 |
|
|
|
396 |
inputs=gr.Video(label="Upload Video Clip"),
|
397 |
outputs=[
|
398 |
gr.Textbox(label="DRS Decision and Debug Log"),
|
399 |
+
gr.Video(label="2D Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
|
400 |
+
gr.Video(label="3D Trajectory Video with Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow), Wicket Lines (Black)"),
|
401 |
gr.Plot(label="3D Ball Detections Plot"),
|
402 |
gr.Plot(label="3D Ball Trajectory Plot")
|
403 |
],
|
404 |
title="AI-Powered DRS for LBW in Local Cricket",
|
405 |
+
description="Upload a video clip of a cricket delivery to get an LBW decision, a 2D slow-motion replay, a 3D trajectory video, and 3D visualizations. The 2D replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D video and plots show the trajectory (blue line), detections (green markers), pitch point (red), impact point (yellow), and wicket lines (black)."
|
406 |
)
|
407 |
|
408 |
if __name__ == "__main__":
|