AjaykumarPilla commited on
Commit
c3429f6
·
verified ·
1 Parent(s): 150d1f9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -182
app.py CHANGED
@@ -1,196 +1,201 @@
1
- import gradio as gr
2
- import torch
3
- from ultralytics import YOLO
4
  import cv2
5
  import numpy as np
6
- from PIL import Image
 
 
 
 
7
  import os
8
- import matplotlib.pyplot as plt
9
- from scipy.interpolate import interp1d
10
-
11
- # Load YOLOv5 model
12
- model = YOLO("best.pt")
13
-
14
- class CentroidTracker:
15
- def __init__(self, max_disappeared=50):
16
- self.next_object_id = 0
17
- self.objects = {}
18
- self.disappeared = {}
19
- self.max_disappeared = max_disappeared
20
-
21
- def register(self, centroid):
22
- self.objects[self.next_object_id] = centroid
23
- self.disappeared[self.next_object_id] = 0
24
- self.next_object_id += 1
25
-
26
- def deregister(self, object_id):
27
- del self.objects[object_id]
28
- del self.disappeared[object_id]
29
-
30
- def update(self, rects):
31
- if len(rects) == 0:
32
- for object_id in list(self.disappeared.keys()):
33
- self.disappeared[object_id] += 1
34
- if self.disappeared[object_id] > self.max_disappeared:
35
- self.deregister(object_id)
36
- return self.objects
37
-
38
- input_centroids = np.zeros((len(rects), 2), dtype="int")
39
- for (i, (x1, y1, x2, y2)) in enumerate(rects):
40
- cX = int((x1 + x2) / 2.0)
41
- cY = int((y1 + y2) / 2.0)
42
- input_centroids[i] = (cX, cY)
43
-
44
- if len(self.objects) == 0:
45
- for i in range(len(input_centroids)):
46
- self.register(input_centroids[i])
47
- else:
48
- object_ids = list(self.objects.keys())
49
- object_centroids = list(self.objects.values())
50
- D = np.sqrt(((input_centroids[:, None] - object_centroids) ** 2).sum(axis=2))
51
- rows = D.min(axis=1).argsort()
52
- cols = D.argmin(axis=1)[rows]
53
- used_rows = set()
54
- used_cols = set()
55
- for (row, col) in zip(rows, cols):
56
- if row in used_rows or col in used_cols:
57
- continue
58
- object_id = object_ids[col]
59
- self.objects[object_id] = input_centroids[row]
60
- self.disappeared[object_id] = 0
61
- used_rows.add(row)
62
- used_cols.add(col)
63
- unused_rows = set(range(0, D.shape[0])).difference(used_rows)
64
- unused_cols = set(range(0, D.shape[1])).difference(used_cols)
65
- if D.shape[0] >= D.shape[1]:
66
- for row in unused_rows:
67
- self.register(input_centroids[row])
68
- else:
69
- for col in unused_cols:
70
- object_id = object_ids[col]
71
- self.disappeared[object_id] += 1
72
- if self.disappeared[object_id] > self.max_disappeared:
73
- self.deregister(object_id)
74
- return self.objects
75
-
76
- def detect_and_track_ball(video_path, conf_threshold=0.5, iou_threshold=0.5):
77
- """
78
- Detect and track ball in video, generate pitch map, and predict LBW outcome.
79
-
80
- Args:
81
- video_path: Path to uploaded video
82
- conf_threshold: Confidence threshold for detection
83
- iou_threshold: IoU threshold for non-max suppression
84
-
85
- Returns:
86
- Tuple of (annotated video path, pitch map image path, LBW decision)
87
- """
88
- # Initialize tracker
89
- tracker = CentroidTracker(max_disappeared=10)
90
  cap = cv2.VideoCapture(video_path)
91
- output_path = "output_video.mp4"
92
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
93
- out = cv2.VideoWriter(output_path, fourcc, 30.0,
94
- (int(cap.get(3)), int(cap.get(4))))
95
-
96
- # Store ball centroids for trajectory
97
- centroids = []
98
- pitch_points = []
99
-
100
  while cap.isOpened():
 
101
  ret, frame = cap.read()
102
  if not ret:
103
  break
104
-
105
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
106
- results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
107
-
108
- rects = []
109
- for box in results[0].boxes:
110
- x1, y1, x2, y2 = map(int, box.xyxy[0])
111
- conf = box.conf[0]
112
- label = f"Ball: {conf:.2f}"
113
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
114
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
115
- rects.append((x1, y1, x2, y2))
116
-
117
- # Update tracker
118
- objects = tracker.update(rects)
119
- for object_id, centroid in objects.items():
120
- cv2.circle(frame, centroid, 5, (0, 0, 255), -1)
121
- centroids.append(centroid)
122
-
123
- out.write(frame)
124
-
 
 
 
 
 
125
  cap.release()
126
- out.release()
127
-
128
- # Generate pitch map
129
- pitch_map_path = "pitch_map.png"
130
- fig, ax = plt.subplots(figsize=(8, 4))
131
- ax.set_xlim(0, 22) # Pitch length in meters (approx)
132
- ax.set_ylim(-1.5, 1.5) # Pitch width (approx)
133
- ax.set_xlabel("Length (m)")
134
- ax.set_ylabel("Width (m)")
135
- ax.set_title("Pitch Map with Ball Trajectory")
136
-
137
- # Plot stumps
138
- ax.plot([20.12, 20.12], [-0.135, 0.135], 'k-', lw=5) # Stumps at bowling end
139
- ax.plot([0, 0], [-0.135, 0.135], 'k-', lw=5) # Stumps at batting end
140
- ax.plot([0, 20.12], [0, 0], 'k--') # Pitch center line
141
-
142
- # Map centroids to pitch coordinates (simplified scaling)
143
- if centroids:
144
- x_coords = [20.12 - (c[1] / cap.get(4)) * 20.12 for c in centroids] # Scale y to pitch length
145
- y_coords = [(c[0] / cap.get(3)) * 2.7 - 1.35 for c in centroids] # Scale x to pitch width
146
- ax.plot(x_coords, y_coords, 'ro-', label="Ball Trajectory")
147
- pitch_points = list(zip(x_coords, y_coords))
148
-
149
- ax.legend()
150
- plt.savefig(pitch_map_path)
151
- plt.close()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # LBW Decision (simplified physics-based model)
154
- lbw_decision = "Not Out"
155
- if pitch_points:
156
- # Check pitching, impact, and wickets
157
- pitching = any(0 <= x <= 20.12 and -0.135 <= y <= 0.135 for x, y in pitch_points[:len(pitch_points)//2])
158
- impact = any(18 <= x <= 20.12 for x, y in pitch_points[len(pitch_points)//2:])
159
-
160
- # Fit a quadratic curve to predict trajectory post-impact
161
- if len(x_coords) > 2:
162
- t = np.linspace(0, 1, len(x_coords))
163
- f_x = interp1d(t, x_coords, kind='quadratic', fill_value="extrapolate")
164
- f_y = interp1d(t, y_coords, kind='quadratic', fill_value="extrapolate")
165
- t_future = np.array([1.5]) # Predict beyond impact
166
- x_future = f_x(t_future)[0]
167
- y_future = f_y(t_future)[0]
168
- wickets = (18 <= x_future <= 20.12) and (-0.135 <= y_future <= 0.135)
169
-
170
- if pitching and impact and wickets:
171
- lbw_decision = "Out"
172
- elif pitching and impact:
173
- lbw_decision = "Umpire's Call" # Marginal case
174
-
175
- return output_path, pitch_map_path, lbw_decision
 
 
 
 
 
 
 
 
 
176
 
177
  # Gradio interface
178
  with gr.Blocks() as demo:
179
- gr.Markdown("# DRS Review System for Cricket")
180
- gr.Markdown("Upload a cricket video to analyze ball tracking, pitch mapping, and LBW review. Adjust thresholds for detection accuracy.")
181
-
182
- video_input = gr.Video(label="Upload Cricket Video")
183
- conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
184
- iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
185
- output_video = gr.Video(label="Annotated Video with Ball Tracking")
186
- output_image = gr.Image(label="Pitch Map")
187
- output_text = gr.Textbox(label="LBW Decision")
188
- submit_button = gr.Button("Analyze DRS")
189
-
190
- submit_button.click(
191
- fn=detect_and_track_ball,
192
- inputs=[video_input, conf_slider, iou_slider],
193
- outputs=[output_video, output_image, output_text]
194
- )
195
 
196
- demo.launch()
 
 
 
 
 
1
  import cv2
2
  import numpy as np
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import plotly.graph_objects as go
6
+ import torch
7
+ import gradio as gr
8
  import os
9
+ from scipy.optimize import curve_fit
10
+ import sys
11
+
12
+ # Add yolov5 directory to sys.path
13
+ sys.path.append(os.path.join(os.path.dirname(__file__), "yolov5"))
14
+
15
+ # Import YOLOv5 modules
16
+ from models.experimental import attempt_load
17
+ from utils.general import non_max_suppression, xywh2xyxy
18
+
19
+ # Cricket pitch dimensions (in meters)
20
+ PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
21
+ PITCH_WIDTH = 3.05 # Width of pitch
22
+ STUMP_HEIGHT = 0.71 # Stump height
23
+ STUMP_WIDTH = 0.2286 # Stump width (including bails)
24
+
25
+ # Load model
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model = attempt_load("best.pt") # Load without map_location
28
+ model.to(device).eval() # Move model to device and set to evaluation mode
29
+
30
+ # Function to process video and detect ball
31
+ def process_video(video_path):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  cap = cv2.VideoCapture(video_path)
33
+ frame_rate = cap.get(cv2.CAP_PROP_FPS)
34
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
+ positions = []
37
+ frame_numbers = []
38
+ bounce_frame = None
39
+ bounce_point = None
40
+
 
41
  while cap.isOpened():
42
+ frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
43
  ret, frame = cap.read()
44
  if not ret:
45
  break
46
+
47
+ # Preprocess frame for YOLOv5
48
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
+ img = torch.from_numpy(img).to(device).float() / 255.0
50
+ img = img.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
51
+
52
+ # Run inference
53
+ with torch.no_grad():
54
+ pred = model(img)[0]
55
+ pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
56
+
57
+ # Process detections
58
+ for det in pred:
59
+ if det is not None and len(det):
60
+ det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
61
+ for *xyxy, conf, cls in det:
62
+ x_center = (xyxy[0] + xyxy[2]) / 2
63
+ y_center = (xyxy[1] + xyxy[3]) / 2
64
+ positions.append((x_center.item(), y_center.item()))
65
+ frame_numbers.append(frame_num)
66
+
67
+ # Detect bounce (lowest y_center point)
68
+ if bounce_frame is None or y_center > positions[bounce_frame][1]:
69
+ bounce_frame = len(frame_numbers) - 1
70
+ bounce_point = (x_center.item(), y_center.item())
71
+
72
  cap.release()
73
+ return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
74
+
75
+ # Polynomial function for trajectory fitting
76
+ def poly_func(x, a, b, c):
77
+ return a * x**2 + b * x + c
78
+
79
+ # Predict trajectory and LBW decision
80
+ def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
81
+ if len(positions) < 3:
82
+ return None, "Insufficient detections for trajectory prediction"
83
+
84
+ x_coords = [p[0] for p in positions]
85
+ y_coords = [p[1] for p in positions]
86
+ frames = np.array(frame_numbers)
87
+
88
+ # Fit polynomial to x and y coordinates
89
+ try:
90
+ popt_x, _ = curve_fit(poly_func, frames, x_coords)
91
+ popt_y, _ = curve_fit(poly_func, frames, y_coords)
92
+ except:
93
+ return None, "Failed to fit trajectory"
94
+
95
+ # Extrapolate to stumps
96
+ frame_max = max(frames) + 10
97
+ future_frames = np.linspace(min(frames), frame_max, 100)
98
+ x_pred = poly_func(future_frames, *popt_x)
99
+ y_pred = poly_func(future_frames, *popt_y)
100
+
101
+ # Check if trajectory hits stumps
102
+ stump_x = frame_width / 2
103
+ stump_y = frame_height
104
+ stump_hit = False
105
+ for x, y in zip(x_pred, y_pred):
106
+ if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
107
+ stump_hit = True
108
+ break
109
+
110
+ lbw_decision = "OUT" if stump_hit else "NOT OUT"
111
+ return list(zip(future_frames, x_pred, y_pred)), lbw_decision
112
+
113
+ # Map pitch location
114
+ def map_pitch(bounce_point, frame_width, frame_height):
115
+ if bounce_point is None:
116
+ return None, "No bounce detected"
117
+
118
+ x, y = bounce_point
119
+ pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
120
+ pitch_y = (1 - y / frame_height) * PITCH_LENGTH
121
+ return pitch_x, pitch_y
122
+
123
+ # Estimate ball speed
124
+ def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
125
+ if len(positions) < 2:
126
+ return None, "Insufficient detections for speed estimation"
127
+
128
+ distances = []
129
+ for i in range(1, len(positions)):
130
+ x1, y1 = positions[i-1]
131
+ x2, y2 = positions[i]
132
+ pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
133
+ distances.append(pixel_dist)
134
+
135
+ pixel_to_meter = PITCH_LENGTH / frame_width
136
+ distances_m = [d * pixel_to_meter for d in distances]
137
+ time_interval = 1 / frame_rate
138
+ speeds = [d / time_interval for d in distances_m]
139
+ avg_speed_kmh = np.mean(speeds) * 3.6
140
+ return avg_speed_kmh, "Speed calculated successfully"
141
+
142
+ # Create pitch map visualization
143
+ def create_pitch_map(pitch_x, pitch_y):
144
+ fig = go.Figure()
145
+ fig.add_shape(
146
+ type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
147
+ line=dict(color="Green"), fillcolor="Green", opacity=0.3
148
+ )
149
+ fig.add_shape(
150
+ type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
151
+ line=dict(color="Brown"), fillcolor="Brown"
152
+ )
153
+ if pitch_x is not None and pitch_y is not None:
154
+ fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
155
 
156
+ fig.update_layout(
157
+ title="Pitch Map", xaxis_title="Width (m)", yaxis_title="Length (m)",
158
+ xaxis_range=[-PITCH_WIDTH/2, PITCH_WIDTH/2], yaxis_range=[0, PITCH_LENGTH]
159
+ )
160
+ return fig
161
+
162
+ # Main Gradio function
163
+ def drs_analysis(video):
164
+ video_path = "temp_video.mp4"
165
+ with open(video_path, "wb") as f:
166
+ f.write(video.read())
167
+
168
+ positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
169
+ if not positions:
170
+ return None, None, "No ball detected in video", None
171
+
172
+ trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
173
+ if trajectory is None:
174
+ return None, None, lbw_decision, None
175
+
176
+ pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
177
+ speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
178
+
179
+ trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
180
+ fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
181
+ fig_traj.update_yaxes(autorange="reversed")
182
+
183
+ fig_pitch = create_pitch_map(pitch_x, pitch_y)
184
+
185
+ os.remove(video_path)
186
+
187
+ return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path
188
 
189
  # Gradio interface
190
  with gr.Blocks() as demo:
191
+ gr.Markdown("## Cricket DRS Analysis")
192
+ video_input = gr.Video(label="Upload Video Clip")
193
+ btn = gr.Button("Analyze")
194
+ trajectory_output = gr.Plot(label="Ball Trajectory")
195
+ pitch_output = gr.Plot(label="Pitch Map")
196
+ text_output = gr.Textbox(label="Analysis Results")
197
+ video_output = gr.Video(label="Processed Video")
198
+ btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])
 
 
 
 
 
 
 
 
199
 
200
+ if __name__ == "__main__":
201
+ demo.launch()