phonghaitran commited on
Commit
721557d
·
1 Parent(s): 310c1ed

yolov5 design with unet and split into radio button for chosen

Browse files
.DS_Store CHANGED
Binary files a/.DS_Store and b/.DS_Store differ
 
app.py CHANGED
@@ -1,327 +1,437 @@
 
 
 
 
 
1
  import streamlit as st
 
 
 
 
 
2
  import torch
3
  import pathlib
4
- from PIL import Image
5
  import io
6
- import numpy as np
7
  import cv2
8
  import tempfile
9
- import pandas as pd
10
 
11
  # Adjust Path for Local Repository
12
  pathlib.WindowsPath = pathlib.PosixPath
13
 
14
  st.title("YOLO Object Detection Web App")
15
 
16
- # Define the available labels
17
- default_sub_classes = [
18
- "container",
19
- "waste-paper",
20
- "plant",
21
- "transportation",
22
- "kitchenware",
23
- "rubbish bag",
24
- "chair",
25
- "wood",
26
- "electronics good",
27
- "sofa",
28
- "scrap metal",
29
- "carton",
30
- "bag",
31
- "tarpaulin",
32
- "accessory",
33
- "rubble",
34
- "table",
35
- "board",
36
- "mattress",
37
- "beverage",
38
- "tyre",
39
- "nylon",
40
- "rack",
41
- "styrofoam",
42
- "clothes",
43
- "toy",
44
- "furniture",
45
- "trolley",
46
- "carpet",
47
- "plastic cup"
48
- ]
49
-
50
- # Initialize session state for video processing
51
- if 'video_processed' not in st.session_state:
52
- st.session_state.video_processed = False
53
- st.session_state.output_video_path = None
54
- st.session_state.detections_summary = None
55
-
56
- # Cache the model loading to prevent repeated loads
57
- @st.cache_resource
58
- def load_model():
59
- model = torch.hub.load('./yolov5', 'custom', path='./yolo/best.pt', source='local', force_reload=False)
60
- return model
 
 
 
 
61
 
62
- model = load_model()
63
-
64
- # Retrieve model class names
65
- model_class_names = model.names # Dictionary {index: class_name}
66
-
67
- # Function to map class names to indices (case-insensitive)
68
- def get_class_indices(class_list):
69
- indices = []
70
- not_found = []
71
- for cls in class_list:
72
- found = False
73
- for index, name in model_class_names.items():
74
- if name.lower() == cls.lower():
75
- indices.append(index)
76
- found = True
77
- break
78
- if not found:
79
- not_found.append(cls)
80
- return indices, not_found
81
-
82
- # Function to annotate images
83
- def annotate_image(frame, results):
84
- results.render() # Updates results.ims with the annotated images
85
- annotated_frame = results.ims[0] # Get the first (and only) image
86
- return annotated_frame
87
-
88
- # Inform the user about the available labels
89
- st.markdown("### Available Classes:")
90
- st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**")
91
-
92
- # Inform the user about the default detection
93
- st.info("By default, the application will detect **rubbish** only.")
94
-
95
- # User input for classes, separated by commas (optional)
96
- custom_classes_input = st.text_input(
97
- "Enter classes (comma-separated) or type 'all' to detect everything:",
98
- ""
99
- )
100
-
101
- # Retrieve all model classes
102
- all_model_classes = list(model_class_names.values())
103
-
104
- # Determine classes to use based on user input
105
- if custom_classes_input.strip() == "":
106
- # No input provided; use only 'rubbish'
107
- selected_classes = ['rubbish']
108
- st.info("No classes entered. Using default class: **rubbish**.")
109
- elif custom_classes_input.strip().lower() == "all":
110
- # User chose to detect all classes
111
- selected_classes = all_model_classes
112
- st.info("Detecting **all** available classes.")
113
- else:
114
- # User provided specific classes
115
- # Split the input string into a list of classes and remove any extra whitespace
116
- input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()]
117
- # Ensure 'rubbish' is included
118
- if 'rubbish' not in [cls.lower() for cls in input_classes]:
119
- selected_classes = input_classes + ['rubbish']
120
- st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)")
121
  else:
122
- selected_classes = input_classes
123
- st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
- # Map selected class names to their indices
126
- selected_class_indices, not_found_classes = get_class_indices(selected_classes)
127
 
128
- if not_found_classes:
129
- st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**")
130
 
131
- # Proceed only if there are valid classes to detect
132
- if selected_class_indices:
133
- # Set the classes for the model
134
- model.classes = selected_class_indices
 
 
135
 
136
- # --------------------- Image Upload and Processing ---------------------
137
- st.header("Image Object Detection")
138
 
139
- uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload")
 
140
 
141
- if uploaded_file is not None:
142
- try:
143
- # Convert the file to a PIL image
144
- image = Image.open(uploaded_file).convert('RGB')
145
- st.image(image, caption="Uploaded Image", use_column_width=True)
146
- st.write("Processing...")
147
 
148
- # Perform inference
149
- results = model(image)
150
-
151
- # Extract DataFrame from results
152
- results_df = results.pandas().xyxy[0]
153
-
154
- # Filter results to include only selected classes
155
- filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])]
156
-
157
- if filtered_results.empty:
158
- st.warning("No objects detected for the selected classes.")
159
- else:
160
- # Display filtered results
161
- st.write("### Detection Results")
162
- st.dataframe(filtered_results)
163
-
164
- # Annotate the image
165
- annotated_image = annotate_image(np.array(image), results)
166
-
167
- # Convert annotated image back to PIL format
168
- annotated_pil = Image.fromarray(annotated_image)
169
-
170
- # Display annotated image
171
- st.image(annotated_pil, caption="Annotated Image", use_column_width=True)
172
-
173
- # Convert annotated image to bytes
174
- img_byte_arr = io.BytesIO()
175
- annotated_pil.save(img_byte_arr, format='PNG')
176
- img_byte_arr = img_byte_arr.getvalue()
177
-
178
- # Add download button
179
- st.download_button(
180
- label="Download Annotated Image",
181
- data=img_byte_arr,
182
- file_name='annotated_image.png',
183
- mime='image/png'
184
- )
185
- except Exception as e:
186
- st.error(f"An error occurred during image processing: {e}")
187
-
188
- # --------------------- Video Upload and Processing ---------------------
189
- st.header("Video Object Detection")
190
-
191
- uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload")
192
-
193
- if uploaded_video is not None:
194
- # Check if the uploaded video is different from the previously processed one
195
- # Check if the uploaded video first time
196
- if st.session_state.get("uploaded_video_name") is None:
197
- st.session_state.uploaded_video_name = uploaded_video.name
198
- print("First time uploaded video" +st.session_state.uploaded_video_name)
199
- elif st.session_state.uploaded_video_name != uploaded_video.name:
200
- st.session_state.uploaded_video_name = uploaded_video.name
201
- print("Another time uploaded video" +st.session_state.uploaded_video_name)
202
  st.session_state.video_processed = False
203
  st.session_state.output_video_path = None
204
  st.session_state.detections_summary = None
205
- print("New uploaded video")
206
-
207
- # Reset session state if video upload is removed
208
- if uploaded_video is None and st.session_state.video_processed:
209
- st.session_state.video_processed = False
210
- st.session_state.output_video_path = None
211
- st.session_state.detections_summary = None
212
- st.warning("Video upload has been cleared. You can upload a new video for processing.")
213
 
214
- if uploaded_video:
215
- if not st.session_state.video_processed:
216
- try:
217
- with st.spinner("Processing video..."):
218
- # Save uploaded video to a temporary file
219
- tfile = tempfile.NamedTemporaryFile(delete=False)
220
- tfile.write(uploaded_video.read())
221
- tfile.close()
222
 
223
- # Open the video file
224
- video_cap = cv2.VideoCapture(tfile.name)
225
- stframe = st.empty() # Placeholder for displaying video frames
226
 
227
- # Initialize VideoWriter for saving the output video
228
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
229
- fps = video_cap.get(cv2.CAP_PROP_FPS)
230
- width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
231
- height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
232
- output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
233
- out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
234
 
235
- frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
236
- progress_bar = st.progress(0)
237
 
238
- # Initialize list to collect all detections
239
- all_detections = []
240
 
241
- for frame_num in range(frame_count):
242
- ret, frame = video_cap.read() # Read a frame from the video
243
- if not ret:
244
- break
245
 
246
- # Convert frame to RGB
247
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
248
 
249
- # Perform inference
250
- results = model(frame_rgb)
251
 
252
- # Extract DataFrame from results
253
- results_df = results.pandas().xyxy[0]
254
- results_df['frame_num'] = frame_num # Optional: Add frame number for reference
255
 
256
- # Append detections to the list
257
- if not results_df.empty:
258
- all_detections.append(results_df)
259
 
260
- # Annotate the frame with detections
261
- annotated_frame = annotate_image(frame_rgb, results)
262
 
263
- # Convert annotated frame back to BGR for VideoWriter
264
- annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
265
 
266
- # Write the annotated frame to the output video
267
- out.write(annotated_bgr)
268
 
269
- # Display the annotated frame in Streamlit
270
- stframe.image(annotated_frame, channels="RGB", use_column_width=True)
271
 
272
- # Update progress bar
273
- progress_percent = (frame_num + 1) / frame_count
274
- progress_bar.progress(progress_percent)
275
 
276
- video_cap.release() # Release the video capture object
277
- out.release() # Release the VideoWriter object
278
 
279
- # Save processed video path and detections summary to session state
280
- st.session_state.output_video_path = output_video_path
281
 
282
- if all_detections:
283
- # Concatenate all detections into a single DataFrame
284
- detections_df = pd.concat(all_detections, ignore_index=True)
285
 
286
- # Optional: Group by class name and count detections
287
- detections_summary = detections_df.groupby('name').size().reset_index(name='counts')
288
- st.session_state.detections_summary = detections_summary
289
- else:
290
- st.session_state.detections_summary = None
 
 
 
291
 
292
- # Mark video as processed
293
- st.session_state.video_processed = True
294
 
295
- # st.session_state.uploaded_video_name = uploaded_video.name
296
 
297
- st.success("Video processing complete!")
 
298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  except Exception as e:
300
- st.error(f"An error occurred during video processing: {e}")
301
 
302
- # Display download button and detection summary if processed
303
- if st.session_state.video_processed:
304
- try:
305
- # Create a download button for the annotated video
306
- with open(st.session_state.output_video_path, "rb") as video_file:
307
- st.download_button(
308
- label="Download Annotated Video",
309
- data=video_file,
310
- file_name="annotated_video.mp4",
311
- mime="video/mp4"
312
- )
 
 
 
 
 
313
 
314
- # Display detection table if there are detections
315
- if st.session_state.detections_summary is not None:
316
- detections_summary = st.session_state.detections_summary
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
317
 
318
- st.write("### Detection Summary")
319
- st.dataframe(detections_summary)
320
- else:
321
- st.warning("No objects detected in the video for the selected classes.")
322
- except Exception as e:
323
- st.error(f"An error occurred while preparing the download: {e}")
324
 
325
- # Optionally, display all available classes when 'all' is selected
326
- if custom_classes_input.strip().lower() == "all":
327
- st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw
2
+
3
+ # Import the model components from unet directory
4
+ from unet.unet_model import UNet
5
+
6
  import streamlit as st
7
+ import plotly.express as px
8
+ import pandas as pd
9
+ import numpy as np
10
+ import torchvision.transforms as T
11
+
12
  import torch
13
  import pathlib
 
14
  import io
 
15
  import cv2
16
  import tempfile
 
17
 
18
  # Adjust Path for Local Repository
19
  pathlib.WindowsPath = pathlib.PosixPath
20
 
21
  st.title("YOLO Object Detection Web App")
22
 
23
+ def yolo():
24
+ st.markdown(
25
+ "<h1 style='text-align: center; font-size: 36px;'>Smart city rubbish detection</h1>",
26
+ unsafe_allow_html=True
27
+ )
28
+ st.markdown(
29
+ "<h2 style='text-align: center; font-size: 30px;'>Presented by team 2</h2>",
30
+ unsafe_allow_html=True
31
+ )
32
+
33
+ # Define the available labels
34
+ default_sub_classes = [
35
+ "container",
36
+ "waste-paper",
37
+ "plant",
38
+ "transportation",
39
+ "kitchenware",
40
+ "rubbish bag",
41
+ "chair",
42
+ "wood",
43
+ "electronics good",
44
+ "sofa",
45
+ "scrap metal",
46
+ "carton",
47
+ "bag",
48
+ "tarpaulin",
49
+ "accessory",
50
+ "rubble",
51
+ "table",
52
+ "board",
53
+ "mattress",
54
+ "beverage",
55
+ "tyre",
56
+ "nylon",
57
+ "rack",
58
+ "styrofoam",
59
+ "clothes",
60
+ "toy",
61
+ "furniture",
62
+ "trolley",
63
+ "carpet",
64
+ "plastic cup"
65
+ ]
66
+
67
+ # Initialize session state for video processing
68
+ if 'video_processed' not in st.session_state:
69
+ st.session_state.video_processed = False
70
+ st.session_state.output_video_path = None
71
+ st.session_state.detections_summary = None
72
 
73
+ # Cache the model loading to prevent repeated loads
74
+ @st.cache_resource
75
+ def load_model():
76
+ model = torch.hub.load('./yolov5', 'custom', path='./model/yolo/best.pt', source='local', force_reload=False)
77
+ return model
78
+
79
+ model = load_model()
80
+
81
+ # Retrieve model class names
82
+ model_class_names = model.names # Dictionary {index: class_name}
83
+
84
+ # Function to map class names to indices (case-insensitive)
85
+ def get_class_indices(class_list):
86
+ indices = []
87
+ not_found = []
88
+ for cls in class_list:
89
+ found = False
90
+ for index, name in model_class_names.items():
91
+ if name.lower() == cls.lower():
92
+ indices.append(index)
93
+ found = True
94
+ break
95
+ if not found:
96
+ not_found.append(cls)
97
+ return indices, not_found
98
+
99
+ # Function to annotate images
100
+ def annotate_image(frame, results):
101
+ results.render() # Updates results.ims with the annotated images
102
+ annotated_frame = results.ims[0] # Get the first (and only) image
103
+ return annotated_frame
104
+
105
+ # Inform the user about the available labels
106
+ st.markdown("### Available Classes:")
107
+ st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**")
108
+
109
+ # Inform the user about the default detection
110
+ st.info("By default, the application will detect **rubbish** only.")
111
+
112
+ # User input for classes, separated by commas (optional)
113
+ custom_classes_input = st.text_input(
114
+ "Enter classes (comma-separated) or type 'all' to detect everything:",
115
+ ""
116
+ )
117
+
118
+ # Retrieve all model classes
119
+ all_model_classes = list(model_class_names.values())
120
+
121
+ # Determine classes to use based on user input
122
+ if custom_classes_input.strip() == "":
123
+ # No input provided; use only 'rubbish'
124
+ selected_classes = ['rubbish']
125
+ st.info("No classes entered. Using default class: **rubbish**.")
126
+ elif custom_classes_input.strip().lower() == "all":
127
+ # User chose to detect all classes
128
+ selected_classes = all_model_classes
129
+ st.info("Detecting **all** available classes.")
 
 
130
  else:
131
+ # User provided specific classes
132
+ # Split the input string into a list of classes and remove any extra whitespace
133
+ input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()]
134
+ # Ensure 'rubbish' is included
135
+ if 'rubbish' not in [cls.lower() for cls in input_classes]:
136
+ selected_classes = input_classes + ['rubbish']
137
+ st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)")
138
+ else:
139
+ selected_classes = input_classes
140
+ st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**")
141
+
142
+ # Map selected class names to their indices
143
+ selected_class_indices, not_found_classes = get_class_indices(selected_classes)
144
+
145
+ if not_found_classes:
146
+ st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**")
147
+
148
+ # Proceed only if there are valid classes to detect
149
+ if selected_class_indices:
150
+ # Set the classes for the model
151
+ model.classes = selected_class_indices
152
+
153
+ # --------------------- Image Upload and Processing ---------------------
154
+ st.header("Image Object Detection")
155
+
156
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload")
157
+
158
+ if uploaded_file is not None:
159
+ try:
160
+ # Convert the file to a PIL image
161
+ image = Image.open(uploaded_file).convert('RGB')
162
+ st.image(image, caption="Uploaded Image", use_column_width=True)
163
+ st.write("Processing...")
164
+
165
+ # Perform inference
166
+ results = model(image)
167
 
168
+ # Extract DataFrame from results
169
+ results_df = results.pandas().xyxy[0]
170
 
171
+ # Filter results to include only selected classes
172
+ filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])]
173
 
174
+ if filtered_results.empty:
175
+ st.warning("No objects detected for the selected classes.")
176
+ else:
177
+ # Display filtered results
178
+ st.write("### Detection Results")
179
+ st.dataframe(filtered_results)
180
 
181
+ # Annotate the image
182
+ annotated_image = annotate_image(np.array(image), results)
183
 
184
+ # Convert annotated image back to PIL format
185
+ annotated_pil = Image.fromarray(annotated_image)
186
 
187
+ # Display annotated image
188
+ st.image(annotated_pil, caption="Annotated Image", use_column_width=True)
 
 
 
 
189
 
190
+ # Convert annotated image to bytes
191
+ img_byte_arr = io.BytesIO()
192
+ annotated_pil.save(img_byte_arr, format='PNG')
193
+ img_byte_arr = img_byte_arr.getvalue()
194
+
195
+ # Add download button
196
+ st.download_button(
197
+ label="Download Annotated Image",
198
+ data=img_byte_arr,
199
+ file_name='annotated_image.png',
200
+ mime='image/png'
201
+ )
202
+ except Exception as e:
203
+ st.error(f"An error occurred during image processing: {e}")
204
+
205
+ # --------------------- Video Upload and Processing ---------------------
206
+ st.header("Video Object Detection")
207
+
208
+ uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload")
209
+
210
+ if uploaded_video is not None:
211
+ # Check if the uploaded video is different from the previously processed one
212
+ # Check if the uploaded video first time
213
+ if st.session_state.get("uploaded_video_name") is None:
214
+ st.session_state.uploaded_video_name = uploaded_video.name
215
+ print("First time uploaded video" +st.session_state.uploaded_video_name)
216
+ elif st.session_state.uploaded_video_name != uploaded_video.name:
217
+ st.session_state.uploaded_video_name = uploaded_video.name
218
+ print("Another time uploaded video" +st.session_state.uploaded_video_name)
219
+ st.session_state.video_processed = False
220
+ st.session_state.output_video_path = None
221
+ st.session_state.detections_summary = None
222
+ print("New uploaded video")
223
+
224
+ # Reset session state if video upload is removed
225
+ if uploaded_video is None and st.session_state.video_processed:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  st.session_state.video_processed = False
227
  st.session_state.output_video_path = None
228
  st.session_state.detections_summary = None
229
+ st.warning("Video upload has been cleared. You can upload a new video for processing.")
 
 
 
 
 
 
 
230
 
231
+ if uploaded_video:
232
+ if not st.session_state.video_processed:
233
+ try:
234
+ with st.spinner("Processing video..."):
235
+ # Save uploaded video to a temporary file
236
+ tfile = tempfile.NamedTemporaryFile(delete=False)
237
+ tfile.write(uploaded_video.read())
238
+ tfile.close()
239
 
240
+ # Open the video file
241
+ video_cap = cv2.VideoCapture(tfile.name)
242
+ stframe = st.empty() # Placeholder for displaying video frames
243
 
244
+ # Initialize VideoWriter for saving the output video
245
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
246
+ fps = video_cap.get(cv2.CAP_PROP_FPS)
247
+ width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
248
+ height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
249
+ output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
250
+ out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
251
 
252
+ frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
253
+ progress_bar = st.progress(0)
254
 
255
+ # Initialize list to collect all detections
256
+ all_detections = []
257
 
258
+ for frame_num in range(frame_count):
259
+ ret, frame = video_cap.read() # Read a frame from the video
260
+ if not ret:
261
+ break
262
 
263
+ # Convert frame to RGB
264
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
265
 
266
+ # Perform inference
267
+ results = model(frame_rgb)
268
 
269
+ # Extract DataFrame from results
270
+ results_df = results.pandas().xyxy[0]
271
+ results_df['frame_num'] = frame_num # Optional: Add frame number for reference
272
 
273
+ # Append detections to the list
274
+ if not results_df.empty:
275
+ all_detections.append(results_df)
276
 
277
+ # Annotate the frame with detections
278
+ annotated_frame = annotate_image(frame_rgb, results)
279
 
280
+ # Convert annotated frame back to BGR for VideoWriter
281
+ annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
282
 
283
+ # Write the annotated frame to the output video
284
+ out.write(annotated_bgr)
285
 
286
+ # Display the annotated frame in Streamlit
287
+ stframe.image(annotated_frame, channels="RGB", use_column_width=True)
288
 
289
+ # Update progress bar
290
+ progress_percent = (frame_num + 1) / frame_count
291
+ progress_bar.progress(progress_percent)
292
 
293
+ video_cap.release() # Release the video capture object
294
+ out.release() # Release the VideoWriter object
295
 
296
+ # Save processed video path and detections summary to session state
297
+ st.session_state.output_video_path = output_video_path
298
 
299
+ if all_detections:
300
+ # Concatenate all detections into a single DataFrame
301
+ detections_df = pd.concat(all_detections, ignore_index=True)
302
 
303
+ # Optional: Group by class name and count detections
304
+ detections_summary = detections_df.groupby('name').size().reset_index(name='counts')
305
+ st.session_state.detections_summary = detections_summary
306
+ else:
307
+ st.session_state.detections_summary = None
308
+
309
+ # Mark video as processed
310
+ st.session_state.video_processed = True
311
 
312
+ # st.session_state.uploaded_video_name = uploaded_video.name
 
313
 
314
+ st.success("Video processing complete!")
315
 
316
+ except Exception as e:
317
+ st.error(f"An error occurred during video processing: {e}")
318
 
319
+ # Display download button and detection summary if processed
320
+ if st.session_state.video_processed:
321
+ try:
322
+ # Create a download button for the annotated video
323
+ with open(st.session_state.output_video_path, "rb") as video_file:
324
+ st.download_button(
325
+ label="Download Annotated Video",
326
+ data=video_file,
327
+ file_name="annotated_video.mp4",
328
+ mime="video/mp4"
329
+ )
330
+
331
+ # Display detection table if there are detections
332
+ if st.session_state.detections_summary is not None:
333
+ detections_summary = st.session_state.detections_summary
334
+
335
+ st.write("### Detection Summary")
336
+ st.dataframe(detections_summary)
337
+ else:
338
+ st.warning("No objects detected in the video for the selected classes.")
339
  except Exception as e:
340
+ st.error(f"An error occurred while preparing the download: {e}")
341
 
342
+ # Optionally, display all available classes when 'all' is selected
343
+ if custom_classes_input.strip().lower() == "all":
344
+ st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}")
345
+
346
+ # Unet model training configuration
347
+
348
+ # Constants
349
+ IMG_SIZE = 128 # Resize dimension for the input image
350
+
351
+ # Load model function
352
+ @st.cache_resource
353
+ def load_model():
354
+ model = UNet(n_channels=3, n_classes=32) # Adjust according to your model setup
355
+ model.load_state_dict(torch.load("/Users/phongporter/Documents/GITHUB/cos40007-team/streamlit_unet/model/unet/checkpoint_epoch5.pth", map_location="cpu", weights_only=True), strict=False)
356
+ model.eval()
357
+ return model
358
 
359
+ # Function to preprocess the image
360
+ def preprocess_image(image):
361
+ transform = T.Compose([
362
+ T.Resize((IMG_SIZE, IMG_SIZE)), # Resize to match model input size
363
+ T.ToTensor(), # Convert to tensor
364
+ ])
365
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
366
+ return image_tensor
367
+
368
+ # Function to postprocess the model output for display
369
+ def postprocess_mask(mask):
370
+ # Convert mask to a numpy array and scale to 0-255
371
+ mask_np = mask.squeeze().cpu().numpy() # Remove batch and channel dimensions
372
+ mask_np = (mask_np > 0.5).astype(np.uint8) * 255 # Binarize and scale to 0-255
373
+ return mask_np
374
+
375
+ def unet():
376
+ try:
377
+ # Load the model
378
+ model = load_model()
379
+
380
+ st.markdown(
381
+ "<h1 style='text-align: center; font-size: 36px;'>Smart city rubbish detection</h1>",
382
+ unsafe_allow_html=True
383
+ )
384
+ st.markdown(
385
+ "<h2 style='text-align: center; font-size: 30px;'>Presented by team 2</h2>",
386
+ unsafe_allow_html=True
387
+ )
388
+
389
+ # Display the file upload widget
390
+ uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
391
+ if uploaded_file is not None:
392
+ # Open and display the uploaded image
393
+ image = Image.open(uploaded_file).convert("RGB")
394
+ st.image(image, caption="Uploaded Image", use_column_width=True)
395
 
396
+ # Preprocess the image
397
+ input_tensor = preprocess_image(image)
 
 
 
 
398
 
399
+ # Perform inference
400
+ with torch.no_grad(): # Disable gradient calculation for inference
401
+ output = model(input_tensor)
402
+ prediction = torch.sigmoid(output) # Apply sigmoid to get probabilities
403
+
404
+ # Post-process the mask for display
405
+ mask = postprocess_mask(prediction[0, 0]) # Get the mask from the first batch item
406
+
407
+ # Display the segmentation mask
408
+ st.image(mask, caption="Segmentation Mask", use_column_width=True)
409
+ except Exception as e:
410
+ st.error(f"An error occurred in Unet: {e}")
411
+
412
+ # Main page
413
+ if 'model_selected' not in st.session_state:
414
+ st.session_state.model_selected = None
415
+
416
+ def main():
417
+ st.markdown(
418
+ "<h1 style='text-align: center; font-size: 36px;'>Unet </h1>",
419
+ unsafe_allow_html=True
420
+ )
421
+
422
+ # Radio button for model selection with consistent casing
423
+ option = st.radio("Select Model:", ("Unet", "YOLO"))
424
+
425
+ # Submit button to confirm selection
426
+ if st.button("Choose"):
427
+ st.session_state.model_selected = option
428
+ st.success(f"Selected Model: {st.session_state.model_selected}")
429
+
430
+ # Render the selected model's interface based on session state
431
+ if st.session_state.model_selected == "Unet":
432
+ unet()
433
+ elif st.session_state.model_selected == "YOLO":
434
+ yolo()
435
+
436
+ if __name__ == "__main__":
437
+ main()
model/.DS_Store ADDED
Binary file (6.15 kB). View file
 
{unet → model/unet}/checkpoint_epoch5.pth RENAMED
File without changes
{yolo → model/yolo}/best.pt RENAMED
File without changes
unet/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .unet_model import UNet
unet/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (220 Bytes). View file
 
unet/__pycache__/unet_model.cpython-312.pyc ADDED
Binary file (2.21 kB). View file
 
unet/__pycache__/unet_parts.cpython-312.pyc ADDED
Binary file (4.46 kB). View file
 
unet/unet_model.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Full assembly of the parts to form the complete network """
2
+
3
+ from .unet_parts import *
4
+
5
+
6
+ class UNet(nn.Module):
7
+ def __init__(self, n_channels, n_classes, bilinear=False):
8
+ super(UNet, self).__init__()
9
+ self.n_channels = n_channels
10
+ self.n_classes = n_classes
11
+ self.bilinear = bilinear
12
+
13
+ self.inc = DoubleConv(n_channels, 64)
14
+ self.down1 = Down(64, 128)
15
+ self.down2 = Down(128, 256)
16
+ self.down3 = Down(256, 512)
17
+ factor = 2 if bilinear else 1
18
+ self.down4 = Down(512, 1024 // factor)
19
+ self.up1 = Up(1024, 512 // factor, bilinear)
20
+ self.up2 = Up(512, 256 // factor, bilinear)
21
+ self.up3 = Up(256, 128 // factor, bilinear)
22
+ self.up4 = Up(128, 64, bilinear)
23
+ self.outc = OutConv(64, n_classes)
24
+
25
+ def forward(self, x):
26
+ x1 = self.inc(x)
27
+ x2 = self.down1(x1)
28
+ x3 = self.down2(x2)
29
+ x4 = self.down3(x3)
30
+ x5 = self.down4(x4)
31
+ x = self.up1(x5, x4)
32
+ x = self.up2(x, x3)
33
+ x = self.up3(x, x2)
34
+ x = self.up4(x, x1)
35
+ logits = self.outc(x)
36
+ return logits
unet/unet_parts.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Parts of the U-Net model """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class DoubleConv(nn.Module):
9
+ """(convolution => [BN] => ReLU) * 2"""
10
+
11
+ def __init__(self, in_channels, out_channels, mid_channels=None):
12
+ super().__init__()
13
+ if not mid_channels:
14
+ mid_channels = out_channels
15
+ self.double_conv = nn.Sequential(
16
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17
+ nn.BatchNorm2d(mid_channels),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.ReLU(inplace=True)
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.double_conv(x)
26
+
27
+
28
+ class Down(nn.Module):
29
+ """Downscaling with maxpool then double conv"""
30
+
31
+ def __init__(self, in_channels, out_channels):
32
+ super().__init__()
33
+ self.maxpool_conv = nn.Sequential(
34
+ nn.MaxPool2d(2),
35
+ DoubleConv(in_channels, out_channels)
36
+ )
37
+
38
+ def forward(self, x):
39
+ return self.maxpool_conv(x)
40
+
41
+
42
+ class Up(nn.Module):
43
+ """Upscaling then double conv"""
44
+
45
+ def __init__(self, in_channels, out_channels, bilinear=True):
46
+ super().__init__()
47
+
48
+ # if bilinear, use the normal convolutions to reduce the number of channels
49
+ if bilinear:
50
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52
+ else:
53
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54
+ self.conv = DoubleConv(in_channels, out_channels)
55
+
56
+ def forward(self, x1, x2):
57
+ x1 = self.up(x1)
58
+ # input is CHW
59
+ diffY = x2.size()[2] - x1.size()[2]
60
+ diffX = x2.size()[3] - x1.size()[3]
61
+
62
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63
+ diffY // 2, diffY - diffY // 2])
64
+ # if you have padding issues, see
65
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
66
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
67
+ x = torch.cat([x2, x1], dim=1)
68
+ return self.conv(x)
69
+
70
+
71
+ class OutConv(nn.Module):
72
+ def __init__(self, in_channels, out_channels):
73
+ super(OutConv, self).__init__()
74
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
75
+
76
+ def forward(self, x):
77
+ return self.conv(x)