from PIL import Image, ImageDraw # Import the model components from unet directory from unet.unet_model import UNet import streamlit as st import plotly.express as px import pandas as pd import numpy as np import torchvision.transforms as T import torch import pathlib import io import cv2 import tempfile # Adjust Path for Local Repository pathlib.WindowsPath = pathlib.PosixPath st.title("Smart city rubbish detection Web Application") def yolo(): st.markdown( "

Yolo object detection

", unsafe_allow_html=True ) st.markdown( "

Using Yolov5

", unsafe_allow_html=True ) # Define the available labels default_sub_classes = [ "container", "waste-paper", "plant", "transportation", "kitchenware", "rubbish bag", "chair", "wood", "electronics good", "sofa", "scrap metal", "carton", "bag", "tarpaulin", "accessory", "rubble", "table", "board", "mattress", "beverage", "tyre", "nylon", "rack", "styrofoam", "clothes", "toy", "furniture", "trolley", "carpet", "plastic cup" ] # Initialize session state for video processing if 'video_processed' not in st.session_state: st.session_state.video_processed = False st.session_state.output_video_path = None st.session_state.detections_summary = None # Cache the model loading to prevent repeated loads @st.cache_resource def load_model(): model = torch.hub.load('./yolov5', 'custom', path='./model/yolo/best.pt', source='local', force_reload=False) return model model = load_model() # Retrieve model class names model_class_names = model.names # Dictionary {index: class_name} # Function to map class names to indices (case-insensitive) def get_class_indices(class_list): indices = [] not_found = [] for cls in class_list: found = False for index, name in model_class_names.items(): if name.lower() == cls.lower(): indices.append(index) found = True break if not found: not_found.append(cls) return indices, not_found # Function to annotate images def annotate_image(frame, results): results.render() # Updates results.ims with the annotated images annotated_frame = results.ims[0] # Get the first (and only) image return annotated_frame # Inform the user about the available labels st.markdown("### Available Classes:") st.markdown("**" + ", ".join(default_sub_classes + ["rubbish"]) + "**") # Inform the user about the default detection st.info("By default, the application will detect **rubbish** only.") # User input for classes, separated by commas (optional) custom_classes_input = st.text_input( "Enter classes (comma-separated) or type 'all' to detect everything:", "" ) # Retrieve all model classes all_model_classes = list(model_class_names.values()) # Determine classes to use based on user input if custom_classes_input.strip() == "": # No input provided; use only 'rubbish' selected_classes = ['rubbish'] st.info("No classes entered. Using default class: **rubbish**.") elif custom_classes_input.strip().lower() == "all": # User chose to detect all classes selected_classes = all_model_classes st.info("Detecting **all** available classes.") else: # User provided specific classes # Split the input string into a list of classes and remove any extra whitespace input_classes = [cls.strip() for cls in custom_classes_input.split(",") if cls.strip()] # Ensure 'rubbish' is included if 'rubbish' not in [cls.lower() for cls in input_classes]: selected_classes = input_classes + ['rubbish'] st.info(f"Detecting the following classes: **{', '.join(selected_classes)}** (Including **rubbish**)") else: selected_classes = input_classes st.info(f"Detecting the following classes: **{', '.join(selected_classes)}**") # Map selected class names to their indices selected_class_indices, not_found_classes = get_class_indices(selected_classes) if not_found_classes: st.warning(f"The following classes were not found in the model and will be ignored: **{', '.join(not_found_classes)}**") # Proceed only if there are valid classes to detect if selected_class_indices: # Set the classes for the model model.classes = selected_class_indices # --------------------- Image Upload and Processing --------------------- st.header("Image Object Detection") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"], key="image_upload") if uploaded_file is not None: try: # Convert the file to a PIL image image = Image.open(uploaded_file).convert('RGB') st.image(image, caption="Uploaded Image", use_column_width=True) st.write("Processing...") # Perform inference results = model(image) # Extract DataFrame from results results_df = results.pandas().xyxy[0] # Filter results to include only selected classes filtered_results = results_df[results_df['name'].str.lower().isin([cls.lower() for cls in selected_classes])] if filtered_results.empty: st.warning("No objects detected for the selected classes.") else: # Display filtered results st.write("### Detection Results") st.dataframe(filtered_results) # Annotate the image annotated_image = annotate_image(np.array(image), results) # Convert annotated image back to PIL format annotated_pil = Image.fromarray(annotated_image) # Display annotated image st.image(annotated_pil, caption="Annotated Image", use_column_width=True) # Convert annotated image to bytes img_byte_arr = io.BytesIO() annotated_pil.save(img_byte_arr, format='PNG') img_byte_arr = img_byte_arr.getvalue() # Add download button st.download_button( label="Download Annotated Image", data=img_byte_arr, file_name='annotated_image.png', mime='image/png' ) except Exception as e: st.error(f"An error occurred during image processing: {e}") # --------------------- Video Upload and Processing --------------------- st.header("Video Object Detection") uploaded_video = st.file_uploader("Choose a video...", type=["mp4", "avi", "mov"], key="video_upload") if uploaded_video is not None: # Check if the uploaded video is different from the previously processed one # Check if the uploaded video first time if st.session_state.get("uploaded_video_name") is None: st.session_state.uploaded_video_name = uploaded_video.name print("First time uploaded video" +st.session_state.uploaded_video_name) elif st.session_state.uploaded_video_name != uploaded_video.name: st.session_state.uploaded_video_name = uploaded_video.name print("Another time uploaded video" +st.session_state.uploaded_video_name) st.session_state.video_processed = False st.session_state.output_video_path = None st.session_state.detections_summary = None print("New uploaded video") # Reset session state if video upload is removed if uploaded_video is None and st.session_state.video_processed: st.session_state.video_processed = False st.session_state.output_video_path = None st.session_state.detections_summary = None st.warning("Video upload has been cleared. You can upload a new video for processing.") if uploaded_video: if not st.session_state.video_processed: try: with st.spinner("Processing video..."): # Save uploaded video to a temporary file tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(uploaded_video.read()) tfile.close() # Open the video file video_cap = cv2.VideoCapture(tfile.name) stframe = st.empty() # Placeholder for displaying video frames # Initialize VideoWriter for saving the output video fourcc = cv2.VideoWriter_fourcc(*'mp4v') fps = video_cap.get(cv2.CAP_PROP_FPS) width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) output_video_path = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4').name out = cv2.VideoWriter(output_video_path, fourcc, fps, (width, height)) frame_count = int(video_cap.get(cv2.CAP_PROP_FRAME_COUNT)) progress_bar = st.progress(0) # Initialize list to collect all detections all_detections = [] for frame_num in range(frame_count): ret, frame = video_cap.read() # Read a frame from the video if not ret: break # Convert frame to RGB frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Perform inference results = model(frame_rgb) # Extract DataFrame from results results_df = results.pandas().xyxy[0] results_df['frame_num'] = frame_num # Optional: Add frame number for reference # Append detections to the list if not results_df.empty: all_detections.append(results_df) # Annotate the frame with detections annotated_frame = annotate_image(frame_rgb, results) # Convert annotated frame back to BGR for VideoWriter annotated_bgr = cv2.cvtColor(annotated_frame, cv2.COLOR_RGB2BGR) # Write the annotated frame to the output video out.write(annotated_bgr) # Display the annotated frame in Streamlit stframe.image(annotated_frame, channels="RGB", use_column_width=True) # Update progress bar progress_percent = (frame_num + 1) / frame_count progress_bar.progress(progress_percent) video_cap.release() # Release the video capture object out.release() # Release the VideoWriter object # Save processed video path and detections summary to session state st.session_state.output_video_path = output_video_path if all_detections: # Concatenate all detections into a single DataFrame detections_df = pd.concat(all_detections, ignore_index=True) # Optional: Group by class name and count detections detections_summary = detections_df.groupby('name').size().reset_index(name='counts') st.session_state.detections_summary = detections_summary else: st.session_state.detections_summary = None # Mark video as processed st.session_state.video_processed = True # st.session_state.uploaded_video_name = uploaded_video.name st.success("Video processing complete!") except Exception as e: st.error(f"An error occurred during video processing: {e}") # Display download button and detection summary if processed if st.session_state.video_processed: try: # Create a download button for the annotated video with open(st.session_state.output_video_path, "rb") as video_file: st.download_button( label="Download Annotated Video", data=video_file, file_name="annotated_video.mp4", mime="video/mp4" ) # Display detection table if there are detections if st.session_state.detections_summary is not None: detections_summary = st.session_state.detections_summary st.write("### Detection Summary") st.dataframe(detections_summary) else: st.warning("No objects detected in the video for the selected classes.") except Exception as e: st.error(f"An error occurred while preparing the download: {e}") # Optionally, display all available classes when 'all' is selected if custom_classes_input.strip().lower() == "all": st.info(f"The model is set to detect **all** available classes: {', '.join(all_model_classes)}") # Unet model training configuration # Constants IMG_SIZE = 128 # Resize dimension for the input image # Load model function @st.cache_resource def load_model(): model = UNet(n_channels=3, n_classes=32) # Adjust according to your model setup model.load_state_dict(torch.load("./model/unet/checkpoint_epoch5.pth", map_location="cpu", weights_only=True), strict=False) model.eval() return model # Function to preprocess the image def preprocess_image(image): transform = T.Compose([ T.Resize((IMG_SIZE, IMG_SIZE)), # Resize to match model input size T.ToTensor(), # Convert to tensor ]) image_tensor = transform(image).unsqueeze(0) # Add batch dimension return image_tensor # Function to postprocess the model output for display def postprocess_mask(mask): # Convert mask to a numpy array and scale to 0-255 mask_np = mask.squeeze().cpu().numpy() # Remove batch and channel dimensions mask_np = (mask_np > 0.5).astype(np.uint8) * 255 # Binarize and scale to 0-255 return mask_np def unet(): try: # Load the model model = load_model() st.markdown( "

Unet object detection

", unsafe_allow_html=True ) st.markdown( "

Using Unet - Pytorch

", unsafe_allow_html=True ) # Display the file upload widget uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) if uploaded_file is not None: st.write("Processing...") # Open and display the uploaded image image = Image.open(uploaded_file).convert("RGB") st.image(image, caption="Uploaded Image", use_column_width=True) # Preprocess the image input_tensor = preprocess_image(image) # Perform inference with torch.no_grad(): # Disable gradient calculation for inference output = model(input_tensor) prediction = torch.sigmoid(output) # Apply sigmoid to get probabilities # Post-process the mask for display mask = postprocess_mask(prediction[0, 0]) # Get the mask from the first batch item # Display the segmentation mask st.image(mask, caption="Segmentation Mask", use_column_width=True) except Exception as e: st.error(f"An error occurred in Unet: {e}") # Main page if 'model_selected' not in st.session_state: st.session_state.model_selected = None def main(): # Radio button for model selection with consistent casing option = st.radio("Select Model:", ("Unet", "YOLO")) # Submit button to confirm selection if st.button("Choose"): st.session_state.model_selected = option st.success(f"Selected Model: {st.session_state.model_selected}") # Render the selected model's interface based on session state if st.session_state.model_selected == "Unet": unet() elif st.session_state.model_selected == "YOLO": yolo() if __name__ == "__main__": main()