import logging import logging.handlers import queue import urllib.request from pathlib import Path try: from typing import Literal except ImportError: from typing_extensions import Literal # type: ignore import av import cv2 import numpy as np import PIL import streamlit as st from aiortc.contrib.media import MediaPlayer from streamlit_webrtc import ( ClientSettings, VideoTransformerBase, WebRtcMode, webrtc_streamer, ) HERE = Path(__file__).parent logger = logging.getLogger(__name__) # This code is based on https://github.com/streamlit/demo-self-driving/blob/230245391f2dda0cb464008195a470751c01770b/streamlit_app.py#L48 # noqa: E501 def download_file(url, download_to: Path, expected_size=None): # Don't download the file twice. # (If possible, verify the download using the file length.) if download_to.exists(): if expected_size: if download_to.stat().st_size == expected_size: return else: st.info(f"{url} is already downloaded.") if not st.button("Download again?"): return download_to.parent.mkdir(parents=True, exist_ok=True) # These are handles to two visual elements to animate. weights_warning, progress_bar = None, None try: weights_warning = st.warning("Downloading %s..." % url) progress_bar = st.progress(0) with open(download_to, "wb") as output_file: with urllib.request.urlopen(url) as response: length = int(response.info()["Content-Length"]) counter = 0.0 MEGABYTES = 2.0 ** 20.0 while True: data = response.read(8192) if not data: break counter += len(data) output_file.write(data) # We perform animation by overwriting the elements. weights_warning.warning( "Downloading %s... (%6.2f/%6.2f MB)" % (url, counter / MEGABYTES, length / MEGABYTES) ) progress_bar.progress(min(counter / length, 1.0)) # Finally, we remove these visual elements by calling .empty(). finally: if weights_warning is not None: weights_warning.empty() if progress_bar is not None: progress_bar.empty() def main(): st.header("WebRTC demo") object_detection_page = "Real time object detection (sendrecv)" video_filters_page = ( "Real time video transform with simple OpenCV filters (sendrecv)" ) streaming_page = ( "Consuming media files on server-side and streaming it to browser (recvonly)" ) sendonly_page = "WebRTC is sendonly and images are shown via st.image() (sendonly)" loopback_page = "Simple video loopback (sendrecv)" app_mode = st.sidebar.selectbox( "Choose the app mode", [ object_detection_page, video_filters_page, streaming_page, sendonly_page, loopback_page, ], ) st.subheader(app_mode) if app_mode == video_filters_page: app_video_filters() elif app_mode == object_detection_page: app_object_detection() elif app_mode == streaming_page: app_streaming() elif app_mode == sendonly_page: app_sendonly() elif app_mode == loopback_page: app_loopback() def app_loopback(): """ Simple video loopback """ webrtc_streamer( key="loopback", mode=WebRtcMode.SENDRECV, client_settings=WEBRTC_CLIENT_SETTINGS, video_transformer_factory=None, # NoOp ) def app_video_filters(): """ Video transforms with OpenCV """ class OpenCVVideoTransformer(VideoTransformerBase): type: Literal["noop", "cartoon", "edges", "rotate"] def __init__(self) -> None: self.type = "noop" def transform(self, frame: av.VideoFrame) -> av.VideoFrame: img = frame.to_ndarray(format="bgr24") if self.type == "noop": pass elif self.type == "cartoon": # prepare color img_color = cv2.pyrDown(cv2.pyrDown(img)) for _ in range(6): img_color = cv2.bilateralFilter(img_color, 9, 9, 7) img_color = cv2.pyrUp(cv2.pyrUp(img_color)) # prepare edges img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) img_edges = cv2.adaptiveThreshold( cv2.medianBlur(img_edges, 7), 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 9, 2, ) img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB) # combine color and edges img = cv2.bitwise_and(img_color, img_edges) elif self.type == "edges": # perform edge detection img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR) elif self.type == "rotate": # rotate image rows, cols, _ = img.shape M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1) img = cv2.warpAffine(img, M, (cols, rows)) return img webrtc_ctx = webrtc_streamer( key="opencv-filter", mode=WebRtcMode.SENDRECV, client_settings=WEBRTC_CLIENT_SETTINGS, video_transformer_factory=OpenCVVideoTransformer, async_transform=True, ) transform_type = st.radio( "Select transform type", ("noop", "cartoon", "edges", "rotate") ) if webrtc_ctx.video_transformer: webrtc_ctx.video_transformer.type = transform_type st.markdown( "This demo is based on " "https://github.com/aiortc/aiortc/blob/2362e6d1f0c730a0f8c387bbea76546775ad2fe8/examples/server/server.py#L34. " # noqa: E501 "Many thanks to the project." ) def app_object_detection(): """Object detection demo with MobileNet SSD. This model and code are based on https://github.com/robmarkcole/object-detection-app """ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501 MODEL_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.caffemodel" PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501 PROTOTXT_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.prototxt.txt" CLASSES = [ "background", "aeroplane", "bicycle", "bird", "boat", "bottle", "bus", "car", "cat", "chair", "cow", "diningtable", "dog", "horse", "motorbike", "person", "pottedplant", "sheep", "sofa", "train", "tvmonitor", ] COLORS = np.random.uniform(0, 255, size=(len(CLASSES), 3)) download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564) download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353) DEFAULT_CONFIDENCE_THRESHOLD = 0.5 class NNVideoTransformer(VideoTransformerBase): confidence_threshold: float def __init__(self) -> None: self._net = cv2.dnn.readNetFromCaffe( str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH) ) self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD def _annotate_image(self, image, detections): # loop over the detections (h, w) = image.shape[:2] labels = [] for i in np.arange(0, detections.shape[2]): confidence = detections[0, 0, i, 2] if confidence > self.confidence_threshold: # extract the index of the class label from the `detections`, # then compute the (x, y)-coordinates of the bounding box for # the object idx = int(detections[0, 0, i, 1]) box = detections[0, 0, i, 3:7] * np.array([w, h, w, h]) (startX, startY, endX, endY) = box.astype("int") # display the prediction label = f"{CLASSES[idx]}: {round(confidence * 100, 2)}%" labels.append(label) cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2) y = startY - 15 if startY - 15 > 15 else startY + 15 cv2.putText( image, label, (startX, y), cv2.FONT_HERSHEY_SIMPLEX, 0.5, COLORS[idx], 2, ) return image, labels def transform(self, frame: av.VideoFrame) -> np.ndarray: image = frame.to_ndarray(format="bgr24") blob = cv2.dnn.blobFromImage( cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5 ) self._net.setInput(blob) detections = self._net.forward() annotated_image, labels = self._annotate_image(image, detections) # TODO: Show labels return annotated_image webrtc_ctx = webrtc_streamer( key="object-detection", mode=WebRtcMode.SENDRECV, client_settings=WEBRTC_CLIENT_SETTINGS, video_transformer_factory=NNVideoTransformer, async_transform=True, ) confidence_threshold = st.slider( "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05 ) if webrtc_ctx.video_transformer: webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold st.markdown( "This demo uses a model and code from " "https://github.com/robmarkcole/object-detection-app. " "Many thanks to the project." ) def app_streaming(): """ Media streamings """ MEDIAFILES = { "big_buck_bunny_720p_2mb.mp4": { "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_2mb.mp4", # noqa: E501 "local_file_path": HERE / "data/big_buck_bunny_720p_2mb.mp4", "type": "video", }, "big_buck_bunny_720p_10mb.mp4": { "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_10mb.mp4", # noqa: E501 "local_file_path": HERE / "data/big_buck_bunny_720p_10mb.mp4", "type": "video", }, "file_example_MP3_700KB.mp3": { "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_700KB.mp3", # noqa: E501 "local_file_path": HERE / "data/file_example_MP3_700KB.mp3", "type": "audio", }, "file_example_MP3_5MG.mp3": { "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_5MG.mp3", # noqa: E501 "local_file_path": HERE / "data/file_example_MP3_5MG.mp3", "type": "audio", }, } media_file_label = st.radio( "Select a media file to stream", tuple(MEDIAFILES.keys()) ) media_file_info = MEDIAFILES[media_file_label] download_file(media_file_info["url"], media_file_info["local_file_path"]) def create_player(): return MediaPlayer(str(media_file_info["local_file_path"])) # NOTE: To stream the video from webcam, use the code below. # return MediaPlayer( # "1:none", # format="avfoundation", # options={"framerate": "30", "video_size": "1280x720"}, # ) WEBRTC_CLIENT_SETTINGS.update( { "fmedia_stream_constraints": { "video": media_file_info["type"] == "video", "audio": media_file_info["type"] == "audio", } } ) webrtc_streamer( key=f"media-streaming-{media_file_label}", mode=WebRtcMode.RECVONLY, client_settings=WEBRTC_CLIENT_SETTINGS, player_factory=create_player, ) def app_sendonly(): """A sample to use WebRTC in sendonly mode to transfer frames from the browser to the server and to render frames via `st.image`.""" webrtc_ctx = webrtc_streamer( key="loopback", mode=WebRtcMode.SENDONLY, client_settings=WEBRTC_CLIENT_SETTINGS, ) if webrtc_ctx.video_receiver: image_loc = st.empty() while True: try: frame = webrtc_ctx.video_receiver.frames_queue.get(timeout=1) except queue.Empty: print("Queue is empty. Stop the loop.") webrtc_ctx.video_receiver.stop() break img = frame.to_ndarray(format="bgr24") img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) image_loc.image(img) WEBRTC_CLIENT_SETTINGS = ClientSettings( rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}, media_stream_constraints={"video": True, "audio": True}, ) if __name__ == "__main__": logging.basicConfig( format="[%(asctime)s] %(levelname)7s from %(name)s in %(filename)s:%(lineno)d: " "%(message)s", force=True, ) logger.setLevel(level=logging.DEBUG) st_webrtc_logger = logging.getLogger("streamlit_webrtc") st_webrtc_logger.setLevel(logging.DEBUG) main()