AjaykumarPilla commited on
Commit
3da7a6d
·
verified ·
1 Parent(s): f2ea3f3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +223 -104
app.py CHANGED
@@ -3,28 +3,39 @@ import numpy as np
3
  import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
- from scipy.interpolate import interp1d, CubicSpline
 
7
  import uuid
8
  import os
9
 
10
- # Load the trained YOLOv8n model
11
  model = YOLO("best.pt")
 
12
 
13
  # Constants for LBW decision and video processing
14
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
- FRAME_RATE = 20 # Input video frame rate (reduced to 20 FPS)
17
- SLOW_MOTION_FACTOR = 3 # Adjusted for 20 FPS (slower playback without being too slow)
18
- CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
- IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely (near stumps)
 
 
 
 
 
 
20
 
21
  def process_video(video_path):
22
  if not os.path.exists(video_path):
23
  return [], [], [], "Error: Video file not found"
24
  cap = cv2.VideoCapture(video_path)
 
 
 
25
  frames = []
26
  ball_positions = []
27
- detection_frames = [] # Track frames with detections
28
  debug_log = []
29
 
30
  frame_count = 0
@@ -34,15 +45,17 @@ def process_video(video_path):
34
  break
35
  frame_count += 1
36
  frames.append(frame.copy())
37
- results = model.predict(frame, conf=CONF_THRESHOLD)
 
38
  detections = 0
39
  for detection in results[0].boxes:
40
- if detection.cls == 0: # Assuming class 0 is the ball
41
  detections += 1
42
- x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
43
- ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
44
- detection_frames.append(frame_count - 1) # Store frame index (0-based)
45
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
 
46
  frames[-1] = frame
47
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
48
  cap.release()
@@ -51,54 +64,164 @@ def process_video(video_path):
51
  debug_log.append("No balls detected in any frame")
52
  else:
53
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
 
54
 
55
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
56
 
57
- def smooth_trajectory(ball_positions, frames):
58
- if len(ball_positions) < 2:
59
- return None, "Error: Fewer than 2 ball detections for trajectory"
60
-
61
- # Extract x, y coordinates
62
- x_coords = [pos[0] for pos in ball_positions]
63
- y_coords = [pos[1] for pos in ball_positions]
64
- times = np.arange(len(ball_positions)) / FRAME_RATE
65
-
66
- # Use cubic spline interpolation to smooth the trajectory
67
- try:
68
- spline_x = CubicSpline(times, x_coords, bc_type='natural')
69
- spline_y = CubicSpline(times, y_coords, bc_type='natural')
70
- except Exception as e:
71
- return None, f"Error in trajectory smoothing: {str(e)}"
72
-
73
- # Project trajectory (detected + future for LBW decision)
74
- t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
75
- x_full = spline_x(t_full)
76
- y_full = spline_y(t_full)
77
- trajectory = list(zip(x_full, y_full))
78
 
79
- return trajectory, "Trajectory smoothed successfully"
 
 
 
80
 
81
- def detect_pitch_and_impact(ball_positions, frames, frame_height):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  pitch_point = None
83
- impact_point = None
84
-
85
- # Pitch detection: Ball hits the ground (vertical velocity and low height)
86
- pitch_threshold_y = frame_height * 0.75 # Ball reaches near the ground
87
-
88
- # For the impact point, we assume that the region of interest is near the stumps
89
- impact_zone_y_min = frame_height * 0.80
90
- impact_zone_y_max = frame_height * 0.85
91
-
92
- # Detect pitch and impact points
93
- for i, (x, y) in enumerate(ball_positions):
94
- if y > pitch_threshold_y and not pitch_point:
95
- pitch_point = (x, y) # Ball has hit the ground
96
 
97
- # Check if the ball is near the batsman's impact zone (e.g., near stumps)
98
- if impact_zone_y_min <= y <= impact_zone_y_max and not impact_point:
99
- impact_point = (x, y) # Ball has impacted the batsman (bat or pad)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
- return pitch_point, impact_point
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
104
  if not frames:
@@ -108,105 +231,101 @@ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
108
 
109
  frame_height, frame_width = frames[0].shape[:2]
110
  stumps_x = frame_width / 2
111
- stumps_y = frame_height * 0.9 # Position of the stumps at the bottom of the frame
112
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
113
 
114
  pitch_x, pitch_y = pitch_point
115
  impact_x, impact_y = impact_point
116
 
117
- # Check pitching point - the ball should land between stumps
118
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
119
  return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
120
-
121
- # Check impact point - the ball should hit within the stumps area
122
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
123
  return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
124
-
125
- # Check trajectory hitting stumps
126
  for x, y in trajectory:
127
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
128
  return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
129
-
130
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
131
 
132
- def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, output_path):
133
  if not frames:
134
  return None
 
135
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
136
- out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
137
-
138
- trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
139
-
140
- pitch_point_detected = False
141
- impact_point_detected = False
 
 
 
 
 
 
 
142
 
143
  for i, frame in enumerate(frames):
144
- # Draw trajectory (blue line) only for frames with detections
145
- if i in detection_frames and trajectory_points.size > 0:
146
- cv2.polylines(frame, [trajectory_points[:detection_frames.index(i) + 1]], False, (255, 0, 0), 2)
147
-
148
- # Draw pitch point (red circle with label) when the ball touches the ground
149
- if pitch_point and not pitch_point_detected:
150
  x, y = pitch_point
151
- if y > frame.shape[0] * 0.75: # Adjust this threshold for the ground position
152
- pitch_point_detected = True
153
- if pitch_point_detected:
154
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
155
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
156
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
157
-
158
- # Draw impact point (yellow circle with label) when ball is near stumps
159
- if impact_point and not impact_point_detected:
160
  x, y = impact_point
161
- if y > frame.shape[0] * 0.85: # Adjust this threshold for impact point
162
- impact_point_detected = True
163
- if impact_point_detected:
164
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
165
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
166
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
167
-
168
- # Add wicket lines for the stumps
169
- stumps_x = frame.shape[1] // 2
170
- stumps_y = frame.shape[0] * 0.9
171
- stumps_width = frame.shape[1] * 0.1
172
- cv2.line(frame, (int(stumps_x - stumps_width / 2), int(stumps_y)),
173
- (int(stumps_x + stumps_width / 2), int(stumps_y)), (0, 255, 0), 3)
174
-
175
- # Write frames to output video
176
  for _ in range(SLOW_MOTION_FACTOR):
177
  out.write(frame)
178
-
179
  out.release()
180
  return output_path
181
 
182
  def drs_review(video):
183
  frames, ball_positions, detection_frames, debug_log = process_video(video)
184
  if not frames:
185
- return f"Error: Failed to process video", None
186
-
187
- trajectory, smoothing_log = smooth_trajectory(ball_positions, frames)
188
 
189
- # Detect pitch and impact points based on ball positions
190
- pitch_point, impact_point = detect_pitch_and_impact(ball_positions, frames, frames[0].shape[0])
 
 
191
 
192
- decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
193
 
194
  output_path = f"output_{uuid.uuid4()}.mp4"
195
- slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, output_path)
 
 
 
 
 
 
196
 
197
- return f"DRS Decision: {decision}", slow_motion_path
 
 
 
 
198
 
199
  # Gradio interface
200
  iface = gr.Interface(
201
  fn=drs_review,
202
  inputs=gr.Video(label="Upload Video Clip"),
203
  outputs=[
204
- gr.Textbox(label="DRS Decision"),
205
- gr.Video(label="Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow), Wicket Lines")
 
 
206
  ],
207
  title="AI-Powered DRS for LBW in Local Cricket",
208
- description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), impact point (yellow circle), and wicket lines."
209
  )
210
 
211
  if __name__ == "__main__":
212
- iface.launch()
 
3
  import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
+ from scipy.interpolate import interp1d
7
+ import plotly.graph_objects as go
8
  import uuid
9
  import os
10
 
11
+ # Load the trained YOLOv8n model with optimizations
12
  model = YOLO("best.pt")
13
+ model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
14
 
15
  # Constants for LBW decision and video processing
16
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
17
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
18
+ FRAME_RATE = 20 # Input video frame rate
19
+ SLOW_MOTION_FACTOR = 3 # For very slow motion (3x slower)
20
+ CONF_THRESHOLD = 0.2 # Confidence threshold
21
+ IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
22
+ IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
23
+ PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
24
+ STUMPS_HEIGHT = 0.71 # meters (stumps height)
25
+ CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
26
+ CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
27
+ MAX_POSITION_JUMP = 30 # Pixels, tightened for continuous trajectory
28
 
29
  def process_video(video_path):
30
  if not os.path.exists(video_path):
31
  return [], [], [], "Error: Video file not found"
32
  cap = cv2.VideoCapture(video_path)
33
+ # Get native video resolution
34
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
  frames = []
37
  ball_positions = []
38
+ detection_frames = []
39
  debug_log = []
40
 
41
  frame_count = 0
 
45
  break
46
  frame_count += 1
47
  frames.append(frame.copy())
48
+ # Use native resolution for inference
49
+ results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
50
  detections = 0
51
  for detection in results[0].boxes:
52
+ if detection.cls == 0: # Class 0 is the ball
53
  detections += 1
54
+ if detections == 1: # Only consider frames with exactly one detection
55
+ x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
56
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
57
+ detection_frames.append(frame_count - 1)
58
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
59
  frames[-1] = frame
60
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
61
  cap.release()
 
64
  debug_log.append("No balls detected in any frame")
65
  else:
66
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
67
+ debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
68
 
69
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
70
 
71
+ def pixel_to_3d(x, y, frame_height, frame_width):
72
+ """Convert 2D pixel coordinates to 3D real-world coordinates."""
73
+ x_norm = x / frame_width
74
+ y_norm = y / frame_height
75
+ x_3d = (x_norm - 0.5) * 3.0 # Center x at 0 (middle of pitch)
76
+ y_3d = y_norm * PITCH_LENGTH
77
+ z_3d = (1 - y_norm) * BALL_DIAMETER * 5 # Scale to approximate ball bounce height
78
+ return x_3d, y_3d, z_3d
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
+ def estimate_trajectory(ball_positions, frames, detection_frames):
81
+ if len(ball_positions) < 2:
82
+ return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 ball detections for trajectory"
83
+ frame_height, frame_width = frames[0].shape[:2]
84
 
85
+ # Filter out sudden changes in position for continuous trajectory
86
+ filtered_positions = [ball_positions[0]]
87
+ filtered_frames = [detection_frames[0]]
88
+ for i in range(1, len(ball_positions)):
89
+ prev_pos = filtered_positions[-1]
90
+ curr_pos = ball_positions[i]
91
+ distance = np.sqrt((curr_pos[0] - prev_pos[0])**2 + (curr_pos[1] - prev_pos[1])**2)
92
+ if distance <= MAX_POSITION_JUMP:
93
+ filtered_positions.append(curr_pos)
94
+ filtered_frames.append(detection_frames[i])
95
+ else:
96
+ # Skip sudden jumps to maintain continuity
97
+ continue
98
+
99
+ if len(filtered_positions) < 2:
100
+ return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
101
+
102
+ x_coords = [pos[0] for pos in filtered_positions]
103
+ y_coords = [pos[1] for pos in filtered_positions]
104
+ times = np.array(filtered_frames) / FRAME_RATE
105
+
106
+ # Pitch point detection: Assume it happens when the ball reaches a certain low point on the y-axis
107
  pitch_point = None
108
+ pitch_frame = None
109
+ for i in range(1, len(y_coords)):
110
+ if y_coords[i] > frame_height * 0.75: # The ball reaches near the ground
111
+ pitch_point = filtered_positions[i]
112
+ pitch_frame = filtered_frames[i]
113
+ break
 
 
 
 
 
 
 
114
 
115
+ # Impact point detection: Look for sudden changes in the y-position (delta_y) or when ball enters impact zone
116
+ impact_idx = None
117
+ impact_frame = None
118
+ for i in range(1, len(y_coords)):
119
+ delta_y = abs(y_coords[i] - y_coords[i-1])
120
+ if delta_y > IMPACT_DELTA_Y:
121
+ impact_idx = i
122
+ impact_frame = filtered_frames[i]
123
+ break
124
+ elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
125
+ impact_idx = i
126
+ impact_frame = filtered_frames[i]
127
+ break
128
+ if impact_idx is None:
129
+ impact_idx = len(filtered_positions) - 1
130
+ impact_frame = filtered_frames[-1]
131
+ impact_point = filtered_positions[impact_idx]
132
 
133
+ try:
134
+ # Use cubic interpolation for smoother trajectory
135
+ fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
136
+ fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
137
+ except Exception as e:
138
+ return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
139
+
140
+ # Generate dense points for all frames between first and last detection
141
+ total_frames = max(detection_frames) - min(detection_frames) + 1
142
+ t_full = np.linspace(times[0], times[-1], total_frames * SLOW_MOTION_FACTOR)
143
+ x_full = fx(t_full)
144
+ y_full = fy(t_full)
145
+ trajectory_2d = list(zip(x_full, y_full))
146
+
147
+ trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
148
+ detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
149
+ pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width) if pitch_point else None
150
+ impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width) if impact_point else None
151
+
152
+ debug_log = (
153
+ f"Trajectory estimated successfully\n"
154
+ f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\n"
155
+ f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})\n"
156
+ f"Detections in frames: {filtered_frames}"
157
+ )
158
+ return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
159
+
160
+ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
161
+ """Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
162
+ stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
163
+ stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
164
+ stump_z = [0, 0, 0]
165
+ stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
166
+ bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
167
+ bail_y = [PITCH_LENGTH, PITCH_LENGTH]
168
+ bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
169
+
170
+ stump_traces = []
171
+ for i in range(3):
172
+ stump_traces.append(go.Scatter3d(
173
+ x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
174
+ mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
175
+ ))
176
+ bail_traces = [
177
+ go.Scatter3d(
178
+ x=bail_x, y=bail_y, z=bail_z,
179
+ mode='lines', line=dict(color='black', width=5), name='Bail'
180
+ )
181
+ ]
182
+
183
+ pitch_scatter = go.Scatter3d(
184
+ x=[pitch_point_3d[0]] if pitch_point_3d else [],
185
+ y=[pitch_point_3d[1]] if pitch_point_3d else [],
186
+ z=[pitch_point_3d[2]] if pitch_point_3d else [],
187
+ mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
188
+ )
189
+ impact_scatter = go.Scatter3d(
190
+ x=[impact_point_3d[0]] if impact_point_3d else [],
191
+ y=[impact_point_3d[1]] if impact_point_3d else [],
192
+ z=[impact_point_3d[2]] if impact_point_3d else [],
193
+ mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
194
+ )
195
+
196
+ if plot_type == "detections":
197
+ x, y, z = zip(*detections_3d) if detections_3d else ([], [], [])
198
+ scatter = go.Scatter3d(
199
+ x=x, y=y, z=z, mode='markers',
200
+ marker=dict(size=5, color='green'), name='Single Ball Detections'
201
+ )
202
+ data = [scatter, pitch_scatter, impact_scatter] + stump_traces + bail_traces
203
+ title = "3D Single Ball Detections"
204
+ else:
205
+ x, y, z = zip(*trajectory_3d) if trajectory_3d else ([], [], [])
206
+ trajectory_line = go.Scatter3d(
207
+ x=x, y=y, z=z, mode='lines',
208
+ line=dict(color='blue', width=4), name='Ball Trajectory (Single Detections)'
209
+ )
210
+ data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
211
+ title = "3D Ball Trajectory (Single Detections)"
212
+
213
+ layout = go.Layout(
214
+ title=title,
215
+ scene=dict(
216
+ xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
217
+ xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
218
+ zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
219
+ aspectratio=dict(x=1, y=4, z=0.5)
220
+ ),
221
+ showlegend=True
222
+ )
223
+ fig = go.Figure(data=data, layout=layout)
224
+ return fig
225
 
226
  def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
227
  if not frames:
 
231
 
232
  frame_height, frame_width = frames[0].shape[:2]
233
  stumps_x = frame_width / 2
234
+ stumps_y = frame_height * 0.9
235
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
236
 
237
  pitch_x, pitch_y = pitch_point
238
  impact_x, impact_y = impact_point
239
 
 
240
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
241
  return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
 
 
242
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
243
  return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
 
 
244
  for x, y in trajectory:
245
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
246
  return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
 
247
  return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
248
 
249
+ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
250
  if not frames:
251
  return None
252
+ frame_height, frame_width = frames[0].shape[:2]
253
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
254
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
255
+
256
+ # Map trajectory points to all frames between first and last detection
257
+ if trajectory and detection_frames:
258
+ min_frame = min(detection_frames)
259
+ max_frame = max(detection_frames)
260
+ total_frames = max_frame - min_frame + 1
261
+ trajectory_points = np.array(trajectory, dtype=np.int32).reshape((-1, 1, 2))
262
+ traj_per_frame = len(trajectory) // total_frames
263
+ trajectory_indices = [i * traj_per_frame for i in range(total_frames)]
264
+ else:
265
+ trajectory_points = np.array([], dtype=np.int32)
266
+ trajectory_indices = []
267
 
268
  for i, frame in enumerate(frames):
269
+ frame_idx = i - min_frame if trajectory_indices else -1
270
+ if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
271
+ # Draw trajectory up to current frame
272
+ end_idx = trajectory_indices[frame_idx] + 1
273
+ cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)
274
+ if pitch_point and i == pitch_frame:
275
  x, y = pitch_point
 
 
 
276
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
277
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
278
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
279
+ if impact_point and i == impact_frame:
 
 
280
  x, y = impact_point
 
 
 
281
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
282
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
283
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
 
 
 
 
 
 
 
 
 
284
  for _ in range(SLOW_MOTION_FACTOR):
285
  out.write(frame)
 
286
  out.release()
287
  return output_path
288
 
289
  def drs_review(video):
290
  frames, ball_positions, detection_frames, debug_log = process_video(video)
291
  if not frames:
292
+ return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None
 
 
293
 
294
+ trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
295
+
296
+ if trajectory_2d is None:
297
+ return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None)
298
 
299
+ decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)
300
 
301
  output_path = f"output_{uuid.uuid4()}.mp4"
302
+ slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
303
+
304
+ detections_fig = None
305
+ trajectory_fig = None
306
+ if detections_3d:
307
+ detections_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "detections")
308
+ trajectory_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "trajectory")
309
 
310
+ debug_output = f"{debug_log}\n{trajectory_log}"
311
+ return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
312
+ slow_motion_path,
313
+ detections_fig,
314
+ trajectory_fig)
315
 
316
  # Gradio interface
317
  iface = gr.Interface(
318
  fn=drs_review,
319
  inputs=gr.Video(label="Upload Video Clip"),
320
  outputs=[
321
+ gr.Textbox(label="DRS Decision and Debug Log"),
322
+ gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
323
+ gr.Plot(label="3D Single Ball Detections Plot"),
324
+ gr.Plot(label="3D Ball Trajectory Plot (Single Detections)")
325
  ],
326
  title="AI-Powered DRS for LBW in Local Cricket",
327
+ description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show single-detection frames (green markers) and trajectory (blue line) with wicket lines (black), pitch point (red), and impact point (yellow)."
328
  )
329
 
330
  if __name__ == "__main__":
331
+ iface.launch()