AjaykumarPilla commited on
Commit
a295d73
·
verified ·
1 Parent(s): bd41d80

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -197
app.py CHANGED
@@ -1,201 +1,124 @@
 
1
  import cv2
2
  import numpy as np
3
- import pandas as pd
4
- import plotly.express as px
5
- import plotly.graph_objects as go
6
- import torch
7
- import gradio as gr
8
  import os
9
- from scipy.optimize import curve_fit
10
- import sys
11
-
12
- # Add yolov5 directory to sys.path
13
- sys.path.append(os.path.join(os.path.dirname(__file__), "yolov5"))
14
-
15
- # Import YOLOv5 modules
16
- from models.experimental import attempt_load
17
- from utils.general import non_max_suppression, xywh2xyxy
18
-
19
- # Cricket pitch dimensions (in meters)
20
- PITCH_LENGTH = 20.12 # Length of cricket pitch (stumps to stumps)
21
- PITCH_WIDTH = 3.05 # Width of pitch
22
- STUMP_HEIGHT = 0.71 # Stump height
23
- STUMP_WIDTH = 0.2286 # Stump width (including bails)
24
-
25
- # Load model
26
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
- model = attempt_load("best.pt") # Load without map_location
28
- model.to(device).eval() # Move model to device and set to evaluation mode
29
-
30
- # Function to process video and detect ball
31
- def process_video(video_path):
32
- cap = cv2.VideoCapture(video_path)
33
- frame_rate = cap.get(cv2.CAP_PROP_FPS)
34
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
35
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
36
- positions = []
37
- frame_numbers = []
38
- bounce_frame = None
39
- bounce_point = None
40
-
41
- while cap.isOpened():
42
- frame_num = int(cap.get(cv2.CAP_PROP_POS_FRAMES))
43
- ret, frame = cap.read()
44
- if not ret:
45
- break
46
-
47
- # Preprocess frame for YOLOv5
48
- img = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
- img = torch.from_numpy(img).to(device).float() / 255.0
50
- img = img.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
51
-
52
  # Run inference
53
- with torch.no_grad():
54
- pred = model(img)[0]
55
- pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45)
56
-
57
- # Process detections
58
- for det in pred:
59
- if det is not None and len(det):
60
- det = xywh2xyxy(det) # Convert to [x1, y1, x2, y2]
61
- for *xyxy, conf, cls in det:
62
- x_center = (xyxy[0] + xyxy[2]) / 2
63
- y_center = (xyxy[1] + xyxy[3]) / 2
64
- positions.append((x_center.item(), y_center.item()))
65
- frame_numbers.append(frame_num)
66
-
67
- # Detect bounce (lowest y_center point)
68
- if bounce_frame is None or y_center > positions[bounce_frame][1]:
69
- bounce_frame = len(frame_numbers) - 1
70
- bounce_point = (x_center.item(), y_center.item())
71
-
72
- cap.release()
73
- return positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height
74
-
75
- # Polynomial function for trajectory fitting
76
- def poly_func(x, a, b, c):
77
- return a * x**2 + b * x + c
78
-
79
- # Predict trajectory and LBW decision
80
- def predict_trajectory(positions, frame_numbers, frame_width, frame_height):
81
- if len(positions) < 3:
82
- return None, "Insufficient detections for trajectory prediction"
83
-
84
- x_coords = [p[0] for p in positions]
85
- y_coords = [p[1] for p in positions]
86
- frames = np.array(frame_numbers)
87
-
88
- # Fit polynomial to x and y coordinates
89
- try:
90
- popt_x, _ = curve_fit(poly_func, frames, x_coords)
91
- popt_y, _ = curve_fit(poly_func, frames, y_coords)
92
- except:
93
- return None, "Failed to fit trajectory"
94
-
95
- # Extrapolate to stumps
96
- frame_max = max(frames) + 10
97
- future_frames = np.linspace(min(frames), frame_max, 100)
98
- x_pred = poly_func(future_frames, *popt_x)
99
- y_pred = poly_func(future_frames, *popt_y)
100
-
101
- # Check if trajectory hits stumps
102
- stump_x = frame_width / 2
103
- stump_y = frame_height
104
- stump_hit = False
105
- for x, y in zip(x_pred, y_pred):
106
- if abs(y - stump_y) < 50 and abs(x - stump_x) < STUMP_WIDTH * frame_width / PITCH_WIDTH:
107
- stump_hit = True
108
- break
109
-
110
- lbw_decision = "OUT" if stump_hit else "NOT OUT"
111
- return list(zip(future_frames, x_pred, y_pred)), lbw_decision
112
-
113
- # Map pitch location
114
- def map_pitch(bounce_point, frame_width, frame_height):
115
- if bounce_point is None:
116
- return None, "No bounce detected"
117
-
118
- x, y = bounce_point
119
- pitch_x = (x / frame_width) * PITCH_WIDTH - PITCH_WIDTH / 2
120
- pitch_y = (1 - y / frame_height) * PITCH_LENGTH
121
- return pitch_x, pitch_y
122
-
123
- # Estimate ball speed
124
- def estimate_speed(positions, frame_numbers, frame_rate, frame_width):
125
- if len(positions) < 2:
126
- return None, "Insufficient detections for speed estimation"
127
-
128
- distances = []
129
- for i in range(1, len(positions)):
130
- x1, y1 = positions[i-1]
131
- x2, y2 = positions[i]
132
- pixel_dist = np.sqrt((x2 - x1)**2 + (y2 - y1)**2)
133
- distances.append(pixel_dist)
134
-
135
- pixel_to_meter = PITCH_LENGTH / frame_width
136
- distances_m = [d * pixel_to_meter for d in distances]
137
- time_interval = 1 / frame_rate
138
- speeds = [d / time_interval for d in distances_m]
139
- avg_speed_kmh = np.mean(speeds) * 3.6
140
- return avg_speed_kmh, "Speed calculated successfully"
141
-
142
- # Create pitch map visualization
143
- def create_pitch_map(pitch_x, pitch_y):
144
- fig = go.Figure()
145
- fig.add_shape(
146
- type="rect", x0=-PITCH_WIDTH/2, y0=0, x1=PITCH_WIDTH/2, y1=PITCH_LENGTH,
147
- line=dict(color="Green"), fillcolor="Green", opacity=0.3
148
- )
149
- fig.add_shape(
150
- type="rect", x0=-STUMP_WIDTH/2, y0=PITCH_LENGTH-0.1, x1=STUMP_WIDTH/2, y1=PITCH_LENGTH,
151
- line=dict(color="Brown"), fillcolor="Brown"
152
- )
153
- if pitch_x is not None and pitch_y is not None:
154
- fig.add_trace(go.Scatter(x=[pitch_x], y=[pitch_y], mode="markers", marker=dict(size=10, color="Red"), name="Bounce Point"))
155
-
156
- fig.update_layout(
157
- title="Pitch Map", xaxis_title="Width (m)", yaxis_title="Length (m)",
158
- xaxis_range=[-PITCH_WIDTH/2, PITCH_WIDTH/2], yaxis_range=[0, PITCH_LENGTH]
159
- )
160
- return fig
161
-
162
- # Main Gradio function
163
- def drs_analysis(video):
164
- video_path = "temp_video.mp4"
165
- with open(video_path, "wb") as f:
166
- f.write(video.read())
167
-
168
- positions, frame_numbers, bounce_point, frame_rate, frame_width, frame_height = process_video(video_path)
169
- if not positions:
170
- return None, None, "No ball detected in video", None
171
-
172
- trajectory, lbw_decision = predict_trajectory(positions, frame_numbers, frame_width, frame_height)
173
- if trajectory is None:
174
- return None, None, lbw_decision, None
175
-
176
- pitch_x, pitch_y = map_pitch(bounce_point, frame_width, frame_height)
177
- speed_kmh, speed_status = estimate_speed(positions, frame_numbers, frame_rate, frame_width)
178
-
179
- trajectory_df = pd.DataFrame(trajectory, columns=["Frame", "X", "Y"])
180
- fig_traj = px.line(trajectory_df, x="X", y="Y", title="Ball Trajectory (Pixel Coordinates)")
181
- fig_traj.update_yaxes(autorange="reversed")
182
-
183
- fig_pitch = create_pitch_map(pitch_x, pitch_y)
184
-
185
- os.remove(video_path)
186
-
187
- return fig_traj, fig_pitch, f"LBW Decision: {lbw_decision}\nSpeed: {speed_kmh:.2f} km/h", video_path
188
-
189
- # Gradio interface
190
- with gr.Blocks() as demo:
191
- gr.Markdown("## Cricket DRS Analysis")
192
- video_input = gr.Video(label="Upload Video Clip")
193
- btn = gr.Button("Analyze")
194
- trajectory_output = gr.Plot(label="Ball Trajectory")
195
- pitch_output = gr.Plot(label="Pitch Map")
196
- text_output = gr.Textbox(label="Analysis Results")
197
- video_output = gr.Video(label="Processed Video")
198
- btn.click(drs_analysis, inputs=video_input, outputs=[trajectory_output, pitch_output, text_output, video_output])
199
-
200
- if __name__ == "__main__":
201
- demo.launch()
 
1
+ import streamlit as st
2
  import cv2
3
  import numpy as np
4
+ from ultralytics import YOLO
5
+ from huggingface_hub import hf_hub_download
6
+ from PIL import Image
 
 
7
  import os
8
+ import tempfile
9
+ import supervision as sv
10
+
11
+ # Title and description
12
+ st.title("DRS Review System - Ball Detection")
13
+ st.write("Upload an image or video to detect balls using a YOLOv5 model for Decision Review System (DRS).")
14
+
15
+ # Model loading
16
+ @st.cache_resource
17
+ def load_model():
18
+ # Replace 'your-username/your-repo' with your Hugging Face repository and model file
19
+ model_path = hf_hub_download(repo_id="your-username/your-repo", filename="best.pt")
20
+ model = YOLO(model_path)
21
+ return model
22
+
23
+ model = load_model()
24
+
25
+ # Confidence threshold slider
26
+ confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.7, 0.05)
27
+
28
+ # File uploader for image or video
29
+ uploaded_file = st.file_uploader("Upload an image or video", type=["jpg", "jpeg", "png", "mp4"])
30
+
31
+ if uploaded_file is not None:
32
+ # Create a temporary file to save the uploaded content
33
+ tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.' + uploaded_file.name.split('.')[-1])
34
+ tfile.write(uploaded_file.read())
35
+ tfile.close()
36
+ file_path = tfile.name
37
+
38
+ # Check if the uploaded file is an image
39
+ if uploaded_file.type in ["image/jpeg", "image/png"]:
40
+ st.subheader("Image Detection Results")
41
+ image = cv2.imread(file_path)
42
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
43
+
 
 
 
 
 
 
 
44
  # Run inference
45
+ results = model(image, conf=confidence_threshold)
46
+ detections = sv.Detections.from_ultralytics(results[0])
47
+
48
+ # Annotate image
49
+ box_annotator = sv.BoxAnnotator()
50
+ annotated_image = box_annotator.annotate(scene=image_rgb, detections=detections)
51
+
52
+ # Display result
53
+ st.image(annotated_image, caption="Detected Balls", use_column_width=True)
54
+
55
+ # Display detection details
56
+ for score, label, box in zip(detections.confidence, detections.class_id, detections.xyxy):
57
+ st.write(f"Detected ball with confidence {score:.2f} at coordinates {box.tolist()}")
58
+
59
+ # Check if the uploaded file is a video
60
+ elif uploaded_file.type == "video/mp4":
61
+ st.subheader("Video Detection Results")
62
+ output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
63
+
64
+ # Process video
65
+ cap = cv2.VideoCapture(file_path)
66
+ if not cap.isOpened():
67
+ st.error("Error: Could not open video file.")
68
+ else:
69
+ # Get video properties
70
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
73
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
74
+ out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
75
+
76
+ # Progress bar
77
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
78
+ progress = st.progress(0)
79
+ frame_count = 0
80
+
81
+ # Process frames
82
+ while cap.isOpened():
83
+ ret, frame = cap.read()
84
+ if not ret:
85
+ break
86
+
87
+ # Run inference on frame
88
+ results = model(frame, conf=confidence_threshold)
89
+ detections = sv.Detections.from_ultralytics(results[0])
90
+
91
+ # Annotate frame
92
+ box_annotator = sv.BoxAnnotator()
93
+ annotated_frame = box_annotator.annotate(scene=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), detections=detections)
94
+ annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
95
+
96
+ # Write to output video
97
+ out.write(annotated_frame_bgr)
98
+
99
+ # Update progress
100
+ frame_count += 1
101
+ progress.progress(frame_count / total_frames)
102
+
103
+ cap.release()
104
+ out.release()
105
+
106
+ # Display video
107
+ st.video(output_path)
108
+
109
+ # Provide download link for processed video
110
+ with open(output_path, "rb") as file:
111
+ st.download_button(
112
+ label="Download Processed Video",
113
+ data=file,
114
+ file_name="processed_drs_video.mp4",
115
+ mime="video/mp4"
116
+ )
117
+
118
+ # Clean up temporary files
119
+ os.remove(file_path)
120
+ if os.path.exists(output_path):
121
+ os.remove(output_path)
122
+
123
+ else:
124
+ st.info("Please upload an image or video to start the DRS review.")