whitphx HF staff commited on
Commit
a01685d
·
1 Parent(s): 0325cdc

Update streamlit==1.19.0 and object detection demo

Browse files
Files changed (3) hide show
  1. app.py +65 -75
  2. pages/1_object_detection.py +65 -75
  3. requirements.txt +1 -1
app.py CHANGED
@@ -52,7 +52,14 @@ CLASSES = [
52
  ]
53
 
54
 
55
- @st.experimental_singleton # type: ignore # See https://github.com/python/mypy/issues/7781, https://github.com/python/mypy/issues/12566 # noqa: E501
 
 
 
 
 
 
 
56
  def generate_label_colors():
57
  return np.random.uniform(0, 255, size=(len(CLASSES), 3))
58
 
@@ -62,13 +69,6 @@ COLORS = generate_label_colors()
62
  download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
63
  download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
64
 
65
- DEFAULT_CONFIDENCE_THRESHOLD = 0.5
66
-
67
-
68
- class Detection(NamedTuple):
69
- name: str
70
- prob: float
71
-
72
 
73
  # Session-specific caching
74
  cache_key = "object_detection_dnn"
@@ -78,77 +78,70 @@ else:
78
  net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
79
  st.session_state[cache_key] = net
80
 
81
- streaming_placeholder = st.empty()
82
 
83
- confidence_threshold = st.slider(
84
- "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05
85
- )
 
 
86
 
87
 
88
- def _annotate_image(image, detections):
89
- # loop over the detections
90
- (h, w) = image.shape[:2]
91
- result: List[Detection] = []
92
- for i in np.arange(0, detections.shape[2]):
93
- confidence = detections[0, 0, i, 2]
94
-
95
- if confidence > confidence_threshold:
96
- # extract the index of the class label from the `detections`,
97
- # then compute the (x, y)-coordinates of the bounding box for
98
- # the object
99
- idx = int(detections[0, 0, i, 1])
100
- box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
101
- (startX, startY, endX, endY) = box.astype("int")
102
-
103
- name = CLASSES[idx]
104
- result.append(Detection(name=name, prob=float(confidence)))
105
-
106
- # display the prediction
107
- label = f"{name}: {round(confidence * 100, 2)}%"
108
- cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
109
- y = startY - 15 if startY - 15 > 15 else startY + 15
110
- cv2.putText(
111
- image,
112
- label,
113
- (startX, y),
114
- cv2.FONT_HERSHEY_SIMPLEX,
115
- 0.5,
116
- COLORS[idx],
117
- 2,
118
- )
119
- return image, result
120
-
121
-
122
- result_queue: queue.Queue = (
123
- queue.Queue()
124
- ) # TODO: A general-purpose shared state object may be more useful.
125
-
126
-
127
- def callback(frame: av.VideoFrame) -> av.VideoFrame:
128
  image = frame.to_ndarray(format="bgr24")
 
 
129
  blob = cv2.dnn.blobFromImage(
130
  cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
131
  )
132
  net.setInput(blob)
133
- detections = net.forward()
134
- annotated_image, result = _annotate_image(image, detections)
135
-
136
- # NOTE: This `recv` method is called in another thread,
137
- # so it must be thread-safe.
138
- result_queue.put(result) # TODO:
139
-
140
- return av.VideoFrame.from_ndarray(annotated_image, format="bgr24")
141
-
142
-
143
- with streaming_placeholder.container():
144
- webrtc_ctx = webrtc_streamer(
145
- key="object-detection",
146
- mode=WebRtcMode.SENDRECV,
147
- rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
148
- video_frame_callback=callback,
149
- media_stream_constraints={"video": True, "audio": False},
150
- async_processing=True,
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  if st.checkbox("Show the detected labels", value=True):
154
  if webrtc_ctx.state.playing:
@@ -159,10 +152,7 @@ if st.checkbox("Show the detected labels", value=True):
159
  # Then the rendered video frames and the labels displayed here
160
  # are not strictly synchronized.
161
  while True:
162
- try:
163
- result = result_queue.get(timeout=1.0)
164
- except queue.Empty:
165
- result = None
166
  labels_placeholder.table(result)
167
 
168
  st.markdown(
 
52
  ]
53
 
54
 
55
+ class Detection(NamedTuple):
56
+ class_id: int
57
+ label: str
58
+ score: float
59
+ box: np.ndarray
60
+
61
+
62
+ @st.cache_resource # type: ignore
63
  def generate_label_colors():
64
  return np.random.uniform(0, 255, size=(len(CLASSES), 3))
65
 
 
69
  download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
70
  download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
71
 
 
 
 
 
 
 
 
72
 
73
  # Session-specific caching
74
  cache_key = "object_detection_dnn"
 
78
  net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
79
  st.session_state[cache_key] = net
80
 
81
+ score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
82
 
83
+ # NOTE: The callback will be called in another thread,
84
+ # so use a queue here for thread-safety to pass the data
85
+ # from inside to outside the callback.
86
+ # TODO: A general-purpose shared state object may be more useful.
87
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
88
 
89
 
90
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  image = frame.to_ndarray(format="bgr24")
92
+
93
+ # Run inference
94
  blob = cv2.dnn.blobFromImage(
95
  cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
96
  )
97
  net.setInput(blob)
98
+ output = net.forward()
99
+
100
+ h, w = image.shape[:2]
101
+
102
+ # Convert the output array into a structured form.
103
+ output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
104
+ output = output[output[:, 2] >= score_threshold]
105
+ detections = [
106
+ Detection(
107
+ class_id=int(detection[1]),
108
+ label=CLASSES[int(detection[1])],
109
+ score=float(detection[2]),
110
+ box=(detection[3:7] * np.array([w, h, w, h])),
111
+ )
112
+ for detection in output
113
+ ]
114
+
115
+ # Render bounding boxes and captions
116
+ for detection in detections:
117
+ caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
118
+ color = COLORS[detection.class_id]
119
+ xmin, ymin, xmax, ymax = detection.box.astype("int")
120
+
121
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
122
+ cv2.putText(
123
+ image,
124
+ caption,
125
+ (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
126
+ cv2.FONT_HERSHEY_SIMPLEX,
127
+ 0.5,
128
+ color,
129
+ 2,
130
+ )
131
+
132
+ result_queue.put(detections)
133
+
134
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
135
+
136
+
137
+ webrtc_ctx = webrtc_streamer(
138
+ key="object-detection",
139
+ mode=WebRtcMode.SENDRECV,
140
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
141
+ video_frame_callback=video_frame_callback,
142
+ media_stream_constraints={"video": True, "audio": False},
143
+ async_processing=True,
144
+ )
145
 
146
  if st.checkbox("Show the detected labels", value=True):
147
  if webrtc_ctx.state.playing:
 
152
  # Then the rendered video frames and the labels displayed here
153
  # are not strictly synchronized.
154
  while True:
155
+ result = result_queue.get()
 
 
 
156
  labels_placeholder.table(result)
157
 
158
  st.markdown(
pages/1_object_detection.py CHANGED
@@ -52,7 +52,14 @@ CLASSES = [
52
  ]
53
 
54
 
55
- @st.experimental_singleton # type: ignore # See https://github.com/python/mypy/issues/7781, https://github.com/python/mypy/issues/12566 # noqa: E501
 
 
 
 
 
 
 
56
  def generate_label_colors():
57
  return np.random.uniform(0, 255, size=(len(CLASSES), 3))
58
 
@@ -62,13 +69,6 @@ COLORS = generate_label_colors()
62
  download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
63
  download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
64
 
65
- DEFAULT_CONFIDENCE_THRESHOLD = 0.5
66
-
67
-
68
- class Detection(NamedTuple):
69
- name: str
70
- prob: float
71
-
72
 
73
  # Session-specific caching
74
  cache_key = "object_detection_dnn"
@@ -78,77 +78,70 @@ else:
78
  net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
79
  st.session_state[cache_key] = net
80
 
81
- streaming_placeholder = st.empty()
82
 
83
- confidence_threshold = st.slider(
84
- "Confidence threshold", 0.0, 1.0, DEFAULT_CONFIDENCE_THRESHOLD, 0.05
85
- )
 
 
86
 
87
 
88
- def _annotate_image(image, detections):
89
- # loop over the detections
90
- (h, w) = image.shape[:2]
91
- result: List[Detection] = []
92
- for i in np.arange(0, detections.shape[2]):
93
- confidence = detections[0, 0, i, 2]
94
-
95
- if confidence > confidence_threshold:
96
- # extract the index of the class label from the `detections`,
97
- # then compute the (x, y)-coordinates of the bounding box for
98
- # the object
99
- idx = int(detections[0, 0, i, 1])
100
- box = detections[0, 0, i, 3:7] * np.array([w, h, w, h])
101
- (startX, startY, endX, endY) = box.astype("int")
102
-
103
- name = CLASSES[idx]
104
- result.append(Detection(name=name, prob=float(confidence)))
105
-
106
- # display the prediction
107
- label = f"{name}: {round(confidence * 100, 2)}%"
108
- cv2.rectangle(image, (startX, startY), (endX, endY), COLORS[idx], 2)
109
- y = startY - 15 if startY - 15 > 15 else startY + 15
110
- cv2.putText(
111
- image,
112
- label,
113
- (startX, y),
114
- cv2.FONT_HERSHEY_SIMPLEX,
115
- 0.5,
116
- COLORS[idx],
117
- 2,
118
- )
119
- return image, result
120
-
121
-
122
- result_queue: queue.Queue = (
123
- queue.Queue()
124
- ) # TODO: A general-purpose shared state object may be more useful.
125
-
126
-
127
- def callback(frame: av.VideoFrame) -> av.VideoFrame:
128
  image = frame.to_ndarray(format="bgr24")
 
 
129
  blob = cv2.dnn.blobFromImage(
130
  cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
131
  )
132
  net.setInput(blob)
133
- detections = net.forward()
134
- annotated_image, result = _annotate_image(image, detections)
135
-
136
- # NOTE: This `recv` method is called in another thread,
137
- # so it must be thread-safe.
138
- result_queue.put(result) # TODO:
139
-
140
- return av.VideoFrame.from_ndarray(annotated_image, format="bgr24")
141
-
142
-
143
- with streaming_placeholder.container():
144
- webrtc_ctx = webrtc_streamer(
145
- key="object-detection",
146
- mode=WebRtcMode.SENDRECV,
147
- rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
148
- video_frame_callback=callback,
149
- media_stream_constraints={"video": True, "audio": False},
150
- async_processing=True,
151
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  if st.checkbox("Show the detected labels", value=True):
154
  if webrtc_ctx.state.playing:
@@ -159,10 +152,7 @@ if st.checkbox("Show the detected labels", value=True):
159
  # Then the rendered video frames and the labels displayed here
160
  # are not strictly synchronized.
161
  while True:
162
- try:
163
- result = result_queue.get(timeout=1.0)
164
- except queue.Empty:
165
- result = None
166
  labels_placeholder.table(result)
167
 
168
  st.markdown(
 
52
  ]
53
 
54
 
55
+ class Detection(NamedTuple):
56
+ class_id: int
57
+ label: str
58
+ score: float
59
+ box: np.ndarray
60
+
61
+
62
+ @st.cache_resource # type: ignore
63
  def generate_label_colors():
64
  return np.random.uniform(0, 255, size=(len(CLASSES), 3))
65
 
 
69
  download_file(MODEL_URL, MODEL_LOCAL_PATH, expected_size=23147564)
70
  download_file(PROTOTXT_URL, PROTOTXT_LOCAL_PATH, expected_size=29353)
71
 
 
 
 
 
 
 
 
72
 
73
  # Session-specific caching
74
  cache_key = "object_detection_dnn"
 
78
  net = cv2.dnn.readNetFromCaffe(str(PROTOTXT_LOCAL_PATH), str(MODEL_LOCAL_PATH))
79
  st.session_state[cache_key] = net
80
 
81
+ score_threshold = st.slider("Score threshold", 0.0, 1.0, 0.5, 0.05)
82
 
83
+ # NOTE: The callback will be called in another thread,
84
+ # so use a queue here for thread-safety to pass the data
85
+ # from inside to outside the callback.
86
+ # TODO: A general-purpose shared state object may be more useful.
87
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
88
 
89
 
90
+ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
  image = frame.to_ndarray(format="bgr24")
92
+
93
+ # Run inference
94
  blob = cv2.dnn.blobFromImage(
95
  cv2.resize(image, (300, 300)), 0.007843, (300, 300), 127.5
96
  )
97
  net.setInput(blob)
98
+ output = net.forward()
99
+
100
+ h, w = image.shape[:2]
101
+
102
+ # Convert the output array into a structured form.
103
+ output = output.squeeze() # (1, 1, N, 7) -> (N, 7)
104
+ output = output[output[:, 2] >= score_threshold]
105
+ detections = [
106
+ Detection(
107
+ class_id=int(detection[1]),
108
+ label=CLASSES[int(detection[1])],
109
+ score=float(detection[2]),
110
+ box=(detection[3:7] * np.array([w, h, w, h])),
111
+ )
112
+ for detection in output
113
+ ]
114
+
115
+ # Render bounding boxes and captions
116
+ for detection in detections:
117
+ caption = f"{detection.label}: {round(detection.score * 100, 2)}%"
118
+ color = COLORS[detection.class_id]
119
+ xmin, ymin, xmax, ymax = detection.box.astype("int")
120
+
121
+ cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, 2)
122
+ cv2.putText(
123
+ image,
124
+ caption,
125
+ (xmin, ymin - 15 if ymin - 15 > 15 else ymin + 15),
126
+ cv2.FONT_HERSHEY_SIMPLEX,
127
+ 0.5,
128
+ color,
129
+ 2,
130
+ )
131
+
132
+ result_queue.put(detections)
133
+
134
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
135
+
136
+
137
+ webrtc_ctx = webrtc_streamer(
138
+ key="object-detection",
139
+ mode=WebRtcMode.SENDRECV,
140
+ rtc_configuration={"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]},
141
+ video_frame_callback=video_frame_callback,
142
+ media_stream_constraints={"video": True, "audio": False},
143
+ async_processing=True,
144
+ )
145
 
146
  if st.checkbox("Show the detected labels", value=True):
147
  if webrtc_ctx.state.playing:
 
152
  # Then the rendered video frames and the labels displayed here
153
  # are not strictly synchronized.
154
  while True:
155
+ result = result_queue.get()
 
 
 
156
  labels_placeholder.table(result)
157
 
158
  st.markdown(
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
  opencv-python-headless==4.5.5.64
2
  pydub==0.25.1
3
- streamlit==1.17.0
4
  streamlit_webrtc==0.44.6
 
1
  opencv-python-headless==4.5.5.64
2
  pydub==0.25.1
3
+ streamlit==1.19.0
4
  streamlit_webrtc==0.44.6