|
from PIL import Image, ImageDraw |
|
|
|
|
|
from unet.unet_model import UNet |
|
|
|
import streamlit as st |
|
import plotly.express as px |
|
import pandas as pd |
|
import numpy as np |
|
import torchvision.transforms as T |
|
|
|
import torch |
|
import pathlib |
|
import io |
|
import cv2 |
|
import tempfile |
|
|
|
|
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
st.title("Smart city rubbish detection Web Application") |
|
|
|
def yolo(): |
|
st.markdown( |
|
"<h1 style='text-align: center; font-size: 36px;'>Yolo object detection</h1>", |
|
unsafe_allow_html=True |
|
) |
|
st.markdown( |
|
"<h2 style='text-align: center; font-size: 30px;'>Using Yolov5</h2>", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
default_sub_classes = [ |
|
"container", |
|
"waste-paper", |
|
"plant", |
|
"transportation", |
|
"kitchenware", |
|
"rubbish bag", |
|
"chair", |
|
"wood", |
|
"electronics good", |
|
"sofa", |
|
"scrap metal", |
|
"carton", |
|
"bag", |
|
"tarpaulin", |
|
"accessory", |
|
"rubble", |
|
"table", |
|
"board", |
|
"mattress", |
|
"beverage", |
|
"tyre", |
|
"nylon", |
|
"rack", |
|
"styrofoam", |
|
"clothes", |
|
"toy", |
|
"furniture", |
|
"trolley", |
|
"carpet", |
|
"plastic cup" |
|
] |
|
|
|
|
|
if 'video_processed' not in st.session_state: |
|
st.session_state.video_processed = False |
|
st.session_state.output_video_path = None |
|
st.session_state.detections_summary = None |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = torch.hub.load('./yolov5', 'custom', path='./model/yolo/best.pt', source='local', force_reload=False) |
|
return model |
|
|
|
model = load_model() |
|
|
|
|
|
model_class_names = model.names |
|
|
|
|
|
def get_class_indices(class_list): |
|
indices = [] |
|
not_found = [] |
|
for cls in class_list: |
|
found = False |
|
for index, name in model_class_names.items(): |
|
if name.lower() == cls.lower(): |
|
indices.append(index) |
|
found = True |
|
break |
|
if not found: |
|
not_found.append(cls) |
|
return indices, not_found |
|
|
|
|
|
def annotate_image(frame, results): |
|
results.render() |
|
annotated_frame = results.ims[0] |
|
return annotated_frame |
|
|
|
|
|
st.markdown("### Available Classes:") |
|
st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**") |
|
|
|
|
|
st.info("By default, the application will detect **rubbish** only.") |
|
|
|
|
|
custom_classes_input = st.text_input( |
|
"Enter classes (comma-separated) or type 'all' to detect everything:", |
|
"" |
|
) |
|
|
|
|
|
all_model_classes = list(model_class_names.values()) |
|
|
|
|
|
if custom_classes_input.strip() == "": |
|
|
|
selected_classes = ['rubbish'] |
|
st.info("No classes entered. Using default class: **rubbish**.") |
|
elif custom_classes_input.strip().lower() == "all": |
|
|
|
selected_classes = all_model_classes |
|
st.info("Detecting **all** available classes.") |
|
else: |
|
|
|
|
|
input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()] |
|
|
|
if 'rubbish' not in [cls.lower() for cls in input_classes]: |
|
selected_classes = input_classes + ['rubbish'] |
|
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)") |
|
else: |
|
selected_classes = input_classes |
|
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**") |
|
|
|
|
|
selected_class_indices, not_found_classes = get_class_indices(selected_classes) |
|
|
|
if not_found_classes: |
|
st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**") |
|
|
|
|
|
if selected_class_indices: |
|
|
|
model.classes = selected_class_indices |
|
|
|
|
|
st.header("Image Object Detection") |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload") |
|
|
|
if uploaded_file is not None: |
|
try: |
|
|
|
image = Image.open(uploaded_file).convert('RGB') |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
st.write("Processing...") |
|
|
|
|
|
results = model(image) |
|
|
|
|
|
results_df = results.pandas().xyxy[0] |
|
|
|
|
|
filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])] |
|
|
|
if filtered_results.empty: |
|
st.warning("No objects detected for the selected classes.") |
|
else: |
|
|
|
st.write("### Detection Results") |
|
st.dataframe(filtered_results) |
|
|
|
|
|
annotated_image = annotate_image(np.array(image), results) |
|
|
|
|
|
annotated_pil = Image.fromarray(annotated_image) |
|
|
|
|
|
st.image(annotated_pil, caption="Annotated Image", use_column_width=True) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
annotated_pil.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
|
|
|
st.download_button( |
|
label="Download Annotated Image", |
|
data=img_byte_arr, |
|
file_name='annotated_image.png', |
|
mime='image/png' |
|
) |
|
except Exception as e: |
|
st.error(f"An error occurred during image processing: {e}") |
|
|
|
|
|
st.header("Video Object Detection") |
|
|
|
uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload") |
|
|
|
if uploaded_video is not None: |
|
|
|
|
|
if st.session_state.get("uploaded_video_name") is None: |
|
st.session_state.uploaded_video_name = uploaded_video.name |
|
print("First time uploaded video" +st.session_state.uploaded_video_name) |
|
elif st.session_state.uploaded_video_name != uploaded_video.name: |
|
st.session_state.uploaded_video_name = uploaded_video.name |
|
print("Another time uploaded video" +st.session_state.uploaded_video_name) |
|
st.session_state.video_processed = False |
|
st.session_state.output_video_path = None |
|
st.session_state.detections_summary = None |
|
print("New uploaded video") |
|
|
|
|
|
if uploaded_video is None and st.session_state.video_processed: |
|
st.session_state.video_processed = False |
|
st.session_state.output_video_path = None |
|
st.session_state.detections_summary = None |
|
st.warning("Video upload has been cleared. You can upload a new video for processing.") |
|
|
|
if uploaded_video: |
|
if not st.session_state.video_processed: |
|
try: |
|
with st.spinner("Processing video..."): |
|
|
|
tfile = tempfile.NamedTemporaryFile(delete=False) |
|
tfile.write(uploaded_video.read()) |
|
tfile.close() |
|
|
|
|
|
video_cap = cv2.VideoCapture(tfile.name) |
|
stframe = st.empty() |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
fps = video_cap.get(cv2.CAP_PROP_FPS) |
|
width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name |
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) |
|
|
|
frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
progress_bar = st.progress(0) |
|
|
|
|
|
all_detections = [] |
|
|
|
for frame_num in range(frame_count): |
|
ret, frame = video_cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
results = model(frame_rgb) |
|
|
|
|
|
results_df = results.pandas().xyxy[0] |
|
results_df['frame_num'] = frame_num |
|
|
|
|
|
if not results_df.empty: |
|
all_detections.append(results_df) |
|
|
|
|
|
annotated_frame = annotate_image(frame_rgb, results) |
|
|
|
|
|
annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
out.write(annotated_bgr) |
|
|
|
|
|
stframe.image(annotated_frame, channels="RGB", use_column_width=True) |
|
|
|
|
|
progress_percent = (frame_num + 1) / frame_count |
|
progress_bar.progress(progress_percent) |
|
|
|
video_cap.release() |
|
out.release() |
|
|
|
|
|
st.session_state.output_video_path = output_video_path |
|
|
|
if all_detections: |
|
|
|
detections_df = pd.concat(all_detections, ignore_index=True) |
|
|
|
|
|
detections_summary = detections_df.groupby('name').size().reset_index(name='counts') |
|
st.session_state.detections_summary = detections_summary |
|
else: |
|
st.session_state.detections_summary = None |
|
|
|
|
|
st.session_state.video_processed = True |
|
|
|
|
|
|
|
st.success("Video processing complete!") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during video processing: {e}") |
|
|
|
|
|
if st.session_state.video_processed: |
|
try: |
|
|
|
with open(st.session_state.output_video_path, "rb") as video_file: |
|
st.download_button( |
|
label="Download Annotated Video", |
|
data=video_file, |
|
file_name="annotated_video.mp4", |
|
mime="video/mp4" |
|
) |
|
|
|
|
|
if st.session_state.detections_summary is not None: |
|
detections_summary = st.session_state.detections_summary |
|
|
|
st.write("### Detection Summary") |
|
st.dataframe(detections_summary) |
|
else: |
|
st.warning("No objects detected in the video for the selected classes.") |
|
except Exception as e: |
|
st.error(f"An error occurred while preparing the download: {e}") |
|
|
|
|
|
if custom_classes_input.strip().lower() == "all": |
|
st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}") |
|
|
|
|
|
|
|
|
|
IMG_SIZE = 128 |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = UNet(n_channels=3, n_classes=32) |
|
model.load_state_dict(torch.load("./model/unet/checkpoint_epoch5.pth", map_location="cpu", weights_only=True), strict=False) |
|
model.eval() |
|
return model |
|
|
|
|
|
def preprocess_image(image): |
|
transform = T.Compose([ |
|
T.Resize((IMG_SIZE, IMG_SIZE)), |
|
T.ToTensor(), |
|
]) |
|
image_tensor = transform(image).unsqueeze(0) |
|
return image_tensor |
|
|
|
|
|
def postprocess_mask(mask): |
|
|
|
mask_np = mask.squeeze().cpu().numpy() |
|
mask_np = (mask_np > 0.5).astype(np.uint8) * 255 |
|
return mask_np |
|
|
|
def unet(): |
|
try: |
|
|
|
model = load_model() |
|
|
|
st.markdown( |
|
"<h1 style='text-align: center; font-size: 36px;'>Unet object detection</h1>", |
|
unsafe_allow_html=True |
|
) |
|
st.markdown( |
|
"<h2 style='text-align: center; font-size: 30px;'>Using Unet - Pytorch</h2>", |
|
unsafe_allow_html=True |
|
) |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
if uploaded_file is not None: |
|
st.write("Processing...") |
|
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
input_tensor = preprocess_image(image) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model(input_tensor) |
|
prediction = torch.sigmoid(output) |
|
|
|
|
|
mask = postprocess_mask(prediction[0, 0]) |
|
|
|
|
|
st.image(mask, caption="Segmentation Mask", use_column_width=True) |
|
except Exception as e: |
|
st.error(f"An error occurred in Unet: {e}") |
|
|
|
|
|
if 'model_selected' not in st.session_state: |
|
st.session_state.model_selected = None |
|
|
|
def main(): |
|
|
|
option = st.radio("Select Model:", ("Unet", "YOLO")) |
|
|
|
|
|
if st.button("Choose"): |
|
st.session_state.model_selected = option |
|
st.success(f"Selected Model: {st.session_state.model_selected}") |
|
|
|
|
|
if st.session_state.model_selected == "Unet": |
|
unet() |
|
elif st.session_state.model_selected == "YOLO": |
|
yolo() |
|
|
|
if __name__ == "__main__": |
|
main() |