Commit
·
d3d822f
1
Parent(s):
4dc9026
update new app.py
Browse files
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
app.py
CHANGED
@@ -1,38 +1,327 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
import pathlib
|
|
|
|
|
5 |
import numpy as np
|
6 |
-
import torch
|
7 |
-
import streamlit as st
|
8 |
import cv2
|
|
|
|
|
9 |
|
10 |
-
#
|
11 |
pathlib.WindowsPath = pathlib.PosixPath
|
12 |
|
13 |
-
# Load YOLOv5 model
|
14 |
-
model = torch.hub.load('./yolov5', 'custom', path='./yolo/best.pt', source='local', force_reload=True)
|
15 |
-
|
16 |
st.title("YOLO Object Detection Web App")
|
17 |
|
18 |
-
#
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
23 |
-
image = Image.open(uploaded_file)
|
24 |
-
st.image(image, caption="Uploaded Image", use_column_width=True)
|
25 |
-
st.write("Processing...")
|
26 |
|
27 |
-
#
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
-
|
|
|
33 |
|
34 |
-
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
|
38 |
-
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import torch
|
|
|
3 |
import pathlib
|
4 |
+
from PIL import Image
|
5 |
+
import io
|
6 |
import numpy as np
|
|
|
|
|
7 |
import cv2
|
8 |
+
import tempfile
|
9 |
+
import pandas as pd
|
10 |
|
11 |
+
# Adjust Path for Local Repository
|
12 |
pathlib.WindowsPath = pathlib.PosixPath
|
13 |
|
|
|
|
|
|
|
14 |
st.title("YOLO Object Detection Web App")
|
15 |
|
16 |
+
# Define the available labels
|
17 |
+
default_sub_classes = [
|
18 |
+
"container",
|
19 |
+
"waste-paper",
|
20 |
+
"plant",
|
21 |
+
"transportation",
|
22 |
+
"kitchenware",
|
23 |
+
"rubbish bag",
|
24 |
+
"chair",
|
25 |
+
"wood",
|
26 |
+
"electronics good",
|
27 |
+
"sofa",
|
28 |
+
"scrap metal",
|
29 |
+
"carton",
|
30 |
+
"bag",
|
31 |
+
"tarpaulin",
|
32 |
+
"accessory",
|
33 |
+
"rubble",
|
34 |
+
"table",
|
35 |
+
"board",
|
36 |
+
"mattress",
|
37 |
+
"beverage",
|
38 |
+
"tyre",
|
39 |
+
"nylon",
|
40 |
+
"rack",
|
41 |
+
"styrofoam",
|
42 |
+
"clothes",
|
43 |
+
"toy",
|
44 |
+
"furniture",
|
45 |
+
"trolley",
|
46 |
+
"carpet",
|
47 |
+
"plastic cup"
|
48 |
+
]
|
49 |
+
|
50 |
+
# Initialize session state for video processing
|
51 |
+
if 'video_processed' not in st.session_state:
|
52 |
+
st.session_state.video_processed = False
|
53 |
+
st.session_state.output_video_path = None
|
54 |
+
st.session_state.detections_summary = None
|
55 |
+
|
56 |
+
# Cache the model loading to prevent repeated loads
|
57 |
+
@st.cache_resource
|
58 |
+
def load_model():
|
59 |
+
model = torch.hub.load('./yolov5', 'custom', path='./yolo/best.pt', source='local', force_reload=False)
|
60 |
+
return model
|
61 |
+
|
62 |
+
model = load_model()
|
63 |
+
|
64 |
+
# Retrieve model class names
|
65 |
+
model_class_names = model.names # Dictionary {index: class_name}
|
66 |
+
|
67 |
+
# Function to map class names to indices (case-insensitive)
|
68 |
+
def get_class_indices(class_list):
|
69 |
+
indices = []
|
70 |
+
not_found = []
|
71 |
+
for cls in class_list:
|
72 |
+
found = False
|
73 |
+
for index, name in model_class_names.items():
|
74 |
+
if name.lower() == cls.lower():
|
75 |
+
indices.append(index)
|
76 |
+
found = True
|
77 |
+
break
|
78 |
+
if not found:
|
79 |
+
not_found.append(cls)
|
80 |
+
return indices, not_found
|
81 |
+
|
82 |
+
# Function to annotate images
|
83 |
+
def annotate_image(frame, results):
|
84 |
+
results.render() # Updates results.ims with the annotated images
|
85 |
+
annotated_frame = results.ims[0] # Get the first (and only) image
|
86 |
+
return annotated_frame
|
87 |
+
|
88 |
+
# Inform the user about the available labels
|
89 |
+
st.markdown("### Available Classes:")
|
90 |
+
st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**")
|
91 |
+
|
92 |
+
# Inform the user about the default detection
|
93 |
+
st.info("By default, the application will detect **rubbish** only.")
|
94 |
+
|
95 |
+
# User input for classes, separated by commas (optional)
|
96 |
+
custom_classes_input = st.text_input(
|
97 |
+
"Enter classes (comma-separated) or type 'all' to detect everything:",
|
98 |
+
""
|
99 |
+
)
|
100 |
+
|
101 |
+
# Retrieve all model classes
|
102 |
+
all_model_classes = list(model_class_names.values())
|
103 |
+
|
104 |
+
# Determine classes to use based on user input
|
105 |
+
if custom_classes_input.strip() == "":
|
106 |
+
# No input provided; use only 'rubbish'
|
107 |
+
selected_classes = ['rubbish']
|
108 |
+
st.info("No classes entered. Using default class: **rubbish**.")
|
109 |
+
elif custom_classes_input.strip().lower() == "all":
|
110 |
+
# User chose to detect all classes
|
111 |
+
selected_classes = all_model_classes
|
112 |
+
st.info("Detecting **all** available classes.")
|
113 |
+
else:
|
114 |
+
# User provided specific classes
|
115 |
+
# Split the input string into a list of classes and remove any extra whitespace
|
116 |
+
input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()]
|
117 |
+
# Ensure 'rubbish' is included
|
118 |
+
if 'rubbish' not in [cls.lower() for cls in input_classes]:
|
119 |
+
selected_classes = input_classes + ['rubbish']
|
120 |
+
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)")
|
121 |
+
else:
|
122 |
+
selected_classes = input_classes
|
123 |
+
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**")
|
124 |
+
|
125 |
+
# Map selected class names to their indices
|
126 |
+
selected_class_indices, not_found_classes = get_class_indices(selected_classes)
|
127 |
+
|
128 |
+
if not_found_classes:
|
129 |
+
st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**")
|
130 |
+
|
131 |
+
# Proceed only if there are valid classes to detect
|
132 |
+
if selected_class_indices:
|
133 |
+
# Set the classes for the model
|
134 |
+
model.classes = selected_class_indices
|
135 |
+
|
136 |
+
# --------------------- Image Upload and Processing ---------------------
|
137 |
+
st.header("Image Object Detection")
|
138 |
+
|
139 |
+
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload")
|
140 |
+
|
141 |
+
if uploaded_file is not None:
|
142 |
+
try:
|
143 |
+
# Convert the file to a PIL image
|
144 |
+
image = Image.open(uploaded_file).convert('RGB')
|
145 |
+
st.image(image, caption="Uploaded Image", use_column_width=True)
|
146 |
+
st.write("Processing...")
|
147 |
+
|
148 |
+
# Perform inference
|
149 |
+
results = model(image)
|
150 |
+
|
151 |
+
# Extract DataFrame from results
|
152 |
+
results_df = results.pandas().xyxy[0]
|
153 |
+
|
154 |
+
# Filter results to include only selected classes
|
155 |
+
filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])]
|
156 |
+
|
157 |
+
if filtered_results.empty:
|
158 |
+
st.warning("No objects detected for the selected classes.")
|
159 |
+
else:
|
160 |
+
# Display filtered results
|
161 |
+
st.write("### Detection Results")
|
162 |
+
st.dataframe(filtered_results)
|
163 |
+
|
164 |
+
# Annotate the image
|
165 |
+
annotated_image = annotate_image(np.array(image), results)
|
166 |
+
|
167 |
+
# Convert annotated image back to PIL format
|
168 |
+
annotated_pil = Image.fromarray(annotated_image)
|
169 |
+
|
170 |
+
# Display annotated image
|
171 |
+
st.image(annotated_pil, caption="Annotated Image", use_column_width=True)
|
172 |
+
|
173 |
+
# Convert annotated image to bytes
|
174 |
+
img_byte_arr = io.BytesIO()
|
175 |
+
annotated_pil.save(img_byte_arr, format='PNG')
|
176 |
+
img_byte_arr = img_byte_arr.getvalue()
|
177 |
+
|
178 |
+
# Add download button
|
179 |
+
st.download_button(
|
180 |
+
label="Download Annotated Image",
|
181 |
+
data=img_byte_arr,
|
182 |
+
file_name='annotated_image.png',
|
183 |
+
mime='image/png'
|
184 |
+
)
|
185 |
+
except Exception as e:
|
186 |
+
st.error(f"An error occurred during image processing: {e}")
|
187 |
+
|
188 |
+
# --------------------- Video Upload and Processing ---------------------
|
189 |
+
st.header("Video Object Detection")
|
190 |
+
|
191 |
+
uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload")
|
192 |
+
|
193 |
+
if uploaded_video is not None:
|
194 |
+
# Check if the uploaded video is different from the previously processed one
|
195 |
+
# Check if the uploaded video first time
|
196 |
+
if st.session_state.get("uploaded_video_name") is None:
|
197 |
+
st.session_state.uploaded_video_name = uploaded_video.name
|
198 |
+
print("First time uploaded video" +st.session_state.uploaded_video_name)
|
199 |
+
elif st.session_state.uploaded_video_name != uploaded_video.name:
|
200 |
+
st.session_state.uploaded_video_name = uploaded_video.name
|
201 |
+
print("Another time uploaded video" +st.session_state.uploaded_video_name)
|
202 |
+
st.session_state.video_processed = False
|
203 |
+
st.session_state.output_video_path = None
|
204 |
+
st.session_state.detections_summary = None
|
205 |
+
print("New uploaded video")
|
206 |
+
|
207 |
+
# Reset session state if video upload is removed
|
208 |
+
if uploaded_video is None and st.session_state.video_processed:
|
209 |
+
st.session_state.video_processed = False
|
210 |
+
st.session_state.output_video_path = None
|
211 |
+
st.session_state.detections_summary = None
|
212 |
+
st.warning("Video upload has been cleared. You can upload a new video for processing.")
|
213 |
+
|
214 |
+
if uploaded_video:
|
215 |
+
if not st.session_state.video_processed:
|
216 |
+
try:
|
217 |
+
with st.spinner("Processing video..."):
|
218 |
+
# Save uploaded video to a temporary file
|
219 |
+
tfile = tempfile.NamedTemporaryFile(delete=False)
|
220 |
+
tfile.write(uploaded_video.read())
|
221 |
+
tfile.close()
|
222 |
+
|
223 |
+
# Open the video file
|
224 |
+
video_cap = cv2.VideoCapture(tfile.name)
|
225 |
+
stframe = st.empty() # Placeholder for displaying video frames
|
226 |
+
|
227 |
+
# Initialize VideoWriter for saving the output video
|
228 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
229 |
+
fps = video_cap.get(cv2.CAP_PROP_FPS)
|
230 |
+
width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
231 |
+
height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
232 |
+
output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
|
233 |
+
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
|
234 |
+
|
235 |
+
frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
236 |
+
progress_bar = st.progress(0)
|
237 |
+
|
238 |
+
# Initialize list to collect all detections
|
239 |
+
all_detections = []
|
240 |
+
|
241 |
+
for frame_num in range(frame_count):
|
242 |
+
ret, frame = video_cap.read() # Read a frame from the video
|
243 |
+
if not ret:
|
244 |
+
break
|
245 |
+
|
246 |
+
# Convert frame to RGB
|
247 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
248 |
+
|
249 |
+
# Perform inference
|
250 |
+
results = model(frame_rgb)
|
251 |
+
|
252 |
+
# Extract DataFrame from results
|
253 |
+
results_df = results.pandas().xyxy[0]
|
254 |
+
results_df['frame_num'] = frame_num # Optional: Add frame number for reference
|
255 |
+
|
256 |
+
# Append detections to the list
|
257 |
+
if not results_df.empty:
|
258 |
+
all_detections.append(results_df)
|
259 |
+
|
260 |
+
# Annotate the frame with detections
|
261 |
+
annotated_frame = annotate_image(frame_rgb, results)
|
262 |
+
|
263 |
+
# Convert annotated frame back to BGR for VideoWriter
|
264 |
+
annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
|
265 |
+
|
266 |
+
# Write the annotated frame to the output video
|
267 |
+
out.write(annotated_bgr)
|
268 |
+
|
269 |
+
# Display the annotated frame in Streamlit
|
270 |
+
stframe.image(annotated_frame, channels="RGB", use_column_width=True)
|
271 |
+
|
272 |
+
# Update progress bar
|
273 |
+
progress_percent = (frame_num + 1) / frame_count
|
274 |
+
progress_bar.progress(progress_percent)
|
275 |
+
|
276 |
+
video_cap.release() # Release the video capture object
|
277 |
+
out.release() # Release the VideoWriter object
|
278 |
+
|
279 |
+
# Save processed video path and detections summary to session state
|
280 |
+
st.session_state.output_video_path = output_video_path
|
281 |
+
|
282 |
+
if all_detections:
|
283 |
+
# Concatenate all detections into a single DataFrame
|
284 |
+
detections_df = pd.concat(all_detections, ignore_index=True)
|
285 |
+
|
286 |
+
# Optional: Group by class name and count detections
|
287 |
+
detections_summary = detections_df.groupby('name').size().reset_index(name='counts')
|
288 |
+
st.session_state.detections_summary = detections_summary
|
289 |
+
else:
|
290 |
+
st.session_state.detections_summary = None
|
291 |
+
|
292 |
+
# Mark video as processed
|
293 |
+
st.session_state.video_processed = True
|
294 |
+
|
295 |
+
# st.session_state.uploaded_video_name = uploaded_video.name
|
296 |
+
|
297 |
+
st.success("Video processing complete!")
|
298 |
|
299 |
+
except Exception as e:
|
300 |
+
st.error(f"An error occurred during video processing: {e}")
|
|
|
|
|
|
|
301 |
|
302 |
+
# Display download button and detection summary if processed
|
303 |
+
if st.session_state.video_processed:
|
304 |
+
try:
|
305 |
+
# Create a download button for the annotated video
|
306 |
+
with open(st.session_state.output_video_path, "rb") as video_file:
|
307 |
+
st.download_button(
|
308 |
+
label="Download Annotated Video",
|
309 |
+
data=video_file,
|
310 |
+
file_name="annotated_video.mp4",
|
311 |
+
mime="video/mp4"
|
312 |
+
)
|
313 |
|
314 |
+
# Display detection table if there are detections
|
315 |
+
if st.session_state.detections_summary is not None:
|
316 |
+
detections_summary = st.session_state.detections_summary
|
317 |
|
318 |
+
st.write("### Detection Summary")
|
319 |
+
st.dataframe(detections_summary)
|
320 |
+
else:
|
321 |
+
st.warning("No objects detected in the video for the selected classes.")
|
322 |
+
except Exception as e:
|
323 |
+
st.error(f"An error occurred while preparing the download: {e}")
|
324 |
|
325 |
+
# Optionally, display all available classes when 'all' is selected
|
326 |
+
if custom_classes_input.strip().lower() == "all":
|
327 |
+
st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}")
|