Spaces:
Sleeping
Sleeping
Update lbw_detector.py
Browse files- lbw_detector.py +52 -0
lbw_detector.py
CHANGED
@@ -26,3 +26,55 @@ def detect_objects_with_model(frame):
|
|
26 |
# Convert output to mask
|
27 |
mask = torch.sigmoid(output).squeeze().cpu().numpy()
|
28 |
return mask # Assumed to be binary mask (ball/pad/stump segmentation)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
# Convert output to mask
|
27 |
mask = torch.sigmoid(output).squeeze().cpu().numpy()
|
28 |
return mask # Assumed to be binary mask (ball/pad/stump segmentation)
|
29 |
+
|
30 |
+
def analyze_video(video_path):
|
31 |
+
frames = extract_frames(video_path)
|
32 |
+
|
33 |
+
ball_positions = []
|
34 |
+
impact_frame_idx = None
|
35 |
+
impact_zone = "unknown"
|
36 |
+
|
37 |
+
for i, frame in enumerate(frames):
|
38 |
+
mask = detect_objects_with_model(frame)
|
39 |
+
|
40 |
+
# Very simple segmentation logic
|
41 |
+
ball_mask = mask[0] > 0.5 # channel 0 for ball
|
42 |
+
pad_mask = mask[1] > 0.5 if mask.ndim > 2 else None # channel 1 for pad
|
43 |
+
|
44 |
+
# Detect ball center
|
45 |
+
contours, _ = cv2.findContours(ball_mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
46 |
+
if contours:
|
47 |
+
largest = max(contours, key=cv2.contourArea)
|
48 |
+
M = cv2.moments(largest)
|
49 |
+
if M['m00'] != 0:
|
50 |
+
cx = int(M['m10']/M['m00'])
|
51 |
+
cy = int(M['m01']/M['m00'])
|
52 |
+
ball_positions.append((i, cx, cy))
|
53 |
+
|
54 |
+
# Detect pad hit (optional logic: ball near pad area)
|
55 |
+
if pad_mask is not None and contours:
|
56 |
+
overlap = np.logical_and(ball_mask, pad_mask).sum()
|
57 |
+
if overlap > 10: # simple overlap threshold
|
58 |
+
impact_frame_idx = i
|
59 |
+
impact_zone = "pad"
|
60 |
+
break
|
61 |
+
|
62 |
+
# Run trajectory prediction if ball was detected
|
63 |
+
trajectory = predict_trajectory(ball_positions)
|
64 |
+
|
65 |
+
# Predict outcome
|
66 |
+
decision = "OUT" if trajectory_hits_stumps(trajectory) and impact_zone == "pad" else "NOT OUT"
|
67 |
+
|
68 |
+
# Visualize
|
69 |
+
result_path = draw_visuals(frames, ball_positions, trajectory, impact_frame_idx, decision)
|
70 |
+
|
71 |
+
return result_path, decision
|
72 |
+
|
73 |
+
|
74 |
+
def trajectory_hits_stumps(trajectory):
|
75 |
+
# Simple rule-based check (assuming stumps are around x=300 to 340 px for now)
|
76 |
+
for (x, y) in trajectory:
|
77 |
+
if 300 < x < 340 and y < 480: # ball projected height intersects stump zone
|
78 |
+
return True
|
79 |
+
return False
|
80 |
+
|