AjaykumarPilla commited on
Commit
689fb64
·
verified ·
1 Parent(s): a653421

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -121
app.py CHANGED
@@ -1,154 +1,287 @@
1
  import cv2
2
  import numpy as np
 
 
 
3
  import torch
4
- from ultralytics import YOLO
5
  import gradio as gr
6
- from scipy.interpolate import interp1d
7
- import uuid
8
  import os
 
 
 
9
 
10
- # Load the trained YOLOv8n model from the Space's root directory
11
- model = YOLO("best.pt") # Assumes best.pt is in the same directory as app.py
12
 
13
- # Constants for LBW decision and video processing
14
- STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
- BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
- FRAME_RATE = 30 # Input video frame rate
17
- SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
18
- CONF_THRESHOLD = 0.3 # Lowered confidence threshold for better detection
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def process_video(video_path):
21
- # Initialize video capture
22
- if not os.path.exists(video_path):
23
- return [], [], "Error: Video file not found"
24
  cap = cv2.VideoCapture(video_path)
25
- frames = []
26
- ball_positions = []
27
- debug_log = []
28
-
 
 
 
 
 
29
  frame_count = 0
 
 
30
  while cap.isOpened():
 
31
  ret, frame = cap.read()
32
  if not ret:
33
  break
 
 
 
 
 
 
 
 
 
 
34
  frame_count += 1
35
- frames.append(frame.copy()) # Store original frame
36
- # Detect ball using the trained YOLOv8n model
37
- results = model.predict(frame, conf=CONF_THRESHOLD)
38
- detections = 0
39
- for detection in results[0].boxes:
40
- if detection.cls == 0: # Assuming class 0 is the ball
41
- detections += 1
42
- x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
43
- ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
44
- # Draw bounding box on frame for visualization
45
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
46
- frames[-1] = frame # Update frame with bounding box
47
- debug_log.append(f"Frame {frame_count}: {detections} ball detections")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  cap.release()
 
 
49
 
50
- if not ball_positions:
51
- debug_log.append("No balls detected in any frame")
52
- else:
53
- debug_log.append(f"Total ball detections: {len(ball_positions)}")
54
 
55
- return frames, ball_positions, "\n".join(debug_log)
 
 
 
56
 
57
- def estimate_trajectory(ball_positions, frames):
58
- # Simplified physics-based trajectory projection
59
- if len(ball_positions) < 2:
60
- return None, None, "Error: Fewer than 2 ball detections for trajectory"
61
- # Extract x, y coordinates
62
- x_coords = [pos[0] for pos in ball_positions]
63
- y_coords = [pos[1] for pos in ball_positions]
64
- times = np.arange(len(ball_positions)) / FRAME_RATE
65
 
66
- # Interpolate to smooth trajectory
67
  try:
68
- fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
69
- fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
70
- except Exception as e:
71
- return None, None, f"Error in trajectory interpolation: {str(e)}"
72
-
73
- # Project trajectory forward (0.5 seconds post-impact)
74
- t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
75
- x_future = fx(t_future)
76
- y_future = fy(t_future)
77
-
78
- return list(zip(x_future, y_future)), t_future, "Trajectory estimated successfully"
79
-
80
- def lbw_decision(ball_positions, trajectory, frames):
81
- # Simplified LBW logic
82
- if not frames:
83
- return "Error: No frames processed", None
84
- if not trajectory or len(ball_positions) < 2:
85
- return "Not enough data (insufficient ball detections)", None
86
-
87
- # Assume stumps are at the bottom center of the frame (calibration needed)
88
- frame_height, frame_width = frames[0].shape[:2]
89
- stumps_x = frame_width / 2
90
- stumps_y = frame_height * 0.9 # Approximate stumps position
91
- stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0) # Assume 3m pitch width
92
-
93
- # Check pitching point (first detected position)
94
- pitch_x, pitch_y = ball_positions[0]
95
- if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
96
- return "Not Out (Pitched outside line)", None
97
-
98
- # Check impact point (last detected position)
99
- impact_x, impact_y = ball_positions[-1]
100
- if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
101
- return "Not Out (Impact outside line)", None
102
-
103
- # Check trajectory hitting stumps
104
- for x, y in trajectory:
105
- if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
106
- return "Out", trajectory
107
- return "Not Out (Missing stumps)", trajectory
108
-
109
- def generate_slow_motion(frames, trajectory, output_path):
110
- # Generate very slow-motion video with ball detection and trajectory overlay
111
- if not frames:
112
- return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
114
- out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
 
 
 
115
 
116
- for frame in frames:
 
 
 
 
 
 
 
 
 
 
 
 
117
  if trajectory:
118
- for x, y in trajectory:
119
- cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots for trajectory
120
- for _ in range(SLOW_MOTION_FACTOR): # Duplicate frames for very slow motion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  out.write(frame)
122
- out.release()
123
- return output_path
124
 
125
- def drs_review(video):
126
- # Process video and generate DRS output
127
- frames, ball_positions, debug_log = process_video(video)
128
- if not frames:
129
- return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
130
- trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
131
- decision, trajectory = lbw_decision(ball_positions, trajectory, frames)
132
 
133
- # Generate slow-motion replay even if Trajectory fails
134
- output_path = f"output_{uuid.uuid4()}.mp4"
135
- slow_motion_path = generate_slow_motion(frames, trajectory, output_path)
136
 
137
- # Combine debug logs for output
138
- debug_output = f"{debug_log}\n{trajectory_log}"
139
- return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
140
 
141
  # Gradio interface
142
- iface = gr.Interface(
143
- fn=drs_review,
144
- inputs=gr.Video(label="Upload Video Clip"),
145
- outputs=[
146
- gr.Textbox(label="DRS Decision and Debug Log"),
147
- gr.Video(label="Very Slow-Motion Replay with Ball Detection and Trajectory")
148
- ],
149
- title="AI-Powered DRS for LBW in Local Cricket",
150
- description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes) and trajectory (blue dots)."
151
- )
152
 
153
  if __name__ == "__main__":
154
- iface.launch()
 
1
  import cv2
2
  import numpy as np
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
  import torch
 
7
  import gradio as gr
 
 
8
  import os
9
+ import time
10
+ from scipy.optimize import curve_fit
11
+ import sys
12
 
13
+ # Add yolov5 directory to sys.path
14
+ sys.path.append(os.path.join(os.path.dirname(__file__), "yolov5"))
15
 
16
+ # Import YOLOv5 modules
17
+ from models.experimental import attempt_load
18
+ from utils.general import non_max_suppression, xywh2xyxy
 
 
 
19
 
20
+ # Cricket pitch dimensions (in meters)
21
+ PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
22
+ PITCH_WIDTH = 3.05 # Width of pitch
23
+ STUMP_HEIGHT = 0.71 # Stump height
24
+ STUMP_WIDTH = 0.2286 # Stump width (including bails)
25
+
26
+ # Model input size (adjust if yolov5s.pt was trained with a different size)
27
+ MODEL_INPUT_SIZE = (640, 640) # (height, width)
28
+ FRAME_SKIP = 2 # Process every 2nd frame
29
+ MIN_DETECTIONS = 10 # Stop after 10 detections
30
+ BATCH_SIZE = 4 # Process 4 frames at a time
31
+ SLOW_MOTION_FACTOR = 3 # Duplicate each frame 3 times for slow motion
32
+
33
+ # Load model
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ model = attempt_load("best.pt") # Load yolov5s.pt
36
+ model.to(device).eval() # Move model to device and set to evaluation mode
37
+
38
+ # Function to process video and detect ball
39
  def process_video(video_path):
 
 
 
40
  cap = cv2.VideoCapture(video_path)
41
+ frame_rate = cap.get(cv2.CAP_PROP_FPS)
42
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
43
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
44
+ positions = []
45
+ frame_numbers = []
46
+ bounce_frame = None
47
+ bounce_point = None
48
+ batch_frames = []
49
+ batch_frame_nums = []
50
  frame_count = 0
51
+
52
+ start_time = time.time()
53
  while cap.isOpened():
54
+ frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
55
  ret, frame = cap.read()
56
  if not ret:
57
  break
58
+
59
+ # Skip frames
60
+ if frame_count % FRAME_SKIP != 0:
61
+ frame_count += 1
62
+ continue
63
+
64
+ # Resize frame to model input size
65
+ frame = cv2.resize(frame, MODEL_INPUT_SIZE, interpolation=cv2.INTER_AREA)
66
+ batch_frames.append(frame)
67
+ batch_frame_nums.append(frame_num)
68
  frame_count += 1
69
+
70
+ # Process batch when full or at end
71
+ if len(batch_frames) == BATCH_SIZE or not ret:
72
+ # Preprocess batch
73
+ batch = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in batch_frames]
74
+ batch = np.stack(batch) # [batch_size, H, W, 3]
75
+ batch = torch.from_numpy(batch).to(device).float() / 255.0
76
+ batch = batch.permute(0, 3, 1, 2) # [batch_size, 3, H, W]
77
+
78
+ # Run inference
79
+ frame_start_time = time.time()
80
+ with torch.no_grad():
81
+ pred = model(batch)[0]
82
+ pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
83
+ print(f"Batch inference time: {time.time() - frame_start_time:.2f}s for {len(batch_frames)} frames")
84
+
85
+ # Process detections
86
+ for i, det in enumerate(pred):
87
+ if det is not None and len(det):
88
+ det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
89
+ for *xyxy, conf, cls in det:
90
+ x_center = (xyxy[0] + xyxy[2]) / 2
91
+ y_center = (xyxy[1] + xyxy[3]) / 2
92
+ # Scale coordinates back to original frame size
93
+ x_center = x_center * frame_width / MODEL_INPUT_SIZE[1]
94
+ y_center = y_center * frame_height / MODEL_INPUT_SIZE[0]
95
+ positions.append((x_center.item(), y_center.item()))
96
+ frame_numbers.append(batch_frame_nums[i])
97
+
98
+ # Detect bounce (lowest y_center point)
99
+ if bounce_frame is None or y_center > positions[bounce_frame][1]:
100
+ bounce_frame = len(frame_numbers) - 1
101
+ bounce_point = (x_center.item(), y_center.item())
102
+
103
+ batch_frames = []
104
+ batch_frame_nums = []
105
+
106
+ # Early termination
107
+ if len(positions) >= MIN_DETECTIONS:
108
+ break
109
+
110
  cap.release()
111
+ print(f"Total video processing time: {time.time() - start_time:.2f}s")
112
+ return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
113
 
114
+ # Polynomial function for trajectory fitting
115
+ def poly_func(x, a, b, c):
116
+ return a * x**2 + b * x + c
 
117
 
118
+ # Predict trajectory and wicket inline path
119
+ def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
120
+ if len(positions) < 3:
121
+ return None, None, "Insufficient detections for trajectory prediction"
122
 
123
+ x_coords = [p[0] for p in positions]
124
+ y_coords = [p[1] for p in positions]
125
+ frames = np.array(frame_numbers)
 
 
 
 
 
126
 
127
+ # Fit polynomial to x and y coordinates
128
  try:
129
+ popt_x, _ = curve_fit(poly_func, frames, x_coords)
130
+ popt_y, _ = curve_fit(poly_func, frames, y_coords)
131
+ except:
132
+ return None, None, "Failed to fit trajectory"
133
+
134
+ # Extrapolate to stumps
135
+ frame_max = max(frames) + 10
136
+ future_frames = np.linspace(min(frames), frame_max, 100)
137
+ x_pred = poly_func(future_frames, *popt_x)
138
+ y_pred = poly_func(future_frames, *popt_y)
139
+
140
+ # Wicket inline path (center line toward stumps)
141
+ stump_x = frame_width / 2
142
+ stump_y = frame_height
143
+ inline_x = np.linspace(min(x_coords), stump_x, 100)
144
+ inline_y = np.interp(inline_x, x_pred, y_pred)
145
+
146
+ # Check if trajectory hits stumps
147
+ stump_hit = False
148
+ for x, y in zip(x_pred, y_pred):
149
+ if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
150
+ stump_hit = True
151
+ break
152
+
153
+ lbw_decision = "OUT" if stump_hit else "NOT OUT"
154
+ return list(zip(future_frames, x_pred, y_pred)), list(zip(inline_x, inline_y)), lbw_decision
155
+
156
+ # Map pitch location
157
+ def map_pitch(bounce_point, frame_width, frame_height):
158
+ if bounce_point is None:
159
+ return None, "No bounce detected"
160
+
161
+ x, y = bounce_point
162
+ pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
163
+ pitch_y = (1 - y / frame_height) * PITCH_LENGTH
164
+ return pitch_x, pitch_y
165
+
166
+ # Estimate ball speed
167
+ def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
168
+ if len(positions) < 2:
169
+ return None, "Insufficient detections for speed estimation"
170
+
171
+ distances = []
172
+ for i in range(1, len(positions)):
173
+ x1, y1 = positions[i-1]
174
+ x2, y2 = positions[i]
175
+ pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
176
+ distances.append(pixel_dist)
177
+
178
+ pixel_to_meter = PITCH_LENGTH / frame_width
179
+ distances_m = [d * pixel_to_meter for d in distances]
180
+ time_interval = 1 / frame_rate
181
+ speeds = [d / time_interval for d in distances_m]
182
+ avg_speed_kmh = np.mean(speeds) * 3.6
183
+ return avg_speed_kmh, "Speed calculated successfully"
184
+
185
+ # Main Gradio function with video overlay and slow motion
186
+ def drs_analysis(video):
187
+ # Video is a file path (string) in Hugging Face Spaces
188
+ video_path = video if isinstance(video, str) else "temp_video.mp4"
189
+ if not isinstance(video, str):
190
+ with open(video_path, "wb") as f:
191
+ f.write(video.read())
192
+
193
+ # Process video for detections
194
+ positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
195
+ if not positions:
196
+ return None, None, "No ball detected in video", None
197
+
198
+ # Predict trajectory and wicket path
199
+ trajectory, inline_path, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
200
+ if trajectory is None:
201
+ return None, None, lbw_decision, None
202
+
203
+ pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
204
+ speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
205
+
206
+ # Create output video with overlays and slow motion
207
+ output_path = "output_video.mp4"
208
+ cap = cv2.VideoCapture(video_path)
209
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
210
+ out = cv2.VideoWriter(output_path, fourcc, frame_rate, (frame_width, frame_height))
211
+
212
+ frame_count = 0
213
+ positions_dict = dict(zip(frame_numbers, positions))
214
 
215
+ while cap.isOpened():
216
+ ret, frame = cap.read()
217
+ if not ret:
218
+ break
219
+
220
+ # Skip frames for consistency with detection
221
+ if frame_count % FRAME_SKIP != 0:
222
+ frame_count += 1
223
+ continue
224
+
225
+ # Overlay ball trajectory (red) and wicket inline path (blue)
226
+ if frame_count in positions_dict:
227
+ cv2.circle(frame, (int(positions_dict[frame_count][0]), int(positions_dict[frame_count][1])), 5, (0, 0, 255), -1) # Red dot
228
  if trajectory:
229
+ traj_x = [int(t[1]) for t in trajectory if t[0] >= frame_count]
230
+ traj_y = [int(t[2]) for t in trajectory if t[0] >= frame_count]
231
+ if traj_x and traj_y:
232
+ for i in range(1, len(traj_x)):
233
+ cv2.line(frame, (traj_x[i-1], traj_y[i-1]), (traj_x[i], traj_y[i]), (0, 0, 255), 2) # Red line
234
+ if inline_path:
235
+ inline_x = [int(x) for x, _ in inline_path]
236
+ inline_y = [int(y) for _, y in inline_path]
237
+ if inline_x and inline_y:
238
+ for i in range(1, len(inline_x)):
239
+ cv2.line(frame, (inline_x[i-1], inline_y[i-1]), (inline_x[i], inline_y[i]), (255, 0, 0), 2) # Blue line
240
+
241
+ # Overlay pitch map in top-right corner
242
+ if pitch_x is not None and pitch_y is not None:
243
+ map_width = 200
244
+ # Cap map_height to 25% of frame height to ensure it fits
245
+ map_height = min(int(map_width * PITCH_LENGTH / PITCH_WIDTH), frame_height // 4)
246
+ pitch_map = np.zeros((map_height, map_width, 3), dtype=np.uint8)
247
+ pitch_map[:] = (0, 255, 0) # Green pitch
248
+ cv2.rectangle(pitch_map, (0, map_height-10), (map_width, map_height), (0, 51, 51), -1) # Brown stumps
249
+ bounce_x = int((pitch_x + PITCH_WIDTH/2) / PITCH_WIDTH * map_width)
250
+ bounce_y = int((1 - pitch_y / PITCH_LENGTH) * map_height)
251
+ cv2.circle(pitch_map, (bounce_x, bounce_y), 5, (0, 0, 255), -1) # Red bounce point
252
+ # Ensure overlay fits within frame
253
+ overlay_region = frame[0:map_height, frame_width-map_width:frame_width]
254
+ if overlay_region.shape[0] >= map_height and overlay_region.shape[1] >= map_width:
255
+ frame[0:map_height, frame_width-map_width:frame_width] = cv2.resize(pitch_map, (map_width, map_height))
256
+
257
+ # Add text annotations
258
+ text = f"LBW: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h"
259
+ cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
260
+
261
+ # Write frame multiple times for slow motion
262
+ for _ in range(SLOW_MOTION_FACTOR):
263
  out.write(frame)
 
 
264
 
265
+ frame_count += 1
266
+
267
+ cap.release()
268
+ out.release()
 
 
 
269
 
270
+ if not isinstance(video, str):
271
+ os.remove(video_path)
 
272
 
273
+ return None, None, None, output_path
 
 
274
 
275
  # Gradio interface
276
+ with gr.Blocks() as demo:
277
+ gr.Markdown("## Cricket DRS Analysis")
278
+ video_input = gr.Video(label="Upload Video Clip")
279
+ btn = gr.Button("Analyze")
280
+ trajectory_output = gr.Plot(label="Ball Trajectory")
281
+ pitch_output = gr.Plot(label="Pitch Map")
282
+ text_output = gr.Textbox(label="Analysis Results")
283
+ video_output = gr.Video(label="Processed Video")
284
+ btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])
 
285
 
286
  if __name__ == "__main__":
287
+ demo.launch()