eusholli commited on
Commit
7eed7bb
·
1 Parent(s): 9dc0fd2

commented app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -90
app.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import logging
2
  import queue
3
  from pathlib import Path
@@ -9,82 +11,156 @@ import numpy as np
9
  import streamlit as st
10
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
11
 
12
- from utils.download import download_file
13
  from utils.turn import get_ice_servers
14
 
15
- from mtcnn import MTCNN
16
- from PIL import Image, ImageDraw
17
- from transformers import pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Initialize the Hugging Face pipeline for facial emotion detection
20
  emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")
21
 
22
- img_container = {"webcam": None, "analyzed": None, "uploaded": None}
 
23
 
24
  # Initialize MTCNN for face detection
25
  mtcnn = MTCNN()
26
 
27
- HERE = Path(__file__).parent
28
- ROOT = HERE.parent
29
-
30
  logger = logging.getLogger(__name__)
31
 
 
 
32
  class Detection(NamedTuple):
33
  class_id: int
34
  label: str
35
  score: float
36
  box: np.ndarray
37
 
38
- result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
39
-
40
- # Function to analyze sentiment
41
- def analyze_sentiment(face):
42
- rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
43
- pil_image = Image.fromarray(rgb_face)
44
- results = emotion_pipeline(pil_image)
45
- dominant_emotion = max(results, key=lambda x: x['score'])['label']
46
- return dominant_emotion
47
 
48
- TEXT_SIZE = 1
49
- LINE_SIZE = 2
50
 
51
- # Function to detect faces, analyze sentiment, and draw a red box around them
52
- def detect_and_draw_faces(frame):
53
- results = mtcnn.detect_faces(frame)
54
- for result in results:
55
- x, y, w, h = result['box']
56
- face = frame[y:y+h, x:x+w]
57
- sentiment = analyze_sentiment(face)
58
- cv2.rectangle(frame, (x, y), (x+w, y+h), (0, 0, 255), LINE_SIZE)
59
- text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[0]
60
- text_x = x
61
- text_y = y - 10
62
- background_tl = (text_x, text_y - text_size[1])
63
- background_br = (text_x + text_size[0], text_y + 5)
64
- cv2.rectangle(frame, background_tl, background_br, (0, 0, 0), cv2.FILLED)
65
- cv2.putText(frame, sentiment, (text_x, text_y), cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, (255, 255, 255), 2)
66
- result_queue.put(results)
67
- return frame
68
 
 
 
 
 
69
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
70
- img = frame.to_ndarray(format="bgr24")
71
- img_container["webcam"] = img
72
- frame_with_boxes = detect_and_draw_faces(img.copy())
73
- img_container["analyzed"] = frame_with_boxes
74
- return frame
75
 
 
 
76
  ice_servers = get_ice_servers()
77
 
78
- # Streamlit UI
 
 
 
79
  st.markdown(
80
  """
81
  <style>
82
  .main {
83
- background-color: #F7F7F7;
84
  padding: 2rem;
85
  }
86
  h1, h2, h3 {
87
- color: #333333;
88
  font-family: 'Arial', sans-serif;
89
  }
90
  h1 {
@@ -99,31 +175,24 @@ st.markdown(
99
  font-weight: 500;
100
  font-size: 1.5rem;
101
  }
102
- .stButton button {
103
- background-color: #E60012;
104
- color: white;
105
- border-radius: 5px;
106
- font-size: 16px;
107
- padding: 0.5rem 1rem;
108
- }
109
  </style>
110
  """,
111
- unsafe_allow_html=True
112
  )
113
 
114
- st.title("Computer Vision Test Lab")
 
115
  st.subheader("Facial Sentiment Analysis")
116
 
117
- show_labels = st.checkbox("Show the detected labels", value=True)
118
-
119
  # Columns for input and output streams
120
  col1, col2 = st.columns(2)
121
 
122
  with col1:
123
  st.header("Input Stream")
124
- st.subheader("Webcam")
 
125
  webrtc_ctx = webrtc_streamer(
126
- key="object-detection",
127
  mode=WebRtcMode.SENDRECV,
128
  rtc_configuration=ice_servers,
129
  video_frame_callback=video_frame_callback,
@@ -131,50 +200,131 @@ with col1:
131
  async_processing=True,
132
  )
133
 
 
134
  st.subheader("Upload an Image")
135
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
136
 
137
- with col2:
138
- st.header("Analysis")
139
- input_subheader_placeholder = st.empty()
140
- input_placeholder = st.empty()
141
-
142
- output_subheader_placeholder = st.empty()
143
- output_placeholder = st.empty()
144
 
145
- if webrtc_ctx.state.playing:
146
- labels_placeholder = st.empty()
147
- input_subheader_placeholder.subheader("Input Frame")
148
- output_subheader_placeholder.subheader("Output Frame")
 
149
 
150
- while True:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  result = result_queue.get()
152
  if show_labels:
153
- labels_placeholder.table(result)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
- img = img_container["webcam"]
156
- frame_with_boxes = img_container["analyzed"]
 
 
 
157
 
158
- if img is None:
159
- continue
 
 
 
 
160
 
161
- input_placeholder.image(img, channels="BGR")
162
- output_placeholder.image(frame_with_boxes, channels="BGR")
163
 
164
- if uploaded_file is not None:
165
- input_subheader_placeholder.subheader("Input Frame")
166
- output_subheader_placeholder.subheader("Output Frame")
167
 
168
- image = Image.open(uploaded_file)
169
- img = np.array(image.convert("RGB")) # Ensure image is in RGB format
170
- img_container["uploaded"] = img
171
- analyzed_img = detect_and_draw_faces(img.copy())
172
- input_placeholder.image(img)
173
- output_placeholder.image(analyzed_img)
174
 
175
- result = result_queue.get()
176
- if show_labels:
177
- labels_placeholder = st.empty()
178
- labels_placeholder.table(result)
 
 
179
 
180
-
 
1
+ import time
2
+ import os
3
  import logging
4
  import queue
5
  from pathlib import Path
 
11
  import streamlit as st
12
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
13
 
14
+ from utils.download import download_file
15
  from utils.turn import get_ice_servers
16
 
17
+ from mtcnn import MTCNN # Import MTCNN for face detection
18
+ from PIL import Image, ImageDraw # Import PIL for image processing
19
+ from transformers import pipeline # Import Hugging Face transformers pipeline
20
+
21
+ import requests
22
+ from io import BytesIO # Import for handling byte streams
23
+
24
+ # CHANGE THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
25
+ #
26
+ #
27
+ # Function to analyze an input frame and generate an analyzed frame
28
+ # This function takes an input video frame, detects faces in it using MTCNN,
29
+ # then for each detected face, it analyzes the sentiment (emotion) using the analyze_sentiment function,
30
+ # draws a rectangle around the face, and overlays the detected emotion on the frame.
31
+ # It also records the time taken to process the frame and stores it in a global container.
32
+ # Constants for text and line size in the output image
33
+ TEXT_SIZE = 1
34
+ LINE_SIZE = 2
35
+
36
+
37
+ def analyze_frame(frame):
38
+ start_time = time.time() # Start timing the analysis
39
+ img_container["input"] = frame # Store the input frame
40
+ frame = frame.copy() # Create a copy of the frame to modify
41
+
42
+ results = mtcnn.detect_faces(frame) # Detect faces in the frame
43
+ for result in results:
44
+ x, y, w, h = result["box"] # Get the bounding box of the detected face
45
+ face = frame[y : y + h, x : x + w] # Extract the face from the frame
46
+ sentiment = analyze_sentiment(face) # Analyze the sentiment of the face
47
+ # Draw a rectangle around the face
48
+ cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), LINE_SIZE)
49
+ text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[
50
+ 0
51
+ ]
52
+ text_x = x
53
+ text_y = y - 10
54
+ background_tl = (text_x, text_y - text_size[1])
55
+ background_br = (text_x + text_size[0], text_y + 5)
56
+ # Draw a black background for the text
57
+ cv2.rectangle(frame, background_tl, background_br, (0, 0, 0), cv2.FILLED)
58
+ # Put the sentiment text on the image
59
+ cv2.putText(
60
+ frame,
61
+ sentiment,
62
+ (text_x, text_y),
63
+ cv2.FONT_HERSHEY_SIMPLEX,
64
+ TEXT_SIZE,
65
+ (255, 255, 255),
66
+ 2,
67
+ )
68
+
69
+ end_time = time.time() # End timing the analysis
70
+ execution_time_ms = round(
71
+ (end_time - start_time) * 1000, 2
72
+ ) # Calculate execution time in milliseconds
73
+ img_container["analysis_time"] = execution_time_ms # Store the execution time
74
+
75
+ result_queue.put(results) # Put the results in the result queue
76
+ img_container["analyzed"] = frame # Store the analyzed frame
77
+
78
+ return # End of the function
79
+
80
+
81
+ # Function to analyze the sentiment (emotion) of a detected face
82
+ # This function converts the face from BGR to RGB format, then converts it to a PIL image,
83
+ # uses a pre-trained emotion detection model to get emotion predictions,
84
+ # and finally returns the most dominant emotion detected.
85
+ def analyze_sentiment(face):
86
+ rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB) # Convert face to RGB format
87
+ pil_image = Image.fromarray(rgb_face) # Convert to PIL image
88
+ results = emotion_pipeline(pil_image) # Run emotion detection on the image
89
+ dominant_emotion = max(results, key=lambda x: x["score"])[
90
+ "label"
91
+ ] # Get the dominant emotion
92
+ return dominant_emotion # Return the detected emotion
93
+
94
+
95
+ # Suppress FFmpeg logs
96
+ os.environ["FFMPEG_LOG_LEVEL"] = "quiet"
97
+
98
+ # Suppress TensorFlow or PyTorch progress bars
99
+ import tensorflow as tf
100
+
101
+ tf.get_logger().setLevel("ERROR")
102
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
103
+
104
+ # Suppress PyTorch logs
105
+ import torch
106
+
107
+ logging.getLogger().setLevel(logging.WARNING)
108
+ torch.set_num_threads(1)
109
+ logging.getLogger("torch").setLevel(logging.ERROR)
110
+
111
+ # Suppress Streamlit logs using the logging module
112
+ logging.getLogger("streamlit").setLevel(logging.ERROR)
113
+
114
 
115
  # Initialize the Hugging Face pipeline for facial emotion detection
116
  emotion_pipeline = pipeline("image-classification", model="trpakov/vit-face-expression")
117
 
118
+ # Container to hold image data and analysis results
119
+ img_container = {"input": None, "analyzed": None, "analysis_time": None}
120
 
121
  # Initialize MTCNN for face detection
122
  mtcnn = MTCNN()
123
 
124
+ # Logger for debugging and information
 
 
125
  logger = logging.getLogger(__name__)
126
 
127
+
128
+ # Named tuple to store detection results
129
  class Detection(NamedTuple):
130
  class_id: int
131
  label: str
132
  score: float
133
  box: np.ndarray
134
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Queue to store detection results
137
+ result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
+ # Callback function to process video frames
141
+ # This function is called for each video frame in the WebRTC stream.
142
+ # It converts the frame to a numpy array in RGB format, analyzes the frame,
143
+ # and returns the original frame.
144
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
145
+ img = frame.to_ndarray(format="rgb24") # Convert frame to numpy array in RGB format
146
+ analyze_frame(img) # Analyze the frame
147
+ return frame # Return the original frame
 
 
148
 
149
+
150
+ # Get ICE servers for WebRTC
151
  ice_servers = get_ice_servers()
152
 
153
+ # Streamlit UI configuration
154
+ st.set_page_config(layout="wide")
155
+
156
+ # Custom CSS for the Streamlit page
157
  st.markdown(
158
  """
159
  <style>
160
  .main {
 
161
  padding: 2rem;
162
  }
163
  h1, h2, h3 {
 
164
  font-family: 'Arial', sans-serif;
165
  }
166
  h1 {
 
175
  font-weight: 500;
176
  font-size: 1.5rem;
177
  }
 
 
 
 
 
 
 
178
  </style>
179
  """,
180
+ unsafe_allow_html=True,
181
  )
182
 
183
+ # Streamlit page title and subtitle
184
+ st.title("Computer Vision Playground")
185
  st.subheader("Facial Sentiment Analysis")
186
 
 
 
187
  # Columns for input and output streams
188
  col1, col2 = st.columns(2)
189
 
190
  with col1:
191
  st.header("Input Stream")
192
+ st.subheader("input")
193
+ # WebRTC streamer to get video input from the webcam
194
  webrtc_ctx = webrtc_streamer(
195
+ key="input-webcam",
196
  mode=WebRtcMode.SENDRECV,
197
  rtc_configuration=ice_servers,
198
  video_frame_callback=video_frame_callback,
 
200
  async_processing=True,
201
  )
202
 
203
+ # File uploader for images
204
  st.subheader("Upload an Image")
205
  uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
206
 
207
+ # Text input for image URL
208
+ st.subheader("Or Enter Image URL")
209
+ image_url = st.text_input("Image URL")
 
 
 
 
210
 
211
+ # File uploader for videos
212
+ st.subheader("Upload a Video")
213
+ uploaded_video = st.file_uploader(
214
+ "Choose a video...", type=["mp4", "avi", "mov", "mkv"]
215
+ )
216
 
217
+ # Text input for video URL
218
+ st.subheader("Or Enter Video URL")
219
+ video_url = st.text_input("Video URL")
220
+
221
+
222
+ # Function to initialize the analysis UI
223
+ # This function sets up the placeholders and UI elements in the analysis section.
224
+ # It creates placeholders for input and output frames, analysis time, and detected labels.
225
+ def analysis_init():
226
+ global analysis_time, show_labels, labels_placeholder, input_placeholder, output_placeholder
227
+
228
+ with col2:
229
+ st.header("Analysis")
230
+ st.subheader("Input Frame")
231
+ input_placeholder = st.empty() # Placeholder for input frame
232
+
233
+ st.subheader("Output Frame")
234
+ output_placeholder = st.empty() # Placeholder for output frame
235
+ analysis_time = st.empty() # Placeholder for analysis time
236
+ show_labels = st.checkbox(
237
+ "Show the detected labels", value=True
238
+ ) # Checkbox to show/hide labels
239
+ labels_placeholder = st.empty() # Placeholder for labels
240
+
241
+
242
+ # Function to publish frames and results to the Streamlit UI
243
+ # This function retrieves the latest frames and results from the global container and result queue,
244
+ # and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
245
+ def publish_frame():
246
+ if not result_queue.empty():
247
  result = result_queue.get()
248
  if show_labels:
249
+ labels_placeholder.table(
250
+ result
251
+ ) # Display labels if the checkbox is checked
252
+
253
+ img = img_container["input"]
254
+ if img is None:
255
+ return
256
+ input_placeholder.image(img, channels="RGB") # Display the input frame
257
+
258
+ analyzed = img_container["analyzed"]
259
+ if analyzed is None:
260
+ return
261
+ output_placeholder.image(analyzed, channels="RGB") # Display the analyzed frame
262
+
263
+ time = img_container["analysis_time"]
264
+ if time is None:
265
+ return
266
+ analysis_time.text(f"Analysis Time: {time} ms") # Display the analysis time
267
+
268
+
269
+ # If the WebRTC streamer is playing, initialize and publish frames
270
+ if webrtc_ctx.state.playing:
271
+ analysis_init() # Initialize the analysis UI
272
+ while True:
273
+ publish_frame() # Publish the frames and results
274
+ time.sleep(0.1) # Delay to control frame rate
275
+
276
+
277
+ # If an image is uploaded or a URL is provided, process the image
278
+ if uploaded_file is not None or image_url:
279
+ analysis_init() # Initialize the analysis UI
280
+
281
+ if uploaded_file is not None:
282
+ image = Image.open(uploaded_file) # Open the uploaded image
283
+ img = np.array(image.convert("RGB")) # Convert the image to RGB format
284
+ else:
285
+ response = requests.get(image_url) # Download the image from the URL
286
+ image = Image.open(BytesIO(response.content)) # Open the downloaded image
287
+ img = np.array(image.convert("RGB")) # Convert the image to RGB format
288
+
289
+ analyze_frame(img) # Analyze the image
290
+ publish_frame() # Publish the results
291
+
292
+
293
+ # Function to process video files
294
+ # This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis,
295
+ # and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels.
296
+ def process_video(video_path):
297
+ cap = cv2.VideoCapture(video_path) # Open the video file
298
+ while cap.isOpened():
299
+ ret, frame = cap.read() # Read a frame from the video
300
+ if not ret:
301
+ break # Exit the loop if no more frames are available
302
 
303
+ input_placeholder.image(frame) # Display the current frame as the input frame
304
+ analyze_frame(
305
+ frame
306
+ ) # Analyze the frame for face detection and sentiment analysis
307
+ publish_frame() # Publish the results
308
 
309
+ if not result_queue.empty():
310
+ result = result_queue.get()
311
+ if show_labels:
312
+ labels_placeholder.table(
313
+ result
314
+ ) # Display labels if the checkbox is checked
315
 
316
+ cap.release() # Release the video capture object
 
317
 
 
 
 
318
 
319
+ # If a video is uploaded or a URL is provided, process the video
320
+ if uploaded_video is not None or video_url:
321
+ analysis_init() # Initialize the analysis UI
 
 
 
322
 
323
+ if uploaded_video is not None:
324
+ video_path = uploaded_video.name # Get the name of the uploaded video
325
+ with open(video_path, "wb") as f:
326
+ f.write(uploaded_video.getbuffer()) # Save the uploaded video to a file
327
+ else:
328
+ video_path = download_file(video_url) # Download the video from the URL
329
 
330
+ process_video(video_path) # Process the video