whitphx HF staff commited on
Commit
e8c0c75
·
1 Parent(s): a4daff6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -27
app.py CHANGED
@@ -5,7 +5,7 @@ import threading
5
  import time
6
  import urllib.request
7
  from pathlib import Path
8
- from typing import List, Union
9
 
10
  try:
11
  from typing import Literal
@@ -15,7 +15,6 @@ except ImportError:
15
  import av
16
  import cv2
17
  import numpy as np
18
- import PIL
19
  import streamlit as st
20
  from aiortc.contrib.media import MediaPlayer
21
 
@@ -77,6 +76,12 @@ def download_file(url, download_to: Path, expected_size=None):
77
  progress_bar.empty()
78
 
79
 
 
 
 
 
 
 
80
  def main():
81
  st.header("WebRTC demo")
82
 
@@ -230,28 +235,32 @@ def app_object_detection():
230
 
231
  DEFAULT_CONFIDENCE_THRESHOLD = 0.5
232
 
 
 
 
 
233
  class MobileNetSSDVideoTransformer(VideoTransformerBase):
234
  confidence_threshold: float
235
- _labels: Union[List[str], None]
236
- _labels_lock: threading.Lock
237
 
238
  def __init__(self) -> None:
239
  self._net = cv2.dnn.readNetFromCaffe(
240
  str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH)
241
  )
242
  self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD
243
- self._labels = None
244
- self._labels_lock = threading.Lock()
245
 
246
  @property
247
- def labels(self) -> Union[List[str], None]:
248
- with self._labels_lock:
249
- return self._labels
250
 
251
  def _annotate_image(self, image, detections):
252
  # loop over the detections
253
  (h, w) = image.shape[:2]
254
- labels = []
255
  for i in np.arange(0, detections.shape[2]):
256
  confidence = detections[0, 0, i, 2]
257
 
@@ -263,9 +272,11 @@ def app_object_detection():
263
  box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
264
  (startX, startY, endX, endY) = box.astype("int")
265
 
 
 
 
266
  # display the prediction
267
- label = f"{CLASSES[idx]}: {round(confidence * 100, 2)}%"
268
- labels.append(label)
269
  cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
270
  y = startY - 15 if startY - 15 > 15 else startY + 15
271
  cv2.putText(
@@ -277,7 +288,7 @@ def app_object_detection():
277
  COLORS[idx],
278
  2,
279
  )
280
- return image, labels
281
 
282
  def transform(self, frame: av.VideoFrame) -> np.ndarray:
283
  image = frame.to_ndarray(format="bgr24")
@@ -286,12 +297,12 @@ def app_object_detection():
286
  )
287
  self._net.setInput(blob)
288
  detections = self._net.forward()
289
- annotated_image, labels = self._annotate_image(image, detections)
290
 
291
  # NOTE: This `transform` method is called in another thread,
292
  # so it must be thread-safe.
293
- with self._labels_lock:
294
- self._labels = labels
295
 
296
  return annotated_image
297
 
@@ -309,7 +320,7 @@ def app_object_detection():
309
  if webrtc_ctx.video_transformer:
310
  webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold
311
 
312
- if st.checkbox("Show the detected labels"):
313
  if webrtc_ctx.state.playing:
314
  labels_placeholder = st.empty()
315
  # NOTE: The video transformation with object detection and
@@ -319,7 +330,7 @@ def app_object_detection():
319
  # are not synchronized.
320
  while True:
321
  if webrtc_ctx.video_transformer:
322
- labels_placeholder.write(webrtc_ctx.video_transformer.labels)
323
  time.sleep(0.1)
324
 
325
  st.markdown(
@@ -371,7 +382,7 @@ def app_streaming():
371
 
372
  WEBRTC_CLIENT_SETTINGS.update(
373
  {
374
- "fmedia_stream_constraints": {
375
  "video": media_file_info["type"] == "video",
376
  "audio": media_file_info["type"] == "audio",
377
  }
@@ -405,15 +416,9 @@ def app_sendonly():
405
  webrtc_ctx.video_receiver.stop()
406
  break
407
 
408
- img = frame.to_ndarray(format="bgr24")
409
- img = PIL.Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
410
- image_loc.image(img)
411
-
412
 
413
- WEBRTC_CLIENT_SETTINGS = ClientSettings(
414
- rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
415
- media_stream_constraints={"video": True, "audio": True},
416
- )
417
 
418
  if __name__ == "__main__":
419
  logging.basicConfig(
 
5
  import time
6
  import urllib.request
7
  from pathlib import Path
8
+ from typing import List, NamedTuple, Union
9
 
10
  try:
11
  from typing import Literal
 
15
  import av
16
  import cv2
17
  import numpy as np
 
18
  import streamlit as st
19
  from aiortc.contrib.media import MediaPlayer
20
 
 
76
  progress_bar.empty()
77
 
78
 
79
+ WEBRTC_CLIENT_SETTINGS = ClientSettings(
80
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
81
+ media_stream_constraints={"video": True, "audio": True},
82
+ )
83
+
84
+
85
  def main():
86
  st.header("WebRTC demo")
87
 
 
235
 
236
  DEFAULT_CONFIDENCE_THRESHOLD = 0.5
237
 
238
+ class Detection(NamedTuple):
239
+ name: str
240
+ prob: float
241
+
242
  class MobileNetSSDVideoTransformer(VideoTransformerBase):
243
  confidence_threshold: float
244
+ _result: Union[List[Detection], None]
245
+ _result_lock: threading.Lock
246
 
247
  def __init__(self) -> None:
248
  self._net = cv2.dnn.readNetFromCaffe(
249
  str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH)
250
  )
251
  self.confidence_threshold = DEFAULT_CONFIDENCE_THRESHOLD
252
+ self._result = None
253
+ self._result_lock = threading.Lock()
254
 
255
  @property
256
+ def result(self) -> Union[List[Detection], None]:
257
+ with self._result_lock:
258
+ return self._result
259
 
260
  def _annotate_image(self, image, detections):
261
  # loop over the detections
262
  (h, w) = image.shape[:2]
263
+ result: List[Detection] = []
264
  for i in np.arange(0, detections.shape[2]):
265
  confidence = detections[0, 0, i, 2]
266
 
 
272
  box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
273
  (startX, startY, endX, endY) = box.astype("int")
274
 
275
+ name = CLASSES[idx]
276
+ result.append(Detection(name=name, prob=float(confidence)))
277
+
278
  # display the prediction
279
+ label = f"{name}: {round(confidence * 100, 2)}%"
 
280
  cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
281
  y = startY - 15 if startY - 15 > 15 else startY + 15
282
  cv2.putText(
 
288
  COLORS[idx],
289
  2,
290
  )
291
+ return image, result
292
 
293
  def transform(self, frame: av.VideoFrame) -> np.ndarray:
294
  image = frame.to_ndarray(format="bgr24")
 
297
  )
298
  self._net.setInput(blob)
299
  detections = self._net.forward()
300
+ annotated_image, result = self._annotate_image(image, detections)
301
 
302
  # NOTE: This `transform` method is called in another thread,
303
  # so it must be thread-safe.
304
+ with self._result_lock:
305
+ self._result = result
306
 
307
  return annotated_image
308
 
 
320
  if webrtc_ctx.video_transformer:
321
  webrtc_ctx.video_transformer.confidence_threshold = confidence_threshold
322
 
323
+ if st.checkbox("Show the detected labels", value=True):
324
  if webrtc_ctx.state.playing:
325
  labels_placeholder = st.empty()
326
  # NOTE: The video transformation with object detection and
 
330
  # are not synchronized.
331
  while True:
332
  if webrtc_ctx.video_transformer:
333
+ labels_placeholder.table(webrtc_ctx.video_transformer.result)
334
  time.sleep(0.1)
335
 
336
  st.markdown(
 
382
 
383
  WEBRTC_CLIENT_SETTINGS.update(
384
  {
385
+ "media_stream_constraints": {
386
  "video": media_file_info["type"] == "video",
387
  "audio": media_file_info["type"] == "audio",
388
  }
 
416
  webrtc_ctx.video_receiver.stop()
417
  break
418
 
419
+ img_rgb = frame.to_ndarray(format="rgb24")
420
+ image_loc.image(img_rgb)
 
 
421
 
 
 
 
 
422
 
423
  if __name__ == "__main__":
424
  logging.basicConfig(