dschandra commited on
Commit
d179a4e
·
verified ·
1 Parent(s): 573c1ef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -23
app.py CHANGED
@@ -4,7 +4,7 @@ import numpy as np
4
  import torch
5
  from ultralytics import YOLO
6
  import gradio as gr
7
- from scipy.interpolate import interp1d
8
  import uuid
9
  import os
10
 
@@ -16,9 +16,9 @@ STUMPS_WIDTH = 0.2286 # meters (width of stumps)
16
  FRAME_RATE = 20 # Input video frame rate
17
  SLOW_MOTION_FACTOR = 2 # Reduced for faster output
18
  CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
- PITCH_ZONE_Y = 0.9 # Fraction of frame height for pitch zone
20
- IMPACT_ZONE_Y = 0.8 # Fraction of frame height for impact zone
21
- IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
22
  STUMPS_HEIGHT = 0.711 # meters (height of stumps)
23
 
24
  def process_video(video_path):
@@ -61,7 +61,7 @@ def estimate_trajectory(ball_positions, detection_frames, frames):
61
  return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
62
  frame_height = frames[0].shape[0]
63
 
64
- # Filter to unique positions to reduce interpolation points
65
  unique_positions = [ball_positions[0]]
66
  for pos in ball_positions[1:]:
67
  if abs(pos[0] - unique_positions[-1][0]) > 10 or abs(pos[1] - unique_positions[-1][1]) > 10:
@@ -70,6 +70,10 @@ def estimate_trajectory(ball_positions, detection_frames, frames):
70
  y_coords = [pos[1] for pos in unique_positions]
71
  times = np.array([i / FRAME_RATE for i in range(len(unique_positions))])
72
 
 
 
 
 
73
  pitch_idx = 0
74
  for i, y in enumerate(y_coords):
75
  if y > frame_height * PITCH_ZONE_Y:
@@ -94,12 +98,12 @@ def estimate_trajectory(ball_positions, detection_frames, frames):
94
  times = times[:impact_idx + 1]
95
 
96
  try:
97
- fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
98
- fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
99
  except Exception as e:
100
  return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
101
 
102
- vis_trajectory = list(zip(x_coords, y_coords))
103
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 5)
104
  x_full = fx(t_full)
105
  y_full = fy(t_full)
@@ -118,9 +122,9 @@ def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_po
118
 
119
  frame_height, frame_width = frames[0].shape[:2]
120
  stumps_x = frame_width / 2
121
- stumps_y = frame_height * 0.9
122
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
123
- batsman_area_y = frame_height * 0.8
124
 
125
  pitch_x, pitch_y = pitch_point
126
  impact_x, impact_y = impact_point
@@ -150,7 +154,7 @@ def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impac
150
  return None
151
  frame_height, frame_width = frames[0].shape[:2]
152
  stumps_x = frame_width / 2
153
- stumps_y = frame_height * 0.9
154
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
155
  stumps_height_pixels = frame_height * (STUMPS_HEIGHT / 3.0)
156
 
@@ -160,33 +164,38 @@ def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impac
160
  trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))
161
 
162
  for i, frame in enumerate(frames):
163
- # Draw stumps (single line for efficiency)
164
  cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y)),
165
  (int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
 
 
 
 
166
 
167
- # Draw crease line
168
- cv2.line(frame, (0, int(stumps_y)), (frame_width, int(stumps_y)), (255, 255, 0), 2)
 
169
 
170
  if i in detection_frames and trajectory_points.size > 0:
171
  idx = detection_frames.index(i) + 1
172
  if idx <= len(trajectory_points):
173
- cv2.polylines(frame, [trajectory_points[:idx]], False, (255, 0, 0), 2)
174
 
175
  if pitch_point and i == pitch_frame:
176
  x, y = pitch_point
177
- cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 0), -1)
178
- cv2.putText(frame, "Pitching Factor", (int(x) + 10, int(y) - 10),
179
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
180
 
181
  if impact_point and i == impact_frame:
182
  x, y = impact_point
183
- cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
184
- cv2.putText(frame, "Impact Factor", (int(x) + 10, int(y) + 20),
185
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
186
 
187
  if impact_point and i == impact_frame and "Out" in decision:
188
- cv2.putText(frame, "Wicket Factor", (int(stumps_x) - 50, int(stumps_y) - 20),
189
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 1)
190
 
191
  for _ in range(SLOW_MOTION_FACTOR):
192
  out.write(frame)
@@ -212,10 +221,10 @@ iface = gr.Interface(
212
  inputs=gr.Video(label="Upload Video Clip"),
213
  outputs=[
214
  gr.Textbox(label="DRS Decision and Debug Log"),
215
- gr.Video(label="Optimized Slow-Motion Replay with Pitching Factor (Green), Impact Factor (Red), Wicket Factor (Orange), Stumps (White), Crease (Yellow)")
216
  ],
217
  title="AI-Powered DRS for LBW in Local Cricket",
218
- description="Upload a video clip of a cricket delivery to get an LBW decision and optimized slow-motion replay showing pitching factor (green circle), impact factor (red circle), wicket factor (orange text), stumps (white lines), and crease line (yellow line)."
219
  )
220
 
221
  if __name__ == "__main__":
 
4
  import torch
5
  from ultralytics import YOLO
6
  import gradio as gr
7
+ from scipy.interpolate import interp1d, UnivariateSpline
8
  import uuid
9
  import os
10
 
 
16
  FRAME_RATE = 20 # Input video frame rate
17
  SLOW_MOTION_FACTOR = 2 # Reduced for faster output
18
  CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
+ PITCH_ZONE_Y = 0.85 # Adjusted for pitch near stumps
20
+ IMPACT_ZONE_Y = 0.75 # Adjusted for impact near batsman leg
21
+ IMPACT_DELTA_Y = 30 # Reduced for finer impact detection
22
  STUMPS_HEIGHT = 0.711 # meters (height of stumps)
23
 
24
  def process_video(video_path):
 
61
  return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
62
  frame_height = frames[0].shape[0]
63
 
64
+ # Filter to unique positions
65
  unique_positions = [ball_positions[0]]
66
  for pos in ball_positions[1:]:
67
  if abs(pos[0] - unique_positions[-1][0]) > 10 or abs(pos[1] - unique_positions[-1][1]) > 10:
 
70
  y_coords = [pos[1] for pos in unique_positions]
71
  times = np.array([i / FRAME_RATE for i in range(len(unique_positions))])
72
 
73
+ # Smooth coordinates with spline interpolation
74
+ x_smooth = UnivariateSpline(times, x_coords, s=10)
75
+ y_smooth = UnivariateSpline(times, y_coords, s=10)
76
+
77
  pitch_idx = 0
78
  for i, y in enumerate(y_coords):
79
  if y > frame_height * PITCH_ZONE_Y:
 
98
  times = times[:impact_idx + 1]
99
 
100
  try:
101
+ fx = interp1d(times, x_smooth(times), kind='linear', fill_value="extrapolate")
102
+ fy = interp1d(times, y_smooth(times), kind='quadratic', fill_value="extrapolate")
103
  except Exception as e:
104
  return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
105
 
106
+ vis_trajectory = list(zip(x_smooth(times), y_smooth(times)))
107
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 5)
108
  x_full = fx(t_full)
109
  y_full = fy(t_full)
 
122
 
123
  frame_height, frame_width = frames[0].shape[:2]
124
  stumps_x = frame_width / 2
125
+ stumps_y = frame_height * 0.85 # Adjusted to align with pitch
126
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
127
+ batsman_area_y = frame_height * 0.75
128
 
129
  pitch_x, pitch_y = pitch_point
130
  impact_x, impact_y = impact_point
 
154
  return None
155
  frame_height, frame_width = frames[0].shape[:2]
156
  stumps_x = frame_width / 2
157
+ stumps_y = frame_height * 0.85 # Align with pitch
158
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
159
  stumps_height_pixels = frame_height * (STUMPS_HEIGHT / 3.0)
160
 
 
164
  trajectory_points = np.array(vis_trajectory, dtype=np.int32).reshape((-1, 1, 2))
165
 
166
  for i, frame in enumerate(frames):
167
+ # Draw stumps outline
168
  cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y)),
169
  (int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
170
+ cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y - stumps_height_pixels)),
171
+ (int(stumps_x - stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
172
+ cv2.line(frame, (int(stumps_x + stumps_width_pixels / 2), int(stumps_y - stumps_height_pixels)),
173
+ (int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 255), 2)
174
 
175
+ # Draw crease line at stumps
176
+ cv2.line(frame, (int(stumps_x - stumps_width_pixels / 2), int(stumps_y)),
177
+ (int(stumps_x + stumps_width_pixels / 2), int(stumps_y)), (255, 255, 0), 2)
178
 
179
  if i in detection_frames and trajectory_points.size > 0:
180
  idx = detection_frames.index(i) + 1
181
  if idx <= len(trajectory_points):
182
+ cv2.polylines(frame, [trajectory_points[:idx]], False, (0, 0, 255), 2) # Blue trajectory
183
 
184
  if pitch_point and i == pitch_frame:
185
  x, y = pitch_point
186
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 0), -1) # Green for pitching
187
+ cv2.putText(frame, "Pitching", (int(x) + 10, int(y) - 10),
188
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
189
 
190
  if impact_point and i == impact_frame:
191
  x, y = impact_point
192
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red for impact
193
+ cv2.putText(frame, "Impact", (int(x) + 10, int(y) + 20),
194
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
195
 
196
  if impact_point and i == impact_frame and "Out" in decision:
197
+ cv2.putText(frame, "Wickets", (int(stumps_x) - 50, int(stumps_y) - 20),
198
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 165, 255), 1) # Orange for wickets
199
 
200
  for _ in range(SLOW_MOTION_FACTOR):
201
  out.write(frame)
 
221
  inputs=gr.Video(label="Upload Video Clip"),
222
  outputs=[
223
  gr.Textbox(label="DRS Decision and Debug Log"),
224
+ gr.Video(label="Optimized Slow-Motion Replay with Pitching (Green), Impact (Red), Wickets (Orange), Stumps (White), Crease (Yellow)")
225
  ],
226
  title="AI-Powered DRS for LBW in Local Cricket",
227
+ description="Upload a video clip of a cricket delivery to get an LBW decision and optimized slow-motion replay showing pitching (green circle), impact (red circle), wickets (orange text), stumps (white outline), and crease line (yellow line)."
228
  )
229
 
230
  if __name__ == "__main__":