File size: 7,207 Bytes
bba6cec
eac73a1
 
 
1c81ee4
eac73a1
 
df8ec21
 
 
eac73a1
 
 
 
df8ec21
bba6cec
 
 
 
 
 
 
 
df8ec21
 
 
4a3054c
 
 
 
 
 
eac73a1
bba6cec
df8ec21
bba6cec
 
 
2dd8b34
df8ec21
 
 
 
 
495fa7e
 
 
2f4a581
506c444
eac73a1
506c444
 
495fa7e
 
 
 
 
 
 
bba6cec
eac73a1
bba6cec
1c15637
bba6cec
 
2dd8b34
506c444
1c15637
 
 
 
 
4a3054c
1c15637
 
 
 
bba6cec
c19a8a5
df8ec21
c19a8a5
2dd8b34
c19a8a5
2dd8b34
c19a8a5
2dd8b34
c19a8a5
2dd8b34
bba6cec
 
4d85c5b
 
 
bba6cec
2dd8b34
 
6f10329
 
495fa7e
 
 
 
 
 
 
 
 
 
 
 
bea8c63
 
 
495fa7e
51e8f4c
 
 
495fa7e
ab370b0
 
 
495fa7e
 
 
 
 
 
 
 
 
2dd8b34
c19a8a5
a21e9b6
 
 
 
 
 
 
 
 
 
 
 
 
 
495fa7e
 
 
 
 
 
 
6f10329
 
 
ab370b0
 
 
 
2dd8b34
3960910
 
 
2dd8b34
3960910
2dd8b34
3960910
 
c19a8a5
3960910
 
c19a8a5
3960910
2dd8b34
3960910
 
2dd8b34
3960910
 
2dd8b34
1c15637
4d85c5b
506c444
1c15637
506c444
bba6cec
1c15637
2dd8b34
bba6cec
2dd8b34
df8ec21
bba6cec
eac73a1
 
a472ccb
eac73a1
 
a472ccb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
import logging
import queue
from pathlib import Path
from typing import List, NamedTuple
import mediapipe as mp

import av
import cv2
import numpy as np
import streamlit as st
from streamlit_webrtc import WebRtcMode, webrtc_streamer

from sample_utils.download import download_file
from sample_utils.turn import get_ice_servers

# Logging setup
logger = logging.getLogger(__name__)

# Streamlit setup
st.title("AI Squat Detection using WebRTC")
st.info("Use your webcam for real-time squat detection.")

# Initialize MediaPipe components
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

class Detection(NamedTuple):
    class_id: int
    label: str
    score: float
    box: np.ndarray


# Angle calculation function
def calculate_angle(a, b, c):
    a = np.array(a)
    b = np.array(b)
    c = np.array(c)
    radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
    angle = np.abs(radians * 180.0 / np.pi)
    if angle > 180.0:
        angle = 360 - angle
    return angle

counterL=0#Counter checks for number of curls
correct=0
incorrect=0

# Detection Queue
result_queue: "queue.Queue[List[Detection]]" = queue.Queue()

def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
    global counterL, correct, incorrect, stage #The change made
        # Initialize stage if not defined
    if 'stage' not in globals():
        stage = 'up'
        correct = 0
        incorrect = 0
        
    image = frame.to_ndarray(format="bgr24")
    h, w = image.shape[:2]
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
        results = pose.process(image_rgb)
        landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []

        # Corrected detection logic
        detections = [
            Detection(
                class_id=0,  # Assuming a generic class_id for pose detections
                label="Pose",
                score=0.7,  # Full confidence as pose landmarks were detected
                box=np.array([0, 0, image.shape[1], image.shape[0]])  # Full image as bounding box
            )
        ] if landmarks else []

        if landmarks:
            hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, 
                   landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
            kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, 
                    landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y]
            ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, 
                     landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
            shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, 
                        landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]
            footIndexL = [landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].x, 
                    landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].y]

            # Calculate angles
            angleKneeL = calculate_angle(hipL, kneeL, ankleL)
            angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])
            angleAnkleL = calculate_angle(footIndexL, ankleL, kneeL)
            
            #Visualize of left leg
            cv2.putText(image, str(angleHipL),tuple(np.multiply(angleHipL, [640, 480]).astype(int)),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)




           # Update squat stage and count correct reps
            if angleKneeL > 110 and stage == 'down':
                stage = 'up'
                if 18 < angleHipL < 40:
                    correct += 1

            if 80 < angleKneeL < 110 and stage == 'up':
                stage = 'down'

            # Display feedback messages
            if 10 < angleHipL < 18:
               cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1) 
               cv2.putText(image,f"Bend Forward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(150,120,255),1,cv2.LINE_AA)
                
            if angleHipL > 45:
               cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1)
               cv2.putText(image,f"Bend Backward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(80,120,255),1,cv2.LINE_AA)
            




            # # # stage 2

            # # # Incorrect movements

            # # 3. Knees not low enough
            # if 110 < angleKneeL < 130:
            #    cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
            #    cv2.putText(image,f"Lower Your Hips",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
       
        
            # # 3. Knees not low enough and not completed the squat 
            # if angleKneeL>130 and stage=='mid':
            #    cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
            #    cv2.putText(image,f"Lower Your Hips",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
            #    incorrect+=1
            #    stage='up'

            # # 4. Squat too deep
            # if angleKneeL < 80 and stage=='mid':
            #    cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
            #    cv2.putText(image,f"Squat too deep",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
            #    incorrect +=1
            #    stage='up'

            # # stage 4 
            # if (80 < angleKneeL < 110) and stage=='mid':
            #    if (18 < angleHipL < 40):  # Valid "down" position
            #       correct+=1
            #       stage='up'
            # if (angleKneeL>110):
            #     stage='mid'


            
        # cv2.putText(image,f"Correct:{correct}",
        #        (400,120),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,0),1,cv2.LINE_AA)
        # cv2.putText(image,f"Incorrect:{incorrect}",
        #        (400,150),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,0),1,cv2.LINE_AA)

    #Render Counter  to our camera screen
    #Setup Status box
        cv2.rectangle(image,(0,0),(500,80),(245,117,16),-1)
    
    #REP data
    
        cv2.putText(image,'Left',(10,12),
               cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA)
    
        cv2.putText(image,str(correct),
               (10,60),cv2.FONT_HERSHEY_SIMPLEX,2,(255,255,255),2,cv2.LINE_AA)
    
    #Stage data for left leg
    
        cv2.putText(image,'STAGE',(230,12),
               cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA)
    
        cv2.putText(image,stage,
               (230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
            

        mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2))

    result_queue.put(detections)
    return av.VideoFrame.from_ndarray(image, format="bgr24")



# WebRTC streamer configuration

webrtc_streamer(
    key="squat-detection",
    mode=WebRtcMode.SENDRECV,
    rtc_configuration={"iceServers": get_ice_servers(), "iceTransportPolicy": "relay"},
    media_stream_constraints={"video": True, "audio": False},
    video_frame_callback=video_frame_callback,
    async_processing=True,
)