File size: 5,426 Bytes
bba6cec
eac73a1
 
 
1c81ee4
eac73a1
df8ec21
 
 
eac73a1
 
 
df8ec21
bba6cec
 
 
 
b2ed51e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bba6cec
 
df8ec21
 
 
4a3054c
 
 
 
 
 
bba6cec
df8ec21
bba6cec
 
 
2dd8b34
df8ec21
 
 
 
 
506c444
eac73a1
506c444
b2ed51e
 
 
506c444
b2ed51e
495fa7e
 
 
 
b2ed51e
bba6cec
c3ddf22
 
eac73a1
bba6cec
1c15637
b2ed51e
 
506c444
b2ed51e
 
 
 
 
 
 
 
 
 
 
bba6cec
b2ed51e
 
1c15637
52700a9
 
2dd8b34
52700a9
 
 
3427eee
52700a9
 
 
b2ed51e
 
 
 
 
 
38c92af
b2ed51e
 
52700a9
 
b2ed51e
 
52700a9
 
 
b2ed51e
 
 
 
 
 
 
 
 
2dd8b34
df8ec21
bba6cec
eac73a1
 
a472ccb
eac73a1
 
b2ed51e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import logging
import queue
from pathlib import Path
from typing import List, NamedTuple
import mediapipe as mp
import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer
from sample_utils.download import download_file
from sample_utils.turn import get_ice_servers

# Logging setup
logger = logging.getLogger(__name__)

# Streamlit setup
st.set_page_config(page_title="AI Squat Detection", page_icon="🏋️")
st.markdown(
    """<style>
    .status-box {
        background: #f7f7f7;
        padding: 15px;
        border-radius: 8px;
        box-shadow: 2px 2px 5px rgba(0,0,0,0.1);
        margin-bottom: 20px;
        font-size: 18px;
    }
    .title {
        color: #2E86C1;
        font-size: 32px;
        font-weight: bold;
        text-align: center;
        margin-bottom: 10px;
    }
    .info {
        text-align: center;
        font-size: 18px;
        margin-bottom: 20px;
        color: #333;
    }
    </style>""", unsafe_allow_html=True)

st.markdown('<div class="title">AI Squat Detection</div>', unsafe_allow_html=True)
st.markdown('<div class="info">Use your webcam for real-time squat detection.</div>', unsafe_allow_html=True)

# Initialize MediaPipe components
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

class Detection(NamedTuple):
    class_id: int
    label: str
    score: float
    box: np.ndarray

# Angle calculation function
def calculate_angle(a, b, c):
    a = np.array(a)
    b = np.array(b)
    c = np.array(c)
    radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
    angle = np.abs(radians * 180.0 / np.pi)
    if angle > 180.0:
        angle = 360 - angle
    return angle

# Detection Queue
result_queue: "queue.Queue[List[Detection]]" = queue.Queue()

# Initialize MediaPipe Pose once
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)

def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
    global counterL, correct, incorrect, stage
    if 'stage' not in globals():
        stage = 'up'
        correct = 0
        incorrect = 0

    image = frame.to_ndarray(format="bgr24")
    # Mirror the image horizontally
    image = cv2.flip(image, 1)  # Flip code 1 means horizontal flip
    h, w = image.shape[:2]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    results = pose.process(image_rgb)
    landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []

    detections = [
        Detection(
            class_id=0, label="Pose", score=0.5, box=np.array([0, 0, w, h])
        )
    ] if landmarks else []

    if landmarks:
        hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
        kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y]
        ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
        shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]

        angleKneeL = calculate_angle(hipL, kneeL, ankleL)
        angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])

        rel_point1 = (int(w * 0), int(h - h * 0.65))
        rel_point2 = (int(w * 0.17), int(h - h * 0.65))

        cv2.rectangle(image, (0, 90), (200, 175), (127, 248, 236), -1)
        cv2.rectangle(image, (0, 93), (197, 173), (12, 85, 61), -1)
        cv2.putText(image, 'HipL', (10, 122), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)

        cv2.putText(image, 'KneeL', (125, 122),cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(image, str(int(angleHipL)), rel_point1, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)
        cv2.putText(image, str(int(angleKneeL)), rel_point2, cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)

        if angleKneeL > 110 and stage == 'down':
            stage = 'up'
            if 18 < angleHipL < 40:
                correct += 1

        if angleKneeL < 110 and stage == 'up':
            stage = 'down'

    cv2.rectangle(image, (0, 0), (200, 83), (127, 248, 236), -1)
    cv2.rectangle(image, (0, 3), (197, 80), (12, 85, 61), -1)

    cv2.putText(image, 'Left', (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(image, str(correct), (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(image, 'STAGE', (110, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
    cv2.putText(image, stage, (77, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.5, (255, 255, 255), 2, cv2.LINE_AA)

    mp_drawing.draw_landmarks(
        image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
        mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),
        mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2)
    )

    result_queue.put(detections)
    return av.VideoFrame.from_ndarray(image, format="bgr24")

webrtc_streamer(
    key="squat-detection",
    mode=WebRtcMode.SENDRECV,
    rtc_configuration={"iceServers": get_ice_servers(), "iceTransportPolicy": "relay"},
    media_stream_constraints={"video": True, "audio": False},
    video_frame_callback=video_frame_callback,
    async_processing=True,
)