Beehzod's picture
Update app.py
fb5660a verified
raw
history blame
13.1 kB
import time
import os
import logging
import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from utils.download import download_file
from utils.turn import get_ice_servers
from mtcnn import MTCNN # Import MTCNN for face detection
from PIL import Image, ImageDraw # Import PIL for image processing
from transformers import pipeline # Import Hugging Face transformers pipeline
import requests
from io import BytesIO # Import for handling byte streams
import yt_dlp
# CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
# Update below string to set display title of analysis
# Appropriate imports needed for analysis
# Initialize MTCNN for face detection
mtcnn = MTCNN()
# Initialize the Hugging Face pipeline for facial emotion detection
emotion_pipeline = pipeline("image-classification",
model="trpakov/vit-face-expression")
# Default title - "Facial Sentiment Analysis"
ANALYSIS_TITLE = "Facial Sentiment Analysis"
# CHANGE THE CONTENTS OF THIS FUNCTION, USE TO REPLACE WITH YOUR WANTED ANALYSIS.
#
#
# Function to analyze an input frame and generate an analyzed frame
# This function takes an input video frame, detects faces in it using MTCNN,
# then for each detected face, it analyzes the sentiment (emotion) using the analyze_sentiment function,
# draws a rectangle around the face, and overlays the detected emotion on the frame.
# It also records the time taken to process the frame and stores it in a global container.
# Constants for text and line size in the output image
TEXT_SIZE = 1
LINE_SIZE = 2
# Set analysis results in img_container and result queue for display
# img_container["input"] - holds the input frame contents - of type np.ndarray
# img_container["analyzed"] - holds the analyzed frame with any added annotations - of type np.ndarray
# img_container["analysis_time"] - holds how long the analysis has taken in miliseconds
# img_container["detections"] - holds the analysis metadata results
def analyze_frame(frame: np.ndarray):
start_time = time.time() # Start timing the analysis
img_container["input"] = frame # Store the input frame
frame = frame.copy() # Create a copy of the frame to modify
results = mtcnn.detect_faces(frame) # Detect faces in the frame
for result in results:
x, y, w, h = result["box"] # Get the bounding box of the detected face
face = frame[y: y + h, x: x + w] # Extract the face from the frame
# Analyze the sentiment of the face
sentiment = analyze_sentiment(face)
result["label"] = sentiment
# Draw a rectangle around the face
cv2.rectangle(frame, (x, y), (x + w, y + h), (0, 0, 255), LINE_SIZE)
text_size = cv2.getTextSize(sentiment, cv2.FONT_HERSHEY_SIMPLEX, TEXT_SIZE, 2)[
0
]
text_x = x
text_y = y - 10
background_tl = (text_x, text_y - text_size[1])
background_br = (text_x + text_size[0], text_y + 5)
# Draw a black background for the text
cv2.rectangle(frame, background_tl, background_br,
(0, 0, 0), cv2.FILLED)
# Put the sentiment text on the image
cv2.putText(
frame,
sentiment,
(text_x, text_y),
cv2.FONT_HERSHEY_SIMPLEX,
TEXT_SIZE,
(255, 255, 255),
2,
)
end_time = time.time() # End timing the analysis
execution_time_ms = round(
(end_time - start_time) * 1000, 2
) # Calculate execution time in milliseconds
# Store the execution time
img_container["analysis_time"] = execution_time_ms
# store the detections
img_container["detections"] = results
img_container["analyzed"] = frame # Store the analyzed frame
return # End of the function
# Function to analyze the sentiment (emotion) of a detected face
# This function converts the face from BGR to RGB format, then converts it to a PIL image,
# uses a pre-trained emotion detection model to get emotion predictions,
# and finally returns the most dominant emotion detected.
def analyze_sentiment(face):
# Convert face to RGB format
rgb_face = cv2.cvtColor(face, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(rgb_face) # Convert to PIL image
results = emotion_pipeline(pil_image) # Run emotion detection on the image
dominant_emotion = max(results, key=lambda x: x["score"])[
"label"
] # Get the dominant emotion
return dominant_emotion # Return the detected emotion
#
#
# DO NOT TOUCH THE BELOW CODE (NOT NEEDED)
#
#
# Suppress FFmpeg logs
os.environ["FFMPEG_LOG_LEVEL"] = "quiet"
# Suppress Streamlit logs using the logging module
logging.getLogger("streamlit").setLevel(logging.ERROR)
# Container to hold image data and analysis results
img_container = {"input": None, "analyzed": None,
"analysis_time": None, "detections": None}
# Logger for debugging and information
logger = logging.getLogger(__name__)
# Callback function to process video frames
# This function is called for each video frame in the WebRTC stream.
# It converts the frame to a numpy array in RGB format, analyzes the frame,
# and returns the original frame.
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
# Convert frame to numpy array in RGB format
img = frame.to_ndarray(format="rgb24")
analyze_frame(img) # Analyze the frame
return frame # Return the original frame
# Get ICE servers for WebRTC
ice_servers = get_ice_servers()
# Streamlit UI configuration
st.set_page_config(layout="wide")
# Custom CSS for the Streamlit page
st.markdown(
"""
<style>
.main {
padding: 2rem;
}
h1, h2, h3 {
font-family: 'Arial', sans-serif;
}
h1 {
font-weight: 700;
font-size: 2.5rem;
}
h2 {
font-weight: 600;
font-size: 2rem;
}
h3 {
font-weight: 500;
font-size: 1.5rem;
}
</style>
""",
unsafe_allow_html=True,
)
# Streamlit page title and subtitle
st.title("Computer Vision Playground")
# Add a link to the README file
st.markdown(
"""
<div style="text-align: left;">
<p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md"
target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p>
</div>
""",
unsafe_allow_html=True,
)
st.subheader(ANALYSIS_TITLE)
# Columns for input and output streams
col1, col2 = st.columns(2)
with col1:
st.header("Input Stream")
input_subheader = st.empty()
input_placeholder = st.empty() # Placeholder for input frame
st.subheader("Input Options")
# WebRTC streamer to get video input from the webcam
webrtc_ctx = webrtc_streamer(
key="input-webcam",
mode=WebRtcMode.SENDONLY,
rtc_configuration=ice_servers,
video_frame_callback=video_frame_callback,
media_stream_constraints={"video": True, "audio": False},
async_processing=True,
)
# File uploader for images
st.subheader("Upload an Image")
uploaded_file = st.file_uploader(
"Choose an image...", type=["jpg", "jpeg", "png"])
# # Text input for image URL
# st.subheader("Or Enter Image URL")
# image_url = st.text_input("Image URL")
# # Text input for YouTube URL
# st.subheader("Enter a YouTube URL")
# youtube_url = st.text_input("YouTube URL")
# # File uploader for videos
# st.subheader("Upload a Video")
# uploaded_video = st.file_uploader(
# "Choose a video...", type=["mp4", "avi", "mov", "mkv"]
# )
# # Text input for video URL
# st.subheader("Or Enter Video Download URL")
# video_url = st.text_input("Video URL")
# # Streamlit footer
# st.markdown(
# """
# <div style="text-align: center; margin-top: 2rem;">
# <p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p>
# </div>
# """,
# unsafe_allow_html=True
# )
# Function to initialize the analysis UI
# This function sets up the placeholders and UI elements in the analysis section.
# It creates placeholders for input and output frames, analysis time, and detected labels.
def analysis_init():
global analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder
with col2:
st.header("Analysis")
input_subheader.subheader("Input Frame")
st.subheader("Output Frame")
output_placeholder = st.empty() # Placeholder for output frame
analysis_time = st.empty() # Placeholder for analysis time
show_labels = st.checkbox(
"Show the detected labels", value=True
) # Checkbox to show/hide labels
labels_placeholder = st.empty() # Placeholder for labels
# Function to publish frames and results to the Streamlit UI
# This function retrieves the latest frames and results from the global container and result queue,
# and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels.
def publish_frame():
img = img_container["input"]
if img is None:
return
input_placeholder.image(img, channels="RGB") # Display the input frame
analyzed = img_container["analyzed"]
if analyzed is None:
return
# Display the analyzed frame
output_placeholder.image(analyzed, channels="RGB")
time = img_container["analysis_time"]
if time is None:
return
# Display the analysis time
analysis_time.text(f"Analysis Time: {time} ms")
detections = img_container["detections"]
if detections is None:
return
if show_labels:
labels_placeholder.table(
detections
) # Display labels if the checkbox is checked
# If the WebRTC streamer is playing, initialize and publish frames
if webrtc_ctx.state.playing:
analysis_init() # Initialize the analysis UI
while True:
publish_frame() # Publish the frames and results
time.sleep(0.1) # Delay to control frame rate
# If an image is uploaded or a URL is provided, process the image
if uploaded_file is not None or image_url:
analysis_init() # Initialize the analysis UI
if uploaded_file is not None:
image = Image.open(uploaded_file) # Open the uploaded image
img = np.array(image.convert("RGB")) # Convert the image to RGB format
else:
response = requests.get(image_url) # Download the image from the URL
# Open the downloaded image
image = Image.open(BytesIO(response.content))
img = np.array(image.convert("RGB")) # Convert the image to RGB format
analyze_frame(img) # Analyze the image
publish_frame() # Publish the results
# Function to process video files
# This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis,
# and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels.
def process_video(video_path):
cap = cv2.VideoCapture(video_path) # Open the video file
while cap.isOpened():
ret, frame = cap.read() # Read a frame from the video
if not ret:
break # Exit the loop if no more frames are available
# Convert the frame from BGR to RGB format
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
# Analyze the frame for face detection and sentiment analysis
analyze_frame(rgb_frame)
publish_frame() # Publish the results
cap.release() # Release the video capture object
# Function to get the video stream URL from YouTube using yt-dlp
def get_youtube_stream_url(youtube_url):
ydl_opts = {
'format': 'best[ext=mp4]',
'quiet': True,
}
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info_dict = ydl.extract_info(youtube_url, download=False)
stream_url = info_dict['url']
return stream_url
# If a YouTube URL is provided, process the video
if youtube_url:
analysis_init() # Initialize the analysis UI
stream_url = get_youtube_stream_url(youtube_url)
process_video(stream_url) # Process the video
# If a video is uploaded or a URL is provided, process the video
if uploaded_video is not None or video_url:
analysis_init() # Initialize the analysis UI
if uploaded_video is not None:
video_path = uploaded_video.name # Get the name of the uploaded video
with open(video_path, "wb") as f:
# Save the uploaded video to a file
f.write(uploaded_video.getbuffer())
else:
# Download the video from the URL
video_path = download_file(video_url)
process_video(video_path) # Process the video