|
import logging |
|
import queue |
|
from pathlib import Path |
|
from typing import List, NamedTuple |
|
import mediapipe as mp |
|
|
|
import av |
|
import cv2 |
|
import numpy as np |
|
import streamlit as st |
|
from streamlit_webrtc import WebRtcMode, webrtc_streamer |
|
|
|
from sample_utils.download import download_file |
|
from sample_utils.turn import get_ice_servers |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
st.title("AI Squat Detection using WebRTC") |
|
st.info("Use your webcam for real-time squat detection.") |
|
|
|
|
|
mp_pose = mp.solutions.pose |
|
mp_drawing = mp.solutions.drawing_utils |
|
|
|
class Detection(NamedTuple): |
|
class_id: int |
|
label: str |
|
score: float |
|
box: np.ndarray |
|
|
|
|
|
|
|
def calculate_angle(a, b, c): |
|
a = np.array(a) |
|
b = np.array(b) |
|
c = np.array(c) |
|
radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0]) |
|
angle = np.abs(radians * 180.0 / np.pi) |
|
if angle > 180.0: |
|
angle = 360 - angle |
|
return angle |
|
|
|
counterL=0 |
|
correct=0 |
|
incorrect=0 |
|
|
|
|
|
result_queue: "queue.Queue[List[Detection]]" = queue.Queue() |
|
|
|
def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame: |
|
global counterL, correct, incorrect, stage |
|
|
|
if 'stage' not in globals(): |
|
stage = 'up' |
|
correct = 0 |
|
incorrect = 0 |
|
|
|
image = frame.to_ndarray(format="bgr24") |
|
h, w = image.shape[:2] |
|
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose: |
|
results = pose.process(image_rgb) |
|
landmarks = results.pose_landmarks.landmark if results.pose_landmarks else [] |
|
|
|
|
|
detections = [ |
|
Detection( |
|
class_id=0, |
|
label="Pose", |
|
score=0.7, |
|
box=np.array([0, 0, image.shape[1], image.shape[0]]) |
|
) |
|
] if landmarks else [] |
|
|
|
if landmarks: |
|
hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y] |
|
kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y] |
|
ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y] |
|
shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y] |
|
footIndexL = [landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].x, |
|
landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].y] |
|
|
|
|
|
angleKneeL = calculate_angle(hipL, kneeL, ankleL) |
|
angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0]) |
|
angleAnkleL = calculate_angle(footIndexL, ankleL, kneeL) |
|
|
|
|
|
cv2.putText(image, str(angleHipL),tuple(np.multiply(angleHipL, [640, 480]).astype(int)),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA) |
|
|
|
|
|
|
|
|
|
|
|
if angleKneeL > 110 and stage == 'down': |
|
stage = 'up' |
|
if 18 < angleHipL < 40: |
|
correct += 1 |
|
|
|
if 80 < angleKneeL < 110 and stage == 'up': |
|
stage = 'down' |
|
|
|
|
|
if 10 < angleHipL < 18: |
|
cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1) |
|
cv2.putText(image,f"Bend Forward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(150,120,255),1,cv2.LINE_AA) |
|
|
|
if angleHipL > 45: |
|
cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1) |
|
cv2.putText(image,f"Bend Backward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(80,120,255),1,cv2.LINE_AA) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cv2.rectangle(image,(0,0),(500,80),(245,117,16),-1) |
|
|
|
|
|
|
|
cv2.putText(image,'Left',(10,12), |
|
cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA) |
|
|
|
cv2.putText(image,str(correct), |
|
(10,60),cv2.FONT_HERSHEY_SIMPLEX,2,(255,255,255),2,cv2.LINE_AA) |
|
|
|
|
|
|
|
cv2.putText(image,'STAGE',(230,12), |
|
cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA) |
|
|
|
cv2.putText(image,stage, |
|
(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA) |
|
|
|
|
|
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2)) |
|
|
|
result_queue.put(detections) |
|
return av.VideoFrame.from_ndarray(image, format="bgr24") |
|
|
|
|
|
|
|
|
|
|
|
webrtc_streamer( |
|
key="squat-detection", |
|
mode=WebRtcMode.SENDRECV, |
|
rtc_configuration={"iceServers": get_ice_servers(), "iceTransportPolicy": "relay"}, |
|
media_stream_constraints={"video": True, "audio": False}, |
|
video_frame_callback=video_frame_callback, |
|
async_processing=True, |
|
) |
|
|