File size: 5,426 Bytes
bba6cec eac73a1 1c81ee4 eac73a1 df8ec21 eac73a1 df8ec21 bba6cec b2ed51e bba6cec df8ec21 4a3054c bba6cec df8ec21 bba6cec 2dd8b34 df8ec21 506c444 eac73a1 506c444 b2ed51e 506c444 b2ed51e 495fa7e b2ed51e bba6cec c3ddf22 eac73a1 bba6cec 1c15637 b2ed51e 506c444 b2ed51e bba6cec b2ed51e 1c15637 52700a9 2dd8b34 52700a9 3427eee 52700a9 b2ed51e 38c92af b2ed51e 52700a9 b2ed51e 52700a9 b2ed51e 2dd8b34 df8ec21 bba6cec eac73a1 a472ccb eac73a1 b2ed51e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import logging
import queue
from pathlib import Path
from typing import List, NamedTuple
import mediapipe as mp
import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from sample_utils.download import download_file
from sample_utils.turn import get_ice_servers
# Logging setup
logger = logging.getLogger(__name__)
# Streamlit setup
st.set_page_config(page_title="AI Squat Detection", page_icon="🏋️")
st.markdown(
"""<style>
.status-box {
background: #f7f7f7;
padding: 15px;
border-radius: 8px;
box-shadow: 2px 2px 5px rgba(0,0,0,0.1);
margin-bottom: 20px;
font-size: 18px;
}
.title {
color: #2E86C1;
font-size: 32px;
font-weight: bold;
text-align: center;
margin-bottom: 10px;
}
.info {
text-align: center;
font-size: 18px;
margin-bottom: 20px;
color: #333;
}
</style>""", unsafe_allow_html=True)
st.markdown('<div class="title">AI Squat Detection</div>', unsafe_allow_html=True)
st.markdown('<div class="info">Use your webcam for real-time squat detection.</div>', unsafe_allow_html=True)
# Initialize MediaPipe components
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils
class Detection(NamedTuple):
class_id: int
label: str
score: float
box: np.ndarray
# Angle calculation function
def calculate_angle(a, b, c):
a = np.array(a)
b = np.array(b)
c = np.array(c)
radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
angle = np.abs(radians * 180.0 / np.pi)
if angle > 180.0:
angle = 360 - angle
return angle
# Detection Queue
result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
# Initialize MediaPipe Pose once
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
global counterL, correct, incorrect, stage
if 'stage' not in globals():
stage = 'up'
correct = 0
incorrect = 0
image = frame.to_ndarray(format="bgr24")
# Mirror the image horizontally
image = cv2.flip(image, 1) # Flip code 1 means horizontal flip
h, w = image.shape[:2]
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
results = pose.process(image_rgb)
landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []
detections = [
Detection(
class_id=0, label="Pose", score=0.5, box=np.array([0, 0, w, h])
)
] if landmarks else []
if landmarks:
hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y]
ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]
angleKneeL = calculate_angle(hipL, kneeL, ankleL)
angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])
rel_point1 = (int(w * 0), int(h - h * 0.65))
rel_point2 = (int(w * 0.17), int(h - h * 0.65))
cv2.rectangle(image, (0, 90), (200, 175), (127, 248, 236), -1)
cv2.rectangle(image, (0, 93), (197, 173), (12, 85, 61), -1)
cv2.putText(image, 'HipL', (10, 122), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
cv2.putText(image, 'KneeL', (125, 122),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
cv2.putText(image, str(int(angleHipL)), rel_point1, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)
cv2.putText(image, str(int(angleKneeL)), rel_point2, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)
if angleKneeL > 110 and stage == 'down':
stage = 'up'
if 18 < angleHipL < 40:
correct += 1
if angleKneeL < 110 and stage == 'up':
stage = 'down'
cv2.rectangle(image, (0, 0), (200, 83), (127, 248, 236), -1)
cv2.rectangle(image, (0, 3), (197, 80), (12, 85, 61), -1)
cv2.putText(image, 'Left', (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
cv2.putText(image, str(correct), (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)
cv2.putText(image, 'STAGE', (110, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
cv2.putText(image, stage, (77, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)
mp_drawing.draw_landmarks(
image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),
mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2)
)
result_queue.put(detections)
return av.VideoFrame.from_ndarray(image, format="bgr24")
webrtc_streamer(
key="squat-detection",
mode=WebRtcMode.SENDRECV,
rtc_configuration={"iceServers": get_ice_servers(), "iceTransportPolicy": "relay"},
media_stream_constraints={"video": True, "audio": False},
video_frame_callback=video_frame_callback,
async_processing=True,
) |