phonghaitran's picture
Specify clearly the title
8fcc1b0
from PIL import Image, ImageDraw
# Import the model components from unet directory
from unet.unet_model import UNet
import streamlit as st
import plotly.express as px
import pandas as pd
import numpy as np
import torchvision.transforms as T
import torch
import pathlib
import io
import cv2
import tempfile
# Adjust Path for Local Repository
pathlib.WindowsPath = pathlib.PosixPath
st.title("Smart city rubbish detection Web Application")
def yolo():
st.markdown(
"<h1 style='text-align: center; font-size: 36px;'>Yolo object detection</h1>",
unsafe_allow_html=True
)
st.markdown(
"<h2 style='text-align: center; font-size: 30px;'>Using Yolov5</h2>",
unsafe_allow_html=True
)
# Define the available labels
default_sub_classes = [
"container",
"waste-paper",
"plant",
"transportation",
"kitchenware",
"rubbish bag",
"chair",
"wood",
"electronics good",
"sofa",
"scrap metal",
"carton",
"bag",
"tarpaulin",
"accessory",
"rubble",
"table",
"board",
"mattress",
"beverage",
"tyre",
"nylon",
"rack",
"styrofoam",
"clothes",
"toy",
"furniture",
"trolley",
"carpet",
"plastic cup"
]
# Initialize session state for video processing
if 'video_processed' not in st.session_state:
st.session_state.video_processed = False
st.session_state.output_video_path = None
st.session_state.detections_summary = None
# Cache the model loading to prevent repeated loads
@st.cache_resource
def load_model():
model = torch.hub.load('./yolov5', 'custom', path='./model/yolo/best.pt', source='local', force_reload=False)
return model
model = load_model()
# Retrieve model class names
model_class_names = model.names # Dictionary {index: class_name}
# Function to map class names to indices (case-insensitive)
def get_class_indices(class_list):
indices = []
not_found = []
for cls in class_list:
found = False
for index, name in model_class_names.items():
if name.lower() == cls.lower():
indices.append(index)
found = True
break
if not found:
not_found.append(cls)
return indices, not_found
# Function to annotate images
def annotate_image(frame, results):
results.render() # Updates results.ims with the annotated images
annotated_frame = results.ims[0] # Get the first (and only) image
return annotated_frame
# Inform the user about the available labels
st.markdown("### Available Classes:")
st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**")
# Inform the user about the default detection
st.info("By default, the application will detect **rubbish** only.")
# User input for classes, separated by commas (optional)
custom_classes_input = st.text_input(
"Enter classes (comma-separated) or type 'all' to detect everything:",
""
)
# Retrieve all model classes
all_model_classes = list(model_class_names.values())
# Determine classes to use based on user input
if custom_classes_input.strip() == "":
# No input provided; use only 'rubbish'
selected_classes = ['rubbish']
st.info("No classes entered. Using default class: **rubbish**.")
elif custom_classes_input.strip().lower() == "all":
# User chose to detect all classes
selected_classes = all_model_classes
st.info("Detecting **all** available classes.")
else:
# User provided specific classes
# Split the input string into a list of classes and remove any extra whitespace
input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()]
# Ensure 'rubbish' is included
if 'rubbish' not in [cls.lower() for cls in input_classes]:
selected_classes = input_classes + ['rubbish']
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)")
else:
selected_classes = input_classes
st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**")
# Map selected class names to their indices
selected_class_indices, not_found_classes = get_class_indices(selected_classes)
if not_found_classes:
st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**")
# Proceed only if there are valid classes to detect
if selected_class_indices:
# Set the classes for the model
model.classes = selected_class_indices
# --------------------- Image Upload and Processing ---------------------
st.header("Image Object Detection")
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload")
if uploaded_file is not None:
try:
# Convert the file to a PIL image
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption="Uploaded Image", use_column_width=True)
st.write("Processing...")
# Perform inference
results = model(image)
# Extract DataFrame from results
results_df = results.pandas().xyxy[0]
# Filter results to include only selected classes
filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])]
if filtered_results.empty:
st.warning("No objects detected for the selected classes.")
else:
# Display filtered results
st.write("### Detection Results")
st.dataframe(filtered_results)
# Annotate the image
annotated_image = annotate_image(np.array(image), results)
# Convert annotated image back to PIL format
annotated_pil = Image.fromarray(annotated_image)
# Display annotated image
st.image(annotated_pil, caption="Annotated Image", use_column_width=True)
# Convert annotated image to bytes
img_byte_arr = io.BytesIO()
annotated_pil.save(img_byte_arr, format='PNG')
img_byte_arr = img_byte_arr.getvalue()
# Add download button
st.download_button(
label="Download Annotated Image",
data=img_byte_arr,
file_name='annotated_image.png',
mime='image/png'
)
except Exception as e:
st.error(f"An error occurred during image processing: {e}")
# --------------------- Video Upload and Processing ---------------------
st.header("Video Object Detection")
uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload")
if uploaded_video is not None:
# Check if the uploaded video is different from the previously processed one
# Check if the uploaded video first time
if st.session_state.get("uploaded_video_name") is None:
st.session_state.uploaded_video_name = uploaded_video.name
print("First time uploaded video" +st.session_state.uploaded_video_name)
elif st.session_state.uploaded_video_name != uploaded_video.name:
st.session_state.uploaded_video_name = uploaded_video.name
print("Another time uploaded video" +st.session_state.uploaded_video_name)
st.session_state.video_processed = False
st.session_state.output_video_path = None
st.session_state.detections_summary = None
print("New uploaded video")
# Reset session state if video upload is removed
if uploaded_video is None and st.session_state.video_processed:
st.session_state.video_processed = False
st.session_state.output_video_path = None
st.session_state.detections_summary = None
st.warning("Video upload has been cleared. You can upload a new video for processing.")
if uploaded_video:
if not st.session_state.video_processed:
try:
with st.spinner("Processing video..."):
# Save uploaded video to a temporary file
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(uploaded_video.read())
tfile.close()
# Open the video file
video_cap = cv2.VideoCapture(tfile.name)
stframe = st.empty() # Placeholder for displaying video frames
# Initialize VideoWriter for saving the output video
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
fps = video_cap.get(cv2.CAP_PROP_FPS)
width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name
out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height))
frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT))
progress_bar = st.progress(0)
# Initialize list to collect all detections
all_detections = []
for frame_num in range(frame_count):
ret, frame = video_cap.read() # Read a frame from the video
if not ret:
break
# Convert frame to RGB
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Perform inference
results = model(frame_rgb)
# Extract DataFrame from results
results_df = results.pandas().xyxy[0]
results_df['frame_num'] = frame_num # Optional: Add frame number for reference
# Append detections to the list
if not results_df.empty:
all_detections.append(results_df)
# Annotate the frame with detections
annotated_frame = annotate_image(frame_rgb, results)
# Convert annotated frame back to BGR for VideoWriter
annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR)
# Write the annotated frame to the output video
out.write(annotated_bgr)
# Display the annotated frame in Streamlit
stframe.image(annotated_frame, channels="RGB", use_column_width=True)
# Update progress bar
progress_percent = (frame_num + 1) / frame_count
progress_bar.progress(progress_percent)
video_cap.release() # Release the video capture object
out.release() # Release the VideoWriter object
# Save processed video path and detections summary to session state
st.session_state.output_video_path = output_video_path
if all_detections:
# Concatenate all detections into a single DataFrame
detections_df = pd.concat(all_detections, ignore_index=True)
# Optional: Group by class name and count detections
detections_summary = detections_df.groupby('name').size().reset_index(name='counts')
st.session_state.detections_summary = detections_summary
else:
st.session_state.detections_summary = None
# Mark video as processed
st.session_state.video_processed = True
# st.session_state.uploaded_video_name = uploaded_video.name
st.success("Video processing complete!")
except Exception as e:
st.error(f"An error occurred during video processing: {e}")
# Display download button and detection summary if processed
if st.session_state.video_processed:
try:
# Create a download button for the annotated video
with open(st.session_state.output_video_path, "rb") as video_file:
st.download_button(
label="Download Annotated Video",
data=video_file,
file_name="annotated_video.mp4",
mime="video/mp4"
)
# Display detection table if there are detections
if st.session_state.detections_summary is not None:
detections_summary = st.session_state.detections_summary
st.write("### Detection Summary")
st.dataframe(detections_summary)
else:
st.warning("No objects detected in the video for the selected classes.")
except Exception as e:
st.error(f"An error occurred while preparing the download: {e}")
# Optionally, display all available classes when 'all' is selected
if custom_classes_input.strip().lower() == "all":
st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}")
# Unet model training configuration
# Constants
IMG_SIZE = 128 # Resize dimension for the input image
# Load model function
@st.cache_resource
def load_model():
model = UNet(n_channels=3, n_classes=32) # Adjust according to your model setup
model.load_state_dict(torch.load("./model/unet/checkpoint_epoch5.pth", map_location="cpu", weights_only=True), strict=False)
model.eval()
return model
# Function to preprocess the image
def preprocess_image(image):
transform = T.Compose([
T.Resize((IMG_SIZE, IMG_SIZE)), # Resize to match model input size
T.ToTensor(), # Convert to tensor
])
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
return image_tensor
# Function to postprocess the model output for display
def postprocess_mask(mask):
# Convert mask to a numpy array and scale to 0-255
mask_np = mask.squeeze().cpu().numpy() # Remove batch and channel dimensions
mask_np = (mask_np > 0.5).astype(np.uint8) * 255 # Binarize and scale to 0-255
return mask_np
def unet():
try:
# Load the model
model = load_model()
st.markdown(
"<h1 style='text-align: center; font-size: 36px;'>Unet object detection</h1>",
unsafe_allow_html=True
)
st.markdown(
"<h2 style='text-align: center; font-size: 30px;'>Using Unet - Pytorch</h2>",
unsafe_allow_html=True
)
# Display the file upload widget
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_file is not None:
st.write("Processing...")
# Open and display the uploaded image
image = Image.open(uploaded_file).convert("RGB")
st.image(image, caption="Uploaded Image", use_column_width=True)
# Preprocess the image
input_tensor = preprocess_image(image)
# Perform inference
with torch.no_grad(): # Disable gradient calculation for inference
output = model(input_tensor)
prediction = torch.sigmoid(output) # Apply sigmoid to get probabilities
# Post-process the mask for display
mask = postprocess_mask(prediction[0, 0]) # Get the mask from the first batch item
# Display the segmentation mask
st.image(mask, caption="Segmentation Mask", use_column_width=True)
except Exception as e:
st.error(f"An error occurred in Unet: {e}")
# Main page
if 'model_selected' not in st.session_state:
st.session_state.model_selected = None
def main():
# Radio button for model selection with consistent casing
option = st.radio("Select Model:", ("Unet", "YOLO"))
# Submit button to confirm selection
if st.button("Choose"):
st.session_state.model_selected = option
st.success(f"Selected Model: {st.session_state.model_selected}")
# Render the selected model's interface based on session state
if st.session_state.model_selected == "Unet":
unet()
elif st.session_state.model_selected == "YOLO":
yolo()
if __name__ == "__main__":
main()