Pratyush101 commited on
Commit
768810e
·
verified ·
1 Parent(s): c461f72

Update app.py

Browse files

Improve the UI of AI squat detection see the changes or revert back to original

Files changed (1) hide show
  1. app.py +36 -110
app.py CHANGED
@@ -17,8 +17,34 @@ from sample_utils.turn import get_ice_servers
17
  logger = logging.getLogger(__name__)
18
 
19
  # Streamlit setup
20
- st.title("AI Squat Detection using WebRTC")
21
- st.info("Use your webcam for real-time squat detection.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
  # Initialize MediaPipe components
24
  mp_pose = mp.solutions.pose
@@ -30,7 +56,6 @@ class Detection(NamedTuple):
30
  score: float
31
  box: np.ndarray
32
 
33
-
34
  # Angle calculation function
35
  def calculate_angle(a, b, c):
36
  a = np.array(a)
@@ -42,21 +67,17 @@ def calculate_angle(a, b, c):
42
  angle = 360 - angle
43
  return angle
44
 
45
- counterL=0#Counter checks for number of curls
46
- correct=0
47
- incorrect=0
48
 
49
  # Detection Queue
50
  result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
51
 
52
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
53
- global counterL, correct, incorrect, stage #The change made
54
- # Initialize stage if not defined
55
  if 'stage' not in globals():
56
  stage = 'up'
57
  correct = 0
58
  incorrect = 0
59
-
60
  image = frame.to_ndarray(format="bgr24")
61
  h, w = image.shape[:2]
62
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
@@ -65,16 +86,6 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
65
  results = pose.process(image_rgb)
66
  landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []
67
 
68
- # Corrected detection logic
69
- detections = [
70
- Detection(
71
- class_id=0, # Assuming a generic class_id for pose detections
72
- label="Pose",
73
- score=0.7, # Full confidence as pose landmarks were detected
74
- box=np.array([0, 0, image.shape[1], image.shape[0]]) # Full image as bounding box
75
- )
76
- ] if landmarks else []
77
-
78
  if landmarks:
79
  hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x,
80
  landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
@@ -84,21 +95,11 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
84
  landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
85
  shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x,
86
  landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]
87
- footIndexL = [landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].x,
88
- landmarks[mp_pose.PoseLandmark.LEFT_FOOT_INDEX.value].y]
89
 
90
  # Calculate angles
91
  angleKneeL = calculate_angle(hipL, kneeL, ankleL)
92
  angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])
93
- angleAnkleL = calculate_angle(footIndexL, ankleL, kneeL)
94
-
95
- #Visualize of left leg
96
- cv2.putText(image, str(angleHipL),tuple(np.multiply(angleHipL, [640, 480]).astype(int)),cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA)
97
-
98
-
99
 
100
-
101
- # Update squat stage and count correct reps
102
  if angleKneeL > 110 and stage == 'down':
103
  stage = 'up'
104
  if 18 < angleHipL < 40:
@@ -107,87 +108,12 @@ def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
107
  if 80 < angleKneeL < 110 and stage == 'up':
108
  stage = 'down'
109
 
110
- # Display feedback messages
111
- if 10 < angleHipL < 18:
112
- cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1)
113
- cv2.putText(image,f"Bend Forward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(150,120,255),1,cv2.LINE_AA)
114
-
115
- if angleHipL > 45:
116
- cv2.rectangle(image, (310, 180), (450, 220), (0, 0, 0), -1)
117
- cv2.putText(image,f"Bend Backward",(320,200),cv2.FONT_HERSHEY_SIMPLEX,1,(80,120,255),1,cv2.LINE_AA)
118
-
119
-
120
-
121
-
122
-
123
- # # # stage 2
124
-
125
- # # # Incorrect movements
126
-
127
- # # 3. Knees not low enough
128
- # if 110 < angleKneeL < 130:
129
- # cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
130
- # cv2.putText(image,f"Lower Your Hips",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
131
-
132
-
133
- # # 3. Knees not low enough and not completed the squat
134
- # if angleKneeL>130 and stage=='mid':
135
- # cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
136
- # cv2.putText(image,f"Lower Your Hips",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
137
- # incorrect+=1
138
- # stage='up'
139
-
140
- # # 4. Squat too deep
141
- # if angleKneeL < 80 and stage=='mid':
142
- # cv2.rectangle(image, (220, 40), (450, 80), (0, 0, 0), -1)
143
- # cv2.putText(image,f"Squat too deep",(230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
144
- # incorrect +=1
145
- # stage='up'
146
-
147
- # # stage 4
148
- # if (80 < angleKneeL < 110) and stage=='mid':
149
- # if (18 < angleHipL < 40): # Valid "down" position
150
- # correct+=1
151
- # stage='up'
152
- # if (angleKneeL>110):
153
- # stage='mid'
154
-
155
-
156
-
157
- # cv2.putText(image,f"Correct:{correct}",
158
- # (400,120),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,0),1,cv2.LINE_AA)
159
- # cv2.putText(image,f"Incorrect:{incorrect}",
160
- # (400,150),cv2.FONT_HERSHEY_SIMPLEX,1,(0,0,0),1,cv2.LINE_AA)
161
-
162
- #Render Counter to our camera screen
163
- #Setup Status box
164
- cv2.rectangle(image,(0,0),(500,80),(245,117,16),-1)
165
-
166
- #REP data
167
-
168
- cv2.putText(image,'Left',(10,12),
169
- cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA)
170
-
171
- cv2.putText(image,str(correct),
172
- (10,60),cv2.FONT_HERSHEY_SIMPLEX,2,(255,255,255),2,cv2.LINE_AA)
173
-
174
- #Stage data for left leg
175
-
176
- cv2.putText(image,'STAGE',(230,12),
177
- cv2.FONT_HERSHEY_SIMPLEX,0.5,(0,0,0),1,cv2.LINE_AA)
178
-
179
- cv2.putText(image,stage,
180
- (230,60),cv2.FONT_HERSHEY_SIMPLEX,1,(255,255,255),1,cv2.LINE_AA)
181
-
182
-
183
- mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,mp_drawing.DrawingSpec(color=(255, 175, 0), thickness=2, circle_radius=2),mp_drawing.DrawingSpec(color=(0, 255, 200), thickness=2, circle_radius=2))
184
-
185
- result_queue.put(detections)
186
- return av.VideoFrame.from_ndarray(image, format="bgr24")
187
-
188
 
189
-
190
- # WebRTC streamer configuration
191
 
192
  webrtc_streamer(
193
  key="squat-detection",
@@ -196,4 +122,4 @@ webrtc_streamer(
196
  media_stream_constraints={"video": True, "audio": False},
197
  video_frame_callback=video_frame_callback,
198
  async_processing=True,
199
- )
 
17
  logger = logging.getLogger(__name__)
18
 
19
  # Streamlit setup
20
+ st.set_page_config(page_title="AI Squat Detection", page_icon="🏋️", layout="wide")
21
+ st.markdown(
22
+ """<style>
23
+ .status-box {
24
+ background: #f7f7f7;
25
+ padding: 15px;
26
+ border-radius: 8px;
27
+ box-shadow: 2px 2px 5px rgba(0,0,0,0.1);
28
+ margin-bottom: 20px;
29
+ font-size: 18px;
30
+ }
31
+ .title {
32
+ color: #2E86C1;
33
+ font-size: 32px;
34
+ font-weight: bold;
35
+ text-align: center;
36
+ margin-bottom: 10px;
37
+ }
38
+ .info {
39
+ text-align: center;
40
+ font-size: 18px;
41
+ margin-bottom: 20px;
42
+ color: #333;
43
+ }
44
+ </style>""", unsafe_allow_html=True)
45
+
46
+ st.markdown('<div class="title">AI Squat Detection</div>', unsafe_allow_html=True)
47
+ st.markdown('<div class="info">Use your webcam for real-time squat detection.</div>', unsafe_allow_html=True)
48
 
49
  # Initialize MediaPipe components
50
  mp_pose = mp.solutions.pose
 
56
  score: float
57
  box: np.ndarray
58
 
 
59
  # Angle calculation function
60
  def calculate_angle(a, b, c):
61
  a = np.array(a)
 
67
  angle = 360 - angle
68
  return angle
69
 
 
 
 
70
 
71
  # Detection Queue
72
  result_queue: "queue.Queue[List[Detection]]" = queue.Queue()
73
 
74
  def video_frame_callback(frame: av.VideoFrame) -> av.VideoFrame:
75
+ global counterL, correct, incorrect, stage
 
76
  if 'stage' not in globals():
77
  stage = 'up'
78
  correct = 0
79
  incorrect = 0
80
+
81
  image = frame.to_ndarray(format="bgr24")
82
  h, w = image.shape[:2]
83
  image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
 
86
  results = pose.process(image_rgb)
87
  landmarks = results.pose_landmarks.landmark if results.pose_landmarks else []
88
 
 
 
 
 
 
 
 
 
 
 
89
  if landmarks:
90
  hipL = [landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].x,
91
  landmarks[mp_pose.PoseLandmark.LEFT_HIP.value].y]
 
95
  landmarks[mp_pose.PoseLandmark.LEFT_ANKLE.value].y]
96
  shoulderL = [landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].x,
97
  landmarks[mp_pose.PoseLandmark.LEFT_SHOULDER.value].y]
 
 
98
 
99
  # Calculate angles
100
  angleKneeL = calculate_angle(hipL, kneeL, ankleL)
101
  angleHipL = calculate_angle(shoulderL, hipL, [hipL[0], 0])
 
 
 
 
 
 
102
 
 
 
103
  if angleKneeL > 110 and stage == 'down':
104
  stage = 'up'
105
  if 18 < angleHipL < 40:
 
108
  if 80 < angleKneeL < 110 and stage == 'up':
109
  stage = 'down'
110
 
111
+ # Overlay feedback messages
112
+ overlay_text = f"Correct: {correct} | Stage: {stage}"
113
+ cv2.rectangle(image, (0, 0), (500, 80), (245, 117, 16), -1)
114
+ cv2.putText(image, overlay_text, (10, 60), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ return av.VideoFrame.from_ndarray(image, format="bgr24")
 
117
 
118
  webrtc_streamer(
119
  key="squat-detection",
 
122
  media_stream_constraints={"video": True, "audio": False},
123
  video_frame_callback=video_frame_callback,
124
  async_processing=True,
125
+ )