AjaykumarPilla commited on
Commit
b267b22
·
verified ·
1 Parent(s): 7e4d469

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -73
app.py CHANGED
@@ -4,7 +4,6 @@ import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
7
- import plotly.graph_objects as go
8
  import uuid
9
  import os
10
  import tempfile
@@ -13,7 +12,7 @@ import tempfile
13
  model = YOLO("best.pt")
14
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
15
 
16
- # Dynamically resolve ball class index
17
  ball_class_index = None
18
  for k, v in model.names.items():
19
  if v.lower() == "cricketball":
@@ -34,6 +33,26 @@ PITCH_LENGTH = 20.12
34
  STUMPS_HEIGHT = 0.71
35
  MAX_POSITION_JUMP = 30
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def process_video(video_path):
38
  if not os.path.exists(video_path):
39
  return [], [], [], "Error: Video file not found"
@@ -42,7 +61,6 @@ def process_video(video_path):
42
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
43
  frames, ball_positions, detection_frames, debug_log = [], [], [], []
44
  frame_count = 0
45
-
46
  while cap.isOpened():
47
  ret, frame = cap.read()
48
  if not ret:
@@ -62,23 +80,11 @@ def process_video(video_path):
62
  frames[-1] = frame
63
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
64
  cap.release()
65
-
66
- if not ball_positions:
67
- debug_log.append("No balls detected in any frame")
68
- else:
69
- debug_log.append(f"Total ball detections: {len(ball_positions)}")
70
- debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
71
-
72
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
73
 
74
  def find_bounce_point(ball_coords):
75
- """
76
- Detect bounce point using y-derivative reversal with early-frame suppression.
77
- Looks for where y increases then decreases (ball hits ground).
78
- """
79
  y_coords = [p[1] for p in ball_coords]
80
  min_index = None
81
-
82
  for i in range(2, len(y_coords) - 2):
83
  dy1 = y_coords[i] - y_coords[i - 1]
84
  dy2 = y_coords[i + 1] - y_coords[i]
@@ -86,89 +92,70 @@ def find_bounce_point(ball_coords):
86
  if i > len(y_coords) * 0.2:
87
  min_index = i
88
  break
89
-
90
- if min_index is not None:
91
- return ball_coords[min_index]
92
-
93
- return ball_coords[len(ball_coords)//2]
94
-
95
- def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
96
- if not frames or not trajectory or len(ball_positions) < 2:
97
- return "Not enough data", trajectory, pitch_point, impact_point
98
-
99
- frame_height, frame_width = frames[0].shape[:2]
100
- stumps_x = frame_width / 2
101
- stumps_y = frame_height * 0.9
102
- stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
103
-
104
- pitch_x, _ = pitch_point
105
- impact_x, impact_y = impact_point
106
-
107
- if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
108
- return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
109
- if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
110
- return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
111
- for x, y in trajectory:
112
- if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
113
- return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
114
- return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
115
 
116
  def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
117
  if len(ball_positions) < 2:
118
  return None, None, None, "Error: Not enough ball detections"
119
-
120
  filtered_positions = [ball_positions[0]]
121
  filtered_frames = [detection_frames[0]]
122
  for i in range(1, len(ball_positions)):
123
- prev, curr = filtered_positions[-1], ball_positions[i]
124
- if np.linalg.norm(np.array(curr) - np.array(prev)) <= MAX_POSITION_JUMP:
125
- filtered_positions.append(curr)
126
  filtered_frames.append(detection_frames[i])
127
-
128
  if len(filtered_positions) < 2:
129
  return None, None, None, "Error: Filtered detections too few"
130
-
131
  x_vals = [p[0] for p in filtered_positions]
132
  y_vals = [p[1] for p in filtered_positions]
133
  times = np.array(filtered_frames) / FRAME_RATE
134
-
135
  try:
136
  fx = interp1d(times, x_vals, kind='cubic', fill_value="extrapolate")
137
  fy = interp1d(times, y_vals, kind='cubic', fill_value="extrapolate")
138
  except Exception as e:
139
  return None, None, None, f"Interpolation error: {str(e)}"
140
-
141
- total_frames = max(filtered_frames) - min(filtered_frames) + 1
142
- t_full = np.linspace(times[0], times[-1], max(5, total_frames * SLOW_MOTION_FACTOR))
143
- x_full = fx(t_full)
144
- y_full = fy(t_full)
145
  trajectory = list(zip(x_full, y_full))
146
-
147
  pitch_point = find_bounce_point(filtered_positions)
148
  impact_point = filtered_positions[-1]
149
-
150
  return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
153
  if not frames or not trajectory:
154
  return None
155
-
156
- temp_file = os.path.join(tempfile.gettempdir(), f"drs_output_{uuid.uuid4()}.mp4")
157
  height, width = frames[0].shape[:2]
158
- out = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
159
-
160
- min_frame = min(detection_frames)
161
- max_frame = max(detection_frames)
162
  total_frames = max_frame - min_frame + 1
163
- traj_per_frame = max(1, len(trajectory) // total_frames)
164
- indices = [min(i * traj_per_frame, len(trajectory)-1) for i in range(total_frames)]
165
-
166
  for i, frame in enumerate(frames):
 
167
  idx = i - min_frame
168
  if 0 <= idx < len(indices):
169
  end_idx = indices[idx]
170
- points = np.array(trajectory[:end_idx+1], dtype=np.int32).reshape((-1, 1, 2))
171
- cv2.polylines(frame, [points], False, (255, 0, 0), 2)
172
  if pitch_point and i == detection_frames[0]:
173
  cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
174
  if impact_point and i == detection_frames[-1]:
@@ -176,21 +163,18 @@ def generate_replay(frames, trajectory, pitch_point, impact_point, detection_fra
176
  for _ in range(SLOW_MOTION_FACTOR):
177
  out.write(frame)
178
  out.release()
179
- return temp_file
180
 
181
  def drs_review(video):
182
  frames, ball_positions, detection_frames, debug_log = process_video(video)
183
  if not frames or not ball_positions:
184
  return "No frames or detections found.", None
185
-
186
  frame_height, frame_width = frames[0].shape[:2]
187
  trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
188
  if not trajectory:
189
  return f"{log}\n{debug_log}", None
190
-
191
  decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
192
  replay_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
193
-
194
  result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
195
  return result_log, replay_path
196
 
@@ -200,10 +184,10 @@ iface = gr.Interface(
200
  inputs=gr.Video(label="Upload Cricket Delivery Video"),
201
  outputs=[
202
  gr.Textbox(label="DRS Result and Debug Info"),
203
- gr.Video(label="Replay with Trajectory & Decision")
204
  ],
205
- title="GullyDRS - AI-Powered LBW Review",
206
- description="Upload a cricket delivery video. The system will track the ball, estimate trajectory, and return a replay with an OUT/NOT OUT decision."
207
  )
208
 
209
  if __name__ == "__main__":
 
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
 
7
  import uuid
8
  import os
9
  import tempfile
 
12
  model = YOLO("best.pt")
13
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
+ # Resolve the class index for "cricketBall"
16
  ball_class_index = None
17
  for k, v in model.names.items():
18
  if v.lower() == "cricketball":
 
33
  STUMPS_HEIGHT = 0.71
34
  MAX_POSITION_JUMP = 30
35
 
36
+ def bezier_curve(points, num=100):
37
+ """Compute a quadratic Bezier curve from given points."""
38
+ points = np.array(points)
39
+ if len(points) < 3:
40
+ return points
41
+ p0, p1, p2 = points[0], points[len(points)//2], points[-1]
42
+ t = np.linspace(0, 1, num=num)
43
+ curve = (1 - t)[:, None]**2 * p0 + 2 * (1 - t)[:, None] * t[:, None] * p1 + t[:, None]**2 * p2
44
+ return curve
45
+
46
+ def draw_stumps_overlay(frame):
47
+ height, width = frame.shape[:2]
48
+ stumps_x = width // 2
49
+ stump_top = int(height * 0.1)
50
+ stump_bottom = int(height * 0.9)
51
+ stump_width_px = int(width * 0.03)
52
+ for offset in [-stump_width_px, 0, stump_width_px]:
53
+ cv2.line(frame, (stumps_x + offset, stump_top), (stumps_x + offset, stump_bottom), (0, 0, 0), 2)
54
+ return frame
55
+
56
  def process_video(video_path):
57
  if not os.path.exists(video_path):
58
  return [], [], [], "Error: Video file not found"
 
61
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
62
  frames, ball_positions, detection_frames, debug_log = [], [], [], []
63
  frame_count = 0
 
64
  while cap.isOpened():
65
  ret, frame = cap.read()
66
  if not ret:
 
80
  frames[-1] = frame
81
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
82
  cap.release()
 
 
 
 
 
 
 
83
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
84
 
85
  def find_bounce_point(ball_coords):
 
 
 
 
86
  y_coords = [p[1] for p in ball_coords]
87
  min_index = None
 
88
  for i in range(2, len(y_coords) - 2):
89
  dy1 = y_coords[i] - y_coords[i - 1]
90
  dy2 = y_coords[i + 1] - y_coords[i]
 
92
  if i > len(y_coords) * 0.2:
93
  min_index = i
94
  break
95
+ return ball_coords[min_index] if min_index else ball_coords[len(ball_coords)//2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
98
  if len(ball_positions) < 2:
99
  return None, None, None, "Error: Not enough ball detections"
 
100
  filtered_positions = [ball_positions[0]]
101
  filtered_frames = [detection_frames[0]]
102
  for i in range(1, len(ball_positions)):
103
+ if np.linalg.norm(np.array(ball_positions[i]) - np.array(filtered_positions[-1])) <= MAX_POSITION_JUMP:
104
+ filtered_positions.append(ball_positions[i])
 
105
  filtered_frames.append(detection_frames[i])
 
106
  if len(filtered_positions) < 2:
107
  return None, None, None, "Error: Filtered detections too few"
 
108
  x_vals = [p[0] for p in filtered_positions]
109
  y_vals = [p[1] for p in filtered_positions]
110
  times = np.array(filtered_frames) / FRAME_RATE
 
111
  try:
112
  fx = interp1d(times, x_vals, kind='cubic', fill_value="extrapolate")
113
  fy = interp1d(times, y_vals, kind='cubic', fill_value="extrapolate")
114
  except Exception as e:
115
  return None, None, None, f"Interpolation error: {str(e)}"
116
+ t_full = np.linspace(times[0], times[-1], max(5, len(times) * SLOW_MOTION_FACTOR))
117
+ x_full, y_full = fx(t_full), fy(t_full)
 
 
 
118
  trajectory = list(zip(x_full, y_full))
 
119
  pitch_point = find_bounce_point(filtered_positions)
120
  impact_point = filtered_positions[-1]
 
121
  return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
122
 
123
+ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
124
+ if not frames or not trajectory or len(ball_positions) < 2:
125
+ return "Not enough data", trajectory, pitch_point, impact_point
126
+ frame_height, frame_width = frames[0].shape[:2]
127
+ stumps_x = frame_width / 2
128
+ stumps_y = frame_height * 0.9
129
+ stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
130
+ pitch_x, _ = pitch_point
131
+ impact_x, _ = impact_point
132
+ if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
133
+ return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
134
+ if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
135
+ return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
136
+ for x, y in trajectory:
137
+ if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
138
+ return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
139
+ return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
140
+
141
  def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
142
  if not frames or not trajectory:
143
  return None
144
+ bezier = bezier_curve(trajectory)
145
+ out_path = os.path.join(tempfile.gettempdir(), f"drs_output_{uuid.uuid4()}.mp4")
146
  height, width = frames[0].shape[:2]
147
+ out = cv2.VideoWriter(out_path, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
148
+ min_frame, max_frame = min(detection_frames), max(detection_frames)
 
 
149
  total_frames = max_frame - min_frame + 1
150
+ traj_per_frame = max(1, len(bezier) // total_frames)
151
+ indices = [min(i * traj_per_frame, len(bezier)-1) for i in range(total_frames)]
 
152
  for i, frame in enumerate(frames):
153
+ frame = draw_stumps_overlay(frame)
154
  idx = i - min_frame
155
  if 0 <= idx < len(indices):
156
  end_idx = indices[idx]
157
+ pts = np.array(bezier[:end_idx+1], dtype=np.int32).reshape((-1, 1, 2))
158
+ cv2.polylines(frame, [pts], False, (255, 0, 0), 2)
159
  if pitch_point and i == detection_frames[0]:
160
  cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
161
  if impact_point and i == detection_frames[-1]:
 
163
  for _ in range(SLOW_MOTION_FACTOR):
164
  out.write(frame)
165
  out.release()
166
+ return out_path
167
 
168
  def drs_review(video):
169
  frames, ball_positions, detection_frames, debug_log = process_video(video)
170
  if not frames or not ball_positions:
171
  return "No frames or detections found.", None
 
172
  frame_height, frame_width = frames[0].shape[:2]
173
  trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
174
  if not trajectory:
175
  return f"{log}\n{debug_log}", None
 
176
  decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
177
  replay_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
 
178
  result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
179
  return result_log, replay_path
180
 
 
184
  inputs=gr.Video(label="Upload Cricket Delivery Video"),
185
  outputs=[
186
  gr.Textbox(label="DRS Result and Debug Info"),
187
+ gr.Video(label="Replay with Bezier Trajectory, Pitch & Impact Points")
188
  ],
189
+ title="GullyDRS - Enhanced DRS with Bezier Trajectory",
190
+ description="Upload a cricket delivery video. System overlays stumps, draws Bezier trajectory, and returns annotated replay with DRS decision."
191
  )
192
 
193
  if __name__ == "__main__":