from ultralytics import YOLO import time import os import logging import tempfile import av import cv2 import numpy as np import streamlit as st from streamlit_webrtc import WebRtcMode, webrtc_streamer from utils.download import download_file from utils.turn import get_ice_servers from PIL import Image import requests from io import BytesIO # CHANGE CODE BELOW HERE, USE TO REPLACE WITH YOUR WANTED ANALYSIS. # Update below string to set display title of analysis ANALYSIS_TITLE = "YOLO-8 Object Detection, Pose Estimation, and Action Detection" # Load the YOLOv8 models pose_model = YOLO("yolov8n-pose.pt") object_model = YOLO("yolov8n.pt") def detect_action(keypoints, prev_keypoints=None): keypoint_dict = { 0: "Nose", 1: "Left Eye", 2: "Right Eye", 3: "Left Ear", 4: "Right Ear", 5: "Left Shoulder", 6: "Right Shoulder", 7: "Left Elbow", 8: "Right Elbow", 9: "Left Wrist", 10: "Right Wrist", 11: "Left Hip", 12: "Right Hip", 13: "Left Knee", 14: "Right Knee", 15: "Left Ankle", 16: "Right Ankle" } confidence_threshold = 0.5 movement_threshold = 0.05 def get_keypoint(idx): if idx < len(keypoints[0]): x, y, conf = keypoints[0][idx] return np.array([x, y]) if conf > confidence_threshold else None return None def calculate_angle(a, b, c): if a is None or b is None or c is None: return None ba = a - b bc = c - b cosine_angle = np.dot(ba, bc) / \ (np.linalg.norm(ba) * np.linalg.norm(bc)) angle = np.arccos(cosine_angle) return np.degrees(angle) def calculate_movement(current, previous): if current is None or previous is None: return None return np.linalg.norm(current - previous) nose = get_keypoint(0) left_shoulder = get_keypoint(5) right_shoulder = get_keypoint(6) left_elbow = get_keypoint(7) right_elbow = get_keypoint(8) left_wrist = get_keypoint(9) right_wrist = get_keypoint(10) left_hip = get_keypoint(11) right_hip = get_keypoint(12) left_knee = get_keypoint(13) right_knee = get_keypoint(14) left_ankle = get_keypoint(15) right_ankle = get_keypoint(16) if all(kp is None for kp in [nose, left_shoulder, right_shoulder, left_hip, right_hip, left_ankle, right_ankle]): return "waiting" # Calculate midpoints shoulder_midpoint = (left_shoulder + right_shoulder) / \ 2 if left_shoulder is not None and right_shoulder is not None else None hip_midpoint = (left_hip + right_hip) / \ 2 if left_hip is not None and right_hip is not None else None ankle_midpoint = (left_ankle + right_ankle) / \ 2 if left_ankle is not None and right_ankle is not None else None # Calculate angles spine_angle = calculate_angle( shoulder_midpoint, hip_midpoint, ankle_midpoint) left_arm_angle = calculate_angle(left_shoulder, left_elbow, left_wrist) right_arm_angle = calculate_angle(right_shoulder, right_elbow, right_wrist) left_leg_angle = calculate_angle(left_hip, left_knee, left_ankle) right_leg_angle = calculate_angle(right_hip, right_knee, right_ankle) # Calculate movement movement = None if prev_keypoints is not None: prev_ankle_midpoint = ((prev_keypoints[0][15][:2] + prev_keypoints[0][16][:2]) / 2 if len(prev_keypoints[0]) > 16 else None) movement = calculate_movement(ankle_midpoint, prev_ankle_midpoint) # Detect actions if spine_angle is not None: if spine_angle > 160: if movement is not None and movement > movement_threshold: if movement > movement_threshold * 3: return "running" else: return "walking" return "standing" elif 70 < spine_angle < 110: return "sitting" elif spine_angle < 30: return "lying" # Detect pointing if (left_arm_angle is not None and left_arm_angle > 150) or (right_arm_angle is not None and right_arm_angle > 150): return "pointing" # Detect kicking if (left_leg_angle is not None and left_leg_angle > 120) or (right_leg_angle is not None and right_leg_angle > 120): return "kicking" # Detect hitting if ((left_arm_angle is not None and 80 < left_arm_angle < 120) or (right_arm_angle is not None and 80 < right_arm_angle < 120)): if movement is not None and movement > movement_threshold * 2: return "hitting" return "waiting" def analyze_frame(frame: np.ndarray): start_time = time.time() img_container["input"] = frame frame = frame.copy() detections = [] if show_labels in ["Object Detection", "Both"]: # Run YOLOv8 object detection on the frame object_results = object_model(frame, conf=0.5) for i, box in enumerate(object_results[0].boxes): class_id = int(box.cls) detection = { "label": object_model.names[class_id], "score": float(box.conf), "box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()] } detections.append(detection) if show_labels in ["Pose Estimation", "Both"]: # Run YOLOv8 pose estimation on the frame pose_results = pose_model(frame, conf=0.5) for i, box in enumerate(pose_results[0].boxes): class_id = int(box.cls) detection = { "label": pose_model.names[class_id], "score": float(box.conf), "box_coords": [round(value.item(), 2) for value in box.xyxy.flatten()] } # Get keypoints for this detection if available try: if pose_results[0].keypoints is not None: keypoints = pose_results[0].keypoints[i].data.cpu().numpy() # Detect action using the keypoints prev_keypoints = img_container.get("prev_keypoints") action = detect_action(keypoints, prev_keypoints) detection["action"] = action # Store current keypoints for next frame img_container["prev_keypoints"] = keypoints # Calculate the average position of visible keypoints visible_keypoints = keypoints[0][keypoints[0] [:, 2] > 0.5][:, :2] if len(visible_keypoints) > 0: label_x, label_y = np.mean( visible_keypoints, axis=0).astype(int) else: # Fallback to the center of the bounding box if no keypoints are visible x1, y1, x2, y2 = detection["box_coords"] label_x = int((x1 + x2) / 2) label_y = int((y1 + y2) / 2) else: detection["action"] = "No keypoint data" # Use the center of the bounding box for label position x1, y1, x2, y2 = detection["box_coords"] label_x = int((x1 + x2) / 2) label_y = int((y1 + y2) / 2) except IndexError: detection["action"] = "Action detection failed" # Use the center of the bounding box for label position x1, y1, x2, y2 = detection["box_coords"] label_x = int((x1 + x2) / 2) label_y = int((y1 + y2) / 2) # Only display the action as the label label = detection.get('action', '') # Increase font scale and thickness to match box label size font_scale = 2.0 thickness = 2 # Get text size for label (label_width, label_height), _ = cv2.getTextSize( label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, thickness) # Calculate position for centered label label_y = label_y - 10 # 10 pixels above the calculated position # Draw yellow background for label cv2.rectangle(frame, (label_x - label_width // 2 - 5, label_y - label_height - 5), (label_x + label_width // 2 + 5, label_y + 5), (0, 255, 255), -1) # Draw black text for label cv2.putText(frame, label, (label_x - label_width // 2, label_y), cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), thickness) detections.append(detection) # Draw detections on the frame if show_labels == "Object Detection": frame = object_results[0].plot() elif show_labels == "Pose Estimation": frame = pose_results[0].plot(boxes=False, labels=False, kpt_line=True) else: # Both frame = object_results[0].plot() frame = pose_results[0].plot( boxes=False, labels=False, kpt_line=True, img=frame) end_time = time.time() execution_time_ms = round((end_time - start_time) * 1000, 2) img_container["analysis_time"] = execution_time_ms img_container["detections"] = detections img_container["analyzed"] = frame return # # # # DO NOT TOUCH THE BELOW CODE (NOT NEEDED) # # # Suppress FFmpeg logs os.environ["FFMPEG_LOG_LEVEL"] = "quiet" # Suppress Streamlit logs using the logging module logging.getLogger("streamlit").setLevel(logging.ERROR) # Container to hold image data and analysis results img_container = {"input": None, "analyzed": None, "analysis_time": None, "detections": None} # Logger for debugging and information logger = logging.getLogger(__name__) # Callback function to process video frames # This function is called for each video frame in the WebRTC stream. # It converts the frame to a numpy array in RGB format, analyzes the frame, # and returns the original frame. def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: # Convert frame to numpy array in RGB format img = frame.to_ndarray(format="rgb24") analyze_frame(img) # Analyze the frame return frame # Return the original frame # Get ICE servers for WebRTC ice_servers = get_ice_servers() # Streamlit UI configuration st.set_page_config(layout="wide") # Custom CSS for the Streamlit page st.markdown( """ """, unsafe_allow_html=True, ) # Streamlit page title and subtitle st.title(ANALYSIS_TITLE) st.subheader("A Computer Vision Playground") # Add a link to the README file st.markdown( """

See the README to learn how to use this code to help you start your computer vision exploration.

""", unsafe_allow_html=True, ) # Columns for input and output streams col1, col2 = st.columns(2) with col1: st.header("Input Stream") input_subheader = st.empty() input_placeholder = st.empty() # Placeholder for input frame st.subheader("Input Options") # WebRTC streamer to get video input from the webcam webrtc_ctx = webrtc_streamer( key="input-webcam", mode=WebRtcMode.SENDONLY, rtc_configuration=ice_servers, video_frame_callback=video_frame_callback, media_stream_constraints={"video": True, "audio": False}, async_processing=True, ) # File uploader for images st.subheader("Upload an Image") uploaded_file = st.file_uploader( "Choose an image...", type=["jpg", "jpeg", "png"]) # Text input for image URL st.subheader("Or Enter Image URL") image_url = st.text_input("Image URL") # Text input for YouTube URL st.subheader("Enter a YouTube URL") youtube_url = st.text_input("YouTube URL") yt_error = st.empty() # Placeholder for analysis time # File uploader for videos st.subheader("Upload a Video") uploaded_video = st.file_uploader( "Choose a video...", type=["mp4", "avi", "mov", "mkv"] ) # Text input for video URL st.subheader("Or Enter Video Download URL") video_url = st.text_input("Video URL") # Streamlit footer st.markdown( """

If you want to set up your own computer vision playground see here.

""", unsafe_allow_html=True ) # Function to initialize the analysis UI # This function sets up the placeholders and UI elements in the analysis section. # It creates placeholders for input and output frames, analysis time, and detected labels. def analysis_init(): global progress_bar, status_text, download_button, yt_error, analysis_time, show_labels, labels_placeholder, input_subheader, input_placeholder, output_placeholder yt_error.empty() # Placeholder for analysis time with col2: st.header("Analysis") input_subheader.subheader("Input Frame") st.subheader("Output Frame") output_placeholder = st.empty() # Placeholder for output frame analysis_time = st.empty() # Placeholder for analysis time show_labels = st.radio( "Choose Detection Type", ("Object Detection", "Pose Estimation", "Both"), index=2 # Set default to "Both" (index 2) ) # Create a progress bar progress_bar = st.empty() status_text = st.empty() labels_placeholder = st.empty() # Placeholder for labels download_button = st.empty() # Placeholder for download button # Function to publish frames and results to the Streamlit UI # This function retrieves the latest frames and results from the global container and result queue, # and updates the placeholders in the Streamlit UI with the current input frame, analyzed frame, analysis time, and detected labels. def publish_frame(): img = img_container["input"] if img is None: return input_placeholder.image(img, channels="RGB") # Display the input frame analyzed = img_container["analyzed"] if analyzed is None: return # Display the analyzed frame output_placeholder.image(analyzed, channels="RGB") time = img_container["analysis_time"] if time is None: return # Display the analysis time analysis_time.text(f"Analysis Time: {time} ms") detections = img_container["detections"] if detections is None: return if show_labels: labels_placeholder.table( detections ) # Display labels if the checkbox is checked # If the WebRTC streamer is playing, initialize and publish frames if webrtc_ctx.state.playing: analysis_init() # Initialize the analysis UI while True: publish_frame() # Publish the frames and results time.sleep(0.1) # Delay to control frame rate # If an image is uploaded or a URL is provided, process the image if uploaded_file is not None or image_url: analysis_init() # Initialize the analysis UI if uploaded_file is not None: image = Image.open(uploaded_file) # Open the uploaded image img = np.array(image.convert("RGB")) # Convert the image to RGB format else: response = requests.get(image_url) # Download the image from the URL # Open the downloaded image image = Image.open(BytesIO(response.content)) img = np.array(image.convert("RGB")) # Convert the image to RGB format analyze_frame(img) # Analyze the image publish_frame() # Publish the results # Function to process video files # This function reads frames from a video file, analyzes each frame for face detection and sentiment analysis, # and updates the Streamlit UI with the current input frame, analyzed frame, and detected labels. # Function to process video files def process_video(video_path): cap = cv2.VideoCapture(video_path) # Open the video file # Create a temporary file for the annotated video with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_video: temp_video_path = temp_video.name # save_annotated_video(video_path, temp_video_path) width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) fps = int(cap.get(cv2.CAP_PROP_FPS)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) fourcc = cv2.VideoWriter_fourcc(*'mp4v') out = cv2.VideoWriter(temp_video_path, fourcc, fps, (width, height)) frame_count = 0 while cap.isOpened(): ret, frame = cap.read() # Read a frame from the video if not ret: break # Exit the loop if no more frames are available # Convert the frame from BGR to RGB format rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) # Analyze the frame for face detection and sentiment analysis analyze_frame(rgb_frame) analyzed_frame = img_container["analyzed"] if analyzed_frame is not None: out.write(cv2.cvtColor(analyzed_frame, cv2.COLOR_RGB2BGR)) publish_frame() # Publish the results # Update progress frame_count += 1 progress = min(100, int(frame_count / total_frames * 100)) progress_bar.progress(progress) status_text.text(f"Processing video: {progress}% complete") cap.release() # Release the video capture object out.release() # Add download button for annotated video with open(temp_video_path, "rb") as file: download_button.download_button( label="Download Annotated Video", data=file, file_name="annotated_video.mp4", mime="video/mp4" ) # Clean up the temporary file os.unlink(temp_video_path) # Function to get video URL using Cobalt API def get_cobalt_video_url(youtube_url): cobalt_api_url = "https://api.cobalt.tools/api/json" headers = { "Accept": "application/json", "Content-Type": "application/json" } payload = { "url": youtube_url, "vCodec": "h264", "vQuality": "720", "aFormat": "mp3", "isAudioOnly": False } try: response = requests.post(cobalt_api_url, headers=headers, json=payload) response.raise_for_status() data = response.json() if data['status'] == 'stream': return data['url'] elif data['status'] == 'redirect': return data['url'] else: yt_error.error(f"Error: {data['text']}") return None except requests.exceptions.RequestException as e: yt_error.error(f"Error: Unable to process the YouTube URL. {str(e)}") return None # If a YouTube URL is provided, process the video if youtube_url: analysis_init() # Initialize the analysis UI stream_url = get_cobalt_video_url(youtube_url) # stream_url = get_youtube_stream_url(youtube_url) if stream_url: process_video(stream_url) # Process the video else: yt_error.error( "Unable to process the YouTube video. Please try a different URL or video format.") # If a video is uploaded or a URL is provided, process the video if uploaded_video is not None or video_url: analysis_init() # Initialize the analysis UI if uploaded_video is not None: video_path = uploaded_video.name # Get the name of the uploaded video with open(video_path, "wb") as f: # Save the uploaded video to a file f.write(uploaded_video.getbuffer()) else: # Download the video from the URL video_path = download_file(video_url) process_video(video_path) # Process the video