AjaykumarPilla commited on
Commit
6e725f6
·
verified ·
1 Parent(s): 58a64bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -256
app.py CHANGED
@@ -1,287 +1,133 @@
1
  import cv2
2
  import numpy as np
3
- import pandas as pd
4
- import plotly.express as px
5
- import plotly.graph_objects as go
6
  import torch
 
7
  import gradio as gr
 
 
8
  import os
9
- import time
10
- from scipy.optimize import curve_fit
11
- import sys
12
 
13
- # Add yolov5 directory to sys.path
14
- sys.path.append(os.path.join(os.path.dirname(__file__), "yolov5"))
15
 
16
- # Import YOLOv5 modules
17
- from models.experimental import attempt_load
18
- from utils.general import non_max_suppression, xywh2xyxy
 
19
 
20
- # Cricket pitch dimensions (in meters)
21
- PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
22
- PITCH_WIDTH = 3.05 # Width of pitch
23
- STUMP_HEIGHT = 0.71 # Stump height
24
- STUMP_WIDTH = 0.2286 # Stump width (including bails)
25
-
26
- # Model input size (adjust if yolov5s.pt was trained with a different size)
27
- MODEL_INPUT_SIZE = (640, 640) # (height, width)
28
- FRAME_SKIP = 2 # Process every 2nd frame
29
- MIN_DETECTIONS = 10 # Stop after 10 detections
30
- BATCH_SIZE = 4 # Process 4 frames at a time
31
- SLOW_MOTION_FACTOR = 3 # Duplicate each frame 3 times for slow motion
32
-
33
- # Load model
34
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
- model = attempt_load("best.pt") # Load yolov5s.pt
36
- model.to(device).eval() # Move model to device and set to evaluation mode
37
-
38
- # Function to process video and detect ball
39
  def process_video(video_path):
 
40
  cap = cv2.VideoCapture(video_path)
41
- frame_rate = cap.get(cv2.CAP_PROP_FPS)
42
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
43
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
44
- positions = []
45
- frame_numbers = []
46
- bounce_frame = None
47
- bounce_point = None
48
- batch_frames = []
49
- batch_frame_nums = []
50
- frame_count = 0
51
 
52
- start_time = time.time()
53
  while cap.isOpened():
54
- frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
55
  ret, frame = cap.read()
56
  if not ret:
57
  break
58
-
59
- # Skip frames
60
- if frame_count % FRAME_SKIP != 0:
61
- frame_count += 1
62
- continue
63
-
64
- # Resize frame to model input size
65
- frame = cv2.resize(frame, MODEL_INPUT_SIZE, interpolation=cv2.INTER_AREA)
66
- batch_frames.append(frame)
67
- batch_frame_nums.append(frame_num)
68
- frame_count += 1
69
-
70
- # Process batch when full or at end
71
- if len(batch_frames) == BATCH_SIZE or not ret:
72
- # Preprocess batch
73
- batch = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in batch_frames]
74
- batch = np.stack(batch) # [batch_size, H, W, 3]
75
- batch = torch.from_numpy(batch).to(device).float() / 255.0
76
- batch = batch.permute(0, 3, 1, 2) # [batch_size, 3, H, W]
77
-
78
- # Run inference
79
- frame_start_time = time.time()
80
- with torch.no_grad():
81
- pred = model(batch)[0]
82
- pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
83
- print(f"Batch inference time: {time.time() - frame_start_time:.2f}s for {len(batch_frames)} frames")
84
-
85
- # Process detections
86
- for i, det in enumerate(pred):
87
- if det is not None and len(det):
88
- det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
89
- for *xyxy, conf, cls in det:
90
- x_center = (xyxy[0] + xyxy[2]) / 2
91
- y_center = (xyxy[1] + xyxy[3]) / 2
92
- # Scale coordinates back to original frame size
93
- x_center = x_center * frame_width / MODEL_INPUT_SIZE[1]
94
- y_center = y_center * frame_height / MODEL_INPUT_SIZE[0]
95
- positions.append((x_center.item(), y_center.item()))
96
- frame_numbers.append(batch_frame_nums[i])
97
-
98
- # Detect bounce (lowest y_center point)
99
- if bounce_frame is None or y_center > positions[bounce_frame][1]:
100
- bounce_frame = len(frame_numbers) - 1
101
- bounce_point = (x_center.item(), y_center.item())
102
-
103
- batch_frames = []
104
- batch_frame_nums = []
105
-
106
- # Early termination
107
- if len(positions) >= MIN_DETECTIONS:
108
- break
109
-
110
  cap.release()
111
- print(f"Total video processing time: {time.time() - start_time:.2f}s")
112
- return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
113
-
114
- # Polynomial function for trajectory fitting
115
- def poly_func(x, a, b, c):
116
- return a * x**2 + b * x + c
117
 
118
- # Predict trajectory and wicket inline path
119
- def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
120
- if len(positions) < 3:
121
- return None, None, "Insufficient detections for trajectory prediction"
122
 
123
- x_coords = [p[0] for p in positions]
124
- y_coords = [p[1] for p in positions]
125
- frames = np.array(frame_numbers)
 
 
 
 
 
126
 
127
- # Fit polynomial to x and y coordinates
128
  try:
129
- popt_x, _ = curve_fit(poly_func, frames, x_coords)
130
- popt_y, _ = curve_fit(poly_func, frames, y_coords)
131
  except:
132
- return None, None, "Failed to fit trajectory"
133
-
134
- # Extrapolate to stumps
135
- frame_max = max(frames) + 10
136
- future_frames = np.linspace(min(frames), frame_max, 100)
137
- x_pred = poly_func(future_frames, *popt_x)
138
- y_pred = poly_func(future_frames, *popt_y)
139
-
140
- # Wicket inline path (center line toward stumps)
141
- stump_x = frame_width / 2
142
- stump_y = frame_height
143
- inline_x = np.linspace(min(x_coords), stump_x, 100)
144
- inline_y = np.interp(inline_x, x_pred, y_pred)
145
-
146
- # Check if trajectory hits stumps
147
- stump_hit = False
148
- for x, y in zip(x_pred, y_pred):
149
- if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
150
- stump_hit = True
151
- break
152
-
153
- lbw_decision = "OUT" if stump_hit else "NOT OUT"
154
- return list(zip(future_frames, x_pred, y_pred)), list(zip(inline_x, inline_y)), lbw_decision
155
-
156
- # Map pitch location
157
- def map_pitch(bounce_point, frame_width, frame_height):
158
- if bounce_point is None:
159
- return None, "No bounce detected"
160
-
161
- x, y = bounce_point
162
- pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
163
- pitch_y = (1 - y / frame_height) * PITCH_LENGTH
164
- return pitch_x, pitch_y
165
-
166
- # Estimate ball speed
167
- def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
168
- if len(positions) < 2:
169
- return None, "Insufficient detections for speed estimation"
170
-
171
- distances = []
172
- for i in range(1, len(positions)):
173
- x1, y1 = positions[i-1]
174
- x2, y2 = positions[i]
175
- pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
176
- distances.append(pixel_dist)
177
-
178
- pixel_to_meter = PITCH_LENGTH / frame_width
179
- distances_m = [d * pixel_to_meter for d in distances]
180
- time_interval = 1 / frame_rate
181
- speeds = [d / time_interval for d in distances_m]
182
- avg_speed_kmh = np.mean(speeds) * 3.6
183
- return avg_speed_kmh, "Speed calculated successfully"
184
-
185
- # Main Gradio function with video overlay and slow motion
186
- def drs_analysis(video):
187
- # Video is a file path (string) in Hugging Face Spaces
188
- video_path = video if isinstance(video, str) else "temp_video.mp4"
189
- if not isinstance(video, str):
190
- with open(video_path, "wb") as f:
191
- f.write(video.read())
192
-
193
- # Process video for detections
194
- positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
195
- if not positions:
196
- return None, None, "No ball detected in video", None
197
-
198
- # Predict trajectory and wicket path
199
- trajectory, inline_path, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
200
- if trajectory is None:
201
- return None, None, lbw_decision, None
202
-
203
- pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
204
- speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
205
-
206
- # Create output video with overlays and slow motion
207
- output_path = "output_video.mp4"
208
- cap = cv2.VideoCapture(video_path)
209
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
210
- out = cv2.VideoWriter(output_path, fourcc, frame_rate, (frame_width, frame_height))
211
-
212
- frame_count = 0
213
- positions_dict = dict(zip(frame_numbers, positions))
214
 
215
- while cap.isOpened():
216
- ret, frame = cap.read()
217
- if not ret:
218
- break
219
-
220
- # Skip frames for consistency with detection
221
- if frame_count % FRAME_SKIP != 0:
222
- frame_count += 1
223
- continue
224
-
225
- # Overlay ball trajectory (red) and wicket inline path (blue)
226
- if frame_count in positions_dict:
227
- cv2.circle(frame, (int(positions_dict[frame_count][0]), int(positions_dict[frame_count][1])), 5, (0, 0, 255), -1) # Red dot
228
  if trajectory:
229
- traj_x = [int(t[1]) for t in trajectory if t[0] >= frame_count]
230
- traj_y = [int(t[2]) for t in trajectory if t[0] >= frame_count]
231
- if traj_x and traj_y:
232
- for i in range(1, len(traj_x)):
233
- cv2.line(frame, (traj_x[i-1], traj_y[i-1]), (traj_x[i], traj_y[i]), (0, 0, 255), 2) # Red line
234
- if inline_path:
235
- inline_x = [int(x) for x, _ in inline_path]
236
- inline_y = [int(y) for _, y in inline_path]
237
- if inline_x and inline_y:
238
- for i in range(1, len(inline_x)):
239
- cv2.line(frame, (inline_x[i-1], inline_y[i-1]), (inline_x[i], inline_y[i]), (255, 0, 0), 2) # Blue line
240
-
241
- # Overlay pitch map in top-right corner
242
- if pitch_x is not None and pitch_y is not None:
243
- map_width = 200
244
- # Cap map_height to 25% of frame height to ensure it fits
245
- map_height = min(int(map_width * PITCH_LENGTH / PITCH_WIDTH), frame_height // 4)
246
- pitch_map = np.zeros((map_height, map_width, 3), dtype=np.uint8)
247
- pitch_map[:] = (0, 255, 0) # Green pitch
248
- cv2.rectangle(pitch_map, (0, map_height-10), (map_width, map_height), (0, 51, 51), -1) # Brown stumps
249
- bounce_x = int((pitch_x + PITCH_WIDTH/2) / PITCH_WIDTH * map_width)
250
- bounce_y = int((1 - pitch_y / PITCH_LENGTH) * map_height)
251
- cv2.circle(pitch_map, (bounce_x, bounce_y), 5, (0, 0, 255), -1) # Red bounce point
252
- # Ensure overlay fits within frame
253
- overlay_region = frame[0:map_height, frame_width-map_width:frame_width]
254
- if overlay_region.shape[0] >= map_height and overlay_region.shape[1] >= map_width:
255
- frame[0:map_height, frame_width-map_width:frame_width] = cv2.resize(pitch_map, (map_width, map_height))
256
-
257
- # Add text annotations
258
- text = f"LBW: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h"
259
- cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
260
-
261
- # Write frame multiple times for slow motion
262
- for _ in range(SLOW_MOTION_FACTOR):
263
- out.write(frame)
264
-
265
- frame_count += 1
266
-
267
- cap.release()
268
  out.release()
 
 
 
 
 
 
 
 
 
269
 
270
- if not isinstance(video, str):
271
- os.remove(video_path)
 
272
 
273
- return None, None, None, output_path
274
 
275
  # Gradio interface
276
- with gr.Blocks() as demo:
277
- gr.Markdown("## Cricket DRS Analysis")
278
- video_input = gr.Video(label="Upload Video Clip")
279
- btn = gr.Button("Analyze")
280
- trajectory_output = gr.Plot(label="Ball Trajectory")
281
- pitch_output = gr.Plot(label="Pitch Map")
282
- text_output = gr.Textbox(label="Analysis Results")
283
- video_output = gr.Video(label="Processed Video")
284
- btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])
 
285
 
286
  if __name__ == "__main__":
287
- demo.launch()
 
1
  import cv2
2
  import numpy as np
 
 
 
3
  import torch
4
+ from ultralytics import YOLO
5
  import gradio as gr
6
+ from scipy.interpolate import interp1d
7
+ import uuid
8
  import os
 
 
 
9
 
10
+ # Load the trained YOLOv8n model from the Space's root directory
11
+ model = YOLO("best.pt") # Assumes best.pt is in the same directory as app.py
12
 
13
+ # Constants for LBW decision
14
+ STUMPS_WIDTH = 0.2286 # meters (width of stumps)
15
+ BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
16
+ FRAME_RATE = 30 # Default frame rate for video processing
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def process_video(video_path):
19
+ # Initialize video capture
20
  cap = cv2.VideoCapture(video_path)
21
+ frames = []
22
+ ball_positions = []
 
 
 
 
 
 
 
 
23
 
 
24
  while cap.isOpened():
 
25
  ret, frame = cap.read()
26
  if not ret:
27
  break
28
+ frames.append(frame.copy()) # Store original frame
29
+ # Detect ball using the trained YOLOv8n model
30
+ results = model.predict(frame, conf=0.5) # Adjust confidence threshold if needed
31
+ for detection in results[0].boxes:
32
+ if detection.cls == 0: # Assuming class 0 is the ball
33
+ x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
34
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
35
+ # Draw bounding box on frame for visualization
36
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
37
+ frames[-1] = frame # Update frame with bounding box
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  cap.release()
 
 
 
 
 
 
39
 
40
+ return frames, ball_positions
 
 
 
41
 
42
+ def estimate_trajectory(ball_positions, frames):
43
+ # Simplified physics-based trajectory projection
44
+ if len(ball_positions) < 2:
45
+ return None, None
46
+ # Extract x, y coordinates
47
+ x_coords = [pos[0] for pos in ball_positions]
48
+ y_coords = [pos[1] for pos in ball_positions]
49
+ times = np.arange(len(ball_positions)) / FRAME_RATE
50
 
51
+ # Interpolate to smooth trajectory
52
  try:
53
+ fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
54
+ fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
55
  except:
56
+ return None, None
57
+
58
+ # Project trajectory forward (0.5 seconds post-impact)
59
+ t_future = np.linspace(times[-1], times[-1] + 0.5, 10)
60
+ x_future = fx(t_future)
61
+ y_future = fy(t_future)
62
+
63
+ return list(zip(x_future, y_future)), t_future
64
+
65
+ def lbw_decision(ball_positions, trajectory, frames):
66
+ # Simplified LBW logic
67
+ if not trajectory or len(ball_positions) < 2:
68
+ return "Not enough data", None
69
+
70
+ # Assume stumps are at the bottom center of the frame (calibration needed)
71
+ frame_height, frame_width = frames[0].shape[:2]
72
+ stumps_x = frame_width / 2
73
+ stumps_y = frame_height * 0.9 # Approximate stumps position
74
+ stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0) # Assume 3m pitch width
75
+
76
+ # Check pitching point (first detected position)
77
+ pitch_x, pitch_y = ball_positions[0]
78
+ if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
79
+ return "Not Out (Pitched outside line)", None
80
+
81
+ # Check impact point (last detected position)
82
+ impact_x, impact_y = ball_positions[-1]
83
+ if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
84
+ return "Not Out (Impact outside line)", None
85
+
86
+ # Check trajectory hitting stumps
87
+ for x, y in trajectory:
88
+ if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
89
+ return "Out", trajectory
90
+ return "Not Out (Missing stumps)", trajectory
91
+
92
+ def generate_slow_motion(frames, trajectory, output_path):
93
+ # Generate slow-motion video with ball detection and trajectory overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
95
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / 2, (frames[0].shape[1], frames[0].shape[0]))
 
 
 
96
 
97
+ for frame in frames:
 
 
 
 
 
 
 
 
 
 
 
 
98
  if trajectory:
99
+ for x, y in trajectory:
100
+ cv2.circle(frame, (int(x), int(y)), 5, (255, 0, 0), -1) # Blue dots for trajectory
101
+ out.write(frame)
102
+ out.write(frame) # Duplicate frames for slow-motion effect
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  out.release()
104
+ return output_path
105
+
106
+ def drs_review(video):
107
+ # Process video and generate DRS output
108
+ if not os.path.exists(video):
109
+ return "Error: Video file not found", None
110
+ frames, ball_positions = process_video(video)
111
+ trajectory, _ = estimate_trajectory(ball_positions, frames)
112
+ decision, trajectory = lbw_decision(ball_positions, trajectory, frames)
113
 
114
+ # Generate slow-motion replay
115
+ output_path = f"output_{uuid.uuid4()}.mp4"
116
+ slow_motion_path = generate_slow_motion(frames, trajectory, output_path)
117
 
118
+ return decision, slow_motion_path
119
 
120
  # Gradio interface
121
+ iface = gr.Interface(
122
+ fn=drs_review,
123
+ inputs=gr.Video(label="Upload Video Clip"),
124
+ outputs=[
125
+ gr.Textbox(label="DRS Decision"),
126
+ gr.Video(label="Slow-Motion Replay with Ball Detection and Trajectory")
127
+ ],
128
+ title="AI-Powered DRS for LBW in Local Cricket",
129
+ description="Upload a video clip of a cricket delivery to get an LBW decision and slow-motion replay showing ball detection and trajectory."
130
+ )
131
 
132
  if __name__ == "__main__":
133
+ iface.launch()