AjaykumarPilla commited on
Commit
0ae5282
·
verified ·
1 Parent(s): 9007667

Update gully_drs_core/ball_detection.py

Browse files
Files changed (1) hide show
  1. gully_drs_core/ball_detection.py +22 -37
gully_drs_core/ball_detection.py CHANGED
@@ -5,53 +5,37 @@ import numpy as np
5
  from .model_utils import load_model
6
 
7
  def find_bounce_point(path):
8
- """
9
- Detects the bounce point by checking for a dip in the y-axis (trajectory).
10
- """
11
- for i in range(1, len(path) - 1):
12
- if path[i - 1][1] > path[i][1] < path[i + 1][1]: # y decreases then increases
13
  return path[i]
14
  return None
15
 
16
  def estimate_speed(ball_path, fps, px_to_m=0.01):
17
- """
18
- Estimate speed in km/h based on pixel distance and frame rate.
19
- Assumes 1 pixel ≈ 1cm (adjust px_to_m for better accuracy).
20
- """
21
  if len(ball_path) < 2:
22
  return 0.0
23
-
24
  p1 = ball_path[0]
25
- p2 = ball_path[min(5, len(ball_path) - 1)] # use the 5th frame ahead
26
-
27
- dx = p2[0] - p1[0]
28
- dy = p2[1] - p1[1]
29
  dist_px = (dx**2 + dy**2)**0.5
30
  dist_m = dist_px * px_to_m
31
- time_s = (min(5, len(ball_path) - 1)) / fps
32
-
33
  speed_kmh = (dist_m / time_s) * 3.6 if time_s > 0 else 0
34
  return round(speed_kmh, 1)
35
 
36
  def analyze_video(file_path):
37
- """
38
- Main processing function:
39
- - Detects the ball using YOLOv8
40
- - Builds trajectory from valid frames
41
- - Detects bounce, impact, stump zone intersection
42
- - Returns decision + video frame overlays
43
- """
44
  model = load_model()
45
  cap = cv2.VideoCapture(file_path)
46
  fps = cap.get(cv2.CAP_PROP_FPS)
47
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
48
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
49
 
50
- frames = []
51
  ball_path = []
 
 
 
52
 
53
- max_jump = 100 # pixels
54
  last_point = None
 
55
 
56
  while True:
57
  ret, frame = cap.read()
@@ -62,20 +46,22 @@ def analyze_video(file_path):
62
  valid_detection = None
63
 
64
  for r in results:
65
- ball_detections = [box for box in r.boxes if int(box.cls[0]) == 0] # class 0 = cricket ball
 
66
  if len(ball_detections) == 1:
67
  box = ball_detections[0]
68
  x1, y1, x2, y2 = map(int, box.xyxy[0])
69
- cx = (x1 + x2) // 2
70
- cy = (y1 + y2) // 2
71
 
72
- # Filter out sudden jumps in position
73
  if last_point:
74
- dx = cx - last_point[0]
75
- dy = cy - last_point[1]
76
  jump = (dx**2 + dy**2)**0.5
77
  if jump > max_jump:
78
- break # skip this frame
 
 
 
79
 
80
  valid_detection = (cx, cy)
81
  last_point = valid_detection
@@ -85,15 +71,13 @@ def analyze_video(file_path):
85
  ball_path.append(valid_detection)
86
 
87
  frames.append(frame)
 
88
 
89
  cap.release()
90
 
91
- # Calculate analysis outputs
92
  bounce_point = find_bounce_point(ball_path)
93
  impact_point = ball_path[-1] if ball_path else None
94
- speed_kmh = estimate_speed(ball_path, fps)
95
 
96
- # Define stump zone area
97
  stump_zone = (
98
  width // 2 - 30,
99
  height - 100,
@@ -101,13 +85,14 @@ def analyze_video(file_path):
101
  height
102
  )
103
 
104
- # LBW decision: does ball impact land in stump zone?
105
  decision = "OUT" if (
106
  impact_point and
107
  stump_zone[0] <= impact_point[0] <= stump_zone[2] and
108
  stump_zone[1] <= impact_point[1] <= stump_zone[3]
109
  ) else "NOT OUT"
110
 
 
 
111
  return {
112
  "trajectory": ball_path,
113
  "fps": fps,
 
5
  from .model_utils import load_model
6
 
7
  def find_bounce_point(path):
8
+ for i in range(1, len(path)-1):
9
+ if path[i-1][1] > path[i][1] < path[i+1][1]: # y dips = bounce
 
 
 
10
  return path[i]
11
  return None
12
 
13
  def estimate_speed(ball_path, fps, px_to_m=0.01):
 
 
 
 
14
  if len(ball_path) < 2:
15
  return 0.0
 
16
  p1 = ball_path[0]
17
+ p2 = ball_path[min(5, len(ball_path)-1)]
18
+ dx, dy = p2[0] - p1[0], p2[1] - p1[1]
 
 
19
  dist_px = (dx**2 + dy**2)**0.5
20
  dist_m = dist_px * px_to_m
21
+ time_s = (min(5, len(ball_path)-1)) / fps
 
22
  speed_kmh = (dist_m / time_s) * 3.6 if time_s > 0 else 0
23
  return round(speed_kmh, 1)
24
 
25
  def analyze_video(file_path):
 
 
 
 
 
 
 
26
  model = load_model()
27
  cap = cv2.VideoCapture(file_path)
28
  fps = cap.get(cv2.CAP_PROP_FPS)
29
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
30
  height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
31
 
 
32
  ball_path = []
33
+ frames = []
34
+
35
+ max_jump = 100 # max allowed jump (pixels) between consecutive ball detections
36
 
 
37
  last_point = None
38
+ frame_idx = 0
39
 
40
  while True:
41
  ret, frame = cap.read()
 
46
  valid_detection = None
47
 
48
  for r in results:
49
+ # Accept only if exactly one detection of cricket ball class (e.g., class 0)
50
+ ball_detections = [box for box in r.boxes if int(box.cls[0]) == 0]
51
  if len(ball_detections) == 1:
52
  box = ball_detections[0]
53
  x1, y1, x2, y2 = map(int, box.xyxy[0])
54
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
 
55
 
56
+ # Check jump threshold from last point
57
  if last_point:
58
+ dx, dy = cx - last_point[0], cy - last_point[1]
 
59
  jump = (dx**2 + dy**2)**0.5
60
  if jump > max_jump:
61
+ # Reject outlier
62
+ frames.append(frame)
63
+ frame_idx += 1
64
+ continue
65
 
66
  valid_detection = (cx, cy)
67
  last_point = valid_detection
 
71
  ball_path.append(valid_detection)
72
 
73
  frames.append(frame)
74
+ frame_idx += 1
75
 
76
  cap.release()
77
 
 
78
  bounce_point = find_bounce_point(ball_path)
79
  impact_point = ball_path[-1] if ball_path else None
 
80
 
 
81
  stump_zone = (
82
  width // 2 - 30,
83
  height - 100,
 
85
  height
86
  )
87
 
 
88
  decision = "OUT" if (
89
  impact_point and
90
  stump_zone[0] <= impact_point[0] <= stump_zone[2] and
91
  stump_zone[1] <= impact_point[1] <= stump_zone[3]
92
  ) else "NOT OUT"
93
 
94
+ speed_kmh = estimate_speed(ball_path, fps)
95
+
96
  return {
97
  "trajectory": ball_path,
98
  "fps": fps,