dschandra commited on
Commit
8ba6e46
·
verified ·
1 Parent(s): 0ff12de

Update lbw_detector.py

Browse files
Files changed (1) hide show
  1. lbw_detector.py +52 -0
lbw_detector.py CHANGED
@@ -26,3 +26,55 @@ def detect_objects_with_model(frame):
26
  # Convert output to mask
27
  mask = torch.sigmoid(output).squeeze().cpu().numpy()
28
  return mask # Assumed to be binary mask (ball/pad/stump segmentation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  # Convert output to mask
27
  mask = torch.sigmoid(output).squeeze().cpu().numpy()
28
  return mask # Assumed to be binary mask (ball/pad/stump segmentation)
29
+
30
+ def analyze_video(video_path):
31
+ frames = extract_frames(video_path)
32
+
33
+ ball_positions = []
34
+ impact_frame_idx = None
35
+ impact_zone = "unknown"
36
+
37
+ for i, frame in enumerate(frames):
38
+ mask = detect_objects_with_model(frame)
39
+
40
+ # Very simple segmentation logic
41
+ ball_mask = mask[0] > 0.5 # channel 0 for ball
42
+ pad_mask = mask[1] > 0.5 if mask.ndim > 2 else None # channel 1 for pad
43
+
44
+ # Detect ball center
45
+ contours, _ = cv2.findContours(ball_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
46
+ if contours:
47
+ largest = max(contours, key=cv2.contourArea)
48
+ M = cv2.moments(largest)
49
+ if M['m00'] != 0:
50
+ cx = int(M['m10']/M['m00'])
51
+ cy = int(M['m01']/M['m00'])
52
+ ball_positions.append((i, cx, cy))
53
+
54
+ # Detect pad hit (optional logic: ball near pad area)
55
+ if pad_mask is not None and contours:
56
+ overlap = np.logical_and(ball_mask, pad_mask).sum()
57
+ if overlap > 10: # simple overlap threshold
58
+ impact_frame_idx = i
59
+ impact_zone = "pad"
60
+ break
61
+
62
+ # Run trajectory prediction if ball was detected
63
+ trajectory = predict_trajectory(ball_positions)
64
+
65
+ # Predict outcome
66
+ decision = "OUT" if trajectory_hits_stumps(trajectory) and impact_zone == "pad" else "NOT OUT"
67
+
68
+ # Visualize
69
+ result_path = draw_visuals(frames, ball_positions, trajectory, impact_frame_idx, decision)
70
+
71
+ return result_path, decision
72
+
73
+
74
+ def trajectory_hits_stumps(trajectory):
75
+ # Simple rule-based check (assuming stumps are around x=300 to 340 px for now)
76
+ for (x, y) in trajectory:
77
+ if 300 < x < 340 and y < 480: # ball projected height intersects stump zone
78
+ return True
79
+ return False
80
+