Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -4,6 +4,7 @@ import torch
|
|
4 |
from ultralytics import YOLO
|
5 |
import gradio as gr
|
6 |
from scipy.interpolate import interp1d
|
|
|
7 |
import uuid
|
8 |
import os
|
9 |
|
@@ -13,13 +14,15 @@ model = YOLO("best.pt")
|
|
13 |
# Constants for LBW decision and video processing
|
14 |
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
|
15 |
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
|
16 |
-
FRAME_RATE =
|
17 |
-
SLOW_MOTION_FACTOR =
|
18 |
CONF_THRESHOLD = 0.25 # Confidence threshold for detection
|
19 |
-
IMPACT_ZONE_Y = 0.
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
23 |
|
24 |
def process_video(video_path):
|
25 |
if not os.path.exists(video_path):
|
@@ -27,7 +30,7 @@ def process_video(video_path):
|
|
27 |
cap = cv2.VideoCapture(video_path)
|
28 |
frames = []
|
29 |
ball_positions = []
|
30 |
-
detection_frames = []
|
31 |
debug_log = []
|
32 |
|
33 |
frame_count = 0
|
@@ -38,96 +41,82 @@ def process_video(video_path):
|
|
38 |
frame_count += 1
|
39 |
frames.append(frame.copy())
|
40 |
results = model.predict(frame, conf=CONF_THRESHOLD)
|
41 |
-
detections =
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
debug_log.append(f"Frame {frame_count}: 1 ball detection at (x: {x_center:.1f}, y: {y_center:.1f})")
|
50 |
-
else:
|
51 |
-
debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections (ignored)")
|
52 |
frames[-1] = frame
|
|
|
53 |
cap.release()
|
54 |
|
55 |
if not ball_positions:
|
56 |
-
debug_log.append("No
|
57 |
else:
|
58 |
-
debug_log.append(f"Total
|
59 |
|
60 |
return frames, ball_positions, detection_frames, "\n".join(debug_log)
|
61 |
|
62 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
if len(ball_positions) < 2:
|
64 |
-
return None, None, None, None, None,
|
65 |
-
frame_height = frames[0].shape[
|
66 |
|
67 |
-
# Extract x, y coordinates
|
68 |
x_coords = [pos[0] for pos in ball_positions]
|
69 |
y_coords = [pos[1] for pos in ball_positions]
|
70 |
times = np.array(detection_frames) / FRAME_RATE
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
for i, y in enumerate(y_coords):
|
75 |
-
if y > frame_height * PITCH_ZONE_Y:
|
76 |
-
pitch_idx = i
|
77 |
-
break
|
78 |
-
else:
|
79 |
-
debug_log = "Warning: No pitch point detected (y never exceeds PITCH_ZONE_Y), using first detection"
|
80 |
-
pitch_idx = 0
|
81 |
-
pitch_point = ball_positions[pitch_idx]
|
82 |
-
pitch_frame = detection_frames[pitch_idx]
|
83 |
|
84 |
-
# Impact point: sudden y-change, x stability, or y exceeds IMPACT_ZONE_Y
|
85 |
impact_idx = None
|
|
|
86 |
for i in range(1, len(y_coords)):
|
87 |
-
|
88 |
-
delta_x = abs(x_coords[i] - x_coords[i-1])
|
89 |
-
if (y_coords[i] > frame_height * IMPACT_ZONE_Y or
|
90 |
-
(delta_y > IMPACT_DELTA_Y and delta_x < IMPACT_DELTA_X)):
|
91 |
impact_idx = i
|
|
|
92 |
break
|
93 |
if impact_idx is None:
|
94 |
impact_idx = len(ball_positions) - 1
|
95 |
-
|
96 |
-
else:
|
97 |
-
debug_log = ""
|
98 |
impact_point = ball_positions[impact_idx]
|
99 |
-
impact_frame = detection_frames[impact_idx]
|
100 |
-
|
101 |
-
# Use only detected positions up to impact for trajectory
|
102 |
-
x_coords = x_coords[:impact_idx + 1]
|
103 |
-
y_coords = y_coords[:impact_idx + 1]
|
104 |
-
times = times[:impact_idx + 1]
|
105 |
|
106 |
try:
|
107 |
-
fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
|
108 |
-
fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
|
109 |
except Exception as e:
|
110 |
-
return None, None, None, None, None,
|
111 |
-
|
112 |
-
# Trajectory for visualization (detected frames only)
|
113 |
-
vis_trajectory = list(zip(x_coords, y_coords))
|
114 |
|
115 |
-
# Full trajectory for LBW (includes projection)
|
116 |
t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
|
117 |
x_full = fx(t_full)
|
118 |
y_full = fy(t_full)
|
119 |
-
|
120 |
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
|
126 |
-
|
|
|
|
|
|
|
127 |
if not frames:
|
128 |
return "Error: No frames processed", None, None, None
|
129 |
-
if not
|
130 |
-
return "Not enough data (insufficient
|
131 |
|
132 |
frame_height, frame_width = frames[0].shape[:2]
|
133 |
stumps_x = frame_width / 2
|
@@ -137,52 +126,111 @@ def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_po
|
|
137 |
pitch_x, pitch_y = pitch_point
|
138 |
impact_x, impact_y = impact_point
|
139 |
|
140 |
-
# Check pitching point
|
141 |
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
|
142 |
-
return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})",
|
143 |
-
|
144 |
-
# Check impact point
|
145 |
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
|
146 |
-
return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})",
|
147 |
-
|
148 |
-
# Check trajectory hitting stumps
|
149 |
-
for x, y in full_trajectory:
|
150 |
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
|
151 |
-
return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})",
|
152 |
-
return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})",
|
153 |
-
|
154 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
if not frames:
|
156 |
return None
|
157 |
-
frame_height, frame_width = frames[0].shape[:2]
|
158 |
-
|
159 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
160 |
-
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (
|
161 |
|
162 |
-
|
163 |
-
trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))
|
164 |
|
165 |
for i, frame in enumerate(frames):
|
166 |
-
# Draw trajectory (blue line) only for detected frames
|
167 |
if i in detection_frames and trajectory_points.size > 0:
|
168 |
-
|
169 |
-
if idx <= len(trajectory_points):
|
170 |
-
cv2.polylines(frame, [trajectory_points[:idx]], False, (255, 0, 0), 2)
|
171 |
-
|
172 |
-
# Draw pitch point (red circle) only in pitch frame
|
173 |
if pitch_point and i == pitch_frame:
|
174 |
x, y = pitch_point
|
175 |
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
|
176 |
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
|
177 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
178 |
-
|
179 |
-
# Draw impact point (yellow circle) only in impact frame
|
180 |
if impact_point and i == impact_frame:
|
181 |
x, y = impact_point
|
182 |
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
|
183 |
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
|
184 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
185 |
-
|
186 |
for _ in range(SLOW_MOTION_FACTOR):
|
187 |
out.write(frame)
|
188 |
out.release()
|
@@ -191,15 +239,29 @@ def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impac
|
|
191 |
def drs_review(video):
|
192 |
frames, ball_positions, detection_frames, debug_log = process_video(video)
|
193 |
if not frames:
|
194 |
-
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
|
195 |
-
|
196 |
-
decision,
|
197 |
|
198 |
output_path = f"output_{uuid.uuid4()}.mp4"
|
199 |
-
slow_motion_path = generate_slow_motion(frames,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
|
201 |
debug_output = f"{debug_log}\n{trajectory_log}"
|
202 |
-
return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
|
|
|
|
|
|
|
203 |
|
204 |
# Gradio interface
|
205 |
iface = gr.Interface(
|
@@ -207,10 +269,12 @@ iface = gr.Interface(
|
|
207 |
inputs=gr.Video(label="Upload Video Clip"),
|
208 |
outputs=[
|
209 |
gr.Textbox(label="DRS Decision and Debug Log"),
|
210 |
-
gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)")
|
|
|
|
|
211 |
],
|
212 |
title="AI-Powered DRS for LBW in Local Cricket",
|
213 |
-
description="Upload a video clip of a cricket delivery to get an LBW decision
|
214 |
)
|
215 |
|
216 |
if __name__ == "__main__":
|
|
|
4 |
from ultralytics import YOLO
|
5 |
import gradio as gr
|
6 |
from scipy.interpolate import interp1d
|
7 |
+
import plotly.graph_objects as go
|
8 |
import uuid
|
9 |
import os
|
10 |
|
|
|
14 |
# Constants for LBW decision and video processing
|
15 |
STUMPS_WIDTH = 0.2286 # meters (width of stumps)
|
16 |
BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
|
17 |
+
FRAME_RATE = 30 # Input video frame rate
|
18 |
+
SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
|
19 |
CONF_THRESHOLD = 0.25 # Confidence threshold for detection
|
20 |
+
IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
|
21 |
+
IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
|
22 |
+
PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
|
23 |
+
STUMPS_HEIGHT = 0.71 # meters (stumps height)
|
24 |
+
CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
|
25 |
+
CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
|
26 |
|
27 |
def process_video(video_path):
|
28 |
if not os.path.exists(video_path):
|
|
|
30 |
cap = cv2.VideoCapture(video_path)
|
31 |
frames = []
|
32 |
ball_positions = []
|
33 |
+
detection_frames = []
|
34 |
debug_log = []
|
35 |
|
36 |
frame_count = 0
|
|
|
41 |
frame_count += 1
|
42 |
frames.append(frame.copy())
|
43 |
results = model.predict(frame, conf=CONF_THRESHOLD)
|
44 |
+
detections = 0
|
45 |
+
for detection in results[0].boxes:
|
46 |
+
if detection.cls == 0: # Class 0 is the ball
|
47 |
+
detections += 1
|
48 |
+
x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
|
49 |
+
ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
|
50 |
+
detection_frames.append(frame_count - 1)
|
51 |
+
cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
|
|
|
|
|
|
|
52 |
frames[-1] = frame
|
53 |
+
debug_log.append(f"Frame {frame_count}: {detections} ball detections")
|
54 |
cap.release()
|
55 |
|
56 |
if not ball_positions:
|
57 |
+
debug_log.append("No balls detected in any frame")
|
58 |
else:
|
59 |
+
debug_log.append(f"Total ball detections: {len(ball_positions)}")
|
60 |
|
61 |
return frames, ball_positions, detection_frames, "\n".join(debug_log)
|
62 |
|
63 |
+
def pixel_to_3d(x, y, frame_height, frame_width):
|
64 |
+
"""Convert 2D pixel coordinates to 3D real-world coordinates."""
|
65 |
+
x_norm = x / frame_width
|
66 |
+
y_norm = y / frame_height
|
67 |
+
x_3d = (x_norm - 0.5) * 3.0 # Center x at 0 (middle of pitch)
|
68 |
+
y_3d = y_norm * PITCH_LENGTH
|
69 |
+
z_3d = (1 - y_norm) * BALL_DIAMETER * 5 # Scale to approximate ball bounce height
|
70 |
+
return x_3d, y_3d, z_3d
|
71 |
+
|
72 |
+
def estimate_trajectory(ball_positions, frames, detection_frames):
|
73 |
if len(ball_positions) < 2:
|
74 |
+
return None, None, None, None, None, "Error: Fewer than 2 ball detections for trajectory"
|
75 |
+
frame_height, frame_width = frames[0].shape[:2]
|
76 |
|
|
|
77 |
x_coords = [pos[0] for pos in ball_positions]
|
78 |
y_coords = [pos[1] for pos in ball_positions]
|
79 |
times = np.array(detection_frames) / FRAME_RATE
|
80 |
|
81 |
+
pitch_point = ball_positions[0]
|
82 |
+
pitch_frame = detection_frames[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
|
|
|
84 |
impact_idx = None
|
85 |
+
impact_frame = None
|
86 |
for i in range(1, len(y_coords)):
|
87 |
+
if y_coords[i] > frame_height * IMPACT_ZONE_Y or abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y:
|
|
|
|
|
|
|
88 |
impact_idx = i
|
89 |
+
impact_frame = detection_frames[i]
|
90 |
break
|
91 |
if impact_idx is None:
|
92 |
impact_idx = len(ball_positions) - 1
|
93 |
+
impact_frame = detection_frames[-1]
|
|
|
|
|
94 |
impact_point = ball_positions[impact_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
|
96 |
try:
|
97 |
+
fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
|
98 |
+
fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='quadratic', fill_value="extrapolate")
|
99 |
except Exception as e:
|
100 |
+
return None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
|
|
|
|
|
|
|
101 |
|
|
|
102 |
t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
|
103 |
x_full = fx(t_full)
|
104 |
y_full = fy(t_full)
|
105 |
+
trajectory_2d = list(zip(x_full, y_full))
|
106 |
|
107 |
+
trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
|
108 |
+
detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in ball_positions]
|
109 |
+
pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
|
110 |
+
impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
|
111 |
|
112 |
+
debug_log = f"Trajectory estimated successfully\nPitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\nImpact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})"
|
113 |
+
return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
|
114 |
+
|
115 |
+
def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
|
116 |
if not frames:
|
117 |
return "Error: No frames processed", None, None, None
|
118 |
+
if not trajectory or len(ball_positions) < 2:
|
119 |
+
return "Not enough data (insufficient ball detections)", None, None, None
|
120 |
|
121 |
frame_height, frame_width = frames[0].shape[:2]
|
122 |
stumps_x = frame_width / 2
|
|
|
126 |
pitch_x, pitch_y = pitch_point
|
127 |
impact_x, impact_y = impact_point
|
128 |
|
|
|
129 |
if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
|
130 |
+
return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
|
|
|
|
|
131 |
if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
|
132 |
+
return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
|
133 |
+
for x, y in trajectory:
|
|
|
|
|
134 |
if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
|
135 |
+
return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
|
136 |
+
return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
|
137 |
+
|
138 |
+
def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
|
139 |
+
"""Create 3D Plotly visualization for detections or trajectory."""
|
140 |
+
# Wicket lines (stumps and bails)
|
141 |
+
stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
|
142 |
+
stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
|
143 |
+
stump_z = [0, 0, 0]
|
144 |
+
stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
|
145 |
+
bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
|
146 |
+
bail_y = [PITCH_LENGTH, PITCH_LENGTH]
|
147 |
+
bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
|
148 |
+
|
149 |
+
# Stumps (three vertical lines)
|
150 |
+
stump_traces = []
|
151 |
+
for i in range(3):
|
152 |
+
stump_traces.append(go.Scatter3d(
|
153 |
+
x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
|
154 |
+
mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
|
155 |
+
))
|
156 |
+
# Bails (two horizontal lines)
|
157 |
+
bail_traces = [
|
158 |
+
go.Scatter3d(
|
159 |
+
x=bail_x, y=bail_y, z=bail_z,
|
160 |
+
mode='lines', line=dict(color='black', width=5), name='Bail'
|
161 |
+
)
|
162 |
+
]
|
163 |
+
|
164 |
+
if plot_type == "detections":
|
165 |
+
# Ball detections plot
|
166 |
+
x, y, z = zip(*detections_3d)
|
167 |
+
scatter = go.Scatter3d(
|
168 |
+
x=x, y=y, z=z, mode='markers',
|
169 |
+
marker=dict(size=5, color='green'), name='Ball Detections'
|
170 |
+
)
|
171 |
+
# Pitch and impact points
|
172 |
+
pitch_scatter = go.Scatter3d(
|
173 |
+
x=[pitch_point_3d[0]], y=[pitch_point_3d[1]], z=[pitch_point_3d[2]],
|
174 |
+
mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
|
175 |
+
)
|
176 |
+
impact_scatter = go.Scatter3d(
|
177 |
+
x=[impact_point_3d[0]], y=[impact_point_3d[1]], z=[impact_point_3d[2]],
|
178 |
+
mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
|
179 |
+
)
|
180 |
+
data = [scatter, pitch_scatter, impact_scatter] + stump_traces + bail_traces
|
181 |
+
title = "3D Ball Detections"
|
182 |
+
else:
|
183 |
+
# Trajectory plot
|
184 |
+
x, y, z = zip(*trajectory_3d)
|
185 |
+
trajectory_line = go.Scatter3d(
|
186 |
+
x=x, y=y, z=z, mode='lines',
|
187 |
+
line=dict(color='blue', width=4), name='Ball Trajectory'
|
188 |
+
)
|
189 |
+
pitch_scatter = go.Scatter3d(
|
190 |
+
x=[pitch_point_3d[0]], y=[pitch_point_3d[1]], z=[pitch_point_3d[2]],
|
191 |
+
mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
|
192 |
+
)
|
193 |
+
impact_scatter = go.Scatter3d(
|
194 |
+
x=[impact_point_3d[0]], y=[impact_point_3d[1]], z=[impact_point_3d[2]],
|
195 |
+
mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
|
196 |
+
)
|
197 |
+
data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
|
198 |
+
title = "3D Ball Trajectory"
|
199 |
+
|
200 |
+
layout = go.Layout(
|
201 |
+
title=title,
|
202 |
+
scene=dict(
|
203 |
+
xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
|
204 |
+
xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
|
205 |
+
zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
|
206 |
+
aspectratio=dict(x=1, y=4, z=0.5)
|
207 |
+
),
|
208 |
+
showlegend=True
|
209 |
+
)
|
210 |
+
fig = go.Figure(data=data, layout=layout)
|
211 |
+
return fig
|
212 |
+
|
213 |
+
def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
|
214 |
if not frames:
|
215 |
return None
|
|
|
|
|
216 |
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
217 |
+
out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
|
218 |
|
219 |
+
trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
|
|
|
220 |
|
221 |
for i, frame in enumerate(frames):
|
|
|
222 |
if i in detection_frames and trajectory_points.size > 0:
|
223 |
+
cv2.polylines(frame, [trajectory_points[:detection_frames.index(i) + 1]], False, (255, 0, 0), 2)
|
|
|
|
|
|
|
|
|
224 |
if pitch_point and i == pitch_frame:
|
225 |
x, y = pitch_point
|
226 |
cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
|
227 |
cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
|
228 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
|
|
|
|
|
229 |
if impact_point and i == impact_frame:
|
230 |
x, y = impact_point
|
231 |
cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
|
232 |
cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
|
233 |
cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
|
|
|
234 |
for _ in range(SLOW_MOTION_FACTOR):
|
235 |
out.write(frame)
|
236 |
out.release()
|
|
|
239 |
def drs_review(video):
|
240 |
frames, ball_positions, detection_frames, debug_log = process_video(video)
|
241 |
if not frames:
|
242 |
+
return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None
|
243 |
+
trajectory, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
|
244 |
+
decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
|
245 |
|
246 |
output_path = f"output_{uuid.uuid4()}.mp4"
|
247 |
+
slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
|
248 |
+
|
249 |
+
detections_plot_path = f"detections_3d_{uuid.uuid4()}.html"
|
250 |
+
trajectory_plot_path = f"trajectory_3d_{uuid.uuid4()}.html"
|
251 |
+
if detections_3d:
|
252 |
+
detections_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "detections")
|
253 |
+
detections_fig.write_html(detections_plot_path)
|
254 |
+
trajectory_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "trajectory")
|
255 |
+
trajectory_fig.write_html(trajectory_plot_path)
|
256 |
+
else:
|
257 |
+
detections_plot_path = None
|
258 |
+
trajectory_plot_path = None
|
259 |
|
260 |
debug_output = f"{debug_log}\n{trajectory_log}"
|
261 |
+
return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
|
262 |
+
slow_motion_path,
|
263 |
+
detections_plot_path,
|
264 |
+
trajectory_plot_path)
|
265 |
|
266 |
# Gradio interface
|
267 |
iface = gr.Interface(
|
|
|
269 |
inputs=gr.Video(label="Upload Video Clip"),
|
270 |
outputs=[
|
271 |
gr.Textbox(label="DRS Decision and Debug Log"),
|
272 |
+
gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
|
273 |
+
gr.File(label="3D Ball Detections Plot (HTML)"),
|
274 |
+
gr.File(label="3D Ball Trajectory Plot (HTML)")
|
275 |
],
|
276 |
title="AI-Powered DRS for LBW in Local Cricket",
|
277 |
+
description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show detections and trajectory with wicket lines."
|
278 |
)
|
279 |
|
280 |
if __name__ == "__main__":
|