AjaykumarPilla commited on
Commit
98a8dfa
·
verified ·
1 Parent(s): d7c6168

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -205
app.py DELETED
@@ -1,205 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import torch
4
- from ultralytics import YOLO
5
- import gradio as gr
6
- from scipy.interpolate import interp1d
7
- import uuid
8
- import os
9
- import tempfile
10
-
11
- # Load YOLOv8 model
12
- model = YOLO("best.pt")
13
- model.to('cuda' if torch.cuda.is_available() else 'cpu')
14
-
15
- # Resolve ball class index
16
- ball_class_index = None
17
- for k, v in model.names.items():
18
- if v.lower() == "cricketball":
19
- ball_class_index = k
20
- break
21
- if ball_class_index is None:
22
- raise ValueError("Class 'cricketBall' not found in model.names")
23
-
24
- # Constants
25
- STUMPS_WIDTH = 0.2286
26
- BALL_DIAMETER = 0.073
27
- FRAME_RATE = 20
28
- SLOW_MOTION_FACTOR = 2 # Normal speed; increase to slow down
29
- CONF_THRESHOLD = 0.2
30
- IMPACT_ZONE_Y = 0.85
31
- IMPACT_DELTA_Y = 50
32
- PITCH_LENGTH = 20.12
33
- STUMPS_HEIGHT = 0.71
34
- MAX_POSITION_JUMP = 30
35
-
36
- def process_video(video_path):
37
- if not os.path.exists(video_path):
38
- return [], [], [], "Error: Video file not found"
39
- cap = cv2.VideoCapture(video_path)
40
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
41
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
42
- frames, ball_positions, detection_frames, debug_log = [], [], [], []
43
- frame_count = 0
44
-
45
- while cap.isOpened():
46
- ret, frame = cap.read()
47
- if not ret:
48
- break
49
- frame_count += 1
50
- frames.append(frame.copy())
51
-
52
- # 🚀 Faster inference with fixed optimized size
53
- results = model.predict(frame, conf=CONF_THRESHOLD, imgsz=640, iou=0.5, max_det=1)
54
-
55
- detections = 0
56
- for detection in results[0].boxes:
57
- if int(detection.cls) == ball_class_index:
58
- detections += 1
59
- if detections == 1:
60
- x1, y1, x2, y2 = detection.xyxy[0].cpu().numpy()
61
- ball_positions.append([(x1 + x2) / 2, (y1 + y2) / 2])
62
- detection_frames.append(frame_count - 1)
63
- cv2.rectangle(frame, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
64
-
65
- frames[-1] = frame
66
- debug_log.append(f"Frame {frame_count}: {detections} ball detections")
67
-
68
- cap.release()
69
- return frames, ball_positions, detection_frames, "\n".join(debug_log)
70
-
71
- def find_bounce_point(ball_coords):
72
- y_coords = [p[1] for p in ball_coords]
73
- for i in range(2, len(y_coords) - 2):
74
- dy1 = y_coords[i] - y_coords[i - 1]
75
- dy2 = y_coords[i + 1] - y_coords[i]
76
- if dy1 > 0 and dy2 < 0:
77
- if i > len(y_coords) * 0.2:
78
- return ball_coords[i]
79
- return ball_coords[len(ball_coords) // 2]
80
-
81
- def estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width):
82
- if len(ball_positions) < 2:
83
- return None, None, None, "Error: Not enough ball detections"
84
-
85
- filtered_positions = [ball_positions[0]]
86
- filtered_frames = [detection_frames[0]]
87
- for i in range(1, len(ball_positions)):
88
- prev, curr = filtered_positions[-1], ball_positions[i]
89
- if np.linalg.norm(np.array(curr) - np.array(prev)) <= MAX_POSITION_JUMP:
90
- filtered_positions.append(curr)
91
- filtered_frames.append(detection_frames[i])
92
-
93
- if len(filtered_positions) < 2:
94
- return None, None, None, "Error: Filtered detections too few"
95
-
96
- x_vals = [p[0] for p in filtered_positions]
97
- y_vals = [p[1] for p in filtered_positions]
98
- times = np.array(filtered_frames) / FRAME_RATE
99
-
100
- try:
101
- fx = interp1d(times, x_vals, kind='cubic', fill_value="extrapolate")
102
- fy = interp1d(times, y_vals, kind='cubic', fill_value="extrapolate")
103
- except Exception as e:
104
- return None, None, None, f"Interpolation error: {str(e)}"
105
-
106
- total_frames = max(filtered_frames) - min(filtered_frames) + 1
107
- t_full = np.linspace(times[0], times[-1], max(5, total_frames * SLOW_MOTION_FACTOR))
108
- x_full = fx(t_full)
109
- y_full = fy(t_full)
110
- trajectory = list(zip(x_full, y_full))
111
-
112
- pitch_point = find_bounce_point(filtered_positions)
113
- impact_point = filtered_positions[-1]
114
-
115
- return trajectory, pitch_point, impact_point, "Trajectory estimated successfully"
116
-
117
- def lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point):
118
- if not frames or not trajectory or len(ball_positions) < 2:
119
- return "Not enough data", trajectory, pitch_point, impact_point
120
-
121
- frame_height, frame_width = frames[0].shape[:2]
122
- stumps_x = frame_width / 2
123
- stumps_y = frame_height * 0.9
124
- stumps_width_pixels = frame_width * (STUMPS_WIDTH / 3.0)
125
-
126
- pitch_x, _ = pitch_point
127
- impact_x, impact_y = impact_point
128
-
129
- if pitch_x < stumps_x - stumps_width_pixels / 2 or pitch_x > stumps_x + stumps_width_pixels / 2:
130
- return f"Not Out (Pitched outside line)", trajectory, pitch_point, impact_point
131
- if impact_x < stumps_x - stumps_width_pixels / 2 or impact_x > stumps_x + stumps_width_pixels / 2:
132
- return f"Not Out (Impact outside line)", trajectory, pitch_point, impact_point
133
- for x, y in trajectory:
134
- if abs(x - stumps_x) < stumps_width_pixels / 2 and abs(y - stumps_y) < frame_height * 0.1:
135
- return f"Out (Ball projected to hit stumps)", trajectory, pitch_point, impact_point
136
- return f"Not Out (Missing stumps)", trajectory, pitch_point, impact_point
137
-
138
- def generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames):
139
- if not frames or not trajectory:
140
- return None, None
141
-
142
- height, width = frames[0].shape[:2]
143
- slow_path = os.path.join(tempfile.gettempdir(), f"drs_slow_{uuid.uuid4()}.mp4")
144
- normal_path = os.path.join(tempfile.gettempdir(), f"drs_normal_{uuid.uuid4()}.mp4")
145
-
146
- slow_writer = cv2.VideoWriter(slow_path, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE / SLOW_MOTION_FACTOR, (width, height))
147
- normal_writer = cv2.VideoWriter(normal_path, cv2.VideoWriter_fourcc(*'mp4v'), FRAME_RATE, (width, height))
148
-
149
- min_frame = min(detection_frames)
150
- max_frame = max(detection_frames)
151
- total_frames = max_frame - min_frame + 1
152
- traj_per_frame = max(1, len(trajectory) // total_frames)
153
- indices = [min(i * traj_per_frame, len(trajectory) - 1) for i in range(total_frames)]
154
-
155
- for i, frame in enumerate(frames):
156
- frame_copy = frame.copy()
157
- idx = i - min_frame
158
- if 0 <= idx < len(indices):
159
- end_idx = indices[idx]
160
- points = np.array(trajectory[:end_idx + 1], dtype=np.int32).reshape((-1, 1, 2))
161
- cv2.polylines(frame, [points], False, (255, 0, 0), 2)
162
- cv2.polylines(frame_copy, [points], False, (255, 0, 0), 2)
163
- if pitch_point and i == detection_frames[0]:
164
- cv2.circle(frame, tuple(map(int, pitch_point)), 6, (0, 0, 255), -1)
165
- if impact_point and i == detection_frames[-1]:
166
- cv2.circle(frame, tuple(map(int, impact_point)), 6, (0, 255, 255), -1)
167
- for _ in range(SLOW_MOTION_FACTOR):
168
- slow_writer.write(frame)
169
- normal_writer.write(frame_copy)
170
-
171
- slow_writer.release()
172
- normal_writer.release()
173
- return slow_path, normal_path
174
-
175
- def drs_review(video):
176
- frames, ball_positions, detection_frames, debug_log = process_video(video)
177
- if not frames or not ball_positions:
178
- return "No frames or detections found.", None, None
179
-
180
- frame_height, frame_width = frames[0].shape[:2]
181
- trajectory, pitch_point, impact_point, log = estimate_trajectory(ball_positions, detection_frames, frame_height, frame_width)
182
- if not trajectory:
183
- return f"{log}\n{debug_log}", None, None
184
-
185
- decision, _, _, _ = lbw_decision(ball_positions, trajectory, frames, pitch_point, impact_point)
186
- slow_path, normal_path = generate_replay(frames, trajectory, pitch_point, impact_point, detection_frames)
187
-
188
- result_log = f"DRS Decision: {decision}\n\n{log}\n\n{debug_log}"
189
- return result_log, slow_path, normal_path
190
-
191
- # Gradio Interface
192
- iface = gr.Interface(
193
- fn=drs_review,
194
- inputs=gr.Video(label="Upload Cricket Delivery Video"),
195
- outputs=[
196
- gr.Textbox(label="DRS Result and Debug Info"),
197
- gr.Video(label="Slow-Motion Replay"),
198
- gr.Video(label="Normal-Speed Trajectory Only")
199
- ],
200
- title="GullyDRS - AI-Powered LBW Review",
201
- description="Upload a cricket delivery video. The system will track the ball, estimate trajectory, and return both slow-motion and normal-speed replays."
202
- )
203
-
204
- if __name__ == "__main__":
205
- iface.launch()