AjaykumarPilla commited on
Commit
9d8c79c
·
verified ·
1 Parent(s): ecfb2f8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +34 -41
app.py CHANGED
@@ -16,15 +16,15 @@ model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
16
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
17
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
18
  FRAME_RATE = 20 # Default frame rate, updated dynamically
19
- SLOW_MOTION_FACTOR = 3 # For very slow motion (3x slower)
20
- CONF_THRESHOLD = 0.25 # Lowered further to improve detection chances
21
- IMPACT_ZONE_Y = 0.8 # Fraction of frame height for impact zone
22
- IMPACT_VELOCITY_THRESHOLD = 1000 # Pixels/second for detecting impact
23
  PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
24
  STUMPS_HEIGHT = 0.71 # meters (stumps height)
25
  CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
26
  CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
27
- MAX_POSITION_JUMP = 50 # For smoother trajectory filtering
28
 
29
  def process_video(video_path):
30
  if not os.path.exists(video_path):
@@ -59,6 +59,7 @@ def process_video(video_path):
59
  if detections == 1: # Only process frames with exactly one ball detection
60
  for detection in results[0].boxes:
61
  if detection.cls == 0: # Class 0 is the ball
 
62
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
63
  # Scale coordinates back to original frame size
64
  x1 = x1 * frame_width / img_width
@@ -68,8 +69,10 @@ def process_video(video_path):
68
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
69
  detection_frames.append(frame_count - 1)
70
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
 
 
 
71
  frames[-1] = frame
72
- debug_log.append(f"Frame {frame_count}: {detections} ball detections")
73
  # Save debug frame
74
  cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
75
  cap.release()
@@ -108,6 +111,7 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
108
  filtered_positions.append(curr_pos)
109
  filtered_frames.append(detection_frames[i])
110
  else:
 
111
  continue
112
 
113
  if len(filtered_positions) < 2:
@@ -117,43 +121,32 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
117
  y_coords = [pos[1] for pos in filtered_positions]
118
  times = np.array(filtered_frames) / FRAME_RATE
119
 
120
- # Convert to 3D for pitch point detection
121
  detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
122
 
123
- # Pitch point: Detection with lowest z-coordinate (closest to ground)
124
- pitch_idx = min(range(len(detections_3d)), key=lambda i: detections_3d[i][2])
125
  pitch_point = filtered_positions[pitch_idx]
126
  pitch_frame = filtered_frames[pitch_idx]
127
 
128
- # Impact point: Detect sudden velocity change or impact zone
129
- impact_idx = None
130
- impact_frame = None
131
- velocities = [np.sqrt((x_coords[i] - x_coords[i-1])**2 + (y_coords[i] - y_coords[i-1])**2) / (times[i] - times[i-1])
132
- for i in range(1, len(x_coords))]
133
- for i in range(1, len(y_coords)):
134
- if velocities[i-1] > IMPACT_VELOCITY_THRESHOLD:
135
- impact_idx = i
136
- impact_frame = filtered_frames[i]
137
- break
138
- elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
139
- impact_idx = i
140
- impact_frame = filtered_frames[i]
141
- break
142
- if impact_idx is None:
143
- impact_idx = len(filtered_positions) - 1
144
- impact_frame = filtered_frames[-1]
145
  impact_point = filtered_positions[impact_idx]
 
146
 
147
  try:
148
  # Use linear interpolation for more stable trajectory
149
- fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
150
- fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='linear', fill_value="extrapolate")
151
  except Exception as e:
152
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
153
 
154
  # Generate dense points for all frames between first and last detection
155
  total_frames = max(detection_frames) - min(detection_frames) + 1
156
- t_full = np.linspace(times[0], times[impact_idx], total_frames * SLOW_MOTION_FACTOR)
157
  x_full = fx(t_full)
158
  y_full = fy(t_full)
159
  trajectory_2d = list(zip(x_full, y_full))
@@ -163,13 +156,13 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
163
  impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
164
 
165
  # Debug trajectory and points
166
- debug_log = (
167
- f"Trajectory estimated successfully\n"
168
- f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f}), 3D: {pitch_point_3d}\n"
169
- f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f}), 3D: {impact_point_3d}\n"
170
- f"Detections in frames: {filtered_frames}\n"
171
- f"Velocities: {velocities}"
172
- )
173
  # Save trajectory plot for debugging
174
  import matplotlib.pyplot as plt
175
  plt.plot(x_coords, y_coords, 'bo-', label='Filtered Detections')
@@ -178,7 +171,7 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
178
  plt.legend()
179
  plt.savefig("trajectory_debug.png")
180
 
181
- return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, debug_log
182
 
183
  def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
184
  """Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
@@ -291,18 +284,18 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
291
  frame_idx = i - min_frame if trajectory_indices else -1
292
  if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
293
  end_idx = trajectory_indices[frame_idx] + 1
294
- cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2)
295
  if pitch_point and i == pitch_frame:
296
  x, y = pitch_point
297
- cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1)
298
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
299
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
300
  if impact_point and i == impact_frame:
301
  x, y = impact_point
302
- cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1)
303
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
304
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
305
- for _ in range(SLOW_MOTION_FACTOR):
306
  out.write(frame)
307
  out.release()
308
  return output_path
 
16
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
17
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
18
  FRAME_RATE = 20 # Default frame rate, updated dynamically
19
+ SLOW_MOTION_FACTOR = 1.5 # Faster replay (e.g., 30 / 1.5 = 20 FPS)
20
+ CONF_THRESHOLD = 0.25 # Confidence threshold for detection
21
+ IMPACT_ZONE_Y = 0.9 # Adjusted to 90% of frame height for impact zone
22
+ IMPACT_VELOCITY_THRESHOLD = 500 # Reduced for better impact detection
23
  PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
24
  STUMPS_HEIGHT = 0.71 # meters (stumps height)
25
  CAMERA_HEIGHT = 2.0 # meters (assumed camera height)
26
  CAMERA_DISTANCE = 10.0 # meters (assumed camera distance from pitch)
27
+ MAX_POSITION_JUMP = 150 # Increased to include more detections
28
 
29
  def process_video(video_path):
30
  if not os.path.exists(video_path):
 
59
  if detections == 1: # Only process frames with exactly one ball detection
60
  for detection in results[0].boxes:
61
  if detection.cls == 0: # Class 0 is the ball
62
+ conf = detection.conf.cpu().numpy()[0]
63
  x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
64
  # Scale coordinates back to original frame size
65
  x1 = x1 * frame_width / img_width
 
69
  ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
70
  detection_frames.append(frame_count - 1)
71
  cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
72
+ debug_log.append(f"Frame {frame_count}: 1 ball detection, confidence={conf:.3f}")
73
+ else:
74
+ debug_log.append(f"Frame {frame_count}: {detections} ball detections")
75
  frames[-1] = frame
 
76
  # Save debug frame
77
  cv2.imwrite(f"debug_frame_{frame_count}.jpg", frame)
78
  cap.release()
 
111
  filtered_positions.append(curr_pos)
112
  filtered_frames.append(detection_frames[i])
113
  else:
114
+ debug_log.append(f"Filtered out detection at frame {detection_frames[i] + 1}: large jump ({distance:.1f} pixels)")
115
  continue
116
 
117
  if len(filtered_positions) < 2:
 
121
  y_coords = [pos[1] for pos in filtered_positions]
122
  times = np.array(filtered_frames) / FRAME_RATE
123
 
124
+ # Convert to 3D for visualization
125
  detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
126
 
127
+ # Pitch point: Detection with lowest y-coordinate (near bowler's end)
128
+ pitch_idx = min(range(len(filtered_positions)), key=lambda i: filtered_positions[i][1])
129
  pitch_point = filtered_positions[pitch_idx]
130
  pitch_frame = filtered_frames[pitch_idx]
131
 
132
+ # Impact point: Detection with highest y-coordinate after pitch point (near stumps)
133
+ post_pitch_indices = [i for i in range(len(filtered_positions)) if filtered_frames[i] > pitch_frame]
134
+ if not post_pitch_indices:
135
+ return None, None, None, None, None, None, None, None, None, "Error: No detections after pitch point"
136
+ impact_idx = max(post_pitch_indices, key=lambda i: filtered_positions[i][1])
 
 
 
 
 
 
 
 
 
 
 
 
137
  impact_point = filtered_positions[impact_idx]
138
+ impact_frame = filtered_frames[impact_idx]
139
 
140
  try:
141
  # Use linear interpolation for more stable trajectory
142
+ fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
143
+ fy = interp1d(times, y_coords, kind='linear', fill_value="extrapolate")
144
  except Exception as e:
145
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
146
 
147
  # Generate dense points for all frames between first and last detection
148
  total_frames = max(detection_frames) - min(detection_frames) + 1
149
+ t_full = np.linspace(times[0], times[-1], total_frames * SLOW_MOTION_FACTOR)
150
  x_full = fx(t_full)
151
  y_full = fy(t_full)
152
  trajectory_2d = list(zip(x_full, y_full))
 
156
  impact_point_3d = pixel_to_3d(impact_point[0], impact_point[1], frame_height, frame_width)
157
 
158
  # Debug trajectory and points
159
+ debug_log = [
160
+ f"Trajectory estimated successfully",
161
+ f"Pitch point at frame {pitch_frame + 1}: ({pitch_point[0]:.1f}, {pitch_point[1]:.1f}), 3D: {pitch_point_3d}",
162
+ f"Impact point at frame {impact_frame + 1}: ({impact_point[0]:.1f}, {impact_point[1]:.1f}), 3D: {impact_point_3d}",
163
+ f"Detections in frames: {filtered_frames}",
164
+ f"Velocities: {velocities}" if 'velocities' in locals() else "Velocities: Not calculated"
165
+ ]
166
  # Save trajectory plot for debugging
167
  import matplotlib.pyplot as plt
168
  plt.plot(x_coords, y_coords, 'bo-', label='Filtered Detections')
 
171
  plt.legend()
172
  plt.savefig("trajectory_debug.png")
173
 
174
+ return trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, "\n".join(debug_log)
175
 
176
  def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, plot_type="detections"):
177
  """Create 3D Plotly visualization for detections or trajectory using single-detection frames."""
 
284
  frame_idx = i - min_frame if trajectory_indices else -1
285
  if frame_idx >= 0 and frame_idx < total_frames and trajectory_points.size > 0:
286
  end_idx = trajectory_indices[frame_idx] + 1
287
+ cv2.polylines(frame, [trajectory_points[:end_idx]], False, (255, 0, 0), 2) # Blue line in BGR
288
  if pitch_point and i == pitch_frame:
289
  x, y = pitch_point
290
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 0, 255), -1) # Red circle
291
  cv2.putText(frame, "Pitch Point", (int(x) + 10, int(y) - 10),
292
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 255), 2)
293
  if impact_point and i == impact_frame:
294
  x, y = impact_point
295
+ cv2.circle(frame, (int(x), int(y)), 8, (0, 255, 255), -1) # Yellow circle
296
  cv2.putText(frame, "Impact Point", (int(x) + 10, int(y) + 20),
297
  cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 255), 2)
298
+ for _ in range(int(SLOW_MOTION_FACTOR)):
299
  out.write(frame)
300
  out.release()
301
  return output_path