Spaces:
Sleeping
Sleeping
from ultralytics import YOLO | |
import time | |
import os | |
import logging | |
import tempfile | |
import av | |
import cv2 | |
import numpy as np | |
import streamlit as st | |
from streamlit_webrtc import WebRtcMode, webrtc_streamer | |
from utils.download import download_file | |
from utils.turn import get_ice_servers | |
from PIL import Image | |
import requests | |
from io import BytesIO | |
# CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS. | |
# Update below string to set display title of analysis | |
ANALYSIS_TITLE = "YOLO-8 Object Detection, Pose Estimation, and Action Detection" | |
# Load the YOLOv8 models | |
pose_model = YOLO("yolov8n-pose.pt") | |
object_model = YOLO("yolov8n.pt") | |
def detect_action(keypoints, prev_keypoints=None): | |
keypoint_dict = { | |
0: "Nose", 1: "Left Eye", 2: "Right Eye", 3: "Left Ear", 4: "Right Ear", | |
5: "Left Shoulder", 6: "Right Shoulder", 7: "Left Elbow", 8: "Right Elbow", | |
9: "Left Wrist", 10: "Right Wrist", 11: "Left Hip", 12: "Right Hip", | |
13: "Left Knee", 14: "Right Knee", 15: "Left Ankle", 16: "Right Ankle" | |
} | |
confidence_threshold = 0.5 | |
movement_threshold = 0.05 | |
def get_keypoint(idx): | |
if idx < len(keypoints[0]): | |
x, y, conf = keypoints[0][idx] | |
return np.array([x, y]) if conf > confidence_threshold else None | |
return None | |
def calculate_angle(a, b, c): | |
if a is None or b is None or c is None: | |
return None | |
ba = a - b | |
bc = c - b | |
cosine_angle = np.dot(ba, bc) / \ | |
(np.linalg.norm(ba) * np.linalg.norm(bc)) | |
angle = np.arccos(cosine_angle) | |
return np.degrees(angle) | |
def calculate_movement(current, previous): | |
if current is None or previous is None: | |
return None | |
return np.linalg.norm(current - previous) | |
nose = get_keypoint(0) | |
left_shoulder = get_keypoint(5) | |
right_shoulder = get_keypoint(6) | |
left_elbow = get_keypoint(7) | |
right_elbow = get_keypoint(8) | |
left_wrist = get_keypoint(9) | |
right_wrist = get_keypoint(10) | |
left_hip = get_keypoint(11) | |
right_hip = get_keypoint(12) | |
left_knee = get_keypoint(13) | |
right_knee = get_keypoint(14) | |
left_ankle = get_keypoint(15) | |
right_ankle = get_keypoint(16) | |
if all(kp is None for kp in [nose, left_shoulder, right_shoulder, left_hip, right_hip, left_ankle, right_ankle]): | |
return "waiting" | |
# Calculate midpoints | |
shoulder_midpoint = (left_shoulder + right_shoulder) / \ | |
2 if left_shoulder is not None and right_shoulder is not None else None | |
hip_midpoint = (left_hip + right_hip) / \ | |
2 if left_hip is not None and right_hip is not None else None | |
ankle_midpoint = (left_ankle + right_ankle) / \ | |
2 if left_ankle is not None and right_ankle is not None else None | |
# Calculate angles | |
spine_angle = calculate_angle( | |
shoulder_midpoint, hip_midpoint, ankle_midpoint) | |
left_arm_angle = calculate_angle(left_shoulder, left_elbow, left_wrist) | |
right_arm_angle = calculate_angle(right_shoulder, right_elbow, right_wrist) | |
left_leg_angle = calculate_angle(left_hip, left_knee, left_ankle) | |
right_leg_angle = calculate_angle(right_hip, right_knee, right_ankle) | |
# Calculate movement | |
movement = None | |
if prev_keypoints is not None: | |
prev_ankle_midpoint = ((prev_keypoints[0][15][:2] + prev_keypoints[0][16][:2]) / 2 | |
if len(prev_keypoints[0]) > 16 else None) | |
movement = calculate_movement(ankle_midpoint, prev_ankle_midpoint) | |
# Detect actions | |
if spine_angle is not None: | |
if spine_angle > 160: | |
if movement is not None and movement > movement_threshold: | |
if movement > movement_threshold * 3: | |
return "running" | |
else: | |
return "walking" | |
return "standing" | |
elif 70 < spine_angle < 110: | |
return "sitting" | |
elif spine_angle < 30: | |
return "lying" | |
# Detect pointing | |
if (left_arm_angle is not None and left_arm_angle > 150) or (right_arm_angle is not None and right_arm_angle > 150): | |
return "pointing" | |
# Detect kicking | |
if (left_leg_angle is not None and left_leg_angle > 120) or (right_leg_angle is not None and right_leg_angle > 120): | |
return "kicking" | |
# Detect hitting | |
if ((left_arm_angle is not None and 80 < left_arm_angle < 120) or | |
(right_arm_angle is not None and 80 < right_arm_angle < 120)): | |
if movement is not None and movement > movement_threshold * 2: | |
return "hitting" | |
return "waiting" | |
def analyze_frame(frame: np.ndarray): | |
start_time = time.time() | |
img_container["input"] = frame | |
frame = frame.copy() | |
detections = [] | |
if show_labels in ["Object Detection", "Both"]: | |
# Run YOLOv8 object detection on the frame | |
object_results = object_model(frame, conf=0.5) | |
for i, box in enumerate(object_results[0].boxes): | |
class_id = int(box.cls) | |
detection = { | |
"label": object_model.names[class_id], | |
"score": float(box.conf), | |
"box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()] | |
} | |
detections.append(detection) | |
if show_labels in ["Pose Estimation", "Both"]: | |
# Run YOLOv8 pose estimation on the frame | |
pose_results = pose_model(frame, conf=0.5) | |
for i, box in enumerate(pose_results[0].boxes): | |
class_id = int(box.cls) | |
detection = { | |
"label": pose_model.names[class_id], | |
"score": float(box.conf), | |
"box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()] | |
} | |
# Get keypoints for this detection if available | |
try: | |
if pose_results[0].keypoints is not None: | |
keypoints = pose_results[0].keypoints[i].data.cpu().numpy() | |
# Detect action using the keypoints | |
prev_keypoints = img_container.get("prev_keypoints") | |
action = detect_action(keypoints, prev_keypoints) | |
detection["action"] = action | |
# Store current keypoints for next frame | |
img_container["prev_keypoints"] = keypoints | |
# Calculate the average position of visible keypoints | |
visible_keypoints = keypoints[0][keypoints[0] | |
[:, 2] > 0.5][:, :2] | |
if len(visible_keypoints) > 0: | |
label_x, label_y = np.mean( | |
visible_keypoints, axis=0).astype(int) | |
else: | |
# Fallback to the center of the bounding box if no keypoints are visible | |
x1, y1, x2, y2 = detection["box_coords"] | |
label_x = int((x1 + x2) / 2) | |
label_y = int((y1 + y2) / 2) | |
else: | |
detection["action"] = "No keypoint data" | |
# Use the center of the bounding box for label position | |
x1, y1, x2, y2 = detection["box_coords"] | |
label_x = int((x1 + x2) / 2) | |
label_y = int((y1 + y2) / 2) | |
except IndexError: | |
detection["action"] = "Action detection failed" | |
# Use the center of the bounding box for label position | |
x1, y1, x2, y2 = detection["box_coords"] | |
label_x = int((x1 + x2) / 2) | |
label_y = int((y1 + y2) / 2) | |
# Only display the action as the label | |
label = detection.get('action', '') | |
# Increase font scale and thickness to match box label size | |
font_scale = 2.0 | |
thickness = 2 | |
# Get text size for label | |
(label_width, label_height), _ = cv2.getTextSize( | |
label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) | |
# Calculate position for centered label | |
label_y = label_y - 10 # 10 pixels above the calculated position | |
# Draw yellow background for label | |
cv2.rectangle(frame, (label_x - label_width // 2 - 5, label_y - label_height - 5), | |
(label_x + label_width // 2 + 5, label_y + 5), (0, 255, 255), -1) | |
# Draw black text for label | |
cv2.putText(frame, label, (label_x - label_width // 2, label_y), | |
cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness) | |
detections.append(detection) | |
# Draw detections on the frame | |
if show_labels == "Object Detection": | |
frame = object_results[0].plot() | |
elif show_labels == "Pose Estimation": | |
frame = pose_results[0].plot(boxes=False, labels=False, kpt_line=True) | |
else: # Both | |
frame = object_results[0].plot() | |
frame = pose_results[0].plot( | |
boxes=False, labels=False, kpt_line=True, img=frame) | |
end_time = time.time() | |
execution_time_ms = round((end_time - start_time) * 1000, 2) | |
img_container["analysis_time"] = execution_time_ms | |
img_container["detections"] = detections | |
img_container["analyzed"] = frame | |
return | |
# | |
# | |
# | |
# DO NOT TOUCH THE BELOW CODE (NOT NEEDED) | |
# | |
# | |
# Suppress FFmpeg logs | |
os.environ["FFMPEG_LOG_LEVEL"] = "quiet" | |
# Suppress Streamlit logs using the logging module | |
logging.getLogger("streamlit").setLevel(logging.ERROR) | |
# Container to hold image data and analysis results | |
img_container = {"input": None, "analyzed": None, | |
"analysis_time": None, "detections": None} | |
# Logger for debugging and information | |
logger = logging.getLogger(__name__) | |
# Callback function to process video frames | |
# This function is called for each video frame in the WebRTC stream. | |
# It converts the frame to a numpy array in RGB format, analyzes the frame, | |
# and returns the original frame. | |
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: | |
# Convert frame to numpy array in RGB format | |
img = frame.to_ndarray(format="rgb24") | |
analyze_frame(img) # Analyze the frame | |
return frame # Return the original frame | |
# Get ICE servers for WebRTC | |
ice_servers = get_ice_servers() | |
# Streamlit UI configuration | |
st.set_page_config(layout="wide") | |
# Custom CSS for the Streamlit page | |
st.markdown( | |
""" | |
<style> | |
.main { | |
padding: 2rem; | |
} | |
h1, h2, h3 { | |
font-family: 'Arial', sans-serif; | |
} | |
h1 { | |
font-weight: 700; | |
font-size: 2.5rem; | |
} | |
h2 { | |
font-weight: 600; | |
font-size: 2rem; | |
} | |
h3 { | |
font-weight: 500; | |
font-size: 1.5rem; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Streamlit page title and subtitle | |
st.title(ANALYSIS_TITLE) | |
st.subheader("A Computer Vision Playground") | |
# Add a link to the README file | |
st.markdown( | |
""" | |
<div style="text-align: left;"> | |
<p>See the <a href="https://huggingface.co/spaces/eusholli/sentiment-analyzer/blob/main/README.md" | |
target="_blank">README</a> to learn how to use this code to help you start your computer vision exploration.</p> | |
</div> | |
""", | |
unsafe_allow_html=True, | |
) | |
# Columns for input and output streams | |
col1, col2 = st.columns(2) | |
with col1: | |
st.header("Input Stream") | |
input_subheader = st.empty() | |
input_placeholder = st.empty() # Placeholder for input frame | |
st.subheader("Input Options") | |
# WebRTC streamer to get video input from the webcam | |
webrtc_ctx = webrtc_streamer( | |
key="input-webcam", | |
mode=WebRtcMode.SENDONLY, | |
rtc_configuration=ice_servers, | |
video_frame_callback=video_frame_callback, | |
media_stream_constraints={"video": True, "audio": False}, | |
async_processing=True, | |
) | |
# File uploader for images | |
st.subheader("Upload an Image") | |
uploaded_file = st.file_uploader( | |
"Choose an image...", type=["jpg", "jpeg", "png"]) | |
# Text input for image URL | |
st.subheader("Or Enter Image URL") | |
image_url = st.text_input("Image URL") | |
# Text input for YouTube URL | |
st.subheader("Enter a YouTube URL") | |
youtube_url = st.text_input("YouTube URL") | |
yt_error = st.empty() # Placeholder for analysis time | |
# File uploader for videos | |
st.subheader("Upload a Video") | |
uploaded_video = st.file_uploader( | |
"Choose a video...", type=["mp4", "avi", "mov", "mkv"] | |
) | |
# Text input for video URL | |
st.subheader("Or Enter Video Download URL") | |
video_url = st.text_input("Video URL") | |
# Streamlit footer | |
st.markdown( | |
""" | |
<div style="text-align: center; margin-top: 2rem;"> | |
<p>If you want to set up your own computer vision playground see <a href="https://huggingface.co/spaces/eusholli/computer-vision-playground/blob/main/README.md" target="_blank">here</a>.</p> | |
</div> | |
""", | |
unsafe_allow_html=True | |
) | |
# Function to initialize the analysis UI | |
# This function sets up the placeholders and UI elements in the analysis section. | |
# It creates placeholders for input and output frames, analysis time, and detected labels. | |
def analysis_init(): | |
global progress_bar, status_text, download_button, yt_error, analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder | |
yt_error.empty() # Placeholder for analysis time | |
with col2: | |
st.header("Analysis") | |
input_subheader.subheader("Input Frame") | |
st.subheader("Output Frame") | |
output_placeholder = st.empty() # Placeholder for output frame | |
analysis_time = st.empty() # Placeholder for analysis time | |
show_labels = st.radio( | |
"Choose Detection Type", | |
("Object Detection", "Pose Estimation", "Both"), | |
index=2 # Set default to "Both" (index 2) | |
) | |
# Create a progress bar | |
progress_bar = st.empty() | |
status_text = st.empty() | |
labels_placeholder = st.empty() # Placeholder for labels | |
download_button = st.empty() # Placeholder for download button | |
# Function to publish frames and results to the Streamlit UI | |
# This function retrieves the latest frames and results from the global container and result queue, | |
# and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels. | |
def publish_frame(): | |
img = img_container["input"] | |
if img is None: | |
return | |
input_placeholder.image(img, channels="RGB") # Display the input frame | |
analyzed = img_container["analyzed"] | |
if analyzed is None: | |
return | |
# Display the analyzed frame | |
output_placeholder.image(analyzed, channels="RGB") | |
time = img_container["analysis_time"] | |
if time is None: | |
return | |
# Display the analysis time | |
analysis_time.text(f"Analysis Time: {time} ms") | |
detections = img_container["detections"] | |
if detections is None: | |
return | |
if show_labels: | |
labels_placeholder.table( | |
detections | |
) # Display labels if the checkbox is checked | |
# If the WebRTC streamer is playing, initialize and publish frames | |
if webrtc_ctx.state.playing: | |
analysis_init() # Initialize the analysis UI | |
while True: | |
publish_frame() # Publish the frames and results | |
time.sleep(0.1) # Delay to control frame rate | |
# If an image is uploaded or a URL is provided, process the image | |
if uploaded_file is not None or image_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) # Open the uploaded image | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
else: | |
response = requests.get(image_url) # Download the image from the URL | |
# Open the downloaded image | |
image = Image.open(BytesIO(response.content)) | |
img = np.array(image.convert("RGB")) # Convert the image to RGB format | |
analyze_frame(img) # Analyze the image | |
publish_frame() # Publish the results | |
# Function to process video files | |
# This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis, | |
# and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels. | |
# Function to process video files | |
def process_video(video_path): | |
cap = cv2.VideoCapture(video_path) # Open the video file | |
# Create a temporary file for the annotated video | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video: | |
temp_video_path = temp_video.name | |
# save_annotated_video(video_path, temp_video_path) | |
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) | |
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) | |
fps = int(cap.get(cv2.CAP_PROP_FPS)) | |
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
fourcc = cv2.VideoWriter_fourcc(*'mp4v') | |
out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) | |
frame_count = 0 | |
while cap.isOpened(): | |
ret, frame = cap.read() # Read a frame from the video | |
if not ret: | |
break # Exit the loop if no more frames are available | |
# Convert the frame from BGR to RGB format | |
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
# Analyze the frame for face detection and sentiment analysis | |
analyze_frame(rgb_frame) | |
analyzed_frame = img_container["analyzed"] | |
if analyzed_frame is not None: | |
out.write(cv2.cvtColor(analyzed_frame, cv2.COLOR_RGB2BGR)) | |
publish_frame() # Publish the results | |
# Update progress | |
frame_count += 1 | |
progress = min(100, int(frame_count / total_frames * 100)) | |
progress_bar.progress(progress) | |
status_text.text(f"Processing video: {progress}% complete") | |
cap.release() # Release the video capture object | |
out.release() | |
# Add download button for annotated video | |
with open(temp_video_path, "rb") as file: | |
download_button.download_button( | |
label="Download Annotated Video", | |
data=file, | |
file_name="annotated_video.mp4", | |
mime="video/mp4" | |
) | |
# Clean up the temporary file | |
os.unlink(temp_video_path) | |
# Function to get video URL using Cobalt API | |
def get_cobalt_video_url(youtube_url): | |
cobalt_api_url = "https://api.cobalt.tools/api/json" | |
headers = { | |
"Accept": "application/json", | |
"Content-Type": "application/json" | |
} | |
payload = { | |
"url": youtube_url, | |
"vCodec": "h264", | |
"vQuality": "720", | |
"aFormat": "mp3", | |
"isAudioOnly": False | |
} | |
try: | |
response = requests.post(cobalt_api_url, headers=headers, json=payload) | |
response.raise_for_status() | |
data = response.json() | |
if data['status'] == 'stream': | |
return data['url'] | |
elif data['status'] == 'redirect': | |
return data['url'] | |
else: | |
yt_error.error(f"Error: {data['text']}") | |
return None | |
except requests.exceptions.RequestException as e: | |
yt_error.error(f"Error: Unable to process the YouTube URL. {str(e)}") | |
return None | |
# If a YouTube URL is provided, process the video | |
if youtube_url: | |
analysis_init() # Initialize the analysis UI | |
stream_url = get_cobalt_video_url(youtube_url) | |
# stream_url = get_youtube_stream_url(youtube_url) | |
if stream_url: | |
process_video(stream_url) # Process the video | |
else: | |
yt_error.error( | |
"Unable to process the YouTube video. Please try a different URL or video format.") | |
# If a video is uploaded or a URL is provided, process the video | |
if uploaded_video is not None or video_url: | |
analysis_init() # Initialize the analysis UI | |
if uploaded_video is not None: | |
video_path = uploaded_video.name # Get the name of the uploaded video | |
with open(video_path, "wb") as f: | |
# Save the uploaded video to a file | |
f.write(uploaded_video.getbuffer()) | |
else: | |
# Download the video from the URL | |
video_path = download_file(video_url) | |
process_video(video_path) # Process the video | |