AjaykumarPilla commited on
Commit
30b23a8
·
verified ·
1 Parent(s): 9516612

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -239
app.py CHANGED
@@ -7,51 +7,54 @@ from scipy.interpolate import interp1d
7
  import plotly.graph_objects as go
8
  import uuid
9
  import os
 
10
 
11
- # Load the trained YOLOv8n model with optimizations
12
  model = YOLO("best.pt")
13
- model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
14
-
15
- # Constants for LBW decision and video processing
16
- STUMPS_WIDTH = 0.2286 # meters (width of stumps)
17
- BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
18
- FRAME_RATE = 20 # Input video frame rate
19
- SLOW_MOTION_FACTOR = 3 # For very slow motion (3x slower)
20
- CONF_THRESHOLD = 0.2 # Confidence threshold
21
- IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
22
- IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
23
- PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
24
- STUMPS_HEIGHT = 0.71 # meters (stumps height)
25
- CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
26
- CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
27
- MAX_POSITION_JUMP = 30 # Pixels, tightened for continuous trajectory
 
 
 
 
 
 
 
28
 
29
  def process_video(video_path):
30
  if not os.path.exists(video_path):
31
  return [], [], [], "Error: Video file not found"
32
  cap = cv2.VideoCapture(video_path)
33
- # Get native video resolution
34
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
- frames = []
37
- ball_positions = []
38
- detection_frames = []
39
- debug_log = []
40
-
41
  frame_count = 0
 
42
  while cap.isOpened():
43
  ret, frame = cap.read()
44
  if not ret:
45
  break
46
  frame_count += 1
47
  frames.append(frame.copy())
48
- # Use native resolution for inference
49
  results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
50
  detections = 0
51
  for detection in results[0].boxes:
52
- if detection.cls == 0: # Class 0 is the ball
53
  detections += 1
54
- if detections == 1: # Only consider frames with exactly one detection
55
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
56
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
57
  detection_frames.append(frame_count - 1)
@@ -68,258 +71,140 @@ def process_video(video_path):
68
 
69
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
70
 
71
- def pixel_to_3d(x, y, frame_height, frame_width):
72
- """Convert 2D pixel coordinates to 3D real-world coordinates."""
73
- x_norm = x / frame_width
74
- y_norm = y / frame_height
75
- x_3d = (x_norm - 0.5) * 3.0 # Center x at 0 (middle of pitch)
76
- y_3d = y_norm * PITCH_LENGTH
77
- z_3d = (1 - y_norm) * BALL_DIAMETER * 5 # Scale to approximate ball bounce height
78
- return x_3d, y_3d, z_3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- def estimate_trajectory(ball_positions, frames, detection_frames):
81
- if len(ball_positions) < 2:
82
- return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 ball detections for trajectory"
83
  frame_height, frame_width = frames[0].shape[:2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
- # Filter out sudden changes in position for continuous trajectory
86
  filtered_positions = [ball_positions[0]]
87
  filtered_frames = [detection_frames[0]]
88
  for i in range(1, len(ball_positions)):
89
- prev_pos = filtered_positions[-1]
90
- curr_pos = ball_positions[i]
91
- distance = np.sqrt((curr_pos[0] - prev_pos[0])**2 + (curr_pos[1] - prev_pos[1])**2)
92
- if distance <= MAX_POSITION_JUMP:
93
- filtered_positions.append(curr_pos)
94
  filtered_frames.append(detection_frames[i])
95
- else:
96
- # Skip sudden jumps to maintain continuity
97
- continue
98
 
99
  if len(filtered_positions) < 2:
100
- return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
101
 
102
- x_coords = [pos[0] for pos in filtered_positions]
103
- y_coords = [pos[1] for pos in filtered_positions]
104
  times = np.array(filtered_frames) / FRAME_RATE
105
 
106
- pitch_point = filtered_positions[0]
107
- pitch_frame = filtered_frames[0]
108
-
109
- # Prioritize sudden y-change for impact detection
110
- impact_idx = None
111
- impact_frame = None
112
- for i in range(1, len(y_coords)):
113
- delta_y = abs(y_coords[i] - y_coords[i-1])
114
- if delta_y > IMPACT_DELTA_Y:
115
- impact_idx = i
116
- impact_frame = filtered_frames[i]
117
- break
118
- elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
119
- impact_idx = i
120
- impact_frame = filtered_frames[i]
121
- break
122
- if impact_idx is None:
123
- impact_idx = len(filtered_positions) - 1
124
- impact_frame = filtered_frames[-1]
125
- impact_point = filtered_positions[impact_idx]
126
-
127
  try:
128
- # Use cubic interpolation for smoother trajectory
129
- fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
130
- fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
131
  except Exception as e:
132
- return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
133
 
134
- # Generate dense points for all frames between first and last detection
135
- total_frames = max(detection_frames) - min(detection_frames) + 1
136
- t_full = np.linspace(times[0], times[-1], total_frames * SLOW_MOTION_FACTOR)
137
  x_full = fx(t_full)
138
  y_full = fy(t_full)
139
- trajectory_2d = list(zip(x_full, y_full))
140
-
141
- trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
142
- detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
143
- pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
144
- impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
145
-
146
- debug_log = (
147
- f"Trajectory estimated successfully\n"
148
- f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\n"
149
- f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})\n"
150
- f"Detections in frames: {filtered_frames}"
151
- )
152
- return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
153
-
154
- def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
155
- """Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
156
- stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
157
- stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
158
- stump_z = [0, 0, 0]
159
- stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
160
- bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
161
- bail_y = [PITCH_LENGTH, PITCH_LENGTH]
162
- bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
163
-
164
- stump_traces = []
165
- for i in range(3):
166
- stump_traces.append(go.Scatter3d(
167
- x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
168
- mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
169
- ))
170
- bail_traces = [
171
- go.Scatter3d(
172
- x=bail_x, y=bail_y, z=bail_z,
173
- mode='lines', line=dict(color='black', width=5), name='Bail'
174
- )
175
- ]
176
-
177
- pitch_scatter = go.Scatter3d(
178
- x=[pitch_point_3d[0]] if pitch_point_3d else [],
179
- y=[pitch_point_3d[1]] if pitch_point_3d else [],
180
- z=[pitch_point_3d[2]] if pitch_point_3d else [],
181
- mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
182
- )
183
- impact_scatter = go.Scatter3d(
184
- x=[impact_point_3d[0]] if impact_point_3d else [],
185
- y=[impact_point_3d[1]] if impact_point_3d else [],
186
- z=[impact_point_3d[2]] if impact_point_3d else [],
187
- mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
188
- )
189
-
190
- if plot_type == "detections":
191
- x, y, z = zip(*detections_3d) if detections_3d else ([], [], [])
192
- scatter = go.Scatter3d(
193
- x=x, y=y, z=z, mode='markers',
194
- marker=dict(size=5, color='green'), name='Single Ball Detections'
195
- )
196
- data = [scatter, pitch_scatter, impact_scatter] + stump_traces + bail_traces
197
- title = "3D Single Ball Detections"
198
- else:
199
- x, y, z = zip(*trajectory_3d) if trajectory_3d else ([], [], [])
200
- trajectory_line = go.Scatter3d(
201
- x=x, y=y, z=z, mode='lines',
202
- line=dict(color='blue', width=4), name='Ball Trajectory (Single Detections)'
203
- )
204
- data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
205
- title = "3D Ball Trajectory (Single Detections)"
206
-
207
- layout = go.Layout(
208
- title=title,
209
- scene=dict(
210
- xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
211
- xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
212
- zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
213
- aspectratio=dict(x=1, y=4, z=0.5)
214
- ),
215
- showlegend=True
216
- )
217
- fig = go.Figure(data=data, layout=layout)
218
- return fig
219
 
220
- def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
221
- if not frames:
222
- return "Error: No frames processed", None, None, None
223
- if not trajectory or len(ball_positions) < 2:
224
- return "Not enough data (insufficient ball detections)", None, None, None
225
 
226
- frame_height, frame_width = frames[0].shape[:2]
227
- stumps_x = frame_width / 2
228
- stumps_y = frame_height * 0.9
229
- stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
230
 
231
- pitch_x, pitch_y = pitch_point
232
- impact_x, impact_y = impact_point
 
233
 
234
- if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
235
- return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
236
- if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
237
- return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
238
- for x, y in trajectory:
239
- if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
240
- return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
241
- return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
242
 
243
- def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
244
- if not frames:
245
- return None
246
- frame_height, frame_width = frames[0].shape[:2]
247
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
248
- out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
249
-
250
- # Map trajectory points to all frames between first and last detection
251
- if trajectory and detection_frames:
252
- min_frame = min(detection_frames)
253
- max_frame = max(detection_frames)
254
- total_frames = max_frame - min_frame + 1
255
- trajectory_points = np.array(trajectory, dtype=np.int32).reshape((-1, 1, 2))
256
- traj_per_frame = len(trajectory) // total_frames
257
- trajectory_indices = [i * traj_per_frame for i in range(total_frames)]
258
- else:
259
- trajectory_points = np.array([], dtype=np.int32)
260
- trajectory_indices = []
261
 
262
  for i, frame in enumerate(frames):
263
- frame_idx = i - min_frame if trajectory_indices else -1
264
- if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
265
- # Draw trajectory up to current frame
266
- end_idx = trajectory_indices[frame_idx] + 1
267
- cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)
268
- if pitch_point and i == pitch_frame:
269
- x, y = pitch_point
270
- cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
271
- cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
272
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
273
- if impact_point and i == impact_frame:
274
- x, y = impact_point
275
- cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
276
- cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
277
- cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
278
  for _ in range(SLOW_MOTION_FACTOR):
279
  out.write(frame)
280
  out.release()
281
- return output_path
282
 
283
  def drs_review(video):
284
  frames, ball_positions, detection_frames, debug_log = process_video(video)
285
- if not frames:
286
- return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None
287
-
288
- trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
289
-
290
- if trajectory_2d is None:
291
- return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None)
292
 
293
- decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)
294
-
295
- output_path = f"output_{uuid.uuid4()}.mp4"
296
- slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
297
 
298
- detections_fig = None
299
- trajectory_fig = None
300
- if detections_3d:
301
- detections_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "detections")
302
- trajectory_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "trajectory")
303
 
304
- debug_output = f"{debug_log}\n{trajectory_log}"
305
- return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
306
- slow_motion_path,
307
- detections_fig,
308
- trajectory_fig)
309
 
310
- # Gradio interface
311
  iface = gr.Interface(
312
  fn=drs_review,
313
- inputs=gr.Video(label="Upload Video Clip"),
314
  outputs=[
315
- gr.Textbox(label="DRS Decision and Debug Log"),
316
- gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
317
- gr.Plot(label="3D Single Ball Detections Plot"),
318
- gr.Plot(label="3D Ball Trajectory Plot (Single Detections)")
319
  ],
320
- title="AI-Powered DRS for LBW in Local Cricket",
321
- description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show single-detection frames (green markers) and trajectory (blue line) with wicket lines (black), pitch point (red), and impact point (yellow)."
322
  )
323
 
324
  if __name__ == "__main__":
325
- iface.launch()
 
7
  import plotly.graph_objects as go
8
  import uuid
9
  import os
10
+ import tempfile
11
 
12
+ # Load YOLOv8 model and resolve class index
13
  model = YOLO("best.pt")
14
+ model.to('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # Dynamically resolve ball class index
17
+ ball_class_index = None
18
+ for k, v in model.names.items():
19
+ if v.lower() == "cricketball":
20
+ ball_class_index = k
21
+ break
22
+ if ball_class_index is None:
23
+ raise ValueError("Class 'cricketBall' not found in model.names")
24
+
25
+ # Constants
26
+ STUMPS_WIDTH = 0.2286
27
+ BALL_DIAMETER = 0.073
28
+ FRAME_RATE = 20
29
+ SLOW_MOTION_FACTOR = 3
30
+ CONF_THRESHOLD = 0.2
31
+ IMPACT_ZONE_Y = 0.85
32
+ IMPACT_DELTA_Y = 50
33
+ PITCH_LENGTH = 20.12
34
+ STUMPS_HEIGHT = 0.71
35
+ MAX_POSITION_JUMP = 30
36
 
37
  def process_video(video_path):
38
  if not os.path.exists(video_path):
39
  return [], [], [], "Error: Video file not found"
40
  cap = cv2.VideoCapture(video_path)
 
41
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
42
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
43
+ frames, ball_positions, detection_frames, debug_log = [], [], [], []
 
 
 
 
44
  frame_count = 0
45
+
46
  while cap.isOpened():
47
  ret, frame = cap.read()
48
  if not ret:
49
  break
50
  frame_count += 1
51
  frames.append(frame.copy())
 
52
  results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
53
  detections = 0
54
  for detection in results[0].boxes:
55
+ if int(detection.cls) == ball_class_index:
56
  detections += 1
57
+ if detections == 1:
58
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
59
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
60
  detection_frames.append(frame_count - 1)
 
71
 
72
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
73
 
74
+ def find_bounce_point(ball_coords):
75
+ """
76
+ Detect bounce point using y-derivative reversal with early-frame suppression.
77
+ Looks for where y increases then decreases (ball hits ground).
78
+ """
79
+ y_coords = [p[1] for p in ball_coords]
80
+ min_index = None
81
+
82
+ for i in range(2, len(y_coords) - 2):
83
+ dy1 = y_coords[i] - y_coords[i - 1]
84
+ dy2 = y_coords[i + 1] - y_coords[i]
85
+ if dy1 > 0 and dy2 < 0:
86
+ if i > len(y_coords) * 0.2:
87
+ min_index = i
88
+ break
89
+
90
+ if min_index is not None:
91
+ return ball_coords[min_index]
92
+
93
+ return ball_coords[len(ball_coords)//2]
94
+
95
+ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
96
+ if not frames or not trajectory or len(ball_positions) < 2:
97
+ return "Not enough data", trajectory, pitch_point, impact_point
98
 
 
 
 
99
  frame_height, frame_width = frames[0].shape[:2]
100
+ stumps_x = frame_width / 2
101
+ stumps_y = frame_height * 0.9
102
+ stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
103
+
104
+ pitch_x, _ = pitch_point
105
+ impact_x, impact_y = impact_point
106
+
107
+ if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
108
+ return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
109
+ if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
110
+ return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
111
+ for x, y in trajectory:
112
+ if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
113
+ return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
114
+ return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
115
+
116
+ def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
117
+ if len(ball_positions) < 2:
118
+ return None, None, None, "Error: Not enough ball detections"
119
 
 
120
  filtered_positions = [ball_positions[0]]
121
  filtered_frames = [detection_frames[0]]
122
  for i in range(1, len(ball_positions)):
123
+ prev, curr = filtered_positions[-1], ball_positions[i]
124
+ if np.linalg.norm(np.array(curr) - np.array(prev)) <= MAX_POSITION_JUMP:
125
+ filtered_positions.append(curr)
 
 
126
  filtered_frames.append(detection_frames[i])
 
 
 
127
 
128
  if len(filtered_positions) < 2:
129
+ return None, None, None, "Error: Filtered detections too few"
130
 
131
+ x_vals = [p[0] for p in filtered_positions]
132
+ y_vals = [p[1] for p in filtered_positions]
133
  times = np.array(filtered_frames) / FRAME_RATE
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  try:
136
+ fx = interp1d(times, x_vals, kind='cubic', fill_value="extrapolate")
137
+ fy = interp1d(times, y_vals, kind='cubic', fill_value="extrapolate")
 
138
  except Exception as e:
139
+ return None, None, None, f"Interpolation error: {str(e)}"
140
 
141
+ total_frames = max(filtered_frames) - min(filtered_frames) + 1
142
+ t_full = np.linspace(times[0], times[-1], max(5, total_frames * SLOW_MOTION_FACTOR))
 
143
  x_full = fx(t_full)
144
  y_full = fy(t_full)
145
+ trajectory = list(zip(x_full, y_full))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
+ pitch_point = find_bounce_point(filtered_positions)
148
+ impact_point = filtered_positions[-1]
 
 
 
149
 
150
+ return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
 
 
 
151
 
152
+ def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
153
+ if not frames or not trajectory:
154
+ return None
155
 
156
+ temp_file = os.path.join(tempfile.gettempdir(), f"drs_output_{uuid.uuid4()}.mp4")
157
+ height, width = frames[0].shape[:2]
158
+ out = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
 
 
 
 
 
159
 
160
+ min_frame = min(detection_frames)
161
+ max_frame = max(detection_frames)
162
+ total_frames = max_frame - min_frame + 1
163
+ traj_per_frame = max(1, len(trajectory) // total_frames)
164
+ indices = [min(i * traj_per_frame, len(trajectory)-1) for i in range(total_frames)]
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
  for i, frame in enumerate(frames):
167
+ idx = i - min_frame
168
+ if 0 <= idx < len(indices):
169
+ end_idx = indices[idx]
170
+ points = np.array(trajectory[:end_idx+1], dtype=np.int32).reshape((-1, 1, 2))
171
+ cv2.polylines(frame, [points], False, (255, 0, 0), 2)
172
+ if pitch_point and i == detection_frames[0]:
173
+ cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
174
+ if impact_point and i == detection_frames[-1]:
175
+ cv2.circle(frame, tuple(map(int, impact_point)), 6, (0, 255, 255), -1)
 
 
 
 
 
 
176
  for _ in range(SLOW_MOTION_FACTOR):
177
  out.write(frame)
178
  out.release()
179
+ return temp_file
180
 
181
  def drs_review(video):
182
  frames, ball_positions, detection_frames, debug_log = process_video(video)
183
+ if not frames or not ball_positions:
184
+ return "No frames or detections found.", None
 
 
 
 
 
185
 
186
+ frame_height, frame_width = frames[0].shape[:2]
187
+ trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
188
+ if not trajectory:
189
+ return f"{log}\n{debug_log}", None
190
 
191
+ decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
192
+ replay_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
 
 
 
193
 
194
+ result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
195
+ return result_log, replay_path
 
 
 
196
 
197
+ # Gradio Interface
198
  iface = gr.Interface(
199
  fn=drs_review,
200
+ inputs=gr.Video(label="Upload Cricket Delivery Video"),
201
  outputs=[
202
+ gr.Textbox(label="DRS Result and Debug Info"),
203
+ gr.Video(label="Replay with Trajectory & Decision")
 
 
204
  ],
205
+ title="GullyDRS - AI-Powered LBW Review",
206
+ description="Upload a cricket delivery video. The system will track the ball, estimate trajectory, and return a replay with an OUT/NOT OUT decision."
207
  )
208
 
209
  if __name__ == "__main__":
210
+ iface.launch()