whitphx's picture
whitphx HF Staff
Update app.py
0325cdc
raw
history blame
5.13 kB
"""Object detection demo with MobileNet SSD.
This model and code are based on
https://github.com/robmarkcole/object-detection-app
"""
import logging
import queue
from pathlib import Path
from typing import List, NamedTuple
import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from sample_utils.download import download_file
HERE = Path(__file__).parent
ROOT = HERE.parent
logger = logging.getLogger(__name__)
MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
CLASSES = [
"background",
"aeroplane",
"bicycle",
"bird",
"boat",
"bottle",
"bus",
"car",
"cat",
"chair",
"cow",
"diningtable",
"dog",
"horse",
"motorbike",
"person",
"pottedplant",
"sheep",
"sofa",
"train",
"tvmonitor",
]
@st.experimental_singleton # type: ignore # See https://github.com/python/mypy/issues/7781, https://github.com/python/mypy/issues/12566 # noqa: E501
def generate_label_colors():
return np.random.uniform(0, 255, size=(len(CLASSES), 3))
COLORS = generate_label_colors()
download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
DEFAULT_CONFIDENCE_THRESHOLD = 0.5
class Detection(NamedTuple):
name: str
prob: float
# Session-specific caching
cache_key = "object_detection_dnn"
if cache_key in st.session_state:
net = st.session_state[cache_key]
else:
net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
st.session_state[cache_key] = net
streaming_placeholder = st.empty()
confidence_threshold = st.slider(
"Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05
)
def _annotate_image(image, detections):
# loop over the detections
(h, w) = image.shape[:2]
result: List[Detection] = []
for i in np.arange(0, detections.shape[2]):
confidence = detections[0, 0, i, 2]
if confidence > confidence_threshold:
# extract the index of the class label from the `detections`,
# then compute the (x, y)-coordinates of the bounding box for
# the object
idx = int(detections[0, 0, i, 1])
box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
(startX, startY, endX, endY) = box.astype("int")
name = CLASSES[idx]
result.append(Detection(name=name, prob=float(confidence)))
# display the prediction
label = f"{name}: {round(confidence * 100, 2)}%"
cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
y = startY - 15 if startY - 15 > 15 else startY + 15
cv2.putText(
image,
label,
(startX, y),
cv2.FONT_HERSHEY_SIMPLEX,
0.5,
COLORS[idx],
2,
)
return image, result
result_queue: queue.Queue = (
queue.Queue()
) # TODO: A general-purpose shared state object may be more useful.
def callback(frame: av.VideoFrame) -> av.VideoFrame:
image = frame.to_ndarray(format="bgr24")
blob = cv2.dnn.blobFromImage(
cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
)
net.setInput(blob)
detections = net.forward()
annotated_image, result = _annotate_image(image, detections)
# NOTE: This `recv` method is called in another thread,
# so it must be thread-safe.
result_queue.put(result) # TODO:
return av.VideoFrame.from_ndarray(annotated_image, format="bgr24")
with streaming_placeholder.container():
webrtc_ctx = webrtc_streamer(
key="object-detection",
mode=WebRtcMode.SENDRECV,
rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
video_frame_callback=callback,
media_stream_constraints={"video": True, "audio": False},
async_processing=True,
)
if st.checkbox("Show the detected labels", value=True):
if webrtc_ctx.state.playing:
labels_placeholder = st.empty()
# NOTE: The video transformation with object detection and
# this loop displaying the result labels are running
# in different threads asynchronously.
# Then the rendered video frames and the labels displayed here
# are not strictly synchronized.
while True:
try:
result = result_queue.get(timeout=1.0)
except queue.Empty:
result = None
labels_placeholder.table(result)
st.markdown(
"This demo uses a model and code from "
"https://github.com/robmarkcole/object-detection-app. "
"Many thanks to the project."
)