AjaykumarPilla commited on
Commit
a8a99f7
·
verified ·
1 Parent(s): 2c04ca3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -7
app.py CHANGED
@@ -8,15 +8,16 @@ import plotly.graph_objects as go
8
  import uuid
9
  import os
10
 
11
- # Load the trained YOLOv8n model
12
  model = YOLO("best.pt")
 
13
 
14
  # Constants for LBW decision and video processing
15
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
16
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
17
- FRAME_RATE = 30 # Input video frame rate
18
  SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
19
- CONF_THRESHOLD = 0.2 # Lowered confidence threshold to improve detection
20
  IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
21
  IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
22
  PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
@@ -28,6 +29,9 @@ def process_video(video_path):
28
  if not os.path.exists(video_path):
29
  return [], [], [], "Error: Video file not found"
30
  cap = cv2.VideoCapture(video_path)
 
 
 
31
  frames = []
32
  ball_positions = []
33
  detection_frames = []
@@ -40,7 +44,8 @@ def process_video(video_path):
40
  break
41
  frame_count += 1
42
  frames.append(frame.copy())
43
- results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=640)
 
44
  detections = 0
45
  for detection in results[0].boxes:
46
  if detection.cls == 0: # Class 0 is the ball
@@ -57,6 +62,7 @@ def process_video(video_path):
57
  debug_log.append("No balls detected in any frame")
58
  else:
59
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
 
60
 
61
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
62
 
@@ -91,7 +97,6 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
91
  impact_frame = detection_frames[i]
92
  break
93
  elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
94
- # Fallback to y-position if no significant y-change
95
  impact_idx = i
96
  impact_frame = detection_frames[i]
97
  break
@@ -206,7 +211,7 @@ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d
206
  impact_scatter = go.Scatter3d(
207
  x=[impact_point_3d[0]] if impact_point_3d else [],
208
  y=[impact_point_3d[1]] if impact_point_3d else [],
209
- z=[impact_point_3d[2]] if pitch_point_3d else [],
210
  mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
211
  )
212
  data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
@@ -228,8 +233,9 @@ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d
228
  def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
229
  if not frames:
230
  return None
 
231
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
232
- out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frames[0].shape[1], frames[0].shape[0]))
233
 
234
  if trajectory and detection_frames:
235
  trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))
 
8
  import uuid
9
  import os
10
 
11
+ # Load the trained YOLOv8n model with optimizations
12
  model = YOLO("best.pt")
13
+ model.to('cuda' if torch.cuda.is_available() else 'cpu') # Use GPU if available
14
 
15
  # Constants for LBW decision and video processing
16
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
17
  BALL_DIAMETER = 0.073 # meters (approx. cricket ball diameter)
18
+ FRAME_RATE = 30 # Input video frame rate (adjust if known)
19
  SLOW_MOTION_FACTOR = 6 # For very slow motion (6x slower)
20
+ CONF_THRESHOLD = 0.2 # Confidence threshold
21
  IMPACT_ZONE_Y = 0.85 # Fraction of frame height where impact is likely
22
  IMPACT_DELTA_Y = 50 # Pixels for detecting sudden y-position change
23
  PITCH_LENGTH = 20.12 # meters (standard cricket pitch length)
 
29
  if not os.path.exists(video_path):
30
  return [], [], [], "Error: Video file not found"
31
  cap = cv2.VideoCapture(video_path)
32
+ # Get native video resolution
33
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
34
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
35
  frames = []
36
  ball_positions = []
37
  detection_frames = []
 
44
  break
45
  frame_count += 1
46
  frames.append(frame.copy())
47
+ # Use native resolution for inference
48
+ results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
49
  detections = 0
50
  for detection in results[0].boxes:
51
  if detection.cls == 0: # Class 0 is the ball
 
62
  debug_log.append("No balls detected in any frame")
63
  else:
64
  debug_log.append(f"Total ball detections: {len(ball_positions)}")
65
+ debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
66
 
67
  return frames, ball_positions, detection_frames, "\n".join(debug_log)
68
 
 
97
  impact_frame = detection_frames[i]
98
  break
99
  elif y_coords[i] > frame_height * IMPACT_ZONE_Y:
 
100
  impact_idx = i
101
  impact_frame = detection_frames[i]
102
  break
 
211
  impact_scatter = go.Scatter3d(
212
  x=[impact_point_3d[0]] if impact_point_3d else [],
213
  y=[impact_point_3d[1]] if impact_point_3d else [],
214
+ z=[impact_point_3d[2]] if impact_point_3d else [],
215
  mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
216
  )
217
  data = [trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces
 
233
  def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
234
  if not frames:
235
  return None
236
+ frame_height, frame_width = frames[0].shape[:2]
237
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
238
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
239
 
240
  if trajectory and detection_frames:
241
  trajectory_points = np.array(trajectory[:len(detection_frames)], dtype=np.int32).reshape((-1, 1, 2))