AjaykumarPilla commited on
Commit
4057582
·
verified ·
1 Parent(s): ab03275

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -167
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
7
- import plotly.graph_objects as go
8
  import uuid
9
  import os
10
 
@@ -14,15 +13,12 @@ model = YOLO("best.pt")
14
  # Constants for LBW decision and video processing
15
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
16
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
17
- FRAME_RATE = 30 # Input video frame rate
18
- SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
19
- CONF_THRESHOLD = 0.2 # Lowered confidence threshold to improve detection
20
- IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
 
21
  IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
22
- PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
23
- STUMPS_HEIGHT = 0.71 # meters (stumps height)
24
- CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
25
- CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
26
 
27
  def process_video(video_path):
28
  if not os.path.exists(video_path):
@@ -30,7 +26,7 @@ def process_video(video_path):
30
  cap = cv2.VideoCapture(video_path)
31
  frames = []
32
  ball_positions = []
33
- detection_frames = []
34
  debug_log = []
35
 
36
  frame_count = 0
@@ -40,83 +36,85 @@ def process_video(video_path):
40
  break
41
  frame_count += 1
42
  frames.append(frame.copy())
43
- results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=640)
44
- detections = 0
45
- for detection in results[0].boxes:
46
- if detection.cls == 0: # Class 0 is the ball
47
- detections += 1
48
- x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
49
- ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
50
- detection_frames.append(frame_count - 1)
51
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
52
  frames[-1] = frame
53
- debug_log.append(f"Frame {frame_count}: {detections} ball detections")
54
  cap.release()
55
 
56
  if not ball_positions:
57
- debug_log.append("No balls detected in any frame")
58
  else:
59
- debug_log.append(f"Total ball detections: {len(ball_positions)}")
60
 
61
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
62
 
63
- def pixel_to_3d(x, y, frame_height, frame_width):
64
- """Convert 2D pixel coordinates to 3D real-world coordinates."""
65
- x_norm = x / frame_width
66
- y_norm = y / frame_height
67
- x_3d = (x_norm - 0.5) * 3.0 # Center x at 0 (middle of pitch)
68
- y_3d = y_norm * PITCH_LENGTH
69
- z_3d = (1 - y_norm) * BALL_DIAMETER * 5 # Scale to approximate ball bounce height
70
- return x_3d, y_3d, z_3d
71
-
72
- def estimate_trajectory(ball_positions, frames, detection_frames):
73
  if len(ball_positions) < 2:
74
- return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 ball detections for trajectory"
75
- frame_height, frame_width = frames[0].shape[:2]
76
 
 
77
  x_coords = [pos[0] for pos in ball_positions]
78
  y_coords = [pos[1] for pos in ball_positions]
79
  times = np.array(detection_frames) / FRAME_RATE
80
 
81
- pitch_point = ball_positions[0]
82
- pitch_frame = detection_frames[0]
 
 
 
 
 
 
83
 
 
84
  impact_idx = None
85
- impact_frame = None
86
  for i in range(1, len(y_coords)):
87
- if y_coords[i] > frame_height * IMPACT_ZONE_Y or abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y:
 
88
  impact_idx = i
89
- impact_frame = detection_frames[i]
90
  break
91
  if impact_idx is None:
92
  impact_idx = len(ball_positions) - 1
93
- impact_frame = detection_frames[-1]
94
  impact_point = ball_positions[impact_idx]
 
 
 
 
 
 
95
 
96
  try:
97
- fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
98
- fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='quadratic', fill_value="extrapolate")
99
  except Exception as e:
100
- return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
 
 
 
101
 
 
102
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
103
  x_full = fx(t_full)
104
  y_full = fy(t_full)
105
- trajectory_2d = list(zip(x_full, y_full))
106
 
107
- trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
108
- detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in ball_positions]
109
- pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
110
- impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
111
 
112
- debug_log = f"Trajectory estimated successfully\nPitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\nImpact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})"
113
- return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
114
-
115
- def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
116
  if not frames:
117
  return "Error: No frames processed", None, None, None
118
- if not trajectory or len(ball_positions) < 2:
119
- return "Not enough data (insufficient ball detections)", None, None, None
120
 
121
  frame_height, frame_width = frames[0].shape[:2]
122
  stumps_x = frame_width / 2
@@ -126,116 +124,52 @@ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
126
  pitch_x, pitch_y = pitch_point
127
  impact_x, impact_y = impact_point
128
 
 
129
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
130
- return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
 
 
131
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
132
- return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
133
- for x, y in trajectory:
 
 
134
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
135
- return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
136
- return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
137
-
138
- def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
139
- """Create 3D Plotly visualization for detections or trajectory."""
140
- stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
141
- stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
142
- stump_z = [0, 0, 0]
143
- stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
144
- bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
145
- bail_y = [PITCH_LENGTH, PITCH_LENGTH]
146
- bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
147
-
148
- stump_traces = []
149
- for i in range(3):
150
- stump_traces.append(go.Scatter3d(
151
- x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
152
- mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
153
- ))
154
- bail_traces = [
155
- go.Scatter3d(
156
- x=bail_x, y=bail_y, z=bail_z,
157
- mode='lines', line=dict(color='black', width=5), name='Bail'
158
- )
159
- ]
160
-
161
- if plot_type == "detections":
162
- x, y, z = zip(*detections_3d) if detections_3d else ([], [], [])
163
- scatter = go.Scatter3d(
164
- x=x, y=y, z=z, mode='markers',
165
- marker=dict(size=5, color='green'), name='Ball Detections'
166
- )
167
- pitch_scatter = go.Scatter3d(
168
- x=[pitch_point_3d[0]] if pitch_point_3d else [],
169
- y=[pitch_point_3d[1]] if pitch_point_3d else [],
170
- z=[pitch_point_3d[2]] if pitch_point_3d else [],
171
- mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
172
- )
173
- impact_scatter = go.Scatter3d(
174
- x=[impact_point_3d[0]] if impact_point_3d else [],
175
- y=[impact_point_3d[1]] if impact_point_3d else [],
176
- z=[impact_point_3d[2]] if impact_point_3d else [],
177
- mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
178
- )
179
- data = [scatter, pitch_scatter, impact_scatter] + stump_traces + bail_traces
180
- title = "3D Ball Detections"
181
- else:
182
- x, y, z = zip(*trajectory_3d) if trajectory_3d else ([], [], [])
183
- trajectory_line = go.Scatter3d(
184
- x=x, y=y, z=z, mode='lines',
185
- line=dict(color='blue', width=4), name='Ball Trajectory'
186
- )
187
- pitch_scatter = go.Scatter3d(
188
- x=[pitch_point_3d[0]] if pitch_point_3d else [],
189
- y=[pitch_point_3d[1]] if pitch_point_3d else [],
190
- z=[pitch_point_3d[2]] if pitch_point_3d else [],
191
- mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
192
- )
193
- impact_scatter = go.Scatter3d(
194
- x=[impact_point_3d[0]] if impact_point_3d else [],
195
- y=[impact_point_3d[1]] if impact_point_3d else [],
196
- z=[impact_point_3d[2]] if impact_point_3d else [],
197
- mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
198
- )
199
- data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
200
- title = "3D Ball Trajectory"
201
-
202
- layout = go.Layout(
203
- title=title,
204
- scene=dict(
205
- xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
206
- xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
207
- zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
208
- aspectratio=dict(x=1, y=4, z=0.5)
209
- ),
210
- showlegend=True
211
- )
212
- fig = go.Figure(data=data, layout=layout)
213
- return fig
214
-
215
- def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
216
  if not frames:
217
  return None
 
 
218
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
219
- out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
220
 
221
- if trajectory and detection_frames:
222
- trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
223
- else:
224
- trajectory_points = np.array([], dtype=np.int32)
225
 
226
  for i, frame in enumerate(frames):
 
227
  if i in detection_frames and trajectory_points.size > 0:
228
- cv2.polylines(frame, [trajectory_points[:detection_frames.index(i) + 1]], False, (255, 0, 0), 2)
 
 
 
 
229
  if pitch_point and i == pitch_frame:
230
  x, y = pitch_point
231
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
232
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
233
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
 
 
234
  if impact_point and i == impact_frame:
235
  x, y = impact_point
236
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
237
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
238
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
 
239
  for _ in range(SLOW_MOTION_FACTOR):
240
  out.write(frame)
241
  out.release()
@@ -244,29 +178,15 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
244
  def drs_review(video):
245
  frames, ball_positions, detection_frames, debug_log = process_video(video)
246
  if not frames:
247
- return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None
248
-
249
- trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
250
-
251
- if trajectory_2d is None:
252
- return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None)
253
-
254
- decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)
255
 
256
  output_path = f"output_{uuid.uuid4()}.mp4"
257
- slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
258
-
259
- detections_fig = None
260
- trajectory_fig = None
261
- if detections_3d:
262
- detections_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "detections")
263
- trajectory_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "trajectory")
264
 
265
  debug_output = f"{debug_log}\n{trajectory_log}"
266
- return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
267
- slow_motion_path,
268
- detections_fig,
269
- trajectory_fig)
270
 
271
  # Gradio interface
272
  iface = gr.Interface(
@@ -274,12 +194,10 @@ iface = gr.Interface(
274
  inputs=gr.Video(label="Upload Video Clip"),
275
  outputs=[
276
  gr.Textbox(label="DRS Decision and Debug Log"),
277
- gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
278
- gr.Plot(label="3D Ball Detections Plot"),
279
- gr.Plot(label="3D Ball Trajectory Plot")
280
  ],
281
  title="AI-Powered DRS for LBW in Local Cricket",
282
- description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show detections (green markers) and trajectory (blue line) with wicket lines (black), pitch point (red), and impact point (yellow)."
283
  )
284
 
285
  if __name__ == "__main__":
 
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
 
7
  import uuid
8
  import os
9
 
 
13
  # Constants for LBW decision and video processing
14
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
+ FRAME_RATE = 20 # Input video frame rate
17
+ SLOW_MOTION_FACTOR = 3 # Adjusted for 20 FPS
18
+ CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
+ IMPACT_ZONE_Y = 0.85 # Fraction of frame height for impact zone
20
+ PITCH_ZONE_Y = 0.75 # Fraction of frame height for pitch zone
21
  IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
 
 
 
 
22
 
23
  def process_video(video_path):
24
  if not os.path.exists(video_path):
 
26
  cap = cv2.VideoCapture(video_path)
27
  frames = []
28
  ball_positions = []
29
+ detection_frames = [] # Track frames with exactly one detection
30
  debug_log = []
31
 
32
  frame_count = 0
 
36
  break
37
  frame_count += 1
38
  frames.append(frame.copy())
39
+ results = model.predict(frame, conf=CONF_THRESHOLD)
40
+ detections = [det for det in results[0].boxes if det.cls == 0] # Class 0 is cricketBall
41
+ if len(detections) == 1: # Only consider frames with exactly one detection
42
+ x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
43
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
44
+ detection_frames.append(frame_count - 1) # 0-based index
45
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
 
 
46
  frames[-1] = frame
47
+ debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections")
48
  cap.release()
49
 
50
  if not ball_positions:
51
+ debug_log.append("No valid single-ball detections in any frame")
52
  else:
53
+ debug_log.append(f"Total valid single-ball detections: {len(ball_positions)}")
54
 
55
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
56
 
57
+ def estimate_trajectory(ball_positions, detection_frames, frames):
 
 
 
 
 
 
 
 
 
58
  if len(ball_positions) < 2:
59
+ return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
60
+ frame_height = frames[0].shape[0]
61
 
62
+ # Extract x, y coordinates
63
  x_coords = [pos[0] for pos in ball_positions]
64
  y_coords = [pos[1] for pos in ball_positions]
65
  times = np.array(detection_frames) / FRAME_RATE
66
 
67
+ # Pitch point: first valid detection or when y exceeds PITCH_ZONE_Y
68
+ pitch_idx = 0
69
+ for i, y in enumerate(y_coords):
70
+ if y > frame_height * PITCH_ZONE_Y:
71
+ pitch_idx = i
72
+ break
73
+ pitch_point = ball_positions[pitch_idx]
74
+ pitch_frame = detection_frames[pitch_idx]
75
 
76
+ # Impact point: sudden y-change or y exceeds IMPACT_ZONE_Y
77
  impact_idx = None
 
78
  for i in range(1, len(y_coords)):
79
+ if (y_coords[i] > frame_height * IMPACT_ZONE_Y or
80
+ abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y):
81
  impact_idx = i
 
82
  break
83
  if impact_idx is None:
84
  impact_idx = len(ball_positions) - 1
 
85
  impact_point = ball_positions[impact_idx]
86
+ impact_frame = detection_frames[impact_idx]
87
+
88
+ # Use only detected positions for trajectory
89
+ x_coords = x_coords[:impact_idx + 1]
90
+ y_coords = y_coords[:impact_idx + 1]
91
+ times = times[:impact_idx + 1]
92
 
93
  try:
94
+ fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
95
+ fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
96
  except Exception as e:
97
+ return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
98
+
99
+ # Trajectory for visualization (detected frames only)
100
+ vis_trajectory = list(zip(x_coords, y_coords))
101
 
102
+ # Full trajectory for LBW (includes projection)
103
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
104
  x_full = fx(t_full)
105
  y_full = fy(t_full)
106
+ full_trajectory = list(zip(x_full, y_full))
107
 
108
+ debug_log = (f"Trajectory estimated successfully\n"
109
+ f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\n"
110
+ f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})")
111
+ return full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, debug_log
112
 
113
+ def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point):
 
 
 
114
  if not frames:
115
  return "Error: No frames processed", None, None, None
116
+ if not full_trajectory or len(ball_positions) < 2:
117
+ return "Not enough data (insufficient valid single-ball detections)", None, None, None
118
 
119
  frame_height, frame_width = frames[0].shape[:2]
120
  stumps_x = frame_width / 2
 
124
  pitch_x, pitch_y = pitch_point
125
  impact_x, impact_y = impact_point
126
 
127
+ # Check pitching point
128
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
129
+ return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", full_trajectory, pitch_point, impact_point
130
+
131
+ # Check impact point
132
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
133
+ return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
134
+
135
+ # Check trajectory hitting stumps
136
+ for x, y in full_trajectory:
137
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
138
+ return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
139
+ return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
140
+
141
+ def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  if not frames:
143
  return None
144
+ frame_height, frame_width = frames[0].shape[:2]
145
+
146
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
147
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
148
 
149
+ # Prepare trajectory points for visualization
150
+ trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))
 
 
151
 
152
  for i, frame in enumerate(frames):
153
+ # Draw trajectory (blue line) only for detected frames
154
  if i in detection_frames and trajectory_points.size > 0:
155
+ idx = detection_frames.index(i) + 1
156
+ if idx <= len(trajectory_points):
157
+ cv2.polylines(frame, [trajectory_points[:idx]], False, (255, 0, 0), 2)
158
+
159
+ # Draw pitch point (red circle) only in pitch frame
160
  if pitch_point and i == pitch_frame:
161
  x, y = pitch_point
162
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
163
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
164
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
165
+
166
+ # Draw impact point (yellow circle) only in impact frame
167
  if impact_point and i == impact_frame:
168
  x, y = impact_point
169
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
170
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
171
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
172
+
173
  for _ in range(SLOW_MOTION_FACTOR):
174
  out.write(frame)
175
  out.release()
 
178
  def drs_review(video):
179
  frames, ball_positions, detection_frames, debug_log = process_video(video)
180
  if not frames:
181
+ return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
182
+ full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, trajectory_log = estimate_trajectory(ball_positions, detection_frames, frames)
183
+ decision, full_trajectory, pitch_point, impact_point = lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point)
 
 
 
 
 
184
 
185
  output_path = f"output_{uuid.uuid4()}.mp4"
186
+ slow_motion_path = generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path)
 
 
 
 
 
 
187
 
188
  debug_output = f"{debug_log}\n{trajectory_log}"
189
+ return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
 
 
 
190
 
191
  # Gradio interface
192
  iface = gr.Interface(
 
194
  inputs=gr.Video(label="Upload Video Clip"),
195
  outputs=[
196
  gr.Textbox(label="DRS Decision and Debug Log"),
197
+ gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)")
 
 
198
  ],
199
  title="AI-Powered DRS for LBW in Local Cricket",
200
+ description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle)."
201
  )
202
 
203
  if __name__ == "__main__":