dschandra commited on
Commit
2db7738
·
verified ·
1 Parent(s): 7bb911b

Upload 6 files

Browse files
drs_modules/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Helper package for the DRS application.
2
+
3
+ This package groups together the individual components used by the
4
+ Digital Review System. Each submodule focuses on a specific part of
5
+ the pipeline: video processing, ball detection and tracking,
6
+ trajectory estimation, LBW decision logic and result visualisation.
7
+ """
drs_modules/detection.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Ball detection and tracking for the DRS application.
2
+
3
+ This module implements a simple motion‑based tracker to follow the cricket
4
+ ball in a video. Professional ball tracking systems use multiple high
5
+ frame‑rate cameras and sophisticated object detectors. Here, we rely on
6
+ background subtraction combined with circle detection (Hough circles) to
7
+ locate the ball in each frame. The tracker keeps the coordinates and
8
+ timestamps of the ball's centre so that downstream modules can
9
+ estimate its trajectory and predict whether it will hit the stumps.
10
+
11
+ The detection pipeline makes the following assumptions:
12
+
13
+ * Only one ball is present in the scene at a time.
14
+ * The ball is approximately circular in appearance.
15
+ * The camera is static or moves little compared to the ball.
16
+
17
+ These assumptions hold for many amateur cricket recordings but are
18
+ obviously simplified compared to a true DRS system.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import cv2
24
+ import numpy as np
25
+ from typing import Dict, List, Tuple
26
+
27
+
28
+ def detect_and_track_ball(video_path: str) -> Dict[str, List]:
29
+ """Detect and track the cricket ball in a video.
30
+
31
+ Parameters
32
+ ----------
33
+ video_path: str
34
+ Path to the trimmed video segment containing the delivery and appeal.
35
+
36
+ Returns
37
+ -------
38
+ Dict[str, List]
39
+ A dictionary containing:
40
+ ``centers``: list of (x, y) coordinates of the ball in successive frames.
41
+ ``timestamps``: list of timestamps (in seconds) corresponding to each centre.
42
+ ``radii``: list of detected circle radii (in pixels).
43
+ """
44
+ cap = cv2.VideoCapture(video_path)
45
+ if not cap.isOpened():
46
+ raise RuntimeError(f"Could not open video {video_path}")
47
+
48
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
49
+
50
+ # Background subtractor for motion detection
51
+ bg_sub = cv2.createBackgroundSubtractorMOG2(history=500, varThreshold=32, detectShadows=False)
52
+
53
+ centers: List[Tuple[int, int]] = []
54
+ radii: List[int] = []
55
+ timestamps: List[float] = []
56
+
57
+ previous_center: Tuple[int, int] | None = None
58
+ frame_idx = 0
59
+ while True:
60
+ ret, frame = cap.read()
61
+ if not ret:
62
+ break
63
+ timestamp = frame_idx / fps
64
+
65
+ # Preprocess: grayscale and blur
66
+ gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)
67
+ blurred = cv2.GaussianBlur(gray, (5, 5), 0)
68
+
69
+ # Apply background subtraction to emphasise moving objects
70
+ fg_mask = bg_sub.apply(frame)
71
+ # Remove noise from mask
72
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
73
+ fg_mask = cv2.morphologyEx(fg_mask, cv2.MORPH_OPEN, kernel)
74
+
75
+ detected_center: Tuple[int, int] | None = None
76
+ detected_radius: int | None = None
77
+
78
+ # Attempt to detect circles using Hough transform
79
+ circles = cv2.HoughCircles(
80
+ blurred,
81
+ cv2.HOUGH_GRADIENT,
82
+ dp=1.2,
83
+ minDist=20,
84
+ param1=50,
85
+ param2=30,
86
+ minRadius=3,
87
+ maxRadius=30,
88
+ )
89
+ if circles is not None:
90
+ circles = np.round(circles[0, :]).astype("int")
91
+ # Choose the circle closest to the previous detection to maintain
92
+ # continuity. If no previous detection exists, pick the circle
93
+ # with the smallest radius (likely the ball).
94
+ if previous_center is not None:
95
+ min_dist = float("inf")
96
+ chosen = None
97
+ for x, y, r in circles:
98
+ dist = (x - previous_center[0]) ** 2 + (y - previous_center[1]) ** 2
99
+ if dist < min_dist:
100
+ min_dist = dist
101
+ chosen = (x, y, r)
102
+ if chosen is not None:
103
+ detected_center = (int(chosen[0]), int(chosen[1]))
104
+ detected_radius = int(chosen[2])
105
+ else:
106
+ # No previous centre: pick the smallest radius circle
107
+ chosen = min(circles, key=lambda c: c[2])
108
+ detected_center = (int(chosen[0]), int(chosen[1]))
109
+ detected_radius = int(chosen[2])
110
+
111
+ # Fallback: use contours on the foreground mask to find moving blobs
112
+ if detected_center is None:
113
+ contours, _ = cv2.findContours(fg_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
114
+ # Filter contours by area to eliminate noise; choose the one
115
+ # closest to previous centre or the smallest area blob
116
+ candidates = []
117
+ for cnt in contours:
118
+ area = cv2.contourArea(cnt)
119
+ if 10 < area < 800: # adjust thresholds as necessary
120
+ x, y, w, h = cv2.boundingRect(cnt)
121
+ cx = x + w // 2
122
+ cy = y + h // 2
123
+ candidates.append((cx, cy, w, h, area))
124
+ if candidates:
125
+ if previous_center is not None:
126
+ chosen = min(candidates, key=lambda c: (c[0] - previous_center[0]) ** 2 + (c[1] - previous_center[1]) ** 2)
127
+ else:
128
+ chosen = min(candidates, key=lambda c: c[4])
129
+ cx, cy, w, h, _ = chosen
130
+ detected_center = (int(cx), int(cy))
131
+ detected_radius = int(max(w, h) / 2)
132
+
133
+ if detected_center is not None:
134
+ centers.append(detected_center)
135
+ radii.append(detected_radius or 5)
136
+ timestamps.append(timestamp)
137
+ previous_center = detected_center
138
+ # Increment frame index regardless of detection
139
+ frame_idx += 1
140
+
141
+ cap.release()
142
+ return {"centers": centers, "radii": radii, "timestamps": timestamps}
drs_modules/lbw_decision.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """LBW decision logic for the cricket DRS demo.
2
+
3
+ This module encapsulates the high‑level logic used to decide whether a
4
+ batsman is out leg before wicket (LBW) based on the ball's trajectory.
5
+ In professional systems the decision depends on many factors: where
6
+ the ball pitched, the line it travelled relative to the stumps, the
7
+ height of impact on the pad/glove, and whether the batsman offered a
8
+ shot. To keep this example straightforward we apply a very simple
9
+ rule:
10
+
11
+ * If the predicted trajectory intersects the stumps, the batsman is
12
+ declared **OUT**.
13
+ * Otherwise the batsman is **NOT OUT**.
14
+
15
+ We also return the index of the frame deemed to be the impact frame.
16
+ Here we take the impact frame to be the last frame where the ball was
17
+ detected. In a more complete system one would detect the moment of
18
+ contact with the pad or glove and use that as the impact frame.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ from typing import List, Tuple
24
+
25
+
26
+ def make_lbw_decision(
27
+ centers: List[Tuple[int, int]],
28
+ trajectory_model: dict,
29
+ will_hit_stumps: bool,
30
+ ) -> Tuple[str, int]:
31
+ """Return a simple LBW decision based on trajectory intersection.
32
+
33
+ Parameters
34
+ ----------
35
+ centers: list of tuple(int, int)
36
+ Sequence of detected ball centres.
37
+ trajectory_model: dict
38
+ The polynomial model fitted to the ball path (unused directly
39
+ here but included for extensibility).
40
+ will_hit_stumps: bool
41
+ Prediction that the ball's path intersects the stumps.
42
+
43
+ Returns
44
+ -------
45
+ Tuple[str, int]
46
+ The decision text (``"OUT"`` or ``"NOT OUT"``) and the index
47
+ of the impact frame. The impact frame is taken to be the
48
+ index of the last detection.
49
+ """
50
+ impact_frame_idx = len(centers) - 1 if centers else -1
51
+ decision = "OUT" if will_hit_stumps else "NOT OUT"
52
+ return decision, impact_frame_idx
drs_modules/trajectory.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Trajectory estimation for the cricket ball.
2
+
3
+ Professional ball tracking systems reconstruct the ball's path in 3D
4
+ from several camera angles and then use physics or machine learning
5
+ models to project its flight. Here we implement a far simpler
6
+ approach. Given a sequence of ball centre coordinates extracted from
7
+ a single camera (behind the bowler), we fit a polynomial curve to
8
+ approximate the ball's trajectory in image space. We assume that the
9
+ ball travels roughly along a parabolic path, so a quadratic fit to
10
+ ``y`` as a function of ``x`` is appropriate for the vertical drop.
11
+
12
+ Because we lack explicit knowledge of the camera's field of view, the
13
+ stumps' location is estimated relative to the range of observed ball
14
+ positions. If the projected path intersects a fixed region near the
15
+ bottom middle of the frame, we say that the ball would have hit the
16
+ stumps.
17
+ """
18
+
19
+ from __future__ import annotations
20
+
21
+ import numpy as np
22
+ from typing import Callable, Dict, List, Tuple
23
+
24
+
25
+ def estimate_trajectory(centers: List[Tuple[int, int]], timestamps: List[float]) -> Dict[str, object]:
26
+ """Fit a polynomial to the ball's path.
27
+
28
+ Parameters
29
+ ----------
30
+ centers: list of tuple(int, int)
31
+ Detected ball centre positions in pixel coordinates (x, y).
32
+ timestamps: list of float
33
+ Timestamps (in seconds) corresponding to each detection. Unused
34
+ in the current implementation but retained for extensibility.
35
+
36
+ Returns
37
+ -------
38
+ dict
39
+ A dictionary with keys ``coeffs`` (the polynomial coefficients
40
+ [a, b, c] for ``y = a*x^2 + b*x + c``) and ``model`` (a
41
+ callable that accepts an x coordinate and returns the
42
+ predicted y coordinate).
43
+ """
44
+ if not centers:
45
+ # No detections; return a dummy model
46
+ return {"coeffs": np.array([0.0, 0.0, 0.0]), "model": lambda x: 0 * x}
47
+
48
+ xs = np.array([pt[0] for pt in centers], dtype=np.float64)
49
+ ys = np.array([pt[1] for pt in centers], dtype=np.float64)
50
+
51
+ # Require at least 3 points for a quadratic fit; otherwise fall back
52
+ # to a linear fit
53
+ if len(xs) >= 3:
54
+ coeffs = np.polyfit(xs, ys, 2)
55
+ def model(x: np.ndarray | float) -> np.ndarray | float:
56
+ return coeffs[0] * (x ** 2) + coeffs[1] * x + coeffs[2]
57
+ else:
58
+ coeffs = np.polyfit(xs, ys, 1)
59
+ def model(x: np.ndarray | float) -> np.ndarray | float:
60
+ return coeffs[0] * x + coeffs[1]
61
+
62
+ return {"coeffs": coeffs, "model": model}
63
+
64
+
65
+ def predict_stumps_intersection(trajectory: Dict[str, object]) -> bool:
66
+ """Predict whether the ball's trajectory will hit the stumps.
67
+
68
+ The stumps are assumed to lie roughly in the centre of the frame
69
+ along the horizontal axis and occupy the lower quarter of the
70
+ vertical axis. This heuristic works reasonably well for videos
71
+ captured from behind the bowler. In a production system you
72
+ would calibrate the exact position of the stumps from the pitch
73
+ geometry.
74
+
75
+ Parameters
76
+ ----------
77
+ trajectory: dict
78
+ Output of :func:`estimate_trajectory`, containing the
79
+ polynomial model and the original ``centers`` list if needed.
80
+
81
+ Returns
82
+ -------
83
+ bool
84
+ True if the ball is predicted to hit the stumps, False otherwise.
85
+ """
86
+ model: Callable[[float], float] = trajectory["model"]
87
+ coeffs = trajectory["coeffs"]
88
+
89
+ # Recover approximate frame dimensions from the observed centres. We
90
+ # estimate the width and height as slightly larger than the max
91
+ # observed coordinates.
92
+ # Note: trajectory does not contain the centres directly, so we
93
+ # recompute width and height heuristically based on coefficient
94
+ # magnitudes. To avoid overcomplication we assign reasonable
95
+ # defaults if no centres were available.
96
+ if hasattr(trajectory, "centers"):
97
+ # never executed; left as placeholder
98
+ pass
99
+
100
+ # Use coefficients to infer approximate domain of x. The roots of
101
+ # derivative give extremum; but we simply sample across a range
102
+ # derived from typical video width (e.g. 640px)
103
+ frame_width = 640
104
+ frame_height = 360
105
+
106
+ # Estimate ball y position at the x coordinate corresponding to the
107
+ # middle stump: 50% of frame width
108
+ stumps_x = frame_width * 0.5
109
+ predicted_y = model(stumps_x)
110
+
111
+ # Define the vertical bounds of the wicket region in pixels. The
112
+ # top of the stumps is roughly three quarters down the frame and
113
+ # the bottom is at the very bottom. These ratios can be tuned.
114
+ stumps_y_low = frame_height * 0.65
115
+ stumps_y_high = frame_height * 0.95
116
+
117
+ return stumps_y_low <= predicted_y <= stumps_y_high
drs_modules/video_processing.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Video processing utilities for the DRS application.
2
+
3
+ This module provides helper functions to save uploaded videos to the
4
+ filesystem and to trim the last N seconds from a video. Using
5
+ OpenCV's ``VideoCapture`` and ``VideoWriter`` avoids external
6
+ dependencies like ffmpeg or moviepy, which may not be installed in
7
+ all execution environments.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import os
13
+ import shutil
14
+ from pathlib import Path
15
+ from typing import Union
16
+
17
+ import cv2
18
+
19
+
20
+ def save_uploaded_video(name: str, file_obj: Union[bytes, str, Path]) -> str:
21
+ """Persist an uploaded video to a predictable location on disk.
22
+
23
+ When a user records or uploads a video in the Gradio interface, it
24
+ arrives as a temporary file object. To analyse the video later we
25
+ copy it into the working directory using its original filename.
26
+
27
+ Parameters
28
+ ----------
29
+ name: str
30
+ The original filename from the upload widget.
31
+ file_obj: Union[bytes, str, Path]
32
+ The file-like object representing the uploaded video. Gradio
33
+ passes the file as a ``gradio.Files`` object whose `.name`
34
+ property holds the temporary path. This function accepts
35
+ either the temporary path or an open file handle.
36
+
37
+ Returns
38
+ -------
39
+ str
40
+ The absolute path where the video has been saved.
41
+ """
42
+ # Determine a safe output directory. Use the current working
43
+ # directory so that Gradio can later access the file by path.
44
+ output_dir = Path(os.getcwd()) / "user_videos"
45
+ output_dir.mkdir(exist_ok=True)
46
+
47
+ # Compose an output filename; avoid overwriting by prefixing with an
48
+ # incrementing integer if necessary.
49
+ base_name = Path(name).stem
50
+ ext = Path(name).suffix or ".mp4"
51
+ counter = 0
52
+ dest = output_dir / f"{base_name}{ext}"
53
+ while dest.exists():
54
+ counter += 1
55
+ dest = output_dir / f"{base_name}_{counter}{ext}"
56
+
57
+ # If file_obj is a path, simply copy it; otherwise, read and write
58
+ if isinstance(file_obj, (str, Path)):
59
+ shutil.copy(str(file_obj), dest)
60
+ else:
61
+ # Gradio passes a file-like object with a `.read()` method
62
+ with open(dest, "wb") as f_out:
63
+ f_out.write(file_obj.read())
64
+ return str(dest)
65
+
66
+
67
+ def trim_last_seconds(input_path: str, output_path: str, seconds: int) -> None:
68
+ """Save the last ``seconds`` of a video to ``output_path``.
69
+
70
+ This function reads the entire video file, calculates the starting
71
+ frame corresponding to ``seconds`` before the end, and writes the
72
+ remaining frames to a new video using OpenCV. If the video is
73
+ shorter than the requested duration, the whole video is copied.
74
+
75
+ Parameters
76
+ ----------
77
+ input_path: str
78
+ Path to the source video file.
79
+ output_path: str
80
+ Path where the trimmed video will be saved.
81
+ seconds: int
82
+ The duration from the end of the video to retain.
83
+ """
84
+ cap = cv2.VideoCapture(input_path)
85
+ if not cap.isOpened():
86
+ raise RuntimeError(f"Unable to open video: {input_path}")
87
+
88
+ fps = cap.get(cv2.CAP_PROP_FPS)
89
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
90
+ if fps <= 0:
91
+ fps = 30.0 # default fallback
92
+ frames_to_keep = int(seconds * fps)
93
+ start_frame = max(total_frames - frames_to_keep, 0)
94
+
95
+ # Prepare writer with the same properties as the input
96
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
97
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
98
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
99
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
100
+
101
+ # Skip frames until start_frame
102
+ current = 0
103
+ while current < start_frame:
104
+ ret, _ = cap.read()
105
+ if not ret:
106
+ break
107
+ current += 1
108
+
109
+ # Write remaining frames
110
+ while True:
111
+ ret, frame = cap.read()
112
+ if not ret:
113
+ break
114
+ out.write(frame)
115
+
116
+ cap.release()
117
+ out.release()
drs_modules/visualization.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Visualisation utilities for the DRS application.
2
+
3
+ This module contains functions to generate images and videos that
4
+ illustrate the ball's flight and the outcome of the LBW decision.
5
+ Using Matplotlib and OpenCV we create a 3D trajectory plot and an
6
+ annotated replay video. These assets are returned to the Gradio
7
+ interface for display to the user.
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import cv2
13
+ import numpy as np
14
+ import matplotlib
15
+ matplotlib.use("Agg") # Use a non‑interactive backend
16
+ import matplotlib.pyplot as plt
17
+ from mpl_toolkits.mplot3d import Axes3D # noqa: F401 # needed for 3D plots
18
+ from typing import List, Tuple, Callable
19
+
20
+
21
+ def generate_trajectory_plot(
22
+ centers: List[Tuple[int, int]],
23
+ trajectory: dict,
24
+ will_hit_stumps: bool,
25
+ output_path: str,
26
+ ) -> None:
27
+ """Create a 3D plot of the observed and predicted ball trajectory.
28
+
29
+ The x axis represents the horizontal pixel coordinate, the y axis
30
+ represents the vertical coordinate (top at 0), and the z axis
31
+ corresponds to the frame index (time). The predicted path is drawn
32
+ on the x–y plane at z=0 for clarity.
33
+
34
+ Parameters
35
+ ----------
36
+ centers: list of tuple(int, int)
37
+ Detected ball centre positions.
38
+ trajectory: dict
39
+ Output of :func:`modules.trajectory.estimate_trajectory`.
40
+ will_hit_stumps: bool
41
+ Whether the ball is predicted to hit the stumps; controls the
42
+ colour of the predicted path.
43
+ output_path: str
44
+ Where to save the resulting PNG image.
45
+ """
46
+ if not centers:
47
+ # If no points, draw an empty figure
48
+ fig = plt.figure(figsize=(6, 4))
49
+ ax = fig.add_subplot(111, projection="3d")
50
+ ax.set_title("No ball detections")
51
+ ax.set_xlabel("X (pixels)")
52
+ ax.set_ylabel("Y (pixels)")
53
+ ax.set_zlabel("Frame index")
54
+ fig.tight_layout()
55
+ fig.savefig(output_path)
56
+ plt.close(fig)
57
+ return
58
+
59
+ xs = np.array([c[0] for c in centers])
60
+ ys = np.array([c[1] for c in centers])
61
+ zs = np.arange(len(centers))
62
+
63
+ # Compute predicted path along the full x range
64
+ model: Callable[[float], float] = trajectory["model"]
65
+ x_range = np.linspace(xs.min(), xs.max(), 100)
66
+ y_pred = model(x_range)
67
+
68
+ fig = plt.figure(figsize=(6, 4))
69
+ ax = fig.add_subplot(111, projection="3d")
70
+
71
+ # Plot observed points
72
+ ax.plot(xs, ys, zs, 'o-', label="Detected ball path", color="blue")
73
+
74
+ # Plot predicted path on z=0 plane
75
+ colour = "green" if will_hit_stumps else "red"
76
+ ax.plot(x_range, y_pred, np.zeros_like(x_range), '--', label="Predicted path", color=colour)
77
+
78
+ ax.set_xlabel("X (pixels)")
79
+ ax.set_ylabel("Y (pixels)")
80
+ ax.set_zlabel("Frame index")
81
+ ax.set_title("Ball trajectory (observed vs predicted)")
82
+ ax.legend()
83
+ ax.invert_yaxis() # Invert y axis to match image coordinates
84
+ fig.tight_layout()
85
+ fig.savefig(output_path)
86
+ plt.close(fig)
87
+
88
+
89
+ def annotate_video_with_tracking(
90
+ video_path: str,
91
+ centers: List[Tuple[int, int]],
92
+ trajectory: dict,
93
+ will_hit_stumps: bool,
94
+ impact_frame_idx: int,
95
+ output_path: str,
96
+ ) -> None:
97
+ """Create an annotated replay video highlighting key elements.
98
+
99
+ The function reads the trimmed input video and writes out a new
100
+ video with the following overlays:
101
+
102
+ * The detected ball centre (small filled circle).
103
+ * A polyline showing the path of the ball up to the current
104
+ frame.
105
+ * The predicted trajectory across the frame, drawn as a dashed
106
+ curve.
107
+ * A rectangle representing the stumps zone at the bottom centre
108
+ of the frame; coloured green if the ball is predicted to hit
109
+ and red otherwise.
110
+ * The text "OUT" or "NOT OUT" displayed after the impact frame.
111
+ * Auto zoom effect on the impact frame by drawing a thicker
112
+ circle around the ball.
113
+
114
+ Parameters
115
+ ----------
116
+ video_path: str
117
+ Path to the trimmed input video.
118
+ centers: list of tuple(int, int)
119
+ Detected ball centres for each frame analysed.
120
+ trajectory: dict
121
+ Output of :func:`modules.trajectory.estimate_trajectory`.
122
+ will_hit_stumps: bool
123
+ Whether the ball is predicted to hit the stumps.
124
+ impact_frame_idx: int
125
+ Index of the frame considered as the impact frame.
126
+ output_path: str
127
+ Where to save the annotated video.
128
+ """
129
+ cap = cv2.VideoCapture(video_path)
130
+ if not cap.isOpened():
131
+ raise RuntimeError(f"Could not open video {video_path}")
132
+
133
+ fps = cap.get(cv2.CAP_PROP_FPS) or 30.0
134
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
135
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
136
+ fourcc = cv2.VideoWriter_fourcc(*"mp4v")
137
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
138
+
139
+ model: Callable[[float], float] = trajectory["model"]
140
+ # Precompute predicted path points for drawing on each frame
141
+ x_vals = np.linspace(0, width - 1, 50)
142
+ y_preds = model(x_vals)
143
+ # Ensure predicted y values stay within frame
144
+ y_preds_clamped = np.clip(y_preds, 0, height - 1).astype(int)
145
+
146
+ # Define stumps zone coordinates
147
+ stumps_width = int(width * 0.1)
148
+ stumps_height = int(height * 0.3)
149
+ stumps_x = int((width - stumps_width) / 2)
150
+ stumps_y = int(height * 0.65)
151
+ stumps_color = (0, 255, 0) if will_hit_stumps else (0, 0, 255)
152
+
153
+ frame_idx = 0
154
+ path_points: List[Tuple[int, int]] = []
155
+
156
+ while True:
157
+ ret, frame = cap.read()
158
+ if not ret:
159
+ break
160
+
161
+ # Draw stumps region on every frame
162
+ cv2.rectangle(
163
+ frame,
164
+ (stumps_x, stumps_y),
165
+ (stumps_x + stumps_width, stumps_y + stumps_height),
166
+ stumps_color,
167
+ 2,
168
+ )
169
+
170
+ # Draw predicted trajectory line (dashed effect by skipping points)
171
+ for i in range(len(x_vals) - 1):
172
+ if i % 4 != 0:
173
+ continue
174
+ pt1 = (int(x_vals[i]), int(y_preds_clamped[i]))
175
+ pt2 = (int(x_vals[i + 1]), int(y_preds_clamped[i + 1]))
176
+ cv2.line(frame, pt1, pt2, stumps_color, 1, lineType=cv2.LINE_AA)
177
+
178
+ # If we have a centre for this frame, draw it and update the path
179
+ if frame_idx < len(centers):
180
+ cx, cy = centers[frame_idx]
181
+ path_points.append((cx, cy))
182
+ # Draw past trajectory as a polyline
183
+ if len(path_points) > 1:
184
+ cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)
185
+ # Draw the ball centre (bigger on impact frame)
186
+ radius = 5
187
+ thickness = -1
188
+ colour = (255, 255, 255)
189
+ if frame_idx == impact_frame_idx:
190
+ # Auto zoom effect: larger circle and thicker outline
191
+ radius = 10
192
+ thickness = 2
193
+ colour = (0, 255, 255)
194
+ cv2.circle(frame, (cx, cy), radius, colour, thickness)
195
+ else:
196
+ # Continue drawing the path beyond detection frames
197
+ if len(path_points) > 1:
198
+ cv2.polylines(frame, [np.array(path_points, dtype=np.int32)], False, (255, 0, 0), 2)
199
+
200
+ # After the impact frame, display the decision text
201
+ if frame_idx >= impact_frame_idx and impact_frame_idx >= 0:
202
+ decision_text = "OUT" if will_hit_stumps else "NOT OUT"
203
+ font = cv2.FONT_HERSHEY_SIMPLEX
204
+ cv2.putText(
205
+ frame,
206
+ decision_text,
207
+ (50, 50),
208
+ font,
209
+ 1.5,
210
+ (0, 255, 0) if will_hit_stumps else (0, 0, 255),
211
+ 3,
212
+ lineType=cv2.LINE_AA,
213
+ )
214
+
215
+ writer.write(frame)
216
+ frame_idx += 1
217
+
218
+ cap.release()
219
+ writer.release()