AjaykumarPilla commited on
Commit
45d3ff2
·
verified ·
1 Parent(s): 08efaf0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -49
app.py CHANGED
@@ -4,6 +4,7 @@ import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
 
7
  import uuid
8
  import os
9
 
@@ -14,24 +15,25 @@ model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
14
  # Constants for LBW decision and video processing
15
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
16
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
17
- FRAME_RATE = 20 # Input video frame rate
18
  SLOW_MOTION_FACTOR = 3 # For very slow motion (3x slower)
19
- CONF_THRESHOLD = 0.2 # Confidence threshold
20
- IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
21
- IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
22
  PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
23
  STUMPS_HEIGHT = 0.71 # meters (stumps height)
24
  CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
25
  CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
26
- MAX_POSITION_JUMP = 30 # Pixels, tightened for continuous trajectory
27
 
28
  def process_video(video_path):
29
  if not os.path.exists(video_path):
30
  return [], [], [], "Error: Video file not found"
31
  cap = cv2.VideoCapture(video_path)
32
- # Get native video resolution
33
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
 
35
  frames = []
36
  ball_positions = []
37
  detection_frames = []
@@ -44,8 +46,9 @@ def process_video(video_path):
44
  break
45
  frame_count += 1
46
  frames.append(frame.copy())
47
- # Use smaller image size to speed up detection
48
- results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(384, 640), iou=0.5, max_det=1)
 
49
  detections = 0
50
  for detection in results[0].boxes:
51
  if detection.cls == 0: # Class 0 is the ball
@@ -57,6 +60,8 @@ def process_video(video_path):
57
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
58
  frames[-1] = frame
59
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
 
 
60
  cap.release()
61
 
62
  if not ball_positions:
@@ -64,6 +69,7 @@ def process_video(video_path):
64
  else:
65
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
66
  debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
 
67
 
68
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
69
 
@@ -92,7 +98,6 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
92
  filtered_positions.append(curr_pos)
93
  filtered_frames.append(detection_frames[i])
94
  else:
95
- # Skip sudden jumps to maintain continuity
96
  continue
97
 
98
  if len(filtered_positions) < 2:
@@ -102,21 +107,21 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
102
  y_coords = [pos[1] for pos in filtered_positions]
103
  times = np.array(filtered_frames) / FRAME_RATE
104
 
105
- # Pitch point detection: Assume it happens when the ball reaches a certain low point on the y-axis
106
- pitch_point = None
107
- pitch_frame = None
108
- for i in range(1, len(y_coords)):
109
- if y_coords[i] > frame_height * 0.75: # The ball reaches near the ground
110
- pitch_point = filtered_positions[i]
111
- pitch_frame = filtered_frames[i]
112
- break
113
 
114
- # Impact point detection: Look for sudden changes in the y-position (delta_y) or when ball enters impact zone
115
  impact_idx = None
116
  impact_frame = None
 
 
117
  for i in range(1, len(y_coords)):
118
- delta_y = abs(y_coords[i] - y_coords[i-1])
119
- if delta_y > IMPACT_DELTA_Y:
120
  impact_idx = i
121
  impact_frame = filtered_frames[i]
122
  break
@@ -130,52 +135,113 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
130
  impact_point = filtered_positions[impact_idx]
131
 
132
  try:
133
- # Use cubic interpolation for smoother trajectory
134
- fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
135
- fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
136
  except Exception as e:
137
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
138
 
139
  # Generate dense points for all frames between first and last detection
140
  total_frames = max(detection_frames) - min(detection_frames) + 1
141
- t_full = np.linspace(times[0], times[-1], total_frames * SLOW_MOTION_FACTOR)
142
  x_full = fx(t_full)
143
  y_full = fy(t_full)
144
  trajectory_2d = list(zip(x_full, y_full))
145
 
146
  trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
147
- detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
148
-
149
- # Handle missing pitch and impact points gracefully
150
- pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width) if pitch_point else None
151
- impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width) if impact_point else None
152
-
153
- # Handle cases where no pitch/impact point is found
154
- if pitch_point is None:
155
- pitch_frame = "N/A"
156
- pitch_point_3d = None # No 3D coordinates for pitch point
157
- if impact_point is None:
158
- impact_frame = "N/A"
159
- impact_point_3d = None # No 3D coordinates for impact point
160
 
 
161
  debug_log = (
162
  f"Trajectory estimated successfully\n"
163
- f"Pitch point at frame {pitch_frame + 1 if pitch_frame != 'N/A' else 'N/A'}: {pitch_point if pitch_point else 'Not detected'}\n"
164
- f"Impact point at frame {impact_frame + 1 if impact_frame != 'N/A' else 'N/A'}: {impact_point if impact_point else 'Not detected'}\n"
165
- f"Detections in frames: {filtered_frames}"
 
166
  )
 
 
 
 
 
 
 
 
167
  return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
  def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
170
  if not frames:
171
  return "Error: No frames processed", None, None, None
172
  if not trajectory or len(ball_positions) < 2:
173
  return "Not enough data (insufficient ball detections)", None, None, None
174
 
175
- # Check for None values before unpacking
176
- if pitch_point is None or impact_point is None:
177
- return "Not Out (Unable to determine pitch or impact points)", trajectory, pitch_point, impact_point
178
-
179
  frame_height, frame_width = frames[0].shape[:2]
180
  stumps_x = frame_width / 2
181
  stumps_y = frame_height * 0.9
@@ -200,7 +266,6 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
200
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
201
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
202
 
203
- # Map trajectory points to all frames between first and last detection
204
  if trajectory and detection_frames:
205
  min_frame = min(detection_frames)
206
  max_frame = max(detection_frames)
@@ -215,7 +280,6 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
215
  for i, frame in enumerate(frames):
216
  frame_idx = i - min_frame if trajectory_indices else -1
217
  if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
218
- # Draw trajectory up to current frame
219
  end_idx = trajectory_indices[frame_idx] + 1
220
  cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)
221
  if pitch_point and i == pitch_frame:
@@ -248,9 +312,17 @@ def drs_review(video):
248
  output_path = f"output_{uuid.uuid4()}.mp4"
249
  slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
250
 
 
 
 
 
 
 
251
  debug_output = f"{debug_log}\n{trajectory_log}"
252
  return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
253
- slow_motion_path)
 
 
254
 
255
  # Gradio interface
256
  iface = gr.Interface(
@@ -259,10 +331,12 @@ iface = gr.Interface(
259
  outputs=[
260
  gr.Textbox(label="DRS Decision and Debug Log"),
261
  gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
 
 
262
  ],
263
  title="AI-Powered DRS for LBW in Local Cricket",
264
- description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle)."
265
  )
266
 
267
  if __name__ == "__main__":
268
- iface.launch()
 
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
7
+ import plotly.graph_objects as go
8
  import uuid
9
  import os
10
 
 
15
  # Constants for LBW decision and video processing
16
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
17
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
18
+ FRAME_RATE = 20 # Input video frame rate (to be updated dynamically)
19
  SLOW_MOTION_FACTOR = 3 # For very slow motion (3x slower)
20
+ CONF_THRESHOLD = 0.4 # Increased confidence threshold for better detection
21
+ IMPACT_ZONE_Y = 0.8 # Adjusted fraction of frame height for impact zone
22
+ IMPACT_VELOCITY_THRESHOLD = 1000 # Pixels/second for detecting impact
23
  PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
24
  STUMPS_HEIGHT = 0.71 # meters (stumps height)
25
  CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
26
  CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
27
+ MAX_POSITION_JUMP = 50 # Increased for smoother trajectory
28
 
29
  def process_video(video_path):
30
  if not os.path.exists(video_path):
31
  return [], [], [], "Error: Video file not found"
32
  cap = cv2.VideoCapture(video_path)
33
+ # Get native video resolution and frame rate
34
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ FRAME_RATE = cap.get(cv2.CAP_PROP_FPS) or 20 # Use actual frame rate or default
37
  frames = []
38
  ball_positions = []
39
  detection_frames = []
 
46
  break
47
  frame_count += 1
48
  frames.append(frame.copy())
49
+ # Enhance frame contrast for better detection
50
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=10)
51
+ results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
52
  detections = 0
53
  for detection in results[0].boxes:
54
  if detection.cls == 0: # Class 0 is the ball
 
60
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
61
  frames[-1] = frame
62
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
63
+ # Save debug frame
64
+ cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
65
  cap.release()
66
 
67
  if not ball_positions:
 
69
  else:
70
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
71
  debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
72
+ debug_log.append(f"Video frame rate: {FRAME_RATE}")
73
 
74
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
75
 
 
98
  filtered_positions.append(curr_pos)
99
  filtered_frames.append(detection_frames[i])
100
  else:
 
101
  continue
102
 
103
  if len(filtered_positions) < 2:
 
107
  y_coords = [pos[1] for pos in filtered_positions]
108
  times = np.array(filtered_frames) / FRAME_RATE
109
 
110
+ # Convert to 3D for pitch point detection
111
+ detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
112
+
113
+ # Pitch point: Detection with lowest z-coordinate (closest to ground)
114
+ pitch_idx = min(range(len(detections_3d)), key=lambda i: detections_3d[i][2])
115
+ pitch_point = filtered_positions[pitch_idx]
116
+ pitch_frame = filtered_frames[pitch_idx]
 
117
 
118
+ # Impact point: Detect sudden velocity change or impact zone
119
  impact_idx = None
120
  impact_frame = None
121
+ velocities = [np.sqrt((x_coords[i] - x_coords[i-1])**2 + (y_coords[i] - y_coords[i-1])**2) / (times[i] - times[i-1])
122
+ for i in range(1, len(x_coords))]
123
  for i in range(1, len(y_coords)):
124
+ if velocities[i-1] > IMPACT_VELOCITY_THRESHOLD:
 
125
  impact_idx = i
126
  impact_frame = filtered_frames[i]
127
  break
 
135
  impact_point = filtered_positions[impact_idx]
136
 
137
  try:
138
+ # Use linear interpolation for more stable trajectory
139
+ fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
140
+ fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
141
  except Exception as e:
142
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
143
 
144
  # Generate dense points for all frames between first and last detection
145
  total_frames = max(detection_frames) - min(detection_frames) + 1
146
+ t_full = np.linspace(times[0], times[impact_idx], total_frames * SLOW_MOTION_FACTOR)
147
  x_full = fx(t_full)
148
  y_full = fy(t_full)
149
  trajectory_2d = list(zip(x_full, y_full))
150
 
151
  trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
152
+ pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
153
+ impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
 
 
 
 
 
 
 
 
 
 
 
154
 
155
+ # Debug trajectory and points
156
  debug_log = (
157
  f"Trajectory estimated successfully\n"
158
+ f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f}), 3D: {pitch_point_3d}\n"
159
+ f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f}), 3D: {impact_point_3d}\n"
160
+ f"Detections in frames: {filtered_frames}\n"
161
+ f"Velocities: {velocities}"
162
  )
163
+ # Save trajectory plot for debugging
164
+ import matplotlib.pyplot as plt
165
+ plt.plot(x_coords, y_coords, 'bo-', label='Filtered Detections')
166
+ plt.plot(pitch_point[0], pitch_point[1], 'ro', label='Pitch Point')
167
+ plt.plot(impact_point[0], impact_point[1], 'yo', label='Impact Point')
168
+ plt.legend()
169
+ plt.savefig("trajectory_debug.png")
170
+
171
  return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
172
 
173
+ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
174
+ """Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
175
+ stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
176
+ stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
177
+ stump_z = [0, 0, 0]
178
+ stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
179
+ bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
180
+ bail_y = [PITCH_LENGTH, PITCH_LENGTH]
181
+ bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
182
+
183
+ stump_traces = []
184
+ for i in range(3):
185
+ stump_traces.append(go.Scatter3d(
186
+ x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
187
+ mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
188
+ ))
189
+ bail_traces = [
190
+ go.Scatter3d(
191
+ x=bail_x, y=bail_y, z=bail_z,
192
+ mode='lines', line=dict(color='black', width=5), name='Bail'
193
+ )
194
+ ]
195
+
196
+ pitch_scatter = go.Scatter3d(
197
+ x=[pitch_point_3d[0]] ifpitch_point_3d else [],
198
+ y=[pitch_point_3d[1]] if pitch_point_3d else [],
199
+ z=[pitch_point_3d[2]] if pitch_point_3d else [],
200
+ mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
201
+ )
202
+ impact_scatter = go.Scatter3d(
203
+ x=[impact_point_3d[0]] if impact_point_3d else [],
204
+ y=[impact_point_3d[1]] if impact_point_3d else [],
205
+ z=[impact_point_3d[2]] if impact_point_3d else [],
206
+ mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
207
+ )
208
+
209
+ if plot_type == "detections":
210
+ x, y, z = zip(*detections_3d) if detections_3d else ([], [], [])
211
+ scatter = go.Scatter3d(
212
+ x=x, y=y, z=z, mode='markers',
213
+ marker=dict(size=5, color='green'), name='Single Ball Detections'
214
+ )
215
+ data = [scatter, pitch_scatter, impact_scatter] + stump_traces + bail_traces
216
+ title = "3D Single Ball Detections"
217
+ else:
218
+ x, y, z = zip(*trajectory_3d) if trajectory_3d else ([], [], [])
219
+ trajectory_line = go.Scatter3d(
220
+ x=x, y=y, z=z, mode='lines',
221
+ line=dict(color='blue', width=4), name='Ball Trajectory (Single Detections)'
222
+ )
223
+ data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
224
+ title = "3D Ball Trajectory (Single Detections)"
225
+
226
+ layout = go.Layout(
227
+ title=title,
228
+ scene=dict(
229
+ xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
230
+ xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
231
+ zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
232
+ aspectratio=dict(x=1, y=4, z=0.5)
233
+ ),
234
+ showlegend=True
235
+ )
236
+ fig = go.Figure(data=data, layout=layout)
237
+ return fig
238
+
239
  def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
240
  if not frames:
241
  return "Error: No frames processed", None, None, None
242
  if not trajectory or len(ball_positions) < 2:
243
  return "Not enough data (insufficient ball detections)", None, None, None
244
 
 
 
 
 
245
  frame_height, frame_width = frames[0].shape[:2]
246
  stumps_x = frame_width / 2
247
  stumps_y = frame_height * 0.9
 
266
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
267
  out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
268
 
 
269
  if trajectory and detection_frames:
270
  min_frame = min(detection_frames)
271
  max_frame = max(detection_frames)
 
280
  for i, frame in enumerate(frames):
281
  frame_idx = i - min_frame if trajectory_indices else -1
282
  if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
 
283
  end_idx = trajectory_indices[frame_idx] + 1
284
  cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)
285
  if pitch_point and i == pitch_frame:
 
312
  output_path = f"output_{uuid.uuid4()}.mp4"
313
  slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
314
 
315
+ detections_fig = None
316
+ trajectory_fig = None
317
+ if detections_3d:
318
+ detections_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "detections")
319
+ trajectory_fig = create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "trajectory")
320
+
321
  debug_output = f"{debug_log}\n{trajectory_log}"
322
  return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
323
+ slow_motion_path,
324
+ detections_fig,
325
+ trajectory_fig)
326
 
327
  # Gradio interface
328
  iface = gr.Interface(
 
331
  outputs=[
332
  gr.Textbox(label="DRS Decision and Debug Log"),
333
  gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
334
+ gr.Plot(label="3D Single Ball Detections Plot"),
335
+ gr.Plot(label="3D Ball Trajectory Plot (Single Detections)")
336
  ],
337
  title="AI-Powered DRS for LBW in Local Cricket",
338
+ description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show single-detection frames (green markers) and trajectory (blue line) with wicket lines (black), pitch point (red), and impact point (yellow)."
339
  )
340
 
341
  if __name__ == "__main__":
342
+ iface.launch()