|
import streamlit as st |
|
import torch |
|
import pathlib |
|
from PIL import Image |
|
import io |
|
import numpy as np |
|
import cv2 |
|
import tempfile |
|
import pandas as pd |
|
|
|
|
|
pathlib.WindowsPath = pathlib.PosixPath |
|
|
|
st.title("YOLO Object Detection Web App") |
|
|
|
|
|
default_sub_classes = [ |
|
"container", |
|
"waste-paper", |
|
"plant", |
|
"transportation", |
|
"kitchenware", |
|
"rubbish bag", |
|
"chair", |
|
"wood", |
|
"electronics good", |
|
"sofa", |
|
"scrap metal", |
|
"carton", |
|
"bag", |
|
"tarpaulin", |
|
"accessory", |
|
"rubble", |
|
"table", |
|
"board", |
|
"mattress", |
|
"beverage", |
|
"tyre", |
|
"nylon", |
|
"rack", |
|
"styrofoam", |
|
"clothes", |
|
"toy", |
|
"furniture", |
|
"trolley", |
|
"carpet", |
|
"plastic cup" |
|
] |
|
|
|
|
|
if 'video_processed' not in st.session_state: |
|
st.session_state.video_processed = False |
|
st.session_state.output_video_path = None |
|
st.session_state.detections_summary = None |
|
|
|
|
|
@st.cache_resource |
|
def load_model(): |
|
model = torch.hub.load('./yolov5', 'custom', path='./yolo/best.pt', source='local', force_reload=False) |
|
return model |
|
|
|
model = load_model() |
|
|
|
|
|
model_class_names = model.names |
|
|
|
|
|
def get_class_indices(class_list): |
|
indices = [] |
|
not_found = [] |
|
for cls in class_list: |
|
found = False |
|
for index, name in model_class_names.items(): |
|
if name.lower() == cls.lower(): |
|
indices.append(index) |
|
found = True |
|
break |
|
if not found: |
|
not_found.append(cls) |
|
return indices, not_found |
|
|
|
|
|
def annotate_image(frame, results): |
|
results.render() |
|
annotated_frame = results.ims[0] |
|
return annotated_frame |
|
|
|
|
|
st.markdown("### Available Classes:") |
|
st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**") |
|
|
|
|
|
st.info("By default, the application will detect **rubbish** only.") |
|
|
|
|
|
custom_classes_input = st.text_input( |
|
"Enter classes (comma-separated) or type 'all' to detect everything:", |
|
"" |
|
) |
|
|
|
|
|
all_model_classes = list(model_class_names.values()) |
|
|
|
|
|
if custom_classes_input.strip() == "": |
|
|
|
selected_classes = ['rubbish'] |
|
st.info("No classes entered. Using default class: **rubbish**.") |
|
elif custom_classes_input.strip().lower() == "all": |
|
|
|
selected_classes = all_model_classes |
|
st.info("Detecting **all** available classes.") |
|
else: |
|
|
|
|
|
input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()] |
|
|
|
if 'rubbish' not in [cls.lower() for cls in input_classes]: |
|
selected_classes = input_classes + ['rubbish'] |
|
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)") |
|
else: |
|
selected_classes = input_classes |
|
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**") |
|
|
|
|
|
selected_class_indices, not_found_classes = get_class_indices(selected_classes) |
|
|
|
if not_found_classes: |
|
st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**") |
|
|
|
|
|
if selected_class_indices: |
|
|
|
model.classes = selected_class_indices |
|
|
|
|
|
st.header("Image Object Detection") |
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload") |
|
|
|
if uploaded_file is not None: |
|
try: |
|
|
|
image = Image.open(uploaded_file).convert('RGB') |
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
st.write("Processing...") |
|
|
|
|
|
results = model(image) |
|
|
|
|
|
results_df = results.pandas().xyxy[0] |
|
|
|
|
|
filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])] |
|
|
|
if filtered_results.empty: |
|
st.warning("No objects detected for the selected classes.") |
|
else: |
|
|
|
st.write("### Detection Results") |
|
st.dataframe(filtered_results) |
|
|
|
|
|
annotated_image = annotate_image(np.array(image), results) |
|
|
|
|
|
annotated_pil = Image.fromarray(annotated_image) |
|
|
|
|
|
st.image(annotated_pil, caption="Annotated Image", use_column_width=True) |
|
|
|
|
|
img_byte_arr = io.BytesIO() |
|
annotated_pil.save(img_byte_arr, format='PNG') |
|
img_byte_arr = img_byte_arr.getvalue() |
|
|
|
|
|
st.download_button( |
|
label="Download Annotated Image", |
|
data=img_byte_arr, |
|
file_name='annotated_image.png', |
|
mime='image/png' |
|
) |
|
except Exception as e: |
|
st.error(f"An error occurred during image processing: {e}") |
|
|
|
|
|
st.header("Video Object Detection") |
|
|
|
uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload") |
|
|
|
if uploaded_video is not None: |
|
|
|
|
|
if st.session_state.get("uploaded_video_name") is None: |
|
st.session_state.uploaded_video_name = uploaded_video.name |
|
print("First time uploaded video" +st.session_state.uploaded_video_name) |
|
elif st.session_state.uploaded_video_name != uploaded_video.name: |
|
st.session_state.uploaded_video_name = uploaded_video.name |
|
print("Another time uploaded video" +st.session_state.uploaded_video_name) |
|
st.session_state.video_processed = False |
|
st.session_state.output_video_path = None |
|
st.session_state.detections_summary = None |
|
print("New uploaded video") |
|
|
|
|
|
if uploaded_video is None and st.session_state.video_processed: |
|
st.session_state.video_processed = False |
|
st.session_state.output_video_path = None |
|
st.session_state.detections_summary = None |
|
st.warning("Video upload has been cleared. You can upload a new video for processing.") |
|
|
|
if uploaded_video: |
|
if not st.session_state.video_processed: |
|
try: |
|
with st.spinner("Processing video..."): |
|
|
|
tfile = tempfile.NamedTemporaryFile(delete=False) |
|
tfile.write(uploaded_video.read()) |
|
tfile.close() |
|
|
|
|
|
video_cap = cv2.VideoCapture(tfile.name) |
|
stframe = st.empty() |
|
|
|
|
|
fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
|
fps = video_cap.get(cv2.CAP_PROP_FPS) |
|
width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name |
|
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) |
|
|
|
frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
progress_bar = st.progress(0) |
|
|
|
|
|
all_detections = [] |
|
|
|
for frame_num in range(frame_count): |
|
ret, frame = video_cap.read() |
|
if not ret: |
|
break |
|
|
|
|
|
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|
|
|
|
|
results = model(frame_rgb) |
|
|
|
|
|
results_df = results.pandas().xyxy[0] |
|
results_df['frame_num'] = frame_num |
|
|
|
|
|
if not results_df.empty: |
|
all_detections.append(results_df) |
|
|
|
|
|
annotated_frame = annotate_image(frame_rgb, results) |
|
|
|
|
|
annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) |
|
|
|
|
|
out.write(annotated_bgr) |
|
|
|
|
|
stframe.image(annotated_frame, channels="RGB", use_column_width=True) |
|
|
|
|
|
progress_percent = (frame_num + 1) / frame_count |
|
progress_bar.progress(progress_percent) |
|
|
|
video_cap.release() |
|
out.release() |
|
|
|
|
|
st.session_state.output_video_path = output_video_path |
|
|
|
if all_detections: |
|
|
|
detections_df = pd.concat(all_detections, ignore_index=True) |
|
|
|
|
|
detections_summary = detections_df.groupby('name').size().reset_index(name='counts') |
|
st.session_state.detections_summary = detections_summary |
|
else: |
|
st.session_state.detections_summary = None |
|
|
|
|
|
st.session_state.video_processed = True |
|
|
|
|
|
|
|
st.success("Video processing complete!") |
|
|
|
except Exception as e: |
|
st.error(f"An error occurred during video processing: {e}") |
|
|
|
|
|
if st.session_state.video_processed: |
|
try: |
|
|
|
with open(st.session_state.output_video_path, "rb") as video_file: |
|
st.download_button( |
|
label="Download Annotated Video", |
|
data=video_file, |
|
file_name="annotated_video.mp4", |
|
mime="video/mp4" |
|
) |
|
|
|
|
|
if st.session_state.detections_summary is not None: |
|
detections_summary = st.session_state.detections_summary |
|
|
|
st.write("### Detection Summary") |
|
st.dataframe(detections_summary) |
|
else: |
|
st.warning("No objects detected in the video for the selected classes.") |
|
except Exception as e: |
|
st.error(f"An error occurred while preparing the download: {e}") |
|
|
|
|
|
if custom_classes_input.strip().lower() == "all": |
|
st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}") |