Pratyush101 commited on
Commit
b2ed51e
·
verified ·
1 Parent(s): 495fa7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -136
app.py CHANGED
@@ -3,13 +3,11 @@ import queue
3
  from pathlib import Path
4
  from typing import List, NamedTuple
5
  import mediapipe as mp
6
-
7
  import av
8
  import cv2
9
  import numpy as np
10
  import streamlit as st
11
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
12
-
13
  from sample_utils.download import download_file
14
  from sample_utils.turn import get_ice_servers
15
 
@@ -17,8 +15,34 @@ from sample_utils.turn import get_ice_servers
17
  logger = logging.getLogger(__name__)
18
 
19
  # Streamlit setup
20
- st.title("AI Squat Detection using WebRTC")
21
- st.info("Use your webcam for real-time squat detection.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Initialize MediaPipe components
24
  mp_pose = mp.solutions.pose
@@ -30,7 +54,6 @@ class Detection(NamedTuple):
30
  score: float
31
  box: np.ndarray
32
 
33
-
34
  # Angle calculation function
35
  def calculate_angle(a, b, c):
36
  a = np.array(a)
@@ -42,152 +65,73 @@ def calculate_angle(a, b, c):
42
  angle = 360 - angle
43
  return angle
44
 
45
- counterL=0#Counter checks for number of curls
46
- correct=0
47
- incorrect=0
48
-
49
  # Detection Queue
50
  result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
51
 
 
 
 
52
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
53
- global counterL, correct, incorrect, stage #The change made
54
- # Initialize stage if not defined
55
  if 'stage' not in globals():
56
  stage = 'up'
57
  correct = 0
58
  incorrect = 0
59
-
60
  image = frame.to_ndarray(format="bgr24")
61
  h, w = image.shape[:2]
62
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
63
 
64
- with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5) as pose:
65
- results = pose.process(image_rgb)
66
- landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []
67
-
68
- # Corrected detection logic
69
- detections = [
70
- Detection(
71
- class_id=0, # Assuming a generic class_id for pose detections
72
- label="Pose",
73
- score=0.7, # Full confidence as pose landmarks were detected
74
- box=np.array([0, 0, image.shape[1], image.shape[0]]) # Full image as bounding box
75
- )
76
- ] if landmarks else []
77
-
78
- if landmarks:
79
- hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x,
80
- landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
81
- kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x,
82
- landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y]
83
- ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x,
84
- landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
85
- shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x,
86
- landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]
87
- footIndexL = [landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].x,
88
- landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].y]
89
-
90
- # Calculate angles
91
- angleKneeL = calculate_angle(hipL, kneeL, ankleL)
92
- angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])
93
- angleAnkleL = calculate_angle(footIndexL, ankleL, kneeL)
94
-
95
- #Visualize of left leg
96
- cv2.putText(image, str(angleHipL),tuple(np.multiply(angleHipL, [640, 480]).astype(int)),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
97
-
98
-
99
-
100
-
101
- # Update squat stage and count correct reps
102
- if angleKneeL > 110 and stage == 'down':
103
- stage = 'up'
104
- if 18 < angleHipL < 40:
105
- correct += 1
106
-
107
- if 80 < angleKneeL < 110 and stage == 'up':
108
- stage = 'down'
109
-
110
- # Display feedback messages
111
- if 10 < angleHipL < 18:
112
- cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1)
113
- cv2.putText(image,f"Bend Forward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(150,120,255),1,cv2.LINE_AA)
114
-
115
- if angleHipL > 45:
116
- cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1)
117
- cv2.putText(image,f"Bend Backward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(80,120,255),1,cv2.LINE_AA)
118
-
119
-
120
-
121
-
122
-
123
- # # # stage 2
124
-
125
- # # # Incorrect movements
126
-
127
- # # 3. Knees not low enough
128
- # if 110 < angleKneeL < 130:
129
- # cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
130
- # cv2.putText(image,f"Lower Your Hips",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
131
-
132
-
133
- # # 3. Knees not low enough and not completed the squat
134
- # if angleKneeL>130 and stage=='mid':
135
- # cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
136
- # cv2.putText(image,f"Lower Your Hips",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
137
- # incorrect+=1
138
- # stage='up'
139
-
140
- # # 4. Squat too deep
141
- # if angleKneeL < 80 and stage=='mid':
142
- # cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
143
- # cv2.putText(image,f"Squat too deep",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
144
- # incorrect +=1
145
- # stage='up'
146
-
147
- # # stage 4
148
- # if (80 < angleKneeL < 110) and stage=='mid':
149
- # if (18 < angleHipL < 40): # Valid "down" position
150
- # correct+=1
151
- # stage='up'
152
- # if (angleKneeL>110):
153
- # stage='mid'
154
-
155
-
156
-
157
- # cv2.putText(image,f"Correct:{correct}",
158
- # (400,120),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,0),1,cv2.LINE_AA)
159
- # cv2.putText(image,f"Incorrect:{incorrect}",
160
- # (400,150),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,0),1,cv2.LINE_AA)
161
-
162
- #Render Counter to our camera screen
163
- #Setup Status box
164
- cv2.rectangle(image,(0,0),(500,80),(245,117,16),-1)
165
-
166
- #REP data
167
-
168
- cv2.putText(image,'Left',(10,12),
169
- cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA)
170
-
171
- cv2.putText(image,str(correct),
172
- (10,60),cv2.FONT_HERSHEY_SIMPLEX,2,(255,255,255),2,cv2.LINE_AA)
173
-
174
- #Stage data for left leg
175
-
176
- cv2.putText(image,'STAGE',(230,12),
177
- cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA)
178
-
179
- cv2.putText(image,stage,
180
- (230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
181
-
182
-
183
- mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2))
184
 
185
- result_queue.put(detections)
186
- return av.VideoFrame.from_ndarray(image, format="bgr24")
 
 
 
 
 
 
 
 
 
187
 
 
 
188
 
 
 
189
 
190
- # WebRTC streamer configuration
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
 
192
  webrtc_streamer(
193
  key="squat-detection",
@@ -196,4 +140,4 @@ webrtc_streamer(
196
  media_stream_constraints={"video": True, "audio": False},
197
  video_frame_callback=video_frame_callback,
198
  async_processing=True,
199
- )
 
3
  from pathlib import Path
4
  from typing import List, NamedTuple
5
  import mediapipe as mp
 
6
  import av
7
  import cv2
8
  import numpy as np
9
  import streamlit as st
10
  from streamlit_webrtc import WebRtcMode, webrtc_streamer
 
11
  from sample_utils.download import download_file
12
  from sample_utils.turn import get_ice_servers
13
 
 
15
  logger = logging.getLogger(__name__)
16
 
17
  # Streamlit setup
18
+ st.set_page_config(page_title="AI Squat Detection", page_icon="🏋️")
19
+ st.markdown(
20
+ """<style>
21
+ .status-box {
22
+ background: #f7f7f7;
23
+ padding: 15px;
24
+ border-radius: 8px;
25
+ box-shadow: 2px 2px 5px rgba(0,0,0,0.1);
26
+ margin-bottom: 20px;
27
+ font-size: 18px;
28
+ }
29
+ .title {
30
+ color: #2E86C1;
31
+ font-size: 32px;
32
+ font-weight: bold;
33
+ text-align: center;
34
+ margin-bottom: 10px;
35
+ }
36
+ .info {
37
+ text-align: center;
38
+ font-size: 18px;
39
+ margin-bottom: 20px;
40
+ color: #333;
41
+ }
42
+ </style>""", unsafe_allow_html=True)
43
+
44
+ st.markdown('<div class="title">AI Squat Detection</div>', unsafe_allow_html=True)
45
+ st.markdown('<div class="info">Use your webcam for real-time squat detection.</div>', unsafe_allow_html=True)
46
 
47
  # Initialize MediaPipe components
48
  mp_pose = mp.solutions.pose
 
54
  score: float
55
  box: np.ndarray
56
 
 
57
  # Angle calculation function
58
  def calculate_angle(a, b, c):
59
  a = np.array(a)
 
65
  angle = 360 - angle
66
  return angle
67
 
 
 
 
 
68
  # Detection Queue
69
  result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
70
 
71
+ # Initialize MediaPipe Pose once
72
+ pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
73
+
74
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
75
+ global counterL, correct, incorrect, stage
 
76
  if 'stage' not in globals():
77
  stage = 'up'
78
  correct = 0
79
  incorrect = 0
80
+
81
  image = frame.to_ndarray(format="bgr24")
82
  h, w = image.shape[:2]
83
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
84
 
85
+ results = pose.process(image_rgb)
86
+ landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
+ detections = [
89
+ Detection(
90
+ class_id=0, label="Pose", score=0.5, box=np.array([0, 0, w, h])
91
+ )
92
+ ] if landmarks else []
93
+
94
+ if landmarks:
95
+ hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x, landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
96
+ kneeL = [landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_KNEE.value].y]
97
+ ankleL = [landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].x, landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
98
+ shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x, landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]
99
 
100
+ angleKneeL = calculate_angle(hipL, kneeL, ankleL)
101
+ angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])
102
 
103
+ rel_point1 = (int(w * 0), int(h - h * 0.55))
104
+ rel_point2 = (int(w * 0.265625), int(h - h * 0.55))
105
 
106
+ cv2.rectangle(image, (0, 110), (280, 225), (127, 248, 236), -1)
107
+ cv2.rectangle(image, (0, 113), (277, 222), (12, 85, 61), -1)
108
+ cv2.putText(image, str(int(angleHipL)), rel_point1, cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)
109
+ cv2.putText(image, str(int(angleKneeL)), rel_point2, cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)
110
+
111
+ if angleKneeL > 110 and stage == 'down':
112
+ stage = 'up'
113
+ if 18 < angleHipL < 40:
114
+ correct += 1
115
+
116
+ if 80 < angleKneeL < 110 and stage == 'up':
117
+ stage = 'down'
118
+
119
+ cv2.rectangle(image, (0, 0), (280, 103), (127, 248, 236), -1)
120
+ cv2.rectangle(image, (0, 3), (277, 100), (12, 85, 61), -1)
121
+
122
+ cv2.putText(image, 'Left', (10, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
123
+ cv2.putText(image, str(correct), (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)
124
+ cv2.putText(image, 'STAGE', (180, 22), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2, cv2.LINE_AA)
125
+ cv2.putText(image, stage, (147, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.7, (255, 255, 255), 2, cv2.LINE_AA)
126
+
127
+ mp_drawing.draw_landmarks(
128
+ image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
129
+ mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),
130
+ mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2)
131
+ )
132
+
133
+ result_queue.put(detections)
134
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
135
 
136
  webrtc_streamer(
137
  key="squat-detection",
 
140
  media_stream_constraints={"video": True, "audio": False},
141
  video_frame_callback=video_frame_callback,
142
  async_processing=True,
143
+ )