File size: 5,115 Bytes
bba6cec
eac73a1
 
 
1c81ee4
eac73a1
df8ec21
 
 
eac73a1
 
 
df8ec21
bba6cec
 
 
 
a6c914a
768810e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bba6cec
 
df8ec21
 
 
4a3054c
 
 
 
 
 
bba6cec
df8ec21
bba6cec
 
 
2dd8b34
df8ec21
 
 
 
 
506c444
eac73a1
506c444
f33062d
 
 
506c444
768810e
495fa7e
 
 
 
768810e
bba6cec
eac73a1
bba6cec
1c15637
f33062d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17fe806
 
768810e
2dd8b34
df8ec21
bba6cec
eac73a1
 
a472ccb
eac73a1
 
f33062d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import logging
import queue
from pathlib import Path
from typing import List, NamedTuple
import mediapipe as mp
import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from sample_utils.download import download_file
from sample_utils.turn import get_ice_servers

# Logging setup
logger = logging.getLogger(__name__)

# Streamlit setup
st.set_page_config(page_title="AI Squat Detection", page_icon="🏋️")
st.markdown(
    """<style>
    .status-box {
        background: #f7f7f7;
        padding: 15px;
        border-radius: 8px;
        box-shadow: 2px 2px 5px rgba(0,0,0,0.1);
        margin-bottom: 20px;
        font-size: 18px;
    }
    .title {
        color: #2E86C1;
        font-size: 32px;
        font-weight: bold;
        text-align: center;
        margin-bottom: 10px;
    }
    .info {
        text-align: center;
        font-size: 18px;
        margin-bottom: 20px;
        color: #333;
    }
    </style>""", unsafe_allow_html=True)

st.markdown('<div class="title">AI Squat Detection</div>', unsafe_allow_html=True)
st.markdown('<div class="info">Use your webcam for real-time squat detection.</div>', unsafe_allow_html=True)

# Initialize MediaPipe components
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

class Detection(NamedTuple):
    class_id: int
    label: str
    score: float
    box: np.ndarray

# Angle calculation function
def calculate_angle(a, b, c):
    a = np.array(a)
    b = np.array(b)
    c = np.array(c)
    radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
    angle = np.abs(radians * 180.0 / np.pi)
    if angle > 180.0:
        angle = 360 - angle
    return angle

# Detection Queue
result_queue: "queue.Queue[List[Detection]]" = queue.Queue()

# Initialize MediaPipe Pose once
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)

def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
    global counterL, correct, incorrect, stage
    if 'stage' not in globals():
        stage = 'up'
        correct = 0
        incorrect = 0

    image = frame.to_ndarray(format="bgr24")
    h, w = image.shape[:2]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    results = pose.process(image_rgb)
    landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []

    detections = [
        Detection(
            class_id=0, label="Pose", score=0.5, box=np.array([0, 0, w, h])
        )
    ] if landmarks else []

    if landmarks:
        hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
        kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y]
        ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
        shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]

        angleKneeL = calculate_angle(hipL, kneeL, ankleL)
        angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])

        rel_point1 = (int(w * 0), int(h - h * 0.55))
        rel_point2 = (int(w * 0.265625), int(h - h * 0.55))

        cv2.rectangle(image, (0, 110), (280, 225), (127, 248, 236), -1)
        cv2.rectangle(image, (0, 113), (277, 222), (12, 85, 61), -1)
        cv2.putText(image, str(int(angleHipL)), rel_point1, cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(image, str(int(angleKneeL)), rel_point2, cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)

        if angleKneeL > 110 and stage == 'down':
            stage = 'up'
            if 18 < angleHipL < 40:
                correct += 1

        if 80 < angleKneeL < 110 and stage == 'up':
            stage = 'down'

    cv2.rectangle(image, (0, 0), (280, 103), (127, 248, 236), -1)
    cv2.rectangle(image, (0, 3), (277, 100), (12, 85, 61), -1)

    cv2.putText(image, 'Left', (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(image, str(correct), (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(image, 'STAGE', (180, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(image, stage, (147, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)

    mp_drawing.draw_landmarks(
        image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
        mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),
        mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2)
    )

    result_queue.put(detections)
    return av.VideoFrame.from_ndarray(image, format="bgr24")

webrtc_streamer(
    key="squat-detection",
    mode=WebRtcMode.SENDRECV,
    rtc_configuration={"iceServers": get_ice_servers(), "iceTransportPolicy": "relay"},
    media_stream_constraints={"video": True, "audio": False},
    video_frame_callback=video_frame_callback,
    async_processing=True,
)