AjaykumarPilla commited on
Commit
c0552e3
·
verified ·
1 Parent(s): 4057582

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -78
app.py CHANGED
@@ -13,12 +13,10 @@ model = YOLO("best.pt")
13
  # Constants for LBW decision and video processing
14
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
- FRAME_RATE = 20 # Input video frame rate
17
- SLOW_MOTION_FACTOR = 3 # Adjusted for 20 FPS
18
  CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
- IMPACT_ZONE_Y = 0.85 # Fraction of frame height for impact zone
20
- PITCH_ZONE_Y = 0.75 # Fraction of frame height for pitch zone
21
- IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
22
 
23
  def process_video(video_path):
24
  if not os.path.exists(video_path):
@@ -26,7 +24,7 @@ def process_video(video_path):
26
  cap = cv2.VideoCapture(video_path)
27
  frames = []
28
  ball_positions = []
29
- detection_frames = [] # Track frames with exactly one detection
30
  debug_log = []
31
 
32
  frame_count = 0
@@ -37,55 +35,55 @@ def process_video(video_path):
37
  frame_count += 1
38
  frames.append(frame.copy())
39
  results = model.predict(frame, conf=CONF_THRESHOLD)
40
- detections = [det for det in results[0].boxes if det.cls == 0] # Class 0 is cricketBall
41
- if len(detections) == 1: # Only consider frames with exactly one detection
42
- x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
43
- ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
44
- detection_frames.append(frame_count - 1) # 0-based index
45
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
 
 
46
  frames[-1] = frame
47
- debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections")
48
  cap.release()
49
 
50
  if not ball_positions:
51
- debug_log.append("No valid single-ball detections in any frame")
52
  else:
53
- debug_log.append(f"Total valid single-ball detections: {len(ball_positions)}")
54
 
55
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
56
 
57
- def estimate_trajectory(ball_positions, detection_frames, frames):
58
  if len(ball_positions) < 2:
59
- return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
 
60
  frame_height = frames[0].shape[0]
61
-
62
  # Extract x, y coordinates
63
  x_coords = [pos[0] for pos in ball_positions]
64
  y_coords = [pos[1] for pos in ball_positions]
65
- times = np.array(detection_frames) / FRAME_RATE
66
 
67
- # Pitch point: first valid detection or when y exceeds PITCH_ZONE_Y
68
- pitch_idx = 0
69
  for i, y in enumerate(y_coords):
70
- if y > frame_height * PITCH_ZONE_Y:
71
- pitch_idx = i
72
  break
73
- pitch_point = ball_positions[pitch_idx]
74
- pitch_frame = detection_frames[pitch_idx]
75
-
76
- # Impact point: sudden y-change or y exceeds IMPACT_ZONE_Y
77
  impact_idx = None
78
- for i in range(1, len(y_coords)):
79
- if (y_coords[i] > frame_height * IMPACT_ZONE_Y or
80
- abs(y_coords[i] - y_coords[i-1]) > IMPACT_DELTA_Y):
81
  impact_idx = i
82
  break
83
  if impact_idx is None:
84
- impact_idx = len(ball_positions) - 1
 
85
  impact_point = ball_positions[impact_idx]
86
- impact_frame = detection_frames[impact_idx]
87
 
88
- # Use only detected positions for trajectory
89
  x_coords = x_coords[:impact_idx + 1]
90
  y_coords = y_coords[:impact_idx + 1]
91
  times = times[:impact_idx + 1]
@@ -94,111 +92,118 @@ def estimate_trajectory(ball_positions, detection_frames, frames):
94
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
95
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
96
  except Exception as e:
97
- return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
98
 
99
- # Trajectory for visualization (detected frames only)
100
- vis_trajectory = list(zip(x_coords, y_coords))
101
-
102
- # Full trajectory for LBW (includes projection)
103
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
104
  x_full = fx(t_full)
105
  y_full = fy(t_full)
106
- full_trajectory = list(zip(x_full, y_full))
107
 
108
- debug_log = (f"Trajectory estimated successfully\n"
109
- f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})\n"
110
- f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})")
111
- return full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, debug_log
112
 
113
- def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point):
114
  if not frames:
115
  return "Error: No frames processed", None, None, None
116
- if not full_trajectory or len(ball_positions) < 2:
117
- return "Not enough data (insufficient valid single-ball detections)", None, None, None
118
 
119
  frame_height, frame_width = frames[0].shape[:2]
120
  stumps_x = frame_width / 2
121
- stumps_y = frame_height * 0.9
122
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
123
 
124
  pitch_x, pitch_y = pitch_point
125
  impact_x, impact_y = impact_point
126
 
127
- # Check pitching point
128
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
129
- return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", full_trajectory, pitch_point, impact_point
130
 
131
- # Check impact point
132
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
133
- return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
134
 
135
  # Check trajectory hitting stumps
136
- for x, y in full_trajectory:
137
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
138
- return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
139
- return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", full_trajectory, pitch_point, impact_point
 
140
 
141
- def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path):
142
  if not frames:
143
  return None
144
- frame_height, frame_width = frames[0].shape[:2]
145
-
146
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
147
- out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
 
 
148
 
149
- # Prepare trajectory points for visualization
150
- trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))
151
 
152
  for i, frame in enumerate(frames):
153
- # Draw trajectory (blue line) only for detected frames
154
  if i in detection_frames and trajectory_points.size > 0:
155
- idx = detection_frames.index(i) + 1
156
- if idx <= len(trajectory_points):
157
- cv2.polylines(frame, [trajectory_points[:idx]], False, (255, 0, 0), 2)
158
 
159
- # Draw pitch point (red circle) only in pitch frame
160
- if pitch_point and i == pitch_frame:
161
  x, y = pitch_point
 
 
 
162
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
163
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
164
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
165
 
166
- # Draw impact point (yellow circle) only in impact frame
167
- if impact_point and i == impact_frame:
168
  x, y = impact_point
 
 
 
169
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
170
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
171
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
172
 
 
 
 
 
 
 
 
 
173
  for _ in range(SLOW_MOTION_FACTOR):
174
  out.write(frame)
 
175
  out.release()
176
  return output_path
177
 
178
  def drs_review(video):
179
  frames, ball_positions, detection_frames, debug_log = process_video(video)
180
  if not frames:
181
- return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
182
- full_trajectory, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, trajectory_log = estimate_trajectory(ball_positions, detection_frames, frames)
183
- decision, full_trajectory, pitch_point, impact_point = lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_point)
184
 
185
  output_path = f"output_{uuid.uuid4()}.mp4"
186
- slow_motion_path = generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impact_point, impact_frame, detection_frames, output_path)
187
 
188
- debug_output = f"{debug_log}\n{trajectory_log}"
189
- return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
190
 
191
  # Gradio interface
192
  iface = gr.Interface(
193
  fn=drs_review,
194
  inputs=gr.Video(label="Upload Video Clip"),
195
  outputs=[
196
- gr.Textbox(label="DRS Decision and Debug Log"),
197
- gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)")
198
  ],
199
  title="AI-Powered DRS for LBW in Local Cricket",
200
- description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle)."
201
  )
202
 
203
  if __name__ == "__main__":
204
- iface.launch()
 
13
  # Constants for LBW decision and video processing
14
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
+ FRAME_RATE = 20 # Input video frame rate (reduced to 20 FPS)
17
+ SLOW_MOTION_FACTOR = 3 # Adjusted for 20 FPS (slower playback without being too slow)
18
  CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
+ IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely (near stumps)
 
 
20
 
21
  def process_video(video_path):
22
  if not os.path.exists(video_path):
 
24
  cap = cv2.VideoCapture(video_path)
25
  frames = []
26
  ball_positions = []
27
+ detection_frames = [] # Track frames with detections
28
  debug_log = []
29
 
30
  frame_count = 0
 
35
  frame_count += 1
36
  frames.append(frame.copy())
37
  results = model.predict(frame, conf=CONF_THRESHOLD)
38
+ detections = 0
39
+ for detection in results[0].boxes:
40
+ if detection.cls == 0: # Assuming class 0 is the ball
41
+ detections += 1
42
+ x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
43
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
44
+ detection_frames.append(frame_count - 1) # Store frame index (0-based)
45
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
46
  frames[-1] = frame
47
+ debug_log.append(f"Frame {frame_count}: {detections} ball detections")
48
  cap.release()
49
 
50
  if not ball_positions:
51
+ debug_log.append("No balls detected in any frame")
52
  else:
53
+ debug_log.append(f"Total ball detections: {len(ball_positions)}")
54
 
55
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
56
 
57
+ def estimate_trajectory(ball_positions, frames):
58
  if len(ball_positions) < 2:
59
+ return None, None, None, "Error: Fewer than 2 ball detections for trajectory"
60
+
61
  frame_height = frames[0].shape[0]
62
+
63
  # Extract x, y coordinates
64
  x_coords = [pos[0] for pos in ball_positions]
65
  y_coords = [pos[1] for pos in ball_positions]
66
+ times = np.arange(len(ball_positions)) / FRAME_RATE
67
 
68
+ # Detect the pitch point: find when the ball touches the ground
69
+ pitch_point = None
70
  for i, y in enumerate(y_coords):
71
+ if y > frame_height * 0.75: # Threshold for ground contact (near the bottom of the frame)
72
+ pitch_point = ball_positions[i]
73
  break
74
+
75
+ # Find impact point (closest to batsman, near stumps)
 
 
76
  impact_idx = None
77
+ for i, y in enumerate(y_coords):
78
+ if y > frame_height * IMPACT_ZONE_Y: # Ball is near stumps/batsman
 
79
  impact_idx = i
80
  break
81
  if impact_idx is None:
82
+ impact_idx = len(ball_positions) - 1 # Fallback to last detection
83
+
84
  impact_point = ball_positions[impact_idx]
 
85
 
86
+ # Use positions up to impact for interpolation
87
  x_coords = x_coords[:impact_idx + 1]
88
  y_coords = y_coords[:impact_idx + 1]
89
  times = times[:impact_idx + 1]
 
92
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
93
  fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
94
  except Exception as e:
95
+ return None, None, None, f"Error in trajectory interpolation: {str(e)}"
96
 
97
+ # Project trajectory (detected + future for LBW decision)
 
 
 
98
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 10)
99
  x_full = fx(t_full)
100
  y_full = fy(t_full)
101
+ trajectory = list(zip(x_full, y_full))
102
 
103
+ return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
 
 
 
104
 
105
+ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
106
  if not frames:
107
  return "Error: No frames processed", None, None, None
108
+ if not trajectory or len(ball_positions) < 2:
109
+ return "Not enough data (insufficient ball detections)", None, None, None
110
 
111
  frame_height, frame_width = frames[0].shape[:2]
112
  stumps_x = frame_width / 2
113
+ stumps_y = frame_height * 0.9 # Position of the stumps at the bottom of the frame
114
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
115
 
116
  pitch_x, pitch_y = pitch_point
117
  impact_x, impact_y = impact_point
118
 
119
+ # Check pitching point - the ball should land between stumps
120
  if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
121
+ return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
122
 
123
+ # Check impact point - the ball should hit within the stumps area
124
  if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
125
+ return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
126
 
127
  # Check trajectory hitting stumps
128
+ for x, y in trajectory:
129
  if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
130
+ return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
131
+
132
+ return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
133
 
134
+ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, output_path):
135
  if not frames:
136
  return None
 
 
137
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
138
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
139
+
140
+ trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
141
 
142
+ pitch_point_detected = False
143
+ impact_point_detected = False
144
 
145
  for i, frame in enumerate(frames):
146
+ # Draw trajectory (blue line) only for frames with detections
147
  if i in detection_frames and trajectory_points.size > 0:
148
+ cv2.polylines(frame, [trajectory_points[:detection_frames.index(i) + 1]], False, (255, 0, 0), 2)
 
 
149
 
150
+ # Draw pitch point (red circle with label) when the ball touches the ground
151
+ if pitch_point and not pitch_point_detected:
152
  x, y = pitch_point
153
+ if y > frame.shape[0] * 0.75: # Adjust this threshold for the ground position
154
+ pitch_point_detected = True
155
+ if pitch_point_detected:
156
  cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
157
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
158
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
159
 
160
+ # Draw impact point (yellow circle with label) when ball is near stumps
161
+ if impact_point and not impact_point_detected:
162
  x, y = impact_point
163
+ if y > frame.shape[0] * 0.85: # Adjust this threshold for impact point
164
+ impact_point_detected = True
165
+ if impact_point_detected:
166
  cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
167
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
168
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
169
 
170
+ # Add wicket lines for the stumps
171
+ stumps_x = frame.shape[1] // 2
172
+ stumps_y = frame.shape[0] * 0.9
173
+ stumps_width = frame.shape[1] * 0.1
174
+ cv2.line(frame, (int(stumps_x - stumps_width / 2), int(stumps_y)),
175
+ (int(stumps_x + stumps_width / 2), int(stumps_y)), (0, 255, 0), 3)
176
+
177
+ # Write frames to output video
178
  for _ in range(SLOW_MOTION_FACTOR):
179
  out.write(frame)
180
+
181
  out.release()
182
  return output_path
183
 
184
  def drs_review(video):
185
  frames, ball_positions, detection_frames, debug_log = process_video(video)
186
  if not frames:
187
+ return f"Error: Failed to process video", None
188
+ trajectory, pitch_point, impact_point, trajectory_log = estimate_trajectory(ball_positions, frames)
189
+ decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
190
 
191
  output_path = f"output_{uuid.uuid4()}.mp4"
192
+ slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, output_path)
193
 
194
+ return f"DRS Decision: {decision}", slow_motion_path
 
195
 
196
  # Gradio interface
197
  iface = gr.Interface(
198
  fn=drs_review,
199
  inputs=gr.Video(label="Upload Video Clip"),
200
  outputs=[
201
+ gr.Textbox(label="DRS Decision"),
202
+ gr.Video(label="Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow), Wicket Lines")
203
  ],
204
  title="AI-Powered DRS for LBW in Local Cricket",
205
+ description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection (green boxes), trajectory (blue line), pitch point (red circle), impact point (yellow circle), and wicket lines."
206
  )
207
 
208
  if __name__ == "__main__":
209
+ iface.launch()