AjaykumarPilla commited on
Commit
462fddf
·
verified ·
1 Parent(s): d81de55

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -254
app.py CHANGED
@@ -1,287 +1,174 @@
1
  import cv2
2
  import numpy as np
3
- import pandas as pd
4
- import plotly.express as px
5
- import plotly.graph_objects as go
6
  import torch
 
7
  import gradio as gr
 
 
8
  import os
9
- import time
10
- from scipy.optimize import curve_fit
11
- import sys
12
 
13
- # Add yolov5 directory to sys.path
14
- sys.path.append(os.path.join(os.path.dirname(__file__), "yolov5"))
15
 
16
- # Import YOLOv5 modules
17
- from models.experimental import attempt_load
18
- from utils.general import non_max_suppression, xywh2xyxy
 
 
 
19
 
20
- # Cricket pitch dimensions (in meters)
21
- PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
22
- PITCH_WIDTH = 3.05 # Width of pitch
23
- STUMP_HEIGHT = 0.71 # Stump height
24
- STUMP_WIDTH = 0.2286 # Stump width (including bails)
25
-
26
- # Model input size (adjust if yolov5s.pt was trained with a different size)
27
- MODEL_INPUT_SIZE = (640, 640) # (height, width)
28
- FRAME_SKIP = 2 # Process every 2nd frame
29
- MIN_DETECTIONS = 10 # Stop after 10 detections
30
- BATCH_SIZE = 4 # Process 4 frames at a time
31
- SLOW_MOTION_FACTOR = 3 # Duplicate each frame 3 times for slow motion
32
-
33
- # Load model
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- model = attempt_load("best.pt") # Load yolov5s.pt
36
- model.to(device).eval() # Move model to device and set to evaluation mode
37
-
38
- # Function to process video and detect ball
39
  def process_video(video_path):
 
 
 
40
  cap = cv2.VideoCapture(video_path)
41
- frame_rate = cap.get(cv2.CAP_PROP_FPS)
42
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
43
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
44
- positions = []
45
- frame_numbers = []
46
- bounce_frame = None
47
- bounce_point = None
48
- batch_frames = []
49
- batch_frame_nums = []
50
- frame_count = 0
51
 
52
- start_time = time.time()
53
  while cap.isOpened():
54
- frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
55
  ret, frame = cap.read()
56
  if not ret:
57
  break
58
-
59
- # Skip frames
60
- if frame_count % FRAME_SKIP != 0:
61
- frame_count += 1
62
- continue
63
-
64
- # Resize frame to model input size
65
- frame = cv2.resize(frame, MODEL_INPUT_SIZE, interpolation=cv2.INTER_AREA)
66
- batch_frames.append(frame)
67
- batch_frame_nums.append(frame_num)
68
  frame_count += 1
69
-
70
- # Process batch when full or at end
71
- if len(batch_frames) == BATCH_SIZE or not ret:
72
- # Preprocess batch
73
- batch = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in batch_frames]
74
- batch = np.stack(batch) # [batch_size, H, W, 3]
75
- batch = torch.from_numpy(batch).to(device).float() / 255.0
76
- batch = batch.permute(0, 3, 1, 2) # [batch_size, 3, H, W]
77
-
78
- # Run inference
79
- frame_start_time = time.time()
80
- with torch.no_grad():
81
- pred = model(batch)[0]
82
- pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
83
- print(f"Batch inference time: {time.time() - frame_start_time:.2f}s for {len(batch_frames)} frames")
84
-
85
- # Process detections
86
- for i, det in enumerate(pred):
87
- if det is not None and len(det):
88
- det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
89
- for *xyxy, conf, cls in det:
90
- x_center = (xyxy[0] + xyxy[2]) / 2
91
- y_center = (xyxy[1] + xyxy[3]) / 2
92
- # Scale coordinates back to original frame size
93
- x_center = x_center * frame_width / MODEL_INPUT_SIZE[1]
94
- y_center = y_center * frame_height / MODEL_INPUT_SIZE[0]
95
- positions.append((x_center.item(), y_center.item()))
96
- frame_numbers.append(batch_frame_nums[i])
97
-
98
- # Detect bounce (lowest y_center point)
99
- if bounce_frame is None or y_center > positions[bounce_frame][1]:
100
- bounce_frame = len(frame_numbers) - 1
101
- bounce_point = (x_center.item(), y_center.item())
102
-
103
- batch_frames = []
104
- batch_frame_nums = []
105
-
106
- # Early termination
107
- if len(positions) >= MIN_DETECTIONS:
108
- break
109
-
110
  cap.release()
111
- print(f"Total video processing time: {time.time() - start_time:.2f}s")
112
- return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
113
 
114
- # Polynomial function for trajectory fitting
115
- def poly_func(x, a, b, c):
116
- return a * x**2 + b * x + c
 
117
 
118
- # Predict trajectory and wicket inline path
119
- def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
120
- if len(positions) < 3:
121
- return None, None, "Insufficient detections for trajectory prediction"
122
 
123
- x_coords = [p[0] for p in positions]
124
- y_coords = [p[1] for p in positions]
125
- frames = np.array(frame_numbers)
 
 
 
 
 
126
 
127
- # Fit polynomial to x and y coordinates
128
  try:
129
- popt_x, _ = curve_fit(poly_func, frames, x_coords)
130
- popt_y, _ = curve_fit(poly_func, frames, y_coords)
131
- except:
132
- return None, None, "Failed to fit trajectory"
133
-
134
- # Extrapolate to stumps
135
- frame_max = max(frames) + 10
136
- future_frames = np.linspace(min(frames), frame_max, 100)
137
- x_pred = poly_func(future_frames, *popt_x)
138
- y_pred = poly_func(future_frames, *popt_y)
139
-
140
- # Wicket inline path (center line toward stumps)
141
- stump_x = frame_width / 2
142
- stump_y = frame_height
143
- inline_x = np.linspace(min(x_coords), stump_x, 100)
144
- inline_y = np.interp(inline_x, x_pred, y_pred)
145
-
146
- # Check if trajectory hits stumps
147
- stump_hit = False
148
- for x, y in zip(x_pred, y_pred):
149
- if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
150
- stump_hit = True
151
- break
152
-
153
- lbw_decision = "OUT" if stump_hit else "NOT OUT"
154
- return list(zip(future_frames, x_pred, y_pred)), list(zip(inline_x, inline_y)), lbw_decision
155
-
156
- # Map pitch location
157
- def map_pitch(bounce_point, frame_width, frame_height):
158
- if bounce_point is None:
159
- return None, "No bounce detected"
160
-
161
- x, y = bounce_point
162
- pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
163
- pitch_y = (1 - y / frame_height) * PITCH_LENGTH
164
- return pitch_x, pitch_y
165
-
166
- # Estimate ball speed
167
- def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
168
- if len(positions) < 2:
169
- return None, "Insufficient detections for speed estimation"
170
-
171
- distances = []
172
- for i in range(1, len(positions)):
173
- x1, y1 = positions[i-1]
174
- x2, y2 = positions[i]
175
- pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
176
- distances.append(pixel_dist)
177
-
178
- pixel_to_meter = PITCH_LENGTH / frame_width
179
- distances_m = [d * pixel_to_meter for d in distances]
180
- time_interval = 1 / frame_rate
181
- speeds = [d / time_interval for d in distances_m]
182
- avg_speed_kmh = np.mean(speeds) * 3.6
183
- return avg_speed_kmh, "Speed calculated successfully"
184
-
185
- # Main Gradio function with video overlay and slow motion
186
- def drs_analysis(video):
187
- # Video is a file path (string) in Hugging Face Spaces
188
- video_path = video if isinstance(video, str) else "temp_video.mp4"
189
- if not isinstance(video, str):
190
- with open(video_path, "wb") as f:
191
- f.write(video.read())
192
-
193
- # Process video for detections
194
- positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
195
- if not positions:
196
- return None, None, "No ball detected in video", None
197
-
198
- # Predict trajectory and wicket path
199
- trajectory, inline_path, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
200
- if trajectory is None:
201
- return None, None, lbw_decision, None
202
-
203
- pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
204
- speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
205
-
206
- # Create output video with overlays and slow motion
207
- output_path = "output_video.mp4"
208
- cap = cv2.VideoCapture(video_path)
209
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
210
- out = cv2.VideoWriter(output_path, fourcc, frame_rate, (frame_width, frame_height))
211
-
212
- frame_count = 0
213
- positions_dict = dict(zip(frame_numbers, positions))
214
 
215
- while cap.isOpened():
216
- ret, frame = cap.read()
217
- if not ret:
218
- break
219
-
220
- # Skip frames for consistency with detection
221
- if frame_count % FRAME_SKIP != 0:
222
- frame_count += 1
223
- continue
224
-
225
- # Overlay ball trajectory (red) and wicket inline path (blue)
226
- if frame_count in positions_dict:
227
- cv2.circle(frame, (int(positions_dict[frame_count][0]), int(positions_dict[frame_count][1])), 5, (0, 0, 255), -1) # Red dot
228
  if trajectory:
229
- traj_x = [int(t[1]) for t in trajectory if t[0] >= frame_count]
230
- traj_y = [int(t[2]) for t in trajectory if t[0] >= frame_count]
231
- if traj_x and traj_y:
232
- for i in range(1, len(traj_x)):
233
- cv2.line(frame, (traj_x[i-1], traj_y[i-1]), (traj_x[i], traj_y[i]), (0, 0, 255), 2) # Red line
234
- if inline_path:
235
- inline_x = [int(x) for x, _ in inline_path]
236
- inline_y = [int(y) for _, y in inline_path]
237
- if inline_x and inline_y:
238
- for i in range(1, len(inline_x)):
239
- cv2.line(frame, (inline_x[i-1], inline_y[i-1]), (inline_x[i], inline_y[i]), (255, 0, 0), 2) # Blue line
240
-
241
- # Overlay pitch map in top-right corner
242
- if pitch_x is not None and pitch_y is not None:
243
- map_width = 200
244
- # Cap map_height to 25% of frame height to ensure it fits
245
- map_height = min(int(map_width * PITCH_LENGTH / PITCH_WIDTH), frame_height // 4)
246
- pitch_map = np.zeros((map_height, map_width, 3), dtype=np.uint8)
247
- pitch_map[:] = (0, 255, 0) # Green pitch
248
- cv2.rectangle(pitch_map, (0, map_height-10), (map_width, map_height), (0, 51, 51), -1) # Brown stumps
249
- bounce_x = int((pitch_x + PITCH_WIDTH/2) / PITCH_WIDTH * map_width)
250
- bounce_y = int((1 - pitch_y / PITCH_LENGTH) * map_height)
251
- cv2.circle(pitch_map, (bounce_x, bounce_y), 5, (0, 0, 255), -1) # Red bounce point
252
- # Ensure overlay fits within frame
253
- overlay_region = frame[0:map_height, frame_width-map_width:frame_width]
254
- if overlay_region.shape[0] >= map_height and overlay_region.shape[1] >= map_width:
255
- frame[0:map_height, frame_width-map_width:frame_width] = cv2.resize(pitch_map, (map_width, map_height))
256
-
257
- # Add text annotations
258
- text = f"LBW: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h"
259
- cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
260
-
261
- # Write frame multiple times for slow motion
262
- for _ in range(SLOW_MOTION_FACTOR):
263
  out.write(frame)
264
-
265
- frame_count += 1
266
-
267
- cap.release()
268
  out.release()
 
 
 
 
 
 
 
 
 
269
 
270
- if not isinstance(video, str):
271
- os.remove(video_path)
 
272
 
273
- return None, None, None, output_path
 
 
274
 
275
  # Gradio interface
276
- with gr.Blocks() as demo:
277
- gr.Markdown("## Cricket DRS Analysis")
278
- video_input = gr.Video(label="Upload Video Clip")
279
- btn = gr.Button("Analyze")
280
- trajectory_output = gr.Plot(label="Ball Trajectory")
281
- pitch_output = gr.Plot(label="Pitch Map")
282
- text_output = gr.Textbox(label="Analysis Results")
283
- video_output = gr.Video(label="Processed Video")
284
- btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])
 
285
 
286
  if __name__ == "__main__":
287
- demo.launch()
 
1
  import cv2
2
  import numpy as np
 
 
 
3
  import torch
4
+ from ultralytics import YOLO
5
  import gradio as gr
6
+ from scipy.interpolate import interp1d
7
+ import uuid
8
  import os
 
 
 
9
 
10
+ # Load the trained YOLOv8n model from the Space's root directory
11
+ model = YOLO("best.pt") # Assumes best.pt is in the same directory as app.py
12
 
13
+ # Constants for LBW decision and video processing
14
+ STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
+ BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
+ FRAME_RATE = 30 # Input video frame rate
17
+ SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
18
+ CONF_THRESHOLD = 0.3 # Lowered confidence threshold for better detection
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def process_video(video_path):
21
+ # Initialize video capture
22
+ if not os.path.exists(video_path):
23
+ return [], [], "Error: Video file not found"
24
  cap = cv2.VideoCapture(video_path)
25
+ frames = []
26
+ ball_positions = []
27
+ debug_log = []
 
 
 
 
 
 
 
28
 
29
+ frame_count = 0
30
  while cap.isOpened():
 
31
  ret, frame = cap.read()
32
  if not ret:
33
  break
 
 
 
 
 
 
 
 
 
 
34
  frame_count += 1
35
+ frames.append(frame.copy()) # Store original frame
36
+ # Detect ball using the trained YOLOv8n model
37
+ results = model.predict(frame, conf=CONF_THRESHOLD)
38
+ detections = 0
39
+ for detection in results[0].boxes:
40
+ if detection.cls == 0: # Assuming class 0 is the ball
41
+ detections += 1
42
+ x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
43
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
44
+ # Draw bounding box on frame for visualization
45
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
46
+ frames[-1] = frame # Update frame with bounding box
47
+ debug_log.append(f"Frame {frame_count}: {detections} ball detections")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  cap.release()
 
 
49
 
50
+ if not ball_positions:
51
+ debug_log.append("No balls detected in any frame")
52
+ else:
53
+ debug_log.append(f"Total ball detections: {len(ball_positions)}")
54
 
55
+ return frames, ball_positions, "\n".join(debug_log)
 
 
 
56
 
57
+ def estimate_trajectory(ball_positions, frames):
58
+ # Simplified physics-based trajectory projection
59
+ if len(ball_positions) < 2:
60
+ return None, None, "Error: Fewer than 2 ball detections for trajectory"
61
+ # Extract x, y coordinates
62
+ x_coords = [pos[0] for pos in ball_positions]
63
+ y_coords = [pos[1] for pos in ball_positions]
64
+ times = np.arange(len(ball_positions)) / FRAME_RATE
65
 
66
+ # Interpolate to smooth trajectory
67
  try:
68
+ fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
69
+ fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
70
+ except Exception as e:
71
+ return None, None, f"Error in trajectory interpolation: {str(e)}"
72
+
73
+ # Project trajectory forward (0.5 seconds post-impact)
74
+ t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
75
+ x_future = fx(t_future)
76
+ y_future = fy(t_future)
77
+
78
+ return list(zip(x_future, y_future)), t_future, "Trajectory estimated successfully"
79
+
80
+ def lbw_decision(ball_positions, trajectory, frames):
81
+ # Simplified LBW logic
82
+ if not frames:
83
+ return "Error: No frames processed", None, None, None
84
+ if not trajectory or len(ball_positions) < 2:
85
+ return "Not enough data (insufficient ball detections)", None, None, None
86
+
87
+ # Assume stumps are at the bottom center of the frame (calibration needed)
88
+ frame_height, frame_width = frames[0].shape[:2]
89
+ stumps_x = frame_width / 2
90
+ stumps_y = frame_height * 0.9 # Approximate stumps position
91
+ stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0) # Assume 3m pitch width
92
+
93
+ # Store pitch and impact points
94
+ pitch_point = ball_positions[0]
95
+ impact_point = ball_positions[-1]
96
+
97
+ # Check pitching point
98
+ pitch_x, pitch_y = pitch_point
99
+ if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
100
+ return f"Not Out (Pitched outside line at x: {pitch_x:.1f}, y: {pitch_y:.1f})", trajectory, pitch_point, impact_point
101
+
102
+ # Check impact point
103
+ impact_x, impact_y = impact_point
104
+ if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
105
+ return f"Not Out (Impact outside line at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
106
+
107
+ # Check trajectory hitting stumps
108
+ for x, y in trajectory:
109
+ if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
110
+ return f"Out (Ball hits stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
111
+ return f"Not Out (Missing stumps, Pitch at x: {pitch_x:.1f}, y: {pitch_y:.1f}, Impact at x: {impact_x:.1f}, y: {impact_y:.1f})", trajectory, pitch_point, impact_point
112
+
113
+ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path):
114
+ # Generate very slow-motion video with ball detection, trajectory, and pitch/impact points
115
+ if not frames:
116
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
118
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
 
 
 
119
 
120
+ for frame in frames:
121
+ # Draw trajectory
 
 
 
 
 
 
 
 
 
 
 
122
  if trajectory:
123
+ for x, y in trajectory:
124
+ cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots for trajectory
125
+
126
+ # Draw pitch point (red circle with label)
127
+ if pitch_point:
128
+ x, y = pitch_point
129
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red circle
130
+ cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
131
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
132
+
133
+ # Draw impact point (yellow circle with label)
134
+ if impact_point:
135
+ x, y = impact_point
136
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1) # Yellow circle
137
+ cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
138
+ cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
139
+
140
+ for _ in range(SLOW_MOTION_FACTOR): # Duplicate frames for very slow motion
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
  out.write(frame)
 
 
 
 
142
  out.release()
143
+ return output_path
144
+
145
+ def drs_review(video):
146
+ # Process video and generate DRS output
147
+ frames, ball_positions, debug_log = process_video(video)
148
+ if not frames:
149
+ return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None
150
+ trajectory, _, trajectory_log = estimate_trajectory(ball_positions, frames)
151
+ decision, trajectory, pitch_point, impact_point = lbw_decision(ball_positions, trajectory, frames)
152
 
153
+ # Generate slow-motion replay with enhanced annotations
154
+ output_path = f"output_{uuid.uuid4()}.mp4"
155
+ slow_motion_path = generate_slow_motion(frames, trajectory, pitch_point, impact_point, output_path)
156
 
157
+ # Combine debug logs for output
158
+ debug_output = f"{debug_log}\n{trajectory_log}"
159
+ return f"DRS Decision: {decision}\nDebug Log:\n{debug_output}", slow_motion_path
160
 
161
  # Gradio interface
162
+ iface = gr.Interface(
163
+ fn=drs_review,
164
+ inputs=gr.Video(label="Upload Video Clip"),
165
+ outputs=[
166
+ gr.Textbox(label="DRS Decision and Debug Log"),
167
+ gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue), Pitch Point (Red), Impact Point (Yellow)")
168
+ ],
169
+ title="AI-Powered DRS for LBW in Local Cricket",
170
+ description="Upload a video clip of a cricket delivery to get an LBW decision and very slow-motion replay showing ball detection (green boxes), trajectory (blue dots), pitch point (red circle), and impact point (yellow circle)."
171
+ )
172
 
173
  if __name__ == "__main__":
174
+ iface.launch()