whitphx HF staff commited on
Commit
0325cdc
·
1 Parent(s): 583219b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -626
app.py CHANGED
@@ -1,676 +1,172 @@
1
- import asyncio
 
 
 
 
2
  import logging
3
  import queue
4
- import threading
5
- import urllib.request
6
  from pathlib import Path
7
- from typing import List, NamedTuple, Optional
8
 
9
  import av
10
  import cv2
11
- import matplotlib.pyplot as plt
12
  import numpy as np
13
- import pydub
14
  import streamlit as st
15
- from aiortc.contrib.media import MediaPlayer
16
 
17
- from streamlit_webrtc import (
18
- RTCConfiguration,
19
- WebRtcMode,
20
- WebRtcStreamerContext,
21
- webrtc_streamer,
22
- )
23
 
24
  HERE = Path(__file__).parent
 
25
 
26
  logger = logging.getLogger(__name__)
27
 
28
 
29
- # This code is based on https://github.com/streamlit/demo-self-driving/blob/230245391f2dda0cb464008195a470751c01770b/streamlit_app.py#L48 # noqa: E501
30
- def download_file(url, download_to: Path, expected_size=None):
31
- # Don't download the file twice.
32
- # (If possible, verify the download using the file length.)
33
- if download_to.exists():
34
- if expected_size:
35
- if download_to.stat().st_size == expected_size:
36
- return
37
- else:
38
- st.info(f"{url} is already downloaded.")
39
- if not st.button("Download again?"):
40
- return
41
-
42
- download_to.parent.mkdir(parents=True, exist_ok=True)
43
-
44
- # These are handles to two visual elements to animate.
45
- weights_warning, progress_bar = None, None
46
- try:
47
- weights_warning = st.warning("Downloading %s..." % url)
48
- progress_bar = st.progress(0)
49
- with open(download_to, "wb") as output_file:
50
- with urllib.request.urlopen(url) as response:
51
- length = int(response.info()["Content-Length"])
52
- counter = 0.0
53
- MEGABYTES = 2.0 ** 20.0
54
- while True:
55
- data = response.read(8192)
56
- if not data:
57
- break
58
- counter += len(data)
59
- output_file.write(data)
60
-
61
- # We perform animation by overwriting the elements.
62
- weights_warning.warning(
63
- "Downloading %s... (%6.2f/%6.2f MB)"
64
- % (url, counter / MEGABYTES, length / MEGABYTES)
65
- )
66
- progress_bar.progress(min(counter / length, 1.0))
67
- # Finally, we remove these visual elements by calling .empty().
68
- finally:
69
- if weights_warning is not None:
70
- weights_warning.empty()
71
- if progress_bar is not None:
72
- progress_bar.empty()
73
-
74
-
75
- RTC_CONFIGURATION = RTCConfiguration(
76
- {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
77
- )
78
-
79
-
80
- def main():
81
- st.header("WebRTC demo")
82
-
83
- pages = {
84
- "Real time object detection (sendrecv)": app_object_detection,
85
- "Real time video transform with simple OpenCV filters (sendrecv)": app_video_filters, # noqa: E501
86
- "Real time audio filter (sendrecv)": app_audio_filter,
87
- "Delayed echo (sendrecv)": app_delayed_echo,
88
- "Consuming media files on server-side and streaming it to browser (recvonly)": app_streaming, # noqa: E501
89
- "WebRTC is sendonly and images are shown via st.image() (sendonly)": app_sendonly_video, # noqa: E501
90
- "WebRTC is sendonly and audio frames are visualized with matplotlib (sendonly)": app_sendonly_audio, # noqa: E501
91
- "Simple video and audio loopback (sendrecv)": app_loopback,
92
- "Configure media constraints and HTML element styles with loopback (sendrecv)": app_media_constraints, # noqa: E501
93
- "Control the playing state programatically": app_programatically_play,
94
- "Customize UI texts": app_customize_ui_texts,
95
- }
96
- page_titles = pages.keys()
97
-
98
- page_title = st.sidebar.selectbox(
99
- "Choose the app mode",
100
- page_titles,
101
- )
102
- st.subheader(page_title)
103
-
104
- page_func = pages[page_title]
105
- page_func()
106
-
107
- st.sidebar.markdown(
108
- """
109
- ---
110
- <a href="https://www.buymeacoffee.com/whitphx" target="_blank"><img src="https://cdn.buymeacoffee.com/buttons/v2/default-yellow.png" alt="Buy Me A Coffee" width="180" height="50" ></a>
111
- """, # noqa: E501
112
- unsafe_allow_html=True,
113
- )
114
 
115
- logger.debug("=== Alive threads ===")
116
- for thread in threading.enumerate():
117
- if thread.is_alive():
118
- logger.debug(f" {thread.name} ({thread.ident})")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
 
120
 
121
- def app_loopback():
122
- """Simple video loopback"""
123
- webrtc_streamer(key="loopback")
124
 
125
 
126
- def app_video_filters():
127
- """Video transforms with OpenCV"""
128
 
129
- _type = st.radio("Select transform type", ("noop", "cartoon", "edges", "rotate"))
 
130
 
131
- def callback(frame: av.VideoFrame) -> av.VideoFrame:
132
- img = frame.to_ndarray(format="bgr24")
133
 
134
- if _type == "noop":
135
- pass
136
- elif _type == "cartoon":
137
- # prepare color
138
- img_color = cv2.pyrDown(cv2.pyrDown(img))
139
- for _ in range(6):
140
- img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
141
- img_color = cv2.pyrUp(cv2.pyrUp(img_color))
142
 
143
- # prepare edges
144
- img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
145
- img_edges = cv2.adaptiveThreshold(
146
- cv2.medianBlur(img_edges, 7),
147
- 255,
148
- cv2.ADAPTIVE_THRESH_MEAN_C,
149
- cv2.THRESH_BINARY,
150
- 9,
151
- 2,
152
- )
153
- img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)
154
-
155
- # combine color and edges
156
- img = cv2.bitwise_and(img_color, img_edges)
157
- elif _type == "edges":
158
- # perform edge detection
159
- img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)
160
- elif _type == "rotate":
161
- # rotate image
162
- rows, cols, _ = img.shape
163
- M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
164
- img = cv2.warpAffine(img, M, (cols, rows))
165
-
166
- return av.VideoFrame.from_ndarray(img, format="bgr24")
167
-
168
- webrtc_streamer(
169
- key="opencv-filter",
170
- mode=WebRtcMode.SENDRECV,
171
- rtc_configuration=RTC_CONFIGURATION,
172
- video_frame_callback=callback,
173
- media_stream_constraints={"video": True, "audio": False},
174
- async_processing=True,
175
- )
176
 
177
- st.markdown(
178
- "This demo is based on "
179
- "https://github.com/aiortc/aiortc/blob/2362e6d1f0c730a0f8c387bbea76546775ad2fe8/examples/server/server.py#L34. " # noqa: E501
180
- "Many thanks to the project."
181
- )
182
 
 
 
 
 
 
 
 
183
 
184
- def app_audio_filter():
185
- gain = st.slider("Gain", -10.0, +20.0, 1.0, 0.05)
186
 
187
- def process_audio(frame: av.AudioFrame) -> av.AudioFrame:
188
- raw_samples = frame.to_ndarray()
189
- sound = pydub.AudioSegment(
190
- data=raw_samples.tobytes(),
191
- sample_width=frame.format.bytes,
192
- frame_rate=frame.sample_rate,
193
- channels=len(frame.layout.channels),
194
- )
195
 
196
- sound = sound.apply_gain(gain)
197
 
198
- # Ref: https://github.com/jiaaro/pydub/blob/master/API.markdown#audiosegmentget_array_of_samples # noqa
199
- channel_sounds = sound.split_to_mono()
200
- channel_samples = [s.get_array_of_samples() for s in channel_sounds]
201
- new_samples: np.ndarray = np.array(channel_samples).T
202
- new_samples = new_samples.reshape(raw_samples.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
 
204
- new_frame = av.AudioFrame.from_ndarray(new_samples, layout=frame.layout.name)
205
- new_frame.sample_rate = frame.sample_rate
206
- return new_frame
207
 
208
- webrtc_streamer(
209
- key="audio-filter",
210
- mode=WebRtcMode.SENDRECV,
211
- rtc_configuration=RTC_CONFIGURATION,
212
- audio_frame_callback=process_audio,
213
- async_processing=True,
214
- )
215
 
216
 
217
- def app_delayed_echo():
218
- delay = st.slider("Delay", 0.0, 5.0, 1.0, 0.05)
219
-
220
- async def queued_video_frames_callback(
221
- frames: List[av.VideoFrame],
222
- ) -> List[av.VideoFrame]:
223
- logger.debug("Delay: %f", delay)
224
- # A standalone `await ...` is interpreted as an expression and
225
- # the Streamlit magic's target, which leads implicit calls of `st.write`.
226
- # To prevent it, fix it as `_ = await ...`, a statement.
227
- # See https://discuss.streamlit.io/t/issue-with-asyncio-run-in-streamlit/7745/15
228
- _ = await asyncio.sleep(delay)
229
- return frames
230
-
231
- async def queued_audio_frames_callback(
232
- frames: List[av.AudioFrame],
233
- ) -> List[av.AudioFrame]:
234
- _ = await asyncio.sleep(delay)
235
- return frames
236
-
237
- webrtc_streamer(
238
- key="delay",
239
- mode=WebRtcMode.SENDRECV,
240
- rtc_configuration=RTC_CONFIGURATION,
241
- queued_video_frames_callback=queued_video_frames_callback,
242
- queued_audio_frames_callback=queued_audio_frames_callback,
243
- async_processing=True,
244
  )
 
 
 
245
 
 
 
 
246
 
247
- def app_object_detection():
248
- """Object detection demo with MobileNet SSD.
249
- This model and code are based on
250
- https://github.com/robmarkcole/object-detection-app
251
- """
252
- MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
253
- MODEL_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.caffemodel"
254
- PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
255
- PROTOTXT_LOCAL_PATH = HERE / "./models/MobileNetSSD_deploy.prototxt.txt"
256
-
257
- CLASSES = [
258
- "background",
259
- "aeroplane",
260
- "bicycle",
261
- "bird",
262
- "boat",
263
- "bottle",
264
- "bus",
265
- "car",
266
- "cat",
267
- "chair",
268
- "cow",
269
- "diningtable",
270
- "dog",
271
- "horse",
272
- "motorbike",
273
- "person",
274
- "pottedplant",
275
- "sheep",
276
- "sofa",
277
- "train",
278
- "tvmonitor",
279
- ]
280
-
281
- @st.experimental_singleton
282
- def generate_label_colors():
283
- return np.random.uniform(0, 255, size=(len(CLASSES), 3))
284
-
285
- COLORS = generate_label_colors()
286
-
287
- download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
288
- download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
289
-
290
- DEFAULT_CONFIDENCE_THRESHOLD = 0.5
291
-
292
- class Detection(NamedTuple):
293
- name: str
294
- prob: float
295
-
296
- # Session-specific caching
297
- cache_key = "object_detection_dnn"
298
- if cache_key in st.session_state:
299
- net = st.session_state[cache_key]
300
- else:
301
- net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
302
- st.session_state[cache_key] = net
303
-
304
- confidence_threshold = st.slider(
305
- "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05
306
- )
307
 
308
- def _annotate_image(image, detections):
309
- # loop over the detections
310
- (h, w) = image.shape[:2]
311
- result: List[Detection] = []
312
- for i in np.arange(0, detections.shape[2]):
313
- confidence = detections[0, 0, i, 2]
314
-
315
- if confidence > confidence_threshold:
316
- # extract the index of the class label from the `detections`,
317
- # then compute the (x, y)-coordinates of the bounding box for
318
- # the object
319
- idx = int(detections[0, 0, i, 1])
320
- box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
321
- (startX, startY, endX, endY) = box.astype("int")
322
-
323
- name = CLASSES[idx]
324
- result.append(Detection(name=name, prob=float(confidence)))
325
-
326
- # display the prediction
327
- label = f"{name}: {round(confidence * 100, 2)}%"
328
- cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
329
- y = startY - 15 if startY - 15 > 15 else startY + 15
330
- cv2.putText(
331
- image,
332
- label,
333
- (startX, y),
334
- cv2.FONT_HERSHEY_SIMPLEX,
335
- 0.5,
336
- COLORS[idx],
337
- 2,
338
- )
339
- return image, result
340
-
341
- result_queue = (
342
- queue.Queue()
343
- ) # TODO: A general-purpose shared state object may be more useful.
344
-
345
- def callback(frame: av.VideoFrame) -> av.VideoFrame:
346
- image = frame.to_ndarray(format="bgr24")
347
- blob = cv2.dnn.blobFromImage(
348
- cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
349
- )
350
- net.setInput(blob)
351
- detections = net.forward()
352
- annotated_image, result = _annotate_image(image, detections)
353
-
354
- # NOTE: This `recv` method is called in another thread,
355
- # so it must be thread-safe.
356
- result_queue.put(result) # TODO:
357
-
358
- return av.VideoFrame.from_ndarray(annotated_image, format="bgr24")
359
 
 
360
  webrtc_ctx = webrtc_streamer(
361
  key="object-detection",
362
  mode=WebRtcMode.SENDRECV,
363
- rtc_configuration=RTC_CONFIGURATION,
364
  video_frame_callback=callback,
365
  media_stream_constraints={"video": True, "audio": False},
366
  async_processing=True,
367
  )
368
 
369
- if st.checkbox("Show the detected labels", value=True):
370
- if webrtc_ctx.state.playing:
371
- labels_placeholder = st.empty()
372
- # NOTE: The video transformation with object detection and
373
- # this loop displaying the result labels are running
374
- # in different threads asynchronously.
375
- # Then the rendered video frames and the labels displayed here
376
- # are not strictly synchronized.
377
- while True:
378
- try:
379
- result = result_queue.get(timeout=1.0)
380
- except queue.Empty:
381
- result = None
382
- labels_placeholder.table(result)
383
-
384
- st.markdown(
385
- "This demo uses a model and code from "
386
- "https://github.com/robmarkcole/object-detection-app. "
387
- "Many thanks to the project."
388
- )
389
-
390
-
391
- def app_streaming():
392
- """Media streamings"""
393
- MEDIAFILES = {
394
- "big_buck_bunny_720p_2mb.mp4 (local)": {
395
- "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_2mb.mp4", # noqa: E501
396
- "local_file_path": HERE / "data/big_buck_bunny_720p_2mb.mp4",
397
- "type": "video",
398
- },
399
- "big_buck_bunny_720p_10mb.mp4 (local)": {
400
- "url": "https://sample-videos.com/video123/mp4/720/big_buck_bunny_720p_10mb.mp4", # noqa: E501
401
- "local_file_path": HERE / "data/big_buck_bunny_720p_10mb.mp4",
402
- "type": "video",
403
- },
404
- "file_example_MP3_700KB.mp3 (local)": {
405
- "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_700KB.mp3", # noqa: E501
406
- "local_file_path": HERE / "data/file_example_MP3_700KB.mp3",
407
- "type": "audio",
408
- },
409
- "file_example_MP3_5MG.mp3 (local)": {
410
- "url": "https://file-examples-com.github.io/uploads/2017/11/file_example_MP3_5MG.mp3", # noqa: E501
411
- "local_file_path": HERE / "data/file_example_MP3_5MG.mp3",
412
- "type": "audio",
413
- },
414
- "rtsp://wowzaec2demo.streamlock.net/vod/mp4:BigBuckBunny_115k.mov": {
415
- "url": "rtsp://wowzaec2demo.streamlock.net/vod/mp4:BigBuckBunny_115k.mov",
416
- "type": "video",
417
- },
418
- }
419
- media_file_label = st.radio(
420
- "Select a media source to stream", tuple(MEDIAFILES.keys())
421
- )
422
- media_file_info = MEDIAFILES[media_file_label]
423
- if "local_file_path" in media_file_info:
424
- download_file(media_file_info["url"], media_file_info["local_file_path"])
425
-
426
- def create_player():
427
- if "local_file_path" in media_file_info:
428
- return MediaPlayer(str(media_file_info["local_file_path"]))
429
- else:
430
- return MediaPlayer(media_file_info["url"])
431
-
432
- # NOTE: To stream the video from webcam, use the code below.
433
- # return MediaPlayer(
434
- # "1:none",
435
- # format="avfoundation",
436
- # options={"framerate": "30", "video_size": "1280x720"},
437
- # )
438
-
439
- key = f"media-streaming-{media_file_label}"
440
- ctx: Optional[WebRtcStreamerContext] = st.session_state.get(key)
441
- if media_file_info["type"] == "video" and ctx and ctx.state.playing:
442
- _type = st.radio(
443
- "Select transform type", ("noop", "cartoon", "edges", "rotate")
444
- )
445
- else:
446
- _type = "noop"
447
-
448
- def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
449
- img = frame.to_ndarray(format="bgr24")
450
-
451
- if _type == "noop":
452
- pass
453
- elif _type == "cartoon":
454
- # prepare color
455
- img_color = cv2.pyrDown(cv2.pyrDown(img))
456
- for _ in range(6):
457
- img_color = cv2.bilateralFilter(img_color, 9, 9, 7)
458
- img_color = cv2.pyrUp(cv2.pyrUp(img_color))
459
-
460
- # prepare edges
461
- img_edges = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
462
- img_edges = cv2.adaptiveThreshold(
463
- cv2.medianBlur(img_edges, 7),
464
- 255,
465
- cv2.ADAPTIVE_THRESH_MEAN_C,
466
- cv2.THRESH_BINARY,
467
- 9,
468
- 2,
469
- )
470
- img_edges = cv2.cvtColor(img_edges, cv2.COLOR_GRAY2RGB)
471
-
472
- # combine color and edges
473
- img = cv2.bitwise_and(img_color, img_edges)
474
- elif _type == "edges":
475
- # perform edge detection
476
- img = cv2.cvtColor(cv2.Canny(img, 100, 200), cv2.COLOR_GRAY2BGR)
477
- elif _type == "rotate":
478
- # rotate image
479
- rows, cols, _ = img.shape
480
- M = cv2.getRotationMatrix2D((cols / 2, rows / 2), frame.time * 45, 1)
481
- img = cv2.warpAffine(img, M, (cols, rows))
482
-
483
- return av.VideoFrame.from_ndarray(img, format="bgr24")
484
-
485
- webrtc_streamer(
486
- key=key,
487
- mode=WebRtcMode.RECVONLY,
488
- rtc_configuration=RTC_CONFIGURATION,
489
- media_stream_constraints={
490
- "video": media_file_info["type"] == "video",
491
- "audio": media_file_info["type"] == "audio",
492
- },
493
- player_factory=create_player,
494
- video_frame_callback=video_frame_callback,
495
- )
496
-
497
- st.markdown(
498
- "The video filter in this demo is based on "
499
- "https://github.com/aiortc/aiortc/blob/2362e6d1f0c730a0f8c387bbea76546775ad2fe8/examples/server/server.py#L34. " # noqa: E501
500
- "Many thanks to the project."
501
- )
502
-
503
-
504
- def app_sendonly_video():
505
- """A sample to use WebRTC in sendonly mode to transfer frames
506
- from the browser to the server and to render frames via `st.image`."""
507
- webrtc_ctx = webrtc_streamer(
508
- key="video-sendonly",
509
- mode=WebRtcMode.SENDONLY,
510
- rtc_configuration=RTC_CONFIGURATION,
511
- media_stream_constraints={"video": True},
512
- )
513
-
514
- image_place = st.empty()
515
-
516
- while True:
517
- if webrtc_ctx.video_receiver:
518
- try:
519
- video_frame = webrtc_ctx.video_receiver.get_frame(timeout=1)
520
- except queue.Empty:
521
- logger.warning("Queue is empty. Abort.")
522
- break
523
-
524
- img_rgb = video_frame.to_ndarray(format="rgb24")
525
- image_place.image(img_rgb)
526
- else:
527
- logger.warning("AudioReciver is not set. Abort.")
528
- break
529
-
530
-
531
- def app_sendonly_audio():
532
- """A sample to use WebRTC in sendonly mode to transfer audio frames
533
- from the browser to the server and visualize them with matplotlib
534
- and `st.pyplot`."""
535
- webrtc_ctx = webrtc_streamer(
536
- key="sendonly-audio",
537
- mode=WebRtcMode.SENDONLY,
538
- audio_receiver_size=256,
539
- rtc_configuration=RTC_CONFIGURATION,
540
- media_stream_constraints={"audio": True},
541
- )
542
-
543
- fig_place = st.empty()
544
-
545
- fig, [ax_time, ax_freq] = plt.subplots(
546
- 2, 1, gridspec_kw={"top": 1.5, "bottom": 0.2}
547
- )
548
-
549
- sound_window_len = 5000 # 5s
550
- sound_window_buffer = None
551
- while True:
552
- if webrtc_ctx.audio_receiver:
553
  try:
554
- audio_frames = webrtc_ctx.audio_receiver.get_frames(timeout=1)
555
  except queue.Empty:
556
- logger.warning("Queue is empty. Abort.")
557
- break
558
-
559
- sound_chunk = pydub.AudioSegment.empty()
560
- for audio_frame in audio_frames:
561
- sound = pydub.AudioSegment(
562
- data=audio_frame.to_ndarray().tobytes(),
563
- sample_width=audio_frame.format.bytes,
564
- frame_rate=audio_frame.sample_rate,
565
- channels=len(audio_frame.layout.channels),
566
- )
567
- sound_chunk += sound
568
-
569
- if len(sound_chunk) > 0:
570
- if sound_window_buffer is None:
571
- sound_window_buffer = pydub.AudioSegment.silent(
572
- duration=sound_window_len
573
- )
574
-
575
- sound_window_buffer += sound_chunk
576
- if len(sound_window_buffer) > sound_window_len:
577
- sound_window_buffer = sound_window_buffer[-sound_window_len:]
578
-
579
- if sound_window_buffer:
580
- # Ref: https://own-search-and-study.xyz/2017/10/27/python%E3%82%92%E4%BD%BF%E3%81%A3%E3%81%A6%E9%9F%B3%E5%A3%B0%E3%83%87%E3%83%BC%E3%82%BF%E3%81%8B%E3%82%89%E3%82%B9%E3%83%9A%E3%82%AF%E3%83%88%E3%83%AD%E3%82%B0%E3%83%A9%E3%83%A0%E3%82%92%E4%BD%9C/ # noqa
581
- sound_window_buffer = sound_window_buffer.set_channels(
582
- 1
583
- ) # Stereo to mono
584
- sample = np.array(sound_window_buffer.get_array_of_samples())
585
-
586
- ax_time.cla()
587
- times = (np.arange(-len(sample), 0)) / sound_window_buffer.frame_rate
588
- ax_time.plot(times, sample)
589
- ax_time.set_xlabel("Time")
590
- ax_time.set_ylabel("Magnitude")
591
-
592
- spec = np.fft.fft(sample)
593
- freq = np.fft.fftfreq(sample.shape[0], 1.0 / sound_chunk.frame_rate)
594
- freq = freq[: int(freq.shape[0] / 2)]
595
- spec = spec[: int(spec.shape[0] / 2)]
596
- spec[0] = spec[0] / 2
597
-
598
- ax_freq.cla()
599
- ax_freq.plot(freq, np.abs(spec))
600
- ax_freq.set_xlabel("Frequency")
601
- ax_freq.set_yscale("log")
602
- ax_freq.set_ylabel("Magnitude")
603
-
604
- fig_place.pyplot(fig)
605
- else:
606
- logger.warning("AudioReciver is not set. Abort.")
607
- break
608
-
609
-
610
- def app_media_constraints():
611
- """A sample to configure MediaStreamConstraints object"""
612
- frame_rate = 5
613
- webrtc_streamer(
614
- key="media-constraints",
615
- mode=WebRtcMode.SENDRECV,
616
- rtc_configuration=RTC_CONFIGURATION,
617
- media_stream_constraints={
618
- "video": {"frameRate": {"ideal": frame_rate}},
619
- },
620
- video_html_attrs={
621
- "style": {"width": "50%", "margin": "0 auto", "border": "5px yellow solid"},
622
- "controls": False,
623
- "autoPlay": True,
624
- },
625
- )
626
- st.write(f"The frame rate is set as {frame_rate}. Video style is changed.")
627
 
628
-
629
- def app_programatically_play():
630
- """A sample of controlling the playing state from Python."""
631
- playing = st.checkbox("Playing", value=True)
632
-
633
- webrtc_streamer(
634
- key="programatic_control",
635
- desired_playing_state=playing,
636
- mode=WebRtcMode.SENDRECV,
637
- rtc_configuration=RTC_CONFIGURATION,
638
- )
639
-
640
-
641
- def app_customize_ui_texts():
642
- webrtc_streamer(
643
- key="custom_ui_texts",
644
- rtc_configuration=RTC_CONFIGURATION,
645
- translations={
646
- "start": "開始",
647
- "stop": "停止",
648
- "select_device": "デバイス選択",
649
- "media_api_not_available": "Media APIが利用できない環境です",
650
- "device_ask_permission": "メディアデバイスへのアクセスを許可してください",
651
- "device_not_available": "メディアデバイスを利用できません",
652
- "device_access_denied": "メディアデバイスへのアクセスが拒否されました",
653
- },
654
- )
655
-
656
-
657
- if __name__ == "__main__":
658
- import os
659
-
660
- DEBUG = os.environ.get("DEBUG", "false").lower() not in ["false", "no", "0"]
661
-
662
- logging.basicConfig(
663
- format="[%(asctime)s] %(levelname)7s from %(name)s in %(pathname)s:%(lineno)d: "
664
- "%(message)s",
665
- force=True,
666
- )
667
-
668
- logger.setLevel(level=logging.DEBUG if DEBUG else logging.INFO)
669
-
670
- st_webrtc_logger = logging.getLogger("streamlit_webrtc")
671
- st_webrtc_logger.setLevel(logging.DEBUG)
672
-
673
- fsevents_logger = logging.getLogger("fsevents")
674
- fsevents_logger.setLevel(logging.WARNING)
675
-
676
- main()
 
1
+ """Object detection demo with MobileNet SSD.
2
+ This model and code are based on
3
+ https://github.com/robmarkcole/object-detection-app
4
+ """
5
+
6
  import logging
7
  import queue
 
 
8
  from pathlib import Path
9
+ from typing import List, NamedTuple
10
 
11
  import av
12
  import cv2
 
13
  import numpy as np
 
14
  import streamlit as st
15
+ from streamlit_webrtc import WebRtcMode, webrtc_streamer
16
 
17
+ from sample_utils.download import download_file
 
 
 
 
 
18
 
19
  HERE = Path(__file__).parent
20
+ ROOT = HERE.parent
21
 
22
  logger = logging.getLogger(__name__)
23
 
24
 
25
+ MODEL_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.caffemodel" # noqa: E501
26
+ MODEL_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.caffemodel"
27
+ PROTOTXT_URL = "https://github.com/robmarkcole/object-detection-app/raw/master/model/MobileNetSSD_deploy.prototxt.txt" # noqa: E501
28
+ PROTOTXT_LOCAL_PATH = ROOT / "./models/MobileNetSSD_deploy.prototxt.txt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
+ CLASSES = [
31
+ "background",
32
+ "aeroplane",
33
+ "bicycle",
34
+ "bird",
35
+ "boat",
36
+ "bottle",
37
+ "bus",
38
+ "car",
39
+ "cat",
40
+ "chair",
41
+ "cow",
42
+ "diningtable",
43
+ "dog",
44
+ "horse",
45
+ "motorbike",
46
+ "person",
47
+ "pottedplant",
48
+ "sheep",
49
+ "sofa",
50
+ "train",
51
+ "tvmonitor",
52
+ ]
53
 
54
 
55
+ @st.experimental_singleton # type: ignore # See https://github.com/python/mypy/issues/7781, https://github.com/python/mypy/issues/12566 # noqa: E501
56
+ def generate_label_colors():
57
+ return np.random.uniform(0, 255, size=(len(CLASSES), 3))
58
 
59
 
60
+ COLORS = generate_label_colors()
 
61
 
62
+ download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
63
+ download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
64
 
65
+ DEFAULT_CONFIDENCE_THRESHOLD = 0.5
 
66
 
 
 
 
 
 
 
 
 
67
 
68
+ class Detection(NamedTuple):
69
+ name: str
70
+ prob: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
 
 
 
 
 
72
 
73
+ # Session-specific caching
74
+ cache_key = "object_detection_dnn"
75
+ if cache_key in st.session_state:
76
+ net = st.session_state[cache_key]
77
+ else:
78
+ net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
79
+ st.session_state[cache_key] = net
80
 
81
+ streaming_placeholder = st.empty()
 
82
 
83
+ confidence_threshold = st.slider(
84
+ "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05
85
+ )
 
 
 
 
 
86
 
 
87
 
88
+ def _annotate_image(image, detections):
89
+ # loop over the detections
90
+ (h, w) = image.shape[:2]
91
+ result: List[Detection] = []
92
+ for i in np.arange(0, detections.shape[2]):
93
+ confidence = detections[0, 0, i, 2]
94
+
95
+ if confidence > confidence_threshold:
96
+ # extract the index of the class label from the `detections`,
97
+ # then compute the (x, y)-coordinates of the bounding box for
98
+ # the object
99
+ idx = int(detections[0, 0, i, 1])
100
+ box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
101
+ (startX, startY, endX, endY) = box.astype("int")
102
+
103
+ name = CLASSES[idx]
104
+ result.append(Detection(name=name, prob=float(confidence)))
105
+
106
+ # display the prediction
107
+ label = f"{name}: {round(confidence * 100, 2)}%"
108
+ cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
109
+ y = startY - 15 if startY - 15 > 15 else startY + 15
110
+ cv2.putText(
111
+ image,
112
+ label,
113
+ (startX, y),
114
+ cv2.FONT_HERSHEY_SIMPLEX,
115
+ 0.5,
116
+ COLORS[idx],
117
+ 2,
118
+ )
119
+ return image, result
120
 
 
 
 
121
 
122
+ result_queue: queue.Queue = (
123
+ queue.Queue()
124
+ ) # TODO: A general-purpose shared state object may be more useful.
 
 
 
 
125
 
126
 
127
+ def callback(frame: av.VideoFrame) -> av.VideoFrame:
128
+ image = frame.to_ndarray(format="bgr24")
129
+ blob = cv2.dnn.blobFromImage(
130
+ cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  )
132
+ net.setInput(blob)
133
+ detections = net.forward()
134
+ annotated_image, result = _annotate_image(image, detections)
135
 
136
+ # NOTE: This `recv` method is called in another thread,
137
+ # so it must be thread-safe.
138
+ result_queue.put(result) # TODO:
139
 
140
+ return av.VideoFrame.from_ndarray(annotated_image, format="bgr24")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
+ with streaming_placeholder.container():
144
  webrtc_ctx = webrtc_streamer(
145
  key="object-detection",
146
  mode=WebRtcMode.SENDRECV,
147
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
148
  video_frame_callback=callback,
149
  media_stream_constraints={"video": True, "audio": False},
150
  async_processing=True,
151
  )
152
 
153
+ if st.checkbox("Show the detected labels", value=True):
154
+ if webrtc_ctx.state.playing:
155
+ labels_placeholder = st.empty()
156
+ # NOTE: The video transformation with object detection and
157
+ # this loop displaying the result labels are running
158
+ # in different threads asynchronously.
159
+ # Then the rendered video frames and the labels displayed here
160
+ # are not strictly synchronized.
161
+ while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  try:
163
+ result = result_queue.get(timeout=1.0)
164
  except queue.Empty:
165
+ result = None
166
+ labels_placeholder.table(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ st.markdown(
169
+ "This demo uses a model and code from "
170
+ "https://github.com/robmarkcole/object-detection-app. "
171
+ "Many thanks to the project."
172
+ )