randomshit11 commited on
Commit
52e4943
·
verified ·
1 Parent(s): 3fa373b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +417 -237
app.py CHANGED
@@ -1,24 +1,10 @@
1
  import streamlit as st
2
  import cv2
3
-
4
- from tensorflow.keras.models import Model
5
- from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, Flatten,
6
- Bidirectional, Permute, multiply)
7
-
8
- import numpy as np
9
- import mediapipe as mp
10
- import math
11
- import streamlit as st
12
- import cv2
13
  import mediapipe as mp
14
  import math
15
- from VideoProcessor import VideoProcessor
16
- # from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
17
- import av
18
- from io import BytesIO
19
- import av
20
  from PIL import Image
21
- video_processor = VideoProcessor()
 
22
  ## Build and Load Model
23
  def attention_block(inputs, time_steps):
24
  """
@@ -61,17 +47,6 @@ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num
61
 
62
  return model
63
 
64
- HIDDEN_UNITS = 256
65
- model = build_model(HIDDEN_UNITS)
66
- threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
67
- threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
68
- threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
69
-
70
- ## Mediapipe
71
- mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
72
- mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
73
- pose = mp_pose.Pose(min_detection_confidence=threshold1, min_tracking_confidence=threshold2) # mediapipe pose model
74
-
75
  ## Real Time Machine Learning and Computer Vision Processes
76
  class VideoProcessor:
77
  def __init__(self):
@@ -79,31 +54,28 @@ class VideoProcessor:
79
  self.actions = np.array(['curl', 'press', 'squat'])
80
  self.sequence_length = 30
81
  self.colors = [(245,117,16), (117,245,16), (16,117,245)]
82
- self.threshold = threshold3
83
 
84
  # Detection variables
85
  self.sequence = []
86
  self.current_action = ''
87
-
88
- # Rep counter logic variables
89
- self.curl_counter = 0
90
- self.press_counter = 0
91
- self.squat_counter = 0
92
- self.curl_stage = None
93
- self.press_stage = None
94
- self.squat_stage = None
95
 
96
- @st.cache()
97
  def draw_landmarks(self, image, results):
98
  """
99
  This function draws keypoints and landmarks detected by the human pose estimation model
100
 
101
  """
102
- mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
103
- mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
104
- mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
105
- )
106
- return
107
 
108
  @st.cache()
109
  def extract_keypoints(self, results):
@@ -116,7 +88,7 @@ class VideoProcessor:
116
  return pose
117
 
118
  @st.cache()
119
- def calculate_angle(self, a,b,c):
120
  """
121
  Computes 3D joint angle inferred by 3 keypoints and their relative positions to one another
122
 
@@ -134,18 +106,17 @@ class VideoProcessor:
134
  return angle
135
 
136
  @st.cache()
137
- def get_coordinates(self, landmarks, mp_pose, side, joint):
138
  """
139
  Retrieves x and y coordinates of a particular keypoint from the pose estimation model
140
 
141
  Args:
142
  landmarks: processed keypoints from the pose estimation model
143
- mp_pose: Mediapipe pose estimation model
144
  side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.
145
  joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.
146
 
147
  """
148
- coord = getattr(mp_pose.PoseLandmark,side.upper()+"_"+joint.upper())
149
  x_coord_val = landmarks[coord.value].x
150
  y_coord_val = landmarks[coord.value].y
151
  return [x_coord_val, y_coord_val]
@@ -161,252 +132,461 @@ class VideoProcessor:
161
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
162
  )
163
  return
 
164
  @st.cache()
165
- def process_video(self, video_file):
166
  """
167
- Processes each frame of the input video, performs pose estimation,
168
- and counts repetitions of each exercise.
169
-
170
- Args:
171
- video_file (BytesIO): Input video file.
172
-
173
- Returns:
174
- tuple: A tuple containing the processed video frames with annotations
175
- and the final count of repetitions for each exercise.
176
  """
 
 
 
 
 
177
  cap = cv2.VideoCapture(video_file)
178
- out_frames = []
179
- # Initialize repetition counters
180
- self.curl_counter = 0
181
- self.press_counter = 0
182
- self.squat_counter = 0
183
-
184
  while cap.isOpened():
185
  ret, frame = cap.read()
186
  if not ret:
187
  break
188
-
189
  # Convert frame to RGB (Mediapipe requires RGB input)
190
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
191
-
192
  # Pose estimation
193
- results = pose.process(frame_rgb)
194
-
195
  # Draw landmarks
196
  self.draw_landmarks(frame, results)
197
-
198
  # Extract keypoints
199
  keypoints = self.extract_keypoints(results)
200
-
201
- # Count repetitions
202
- self.count_reps(frame, results.pose_landmarks, mp_pose)
203
-
204
  # Visualize probabilities
205
  if len(self.sequence) == self.sequence_length:
206
  sequence = np.array([self.sequence])
207
  res = model.predict(sequence)
208
  frame = self.prob_viz(res[0], frame)
209
-
210
  # Append frame to output frames
211
  out_frames.append(frame)
212
-
213
  # Release video capture
214
  cap.release()
215
 
216
- # Return annotated frames and repetition counts
217
- return out_frames, {'curl': self.curl_counter, 'press': self.press_counter, 'squat': self.squat_counter}
218
- @st.cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
219
 
220
- def count_reps(self, image, landmarks, mp_pose):
221
- """
222
- Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
 
226
- if self.current_action == 'curl':
227
- # Get coords
228
- shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
229
- elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
230
- wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
231
 
232
- # calculate elbow angle
233
- angle = self.calculate_angle(shoulder, elbow, wrist)
234
 
235
- # curl counter logic
236
- if angle < 30:
237
- self.curl_stage = "up"
238
- if angle > 140 and self.curl_stage =='up':
239
- self.curl_stage="down"
240
- self.curl_counter +=1
241
- self.press_stage = None
242
- self.squat_stage = None
243
 
244
- # Viz joint angle
245
- self.viz_joint_angle(image, angle, elbow)
246
 
247
- elif self.current_action == 'press':
248
- # Get coords
249
- shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
250
- elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
251
- wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
252
-
253
- # Calculate elbow angle
254
- elbow_angle = self.calculate_angle(shoulder, elbow, wrist)
255
 
256
- # Compute distances between joints
257
- shoulder2elbow_dist = abs(math.dist(shoulder,elbow))
258
- shoulder2wrist_dist = abs(math.dist(shoulder,wrist))
259
 
260
- # Press counter logic
261
- if (elbow_angle > 130) and (shoulder2elbow_dist < shoulder2wrist_dist):
262
- self.press_stage = "up"
263
- if (elbow_angle < 50) and (shoulder2elbow_dist > shoulder2wrist_dist) and (self.press_stage =='up'):
264
- self.press_stage='down'
265
- self.press_counter += 1
266
- self.curl_stage = None
267
- self.squat_stage = None
268
 
269
- # Viz joint angle
270
- self.viz_joint_angle(image, elbow_angle, elbow)
271
 
272
- elif self.current_action == 'squat':
273
- # Get coords
274
- # left side
275
- left_shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
276
- left_hip = self.get_coordinates(landmarks, mp_pose, 'left', 'hip')
277
- left_knee = self.get_coordinates(landmarks, mp_pose, 'left', 'knee')
278
- left_ankle = self.get_coordinates(landmarks, mp_pose, 'left', 'ankle')
279
- # right side
280
- right_shoulder = self.get_coordinates(landmarks, mp_pose, 'right', 'shoulder')
281
- right_hip = self.get_coordinates(landmarks, mp_pose, 'right', 'hip')
282
- right_knee = self.get_coordinates(landmarks, mp_pose, 'right', 'knee')
283
- right_ankle = self.get_coordinates(landmarks, mp_pose, 'right', 'ankle')
284
 
285
- # Calculate knee angles
286
- left_knee_angle = self.calculate_angle(left_hip, left_knee, left_ankle)
287
- right_knee_angle = self.calculate_angle(right_hip, right_knee, right_ankle)
288
 
289
- # Calculate hip angles
290
- left_hip_angle = self.calculate_angle(left_shoulder, left_hip, left_knee)
291
- right_hip_angle = self.calculate_angle(right_shoulder, right_hip, right_knee)
292
 
293
- # Squat counter logic
294
- thr = 165
295
- if (left_knee_angle < thr) and (right_knee_angle < thr) and (left_hip_angle < thr) and (right_hip_angle < thr):
296
- self.squat_stage = "down"
297
- if (left_knee_angle > thr) and (right_knee_angle > thr) and (left_hip_angle > thr) and (right_hip_angle > thr) and (self.squat_stage =='down'):
298
- self.squat_stage='up'
299
- self.squat_counter += 1
300
- self.curl_stage = None
301
- self.press_stage = None
302
 
303
- # Viz joint angles
304
- self.viz_joint_angle(image, left_knee_angle, left_knee)
305
- self.viz_joint_angle(image, left_hip_angle, left_hip)
306
 
307
- else:
308
- pass
309
- return
310
 
311
- @st.cache()
312
- def prob_viz(self, res, input_frame):
313
- """
314
- This function displays the model prediction probability distribution over the set of exercise classes
315
- as a horizontal bar graph
316
 
317
- """
318
- output_frame = input_frame.copy()
319
- for num, prob in enumerate(res):
320
- cv2.rectangle(output_frame, (0,60+num*40), (int(prob*100), 90+num*40), self.colors[num], -1)
321
- cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
322
 
323
- return output_frame
324
 
325
 
326
- # Slider widgets
327
- threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
328
- threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
329
- threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
330
 
331
- # Sidebar
332
- st.sidebar.header("Settings")
333
- st.sidebar.write("Adjust the confidence thresholds")
334
 
335
- # Call process_video_input() method from VideoProcessor
336
- video_processor.process_video_input(threshold1, threshold2, threshold3)
337
- # def process_uploaded_file(self, file):
338
- # """
339
- # Function to process an uploaded image or video file and run the fitness trainer AI
340
- # Args:
341
- # file (BytesIO): uploaded image or video file
342
- # Returns:
343
- # numpy array: processed image with keypoint detection and fitness activity classification visualized
344
- # """
345
- # # Initialize an empty list to store processed frames
346
- # processed_frames = []
347
-
348
- # # Check if the uploaded file is a video
349
- # is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
350
-
351
- # if is_video:
352
- # container = av.open(file)
353
- # for frame in container.decode(video=0):
354
- # # Convert the frame to OpenCV format
355
- # image = frame.to_image().convert("RGB")
356
- # image = np.array(image)
357
 
358
- # # Process the frame
359
- # processed_frame = self.process(image)
360
 
361
- # # Append the processed frame to the list
362
- # processed_frames.append(processed_frame)
363
 
364
- # # Close the video file container
365
- # container.close()
366
- # else:
367
- # # If the uploaded file is an image
368
- # # Load the image from the BytesIO object
369
- # image = Image.open(file)
370
- # image = np.array(image)
371
 
372
- # # Process the image
373
- # processed_frame = self.process(image)
374
 
375
- # # Append the processed frame to the list
376
- # processed_frames.append(processed_frame)
377
 
378
- # return processed_frames
379
 
380
- # def recv_uploaded_file(self, file):
381
- # """
382
- # Receive and process an uploaded video file
383
- # Args:
384
- # file (BytesIO): uploaded video file
385
- # Returns:
386
- # List[av.VideoFrame]: list of processed video frames
387
- # """
388
- # # Process the uploaded file
389
- # processed_frames = self.process_uploaded_file(file)
390
 
391
- # # Convert processed frames to av.VideoFrame objects
392
- # av_frames = []
393
- # for frame in processed_frames:
394
- # av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
395
- # av_frames.append(av_frame)
396
 
397
- # return av_frames
398
 
399
- # # Options
400
- # RTC_CONFIGURATION = RTCConfiguration(
401
- # {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
402
- # )
403
-
404
- # # Streamer
405
- # webrtc_ctx = webrtc_streamer(
406
- # key="AI trainer",
407
- # mode=WebRtcMode.SENDRECV,
408
- # rtc_configuration=RTC_CONFIGURATION,
409
- # media_stream_constraints={"video": True, "audio": False},
410
- # video_processor_factory=VideoProcessor,
411
- # async_processing=True,
412
- # )
 
1
  import streamlit as st
2
  import cv2
 
 
 
 
 
 
 
 
 
 
3
  import mediapipe as mp
4
  import math
 
 
 
 
 
5
  from PIL import Image
6
+ import numpy as np
7
+
8
  ## Build and Load Model
9
  def attention_block(inputs, time_steps):
10
  """
 
47
 
48
  return model
49
 
 
 
 
 
 
 
 
 
 
 
 
50
  ## Real Time Machine Learning and Computer Vision Processes
51
  class VideoProcessor:
52
  def __init__(self):
 
54
  self.actions = np.array(['curl', 'press', 'squat'])
55
  self.sequence_length = 30
56
  self.colors = [(245,117,16), (117,245,16), (16,117,245)]
57
+ self.threshold = 0.50 # Default threshold for activity classification confidence
58
 
59
  # Detection variables
60
  self.sequence = []
61
  self.current_action = ''
62
+
63
+ # Initialize pose model
64
+ self.mp_pose = mp.solutions.pose
65
+ self.mp_drawing = mp.solutions.drawing_utils
66
+ self.pose = self.mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
 
 
 
67
 
68
+ @st.cache()
69
  def draw_landmarks(self, image, results):
70
  """
71
  This function draws keypoints and landmarks detected by the human pose estimation model
72
 
73
  """
74
+ self.mp_drawing.draw_landmarks(image, results.pose_landmarks, self.mp_pose.POSE_CONNECTIONS,
75
+ self.mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
76
+ self.mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
77
+ )
78
+ return image
79
 
80
  @st.cache()
81
  def extract_keypoints(self, results):
 
88
  return pose
89
 
90
  @st.cache()
91
+ def calculate_angle(self, a, b, c):
92
  """
93
  Computes 3D joint angle inferred by 3 keypoints and their relative positions to one another
94
 
 
106
  return angle
107
 
108
  @st.cache()
109
+ def get_coordinates(self, landmarks, side, joint):
110
  """
111
  Retrieves x and y coordinates of a particular keypoint from the pose estimation model
112
 
113
  Args:
114
  landmarks: processed keypoints from the pose estimation model
 
115
  side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.
116
  joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.
117
 
118
  """
119
+ coord = getattr(self.mp_pose.PoseLandmark, side.upper() + "_" + joint.upper())
120
  x_coord_val = landmarks[coord.value].x
121
  y_coord_val = landmarks[coord.value].y
122
  return [x_coord_val, y_coord_val]
 
132
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
133
  )
134
  return
135
+
136
  @st.cache()
137
+ def process_video_input(self, threshold1, threshold2, threshold3):
138
  """
139
+ Processes the video input and performs real-time action recognition and rep counting.
140
+
 
 
 
 
 
 
 
141
  """
142
+ video_file = st.file_uploader("Upload Video", type=["mp4", "avi"])
143
+ if video_file is None:
144
+ st.warning("Please upload a video file.")
145
+ return
146
+
147
  cap = cv2.VideoCapture(video_file)
148
+ if not cap.isOpened():
149
+ st.error("Error opening video stream or file.")
150
+ return
151
+
 
 
152
  while cap.isOpened():
153
  ret, frame = cap.read()
154
  if not ret:
155
  break
156
+
157
  # Convert frame to RGB (Mediapipe requires RGB input)
158
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
159
+
160
  # Pose estimation
161
+ results = self.pose.process(frame_rgb)
162
+
163
  # Draw landmarks
164
  self.draw_landmarks(frame, results)
165
+
166
  # Extract keypoints
167
  keypoints = self.extract_keypoints(results)
168
+
 
 
 
169
  # Visualize probabilities
170
  if len(self.sequence) == self.sequence_length:
171
  sequence = np.array([self.sequence])
172
  res = model.predict(sequence)
173
  frame = self.prob_viz(res[0], frame)
174
+
175
  # Append frame to output frames
176
  out_frames.append(frame)
177
+
178
  # Release video capture
179
  cap.release()
180
 
181
+ # import streamlit as st
182
+ # import cv2
183
+
184
+ # from tensorflow.keras.models import Model
185
+ # from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, Flatten,
186
+ # Bidirectional, Permute, multiply)
187
+
188
+ # import numpy as np
189
+ # import mediapipe as mp
190
+ # import math
191
+ # import streamlit as st
192
+ # import cv2
193
+ # import mediapipe as mp
194
+ # import math
195
+
196
+ # # from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
197
+ # import av
198
+ # from io import BytesIO
199
+ # import av
200
+ # from PIL import Image
201
+
202
+ # ## Build and Load Model
203
+ # def attention_block(inputs, time_steps):
204
+ # """
205
+ # Attention layer for deep neural network
206
 
207
+ # """
208
+ # # Attention weights
209
+ # a = Permute((2, 1))(inputs)
210
+ # a = Dense(time_steps, activation='softmax')(a)
211
+
212
+ # # Attention vector
213
+ # a_probs = Permute((2, 1), name='attention_vec')(a)
214
+
215
+ # # Luong's multiplicative score
216
+ # output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
217
+
218
+ # return output_attention_mul
219
+
220
+ # @st.cache(allow_output_mutation=True)
221
+ # def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
222
+
223
+ # # Input
224
+ # inputs = Input(shape=(sequence_length, num_input_values))
225
+ # # Bi-LSTM
226
+ # lstm_out = Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True))(inputs)
227
+ # # Attention
228
+ # attention_mul = attention_block(lstm_out, sequence_length)
229
+ # attention_mul = Flatten()(attention_mul)
230
+ # # Fully Connected Layer
231
+ # x = Dense(2*HIDDEN_UNITS, activation='relu')(attention_mul)
232
+ # x = Dropout(0.5)(x)
233
+ # # Output
234
+ # x = Dense(num_classes, activation='softmax')(x)
235
+ # # Bring it all together
236
+ # model = Model(inputs=[inputs], outputs=x)
237
+
238
+ # ## Load Model Weights
239
+ # load_dir = "./models/LSTM_Attention.h5"
240
+ # model.load_weights(load_dir)
241
+
242
+ # return model
243
+
244
+ # HIDDEN_UNITS = 256
245
+ # model = build_model(HIDDEN_UNITS)
246
+ # threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
247
+ # threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
248
+ # threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
249
+
250
+ # ## Mediapipe
251
+ # mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
252
+ # mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
253
+ # pose = mp_pose.Pose(min_detection_confidence=threshold1, min_tracking_confidence=threshold2) # mediapipe pose model
254
+
255
+ # ## Real Time Machine Learning and Computer Vision Processes
256
+ # class VideoProcessor:
257
+ # def __init__(self):
258
+ # # Parameters
259
+ # self.actions = np.array(['curl', 'press', 'squat'])
260
+ # self.sequence_length = 30
261
+ # self.colors = [(245,117,16), (117,245,16), (16,117,245)]
262
+ # self.threshold = threshold3
263
 
264
+ # # Detection variables
265
+ # self.sequence = []
266
+ # self.current_action = ''
267
+
268
+ # # Rep counter logic variables
269
+ # self.curl_counter = 0
270
+ # self.press_counter = 0
271
+ # self.squat_counter = 0
272
+ # self.curl_stage = None
273
+ # self.press_stage = None
274
+ # self.squat_stage = None
275
+
276
+ # @st.cache()
277
+ # def draw_landmarks(self, image, results):
278
+ # """
279
+ # This function draws keypoints and landmarks detected by the human pose estimation model
280
 
281
+ # """
282
+ # mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
283
+ # mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
284
+ # mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
285
+ # )
286
+ # return
287
+
288
+ # @st.cache()
289
+ # def extract_keypoints(self, results):
290
+ # """
291
+ # Processes and organizes the keypoints detected from the pose estimation model
292
+ # to be used as inputs for the exercise decoder models
293
+
294
+ # """
295
+ # pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
296
+ # return pose
297
+
298
+ # @st.cache()
299
+ # def calculate_angle(self, a,b,c):
300
+ # """
301
+ # Computes 3D joint angle inferred by 3 keypoints and their relative positions to one another
302
+
303
+ # """
304
+ # a = np.array(a) # First
305
+ # b = np.array(b) # Mid
306
+ # c = np.array(c) # End
307
+
308
+ # radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
309
+ # angle = np.abs(radians*180.0/np.pi)
310
+
311
+ # if angle > 180.0:
312
+ # angle = 360-angle
313
+
314
+ # return angle
315
+
316
+ # @st.cache()
317
+ # def get_coordinates(self, landmarks, mp_pose, side, joint):
318
+ # """
319
+ # Retrieves x and y coordinates of a particular keypoint from the pose estimation model
320
+
321
+ # Args:
322
+ # landmarks: processed keypoints from the pose estimation model
323
+ # mp_pose: Mediapipe pose estimation model
324
+ # side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.
325
+ # joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.
326
+
327
+ # """
328
+ # coord = getattr(mp_pose.PoseLandmark,side.upper()+"_"+joint.upper())
329
+ # x_coord_val = landmarks[coord.value].x
330
+ # y_coord_val = landmarks[coord.value].y
331
+ # return [x_coord_val, y_coord_val]
332
+
333
+ # @st.cache()
334
+ # def viz_joint_angle(self, image, angle, joint):
335
+ # """
336
+ # Displays the joint angle value near the joint within the image frame
337
+
338
+ # """
339
+ # cv2.putText(image, str(int(angle)),
340
+ # tuple(np.multiply(joint, [640, 480]).astype(int)),
341
+ # cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
342
+ # )
343
+ # return
344
+ # @st.cache()
345
+ # def process_video(self, video_file):
346
+ # """
347
+ # Processes each frame of the input video, performs pose estimation,
348
+ # and counts repetitions of each exercise.
349
+
350
+ # Args:
351
+ # video_file (BytesIO): Input video file.
352
+
353
+ # Returns:
354
+ # tuple: A tuple containing the processed video frames with annotations
355
+ # and the final count of repetitions for each exercise.
356
+ # """
357
+ # cap = cv2.VideoCapture(video_file)
358
+ # out_frames = []
359
+ # # Initialize repetition counters
360
+ # self.curl_counter = 0
361
+ # self.press_counter = 0
362
+ # self.squat_counter = 0
363
+
364
+ # while cap.isOpened():
365
+ # ret, frame = cap.read()
366
+ # if not ret:
367
+ # break
368
+
369
+ # # Convert frame to RGB (Mediapipe requires RGB input)
370
+ # frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
371
+
372
+ # # Pose estimation
373
+ # results = pose.process(frame_rgb)
374
+
375
+ # # Draw landmarks
376
+ # self.draw_landmarks(frame, results)
377
+
378
+ # # Extract keypoints
379
+ # keypoints = self.extract_keypoints(results)
380
+
381
+ # # Count repetitions
382
+ # self.count_reps(frame, results.pose_landmarks, mp_pose)
383
+
384
+ # # Visualize probabilities
385
+ # if len(self.sequence) == self.sequence_length:
386
+ # sequence = np.array([self.sequence])
387
+ # res = model.predict(sequence)
388
+ # frame = self.prob_viz(res[0], frame)
389
+
390
+ # # Append frame to output frames
391
+ # out_frames.append(frame)
392
+
393
+ # # Release video capture
394
+ # cap.release()
395
+
396
+ # # Return annotated frames and repetition counts
397
+ # return out_frames, {'curl': self.curl_counter, 'press': self.press_counter, 'squat': self.squat_counter}
398
+ # @st.cache()
399
+
400
+ # def count_reps(self, image, landmarks, mp_pose):
401
+ # """
402
+ # Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
403
+
404
+ # """
405
+
406
+ # if self.current_action == 'curl':
407
+ # # Get coords
408
+ # shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
409
+ # elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
410
+ # wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
411
 
412
+ # # calculate elbow angle
413
+ # angle = self.calculate_angle(shoulder, elbow, wrist)
414
 
415
+ # # curl counter logic
416
+ # if angle < 30:
417
+ # self.curl_stage = "up"
418
+ # if angle > 140 and self.curl_stage =='up':
419
+ # self.curl_stage="down"
420
+ # self.curl_counter +=1
421
+ # self.press_stage = None
422
+ # self.squat_stage = None
423
 
424
+ # # Viz joint angle
425
+ # self.viz_joint_angle(image, angle, elbow)
426
 
427
+ # elif self.current_action == 'press':
428
+ # # Get coords
429
+ # shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
430
+ # elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
431
+ # wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
432
+
433
+ # # Calculate elbow angle
434
+ # elbow_angle = self.calculate_angle(shoulder, elbow, wrist)
435
 
436
+ # # Compute distances between joints
437
+ # shoulder2elbow_dist = abs(math.dist(shoulder,elbow))
438
+ # shoulder2wrist_dist = abs(math.dist(shoulder,wrist))
439
 
440
+ # # Press counter logic
441
+ # if (elbow_angle > 130) and (shoulder2elbow_dist < shoulder2wrist_dist):
442
+ # self.press_stage = "up"
443
+ # if (elbow_angle < 50) and (shoulder2elbow_dist > shoulder2wrist_dist) and (self.press_stage =='up'):
444
+ # self.press_stage='down'
445
+ # self.press_counter += 1
446
+ # self.curl_stage = None
447
+ # self.squat_stage = None
448
 
449
+ # # Viz joint angle
450
+ # self.viz_joint_angle(image, elbow_angle, elbow)
451
 
452
+ # elif self.current_action == 'squat':
453
+ # # Get coords
454
+ # # left side
455
+ # left_shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
456
+ # left_hip = self.get_coordinates(landmarks, mp_pose, 'left', 'hip')
457
+ # left_knee = self.get_coordinates(landmarks, mp_pose, 'left', 'knee')
458
+ # left_ankle = self.get_coordinates(landmarks, mp_pose, 'left', 'ankle')
459
+ # # right side
460
+ # right_shoulder = self.get_coordinates(landmarks, mp_pose, 'right', 'shoulder')
461
+ # right_hip = self.get_coordinates(landmarks, mp_pose, 'right', 'hip')
462
+ # right_knee = self.get_coordinates(landmarks, mp_pose, 'right', 'knee')
463
+ # right_ankle = self.get_coordinates(landmarks, mp_pose, 'right', 'ankle')
464
 
465
+ # # Calculate knee angles
466
+ # left_knee_angle = self.calculate_angle(left_hip, left_knee, left_ankle)
467
+ # right_knee_angle = self.calculate_angle(right_hip, right_knee, right_ankle)
468
 
469
+ # # Calculate hip angles
470
+ # left_hip_angle = self.calculate_angle(left_shoulder, left_hip, left_knee)
471
+ # right_hip_angle = self.calculate_angle(right_shoulder, right_hip, right_knee)
472
 
473
+ # # Squat counter logic
474
+ # thr = 165
475
+ # if (left_knee_angle < thr) and (right_knee_angle < thr) and (left_hip_angle < thr) and (right_hip_angle < thr):
476
+ # self.squat_stage = "down"
477
+ # if (left_knee_angle > thr) and (right_knee_angle > thr) and (left_hip_angle > thr) and (right_hip_angle > thr) and (self.squat_stage =='down'):
478
+ # self.squat_stage='up'
479
+ # self.squat_counter += 1
480
+ # self.curl_stage = None
481
+ # self.press_stage = None
482
 
483
+ # # Viz joint angles
484
+ # self.viz_joint_angle(image, left_knee_angle, left_knee)
485
+ # self.viz_joint_angle(image, left_hip_angle, left_hip)
486
 
487
+ # else:
488
+ # pass
489
+ # return
490
 
491
+ # @st.cache()
492
+ # def prob_viz(self, res, input_frame):
493
+ # """
494
+ # This function displays the model prediction probability distribution over the set of exercise classes
495
+ # as a horizontal bar graph
496
 
497
+ # """
498
+ # output_frame = input_frame.copy()
499
+ # for num, prob in enumerate(res):
500
+ # cv2.rectangle(output_frame, (0,60+num*40), (int(prob*100), 90+num*40), self.colors[num], -1)
501
+ # cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
502
 
503
+ # return output_frame
504
 
505
 
506
+ # # Slider widgets
507
+ # threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
508
+ # threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
509
+ # threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
510
 
511
+ # # Sidebar
512
+ # st.sidebar.header("Settings")
513
+ # st.sidebar.write("Adjust the confidence thresholds")
514
 
515
+ # # Call process_video_input() method from VideoProcessor
516
+ # video_processor.process_video_input(threshold1, threshold2, threshold3)
517
+ # # def process_uploaded_file(self, file):
518
+ # # """
519
+ # # Function to process an uploaded image or video file and run the fitness trainer AI
520
+ # # Args:
521
+ # # file (BytesIO): uploaded image or video file
522
+ # # Returns:
523
+ # # numpy array: processed image with keypoint detection and fitness activity classification visualized
524
+ # # """
525
+ # # # Initialize an empty list to store processed frames
526
+ # # processed_frames = []
527
+
528
+ # # # Check if the uploaded file is a video
529
+ # # is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
530
+
531
+ # # if is_video:
532
+ # # container = av.open(file)
533
+ # # for frame in container.decode(video=0):
534
+ # # # Convert the frame to OpenCV format
535
+ # # image = frame.to_image().convert("RGB")
536
+ # # image = np.array(image)
537
 
538
+ # # # Process the frame
539
+ # # processed_frame = self.process(image)
540
 
541
+ # # # Append the processed frame to the list
542
+ # # processed_frames.append(processed_frame)
543
 
544
+ # # # Close the video file container
545
+ # # container.close()
546
+ # # else:
547
+ # # # If the uploaded file is an image
548
+ # # # Load the image from the BytesIO object
549
+ # # image = Image.open(file)
550
+ # # image = np.array(image)
551
 
552
+ # # # Process the image
553
+ # # processed_frame = self.process(image)
554
 
555
+ # # # Append the processed frame to the list
556
+ # # processed_frames.append(processed_frame)
557
 
558
+ # # return processed_frames
559
 
560
+ # # def recv_uploaded_file(self, file):
561
+ # # """
562
+ # # Receive and process an uploaded video file
563
+ # # Args:
564
+ # # file (BytesIO): uploaded video file
565
+ # # Returns:
566
+ # # List[av.VideoFrame]: list of processed video frames
567
+ # # """
568
+ # # # Process the uploaded file
569
+ # # processed_frames = self.process_uploaded_file(file)
570
 
571
+ # # # Convert processed frames to av.VideoFrame objects
572
+ # # av_frames = []
573
+ # # for frame in processed_frames:
574
+ # # av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
575
+ # # av_frames.append(av_frame)
576
 
577
+ # # return av_frames
578
 
579
+ # # # Options
580
+ # # RTC_CONFIGURATION = RTCConfiguration(
581
+ # # {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
582
+ # # )
583
+
584
+ # # # Streamer
585
+ # # webrtc_ctx = webrtc_streamer(
586
+ # # key="AI trainer",
587
+ # # mode=WebRtcMode.SENDRECV,
588
+ # # rtc_configuration=RTC_CONFIGURATION,
589
+ # # media_stream_constraints={"video": True, "audio": False},
590
+ # # video_processor_factory=VideoProcessor,
591
+ # # async_processing=True,
592
+ # # )