"""Object detection demo with MobileNet SSD. This model and code are based on https://github.com/robmarkcole/object-detection-app """ import logging import queue from pathlib import Path from typing import List, NamedTuple import av import cv2 import numpy as np import streamlit as st from streamlit_webrtc import WebRtcMode, webrtc_streamer from sample_utils.download import download_file HERE = Path(__file__).parent ROOT = HERE.parent logger = logging.getLogger(__name__) MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501 MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel" PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501 PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt" CLASSES = [ "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ] @st.experimental_singleton # type: ignore # See https://github.com/python/mypy/issues/7781, https://github.com/python/mypy/issues/12566 # noqa: E501 def generate_label_colors(): return np.random.uniform(0, 255, size=(len(CLASSES), 3)) COLORS = generate_label_colors() download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564) download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353) DEFAULT_CONFIDENCE_THRESHOLD = 0.5 class Detection(NamedTuple): name: str prob: float # Session-specific caching cache_key = "object_detection_dnn" if cache_key in st.session_state: net = st.session_state[cache_key] else: net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH)) st.session_state[cache_key] = net streaming_placeholder = st.empty() confidence_threshold = st.slider( "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05 ) def _annotate_image(image, detections): # loop over the detections (h, w) = image.shape[:2] result: List[Detection] = [] for i in np.arange(0, detections.shape[2]): confidence = detections[0, 0, i, 2] if confidence > confidence_threshold: # extract the index of the class label from the `detections`, # then compute the (x, y)-coordinates of the bounding box for # the object idx = int(detections[0, 0, i, 1]) box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) (startX, startY, endX, endY) = box.astype("int") name = CLASSES[idx] result.append(Detection(name=name, prob=float(confidence))) # display the prediction label = f"{name}: {round(confidence * 100, 2)}%" cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2) y = startY - 15 if startY - 15 > 15 else startY + 15 cv2.putText( image, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2, ) return image, result result_queue: queue.Queue = ( queue.Queue() ) # TODO: A general-purpose shared state object may be more useful. def callback(frame: av.VideoFrame) -> av.VideoFrame: image = frame.to_ndarray(format="bgr24") blob = cv2.dnn.blobFromImage( cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5 ) net.setInput(blob) detections = net.forward() annotated_image, result = _annotate_image(image, detections) # NOTE: This `recv` method is called in another thread, # so it must be thread-safe. result_queue.put(result) # TODO: return av.VideoFrame.from_ndarray(annotated_image, format="bgr24") with streaming_placeholder.container(): webrtc_ctx = webrtc_streamer( key="object-detection", mode=WebRtcMode.SENDRECV, rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, video_frame_callback=callback, media_stream_constraints={"video": True, "audio": False}, async_processing=True, ) if st.checkbox("Show the detected labels", value=True): if webrtc_ctx.state.playing: labels_placeholder = st.empty() # NOTE: The video transformation with object detection and # this loop displaying the result labels are running # in different threads asynchronously. # Then the rendered video frames and the labels displayed here # are not strictly synchronized. while True: try: result = result_queue.get(timeout=1.0) except queue.Empty: result = None labels_placeholder.table(result) st.markdown( "This demo uses a model and code from " "https://github.com/robmarkcole/object-detection-app. " "Many thanks to the project." )