AjaykumarPilla commited on
Commit
42d2b87
·
verified ·
1 Parent(s): 19ecbb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +156 -60
app.py CHANGED
@@ -5,96 +5,192 @@ import cv2
5
  import numpy as np
6
  from PIL import Image
7
  import os
 
 
8
 
9
- # Load the YOLOv5 model
10
  model = YOLO("best.pt")
11
 
12
- def detect_ball(input_media, conf_threshold=0.5, iou_threshold=0.5):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  """
14
- Perform ball detection on image or video input.
15
 
16
  Args:
17
- input_media: Uploaded image or video file
18
  conf_threshold: Confidence threshold for detection
19
  iou_threshold: IoU threshold for non-max suppression
20
 
21
  Returns:
22
- Annotated image or video path
23
  """
24
- # Check if input is image or video based on file extension
25
- file_extension = os.path.splitext(input_media)[1].lower()
 
 
 
 
 
26
 
27
- if file_extension in ['.jpg', '.jpeg', '.png']:
28
- # Process image
29
- img = cv2.imread(input_media)
30
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
31
-
32
- # Perform detection
33
- results = model.predict(img, conf=conf_threshold, iou=iou_threshold)
 
 
 
 
34
 
35
- # Draw bounding boxes
36
  for box in results[0].boxes:
37
  x1, y1, x2, y2 = map(int, box.xyxy[0])
38
  conf = box.conf[0]
39
  label = f"Ball: {conf:.2f}"
40
- cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
41
- cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
 
42
 
43
- # Convert to PIL Image for Gradio output
44
- output_img = Image.fromarray(img)
45
- return output_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
- elif file_extension in ['.mp4', '.avi', '.mov']:
48
- # Process video
49
- cap = cv2.VideoCapture(input_media)
50
- output_path = "output_video.mp4"
51
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
52
- out = cv2.VideoWriter(output_path, fourcc, 30.0,
53
- (int(cap.get(3)), int(cap.get(4))))
 
 
 
54
 
55
- while cap.isOpened():
56
- ret, frame = cap.read()
57
- if not ret:
58
- break
59
-
60
- # Perform detection
61
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
62
- results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
 
63
 
64
- # Draw bounding boxes
65
- for box in results[0].boxes:
66
- x1, y1, x2, y2 = map(int, box.xyxy[0])
67
- conf = box.conf[0]
68
- label = f"Ball: {conf:.2f}"
69
- cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
70
- cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
71
 
72
- out.write(frame)
73
-
74
- cap.release()
75
- out.release()
76
- return output_path
77
-
78
- else:
79
- return "Unsupported file format. Please upload an image (.jpg, .png) or video (.mp4, .avi, .mov)."
80
 
81
  # Gradio interface
82
  with gr.Blocks() as demo:
83
- gr.Markdown("# Decision Review System (DRS) for Ball Detection")
84
- gr.Markdown("Upload an image or video to detect the ball using a trained YOLOv5 model. Adjust confidence and IoU thresholds for detection.")
85
 
86
- --
87
-
88
- input_media = gr.File(label="Upload Image or Video")
89
  conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
90
  iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
91
- output = gr.Image(label="Output (Image or Video)")
92
- submit_button = gr.Button("Detect Ball")
 
 
93
 
94
  submit_button.click(
95
- fn=detect_ball,
96
- inputs=[input_media, conf_slider, iou_slider],
97
- outputs=output
98
  )
99
 
100
  demo.launch()
 
5
  import numpy as np
6
  from PIL import Image
7
  import os
8
+ import matplotlib.pyplot as plt
9
+ from scipy.interpolate import interp1d
10
 
11
+ # Load YOLOv5 model
12
  model = YOLO("best.pt")
13
 
14
+ class CentroidTracker:
15
+ def __init__(self, max_disappeared=50):
16
+ self.next_object_id = 0
17
+ self.objects = {}
18
+ self.disappeared = {}
19
+ self.max_disappeared = max_disappeared
20
+
21
+ def register(self, centroid):
22
+ self.objects[self.next_object_id] = centroid
23
+ self.disappeared[self.next_object_id] = 0
24
+ self.next_object_id += 1
25
+
26
+ def deregister(self, object_id):
27
+ del self.objects[object_id]
28
+ del self.disappeared[object_id]
29
+
30
+ def update(self, rects):
31
+ if len(rects) == 0:
32
+ for object_id in list(self.disappeared.keys()):
33
+ self.disappeared[object_id] += 1
34
+ if self.disappeared[object_id] > self.max_disappeared:
35
+ self.deregister(object_id)
36
+ return self.objects
37
+
38
+ input_centroids = np.zeros((len(rects), 2), dtype="int")
39
+ for (i, (x1, y1, x2, y2)) in enumerate(rects):
40
+ cX = int((x1 + x2) / 2.0)
41
+ cY = int((y1 + y2) / 2.0)
42
+ input_centroids[i] = (cX, cY)
43
+
44
+ if len(self.objects) == 0:
45
+ for i in range(len(input_centroids)):
46
+ self.register(input_centroids[i])
47
+ else:
48
+ object_ids = list(self.objects.keys())
49
+ object_centroids = list(self.objects.values())
50
+ D = np.sqrt(((input_centroids[:, None] - object_centroids) ** 2).sum(axis=2))
51
+ rows = D.min(axis=1).argsort()
52
+ cols = D.argmin(axis=1)[rows]
53
+ used_rows = set()
54
+ used_cols = set()
55
+ for (row, col) in zip(rows, cols):
56
+ if row in used_rows or col in used_cols:
57
+ continue
58
+ object_id = object_ids[col]
59
+ self.objects[object_id] = input_centroids[row]
60
+ self.disappeared[object_id] = 0
61
+ used_rows.add(row)
62
+ used_cols.add(col)
63
+ unused_rows = set(range(0, D.shape[0])).difference(used_rows)
64
+ unused_cols = set(range(0, D.shape[1])).difference(used_cols)
65
+ if D.shape[0] >= D.shape[1]:
66
+ for row in unused_rows:
67
+ self.register(input_centroids[row])
68
+ else:
69
+ for col in unused_cols:
70
+ object_id = object_ids[col]
71
+ self.disappeared[object_id] += 1
72
+ if self.disappeared[object_id] > self.max_disappeared:
73
+ self.deregister(object_id)
74
+ return self.objects
75
+
76
+ def detect_and_track_ball(video_path, conf_threshold=0.5, iou_threshold=0.5):
77
  """
78
+ Detect and track ball in video, generate pitch map, and predict LBW outcome.
79
 
80
  Args:
81
+ video_path: Path to uploaded video
82
  conf_threshold: Confidence threshold for detection
83
  iou_threshold: IoU threshold for non-max suppression
84
 
85
  Returns:
86
+ Tuple of (annotated video path, pitch map image path, LBW decision)
87
  """
88
+ # Initialize tracker
89
+ tracker = CentroidTracker(max_disappeared=10)
90
+ cap = cv2.VideoCapture(video_path)
91
+ output_path = "output_video.mp4"
92
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
93
+ out = cv2.VideoWriter(output_path, fourcc, 30.0,
94
+ (int(cap.get(3)), int(cap.get(4))))
95
 
96
+ # Store ball centroids for trajectory
97
+ centroids = []
98
+ pitch_points = []
99
+
100
+ while cap.isOpened():
101
+ ret, frame = cap.read()
102
+ if not ret:
103
+ break
104
+
105
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
106
+ results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
107
 
108
+ rects = []
109
  for box in results[0].boxes:
110
  x1, y1, x2, y2 = map(int, box.xyxy[0])
111
  conf = box.conf[0]
112
  label = f"Ball: {conf:.2f}"
113
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
114
+ cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
115
+ rects.append((x1, y1, x2, y2))
116
 
117
+ # Update tracker
118
+ objects = tracker.update(rects)
119
+ for object_id, centroid in objects.items():
120
+ cv2.circle(frame, centroid, 5, (0, 0, 255), -1)
121
+ centroids.append(centroid)
122
+
123
+ out.write(frame)
124
+
125
+ cap.release()
126
+ out.release()
127
+
128
+ # Generate pitch map
129
+ pitch_map_path = "pitch_map.png"
130
+ fig, ax = plt.subplots(figsize=(8, 4))
131
+ ax.set_xlim(0, 22) # Pitch length in meters (approx)
132
+ ax.set_ylim(-1.5, 1.5) # Pitch width (approx)
133
+ ax.set_xlabel("Length (m)")
134
+ ax.set_ylabel("Width (m)")
135
+ ax.set_title("Pitch Map with Ball Trajectory")
136
+
137
+ # Plot stumps
138
+ ax.plot([20.12, 20.12], [-0.135, 0.135], 'k-', lw=5) # Stumps at bowling end
139
+ ax.plot([0, 0], [-0.135, 0.135], 'k-', lw=5) # Stumps at batting end
140
+ ax.plot([0, 20.12], [0, 0], 'k--') # Pitch center line
141
+
142
+ # Map centroids to pitch coordinates (simplified scaling)
143
+ if centroids:
144
+ x_coords = [20.12 - (c[1] / cap.get(4)) * 20.12 for c in centroids] # Scale y to pitch length
145
+ y_coords = [(c[0] / cap.get(3)) * 2.7 - 1.35 for c in centroids] # Scale x to pitch width
146
+ ax.plot(x_coords, y_coords, 'ro-', label="Ball Trajectory")
147
+ pitch_points = list(zip(x_coords, y_coords))
148
 
149
+ ax.legend()
150
+ plt.savefig(pitch_map_path)
151
+ plt.close()
152
+
153
+ # LBW Decision (simplified physics-based model)
154
+ lbw_decision = "Not Out"
155
+ if pitch_points:
156
+ # Check pitching, impact, and wickets
157
+ pitching = any(0 <= x <= 20.12 and -0.135 <= y <= 0.135 for x, y in pitch_points[:len(pitch_points)//2])
158
+ impact = any(18 <= x <= 20.12 for x, y in pitch_points[len(pitch_points)//2:])
159
 
160
+ # Fit a quadratic curve to predict trajectory post-impact
161
+ if len(x_coords) > 2:
162
+ t = np.linspace(0, 1, len(x_coords))
163
+ f_x = interp1d(t, x_coords, kind='quadratic', fill_value="extrapolate")
164
+ f_y = interp1d(t, y_coords, kind='quadratic', fill_value="extrapolate")
165
+ t_future = np.array([1.5]) # Predict beyond impact
166
+ x_future = f_x(t_future)[0]
167
+ y_future = f_y(t_future)[0]
168
+ wickets = (18 <= x_future <= 20.12) and (-0.135 <= y_future <= 0.135)
169
 
170
+ if pitching and impact and wickets:
171
+ lbw_decision = "Out"
172
+ elif pitching and impact:
173
+ lbw_decision = "Umpire's Call" # Marginal case
 
 
 
174
 
175
+ return output_path, pitch_map_path, lbw_decision
 
 
 
 
 
 
 
176
 
177
  # Gradio interface
178
  with gr.Blocks() as demo:
179
+ gr.Markdown("# DRS Review System for Cricket")
180
+ gr.Markdown("Upload a cricket video to analyze ball tracking, pitch mapping, and LBW review. Adjust thresholds for detection accuracy.")
181
 
182
+ video_input = gr.Video(label="Upload Cricket Video")
 
 
183
  conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
184
  iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
185
+ output_video = gr.Video(label="Annotated Video with Ball Tracking")
186
+ output_image = gr.Image(label="Pitch Map")
187
+ output_text = gr.Textbox(label="LBW Decision")
188
+ submit_button = gr.Button("Analyze DRS")
189
 
190
  submit_button.click(
191
+ fn=detect_and_track_ball,
192
+ inputs=[video_input, conf_slider, iou_slider],
193
+ outputs=[output_video, output_image, output_text]
194
  )
195
 
196
  demo.launch()