AjaykumarPilla commited on
Commit
6110fb8
·
verified ·
1 Parent(s): 349e71a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -41
app.py CHANGED
@@ -3,14 +3,15 @@ import numpy as np
3
  import pandas as pd
4
  import plotly.express as px
5
  import plotly.graph_objects as go
6
- from ultralytics import YOLO
7
  import gradio as gr
8
  import os
9
  from scipy.interpolate import interp1d
10
  from scipy.optimize import curve_fit
11
 
12
- # Load YOLOv5 model
13
- model = YOLO("best.pt") # Path to your best.pt
 
14
 
15
  # Cricket pitch dimensions (in meters)
16
  PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
@@ -18,6 +19,11 @@ PITCH_WIDTH = 3.05 # Width of pitch
18
  STUMP_HEIGHT = 0.71 # Stump height
19
  STUMP_WIDTH = 0.2286 # Stump width (including bails)
20
 
 
 
 
 
 
21
  # Function to process video and detect ball
22
  def process_video(video_path):
23
  cap = cv2.VideoCapture(video_path)
@@ -35,19 +41,30 @@ def process_video(video_path):
35
  if not ret:
36
  break
37
 
38
- # Run YOLOv5 detection
39
- results = model(frame)
40
- detections = results[0].boxes.xywh.cpu().numpy() # [x_center, y_center, width, height]
41
-
42
- for det in detections:
43
- x_center, y_center, _, _ = det
44
- positions.append((x_center, y_center))
45
- frame_numbers.append(frame_num)
46
-
47
- # Detect bounce (lowest y_center point)
48
- if bounce_frame is None or y_center > positions[bounce_frame][1]:
49
- bounce_frame = len(frame_numbers) - 1
50
- bounce_point = (x_center, y_center)
 
 
 
 
 
 
 
 
 
 
 
51
 
52
  cap.release()
53
  return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
@@ -72,15 +89,15 @@ def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
72
  except:
73
  return None, "Failed to fit trajectory"
74
 
75
- # Extrapolate to stumps (assume stumps at y=frame_height)
76
- frame_max = max(frames) + 10 # Predict 10 frames ahead
77
  future_frames = np.linspace(min(frames), frame_max, 100)
78
  x_pred = poly_func(future_frames, *popt_x)
79
  y_pred = poly_func(future_frames, *popt_y)
80
 
81
  # Check if trajectory hits stumps
82
- stump_x = frame_width / 2 # Assume stumps at center of frame
83
- stump_y = frame_height # Assume stumps at bottom of frame
84
  stump_hit = False
85
  for x, y in zip(x_pred, y_pred):
86
  if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
@@ -96,9 +113,8 @@ def map_pitch(bounce_point, frame_width, frame_height):
96
  return None, "No bounce detected"
97
 
98
  x, y = bounce_point
99
- # Convert pixel coordinates to pitch coordinates
100
- pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2 # Center at 0
101
- pitch_y = (1 - y / frame_height) * PITCH_LENGTH # Bottom of frame = 0
102
  return pitch_x, pitch_y
103
 
104
  # Estimate ball speed
@@ -106,7 +122,6 @@ def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
106
  if len(positions) < 2:
107
  return None, "Insufficient detections for speed estimation"
108
 
109
- # Calculate distance in pixels between consecutive detections
110
  distances = []
111
  for i in range(1, len(positions)):
112
  x1, y1 = positions[i-1]
@@ -114,30 +129,24 @@ def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
114
  pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
115
  distances.append(pixel_dist)
116
 
117
- # Convert to meters (assume pitch length = frame height)
118
  pixel_to_meter = PITCH_LENGTH / frame_width
119
  distances_m = [d * pixel_to_meter for d in distances]
120
-
121
- # Speed in m/s
122
  time_interval = 1 / frame_rate
123
  speeds = [d / time_interval for d in distances_m]
124
- avg_speed_kmh = np.mean(speeds) * 3.6 # Convert m/s to km/h
125
  return avg_speed_kmh, "Speed calculated successfully"
126
 
127
  # Create pitch map visualization
128
  def create_pitch_map(pitch_x, pitch_y):
129
  fig = go.Figure()
130
- # Draw pitch rectangle
131
  fig.add_shape(
132
  type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
133
  line=dict(color="Green"), fillcolor="Green", opacity=0.3
134
  )
135
- # Draw stumps
136
  fig.add_shape(
137
  type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
138
  line=dict(color="Brown"), fillcolor="Brown"
139
  )
140
- # Plot bounce point
141
  if pitch_x is not None and pitch_y is not None:
142
  fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
143
 
@@ -149,36 +158,27 @@ def create_pitch_map(pitch_x, pitch_y):
149
 
150
  # Main Gradio function
151
  def drs_analysis(video):
152
- # Save uploaded video temporarily
153
  video_path = "temp_video.mp4"
154
  with open(video_path, "wb") as f:
155
  f.write(video.read())
156
 
157
- # Process video
158
  positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
159
  if not positions:
160
  return None, None, "No ball detected in video", None
161
 
162
- # Predict trajectory
163
  trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
164
  if trajectory is None:
165
  return None, None, lbw_decision, None
166
 
167
- # Map pitch
168
  pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
169
-
170
- # Estimate speed
171
  speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
172
 
173
- # Create trajectory plot
174
  trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
175
  fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
176
- fig_traj.update_yaxes(autorange="reversed") # Invert y-axis to match video frame
177
 
178
- # Create pitch map
179
  fig_pitch = create_pitch_map(pitch_x, pitch_y)
180
 
181
- # Clean up
182
  os.remove(video_path)
183
 
184
  return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path
 
3
  import pandas as pd
4
  import plotly.express as px
5
  import plotly.graph_objects as go
6
+ import torch
7
  import gradio as gr
8
  import os
9
  from scipy.interpolate import interp1d
10
  from scipy.optimize import curve_fit
11
 
12
+ # Load YOLOv5 model from yolov5 repository
13
+ from yolov5.models.experimental import attempt_load
14
+ from yolov5.utils.general import non_max_suppression, xywh2xyxy
15
 
16
  # Cricket pitch dimensions (in meters)
17
  PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
 
19
  STUMP_HEIGHT = 0.71 # Stump height
20
  STUMP_WIDTH = 0.2286 # Stump width (including bails)
21
 
22
+ # Load model
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ model = attempt_load("best.pt", map_location=device)
25
+ model.eval()
26
+
27
  # Function to process video and detect ball
28
  def process_video(video_path):
29
  cap = cv2.VideoCapture(video_path)
 
41
  if not ret:
42
  break
43
 
44
+ # Preprocess frame for YOLOv5
45
+ img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
46
+ img = torch.from_numpy(img).to(device).float() / 255.0
47
+ img = img.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
48
+
49
+ # Run inference
50
+ with torch.no_grad():
51
+ pred = model(img)[0]
52
+ pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
53
+
54
+ # Process detections
55
+ for det in pred:
56
+ if det is not None and len(det):
57
+ det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
58
+ for *xyxy, conf, cls in det:
59
+ x_center = (xyxy[0] + xyxy[2]) / 2
60
+ y_center = (xyxy[1] + xyxy[3]) / 2
61
+ positions.append((x_center.item(), y_center.item()))
62
+ frame_numbers.append(frame_num)
63
+
64
+ # Detect bounce (lowest y_center point)
65
+ if bounce_frame is None or y_center > positions[bounce_frame][1]:
66
+ bounce_frame = len(frame_numbers) - 1
67
+ bounce_point = (x_center.item(), y_center.item())
68
 
69
  cap.release()
70
  return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
 
89
  except:
90
  return None, "Failed to fit trajectory"
91
 
92
+ # Extrapolate to stumps
93
+ frame_max = max(frames) + 10
94
  future_frames = np.linspace(min(frames), frame_max, 100)
95
  x_pred = poly_func(future_frames, *popt_x)
96
  y_pred = poly_func(future_frames, *popt_y)
97
 
98
  # Check if trajectory hits stumps
99
+ stump_x = frame_width / 2
100
+ stump_y = frame_height
101
  stump_hit = False
102
  for x, y in zip(x_pred, y_pred):
103
  if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
 
113
  return None, "No bounce detected"
114
 
115
  x, y = bounce_point
116
+ pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
117
+ pitch_y = (1 - y / frame_height) * PITCH_LENGTH
 
118
  return pitch_x, pitch_y
119
 
120
  # Estimate ball speed
 
122
  if len(positions) < 2:
123
  return None, "Insufficient detections for speed estimation"
124
 
 
125
  distances = []
126
  for i in range(1, len(positions)):
127
  x1, y1 = positions[i-1]
 
129
  pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
130
  distances.append(pixel_dist)
131
 
 
132
  pixel_to_meter = PITCH_LENGTH / frame_width
133
  distances_m = [d * pixel_to_meter for d in distances]
 
 
134
  time_interval = 1 / frame_rate
135
  speeds = [d / time_interval for d in distances_m]
136
+ avg_speed_kmh = np.mean(speeds) * 3.6
137
  return avg_speed_kmh, "Speed calculated successfully"
138
 
139
  # Create pitch map visualization
140
  def create_pitch_map(pitch_x, pitch_y):
141
  fig = go.Figure()
 
142
  fig.add_shape(
143
  type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
144
  line=dict(color="Green"), fillcolor="Green", opacity=0.3
145
  )
 
146
  fig.add_shape(
147
  type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
148
  line=dict(color="Brown"), fillcolor="Brown"
149
  )
 
150
  if pitch_x is not None and pitch_y is not None:
151
  fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
152
 
 
158
 
159
  # Main Gradio function
160
  def drs_analysis(video):
 
161
  video_path = "temp_video.mp4"
162
  with open(video_path, "wb") as f:
163
  f.write(video.read())
164
 
 
165
  positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
166
  if not positions:
167
  return None, None, "No ball detected in video", None
168
 
 
169
  trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
170
  if trajectory is None:
171
  return None, None, lbw_decision, None
172
 
 
173
  pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
 
 
174
  speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
175
 
 
176
  trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
177
  fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
178
+ fig_traj.update_yaxes(autorange="reversed")
179
 
 
180
  fig_pitch = create_pitch_map(pitch_x, pitch_y)
181
 
 
182
  os.remove(video_path)
183
 
184
  return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path