dschandra commited on
Commit
5eaa37a
·
verified ·
1 Parent(s): 698479c

Update lbw_detector.py

Browse files
Files changed (1) hide show
  1. lbw_detector.py +27 -78
lbw_detector.py CHANGED
@@ -1,79 +1,28 @@
1
- from ultralytics import YOLO
2
- import cv2
3
  import numpy as np
4
- import os
5
-
6
- # Load YOLO model (custom-trained or pretrained with compatible classes)
7
- model = YOLO("yolov8n.pt") # Replace with "lbw_yolov8.pt" if custom-trained
8
-
9
- #model_path = os.path.join("models", "yolov8n.pt")
10
- #model = YOLO(model_path)
11
-
12
- # Target class IDs update based on your custom model class mapping
13
- CLASS_NAMES = {
14
- 0: "ball",
15
- 1: "bat",
16
- 2: "pad",
17
- 3: "stump",
18
- 4: "player"
19
- }
20
-
21
- def detect_lbw_event(frames):
22
- """
23
- Detects ball, bat, stump, and pad in each frame.
24
- Identifies impact and prepares coordinates for trajectory modeling.
25
-
26
- Returns:
27
- dict: {
28
- "ball_positions": [x, y] list per frame,
29
- "impact_frame": int,
30
- "impact_type": str,
31
- "objects_per_frame": [
32
- {"ball": (x, y), "pad": (x, y), ...}
33
- ]
34
- }
35
- """
36
- ball_positions = []
37
- impact_frame = -1
38
- impact_type = None
39
- objects_per_frame = []
40
-
41
- for idx, frame in enumerate(frames):
42
- results = model(frame)[0]
43
- frame_objects = {}
44
-
45
- for det in results.boxes.data:
46
- x1, y1, x2, y2, conf, cls = det.cpu().numpy()
47
- class_id = int(cls)
48
- class_name = CLASS_NAMES.get(class_id, "unknown")
49
- center_x = int((x1 + x2) / 2)
50
- center_y = int((y1 + y2) / 2)
51
- frame_objects[class_name] = (center_x, center_y)
52
-
53
- if class_name == "ball":
54
- ball_positions.append((idx, center_x, center_y))
55
-
56
- objects_per_frame.append(frame_objects)
57
-
58
- # Basic impact logic: ball overlaps pad or bat
59
- if "ball" in frame_objects and ("pad" in frame_objects or "bat" in frame_objects):
60
- bx, by = frame_objects["ball"]
61
- if "pad" in frame_objects:
62
- px, py = frame_objects["pad"]
63
- if abs(bx - px) < 30 and abs(by - py) < 30:
64
- impact_frame = idx
65
- impact_type = "pad"
66
- break
67
- if "bat" in frame_objects:
68
- tx, ty = frame_objects["bat"]
69
- if abs(bx - tx) < 30 and abs(by - ty) < 30:
70
- impact_frame = idx
71
- impact_type = "bat"
72
- break
73
-
74
- return {
75
- "ball_positions": ball_positions,
76
- "impact_frame": impact_frame,
77
- "impact_type": impact_type,
78
- "objects_per_frame": objects_per_frame
79
- }
 
1
+ # lbw_detector.py
2
+ import torch
3
  import numpy as np
4
+ from torchvision import transforms
5
+ import cv2
6
+ from utils import extract_frames
7
+ from trajectory_predictor import predict_trajectory
8
+ from visualizer import draw_visuals
9
+
10
+ # Load the custom LBW model
11
+ model_path = "models/lbw_drs_unet_model.pth"
12
+ device = "cpu" # Hugging Face Free Tier
13
+
14
+ model = torch.load(model_path, map_location=device)
15
+ model.eval()
16
+
17
+ transform = transforms.Compose([
18
+ transforms.ToTensor(),
19
+ ])
20
+
21
+ def detect_objects_with_model(frame):
22
+ """Run segmentation on a frame using the custom model"""
23
+ input_tensor = transform(frame).unsqueeze(0).to(device)
24
+ with torch.no_grad():
25
+ output = model(input_tensor)
26
+ # Convert output to mask
27
+ mask = torch.sigmoid(output).squeeze().cpu().numpy()
28
+ return mask # Assumed to be binary mask (ball/pad/stump segmentation)