File size: 4,746 Bytes
bba6cec
eac73a1
 
 
1c81ee4
eac73a1
 
df8ec21
 
 
eac73a1
 
 
 
df8ec21
bba6cec
 
 
 
768810e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bba6cec
 
df8ec21
 
 
4a3054c
 
 
 
 
 
bba6cec
df8ec21
bba6cec
 
 
2dd8b34
df8ec21
 
 
 
 
2f4a581
506c444
eac73a1
506c444
 
768810e
495fa7e
 
 
 
768810e
bba6cec
eac73a1
bba6cec
1c15637
bba6cec
 
2dd8b34
506c444
17fe806
 
 
 
 
 
 
 
 
 
bba6cec
c19a8a5
df8ec21
c19a8a5
2dd8b34
c19a8a5
2dd8b34
c19a8a5
2dd8b34
bba6cec
 
4d85c5b
 
495fa7e
 
 
 
 
 
 
 
 
768810e
 
 
 
1c15637
17fe806
 
 
 
768810e
2dd8b34
df8ec21
bba6cec
eac73a1
 
a472ccb
eac73a1
 
768810e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import logging
import queue
from pathlib import Path
from typing import List, NamedTuple
import mediapipe as mp

import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer

from sample_utils.download import download_file
from sample_utils.turn import get_ice_servers

# Logging setup
logger = logging.getLogger(__name__)

# Streamlit setup
st.set_page_config(page_title="AI Squat Detection", page_icon="🏋️", layout="wide")
st.markdown(
    """<style>
    .status-box {
        background: #f7f7f7;
        padding: 15px;
        border-radius: 8px;
        box-shadow: 2px 2px 5px rgba(0,0,0,0.1);
        margin-bottom: 20px;
        font-size: 18px;
    }
    .title {
        color: #2E86C1;
        font-size: 32px;
        font-weight: bold;
        text-align: center;
        margin-bottom: 10px;
    }
    .info {
        text-align: center;
        font-size: 18px;
        margin-bottom: 20px;
        color: #333;
    }
    </style>""", unsafe_allow_html=True)

st.markdown('<div class="title">AI Squat Detection</div>', unsafe_allow_html=True)
st.markdown('<div class="info">Use your webcam for real-time squat detection.</div>', unsafe_allow_html=True)

# Initialize MediaPipe components
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

class Detection(NamedTuple):
    class_id: int
    label: str
    score: float
    box: np.ndarray

# Angle calculation function
def calculate_angle(a, b, c):
    a = np.array(a)
    b = np.array(b)
    c = np.array(c)
    radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
    angle = np.abs(radians * 180.0 / np.pi)
    if angle > 180.0:
        angle = 360 - angle
    return angle


# Detection Queue
result_queue: "queue.Queue[List[Detection]]" = queue.Queue()

def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
    global counterL, correct, incorrect, stage
    if 'stage' not in globals():
        stage = 'up'
        correct = 0
        incorrect = 0

    image = frame.to_ndarray(format="bgr24")
    h, w = image.shape[:2]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
        results = pose.process(image_rgb)
        landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []

        # Corrected detection logic
        detections = [
            Detection(
                class_id=0,  # Assuming a generic class_id for pose detections
                label="Pose",
                score=0.5,  # Full confidence as pose landmarks were detected
                box=np.array([0, 0, image.shape[1], image.shape[0]])  # Full image as bounding box
            )  if landmarks else []

        
        if landmarks:
            hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, 
                   landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
            kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, 
                    landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y]
            ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, 
                     landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
            shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, 
                        landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]

            # Calculate angles
            angleKneeL = calculate_angle(hipL, kneeL, ankleL)
            angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])

            if angleKneeL > 110 and stage == 'down':
                stage = 'up'
                if 18 < angleHipL < 40:
                    correct += 1

            if 80 < angleKneeL < 110 and stage == 'up':
                stage = 'down'

        # Overlay feedback messages
        overlay_text = f"Correct: {correct} | Stage: {stage}"
        cv2.rectangle(image, (0, 0), (500, 80), (245, 117, 16), -1)
        cv2.putText(image, overlay_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)

        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2))

    result_queue.put(detections)
    
    return av.VideoFrame.from_ndarray(image, format="bgr24")

webrtc_streamer(
    key="squat-detection",
    mode=WebRtcMode.SENDRECV,
    rtc_configuration={"iceServers": get_ice_servers(), "iceTransportPolicy": "relay"},
    media_stream_constraints={"video": True, "audio": False},
    video_frame_callback=video_frame_callback,
    async_processing=True,
)