AjaykumarPilla commited on
Commit
61be320
·
verified ·
1 Parent(s): a295d73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -109
app.py CHANGED
@@ -1,124 +1,100 @@
1
- import streamlit as st
 
 
2
  import cv2
3
  import numpy as np
4
- from ultralytics import YOLO
5
- from huggingface_hub import hf_hub_download
6
  from PIL import Image
7
  import os
8
- import tempfile
9
- import supervision as sv
10
-
11
- # Title and description
12
- st.title("DRS Review System - Ball Detection")
13
- st.write("Upload an image or video to detect balls using a YOLOv5 model for Decision Review System (DRS).")
14
-
15
- # Model loading
16
- @st.cache_resource
17
- def load_model():
18
- # Replace 'your-username/your-repo' with your Hugging Face repository and model file
19
- model_path = hf_hub_download(repo_id="your-username/your-repo", filename="best.pt")
20
- model = YOLO(model_path)
21
- return model
22
-
23
- model = load_model()
24
 
25
- # Confidence threshold slider
26
- confidence_threshold = st.slider("Confidence Threshold", 0.0, 1.0, 0.7, 0.05)
27
 
28
- # File uploader for image or video
29
- uploaded_file = st.file_uploader("Upload an image or video", type=["jpg", "jpeg", "png", "mp4"])
30
-
31
- if uploaded_file is not None:
32
- # Create a temporary file to save the uploaded content
33
- tfile = tempfile.NamedTemporaryFile(delete=False, suffix='.' + uploaded_file.name.split('.')[-1])
34
- tfile.write(uploaded_file.read())
35
- tfile.close()
36
- file_path = tfile.name
37
-
38
- # Check if the uploaded file is an image
39
- if uploaded_file.type in ["image/jpeg", "image/png"]:
40
- st.subheader("Image Detection Results")
41
- image = cv2.imread(file_path)
42
- image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
43
-
44
- # Run inference
45
- results = model(image, conf=confidence_threshold)
46
- detections = sv.Detections.from_ultralytics(results[0])
47
 
48
- # Annotate image
49
- box_annotator = sv.BoxAnnotator()
50
- annotated_image = box_annotator.annotate(scene=image_rgb, detections=detections)
51
 
52
- # Display result
53
- st.image(annotated_image, caption="Detected Balls", use_column_width=True)
54
-
55
- # Display detection details
56
- for score, label, box in zip(detections.confidence, detections.class_id, detections.xyxy):
57
- st.write(f"Detected ball with confidence {score:.2f} at coordinates {box.tolist()}")
58
-
59
- # Check if the uploaded file is a video
60
- elif uploaded_file.type == "video/mp4":
61
- st.subheader("Video Detection Results")
62
- output_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
63
 
 
 
 
 
 
64
  # Process video
65
- cap = cv2.VideoCapture(file_path)
66
- if not cap.isOpened():
67
- st.error("Error: Could not open video file.")
68
- else:
69
- # Get video properties
70
- width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
71
- height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
72
- fps = int(cap.get(cv2.CAP_PROP_FPS))
73
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
74
- out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
75
-
76
- # Progress bar
77
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
78
- progress = st.progress(0)
79
- frame_count = 0
80
-
81
- # Process frames
82
- while cap.isOpened():
83
- ret, frame = cap.read()
84
- if not ret:
85
- break
86
-
87
- # Run inference on frame
88
- results = model(frame, conf=confidence_threshold)
89
- detections = sv.Detections.from_ultralytics(results[0])
90
-
91
- # Annotate frame
92
- box_annotator = sv.BoxAnnotator()
93
- annotated_frame = box_annotator.annotate(scene=cv2.cvtColor(frame, cv2.COLOR_BGR2RGB), detections=detections)
94
- annotated_frame_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
95
-
96
- # Write to output video
97
- out.write(annotated_frame_bgr)
98
 
99
- # Update progress
100
- frame_count += 1
101
- progress.progress(frame_count / total_frames)
102
-
103
- cap.release()
104
- out.release()
105
 
106
- # Display video
107
- st.video(output_path)
 
 
 
 
 
108
 
109
- # Provide download link for processed video
110
- with open(output_path, "rb") as file:
111
- st.download_button(
112
- label="Download Processed Video",
113
- data=file,
114
- file_name="processed_drs_video.mp4",
115
- mime="video/mp4"
116
- )
117
 
118
- # Clean up temporary files
119
- os.remove(file_path)
120
- if os.path.exists(output_path):
121
- os.remove(output_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
- else:
124
- st.info("Please upload an image or video to start the DRS review.")
 
1
+ import gradio as gr
2
+ import torch
3
+ from ultralytics import YOLO
4
  import cv2
5
  import numpy as np
 
 
6
  from PIL import Image
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # Load the YOLOv5 model
10
+ model = YOLO("best.pt")
11
 
12
+ def detect_ball(input_media, conf_threshold=0.5, iou_threshold=0.5):
13
+ """
14
+ Perform ball detection on image or video input.
15
+
16
+ Args:
17
+ input_media: Uploaded image or video file
18
+ conf_threshold: Confidence threshold for detection
19
+ iou_threshold: IoU threshold for non-max suppression
20
+
21
+ Returns:
22
+ Annotated image or video path
23
+ """
24
+ # Check if input is image or video based on file extension
25
+ file_extension = os.path.splitext(input_media)[1].lower()
26
+
27
+ if file_extension in ['.jpg', '.jpeg', '.png']:
28
+ # Process image
29
+ img = cv2.imread(input_media)
30
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
31
 
32
+ # Perform detection
33
+ results = model.predict(img, conf=conf_threshold, iou=iou_threshold)
 
34
 
35
+ # Draw bounding boxes
36
+ for box in results[0].boxes:
37
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
38
+ conf = box.conf[0]
39
+ label = f"Ball: {conf:.2f}"
40
+ cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
41
+ cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
 
 
 
 
42
 
43
+ # Convert to PIL Image for Gradio output
44
+ output_img = Image.fromarray(img)
45
+ return output_img
46
+
47
+ elif file_extension in ['.mp4', '.avi', '.mov']:
48
  # Process video
49
+ cap = cv2.VideoCapture(input_media)
50
+ output_path = "output_video.mp4"
51
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
52
+ out = cv2.VideoWriter(output_path, fourcc, 30.0,
53
+ (int(cap.get(3)), int(cap.get(4))))
54
+
55
+ while cap.isOpened():
56
+ ret, frame = cap.read()
57
+ if not ret:
58
+ break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
+ # Perform detection
61
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
62
+ results = model.predict(frame_rgb, conf=conf_threshold, iou=iou_threshold)
 
 
 
63
 
64
+ # Draw bounding boxes
65
+ for box in results[0].boxes:
66
+ x1, y1, x2, y2 = map(int, box.xyxy[0])
67
+ conf = box.conf[0]
68
+ label = f"Ball: {conf:.2f}"
69
+ cv2.rectangle(frame, (x1, y1), (x2, y2), (0, 255, 0), 2)
70
+ cv2.putText(frame, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
71
 
72
+ out.write(frame)
 
 
 
 
 
 
 
73
 
74
+ cap.release()
75
+ out.release()
76
+ return output_path
77
+
78
+ else:
79
+ return "Unsupported file format. Please upload an image (.jpg, .png) or video (.mp4, .avi, .mov)."
80
+
81
+ # Gradio interface
82
+ with gr.Blocks() as demo:
83
+ gr.Markdown("# Decision Review System (DRS) for Ball Detection")
84
+ gr.Markdown("Upload an image or video to detect the ball using a trained YOLOv5 model. Adjust confidence and IoU thresholds for detection.")
85
+
86
+ --
87
+
88
+ input_media = gr.File(label="Upload Image or Video")
89
+ conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="Confidence Threshold")
90
+ iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=0.5, label="IoU Threshold")
91
+ output = gr.Image(label="Output (Image or Video)")
92
+ submit_button = gr.Button("Detect Ball")
93
+
94
+ submit_button.click(
95
+ fn=detect_ball,
96
+ inputs=[input_media, conf_slider, iou_slider],
97
+ outputs=output
98
+ )
99
 
100
+ demo.launch()