AjaykumarPilla commited on
Commit
0a4055d
·
verified ·
1 Parent(s): e68af4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +103 -18
app.py CHANGED
@@ -7,6 +7,7 @@ from scipy.interpolate import interp1d
7
  import plotly.graph_objects as go
8
  import uuid
9
  import os
 
10
 
11
  # Load the trained YOLOv8n model with optimizations
12
  model = YOLO("best.pt")
@@ -92,9 +93,6 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
92
  if distance <= MAX_POSITION_JUMP:
93
  filtered_positions.append(curr_pos)
94
  filtered_frames.append(detection_frames[i])
95
- else:
96
- # Skip sudden jumps to maintain continuity
97
- continue
98
 
99
  if len(filtered_positions) < 2:
100
  return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
@@ -106,7 +104,7 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
106
  pitch_point = filtered_positions[0]
107
  pitch_frame = filtered_frames[0]
108
 
109
- # Prioritize sudden y-change for impact detection
110
  impact_idx = None
111
  impact_frame = None
112
  for i in range(1, len(y_coords)):
@@ -125,17 +123,16 @@ def estimate_trajectory(ball_positions, frames, detection_frames):
125
  impact_point = filtered_positions[impact_idx]
126
 
127
  try:
128
- # Use cubic interpolation for smoother trajectory
129
- fx = interp1d(times[:impact_idx + 1], x_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
130
- fy = interp1d(times[:impact_idx + 1], y_coords[:impact_idx + 1], kind='cubic', fill_value="extrapolate")
 
 
 
 
131
  except Exception as e:
132
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
133
 
134
- t_full = np.linspace(times[0], times[-1], len(times) * 4) # Dense points for smooth trajectory
135
- x_full = fx(t_full)
136
- y_full = fy(t_full)
137
- trajectory_2d = list(zip(x_full, y_full))
138
-
139
  trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
140
  detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
141
  pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
@@ -249,6 +246,89 @@ def create_3d_plot(detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d
249
  fig = go.Figure(data=data, layout=layout)
250
  return fig
251
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
  def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
253
  if not frames:
254
  return None
@@ -282,17 +362,20 @@ def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detectio
282
  def drs_review(video):
283
  frames, ball_positions, detection_frames, debug_log = process_video(video)
284
  if not frames:
285
- return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None
286
 
287
  trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
288
 
289
  if trajectory_2d is None:
290
- return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None)
291
 
292
  decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)
293
 
294
- output_path = f"output_{uuid.uuid4()}.mp4"
295
- slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path)
 
 
 
296
 
297
  detections_fig = None
298
  trajectory_fig = None
@@ -303,6 +386,7 @@ def drs_review(video):
303
  debug_output = f"{debug_log}\n{trajectory_log}"
304
  return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
305
  slow_motion_path,
 
306
  detections_fig,
307
  trajectory_fig)
308
 
@@ -312,12 +396,13 @@ iface = gr.Interface(
312
  inputs=gr.Video(label="Upload Video Clip"),
313
  outputs=[
314
  gr.Textbox(label="DRS Decision and Debug Log"),
315
- gr.Video(label="Very Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
 
316
  gr.Plot(label="3D Ball Detections Plot"),
317
  gr.Plot(label="3D Ball Trajectory Plot")
318
  ],
319
  title="AI-Powered DRS for LBW in Local Cricket",
320
- description="Upload a video clip of a cricket delivery to get an LBW decision, a slow-motion replay, and 3D visualizations. The replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D plots show detections (green markers) and trajectory (blue line) with wicket lines (black), pitch point (red), and impact point (yellow)."
321
  )
322
 
323
  if __name__ == "__main__":
 
7
  import plotly.graph_objects as go
8
  import uuid
9
  import os
10
+ import plotly.io as pio
11
 
12
  # Load the trained YOLOv8n model with optimizations
13
  model = YOLO("best.pt")
 
93
  if distance <= MAX_POSITION_JUMP:
94
  filtered_positions.append(curr_pos)
95
  filtered_frames.append(detection_frames[i])
 
 
 
96
 
97
  if len(filtered_positions) < 2:
98
  return None, None, None, None, None, None, None, None, None, "Error: Fewer than 2 valid ball detections after filtering"
 
104
  pitch_point = filtered_positions[0]
105
  pitch_frame = filtered_frames[0]
106
 
107
+ # Detect impact point
108
  impact_idx = None
109
  impact_frame = None
110
  for i in range(1, len(y_coords)):
 
123
  impact_point = filtered_positions[impact_idx]
124
 
125
  try:
126
+ # Cubic interpolation for smooth trajectory
127
+ fx = interp1d(times, x_coords, kind='cubic', fill_value="extrapolate")
128
+ fy = interp1d(times, y_coords, kind='cubic', fill_value="extrapolate")
129
+ t_full = np.linspace(times[0], times[-1], len(times) * 4)
130
+ x_full = fx(t_full)
131
+ y_full = fy(t_full)
132
+ trajectory_2d = list(zip(x_full, y_full))
133
  except Exception as e:
134
  return None, None, None, None, None, None, None, None, None, f"Error in trajectory interpolation: {str(e)}"
135
 
 
 
 
 
 
136
  trajectory_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in trajectory_2d]
137
  detections_3d = [pixel_to_3d(x, y, frame_height, frame_width) for x, y in filtered_positions]
138
  pitch_point_3d = pixel_to_3d(pitch_point[0], pitch_point[1], frame_height, frame_width)
 
246
  fig = go.Figure(data=data, layout=layout)
247
  return fig
248
 
249
+ def generate_3d_trajectory_video(trajectory_3d, pitch_point_3d, impact_point_3d, detection_frames, pitch_frame, impact_frame, output_path):
250
+ if not trajectory_3d or not detection_frames:
251
+ return None
252
+
253
+ # Define video parameters
254
+ frame_width = 1280
255
+ frame_height = 720
256
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
257
+ out = cv2.VideoWriter(output_path, fourcc, FRAME_RATE / SLOW_MOTION_FACTOR, (frame_width, frame_height))
258
+
259
+ # Wicket lines (stumps and bails)
260
+ stump_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2, 0]
261
+ stump_y = [PITCH_LENGTH, PITCH_LENGTH, PITCH_LENGTH]
262
+ stump_z = [0, 0, 0]
263
+ stump_top_z = [STUMPS_HEIGHT, STUMPS_HEIGHT, STUMPS_HEIGHT]
264
+ bail_x = [-STUMPS_WIDTH/2, STUMPS_WIDTH/2]
265
+ bail_y = [PITCH_LENGTH, PITCH_LENGTH]
266
+ bail_z = [STUMPS_HEIGHT, STUMPS_HEIGHT]
267
+
268
+ stump_traces = []
269
+ for i in range(3):
270
+ stump_traces.append(go.Scatter3d(
271
+ x=[stump_x[i], stump_x[i]], y=[stump_y[i], stump_y[i]], z=[stump_z[i], stump_top_z[i]],
272
+ mode='lines', line=dict(color='black', width=5), name=f'Stump {i+1}'
273
+ ))
274
+ bail_traces = [
275
+ go.Scatter3d(
276
+ x=bail_x, y=bail_y, z=bail_z,
277
+ mode='lines', line=dict(color='black', width=5), name='Bail'
278
+ )
279
+ ]
280
+
281
+ # Generate frames for each detection
282
+ for i, frame_idx in enumerate(detection_frames):
283
+ # Trajectory up to current frame
284
+ traj_idx = min(i * 4, len(trajectory_3d) - 1) # Match 2D trajectory density
285
+ x, y, z = zip(*trajectory_3d[:traj_idx + 1]) if trajectory_3d else ([], [], [])
286
+ trajectory_line = go.Scatter3d(
287
+ x=x, y=y, z=z, mode='lines',
288
+ line=dict(color='blue', width=4), name='Ball Trajectory'
289
+ )
290
+
291
+ # Pitch point (red marker) if at or after pitch_frame
292
+ pitch_scatter = go.Scatter3d(
293
+ x=[pitch_point_3d[0]] if frame_idx >= pitch_frame else [],
294
+ y=[pitch_point_3d[1]] if frame_idx >= pitch_frame else [],
295
+ z=[pitch_point_3d[2]] if frame_idx >= pitch_frame else [],
296
+ mode='markers', marker=dict(size=8, color='red'), name='Pitch Point'
297
+ )
298
+
299
+ # Impact point (yellow marker) if at or after impact_frame
300
+ impact_scatter = go.Scatter3d(
301
+ x=[impact_point_3d[0]] if frame_idx >= impact_frame else [],
302
+ y=[impact_point_3d[1]] if frame_idx >= impact_frame else [],
303
+ z=[impact_point_3d[2]] if frame_idx >= impact_frame else [],
304
+ mode='markers', marker=dict(size=8, color='yellow'), name='Impact Point'
305
+ )
306
+
307
+ # Create frame
308
+ fig = go.Figure(
309
+ data=[trajectory_line, pitch_scatter, impact_scatter] + stump_traces + bail_traces,
310
+ layout=go.Layout(
311
+ scene=dict(
312
+ xaxis_title='X (meters)', yaxis_title='Y (meters)', zaxis_title='Z (meters)',
313
+ xaxis=dict(range=[-1.5, 1.5]), yaxis=dict(range=[0, PITCH_LENGTH]),
314
+ zaxis=dict(range=[0, STUMPS_HEIGHT * 2]), aspectmode='manual',
315
+ aspectratio=dict(x=1, y=4, z=0.5)
316
+ ),
317
+ showlegend=True
318
+ )
319
+ )
320
+
321
+ # Render frame to image
322
+ img_bytes = pio.to_image(fig, format='png', width=frame_width, height=frame_height)
323
+ img = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
324
+
325
+ # Write frame multiple times for slow motion
326
+ for _ in range(SLOW_MOTION_FACTOR):
327
+ out.write(img)
328
+
329
+ out.release()
330
+ return output_path
331
+
332
  def generate_slow_motion(frames, trajectory, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path):
333
  if not frames:
334
  return None
 
362
  def drs_review(video):
363
  frames, ball_positions, detection_frames, debug_log = process_video(video)
364
  if not frames:
365
+ return f"Error: Failed to process video\nDebug Log:\n{debug_log}", None, None, None, None
366
 
367
  trajectory_2d, pitch_point, impact_point, pitch_frame, impact_frame, detections_3d, trajectory_3d, pitch_point_3d, impact_point_3d, trajectory_log = estimate_trajectory(ball_positions, frames, detection_frames)
368
 
369
  if trajectory_2d is None:
370
+ return (f"Error: {trajectory_log}\nDebug Log:\n{debug_log}", None, None, None, None)
371
 
372
  decision, trajectory_2d, pitch_point, impact_point = lbw_decision(ball_positions, trajectory_2d, frames, pitch_point, impact_point)
373
 
374
+ output_path_2d = f"output_2d_{uuid.uuid4()}.mp4"
375
+ slow_motion_path = generate_slow_motion(frames, trajectory_2d, pitch_point, impact_point, detection_frames, pitch_frame, impact_frame, output_path_2d)
376
+
377
+ output_path_3d = f"output_3d_{uuid.uuid4()}.mp4"
378
+ trajectory_video_path = generate_3d_trajectory_video(trajectory_3d, pitch_point_3d, impact_point_3d, detection_frames, pitch_frame, impact_frame, output_path_3d)
379
 
380
  detections_fig = None
381
  trajectory_fig = None
 
386
  debug_output = f"{debug_log}\n{trajectory_log}"
387
  return (f"DRS Decision: {decision}\nDebug Log:\n{debug_output}",
388
  slow_motion_path,
389
+ trajectory_video_path,
390
  detections_fig,
391
  trajectory_fig)
392
 
 
396
  inputs=gr.Video(label="Upload Video Clip"),
397
  outputs=[
398
  gr.Textbox(label="DRS Decision and Debug Log"),
399
+ gr.Video(label="2D Slow-Motion Replay with Ball Detection (Green), Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow)"),
400
+ gr.Video(label="3D Trajectory Video with Trajectory (Blue Line), Pitch Point (Red), Impact Point (Yellow), Wicket Lines (Black)"),
401
  gr.Plot(label="3D Ball Detections Plot"),
402
  gr.Plot(label="3D Ball Trajectory Plot")
403
  ],
404
  title="AI-Powered DRS for LBW in Local Cricket",
405
+ description="Upload a video clip of a cricket delivery to get an LBW decision, a 2D slow-motion replay, a 3D trajectory video, and 3D visualizations. The 2D replay shows ball detection (green boxes), trajectory (blue line), pitch point (red circle), and impact point (yellow circle). The 3D video and plots show the trajectory (blue line), detections (green markers), pitch point (red), impact point (yellow), and wicket lines (black)."
406
  )
407
 
408
  if __name__ == "__main__":