AjaykumarPilla commited on
Commit
e0fcf03
·
verified ·
1 Parent(s): 98a8dfa

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +210 -0
app.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ from ultralytics import YOLO
5
+ import gradio as gr
6
+ from scipy.interpolate import interp1d
7
+ import plotly.graph_objects as go
8
+ import uuid
9
+ import os
10
+ import tempfile
11
+
12
+ # Load YOLOv8 model and resolve class index
13
+ model = YOLO("best.pt")
14
+ model.to('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ # Dynamically resolve ball class index
17
+ ball_class_index = None
18
+ for k, v in model.names.items():
19
+ if v.lower() == "cricketball":
20
+ ball_class_index = k
21
+ break
22
+ if ball_class_index is None:
23
+ raise ValueError("Class 'cricketBall' not found in model.names")
24
+
25
+ # Constants
26
+ STUMPS_WIDTH = 0.2286
27
+ BALL_DIAMETER = 0.073
28
+ FRAME_RATE = 20
29
+ SLOW_MOTION_FACTOR = 3
30
+ CONF_THRESHOLD = 0.2
31
+ IMPACT_ZONE_Y = 0.85
32
+ IMPACT_DELTA_Y = 50
33
+ PITCH_LENGTH = 20.12
34
+ STUMPS_HEIGHT = 0.71
35
+ MAX_POSITION_JUMP = 30
36
+
37
+ def process_video(video_path):
38
+ if not os.path.exists(video_path):
39
+ return [], [], [], "Error: Video file not found"
40
+ cap = cv2.VideoCapture(video_path)
41
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
42
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
43
+ frames, ball_positions, detection_frames, debug_log = [], [], [], []
44
+ frame_count = 0
45
+
46
+ while cap.isOpened():
47
+ ret, frame = cap.read()
48
+ if not ret:
49
+ break
50
+ frame_count += 1
51
+ frames.append(frame.copy())
52
+ results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=(frame_height, frame_width), iou=0.5, max_det=1)
53
+ detections = 0
54
+ for detection in results[0].boxes:
55
+ if int(detection.cls) == ball_class_index:
56
+ detections += 1
57
+ if detections == 1:
58
+ x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
59
+ ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
60
+ detection_frames.append(frame_count - 1)
61
+ cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
62
+ frames[-1] = frame
63
+ debug_log.append(f"Frame {frame_count}: {detections} ball detections")
64
+ cap.release()
65
+
66
+ if not ball_positions:
67
+ debug_log.append("No balls detected in any frame")
68
+ else:
69
+ debug_log.append(f"Total ball detections: {len(ball_positions)}")
70
+ debug_log.append(f"Video resolution: {frame_width}x{frame_height}")
71
+
72
+ return frames, ball_positions, detection_frames, "\n".join(debug_log)
73
+
74
+ def find_bounce_point(ball_coords):
75
+ """
76
+ Detect bounce point using y-derivative reversal with early-frame suppression.
77
+ Looks for where y increases then decreases (ball hits ground).
78
+ """
79
+ y_coords = [p[1] for p in ball_coords]
80
+ min_index = None
81
+
82
+ for i in range(2, len(y_coords) - 2):
83
+ dy1 = y_coords[i] - y_coords[i - 1]
84
+ dy2 = y_coords[i + 1] - y_coords[i]
85
+ if dy1 > 0 and dy2 < 0:
86
+ if i > len(y_coords) * 0.2:
87
+ min_index = i
88
+ break
89
+
90
+ if min_index is not None:
91
+ return ball_coords[min_index]
92
+
93
+ return ball_coords[len(ball_coords)//2]
94
+
95
+ def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
96
+ if not frames or not trajectory or len(ball_positions) < 2:
97
+ return "Not enough data", trajectory, pitch_point, impact_point
98
+
99
+ frame_height, frame_width = frames[0].shape[:2]
100
+ stumps_x = frame_width / 2
101
+ stumps_y = frame_height * 0.9
102
+ stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
103
+
104
+ pitch_x, _ = pitch_point
105
+ impact_x, impact_y = impact_point
106
+
107
+ if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
108
+ return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
109
+ if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
110
+ return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
111
+ for x, y in trajectory:
112
+ if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
113
+ return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
114
+ return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
115
+
116
+ def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
117
+ if len(ball_positions) < 2:
118
+ return None, None, None, "Error: Not enough ball detections"
119
+
120
+ filtered_positions = [ball_positions[0]]
121
+ filtered_frames = [detection_frames[0]]
122
+ for i in range(1, len(ball_positions)):
123
+ prev, curr = filtered_positions[-1], ball_positions[i]
124
+ if np.linalg.norm(np.array(curr) - np.array(prev)) <= MAX_POSITION_JUMP:
125
+ filtered_positions.append(curr)
126
+ filtered_frames.append(detection_frames[i])
127
+
128
+ if len(filtered_positions) < 2:
129
+ return None, None, None, "Error: Filtered detections too few"
130
+
131
+ x_vals = [p[0] for p in filtered_positions]
132
+ y_vals = [p[1] for p in filtered_positions]
133
+ times = np.array(filtered_frames) / FRAME_RATE
134
+
135
+ try:
136
+ fx = interp1d(times, x_vals, kind='cubic', fill_value="extrapolate")
137
+ fy = interp1d(times, y_vals, kind='cubic', fill_value="extrapolate")
138
+ except Exception as e:
139
+ return None, None, None, f"Interpolation error: {str(e)}"
140
+
141
+ total_frames = max(filtered_frames) - min(filtered_frames) + 1
142
+ t_full = np.linspace(times[0], times[-1], max(5, total_frames * SLOW_MOTION_FACTOR))
143
+ x_full = fx(t_full)
144
+ y_full = fy(t_full)
145
+ trajectory = list(zip(x_full, y_full))
146
+
147
+ pitch_point = find_bounce_point(filtered_positions)
148
+ impact_point = filtered_positions[-1]
149
+
150
+ return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
151
+
152
+ def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
153
+ if not frames or not trajectory:
154
+ return None
155
+
156
+ temp_file = os.path.join(tempfile.gettempdir(), f"drs_output_{uuid.uuid4()}.mp4")
157
+ height, width = frames[0].shape[:2]
158
+ out = cv2.VideoWriter(temp_file, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
159
+
160
+ min_frame = min(detection_frames)
161
+ max_frame = max(detection_frames)
162
+ total_frames = max_frame - min_frame + 1
163
+ traj_per_frame = max(1, len(trajectory) // total_frames)
164
+ indices = [min(i * traj_per_frame, len(trajectory)-1) for i in range(total_frames)]
165
+
166
+ for i, frame in enumerate(frames):
167
+ idx = i - min_frame
168
+ if 0 <= idx < len(indices):
169
+ end_idx = indices[idx]
170
+ points = np.array(trajectory[:end_idx+1], dtype=np.int32).reshape((-1, 1, 2))
171
+ cv2.polylines(frame, [points], False, (255, 0, 0), 2)
172
+ if pitch_point and i == detection_frames[0]:
173
+ cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
174
+ if impact_point and i == detection_frames[-1]:
175
+ cv2.circle(frame, tuple(map(int, impact_point)), 6, (0, 255, 255), -1)
176
+ for _ in range(SLOW_MOTION_FACTOR):
177
+ out.write(frame)
178
+ out.release()
179
+ return temp_file
180
+
181
+ def drs_review(video):
182
+ frames, ball_positions, detection_frames, debug_log = process_video(video)
183
+ if not frames or not ball_positions:
184
+ return "No frames or detections found.", None
185
+
186
+ frame_height, frame_width = frames[0].shape[:2]
187
+ trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
188
+ if not trajectory:
189
+ return f"{log}\n{debug_log}", None
190
+
191
+ decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
192
+ replay_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
193
+
194
+ result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
195
+ return result_log, replay_path
196
+
197
+ # Gradio Interface
198
+ iface = gr.Interface(
199
+ fn=drs_review,
200
+ inputs=gr.Video(label="Upload Cricket Delivery Video"),
201
+ outputs=[
202
+ gr.Textbox(label="DRS Result and Debug Info"),
203
+ gr.Video(label="Replay with Trajectory & Decision")
204
+ ],
205
+ title="GullyDRS - AI-Powered LBW Review",
206
+ description="Upload a cricket delivery video. The system will track the ball, estimate trajectory, and return a replay with an OUT/NOT OUT decision."
207
+ )
208
+
209
+ if __name__ == "__main__":
210
+ iface.launch()