phonghaitran commited on
Commit
d3d822f
·
1 Parent(s): 4dc9026

update new app.py

Browse files
Files changed (2) hide show
  1. .DS_Store +0 -0
  2. app.py +314 -25
.DS_Store ADDED
Binary file (6.15 kB). View file
 
app.py CHANGED
@@ -1,38 +1,327 @@
1
- from pathlib import Path
2
- from PIL import Image
3
-
4
  import pathlib
 
 
5
  import numpy as np
6
- import torch
7
- import streamlit as st
8
  import cv2
 
 
9
 
10
- #If you have linux (or deploying for linux) use:
11
  pathlib.WindowsPath = pathlib.PosixPath
12
 
13
- # Load YOLOv5 model
14
- model = torch.hub.load('./yolov5', 'custom', path='./yolo/best.pt', source='local', force_reload=True)
15
-
16
  st.title("YOLO Object Detection Web App")
17
 
18
- # Upload image
19
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- if uploaded_file is not None:
22
- # Convert the file to an OpenCV image
23
- image = Image.open(uploaded_file)
24
- st.image(image, caption="Uploaded Image", use_column_width=True)
25
- st.write("Processing...")
26
 
27
- # Convert the image to a format compatible with YOLO
28
- image_np = np.array(image)
29
- image_cv = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)
 
 
 
 
 
 
 
 
30
 
31
- # Perform YOLO detection
32
- results = model(image_cv)
 
33
 
34
- # Render the results
35
- detected_image = np.squeeze(results.render())
 
 
 
 
36
 
37
- # Display result
38
- st.image(detected_image, caption="Detected Image", use_column_width=True)
 
 
1
+ import streamlit as st
2
+ import torch
 
3
  import pathlib
4
+ from PIL import Image
5
+ import io
6
  import numpy as np
 
 
7
  import cv2
8
+ import tempfile
9
+ import pandas as pd
10
 
11
+ # Adjust Path for Local Repository
12
  pathlib.WindowsPath = pathlib.PosixPath
13
 
 
 
 
14
  st.title("YOLO Object Detection Web App")
15
 
16
+ # Define the available labels
17
+ default_sub_classes = [
18
+ "container",
19
+ "waste-paper",
20
+ "plant",
21
+ "transportation",
22
+ "kitchenware",
23
+ "rubbish bag",
24
+ "chair",
25
+ "wood",
26
+ "electronics good",
27
+ "sofa",
28
+ "scrap metal",
29
+ "carton",
30
+ "bag",
31
+ "tarpaulin",
32
+ "accessory",
33
+ "rubble",
34
+ "table",
35
+ "board",
36
+ "mattress",
37
+ "beverage",
38
+ "tyre",
39
+ "nylon",
40
+ "rack",
41
+ "styrofoam",
42
+ "clothes",
43
+ "toy",
44
+ "furniture",
45
+ "trolley",
46
+ "carpet",
47
+ "plastic cup"
48
+ ]
49
+
50
+ # Initialize session state for video processing
51
+ if 'video_processed' not in st.session_state:
52
+ st.session_state.video_processed = False
53
+ st.session_state.output_video_path = None
54
+ st.session_state.detections_summary = None
55
+
56
+ # Cache the model loading to prevent repeated loads
57
+ @st.cache_resource
58
+ def load_model():
59
+ model = torch.hub.load('./yolov5', 'custom', path='./yolo/best.pt', source='local', force_reload=False)
60
+ return model
61
+
62
+ model = load_model()
63
+
64
+ # Retrieve model class names
65
+ model_class_names = model.names # Dictionary {index: class_name}
66
+
67
+ # Function to map class names to indices (case-insensitive)
68
+ def get_class_indices(class_list):
69
+ indices = []
70
+ not_found = []
71
+ for cls in class_list:
72
+ found = False
73
+ for index, name in model_class_names.items():
74
+ if name.lower() == cls.lower():
75
+ indices.append(index)
76
+ found = True
77
+ break
78
+ if not found:
79
+ not_found.append(cls)
80
+ return indices, not_found
81
+
82
+ # Function to annotate images
83
+ def annotate_image(frame, results):
84
+ results.render() # Updates results.ims with the annotated images
85
+ annotated_frame = results.ims[0] # Get the first (and only) image
86
+ return annotated_frame
87
+
88
+ # Inform the user about the available labels
89
+ st.markdown("### Available Classes:")
90
+ st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**")
91
+
92
+ # Inform the user about the default detection
93
+ st.info("By default, the application will detect **rubbish** only.")
94
+
95
+ # User input for classes, separated by commas (optional)
96
+ custom_classes_input = st.text_input(
97
+ "Enter classes (comma-separated) or type 'all' to detect everything:",
98
+ ""
99
+ )
100
+
101
+ # Retrieve all model classes
102
+ all_model_classes = list(model_class_names.values())
103
+
104
+ # Determine classes to use based on user input
105
+ if custom_classes_input.strip() == "":
106
+ # No input provided; use only 'rubbish'
107
+ selected_classes = ['rubbish']
108
+ st.info("No classes entered. Using default class: **rubbish**.")
109
+ elif custom_classes_input.strip().lower() == "all":
110
+ # User chose to detect all classes
111
+ selected_classes = all_model_classes
112
+ st.info("Detecting **all** available classes.")
113
+ else:
114
+ # User provided specific classes
115
+ # Split the input string into a list of classes and remove any extra whitespace
116
+ input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()]
117
+ # Ensure 'rubbish' is included
118
+ if 'rubbish' not in [cls.lower() for cls in input_classes]:
119
+ selected_classes = input_classes + ['rubbish']
120
+ st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)")
121
+ else:
122
+ selected_classes = input_classes
123
+ st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**")
124
+
125
+ # Map selected class names to their indices
126
+ selected_class_indices, not_found_classes = get_class_indices(selected_classes)
127
+
128
+ if not_found_classes:
129
+ st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**")
130
+
131
+ # Proceed only if there are valid classes to detect
132
+ if selected_class_indices:
133
+ # Set the classes for the model
134
+ model.classes = selected_class_indices
135
+
136
+ # --------------------- Image Upload and Processing ---------------------
137
+ st.header("Image Object Detection")
138
+
139
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload")
140
+
141
+ if uploaded_file is not None:
142
+ try:
143
+ # Convert the file to a PIL image
144
+ image = Image.open(uploaded_file).convert('RGB')
145
+ st.image(image, caption="Uploaded Image", use_column_width=True)
146
+ st.write("Processing...")
147
+
148
+ # Perform inference
149
+ results = model(image)
150
+
151
+ # Extract DataFrame from results
152
+ results_df = results.pandas().xyxy[0]
153
+
154
+ # Filter results to include only selected classes
155
+ filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])]
156
+
157
+ if filtered_results.empty:
158
+ st.warning("No objects detected for the selected classes.")
159
+ else:
160
+ # Display filtered results
161
+ st.write("### Detection Results")
162
+ st.dataframe(filtered_results)
163
+
164
+ # Annotate the image
165
+ annotated_image = annotate_image(np.array(image), results)
166
+
167
+ # Convert annotated image back to PIL format
168
+ annotated_pil = Image.fromarray(annotated_image)
169
+
170
+ # Display annotated image
171
+ st.image(annotated_pil, caption="Annotated Image", use_column_width=True)
172
+
173
+ # Convert annotated image to bytes
174
+ img_byte_arr = io.BytesIO()
175
+ annotated_pil.save(img_byte_arr, format='PNG')
176
+ img_byte_arr = img_byte_arr.getvalue()
177
+
178
+ # Add download button
179
+ st.download_button(
180
+ label="Download Annotated Image",
181
+ data=img_byte_arr,
182
+ file_name='annotated_image.png',
183
+ mime='image/png'
184
+ )
185
+ except Exception as e:
186
+ st.error(f"An error occurred during image processing: {e}")
187
+
188
+ # --------------------- Video Upload and Processing ---------------------
189
+ st.header("Video Object Detection")
190
+
191
+ uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload")
192
+
193
+ if uploaded_video is not None:
194
+ # Check if the uploaded video is different from the previously processed one
195
+ # Check if the uploaded video first time
196
+ if st.session_state.get("uploaded_video_name") is None:
197
+ st.session_state.uploaded_video_name = uploaded_video.name
198
+ print("First time uploaded video" +st.session_state.uploaded_video_name)
199
+ elif st.session_state.uploaded_video_name != uploaded_video.name:
200
+ st.session_state.uploaded_video_name = uploaded_video.name
201
+ print("Another time uploaded video" +st.session_state.uploaded_video_name)
202
+ st.session_state.video_processed = False
203
+ st.session_state.output_video_path = None
204
+ st.session_state.detections_summary = None
205
+ print("New uploaded video")
206
+
207
+ # Reset session state if video upload is removed
208
+ if uploaded_video is None and st.session_state.video_processed:
209
+ st.session_state.video_processed = False
210
+ st.session_state.output_video_path = None
211
+ st.session_state.detections_summary = None
212
+ st.warning("Video upload has been cleared. You can upload a new video for processing.")
213
+
214
+ if uploaded_video:
215
+ if not st.session_state.video_processed:
216
+ try:
217
+ with st.spinner("Processing video..."):
218
+ # Save uploaded video to a temporary file
219
+ tfile = tempfile.NamedTemporaryFile(delete=False)
220
+ tfile.write(uploaded_video.read())
221
+ tfile.close()
222
+
223
+ # Open the video file
224
+ video_cap = cv2.VideoCapture(tfile.name)
225
+ stframe = st.empty() # Placeholder for displaying video frames
226
+
227
+ # Initialize VideoWriter for saving the output video
228
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
229
+ fps = video_cap.get(cv2.CAP_PROP_FPS)
230
+ width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
231
+ height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
232
+ output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
233
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
234
+
235
+ frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
236
+ progress_bar = st.progress(0)
237
+
238
+ # Initialize list to collect all detections
239
+ all_detections = []
240
+
241
+ for frame_num in range(frame_count):
242
+ ret, frame = video_cap.read() # Read a frame from the video
243
+ if not ret:
244
+ break
245
+
246
+ # Convert frame to RGB
247
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
248
+
249
+ # Perform inference
250
+ results = model(frame_rgb)
251
+
252
+ # Extract DataFrame from results
253
+ results_df = results.pandas().xyxy[0]
254
+ results_df['frame_num'] = frame_num # Optional: Add frame number for reference
255
+
256
+ # Append detections to the list
257
+ if not results_df.empty:
258
+ all_detections.append(results_df)
259
+
260
+ # Annotate the frame with detections
261
+ annotated_frame = annotate_image(frame_rgb, results)
262
+
263
+ # Convert annotated frame back to BGR for VideoWriter
264
+ annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
265
+
266
+ # Write the annotated frame to the output video
267
+ out.write(annotated_bgr)
268
+
269
+ # Display the annotated frame in Streamlit
270
+ stframe.image(annotated_frame, channels="RGB", use_column_width=True)
271
+
272
+ # Update progress bar
273
+ progress_percent = (frame_num + 1) / frame_count
274
+ progress_bar.progress(progress_percent)
275
+
276
+ video_cap.release() # Release the video capture object
277
+ out.release() # Release the VideoWriter object
278
+
279
+ # Save processed video path and detections summary to session state
280
+ st.session_state.output_video_path = output_video_path
281
+
282
+ if all_detections:
283
+ # Concatenate all detections into a single DataFrame
284
+ detections_df = pd.concat(all_detections, ignore_index=True)
285
+
286
+ # Optional: Group by class name and count detections
287
+ detections_summary = detections_df.groupby('name').size().reset_index(name='counts')
288
+ st.session_state.detections_summary = detections_summary
289
+ else:
290
+ st.session_state.detections_summary = None
291
+
292
+ # Mark video as processed
293
+ st.session_state.video_processed = True
294
+
295
+ # st.session_state.uploaded_video_name = uploaded_video.name
296
+
297
+ st.success("Video processing complete!")
298
 
299
+ except Exception as e:
300
+ st.error(f"An error occurred during video processing: {e}")
 
 
 
301
 
302
+ # Display download button and detection summary if processed
303
+ if st.session_state.video_processed:
304
+ try:
305
+ # Create a download button for the annotated video
306
+ with open(st.session_state.output_video_path, "rb") as video_file:
307
+ st.download_button(
308
+ label="Download Annotated Video",
309
+ data=video_file,
310
+ file_name="annotated_video.mp4",
311
+ mime="video/mp4"
312
+ )
313
 
314
+ # Display detection table if there are detections
315
+ if st.session_state.detections_summary is not None:
316
+ detections_summary = st.session_state.detections_summary
317
 
318
+ st.write("### Detection Summary")
319
+ st.dataframe(detections_summary)
320
+ else:
321
+ st.warning("No objects detected in the video for the selected classes.")
322
+ except Exception as e:
323
+ st.error(f"An error occurred while preparing the download: {e}")
324
 
325
+ # Optionally, display all available classes when 'all' is selected
326
+ if custom_classes_input.strip().lower() == "all":
327
+ st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}")