AjaykumarPilla commited on
Commit
13e82ee
·
verified ·
1 Parent(s): c788af6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -33
app.py CHANGED
@@ -31,7 +31,7 @@ BATCH_SIZE = 4 # Process 4 frames at a time
31
 
32
  # Load model
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
- model = attempt_load("yolov5s.pt") # Load without map_location
35
  model.to(device).eval() # Move model to device and set to evaluation mode
36
 
37
  # Function to process video and detect ball
@@ -114,10 +114,10 @@ def process_video(video_path):
114
  def poly_func(x, a, b, c):
115
  return a * x**2 + b * x + c
116
 
117
- # Predict trajectory and LBW decision
118
  def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
119
  if len(positions) < 3:
120
- return None, "Insufficient detections for trajectory prediction"
121
 
122
  x_coords = [p[0] for p in positions]
123
  y_coords = [p[1] for p in positions]
@@ -128,7 +128,7 @@ def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
128
  popt_x, _ = curve_fit(poly_func, frames, x_coords)
129
  popt_y, _ = curve_fit(poly_func, frames, y_coords)
130
  except:
131
- return None, "Failed to fit trajectory"
132
 
133
  # Extrapolate to stumps
134
  frame_max = max(frames) + 10
@@ -136,9 +136,13 @@ def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
136
  x_pred = poly_func(future_frames, *popt_x)
137
  y_pred = poly_func(future_frames, *popt_y)
138
 
139
- # Check if trajectory hits stumps
140
  stump_x = frame_width / 2
141
  stump_y = frame_height
 
 
 
 
142
  stump_hit = False
143
  for x, y in zip(x_pred, y_pred):
144
  if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
@@ -146,7 +150,7 @@ def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
146
  break
147
 
148
  lbw_decision = "OUT" if stump_hit else "NOT OUT"
149
- return list(zip(future_frames, x_pred, y_pred)), lbw_decision
150
 
151
  # Map pitch location
152
  def map_pitch(bounce_point, frame_width, frame_height):
@@ -177,27 +181,7 @@ def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
177
  avg_speed_kmh = np.mean(speeds) * 3.6
178
  return avg_speed_kmh, "Speed calculated successfully"
179
 
180
- # Create pitch map visualization
181
- def create_pitch_map(pitch_x, pitch_y):
182
- fig = go.Figure()
183
- fig.add_shape(
184
- type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
185
- line=dict(color="Green"), fillcolor="Green", opacity=0.3
186
- )
187
- fig.add_shape(
188
- type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
189
- line=dict(color="Brown"), fillcolor="Brown"
190
- )
191
- if pitch_x is not None and pitch_y is not None:
192
- fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
193
-
194
- fig.update_layout(
195
- title="Pitch Map", xaxis_title="Width (m)", yaxis_title="Length (m)",
196
- xaxis_range=[-PITCH_WIDTH/2, PITCH_WIDTH/2], yaxis_range=[0, PITCH_LENGTH]
197
- )
198
- return fig
199
-
200
- # Main Gradio function
201
  def drs_analysis(video):
202
  # Video is a file path (string) in Hugging Face Spaces
203
  video_path = video if isinstance(video, str) else "temp_video.mp4"
@@ -205,27 +189,80 @@ def drs_analysis(video):
205
  with open(video_path, "wb") as f:
206
  f.write(video.read())
207
 
 
208
  positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
209
  if not positions:
210
  return None, None, "No ball detected in video", None
211
 
212
- trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
 
213
  if trajectory is None:
214
  return None, None, lbw_decision, None
215
 
216
  pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
217
  speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
218
 
219
- trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
220
- fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
221
- fig_traj.update_yaxes(autorange="reversed")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
222
 
223
- fig_pitch = create_pitch_map(pitch_x, pitch_y)
 
224
 
225
  if not isinstance(video, str):
226
  os.remove(video_path)
227
 
228
- return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path
229
 
230
  # Gradio interface
231
  with gr.Blocks() as demo:
 
31
 
32
  # Load model
33
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+ model = attempt_load("best.pt") # Load without map_location
35
  model.to(device).eval() # Move model to device and set to evaluation mode
36
 
37
  # Function to process video and detect ball
 
114
  def poly_func(x, a, b, c):
115
  return a * x**2 + b * x + c
116
 
117
+ # Predict trajectory and wicket inline path
118
  def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
119
  if len(positions) < 3:
120
+ return None, None, "Insufficient detections for trajectory prediction"
121
 
122
  x_coords = [p[0] for p in positions]
123
  y_coords = [p[1] for p in positions]
 
128
  popt_x, _ = curve_fit(poly_func, frames, x_coords)
129
  popt_y, _ = curve_fit(poly_func, frames, y_coords)
130
  except:
131
+ return None, None, "Failed to fit trajectory"
132
 
133
  # Extrapolate to stumps
134
  frame_max = max(frames) + 10
 
136
  x_pred = poly_func(future_frames, *popt_x)
137
  y_pred = poly_func(future_frames, *popt_y)
138
 
139
+ # Wicket inline path (center line toward stumps)
140
  stump_x = frame_width / 2
141
  stump_y = frame_height
142
+ inline_x = np.linspace(min(x_coords), stump_x, 100)
143
+ inline_y = np.interp(inline_x, x_pred, y_pred)
144
+
145
+ # Check if trajectory hits stumps
146
  stump_hit = False
147
  for x, y in zip(x_pred, y_pred):
148
  if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
 
150
  break
151
 
152
  lbw_decision = "OUT" if stump_hit else "NOT OUT"
153
+ return list(zip(future_frames, x_pred, y_pred)), list(zip(inline_x, inline_y)), lbw_decision
154
 
155
  # Map pitch location
156
  def map_pitch(bounce_point, frame_width, frame_height):
 
181
  avg_speed_kmh = np.mean(speeds) * 3.6
182
  return avg_speed_kmh, "Speed calculated successfully"
183
 
184
+ # Main Gradio function with video overlay
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  def drs_analysis(video):
186
  # Video is a file path (string) in Hugging Face Spaces
187
  video_path = video if isinstance(video, str) else "temp_video.mp4"
 
189
  with open(video_path, "wb") as f:
190
  f.write(video.read())
191
 
192
+ # Process video for detections
193
  positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
194
  if not positions:
195
  return None, None, "No ball detected in video", None
196
 
197
+ # Predict trajectory and wicket path
198
+ trajectory, inline_path, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
199
  if trajectory is None:
200
  return None, None, lbw_decision, None
201
 
202
  pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
203
  speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
204
 
205
+ # Create output video with overlays
206
+ output_path = "output_video.mp4"
207
+ cap = cv2.VideoCapture(video_path)
208
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
209
+ out = cv2.VideoWriter(output_path, fourcc, frame_rate, (frame_width, frame_height))
210
+
211
+ frame_count = 0
212
+ positions_dict = dict(zip(frame_numbers, positions))
213
+
214
+ while cap.isOpened():
215
+ ret, frame = cap.read()
216
+ if not ret:
217
+ break
218
+
219
+ # Skip frames for consistency with detection
220
+ if frame_count % FRAME_SKIP != 0:
221
+ frame_count += 1
222
+ continue
223
+
224
+ # Overlay ball trajectory (red) and wicket inline path (blue)
225
+ if frame_count in positions_dict:
226
+ cv2.circle(frame, (int(positions_dict[frame_count][0]), int(positions_dict[frame_count][1])), 5, (0, 0, 255), -1) # Red dot
227
+ if trajectory:
228
+ traj_x = [int(t[1]) for t in trajectory if t[0] >= frame_count]
229
+ traj_y = [int(t[2]) for t in trajectory if t[0] >= frame_count]
230
+ if traj_x and traj_y:
231
+ for i in range(1, len(traj_x)):
232
+ cv2.line(frame, (traj_x[i-1], traj_y[i-1]), (traj_x[i], traj_y[i]), (0, 0, 255), 2) # Red line
233
+ if inline_path:
234
+ inline_x = [int(x) for x, _ in inline_path]
235
+ inline_y = [int(y) for _, y in inline_path]
236
+ if inline_x and inline_y:
237
+ for i in range(1, len(inline_x)):
238
+ cv2.line(frame, (inline_x[i-1], inline_y[i-1]), (inline_x[i], inline_y[i]), (255, 0, 0), 2) # Blue line
239
+
240
+ # Overlay pitch map in top-right corner
241
+ if pitch_x is not None and pitch_y is not None:
242
+ map_width = 200
243
+ map_height = int(map_width * PITCH_LENGTH / PITCH_WIDTH)
244
+ pitch_map = np.zeros((map_height, map_width, 3), dtype=np.uint8)
245
+ pitch_map[:] = (0, 255, 0) # Green pitch
246
+ cv2.rectangle(pitch_map, (0, map_height-10), (map_width, map_height), (0, 51, 51), -1) # Brown stumps
247
+ bounce_x = int((pitch_x + PITCH_WIDTH/2) / PITCH_WIDTH * map_width)
248
+ bounce_y = int((1 - pitch_y / PITCH_LENGTH) * map_height)
249
+ cv2.circle(pitch_map, (bounce_x, bounce_y), 5, (0, 0, 255), -1) # Red bounce point
250
+ frame[0:map_height, frame_width-map_width:frame_width] = cv2.resize(pitch_map, (map_width, map_height))
251
+
252
+ # Add text annotations
253
+ text = f"LBW: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h"
254
+ cv2.putText(frame, text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
255
+
256
+ out.write(frame)
257
+ frame_count += 1
258
 
259
+ cap.release()
260
+ out.release()
261
 
262
  if not isinstance(video, str):
263
  os.remove(video_path)
264
 
265
+ return None, None, None, output_path
266
 
267
  # Gradio interface
268
  with gr.Blocks() as demo: