AjaykumarPilla commited on
Commit
37acdc4
·
verified ·
1 Parent(s): 004ee50

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +161 -97
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
 
7
  import uuid
8
  import os
9
 
@@ -13,13 +14,15 @@ model = YOLO("best.pt")
13
  # Constants for LBW decision and video processing
14
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
- FRAME_RATE = 20 # Input video frame rate
17
- SLOW_MOTION_FACTOR = 3 # Adjusted for 20 FPS
18
  CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
- IMPACT_ZONE_Y = 0.80 # Fraction of frame height for impact zone
20
- PITCH_ZONE_Y = 0.80 # Fraction of frame height for pitch zone
21
- IMPACT_DELTA_Y = 75 # Pixels for detecting sudden y-position change
22
- IMPACT_DELTA_X = 50 # Pixels for x-position stability
 
 
23
 
24
  def process_video(video_path):
25
  if not os.path.exists(video_path):
@@ -27,7 +30,7 @@ def process_video(video_path):
27
  cap = cv2.VideoCapture(video_path)
28
  frames = []
29
  ball_positions = []
30
- detection_frames = [] # Track frames with exactly one detection
31
  debug_log = []
32
 
33
  frame_count = 0
@@ -38,96 +41,82 @@ def process_video(video_path):
38
  frame_count += 1
39
  frames.append(frame.copy())
40
  results = model.predict(frame, conf=CONF_THRESHOLD)
41
- detections = [det for det in results[0].boxes if det.cls == 0] # Class 0 is cricketBall
42
- if len(detections) == 1: # Only consider frames with exactly one detection
43
- x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
44
- x_center = (x1 + x2) / 2
45
- y_center = (y1 + y2) / 2
46
- ball_positions.append([x_center, y_center])
47
- detection_frames.append(frame_count - 1) # 0-based index
48
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
49
- debug_log.append(f"Frame {frame_count}: 1 ball detection at (x: {x_center:.1f}, y: {y_center:.1f})")
50
- else:
51
- debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections (ignored)")
52
  frames[-1] = frame
 
53
  cap.release()
54
 
55
  if not ball_positions:
56
- debug_log.append("No valid single-ball detections in any frame")
57
  else:
58
- debug_log.append(f"Total valid single-ball detections: {len(ball_positions)}")
59
 
60
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
61
 
62
- def estimate_trajectory(ball_positions, detection_frames, frames):
 
 
 
 
 
 
 
 
 
63
  if len(ball_positions) < 2:
64
- return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
65
- frame_height = frames[0].shape[0]
66
 
67
- # Extract x, y coordinates
68
  x_coords = [pos[0] for pos in ball_positions]
69
  y_coords = [pos[1] for pos in ball_positions]
70
  times = np.array(detection_frames) / FRAME_RATE
71
 
72
- # Pitch point: first detection where y exceeds PITCH_ZONE_Y
73
- pitch_idx = 0
74
- for i, y in enumerate(y_coords):
75
- if y > frame_height * PITCH_ZONE_Y:
76
- pitch_idx = i
77
- break
78
- else:
79
- debug_log = "Warning: No pitch point detected (y never exceeds PITCH_ZONE_Y), using first detection"
80
- pitch_idx = 0
81
- pitch_point = ball_positions[pitch_idx]
82
- pitch_frame = detection_frames[pitch_idx]
83
 
84
- # Impact point: sudden y-change, x stability, or y exceeds IMPACT_ZONE_Y
85
  impact_idx = None
 
86
  for i in range(1, len(y_coords)):
87
- delta_y = abs(y_coords[i] - y_coords[i-1])
88
- delta_x = abs(x_coords[i] - x_coords[i-1])
89
- if (y_coords[i] > frame_height * IMPACT_ZONE_Y or
90
- (delta_y > IMPACT_DELTA_Y and delta_x < IMPACT_DELTA_X)):
91
  impact_idx = i
 
92
  break
93
  if impact_idx is None:
94
  impact_idx = len(ball_positions) - 1
95
- debug_log = f"Warning: No impact point detected (no sudden y-change or y > IMPACT_ZONE_Y), using last detection (frame {detection_frames[impact_idx] + 1})"
96
- else:
97
- debug_log = ""
98
  impact_point = ball_positions[impact_idx]
99
- impact_frame = detection_frames[impact_idx]
100
-
101
- # Use only detected positions up to impact for trajectory
102
- x_coords = x_coords[:impact_idx + 1]
103
- y_coords = y_coords[:impact_idx + 1]
104
- times = times[:impact_idx + 1]
105
 
106
  try:
107
- fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
108
- fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
109
  except Exception as e:
110
- return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
111
-
112
- # Trajectory for visualization (detected frames only)
113
- vis_trajectory = list(zip(x_coords, y_coords))
114
 
115
- # Full trajectory for LBW (includes projection)
116
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
117
  x_full = fx(t_full)
118
  y_full = fy(t_full)
119
- full_trajectory = list(zip(x_full, y_full))
120
 
121
- debug_log = (f"{debug_log}\nTrajectory estimated successfully\n"
122
- f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\n"
123
- f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})")
124
- return full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, debug_log
125
 
126
- def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point):
 
 
 
127
  if not frames:
128
  return "Error: No frames processed", None, None, None
129
- if not full_trajectory or len(ball_positions) < 2:
130
- return "Not enough data (insufficient valid single-ball detections)", None, None, None
131
 
132
  frame_height, frame_width = frames[0].shape[:2]
133
  stumps_x = frame_width / 2
@@ -137,52 +126,111 @@ def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_po
137
  pitch_x, pitch_y = pitch_point
138
  impact_x, impact_y = impact_point
139
 
140
- # Check pitching point
141
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
142
- return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", full_trajectory, pitch_point, impact_point
143
-
144
- # Check impact point
145
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
146
- return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
147
-
148
- # Check trajectory hitting stumps
149
- for x, y in full_trajectory:
150
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
151
- return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
152
- return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
153
-
154
- def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  if not frames:
156
  return None
157
- frame_height, frame_width = frames[0].shape[:2]
158
-
159
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
160
- out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
161
 
162
- # Prepare trajectory points for visualization
163
- trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))
164
 
165
  for i, frame in enumerate(frames):
166
- # Draw trajectory (blue line) only for detected frames
167
  if i in detection_frames and trajectory_points.size > 0:
168
- idx = detection_frames.index(i) + 1
169
- if idx <= len(trajectory_points):
170
- cv2.polylines(frame, [trajectory_points[:idx]], False, (255, 0, 0), 2)
171
-
172
- # Draw pitch point (red circle) only in pitch frame
173
  if pitch_point and i == pitch_frame:
174
  x, y = pitch_point
175
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
176
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
177
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
178
-
179
- # Draw impact point (yellow circle) only in impact frame
180
  if impact_point and i == impact_frame:
181
  x, y = impact_point
182
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
183
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
184
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
185
-
186
  for _ in range(SLOW_MOTION_FACTOR):
187
  out.write(frame)
188
  out.release()
@@ -191,15 +239,29 @@ def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impac
191
  def drs_review(video):
192
  frames, ball_positions, detection_frames, debug_log = process_video(video)
193
  if not frames:
194
- return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
195
- full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, trajectory_log = estimate_trajectory(ball_positions, detection_frames, frames)
196
- decision, full_trajectory, pitch_point, impact_point = lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point)
197
 
198
  output_path = f"output_{uuid.uuid4()}.mp4"
199
- slow_motion_path = generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path)
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  debug_output = f"{debug_log}\n{trajectory_log}"
202
- return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
 
 
 
203
 
204
  # Gradio interface
205
  iface = gr.Interface(
@@ -207,10 +269,12 @@ iface = gr.Interface(
207
  inputs=gr.Video(label="Upload Video Clip"),
208
  outputs=[
209
  gr.Textbox(label="DRS Decision and Debug Log"),
210
- gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)")
 
 
211
  ],
212
  title="AI-Powered DRS for LBW in Local Cricket",
213
- description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle)."
214
  )
215
 
216
  if __name__ == "__main__":
 
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
7
+ import plotly.graph_objects as go
8
  import uuid
9
  import os
10
 
 
14
  # Constants for LBW decision and video processing
15
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
16
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
17
+ FRAME_RATE = 30 # Input video frame rate
18
+ SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
19
  CONF_THRESHOLD = 0.25 # Confidence threshold for detection
20
+ IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
21
+ IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
22
+ PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
23
+ STUMPS_HEIGHT = 0.71 # meters (stumps height)
24
+ CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
25
+ CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
26
 
27
  def process_video(video_path):
28
  if not os.path.exists(video_path):
 
30
  cap = cv2.VideoCapture(video_path)
31
  frames = []
32
  ball_positions = []
33
+ detection_frames = []
34
  debug_log = []
35
 
36
  frame_count = 0
 
41
  frame_count += 1
42
  frames.append(frame.copy())
43
  results = model.predict(frame, conf=CONF_THRESHOLD)
44
+ detections = 0
45
+ for detection in results[0].boxes:
46
+ if detection.cls == 0: # Class 0 is the ball
47
+ detections += 1
48
+ x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
49
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
50
+ detection_frames.append(frame_count - 1)
51
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
 
 
 
52
  frames[-1] = frame
53
+ debug_log.append(f"Frame {frame_count}: {detections} ball detections")
54
  cap.release()
55
 
56
  if not ball_positions:
57
+ debug_log.append("No balls detected in any frame")
58
  else:
59
+ debug_log.append(f"Total ball detections: {len(ball_positions)}")
60
 
61
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
62
 
63
+ def pixel_to_3d(x, y, frame_height, frame_width):
64
+ """Convert 2D pixel coordinates to 3D real-world coordinates."""
65
+ x_norm = x / frame_width
66
+ y_norm = y / frame_height
67
+ x_3d = (x_norm - 0.5) * 3.0 # Center x at 0 (middle of pitch)
68
+ y_3d = y_norm * PITCH_LENGTH
69
+ z_3d = (1 - y_norm) * BALL_DIAMETER * 5 # Scale to approximate ball bounce height
70
+ return x_3d, y_3d, z_3d
71
+
72
+ def estimate_trajectory(ball_positions, frames, detection_frames):
73
  if len(ball_positions) < 2:
74
+ return None, None, None, None, None, "Error: Fewer than 2 ball detections for trajectory"
75
+ frame_height, frame_width = frames[0].shape[:2]
76
 
 
77
  x_coords = [pos[0] for pos in ball_positions]
78
  y_coords = [pos[1] for pos in ball_positions]
79
  times = np.array(detection_frames) / FRAME_RATE
80
 
81
+ pitch_point = ball_positions[0]
82
+ pitch_frame = detection_frames[0]
 
 
 
 
 
 
 
 
 
83
 
 
84
  impact_idx = None
85
+ impact_frame = None
86
  for i in range(1, len(y_coords)):
87
+ if y_coords[i] > frame_height * IMPACT_ZONE_Y or abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y:
 
 
 
88
  impact_idx = i
89
+ impact_frame = detection_frames[i]
90
  break
91
  if impact_idx is None:
92
  impact_idx = len(ball_positions) - 1
93
+ impact_frame = detection_frames[-1]
 
 
94
  impact_point = ball_positions[impact_idx]
 
 
 
 
 
 
95
 
96
  try:
97
+ fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
98
+ fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='quadratic', fill_value="extrapolate")
99
  except Exception as e:
100
+ return None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
 
 
 
101
 
 
102
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
103
  x_full = fx(t_full)
104
  y_full = fy(t_full)
105
+ trajectory_2d = list(zip(x_full, y_full))
106
 
107
+ trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
108
+ detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in ball_positions]
109
+ pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
110
+ impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
111
 
112
+ debug_log = f"Trajectory estimated successfully\nPitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\nImpact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})"
113
+ return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
114
+
115
+ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
116
  if not frames:
117
  return "Error: No frames processed", None, None, None
118
+ if not trajectory or len(ball_positions) < 2:
119
+ return "Not enough data (insufficient ball detections)", None, None, None
120
 
121
  frame_height, frame_width = frames[0].shape[:2]
122
  stumps_x = frame_width / 2
 
126
  pitch_x, pitch_y = pitch_point
127
  impact_x, impact_y = impact_point
128
 
 
129
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
130
+ return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
 
 
131
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
132
+ return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
133
+ for x, y in trajectory:
 
 
134
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
135
+ return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
136
+ return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
137
+
138
+ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
139
+ """Create 3D Plotly visualization for detections or trajectory."""
140
+ # Wicket lines (stumps and bails)
141
+ stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
142
+ stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
143
+ stump_z = [0, 0, 0]
144
+ stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
145
+ bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
146
+ bail_y = [PITCH_LENGTH, PITCH_LENGTH]
147
+ bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
148
+
149
+ # Stumps (three vertical lines)
150
+ stump_traces = []
151
+ for i in range(3):
152
+ stump_traces.append(go.Scatter3d(
153
+ x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
154
+ mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
155
+ ))
156
+ # Bails (two horizontal lines)
157
+ bail_traces = [
158
+ go.Scatter3d(
159
+ x=bail_x, y=bail_y, z=bail_z,
160
+ mode='lines', line=dict(color='black', width=5), name='Bail'
161
+ )
162
+ ]
163
+
164
+ if plot_type == "detections":
165
+ # Ball detections plot
166
+ x, y, z = zip(*detections_3d)
167
+ scatter = go.Scatter3d(
168
+ x=x, y=y, z=z, mode='markers',
169
+ marker=dict(size=5, color='green'), name='Ball Detections'
170
+ )
171
+ # Pitch and impact points
172
+ pitch_scatter = go.Scatter3d(
173
+ x=[pitch_point_3d[0]], y=[pitch_point_3d[1]], z=[pitch_point_3d[2]],
174
+ mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
175
+ )
176
+ impact_scatter = go.Scatter3d(
177
+ x=[impact_point_3d[0]], y=[impact_point_3d[1]], z=[impact_point_3d[2]],
178
+ mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
179
+ )
180
+ data = [scatter, pitch_scatter, impact_scatter] + stump_traces + bail_traces
181
+ title = "3D Ball Detections"
182
+ else:
183
+ # Trajectory plot
184
+ x, y, z = zip(*trajectory_3d)
185
+ trajectory_line = go.Scatter3d(
186
+ x=x, y=y, z=z, mode='lines',
187
+ line=dict(color='blue', width=4), name='Ball Trajectory'
188
+ )
189
+ pitch_scatter = go.Scatter3d(
190
+ x=[pitch_point_3d[0]], y=[pitch_point_3d[1]], z=[pitch_point_3d[2]],
191
+ mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
192
+ )
193
+ impact_scatter = go.Scatter3d(
194
+ x=[impact_point_3d[0]], y=[impact_point_3d[1]], z=[impact_point_3d[2]],
195
+ mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
196
+ )
197
+ data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
198
+ title = "3D Ball Trajectory"
199
+
200
+ layout = go.Layout(
201
+ title=title,
202
+ scene=dict(
203
+ xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
204
+ xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
205
+ zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
206
+ aspectratio=dict(x=1, y=4, z=0.5)
207
+ ),
208
+ showlegend=True
209
+ )
210
+ fig = go.Figure(data=data, layout=layout)
211
+ return fig
212
+
213
+ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
214
  if not frames:
215
  return None
 
 
216
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
217
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
218
 
219
+ trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
 
220
 
221
  for i, frame in enumerate(frames):
 
222
  if i in detection_frames and trajectory_points.size > 0:
223
+ cv2.polylines(frame, [trajectory_points[:detection_frames.index(i) + 1]], False, (255, 0, 0), 2)
 
 
 
 
224
  if pitch_point and i == pitch_frame:
225
  x, y = pitch_point
226
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
227
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
228
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
 
 
229
  if impact_point and i == impact_frame:
230
  x, y = impact_point
231
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
232
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
233
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
 
234
  for _ in range(SLOW_MOTION_FACTOR):
235
  out.write(frame)
236
  out.release()
 
239
  def drs_review(video):
240
  frames, ball_positions, detection_frames, debug_log = process_video(video)
241
  if not frames:
242
+ return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None
243
+ trajectory, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
244
+ decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
245
 
246
  output_path = f"output_{uuid.uuid4()}.mp4"
247
+ slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
248
+
249
+ detections_plot_path = f"detections_3d_{uuid.uuid4()}.html"
250
+ trajectory_plot_path = f"trajectory_3d_{uuid.uuid4()}.html"
251
+ if detections_3d:
252
+ detections_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "detections")
253
+ detections_fig.write_html(detections_plot_path)
254
+ trajectory_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "trajectory")
255
+ trajectory_fig.write_html(trajectory_plot_path)
256
+ else:
257
+ detections_plot_path = None
258
+ trajectory_plot_path = None
259
 
260
  debug_output = f"{debug_log}\n{trajectory_log}"
261
+ return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
262
+ slow_motion_path,
263
+ detections_plot_path,
264
+ trajectory_plot_path)
265
 
266
  # Gradio interface
267
  iface = gr.Interface(
 
269
  inputs=gr.Video(label="Upload Video Clip"),
270
  outputs=[
271
  gr.Textbox(label="DRS Decision and Debug Log"),
272
+ gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
273
+ gr.File(label="3D Ball Detections Plot (HTML)"),
274
+ gr.File(label="3D Ball Trajectory Plot (HTML)")
275
  ],
276
  title="AI-Powered DRS for LBW in Local Cricket",
277
+ description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show detections and trajectory with wicket lines."
278
  )
279
 
280
  if __name__ == "__main__":