dschandra commited on
Commit
449d194
·
verified ·
1 Parent(s): d179a4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -37
app.py CHANGED
@@ -1,10 +1,10 @@
1
-
2
  import cv2
3
  import numpy as np
4
  import torch
5
  from ultralytics import YOLO
6
  import gradio as gr
7
- from scipy.interpolate import interp1d, UnivariateSpline
 
8
  import uuid
9
  import os
10
 
@@ -15,10 +15,10 @@ model = YOLO("best.pt")
15
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
16
  FRAME_RATE = 20 # Input video frame rate
17
  SLOW_MOTION_FACTOR = 2 # Reduced for faster output
18
- CONF_THRESHOLD = 0.25 # Confidence threshold for detection
19
- PITCH_ZONE_Y = 0.85 # Adjusted for pitch near stumps
20
- IMPACT_ZONE_Y = 0.75 # Adjusted for impact near batsman leg
21
- IMPACT_DELTA_Y = 30 # Reduced for finer impact detection
22
  STUMPS_HEIGHT = 0.711 # meters (height of stumps)
23
 
24
  def process_video(video_path):
@@ -35,16 +35,18 @@ def process_video(video_path):
35
  ret, frame = cap.read()
36
  if not ret:
37
  break
38
- if frame_count % 2 == 0: # Process every 2nd frame
39
- frames.append(frame.copy())
40
- results = model.predict(frame, conf=CONF_THRESHOLD)
41
- detections = [det for det in results[0].boxes if det.cls == 0]
42
- if len(detections) == 1:
43
- x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
44
- ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
45
- detection_frames.append(len(frames) - 1)
46
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
47
- frames[-1] = frame
 
 
48
  debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections")
49
  frame_count += 1
50
  cap.release()
@@ -61,25 +63,18 @@ def estimate_trajectory(ball_positions, detection_frames, frames):
61
  return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
62
  frame_height = frames[0].shape[0]
63
 
64
- # Filter to unique positions
65
- unique_positions = [ball_positions[0]]
66
- for pos in ball_positions[1:]:
67
- if abs(pos[0] - unique_positions[-1][0]) > 10 or abs(pos[1] - unique_positions[-1][1]) > 10:
68
- unique_positions.append(pos)
69
- x_coords = [pos[0] for pos in unique_positions]
70
- y_coords = [pos[1] for pos in unique_positions]
71
- times = np.array([i / FRAME_RATE for i in range(len(unique_positions))])
72
-
73
- # Smooth coordinates with spline interpolation
74
- x_smooth = UnivariateSpline(times, x_coords, s=10)
75
- y_smooth = UnivariateSpline(times, y_coords, s=10)
76
 
77
  pitch_idx = 0
78
  for i, y in enumerate(y_coords):
79
  if y > frame_height * PITCH_ZONE_Y:
80
  pitch_idx = i
81
  break
82
- pitch_point = unique_positions[pitch_idx]
83
  pitch_frame = detection_frames[pitch_idx]
84
 
85
  impact_idx = None
@@ -90,7 +85,7 @@ def estimate_trajectory(ball_positions, detection_frames, frames):
90
  break
91
  if impact_idx is None:
92
  impact_idx = len(y_coords) - 1
93
- impact_point = unique_positions[impact_idx]
94
  impact_frame = detection_frames[impact_idx]
95
 
96
  x_coords = x_coords[:impact_idx + 1]
@@ -98,12 +93,12 @@ def estimate_trajectory(ball_positions, detection_frames, frames):
98
  times = times[:impact_idx + 1]
99
 
100
  try:
101
- fx = interp1d(times, x_smooth(times), kind='linear', fill_value="extrapolate")
102
- fy = interp1d(times, y_smooth(times), kind='quadratic', fill_value="extrapolate")
103
  except Exception as e:
104
  return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
105
 
106
- vis_trajectory = list(zip(x_smooth(times), y_smooth(times)))
107
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 5)
108
  x_full = fx(t_full)
109
  y_full = fy(t_full)
@@ -122,9 +117,9 @@ def lbw_decision(ball_positions, full_trajectory, frames, pitch_point, impact_po
122
 
123
  frame_height, frame_width = frames[0].shape[:2]
124
  stumps_x = frame_width / 2
125
- stumps_y = frame_height * 0.85 # Adjusted to align with pitch
126
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
127
- batsman_area_y = frame_height * 0.75
128
 
129
  pitch_x, pitch_y = pitch_point
130
  impact_x, impact_y = impact_point
@@ -154,7 +149,7 @@ def generate_slow_motion(frames, vis_trajectory, pitch_point, pitch_frame, impac
154
  return None
155
  frame_height, frame_width = frames[0].shape[:2]
156
  stumps_x = frame_width / 2
157
- stumps_y = frame_height * 0.85 # Align with pitch
158
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
159
  stumps_height_pixels = frame_height * (STUMPS_HEIGHT / 3.0)
160
 
@@ -228,4 +223,4 @@ iface = gr.Interface(
228
  )
229
 
230
  if __name__ == "__main__":
231
- iface.launch()
 
 
1
  import cv2
2
  import numpy as np
3
  import torch
4
  from ultralytics import YOLO
5
  import gradio as gr
6
+ from scipy.interpolate import interp1d
7
+ from scipy.ndimage import uniform_filter1d
8
  import uuid
9
  import os
10
 
 
15
  STUMPS_WIDTH = 0.2286 # meters (width of stumps)
16
  FRAME_RATE = 20 # Input video frame rate
17
  SLOW_MOTION_FACTOR = 2 # Reduced for faster output
18
+ CONF_THRESHOLD = 0.3 # Increased for better detection
19
+ PITCH_ZONE_Y = 0.8 # Adjusted for pitch near stumps
20
+ IMPACT_ZONE_Y = 0.7 # Adjusted for impact near batsman leg
21
+ IMPACT_DELTA_Y = 20 # Reduced for finer impact detection
22
  STUMPS_HEIGHT = 0.711 # meters (height of stumps)
23
 
24
  def process_video(video_path):
 
35
  ret, frame = cap.read()
36
  if not ret:
37
  break
38
+ # Process every frame for better tracking
39
+ frames.append(frame.copy())
40
+ # Preprocess frame for better detection
41
+ frame = cv2.convertScaleAbs(frame, alpha=1.2, beta=10) # Enhance contrast
42
+ results = model.predict(frame, conf=CONF_THRESHOLD)
43
+ detections = [det for det in results[0].boxes if det.cls == 0]
44
+ if len(detections) == 1:
45
+ x1, y1, x2, y2 = detections[0].xyxy[0].cpu().numpy()
46
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
47
+ detection_frames.append(len(frames) - 1)
48
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
49
+ frames[-1] = frame
50
  debug_log.append(f"Frame {frame_count}: {len(detections)} ball detections")
51
  frame_count += 1
52
  cap.release()
 
63
  return None, None, None, None, None, None, "Error: Fewer than 2 valid single-ball detections for trajectory"
64
  frame_height = frames[0].shape[0]
65
 
66
+ # Smooth coordinates with moving average
67
+ window_size = 3
68
+ x_coords = uniform_filter1d([pos[0] for pos in ball_positions], size=window_size, mode='nearest')
69
+ y_coords = uniform_filter1d([pos[1] for pos in ball_positions], size=window_size, mode='nearest')
70
+ times = np.array([i / FRAME_RATE for i in range(len(ball_positions))])
 
 
 
 
 
 
 
71
 
72
  pitch_idx = 0
73
  for i, y in enumerate(y_coords):
74
  if y > frame_height * PITCH_ZONE_Y:
75
  pitch_idx = i
76
  break
77
+ pitch_point = ball_positions[pitch_idx]
78
  pitch_frame = detection_frames[pitch_idx]
79
 
80
  impact_idx = None
 
85
  break
86
  if impact_idx is None:
87
  impact_idx = len(y_coords) - 1
88
+ impact_point = ball_positions[impact_idx]
89
  impact_frame = detection_frames[impact_idx]
90
 
91
  x_coords = x_coords[:impact_idx + 1]
 
93
  times = times[:impact_idx + 1]
94
 
95
  try:
96
+ fx = interp1d(times, x_coords, kind='linear', fill_value="extrapolate")
97
+ fy = interp1d(times, y_coords, kind='quadratic', fill_value="extrapolate")
98
  except Exception as e:
99
  return None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
100
 
101
+ vis_trajectory = list(zip(x_coords, y_coords))
102
  t_full = np.linspace(times[0], times[-1] + 0.5, len(times) + 5)
103
  x_full = fx(t_full)
104
  y_full = fy(t_full)
 
117
 
118
  frame_height, frame_width = frames[0].shape[:2]
119
  stumps_x = frame_width / 2
120
+ stumps_y = frame_height * 0.8 # Adjusted to align with pitch
121
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
122
+ batsman_area_y = frame_height * 0.7
123
 
124
  pitch_x, pitch_y = pitch_point
125
  impact_x, impact_y = impact_point
 
149
  return None
150
  frame_height, frame_width = frames[0].shape[:2]
151
  stumps_x = frame_width / 2
152
+ stumps_y = frame_height * 0.8 # Align with pitch
153
  stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
154
  stumps_height_pixels = frame_height * (STUMPS_HEIGHT / 3.0)
155
 
 
223
  )
224
 
225
  if __name__ == "__main__":
226
+ iface.launch()