AjaykumarPilla commited on
Commit
ef9f7bb
·
verified ·
1 Parent(s): 30b23a8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -58
app.py CHANGED
@@ -4,16 +4,15 @@ import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
7
- import plotly.graph_objects as go
8
  import uuid
9
  import os
10
  import tempfile
11
 
12
- # Load YOLOv8 model and resolve class index
13
  model = YOLO("best.pt")
14
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
15
 
16
- # Dynamically resolve ball class index
17
  ball_class_index = None
18
  for k, v in model.names.items():
19
  if v.lower() == "cricketball":
@@ -26,7 +25,7 @@ if ball_class_index is None:
26
  STUMPS_WIDTH = 0.2286
27
  BALL_DIAMETER = 0.073
28
  FRAME_RATE = 20
29
- SLOW_MOTION_FACTOR = 3
30
  CONF_THRESHOLD = 0.2
31
  IMPACT_ZONE_Y = 0.85
32
  IMPACT_DELTA_Y = 50
@@ -63,55 +62,17 @@ def process_video(video_path):
63
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
64
  cap.release()
65
 
66
- if not ball_positions:
67
- debug_log.append("No balls detected in any frame")
68
- else:
69
- debug_log.append(f"Total ball detections: {len(ball_positions)}")
70
- debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
71
-
72
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
73
 
74
  def find_bounce_point(ball_coords):
75
- """
76
- Detect bounce point using y-derivative reversal with early-frame suppression.
77
- Looks for where y increases then decreases (ball hits ground).
78
- """
79
  y_coords = [p[1] for p in ball_coords]
80
- min_index = None
81
-
82
  for i in range(2, len(y_coords) - 2):
83
  dy1 = y_coords[i] - y_coords[i - 1]
84
  dy2 = y_coords[i + 1] - y_coords[i]
85
  if dy1 > 0 and dy2 < 0:
86
  if i > len(y_coords) * 0.2:
87
- min_index = i
88
- break
89
-
90
- if min_index is not None:
91
- return ball_coords[min_index]
92
-
93
- return ball_coords[len(ball_coords)//2]
94
-
95
- def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
96
- if not frames or not trajectory or len(ball_positions) < 2:
97
- return "Not enough data", trajectory, pitch_point, impact_point
98
-
99
- frame_height, frame_width = frames[0].shape[:2]
100
- stumps_x = frame_width / 2
101
- stumps_y = frame_height * 0.9
102
- stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
103
-
104
- pitch_x, _ = pitch_point
105
- impact_x, impact_y = impact_point
106
-
107
- if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
108
- return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
109
- if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
110
- return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
111
- for x, y in trajectory:
112
- if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
113
- return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
114
- return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
115
 
116
  def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
117
  if len(ball_positions) < 2:
@@ -149,50 +110,79 @@ def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_wi
149
 
150
  return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
153
  if not frames or not trajectory:
154
- return None
155
 
156
- temp_file = os.path.join(tempfile.gettempdir(), f"drs_output_{uuid.uuid4()}.mp4")
157
  height, width = frames[0].shape[:2]
158
- out = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
 
 
 
 
159
 
160
  min_frame = min(detection_frames)
161
  max_frame = max(detection_frames)
162
  total_frames = max_frame - min_frame + 1
163
  traj_per_frame = max(1, len(trajectory) // total_frames)
164
- indices = [min(i * traj_per_frame, len(trajectory)-1) for i in range(total_frames)]
165
 
166
  for i, frame in enumerate(frames):
 
167
  idx = i - min_frame
168
  if 0 <= idx < len(indices):
169
  end_idx = indices[idx]
170
- points = np.array(trajectory[:end_idx+1], dtype=np.int32).reshape((-1, 1, 2))
171
  cv2.polylines(frame, [points], False, (255, 0, 0), 2)
 
172
  if pitch_point and i == detection_frames[0]:
173
  cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
174
  if impact_point and i == detection_frames[-1]:
175
  cv2.circle(frame, tuple(map(int, impact_point)), 6, (0, 255, 255), -1)
176
  for _ in range(SLOW_MOTION_FACTOR):
177
- out.write(frame)
178
- out.release()
179
- return temp_file
 
 
 
180
 
181
  def drs_review(video):
182
  frames, ball_positions, detection_frames, debug_log = process_video(video)
183
  if not frames or not ball_positions:
184
- return "No frames or detections found.", None
185
 
186
  frame_height, frame_width = frames[0].shape[:2]
187
  trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
188
  if not trajectory:
189
- return f"{log}\n{debug_log}", None
190
 
191
  decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
192
- replay_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
193
 
194
  result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
195
- return result_log, replay_path
196
 
197
  # Gradio Interface
198
  iface = gr.Interface(
@@ -200,10 +190,11 @@ iface = gr.Interface(
200
  inputs=gr.Video(label="Upload Cricket Delivery Video"),
201
  outputs=[
202
  gr.Textbox(label="DRS Result and Debug Info"),
203
- gr.Video(label="Replay with Trajectory & Decision")
 
204
  ],
205
  title="GullyDRS - AI-Powered LBW Review",
206
- description="Upload a cricket delivery video. The system will track the ball, estimate trajectory, and return a replay with an OUT/NOT OUT decision."
207
  )
208
 
209
  if __name__ == "__main__":
 
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
 
7
  import uuid
8
  import os
9
  import tempfile
10
 
11
+ # Load YOLOv8 model
12
  model = YOLO("best.pt")
13
  model.to('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
+ # Resolve class index for cricket ball
16
  ball_class_index = None
17
  for k, v in model.names.items():
18
  if v.lower() == "cricketball":
 
25
  STUMPS_WIDTH = 0.2286
26
  BALL_DIAMETER = 0.073
27
  FRAME_RATE = 20
28
+ SLOW_MOTION_FACTOR = 2 # Set to 1 for normal speed replay
29
  CONF_THRESHOLD = 0.2
30
  IMPACT_ZONE_Y = 0.85
31
  IMPACT_DELTA_Y = 50
 
62
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
63
  cap.release()
64
 
 
 
 
 
 
 
65
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
66
 
67
  def find_bounce_point(ball_coords):
 
 
 
 
68
  y_coords = [p[1] for p in ball_coords]
 
 
69
  for i in range(2, len(y_coords) - 2):
70
  dy1 = y_coords[i] - y_coords[i - 1]
71
  dy2 = y_coords[i + 1] - y_coords[i]
72
  if dy1 > 0 and dy2 < 0:
73
  if i > len(y_coords) * 0.2:
74
+ return ball_coords[i]
75
+ return ball_coords[len(ball_coords) // 2]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
78
  if len(ball_positions) < 2:
 
110
 
111
  return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
112
 
113
+ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
114
+ if not frames or not trajectory or len(ball_positions) < 2:
115
+ return "Not enough data", trajectory, pitch_point, impact_point
116
+
117
+ frame_height, frame_width = frames[0].shape[:2]
118
+ stumps_x = frame_width / 2
119
+ stumps_y = frame_height * 0.9
120
+ stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
121
+
122
+ pitch_x, _ = pitch_point
123
+ impact_x, impact_y = impact_point
124
+
125
+ if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
126
+ return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
127
+ if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
128
+ return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
129
+ for x, y in trajectory:
130
+ if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
131
+ return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
132
+ return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
133
+
134
  def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
135
  if not frames or not trajectory:
136
+ return None, None
137
 
 
138
  height, width = frames[0].shape[:2]
139
+ slow_path = os.path.join(tempfile.gettempdir(), f"drs_slow_{uuid.uuid4()}.mp4")
140
+ normal_path = os.path.join(tempfile.gettempdir(), f"drs_normal_{uuid.uuid4()}.mp4")
141
+
142
+ slow_writer = cv2.VideoWriter(slow_path, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
143
+ normal_writer = cv2.VideoWriter(normal_path, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE, (width, height))
144
 
145
  min_frame = min(detection_frames)
146
  max_frame = max(detection_frames)
147
  total_frames = max_frame - min_frame + 1
148
  traj_per_frame = max(1, len(trajectory) // total_frames)
149
+ indices = [min(i * traj_per_frame, len(trajectory) - 1) for i in range(total_frames)]
150
 
151
  for i, frame in enumerate(frames):
152
+ frame_copy = frame.copy()
153
  idx = i - min_frame
154
  if 0 <= idx < len(indices):
155
  end_idx = indices[idx]
156
+ points = np.array(trajectory[:end_idx + 1], dtype=np.int32).reshape((-1, 1, 2))
157
  cv2.polylines(frame, [points], False, (255, 0, 0), 2)
158
+ cv2.polylines(frame_copy, [points], False, (255, 0, 0), 2)
159
  if pitch_point and i == detection_frames[0]:
160
  cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
161
  if impact_point and i == detection_frames[-1]:
162
  cv2.circle(frame, tuple(map(int, impact_point)), 6, (0, 255, 255), -1)
163
  for _ in range(SLOW_MOTION_FACTOR):
164
+ slow_writer.write(frame)
165
+ normal_writer.write(frame_copy)
166
+
167
+ slow_writer.release()
168
+ normal_writer.release()
169
+ return slow_path, normal_path
170
 
171
  def drs_review(video):
172
  frames, ball_positions, detection_frames, debug_log = process_video(video)
173
  if not frames or not ball_positions:
174
+ return "No frames or detections found.", None, None
175
 
176
  frame_height, frame_width = frames[0].shape[:2]
177
  trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
178
  if not trajectory:
179
+ return f"{log}\n{debug_log}", None, None
180
 
181
  decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
182
+ slow_path, normal_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
183
 
184
  result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
185
+ return result_log, slow_path, normal_path
186
 
187
  # Gradio Interface
188
  iface = gr.Interface(
 
190
  inputs=gr.Video(label="Upload Cricket Delivery Video"),
191
  outputs=[
192
  gr.Textbox(label="DRS Result and Debug Info"),
193
+ gr.Video(label="Slow-Motion Replay"),
194
+ gr.Video(label="Normal-Speed Trajectory Only")
195
  ],
196
  title="GullyDRS - AI-Powered LBW Review",
197
+ description="Upload a cricket delivery video. The system will track the ball, estimate trajectory, and return both slow-motion and normal-speed replays."
198
  )
199
 
200
  if __name__ == "__main__":