Update app.py
Browse files
app.py
CHANGED
@@ -1,124 +1,100 @@
|
|
1 |
-
import
|
|
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
-
from ultralytics import YOLO
|
5 |
-
from huggingface_hub import hf_hub_download
|
6 |
from PIL import Image
|
7 |
import os
|
8 |
-
import tempfile
|
9 |
-
import supervision as sv
|
10 |
-
|
11 |
-
# Title and description
|
12 |
-
st.title("DRS Review System - Ball Detection")
|
13 |
-
st.write("Upload an image or video to detect balls using a YOLOv5 model for Decision Review System (DRS).")
|
14 |
-
|
15 |
-
# Model loading
|
16 |
-
@st.cache_resource
|
17 |
-
def load_model():
|
18 |
-
# Replace 'your-username/your-repo' with your Hugging Face repository and model file
|
19 |
-
model_path = hf_hub_download(repo_id="your-username/your-repo", filename="best.pt")
|
20 |
-
model = YOLO(model_path)
|
21 |
-
return model
|
22 |
-
|
23 |
-
model = load_model()
|
24 |
|
25 |
-
#
|
26 |
-
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
#
|
45 |
-
|
46 |
-
|
47 |
|
48 |
-
#
|
49 |
-
|
50 |
-
annotated_image = box_annotator.annotate(scene=image_rgb, detections=detections)
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
# Check if the uploaded file is a video
|
60 |
-
elif uploaded_file.type == "video/mp4":
|
61 |
-
st.subheader("Video Detection Results")
|
62 |
-
output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
63 |
|
|
|
|
|
|
|
|
|
|
|
64 |
# Process video
|
65 |
-
cap = cv2.VideoCapture(
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
# Progress bar
|
77 |
-
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
78 |
-
progress = st.progress(0)
|
79 |
-
frame_count = 0
|
80 |
-
|
81 |
-
# Process frames
|
82 |
-
while cap.isOpened():
|
83 |
-
ret, frame = cap.read()
|
84 |
-
if not ret:
|
85 |
-
break
|
86 |
-
|
87 |
-
# Run inference on frame
|
88 |
-
results = model(frame, conf=confidence_threshold)
|
89 |
-
detections = sv.Detections.from_ultralytics(results[0])
|
90 |
-
|
91 |
-
# Annotate frame
|
92 |
-
box_annotator = sv.BoxAnnotator()
|
93 |
-
annotated_frame = box_annotator.annotate(scene=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), detections=detections)
|
94 |
-
annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
95 |
-
|
96 |
-
# Write to output video
|
97 |
-
out.write(annotated_frame_bgr)
|
98 |
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
cap.release()
|
104 |
-
out.release()
|
105 |
|
106 |
-
#
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
with open(output_path, "rb") as file:
|
111 |
-
st.download_button(
|
112 |
-
label="Download Processed Video",
|
113 |
-
data=file,
|
114 |
-
file_name="processed_drs_video.mp4",
|
115 |
-
mime="video/mp4"
|
116 |
-
)
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
|
123 |
-
|
124 |
-
st.info("Please upload an image or video to start the DRS review.")
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from ultralytics import YOLO
|
4 |
import cv2
|
5 |
import numpy as np
|
|
|
|
|
6 |
from PIL import Image
|
7 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
+
# Load the YOLOv5 model
|
10 |
+
model = YOLO("best.pt")
|
11 |
|
12 |
+
def detect_ball(input_media, conf_threshold=0.5, iou_threshold=0.5):
|
13 |
+
"""
|
14 |
+
Perform ball detection on image or video input.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
input_media: Uploaded image or video file
|
18 |
+
conf_threshold: Confidence threshold for detection
|
19 |
+
iou_threshold: IoU threshold for non-max suppression
|
20 |
+
|
21 |
+
Returns:
|
22 |
+
Annotated image or video path
|
23 |
+
"""
|
24 |
+
# Check if input is image or video based on file extension
|
25 |
+
file_extension = os.path.splitext(input_media)[1].lower()
|
26 |
+
|
27 |
+
if file_extension in ['.jpg', '.jpeg', '.png']:
|
28 |
+
# Process image
|
29 |
+
img = cv2.imread(input_media)
|
30 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
31 |
|
32 |
+
# Perform detection
|
33 |
+
results = model.predict(img, conf=conf_threshold, iou=iou_threshold)
|
|
|
34 |
|
35 |
+
# Draw bounding boxes
|
36 |
+
for box in results[0].boxes:
|
37 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
38 |
+
conf = box.conf[0]
|
39 |
+
label = f"Ball: {conf:.2f}"
|
40 |
+
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
41 |
+
cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
# Convert to PIL Image for Gradio output
|
44 |
+
output_img = Image.fromarray(img)
|
45 |
+
return output_img
|
46 |
+
|
47 |
+
elif file_extension in ['.mp4', '.avi', '.mov']:
|
48 |
# Process video
|
49 |
+
cap = cv2.VideoCapture(input_media)
|
50 |
+
output_path = "output_video.mp4"
|
51 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
52 |
+
out = cv2.VideoWriter(output_path, fourcc, 30.0,
|
53 |
+
(int(cap.get(3)), int(cap.get(4))))
|
54 |
+
|
55 |
+
while cap.isOpened():
|
56 |
+
ret, frame = cap.read()
|
57 |
+
if not ret:
|
58 |
+
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
+
# Perform detection
|
61 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
62 |
+
results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
|
|
|
|
|
|
|
63 |
|
64 |
+
# Draw bounding boxes
|
65 |
+
for box in results[0].boxes:
|
66 |
+
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
67 |
+
conf = box.conf[0]
|
68 |
+
label = f"Ball: {conf:.2f}"
|
69 |
+
cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
|
70 |
+
cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
|
71 |
|
72 |
+
out.write(frame)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
+
cap.release()
|
75 |
+
out.release()
|
76 |
+
return output_path
|
77 |
+
|
78 |
+
else:
|
79 |
+
return "Unsupported file format. Please upload an image (.jpg, .png) or video (.mp4, .avi, .mov)."
|
80 |
+
|
81 |
+
# Gradio interface
|
82 |
+
with gr.Blocks() as demo:
|
83 |
+
gr.Markdown("# Decision Review System (DRS) for Ball Detection")
|
84 |
+
gr.Markdown("Upload an image or video to detect the ball using a trained YOLOv5 model. Adjust confidence and IoU thresholds for detection.")
|
85 |
+
|
86 |
+
--
|
87 |
+
|
88 |
+
input_media = gr.File(label="Upload Image or Video")
|
89 |
+
conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
|
90 |
+
iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
|
91 |
+
output = gr.Image(label="Output (Image or Video)")
|
92 |
+
submit_button = gr.Button("Detect Ball")
|
93 |
+
|
94 |
+
submit_button.click(
|
95 |
+
fn=detect_ball,
|
96 |
+
inputs=[input_media, conf_slider, iou_slider],
|
97 |
+
outputs=output
|
98 |
+
)
|
99 |
|
100 |
+
demo.launch()
|
|