AjaykumarPilla commited on
Commit
61746ab
·
verified ·
1 Parent(s): 239672f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -55
app.py CHANGED
@@ -4,36 +4,37 @@ import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
7
- from scipy.signal import savgol_filter
8
  import plotly.graph_objects as go
9
  import uuid
10
  import os
 
11
 
12
- # Load the trained YOLOv8n model
13
  model = YOLO("best.pt")
14
  model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
15
 
16
- # Constants
17
- STUMPS_WIDTH = 0.2286
18
- BALL_DIAMETER = 0.073
19
- FRAME_RATE = 20
20
- SLOW_MOTION_FACTOR = 1.5
21
- CONF_THRESHOLD = 0.15
22
- IMPACT_ZONE_Y = 0.9
23
- PITCH_LENGTH = 20.12
24
- STUMPS_HEIGHT = 0.71
25
- CAMERA_HEIGHT = 2.0
26
- CAMERA_DISTANCE = 10.0
27
- MAX_POSITION_JUMP = 250
28
 
29
  def process_video(video_path):
30
  if not os.path.exists(video_path):
31
  return [], [], [], "Error: Video file not found"
32
  cap = cv2.VideoCapture(video_path)
 
33
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
- global FRAME_RATE
36
- FRAME_RATE = cap.get(cv2.CAP_PROP_FPS) or 20
37
  stride = 32
38
  img_width = ((frame_width + stride - 1) // stride) * stride
39
  img_height = ((frame_height + stride - 1) // stride) * stride
@@ -49,17 +50,18 @@ def process_video(video_path):
49
  break
50
  frame_count += 1
51
  frames.append(frame.copy())
 
52
  frame = cv2.convertScaleAbs(frame, alpha=1.5, beta=20)
53
  kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
54
  frame = cv2.filter2D(frame, -1, kernel)
55
  results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(img_height, img_width), iou=0.5, max_det=5)
56
  detections = sum(1 for detection in results[0].boxes if detection.cls == 0)
57
- if detections >= 1:
58
  max_conf = 0
59
  best_detection = None
60
  conf_scores = []
61
  for detection in results[0].boxes:
62
- if detection.cls == 0:
63
  conf = detection.conf.cpu().numpy()[0]
64
  conf_scores.append(conf)
65
  if conf > max_conf:
@@ -67,6 +69,7 @@ def process_video(video_path):
67
  best_detection = detection
68
  if best_detection:
69
  x1, y1, x2, y2 = best_detection.xyxy[0].cpu().numpy()
 
70
  x1 = x1 * frame_width / img_width
71
  x2 = x2 * frame_width / img_width
72
  y1 = y1 * frame_height / img_height
@@ -78,6 +81,7 @@ def process_video(video_path):
78
  else:
79
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
80
  frames[-1] = frame
 
81
  cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
82
  cap.release()
83
 
@@ -91,35 +95,33 @@ def process_video(video_path):
91
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
92
 
93
  def pixel_to_3d(x, y, frame_height, frame_width):
 
94
  x_norm = x / frame_width
95
  y_norm = y / frame_height
96
- x_3d = (x_norm - 0.5) * 3.0
97
  y_3d = y_norm * PITCH_LENGTH
98
- z_3d = (1 - y_norm) * BALL_DIAMETER * 5
99
  return x_3d, y_3d, z_3d
100
 
101
  def estimate_trajectory(ball_positions, frames, detection_frames):
102
  if len(ball_positions) < 2:
103
  return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 frames with one ball detection"
104
-
105
  frame_height, frame_width = frames[0].shape[:2]
106
  debug_log = []
107
 
 
108
  filtered_positions = [ball_positions[0]]
109
  filtered_frames = [detection_frames[0]]
110
-
111
  for i in range(1, len(ball_positions)):
112
  prev_pos = filtered_positions[-1]
113
  curr_pos = ball_positions[i]
114
- distance = np.linalg.norm(np.array(curr_pos) - np.array(prev_pos))
115
- frame_gap = detection_frames[i] - filtered_frames[-1]
116
- velocity = distance / frame_gap if frame_gap > 0 else 0
117
-
118
- if distance <= MAX_POSITION_JUMP and velocity < 100:
119
  filtered_positions.append(curr_pos)
120
  filtered_frames.append(detection_frames[i])
121
  else:
122
- debug_log.append(f"Filtered out frame {detection_frames[i]} due to sudden jump: distance={distance:.1f}, velocity={velocity:.1f}")
 
123
 
124
  if len(filtered_positions) < 2:
125
  return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
@@ -128,60 +130,64 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
128
  y_coords = [pos[1] for pos in filtered_positions]
129
  times = np.array(filtered_frames) / FRAME_RATE
130
 
131
- try:
132
- x_coords = savgol_filter(x_coords, window_length=5, polyorder=2, mode='nearest')
133
- y_coords = savgol_filter(y_coords, window_length=5, polyorder=2, mode='nearest')
134
- except Exception as e:
135
- return None, None, None, None, None, None, None, None, None, f"Smoothing error: {str(e)}"
136
 
 
137
  detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in zip(x_coords, y_coords)]
138
-
 
139
  pitch_idx = min(range(len(filtered_positions)), key=lambda i: y_coords[i])
140
  pitch_point = (x_coords[pitch_idx], y_coords[pitch_idx])
141
  pitch_frame = filtered_frames[pitch_idx]
142
 
 
143
  post_pitch_indices = [i for i in range(len(filtered_positions)) if filtered_frames[i] > pitch_frame]
144
  if not post_pitch_indices:
145
  return None, None, None, None, None, None, None, None, None, "Error: No detections after pitch point"
146
-
147
  impact_idx = max(post_pitch_indices, key=lambda i: y_coords[i])
148
  impact_point = (x_coords[impact_idx], y_coords[impact_idx])
149
  impact_frame = filtered_frames[impact_idx]
150
 
151
  try:
 
152
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
153
  fy = interp1d(times, y_coords, kind='linear', fill_value="extrapolate")
154
  except Exception as e:
155
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
156
 
 
157
  total_frames = max(detection_frames) - min(detection_frames) + 1
158
  t_full = np.linspace(min(detection_frames) / FRAME_RATE, max(detection_frames) / FRAME_RATE, int(total_frames * SLOW_MOTION_FACTOR))
159
  x_full = fx(t_full)
160
  y_full = fy(t_full)
161
  trajectory_2d = list(zip(x_full, y_full))
162
- trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
163
- pitch_point_3d = pixel_to_3d(*pitch_point, frame_height, frame_width)
164
- impact_point_3d = pixel_to_3d(*impact_point, frame_height, frame_width)
165
 
166
- import matplotlib.pyplot as plt
167
- plt.figure(figsize=(10, 6))
168
- plt.plot([p[0] for p in ball_positions], [p[1] for p in ball_positions], 'kx-', label='Original')
169
- plt.plot(x_coords, y_coords, 'bo-', label='Smoothed Trajectory')
170
- plt.scatter(pitch_point[0], pitch_point[1], color='red', label='Pitch Point')
171
- plt.scatter(impact_point[0], impact_point[1], color='yellow', label='Impact Point')
172
- plt.legend()
173
- plt.title("Ball Trajectory Filtering & Smoothing")
174
- plt.savefig("trajectory_smooth_debug.png")
175
 
 
176
  debug_log.extend([
177
  f"Trajectory estimated successfully",
178
- f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f})",
179
- f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f})"
 
 
180
  ])
 
 
 
 
 
 
 
181
 
182
  return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "\n".join(debug_log)
183
 
184
  def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
 
185
  stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
186
  stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
187
  stump_z = [0, 0, 0]
@@ -288,18 +294,18 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
288
  trajectory_indices = []
289
 
290
  for i, frame in enumerate(frames):
291
- frame_idx = i - min(detection_frames) if trajectory_indices else -1
292
- if 0 <= frame_idx < total_frames and trajectory_points.size > 0:
293
  end_idx = trajectory_indices[frame_idx] + 1
294
- cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)
295
  if pitch_point and i == pitch_frame:
296
  x, y = pitch_point
297
- cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
298
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
299
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
300
  if impact_point and i == impact_frame:
301
  x, y = impact_point
302
- cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
303
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
304
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
305
  for _ in range(int(SLOW_MOTION_FACTOR)):
@@ -349,4 +355,4 @@ iface = gr.Interface(
349
  )
350
 
351
  if __name__ == "__main__":
352
- iface.launch()
 
4
  from ultralytics import YOLO
5
  import gradio as gr
6
  from scipy.interpolate import interp1d
 
7
  import plotly.graph_objects as go
8
  import uuid
9
  import os
10
+ from scipy.ndimage import uniform_filter1d
11
 
12
+ # Load the trained YOLOv8n model with optimizations
13
  model = YOLO("best.pt")
14
  model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
15
 
16
+ # Constants for LBW decision and video processing
17
+ STUMPS_WIDTH = 0.2286 # meters (width of stumps)
18
+ BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
19
+ FRAME_RATE = 20 # Default frame rate, updated dynamically
20
+ SLOW_MOTION_FACTOR = 1.5 # Faster replay (e.g., 30 / 1.5 = 20 FPS)
21
+ CONF_THRESHOLD = 0.15 # Lowered for better detection
22
+ IMPACT_ZONE_Y = 0.9 # Adjusted to 90% of frame height for impact zone
23
+ PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
24
+ STUMPS_HEIGHT = 0.71 # meters (stumps height)
25
+ CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
26
+ CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
27
+ MAX_POSITION_JUMP = 250 # Increased to include more detections
28
 
29
  def process_video(video_path):
30
  if not os.path.exists(video_path):
31
  return [], [], [], "Error: Video file not found"
32
  cap = cv2.VideoCapture(video_path)
33
+ # Get native video resolution and frame rate
34
  frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
  frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ FRAME_RATE = cap.get(cv2.CAP_PROP_FPS) or 20 # Use actual frame rate or default
37
+ # Adjust image size to be multiple of 32 for YOLO
38
  stride = 32
39
  img_width = ((frame_width + stride - 1) // stride) * stride
40
  img_height = ((frame_height + stride - 1) // stride) * stride
 
50
  break
51
  frame_count += 1
52
  frames.append(frame.copy())
53
+ # Enhance frame contrast and sharpness
54
  frame = cv2.convertScaleAbs(frame, alpha=1.5, beta=20)
55
  kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
56
  frame = cv2.filter2D(frame, -1, kernel)
57
  results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(img_height, img_width), iou=0.5, max_det=5)
58
  detections = sum(1 for detection in results[0].boxes if detection.cls == 0)
59
+ if detections >= 1: # Process frames with at least one ball detection
60
  max_conf = 0
61
  best_detection = None
62
  conf_scores = []
63
  for detection in results[0].boxes:
64
+ if detection.cls == 0: # Class 0 is the ball
65
  conf = detection.conf.cpu().numpy()[0]
66
  conf_scores.append(conf)
67
  if conf > max_conf:
 
69
  best_detection = detection
70
  if best_detection:
71
  x1, y1, x2, y2 = best_detection.xyxy[0].cpu().numpy()
72
+ # Scale coordinates back to original frame size
73
  x1 = x1 * frame_width / img_width
74
  x2 = x2 * frame_width / img_width
75
  y1 = y1 * frame_height / img_height
 
81
  else:
82
  debug_log.append(f"Frame {frame_count}: {detections} ball detections")
83
  frames[-1] = frame
84
+ # Save debug frame
85
  cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
86
  cap.release()
87
 
 
95
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
96
 
97
  def pixel_to_3d(x, y, frame_height, frame_width):
98
+ """Convert 2D pixel coordinates to 3D real-world coordinates."""
99
  x_norm = x / frame_width
100
  y_norm = y / frame_height
101
+ x_3d = (x_norm - 0.5) * 3.0 # Center x at 0 (middle of pitch)
102
  y_3d = y_norm * PITCH_LENGTH
103
+ z_3d = (1 - y_norm) * BALL_DIAMETER * 5 # Scale to approximate ball bounce height
104
  return x_3d, y_3d, z_3d
105
 
106
  def estimate_trajectory(ball_positions, frames, detection_frames):
107
  if len(ball_positions) < 2:
108
  return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 frames with one ball detection"
 
109
  frame_height, frame_width = frames[0].shape[:2]
110
  debug_log = []
111
 
112
+ # Filter out sudden changes in position for continuous trajectory
113
  filtered_positions = [ball_positions[0]]
114
  filtered_frames = [detection_frames[0]]
 
115
  for i in range(1, len(ball_positions)):
116
  prev_pos = filtered_positions[-1]
117
  curr_pos = ball_positions[i]
118
+ distance = np.sqrt((curr_pos[0] - prev_pos[0])**2 + (curr_pos[1] - prev_pos[1])**2)
119
+ if distance <= MAX_POSITION_JUMP:
 
 
 
120
  filtered_positions.append(curr_pos)
121
  filtered_frames.append(detection_frames[i])
122
  else:
123
+ debug_log.append(f"Filtered out detection at frame {detection_frames[i] + 1}: large jump ({distance:.1f} pixels)")
124
+ continue
125
 
126
  if len(filtered_positions) < 2:
127
  return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
 
130
  y_coords = [pos[1] for pos in filtered_positions]
131
  times = np.array(filtered_frames) / FRAME_RATE
132
 
133
+ # Smooth coordinates to avoid sudden jumps
134
+ x_coords = uniform_filter1d(x_coords, size=3)
135
+ y_coords = uniform_filter1d(y_coords, size=3)
 
 
136
 
137
+ # Convert to 3D for visualization
138
  detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in zip(x_coords, y_coords)]
139
+
140
+ # Pitch point: Detection with lowest y-coordinate (near bowler's end)
141
  pitch_idx = min(range(len(filtered_positions)), key=lambda i: y_coords[i])
142
  pitch_point = (x_coords[pitch_idx], y_coords[pitch_idx])
143
  pitch_frame = filtered_frames[pitch_idx]
144
 
145
+ # Impact point: Detection with highest y-coordinate after pitch point (near stumps)
146
  post_pitch_indices = [i for i in range(len(filtered_positions)) if filtered_frames[i] > pitch_frame]
147
  if not post_pitch_indices:
148
  return None, None, None, None, None, None, None, None, None, "Error: No detections after pitch point"
 
149
  impact_idx = max(post_pitch_indices, key=lambda i: y_coords[i])
150
  impact_point = (x_coords[impact_idx], y_coords[impact_idx])
151
  impact_frame = filtered_frames[impact_idx]
152
 
153
  try:
154
+ # Use linear interpolation for stable trajectory
155
  fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
156
  fy = interp1d(times, y_coords, kind='linear', fill_value="extrapolate")
157
  except Exception as e:
158
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
159
 
160
+ # Generate dense points for all frames between first and last detection
161
  total_frames = max(detection_frames) - min(detection_frames) + 1
162
  t_full = np.linspace(min(detection_frames) / FRAME_RATE, max(detection_frames) / FRAME_RATE, int(total_frames * SLOW_MOTION_FACTOR))
163
  x_full = fx(t_full)
164
  y_full = fy(t_full)
165
  trajectory_2d = list(zip(x_full, y_full))
 
 
 
166
 
167
+ trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
168
+ pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
169
+ impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
 
 
 
 
 
 
170
 
171
+ # Debug trajectory and points
172
  debug_log.extend([
173
  f"Trajectory estimated successfully",
174
+ f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f}), 3D: {pitch_point_3d}",
175
+ f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f}), 3D: {impact_point_3d}",
176
+ f"Detections in frames: {filtered_frames}",
177
+ f"Total filtered detections: {len(filtered_frames)}"
178
  ])
179
+ # Save trajectory plot for debugging
180
+ import matplotlib.pyplot as plt
181
+ plt.plot(x_coords, y_coords, 'bo-', label='Filtered Detections')
182
+ plt.plot(pitch_point[0], pitch_point[1], 'ro', label='Pitch Point')
183
+ plt.plot(impact_point[0], impact_point[1], 'yo', label='Impact Point')
184
+ plt.legend()
185
+ plt.savefig("trajectory_debug.png")
186
 
187
  return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "\n".join(debug_log)
188
 
189
  def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
190
+ """Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
191
  stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
192
  stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
193
  stump_z = [0, 0, 0]
 
294
  trajectory_indices = []
295
 
296
  for i, frame in enumerate(frames):
297
+ frame_idx = i - min_frame if trajectory_indices else -1
298
+ if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
299
  end_idx = trajectory_indices[frame_idx] + 1
300
+ cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2) # Blue line in BGR
301
  if pitch_point and i == pitch_frame:
302
  x, y = pitch_point
303
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red circle
304
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
305
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
306
  if impact_point and i == impact_frame:
307
  x, y = impact_point
308
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1) # Yellow circle
309
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
310
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
311
  for _ in range(int(SLOW_MOTION_FACTOR)):
 
355
  )
356
 
357
  if __name__ == "__main__":
358
+ iface.launch()