Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -35,18 +35,7 @@ def attention_block(inputs, time_steps):
|
|
35 |
|
36 |
@st.cache(allow_output_mutation=True)
|
37 |
def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
|
38 |
-
|
39 |
-
Function used to build the deep neural network model on startup
|
40 |
-
|
41 |
-
Args:
|
42 |
-
HIDDEN_UNITS (int, optional): Number of hidden units for each neural network hidden layer. Defaults to 256.
|
43 |
-
sequence_length (int, optional): Input sequence length (i.e., number of frames). Defaults to 30.
|
44 |
-
num_input_values (_type_, optional): Input size of the neural network model. Defaults to 33*4 (i.e., number of keypoints x number of metrics).
|
45 |
-
num_classes (int, optional): Number of classification categories (i.e., model output size). Defaults to 3.
|
46 |
-
|
47 |
-
Returns:
|
48 |
-
keras model: neural network with pre-trained weights
|
49 |
-
"""
|
50 |
# Input
|
51 |
inputs = Input(shape=(sequence_length, num_input_values))
|
52 |
# Bi-LSTM
|
@@ -70,24 +59,10 @@ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num
|
|
70 |
|
71 |
HIDDEN_UNITS = 256
|
72 |
model = build_model(HIDDEN_UNITS)
|
73 |
-
|
74 |
-
## App
|
75 |
-
st.write("# AI Personal Fitness Trainer Web App")
|
76 |
-
|
77 |
-
st.markdown("ββ **Development Note** ββ")
|
78 |
-
st.markdown("Currently, the exercise recognition model uses the the x, y, and z coordinates of each anatomical landmark from the MediaPipe Pose model. These coordinates are normalized with respect to the image frame (e.g., the top left corner represents (x=0,y=0) and the bottom right corner represents(x=1,y=1)).")
|
79 |
-
st.markdown("I'm currently developing and testing two new feature engineering strategies:")
|
80 |
-
st.markdown("- Normalizing coordinates by the detected bounding box of the user")
|
81 |
-
st.markdown("- Using joint angles rather than keypoint coordaintes as features")
|
82 |
-
st.write("Stay Tuned!")
|
83 |
-
|
84 |
-
st.write("## Settings")
|
85 |
threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
|
86 |
threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
|
87 |
threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
|
88 |
|
89 |
-
st.write("## Activate the AI π€ποΈββοΈ")
|
90 |
-
|
91 |
## Mediapipe
|
92 |
mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
|
93 |
mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
|
@@ -182,8 +157,62 @@ class VideoProcessor:
|
|
182 |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
|
183 |
)
|
184 |
return
|
185 |
-
|
186 |
@st.cache()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
def count_reps(self, image, landmarks, mp_pose):
|
188 |
"""
|
189 |
Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
|
@@ -288,155 +317,82 @@ class VideoProcessor:
|
|
288 |
cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
|
289 |
|
290 |
return output_frame
|
291 |
-
|
292 |
-
# @st.cache()
|
293 |
-
# def process(self, image):
|
294 |
-
# """
|
295 |
-
# Function to process the video frame from the user's webcam and run the fitness trainer AI
|
296 |
-
|
297 |
-
# Args:
|
298 |
-
# image (numpy array): input image from the webcam
|
299 |
-
|
300 |
-
# Returns:
|
301 |
-
# numpy array: processed image with keypoint detection and fitness activity classification visualized
|
302 |
-
# """
|
303 |
-
# # Pose detection model
|
304 |
-
# image.flags.writeable = False
|
305 |
-
# image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
306 |
-
# results = pose.process(image)
|
307 |
|
308 |
-
# # Draw the hand annotations on the image.
|
309 |
-
# image.flags.writeable = True
|
310 |
-
# image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
311 |
-
# self.draw_landmarks(image, results)
|
312 |
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
# self.current_action = self.actions[np.argmax(res)]
|
325 |
-
# confidence = np.max(res)
|
326 |
-
|
327 |
-
# # Erase current action variable if no probability is above threshold
|
328 |
-
# if confidence < self.threshold:
|
329 |
-
# self.current_action = ''
|
330 |
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
# # Count reps
|
335 |
-
# try:
|
336 |
-
# landmarks = results.pose_landmarks.landmark
|
337 |
-
# self.count_reps(
|
338 |
-
# image, landmarks, mp_pose)
|
339 |
-
# except:
|
340 |
-
# pass
|
341 |
-
|
342 |
-
# # Display graphical information
|
343 |
-
# cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
|
344 |
-
# cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
|
345 |
-
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
346 |
-
# cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
|
347 |
-
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
348 |
-
# cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
|
349 |
-
# cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
|
350 |
-
|
351 |
-
# # return cv2.flip(image, 1)
|
352 |
-
# return image
|
353 |
-
|
354 |
-
# def recv(self, frame):
|
355 |
-
# """
|
356 |
-
# Receive and process video stream from webcam
|
357 |
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
363 |
-
|
364 |
-
# img = frame.to_ndarray(format="bgr24")
|
365 |
-
# img = self.process(img)
|
366 |
-
# return av.VideoFrame.from_ndarray(img, format="bgr24")
|
367 |
-
def process_uploaded_file(self, file):
|
368 |
-
"""
|
369 |
-
Function to process an uploaded image or video file and run the fitness trainer AI
|
370 |
-
Args:
|
371 |
-
file (BytesIO): uploaded image or video file
|
372 |
-
Returns:
|
373 |
-
numpy array: processed image with keypoint detection and fitness activity classification visualized
|
374 |
-
"""
|
375 |
-
# Initialize an empty list to store processed frames
|
376 |
-
processed_frames = []
|
377 |
-
|
378 |
-
# Check if the uploaded file is a video
|
379 |
-
is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
|
380 |
-
|
381 |
-
if is_video:
|
382 |
-
container = av.open(file)
|
383 |
-
for frame in container.decode(video=0):
|
384 |
-
# Convert the frame to OpenCV format
|
385 |
-
image = frame.to_image().convert("RGB")
|
386 |
-
image = np.array(image)
|
387 |
|
388 |
-
|
389 |
-
|
390 |
|
391 |
-
|
392 |
-
|
393 |
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
|
402 |
-
|
403 |
-
|
404 |
|
405 |
-
|
406 |
-
|
407 |
|
408 |
-
|
409 |
|
410 |
-
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
-
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
|
427 |
-
|
428 |
|
429 |
-
# Options
|
430 |
-
RTC_CONFIGURATION = RTCConfiguration(
|
431 |
-
|
432 |
-
)
|
433 |
|
434 |
-
# Streamer
|
435 |
-
webrtc_ctx = webrtc_streamer(
|
436 |
-
|
437 |
-
|
438 |
-
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
)
|
|
|
35 |
|
36 |
@st.cache(allow_output_mutation=True)
|
37 |
def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
|
38 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
# Input
|
40 |
inputs = Input(shape=(sequence_length, num_input_values))
|
41 |
# Bi-LSTM
|
|
|
59 |
|
60 |
HIDDEN_UNITS = 256
|
61 |
model = build_model(HIDDEN_UNITS)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
|
63 |
threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
|
64 |
threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
|
65 |
|
|
|
|
|
66 |
## Mediapipe
|
67 |
mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
|
68 |
mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
|
|
|
157 |
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
|
158 |
)
|
159 |
return
|
|
|
160 |
@st.cache()
|
161 |
+
def process_video(self, video_file):
|
162 |
+
"""
|
163 |
+
Processes each frame of the input video, performs pose estimation,
|
164 |
+
and counts repetitions of each exercise.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
video_file (BytesIO): Input video file.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
tuple: A tuple containing the processed video frames with annotations
|
171 |
+
and the final count of repetitions for each exercise.
|
172 |
+
"""
|
173 |
+
cap = cv2.VideoCapture(video_file)
|
174 |
+
out_frames = []
|
175 |
+
# Initialize repetition counters
|
176 |
+
self.curl_counter = 0
|
177 |
+
self.press_counter = 0
|
178 |
+
self.squat_counter = 0
|
179 |
+
|
180 |
+
while cap.isOpened():
|
181 |
+
ret, frame = cap.read()
|
182 |
+
if not ret:
|
183 |
+
break
|
184 |
+
|
185 |
+
# Convert frame to RGB (Mediapipe requires RGB input)
|
186 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
187 |
+
|
188 |
+
# Pose estimation
|
189 |
+
results = pose.process(frame_rgb)
|
190 |
+
|
191 |
+
# Draw landmarks
|
192 |
+
self.draw_landmarks(frame, results)
|
193 |
+
|
194 |
+
# Extract keypoints
|
195 |
+
keypoints = self.extract_keypoints(results)
|
196 |
+
|
197 |
+
# Count repetitions
|
198 |
+
self.count_reps(frame, results.pose_landmarks, mp_pose)
|
199 |
+
|
200 |
+
# Visualize probabilities
|
201 |
+
if len(self.sequence) == self.sequence_length:
|
202 |
+
sequence = np.array([self.sequence])
|
203 |
+
res = model.predict(sequence)
|
204 |
+
frame = self.prob_viz(res[0], frame)
|
205 |
+
|
206 |
+
# Append frame to output frames
|
207 |
+
out_frames.append(frame)
|
208 |
+
|
209 |
+
# Release video capture
|
210 |
+
cap.release()
|
211 |
+
|
212 |
+
# Return annotated frames and repetition counts
|
213 |
+
return out_frames, {'curl': self.curl_counter, 'press': self.press_counter, 'squat': self.squat_counter}
|
214 |
+
@st.cache()
|
215 |
+
|
216 |
def count_reps(self, image, landmarks, mp_pose):
|
217 |
"""
|
218 |
Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
|
|
|
317 |
cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
|
318 |
|
319 |
return output_frame
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
320 |
|
|
|
|
|
|
|
|
|
321 |
|
322 |
+
video_processor.process_video_input(threshold1, threshold2, threshold3)
|
323 |
+
# def process_uploaded_file(self, file):
|
324 |
+
# """
|
325 |
+
# Function to process an uploaded image or video file and run the fitness trainer AI
|
326 |
+
# Args:
|
327 |
+
# file (BytesIO): uploaded image or video file
|
328 |
+
# Returns:
|
329 |
+
# numpy array: processed image with keypoint detection and fitness activity classification visualized
|
330 |
+
# """
|
331 |
+
# # Initialize an empty list to store processed frames
|
332 |
+
# processed_frames = []
|
|
|
|
|
|
|
|
|
|
|
|
|
333 |
|
334 |
+
# # Check if the uploaded file is a video
|
335 |
+
# is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
336 |
|
337 |
+
# if is_video:
|
338 |
+
# container = av.open(file)
|
339 |
+
# for frame in container.decode(video=0):
|
340 |
+
# # Convert the frame to OpenCV format
|
341 |
+
# image = frame.to_image().convert("RGB")
|
342 |
+
# image = np.array(image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
+
# # Process the frame
|
345 |
+
# processed_frame = self.process(image)
|
346 |
|
347 |
+
# # Append the processed frame to the list
|
348 |
+
# processed_frames.append(processed_frame)
|
349 |
|
350 |
+
# # Close the video file container
|
351 |
+
# container.close()
|
352 |
+
# else:
|
353 |
+
# # If the uploaded file is an image
|
354 |
+
# # Load the image from the BytesIO object
|
355 |
+
# image = Image.open(file)
|
356 |
+
# image = np.array(image)
|
357 |
|
358 |
+
# # Process the image
|
359 |
+
# processed_frame = self.process(image)
|
360 |
|
361 |
+
# # Append the processed frame to the list
|
362 |
+
# processed_frames.append(processed_frame)
|
363 |
|
364 |
+
# return processed_frames
|
365 |
|
366 |
+
# def recv_uploaded_file(self, file):
|
367 |
+
# """
|
368 |
+
# Receive and process an uploaded video file
|
369 |
+
# Args:
|
370 |
+
# file (BytesIO): uploaded video file
|
371 |
+
# Returns:
|
372 |
+
# List[av.VideoFrame]: list of processed video frames
|
373 |
+
# """
|
374 |
+
# # Process the uploaded file
|
375 |
+
# processed_frames = self.process_uploaded_file(file)
|
376 |
|
377 |
+
# # Convert processed frames to av.VideoFrame objects
|
378 |
+
# av_frames = []
|
379 |
+
# for frame in processed_frames:
|
380 |
+
# av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
|
381 |
+
# av_frames.append(av_frame)
|
382 |
|
383 |
+
# return av_frames
|
384 |
|
385 |
+
# # Options
|
386 |
+
# RTC_CONFIGURATION = RTCConfiguration(
|
387 |
+
# {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
|
388 |
+
# )
|
389 |
|
390 |
+
# # Streamer
|
391 |
+
# webrtc_ctx = webrtc_streamer(
|
392 |
+
# key="AI trainer",
|
393 |
+
# mode=WebRtcMode.SENDRECV,
|
394 |
+
# rtc_configuration=RTC_CONFIGURATION,
|
395 |
+
# media_stream_constraints={"video": True, "audio": False},
|
396 |
+
# video_processor_factory=VideoProcessor,
|
397 |
+
# async_processing=True,
|
398 |
+
# )
|