|
import logging |
|
import queue |
|
from pathlib import Path |
|
from typing import List, NamedTuple |
|
import mediapipe as mp |
|
|
|
import av |
|
import cv2 |
|
import numpy as np |
|
import streamlit as st |
|
from streamlit_webrtc import WebRtcMode, webrtc_streamer |
|
|
|
from sample_utils.download import download_file |
|
from sample_utils.turn import get_ice_servers |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
st.set_page_config(page_title="AI Squat Detection", page_icon="🏋️", layout="wide") |
|
st.markdown( |
|
"""<style> |
|
.status-box { |
|
background: #f7f7f7; |
|
padding: 15px; |
|
border-radius: 8px; |
|
box-shadow: 2px 2px 5px rgba(0,0,0,0.1); |
|
margin-bottom: 20px; |
|
font-size: 18px; |
|
} |
|
.title { |
|
color: #2E86C1; |
|
font-size: 32px; |
|
font-weight: bold; |
|
text-align: center; |
|
margin-bottom: 10px; |
|
} |
|
.info { |
|
text-align: center; |
|
font-size: 18px; |
|
margin-bottom: 20px; |
|
color: #333; |
|
} |
|
</style>""", unsafe_allow_html=True) |
|
|
|
st.markdown('<div class="title">AI Squat Detection</div>', unsafe_allow_html=True) |
|
st.markdown('<div class="info">Use your webcam for real-time squat detection.</div>', unsafe_allow_html=True) |
|
|
|
|
|
mp_pose = mp.solutions.pose |
|
mp_drawing = mp.solutions.drawing_utils |
|
|
|
class Detection(NamedTuple): |
|
class_id: int |
|
label: str |
|
score: float |
|
box: np.ndarray |
|
|
|
|
|
def calculate_angle(a, b, c): |
|
a = np.array(a) |
|
b = np.array(b) |
|
c = np.array(c) |
|
radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0]) |
|
angle = np.abs(radians * 180.0 / np.pi) |
|
if angle > 180.0: |
|
angle = 360 - angle |
|
return angle |
|
|
|
|
|
|
|
result_queue: "queue.Queue[List[Detection]]" = queue.Queue() |
|
|
|
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: |
|
global counterL, correct, incorrect, stage |
|
if 'stage' not in globals(): |
|
stage = 'up' |
|
correct = 0 |
|
incorrect = 0 |
|
|
|
image = frame.to_ndarray(format="bgr24") |
|
h, w = image.shape[:2] |
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose: |
|
results = pose.process(image_rgb) |
|
landmarks = results.pose_landmarks.landmark if results.pose_landmarks else [] |
|
|
|
|
|
detections = [ |
|
Detection( |
|
class_id=0, |
|
label="Pose", |
|
score=0.5, |
|
box=np.array([0, 0, image.shape[1], image.shape[0]]) |
|
) if landmarks else [] |
|
|
|
|
|
if landmarks: |
|
hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y] |
|
kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y] |
|
ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y] |
|
shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y] |
|
|
|
|
|
angleKneeL = calculate_angle(hipL, kneeL, ankleL) |
|
angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0]) |
|
|
|
if angleKneeL > 110 and stage == 'down': |
|
stage = 'up' |
|
if 18 < angleHipL < 40: |
|
correct += 1 |
|
|
|
if 80 < angleKneeL < 110 and stage == 'up': |
|
stage = 'down' |
|
|
|
|
|
overlay_text = f"Correct: {correct} | Stage: {stage}" |
|
cv2.rectangle(image, (0, 0), (500, 80), (245, 117, 16), -1) |
|
cv2.putText(image, overlay_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA) |
|
|
|
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2)) |
|
|
|
result_queue.put(detections) |
|
|
|
return av.VideoFrame.from_ndarray(image, format="bgr24") |
|
|
|
webrtc_streamer( |
|
key="squat-detection", |
|
mode=WebRtcMode.SENDRECV, |
|
rtc_configuration={"iceServers": get_ice_servers(), "iceTransportPolicy": "relay"}, |
|
media_stream_constraints={"video": True, "audio": False}, |
|
video_frame_callback=video_frame_callback, |
|
async_processing=True, |
|
) |
|
|