randomshit11 commited on
Commit
8fbb28c
Β·
verified Β·
1 Parent(s): f5e0301

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -164
app.py CHANGED
@@ -35,18 +35,7 @@ def attention_block(inputs, time_steps):
35
 
36
  @st.cache(allow_output_mutation=True)
37
  def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
38
- """
39
- Function used to build the deep neural network model on startup
40
-
41
- Args:
42
- HIDDEN_UNITS (int, optional): Number of hidden units for each neural network hidden layer. Defaults to 256.
43
- sequence_length (int, optional): Input sequence length (i.e., number of frames). Defaults to 30.
44
- num_input_values (_type_, optional): Input size of the neural network model. Defaults to 33*4 (i.e., number of keypoints x number of metrics).
45
- num_classes (int, optional): Number of classification categories (i.e., model output size). Defaults to 3.
46
-
47
- Returns:
48
- keras model: neural network with pre-trained weights
49
- """
50
  # Input
51
  inputs = Input(shape=(sequence_length, num_input_values))
52
  # Bi-LSTM
@@ -70,24 +59,10 @@ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num
70
 
71
  HIDDEN_UNITS = 256
72
  model = build_model(HIDDEN_UNITS)
73
-
74
- ## App
75
- st.write("# AI Personal Fitness Trainer Web App")
76
-
77
- st.markdown("❗❗ **Development Note** ❗❗")
78
- st.markdown("Currently, the exercise recognition model uses the the x, y, and z coordinates of each anatomical landmark from the MediaPipe Pose model. These coordinates are normalized with respect to the image frame (e.g., the top left corner represents (x=0,y=0) and the bottom right corner represents(x=1,y=1)).")
79
- st.markdown("I'm currently developing and testing two new feature engineering strategies:")
80
- st.markdown("- Normalizing coordinates by the detected bounding box of the user")
81
- st.markdown("- Using joint angles rather than keypoint coordaintes as features")
82
- st.write("Stay Tuned!")
83
-
84
- st.write("## Settings")
85
  threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
86
  threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
87
  threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
88
 
89
- st.write("## Activate the AI πŸ€–πŸ‹οΈβ€β™‚οΈ")
90
-
91
  ## Mediapipe
92
  mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
93
  mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
@@ -182,8 +157,62 @@ class VideoProcessor:
182
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
183
  )
184
  return
185
-
186
  @st.cache()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  def count_reps(self, image, landmarks, mp_pose):
188
  """
189
  Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
@@ -288,155 +317,82 @@ class VideoProcessor:
288
  cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
289
 
290
  return output_frame
291
-
292
- # @st.cache()
293
- # def process(self, image):
294
- # """
295
- # Function to process the video frame from the user's webcam and run the fitness trainer AI
296
-
297
- # Args:
298
- # image (numpy array): input image from the webcam
299
-
300
- # Returns:
301
- # numpy array: processed image with keypoint detection and fitness activity classification visualized
302
- # """
303
- # # Pose detection model
304
- # image.flags.writeable = False
305
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
306
- # results = pose.process(image)
307
 
308
- # # Draw the hand annotations on the image.
309
- # image.flags.writeable = True
310
- # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
311
- # self.draw_landmarks(image, results)
312
 
313
- # # Prediction logic
314
- # keypoints = self.extract_keypoints(results)
315
- # self.sequence.append(keypoints.astype('float32',casting='same_kind'))
316
- # self.sequence = self.sequence[-self.sequence_length:]
317
-
318
- # if len(self.sequence) == self.sequence_length:
319
- # res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
320
- # # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
321
- # # interpreter.invoke()
322
- # # res = interpreter.get_tensor(self.output_details[0]['index'])
323
-
324
- # self.current_action = self.actions[np.argmax(res)]
325
- # confidence = np.max(res)
326
-
327
- # # Erase current action variable if no probability is above threshold
328
- # if confidence < self.threshold:
329
- # self.current_action = ''
330
 
331
- # # Viz probabilities
332
- # image = self.prob_viz(res, image)
333
-
334
- # # Count reps
335
- # try:
336
- # landmarks = results.pose_landmarks.landmark
337
- # self.count_reps(
338
- # image, landmarks, mp_pose)
339
- # except:
340
- # pass
341
-
342
- # # Display graphical information
343
- # cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
344
- # cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
345
- # cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
346
- # cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
347
- # cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
348
- # cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
349
- # cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
350
-
351
- # # return cv2.flip(image, 1)
352
- # return image
353
-
354
- # def recv(self, frame):
355
- # """
356
- # Receive and process video stream from webcam
357
 
358
- # Args:
359
- # frame: current video frame
360
-
361
- # Returns:
362
- # av.VideoFrame: processed video frame
363
- # """
364
- # img = frame.to_ndarray(format="bgr24")
365
- # img = self.process(img)
366
- # return av.VideoFrame.from_ndarray(img, format="bgr24")
367
- def process_uploaded_file(self, file):
368
- """
369
- Function to process an uploaded image or video file and run the fitness trainer AI
370
- Args:
371
- file (BytesIO): uploaded image or video file
372
- Returns:
373
- numpy array: processed image with keypoint detection and fitness activity classification visualized
374
- """
375
- # Initialize an empty list to store processed frames
376
- processed_frames = []
377
-
378
- # Check if the uploaded file is a video
379
- is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
380
-
381
- if is_video:
382
- container = av.open(file)
383
- for frame in container.decode(video=0):
384
- # Convert the frame to OpenCV format
385
- image = frame.to_image().convert("RGB")
386
- image = np.array(image)
387
 
388
- # Process the frame
389
- processed_frame = self.process(image)
390
 
391
- # Append the processed frame to the list
392
- processed_frames.append(processed_frame)
393
 
394
- # Close the video file container
395
- container.close()
396
- else:
397
- # If the uploaded file is an image
398
- # Load the image from the BytesIO object
399
- image = Image.open(file)
400
- image = np.array(image)
401
 
402
- # Process the image
403
- processed_frame = self.process(image)
404
 
405
- # Append the processed frame to the list
406
- processed_frames.append(processed_frame)
407
 
408
- return processed_frames
409
 
410
- def recv_uploaded_file(self, file):
411
- """
412
- Receive and process an uploaded video file
413
- Args:
414
- file (BytesIO): uploaded video file
415
- Returns:
416
- List[av.VideoFrame]: list of processed video frames
417
- """
418
- # Process the uploaded file
419
- processed_frames = self.process_uploaded_file(file)
420
 
421
- # Convert processed frames to av.VideoFrame objects
422
- av_frames = []
423
- for frame in processed_frames:
424
- av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
425
- av_frames.append(av_frame)
426
 
427
- return av_frames
428
 
429
- # Options
430
- RTC_CONFIGURATION = RTCConfiguration(
431
- {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
432
- )
433
 
434
- # Streamer
435
- webrtc_ctx = webrtc_streamer(
436
- key="AI trainer",
437
- mode=WebRtcMode.SENDRECV,
438
- rtc_configuration=RTC_CONFIGURATION,
439
- media_stream_constraints={"video": True, "audio": False},
440
- video_processor_factory=VideoProcessor,
441
- async_processing=True,
442
- )
 
35
 
36
  @st.cache(allow_output_mutation=True)
37
  def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
38
+
 
 
 
 
 
 
 
 
 
 
 
39
  # Input
40
  inputs = Input(shape=(sequence_length, num_input_values))
41
  # Bi-LSTM
 
59
 
60
  HIDDEN_UNITS = 256
61
  model = build_model(HIDDEN_UNITS)
 
 
 
 
 
 
 
 
 
 
 
 
62
  threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
63
  threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
64
  threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
65
 
 
 
66
  ## Mediapipe
67
  mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
68
  mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
 
157
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
158
  )
159
  return
 
160
  @st.cache()
161
+ def process_video(self, video_file):
162
+ """
163
+ Processes each frame of the input video, performs pose estimation,
164
+ and counts repetitions of each exercise.
165
+
166
+ Args:
167
+ video_file (BytesIO): Input video file.
168
+
169
+ Returns:
170
+ tuple: A tuple containing the processed video frames with annotations
171
+ and the final count of repetitions for each exercise.
172
+ """
173
+ cap = cv2.VideoCapture(video_file)
174
+ out_frames = []
175
+ # Initialize repetition counters
176
+ self.curl_counter = 0
177
+ self.press_counter = 0
178
+ self.squat_counter = 0
179
+
180
+ while cap.isOpened():
181
+ ret, frame = cap.read()
182
+ if not ret:
183
+ break
184
+
185
+ # Convert frame to RGB (Mediapipe requires RGB input)
186
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
187
+
188
+ # Pose estimation
189
+ results = pose.process(frame_rgb)
190
+
191
+ # Draw landmarks
192
+ self.draw_landmarks(frame, results)
193
+
194
+ # Extract keypoints
195
+ keypoints = self.extract_keypoints(results)
196
+
197
+ # Count repetitions
198
+ self.count_reps(frame, results.pose_landmarks, mp_pose)
199
+
200
+ # Visualize probabilities
201
+ if len(self.sequence) == self.sequence_length:
202
+ sequence = np.array([self.sequence])
203
+ res = model.predict(sequence)
204
+ frame = self.prob_viz(res[0], frame)
205
+
206
+ # Append frame to output frames
207
+ out_frames.append(frame)
208
+
209
+ # Release video capture
210
+ cap.release()
211
+
212
+ # Return annotated frames and repetition counts
213
+ return out_frames, {'curl': self.curl_counter, 'press': self.press_counter, 'squat': self.squat_counter}
214
+ @st.cache()
215
+
216
  def count_reps(self, image, landmarks, mp_pose):
217
  """
218
  Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
 
317
  cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
318
 
319
  return output_frame
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
320
 
 
 
 
 
321
 
322
+ video_processor.process_video_input(threshold1, threshold2, threshold3)
323
+ # def process_uploaded_file(self, file):
324
+ # """
325
+ # Function to process an uploaded image or video file and run the fitness trainer AI
326
+ # Args:
327
+ # file (BytesIO): uploaded image or video file
328
+ # Returns:
329
+ # numpy array: processed image with keypoint detection and fitness activity classification visualized
330
+ # """
331
+ # # Initialize an empty list to store processed frames
332
+ # processed_frames = []
 
 
 
 
 
 
333
 
334
+ # # Check if the uploaded file is a video
335
+ # is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
+ # if is_video:
338
+ # container = av.open(file)
339
+ # for frame in container.decode(video=0):
340
+ # # Convert the frame to OpenCV format
341
+ # image = frame.to_image().convert("RGB")
342
+ # image = np.array(image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
343
 
344
+ # # Process the frame
345
+ # processed_frame = self.process(image)
346
 
347
+ # # Append the processed frame to the list
348
+ # processed_frames.append(processed_frame)
349
 
350
+ # # Close the video file container
351
+ # container.close()
352
+ # else:
353
+ # # If the uploaded file is an image
354
+ # # Load the image from the BytesIO object
355
+ # image = Image.open(file)
356
+ # image = np.array(image)
357
 
358
+ # # Process the image
359
+ # processed_frame = self.process(image)
360
 
361
+ # # Append the processed frame to the list
362
+ # processed_frames.append(processed_frame)
363
 
364
+ # return processed_frames
365
 
366
+ # def recv_uploaded_file(self, file):
367
+ # """
368
+ # Receive and process an uploaded video file
369
+ # Args:
370
+ # file (BytesIO): uploaded video file
371
+ # Returns:
372
+ # List[av.VideoFrame]: list of processed video frames
373
+ # """
374
+ # # Process the uploaded file
375
+ # processed_frames = self.process_uploaded_file(file)
376
 
377
+ # # Convert processed frames to av.VideoFrame objects
378
+ # av_frames = []
379
+ # for frame in processed_frames:
380
+ # av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
381
+ # av_frames.append(av_frame)
382
 
383
+ # return av_frames
384
 
385
+ # # Options
386
+ # RTC_CONFIGURATION = RTCConfiguration(
387
+ # {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
388
+ # )
389
 
390
+ # # Streamer
391
+ # webrtc_ctx = webrtc_streamer(
392
+ # key="AI trainer",
393
+ # mode=WebRtcMode.SENDRECV,
394
+ # rtc_configuration=RTC_CONFIGURATION,
395
+ # media_stream_constraints={"video": True, "audio": False},
396
+ # video_processor_factory=VideoProcessor,
397
+ # async_processing=True,
398
+ # )