randomshit11 commited on
Commit
8ba1a89
·
verified ·
1 Parent(s): 59af79a

Upload 14 files

Browse files
.dockerignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .ipynb_checkpoints
3
+ data
4
+ logs
5
+ old
6
+ research
7
+ tests
8
+ LICENSE
9
+ README.md
10
+ *.mp4
11
+ *.png
12
+ *.h5
13
+ !models/*.h5
14
+ .gitignore
.github/workflows/CI.yml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Build Docker image and deploy to Heroku
2
+ on:
3
+ # Trigger the workflow on push or pull request,
4
+ # but only for the main branch
5
+ push:
6
+ branches:
7
+ - main
8
+ jobs:
9
+ build:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v1
13
+ - name: Login to Heroku Container registry
14
+ env:
15
+ HEROKU_API_KEY: ${{ secrets.HEROKU_API_KEY }}
16
+ run: heroku container:login
17
+ - name: Build and push
18
+ env:
19
+ HEROKU_API_KEY: ${{ secrets.HEROKU_API_KEY }}
20
+ run: heroku container:push -a ai-personal-fitness-trainer web
21
+ - name: Release
22
+ env:
23
+ HEROKU_API_KEY: ${{ secrets.HEROKU_API_KEY }}
24
+ run: heroku container:release -a ai-personal-fitness-trainer web
.gitignore ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.egg-info
2
+ *.pyc
3
+ data
4
+ old
5
+ logs
6
+ .ipynb_checkpoints
7
+ *.h5
8
+ !models/*.h5
9
+ env
10
+ *.avi
11
+ *.mp4
Dockerfile ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8
2
+ EXPOSE 8501
3
+ WORKDIR /app
4
+ COPY requirements.txt ./requirements.txt
5
+ RUN apt-get update
6
+ RUN apt-get install ffmpeg libsm6 libxext6 -y
7
+ RUN pip3 install -r requirements.txt
8
+ COPY . .
9
+ CMD streamlit run --server.port $PORT app.py
ExerciseDecoder.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Chris Prasanna
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,10 +1,77 @@
1
- ---
2
- title: Yogsss
3
- emoji: 🐢
4
- colorFrom: pink
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # :robot::video_camera: Building an AI for Real-Time Exercise Recognition using Computer Vision & Deep Learning :weight_lifting_man::muscle:
2
+
3
+ ## Description
4
+ In this project, I designed an AI that uses webcam footage to accurately detect exercises in real time and counts reps. OpenCV is used to access the webcam on your machine, a pretrained CNN is implemented for real-time pose estimation, and custom deep learning models are built using TensorFlow/Keras to recognize what exercise is being performed. In addition, this project includes a guided data collection pipeline which is used to generate training data for the custom deep learning models. Using my data, the LSTM model achieved an accuracy score of 97.78% and a categorical cross-entropy loss of 1.51e-3 on the validation dataset. The attention-based LSTM achieved an accuracy score of 100% and a categorical cross-entropy loss of 2.08e-5 on the validation dataset. From that point, joint angles are extracted from the pose estimation coordinates and heuristics are used to track the exercise and count reps. Finally, visualization tools are included that display joint angles, rep counts, and probability distributions.
5
+
6
+ https://user-images.githubusercontent.com/88418264/176807706-960e19dd-4261-46f6-bdc0-cf8a6077cc82.mp4
7
+
8
+ ## Web App
9
+ I made two web interfaces to use and interact with the Personal Fitness Trainer AI. The links and descriptions are provided below.
10
+
11
+ ### Streamlit + Docker + Heroku + GitHub Actions (CI/CD)
12
+
13
+ [Web App Link (Heroku)](https://ai-personal-fitness-trainer.herokuapp.com/)
14
+
15
+ I used Streamlit, a Python library designed for people who are not expert web developers, to design an application to use the AI. Streamlit allows you to build data science applications without worrying too much about the UI design, which is all handled by the Streamlit API. I then constructed a Dockerfile that provides instructions to build a Docker image with the running application. The application was then deployed on the web using Heroku and their Docker Container Registry. Finally, I automated the deployment pipeline using GitHub Actions. I did this by designing a workflow to build the Docker image and push to Heroku's registry whenever I pushed changes to the main branch of this GitHub repository. Essentially, the workflow file automatically performs the same commands that I ran on my local machine: login to the Heroku container registry, build the Docker image, and deploy it to the web.
16
+
17
+ ### Streamlit Cloud
18
+
19
+ [Web App Link (Streamlit Cloud)](https://chrisprasanna-exercise-recognition-ai-app-app-enjv7a.streamlitapp.com/)
20
+
21
+ I also deployed the AI directly from Streamlit to their cloud. This was quick and easy, however, the biggest downside of Streamlit cloud deployment is its speed issues. The entire Python script is re-run in the browser every time the user interacts with the application. I included the link to this application for documentation purposes but I would recommend you use the link from the previous section.
22
+
23
+ ## Installation
24
+ - Download this repository and move it to your desired working directory
25
+ - Download Anaconda if you haven't already
26
+ - Open the Anaconda Prompt
27
+ - Navigate to your working directory using the cd command
28
+ - Run the following command in the Anaconda prompt:
29
+ ```
30
+ conda env create --name NAME --file environment.yml
31
+ ```
32
+ where NAME needs to be changed to the name of the conda virtual environment for this project. This environment contains all the package installations and dependencies for this project.
33
+
34
+ - Run the following command in the Anaconda prompt:
35
+ ```
36
+ conda activate NAME
37
+ ```
38
+ This activates the conda environment containing all the required packages and their versions.
39
+
40
+ - Open Anaconda Navigator
41
+ - Under the "Applications On" dropdown menu, select the newly created conda environment
42
+ - Install and open Jupyter Notebook. NOTE: once you complete this step and if you're on a Windows device, you can call the installed version of Jupyter Notebook within the conda environment directly from the start menu.
43
+ - Navigate to the ExerciseDecoder.ipynb file within the repository
44
+
45
+ ## Features
46
+
47
+ - Implementation of Google MediaPipe's BlazePose model for real-time human pose estimation
48
+ - Computer vision tools (i.e., OpenCV) for color conversion, detecting cameras, detecting camera properties, displaying images, and custom graphics/visualization
49
+ - Inferred 3D joint angle computation according to relative coordinates of surrounding body landmarks
50
+ - Guided training data generation
51
+ - Data preprocessing and callback methods for efficient deep neural network training
52
+ - Customizable LSTM and Attention-Based LSTM models
53
+ - Real-time visualization of joint angles, rep counters, and probability distribution of exercise classification predictions
54
+
55
+ ## To-Do
56
+
57
+ * Higher Priority
58
+ - [x] Add precision-recall analysis
59
+ - [x] Deploy the AI and build a web app
60
+ - [x] Build a Docker Image
61
+ - [x] Build a CI/CD workflow
62
+ - [ ] Train networks using angular joint kinematics rather than xyz coordinates
63
+ - [ ] Translate AI to a portable embedded system that you can take outdoors or at a commercial gym. Components may include a microcontroller (e.g., Raspberry Pi), external USB camera, LED screen, battery, and 3D-printed case
64
+ * Back-burner
65
+ - [ ] Add AI features that can detect poor form (e.g., leaning, fast eccentric motion, knees caving in, poor squat depth, etc.) and offer real-time advice/feedback for
66
+ - [ ] Optimize hyperparameters based on minimizing training time and cross-entropy loss on the validation dataset
67
+ - [ ] Add more exercise classes
68
+ - [ ] Add additional models. For instance, even though BlazePose is a type of CNN, there may be benefits to including convolutional layers within the custom deep learning models
69
+
70
+ ## Credits
71
+
72
+ - [MediaPipe Pose](https://google.github.io/mediapipe/solutions/pose.html) for the pretrained human pose estimation model
73
+ - [Nicholas Renotte](https://github.com/nicknochnack) for tutorials on real-time action detection and pose estimation
74
+ - [Philippe Rémy](https://github.com/philipperemy/keras-attention-mechanism) for the attention mechanism implementation for Keras
75
+
76
+ ## License
77
+ [MIT](https://github.com/chrisprasanna/Exercise_Recognition_AI/blob/main/LICENSE)
app.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+
4
+ from tensorflow.keras.models import Model
5
+ from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, Flatten,
6
+ Bidirectional, Permute, multiply)
7
+
8
+ import numpy as np
9
+ import mediapipe as mp
10
+ import math
11
+
12
+ from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
13
+ import av
14
+
15
+ ## Build and Load Model
16
+ def attention_block(inputs, time_steps):
17
+ """
18
+ Attention layer for deep neural network
19
+
20
+ """
21
+ # Attention weights
22
+ a = Permute((2, 1))(inputs)
23
+ a = Dense(time_steps, activation='softmax')(a)
24
+
25
+ # Attention vector
26
+ a_probs = Permute((2, 1), name='attention_vec')(a)
27
+
28
+ # Luong's multiplicative score
29
+ output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
30
+
31
+ return output_attention_mul
32
+
33
+ @st.cache(allow_output_mutation=True)
34
+ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
35
+ """
36
+ Function used to build the deep neural network model on startup
37
+
38
+ Args:
39
+ HIDDEN_UNITS (int, optional): Number of hidden units for each neural network hidden layer. Defaults to 256.
40
+ sequence_length (int, optional): Input sequence length (i.e., number of frames). Defaults to 30.
41
+ num_input_values (_type_, optional): Input size of the neural network model. Defaults to 33*4 (i.e., number of keypoints x number of metrics).
42
+ num_classes (int, optional): Number of classification categories (i.e., model output size). Defaults to 3.
43
+
44
+ Returns:
45
+ keras model: neural network with pre-trained weights
46
+ """
47
+ # Input
48
+ inputs = Input(shape=(sequence_length, num_input_values))
49
+ # Bi-LSTM
50
+ lstm_out = Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True))(inputs)
51
+ # Attention
52
+ attention_mul = attention_block(lstm_out, sequence_length)
53
+ attention_mul = Flatten()(attention_mul)
54
+ # Fully Connected Layer
55
+ x = Dense(2*HIDDEN_UNITS, activation='relu')(attention_mul)
56
+ x = Dropout(0.5)(x)
57
+ # Output
58
+ x = Dense(num_classes, activation='softmax')(x)
59
+ # Bring it all together
60
+ model = Model(inputs=[inputs], outputs=x)
61
+
62
+ ## Load Model Weights
63
+ load_dir = "./models/LSTM_Attention.h5"
64
+ model.load_weights(load_dir)
65
+
66
+ return model
67
+
68
+ HIDDEN_UNITS = 256
69
+ model = build_model(HIDDEN_UNITS)
70
+
71
+ ## App
72
+ st.write("# AI Personal Fitness Trainer Web App")
73
+
74
+ st.markdown("❗❗ **Development Note** ❗❗")
75
+ st.markdown("Currently, the exercise recognition model uses the the x, y, and z coordinates of each anatomical landmark from the MediaPipe Pose model. These coordinates are normalized with respect to the image frame (e.g., the top left corner represents (x=0,y=0) and the bottom right corner represents(x=1,y=1)).")
76
+ st.markdown("I'm currently developing and testing two new feature engineering strategies:")
77
+ st.markdown("- Normalizing coordinates by the detected bounding box of the user")
78
+ st.markdown("- Using joint angles rather than keypoint coordaintes as features")
79
+ st.write("Stay Tuned!")
80
+
81
+ st.write("## Settings")
82
+ threshold1 = st.slider("Minimum Keypoint Detection Confidence", 0.00, 1.00, 0.50)
83
+ threshold2 = st.slider("Minimum Tracking Confidence", 0.00, 1.00, 0.50)
84
+ threshold3 = st.slider("Minimum Activity Classification Confidence", 0.00, 1.00, 0.50)
85
+
86
+ st.write("## Activate the AI 🤖🏋️‍♂️")
87
+
88
+ ## Mediapipe
89
+ mp_pose = mp.solutions.pose # Pre-trained pose estimation model from Google Mediapipe
90
+ mp_drawing = mp.solutions.drawing_utils # Supported Mediapipe visualization tools
91
+ pose = mp_pose.Pose(min_detection_confidence=threshold1, min_tracking_confidence=threshold2) # mediapipe pose model
92
+
93
+ ## Real Time Machine Learning and Computer Vision Processes
94
+ class VideoProcessor:
95
+ def __init__(self):
96
+ # Parameters
97
+ self.actions = np.array(['curl', 'press', 'squat'])
98
+ self.sequence_length = 30
99
+ self.colors = [(245,117,16), (117,245,16), (16,117,245)]
100
+ self.threshold = threshold3
101
+
102
+ # Detection variables
103
+ self.sequence = []
104
+ self.current_action = ''
105
+
106
+ # Rep counter logic variables
107
+ self.curl_counter = 0
108
+ self.press_counter = 0
109
+ self.squat_counter = 0
110
+ self.curl_stage = None
111
+ self.press_stage = None
112
+ self.squat_stage = None
113
+
114
+ @st.cache()
115
+ def draw_landmarks(self, image, results):
116
+ """
117
+ This function draws keypoints and landmarks detected by the human pose estimation model
118
+
119
+ """
120
+ mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
121
+ mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
122
+ mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
123
+ )
124
+ return
125
+
126
+ @st.cache()
127
+ def extract_keypoints(self, results):
128
+ """
129
+ Processes and organizes the keypoints detected from the pose estimation model
130
+ to be used as inputs for the exercise decoder models
131
+
132
+ """
133
+ pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
134
+ return pose
135
+
136
+ @st.cache()
137
+ def calculate_angle(self, a,b,c):
138
+ """
139
+ Computes 3D joint angle inferred by 3 keypoints and their relative positions to one another
140
+
141
+ """
142
+ a = np.array(a) # First
143
+ b = np.array(b) # Mid
144
+ c = np.array(c) # End
145
+
146
+ radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
147
+ angle = np.abs(radians*180.0/np.pi)
148
+
149
+ if angle > 180.0:
150
+ angle = 360-angle
151
+
152
+ return angle
153
+
154
+ @st.cache()
155
+ def get_coordinates(self, landmarks, mp_pose, side, joint):
156
+ """
157
+ Retrieves x and y coordinates of a particular keypoint from the pose estimation model
158
+
159
+ Args:
160
+ landmarks: processed keypoints from the pose estimation model
161
+ mp_pose: Mediapipe pose estimation model
162
+ side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.
163
+ joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.
164
+
165
+ """
166
+ coord = getattr(mp_pose.PoseLandmark,side.upper()+"_"+joint.upper())
167
+ x_coord_val = landmarks[coord.value].x
168
+ y_coord_val = landmarks[coord.value].y
169
+ return [x_coord_val, y_coord_val]
170
+
171
+ @st.cache()
172
+ def viz_joint_angle(self, image, angle, joint):
173
+ """
174
+ Displays the joint angle value near the joint within the image frame
175
+
176
+ """
177
+ cv2.putText(image, str(int(angle)),
178
+ tuple(np.multiply(joint, [640, 480]).astype(int)),
179
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
180
+ )
181
+ return
182
+
183
+ @st.cache()
184
+ def count_reps(self, image, landmarks, mp_pose):
185
+ """
186
+ Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
187
+
188
+ """
189
+
190
+ if self.current_action == 'curl':
191
+ # Get coords
192
+ shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
193
+ elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
194
+ wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
195
+
196
+ # calculate elbow angle
197
+ angle = self.calculate_angle(shoulder, elbow, wrist)
198
+
199
+ # curl counter logic
200
+ if angle < 30:
201
+ self.curl_stage = "up"
202
+ if angle > 140 and self.curl_stage =='up':
203
+ self.curl_stage="down"
204
+ self.curl_counter +=1
205
+ self.press_stage = None
206
+ self.squat_stage = None
207
+
208
+ # Viz joint angle
209
+ self.viz_joint_angle(image, angle, elbow)
210
+
211
+ elif self.current_action == 'press':
212
+ # Get coords
213
+ shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
214
+ elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
215
+ wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
216
+
217
+ # Calculate elbow angle
218
+ elbow_angle = self.calculate_angle(shoulder, elbow, wrist)
219
+
220
+ # Compute distances between joints
221
+ shoulder2elbow_dist = abs(math.dist(shoulder,elbow))
222
+ shoulder2wrist_dist = abs(math.dist(shoulder,wrist))
223
+
224
+ # Press counter logic
225
+ if (elbow_angle > 130) and (shoulder2elbow_dist < shoulder2wrist_dist):
226
+ self.press_stage = "up"
227
+ if (elbow_angle < 50) and (shoulder2elbow_dist > shoulder2wrist_dist) and (self.press_stage =='up'):
228
+ self.press_stage='down'
229
+ self.press_counter += 1
230
+ self.curl_stage = None
231
+ self.squat_stage = None
232
+
233
+ # Viz joint angle
234
+ self.viz_joint_angle(image, elbow_angle, elbow)
235
+
236
+ elif self.current_action == 'squat':
237
+ # Get coords
238
+ # left side
239
+ left_shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
240
+ left_hip = self.get_coordinates(landmarks, mp_pose, 'left', 'hip')
241
+ left_knee = self.get_coordinates(landmarks, mp_pose, 'left', 'knee')
242
+ left_ankle = self.get_coordinates(landmarks, mp_pose, 'left', 'ankle')
243
+ # right side
244
+ right_shoulder = self.get_coordinates(landmarks, mp_pose, 'right', 'shoulder')
245
+ right_hip = self.get_coordinates(landmarks, mp_pose, 'right', 'hip')
246
+ right_knee = self.get_coordinates(landmarks, mp_pose, 'right', 'knee')
247
+ right_ankle = self.get_coordinates(landmarks, mp_pose, 'right', 'ankle')
248
+
249
+ # Calculate knee angles
250
+ left_knee_angle = self.calculate_angle(left_hip, left_knee, left_ankle)
251
+ right_knee_angle = self.calculate_angle(right_hip, right_knee, right_ankle)
252
+
253
+ # Calculate hip angles
254
+ left_hip_angle = self.calculate_angle(left_shoulder, left_hip, left_knee)
255
+ right_hip_angle = self.calculate_angle(right_shoulder, right_hip, right_knee)
256
+
257
+ # Squat counter logic
258
+ thr = 165
259
+ if (left_knee_angle < thr) and (right_knee_angle < thr) and (left_hip_angle < thr) and (right_hip_angle < thr):
260
+ self.squat_stage = "down"
261
+ if (left_knee_angle > thr) and (right_knee_angle > thr) and (left_hip_angle > thr) and (right_hip_angle > thr) and (self.squat_stage =='down'):
262
+ self.squat_stage='up'
263
+ self.squat_counter += 1
264
+ self.curl_stage = None
265
+ self.press_stage = None
266
+
267
+ # Viz joint angles
268
+ self.viz_joint_angle(image, left_knee_angle, left_knee)
269
+ self.viz_joint_angle(image, left_hip_angle, left_hip)
270
+
271
+ else:
272
+ pass
273
+ return
274
+
275
+ @st.cache()
276
+ def prob_viz(self, res, input_frame):
277
+ """
278
+ This function displays the model prediction probability distribution over the set of exercise classes
279
+ as a horizontal bar graph
280
+
281
+ """
282
+ output_frame = input_frame.copy()
283
+ for num, prob in enumerate(res):
284
+ cv2.rectangle(output_frame, (0,60+num*40), (int(prob*100), 90+num*40), self.colors[num], -1)
285
+ cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
286
+
287
+ return output_frame
288
+
289
+ @st.cache()
290
+ def process(self, image):
291
+ """
292
+ Function to process the video frame from the user's webcam and run the fitness trainer AI
293
+
294
+ Args:
295
+ image (numpy array): input image from the webcam
296
+
297
+ Returns:
298
+ numpy array: processed image with keypoint detection and fitness activity classification visualized
299
+ """
300
+ # Pose detection model
301
+ image.flags.writeable = False
302
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
303
+ results = pose.process(image)
304
+
305
+ # Draw the hand annotations on the image.
306
+ image.flags.writeable = True
307
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
308
+ self.draw_landmarks(image, results)
309
+
310
+ # Prediction logic
311
+ keypoints = self.extract_keypoints(results)
312
+ self.sequence.append(keypoints.astype('float32',casting='same_kind'))
313
+ self.sequence = self.sequence[-self.sequence_length:]
314
+
315
+ if len(self.sequence) == self.sequence_length:
316
+ res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
317
+ # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
318
+ # interpreter.invoke()
319
+ # res = interpreter.get_tensor(self.output_details[0]['index'])
320
+
321
+ self.current_action = self.actions[np.argmax(res)]
322
+ confidence = np.max(res)
323
+
324
+ # Erase current action variable if no probability is above threshold
325
+ if confidence < self.threshold:
326
+ self.current_action = ''
327
+
328
+ # Viz probabilities
329
+ image = self.prob_viz(res, image)
330
+
331
+ # Count reps
332
+ try:
333
+ landmarks = results.pose_landmarks.landmark
334
+ self.count_reps(
335
+ image, landmarks, mp_pose)
336
+ except:
337
+ pass
338
+
339
+ # Display graphical information
340
+ cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
341
+ cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
342
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
343
+ cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
344
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
345
+ cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
346
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
347
+
348
+ # return cv2.flip(image, 1)
349
+ return image
350
+
351
+ def recv(self, frame):
352
+ """
353
+ Receive and process video stream from webcam
354
+
355
+ Args:
356
+ frame: current video frame
357
+
358
+ Returns:
359
+ av.VideoFrame: processed video frame
360
+ """
361
+ img = frame.to_ndarray(format="bgr24")
362
+ img = self.process(img)
363
+ return av.VideoFrame.from_ndarray(img, format="bgr24")
364
+
365
+ ## Stream Webcam Video and Run Model
366
+ # Options
367
+ RTC_CONFIGURATION = RTCConfiguration(
368
+ {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
369
+ )
370
+ # Streamer
371
+ webrtc_ctx = webrtc_streamer(
372
+ key="AI trainer",
373
+ mode=WebRtcMode.SENDRECV,
374
+ rtc_configuration=RTC_CONFIGURATION,
375
+ media_stream_constraints={"video": True, "audio": False},
376
+ video_processor_factory=VideoProcessor,
377
+ async_processing=True,
378
+ )
environment.yml ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: AItrainer
2
+ channels:
3
+ - conda-forge
4
+ - anaconda
5
+ - soumith
6
+ - defaults
7
+ dependencies:
8
+ - _tflow_select=2.3.0=gpu
9
+ - aiohttp=3.8.1=py38h294d835_1
10
+ - aiosignal=1.2.0=pyhd8ed1ab_0
11
+ - alabaster=0.7.12=pyhd3eb1b0_0
12
+ - appdirs=1.4.4=pyhd3eb1b0_0
13
+ - argon2-cffi=21.3.0=pyhd3eb1b0_0
14
+ - argon2-cffi-bindings=21.2.0=py38h2bbff1b_0
15
+ - arrow=1.2.2=pyhd3eb1b0_0
16
+ - astor=0.8.1=pyh9f0ad1d_0
17
+ - astroid=2.9.0=py38haa95532_0
18
+ - asttokens=2.0.5=pyhd3eb1b0_0
19
+ - astunparse=1.6.3=pyhd8ed1ab_0
20
+ - async-timeout=4.0.2=pyhd8ed1ab_0
21
+ - atomicwrites=1.4.0=py_0
22
+ - attrs=21.4.0=pyhd3eb1b0_0
23
+ - autopep8=1.5.6=pyhd3eb1b0_0
24
+ - babel=2.9.1=pyhd3eb1b0_0
25
+ - backcall=0.2.0=pyhd3eb1b0_0
26
+ - bcrypt=3.2.0=py38he774522_0
27
+ - beautifulsoup4=4.11.1=py38haa95532_0
28
+ - binaryornot=0.4.4=pyhd3eb1b0_1
29
+ - black=19.10b0=py_0
30
+ - blas=1.0=mkl
31
+ - bleach=4.1.0=pyhd3eb1b0_0
32
+ - blinker=1.4=py_1
33
+ - brotlipy=0.7.0=py38h2bbff1b_1003
34
+ - ca-certificates=2022.6.15=h5b45459_0
35
+ - cachetools=5.0.0=pyhd8ed1ab_0
36
+ - certifi=2022.6.15=py38haa244fe_0
37
+ - cffi=1.15.0=py38h2bbff1b_1
38
+ - chardet=4.0.0=py38haa95532_1003
39
+ - charset-normalizer=2.0.4=pyhd3eb1b0_0
40
+ - click=8.0.4=py38haa95532_0
41
+ - cloudpickle=2.0.0=pyhd3eb1b0_0
42
+ - colorama=0.4.4=pyhd3eb1b0_0
43
+ - cookiecutter=1.7.3=pyhd3eb1b0_0
44
+ - cryptography=37.0.1=py38h21b164f_0
45
+ - debugpy=1.5.1=py38hd77b12b_0
46
+ - decorator=5.1.1=pyhd3eb1b0_0
47
+ - defusedxml=0.7.1=pyhd3eb1b0_0
48
+ - diff-match-patch=20200713=pyhd3eb1b0_0
49
+ - docutils=0.17.1=py38haa95532_1
50
+ - eigen=3.3.7=h59b6b97_1
51
+ - entrypoints=0.4=py38haa95532_0
52
+ - executing=0.8.3=pyhd3eb1b0_0
53
+ - flake8=3.9.0=pyhd3eb1b0_0
54
+ - frozenlist=1.3.0=py38h294d835_1
55
+ - future=0.18.2=py38_1
56
+ - gast=0.4.0=pyh9f0ad1d_0
57
+ - glib=2.69.1=h5dc1a3c_1
58
+ - google-auth=2.8.0=pyh6c4a22f_0
59
+ - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0
60
+ - google-pasta=0.2.0=pyh8c360ce_0
61
+ - gst-plugins-base=1.18.5=h9e645db_0
62
+ - gstreamer=1.18.5=hd78058f_0
63
+ - h5py=2.10.0=py38h5e291fa_0
64
+ - hdf5=1.10.4=h7ebc959_0
65
+ - icc_rt=2019.0.0=h0cc432a_1
66
+ - icu=58.2=ha925a31_3
67
+ - idna=3.3=pyhd3eb1b0_0
68
+ - imagesize=1.3.0=pyhd3eb1b0_0
69
+ - importlib-metadata=4.11.3=py38haa95532_0
70
+ - importlib_metadata=4.11.3=hd3eb1b0_0
71
+ - importlib_resources=5.2.0=pyhd3eb1b0_1
72
+ - inflection=0.5.1=py38haa95532_0
73
+ - intel-openmp=2021.4.0=haa95532_3556
74
+ - intervaltree=3.1.0=pyhd3eb1b0_0
75
+ - ipykernel=6.9.1=py38haa95532_0
76
+ - ipython=8.3.0=py38haa95532_0
77
+ - ipython_genutils=0.2.0=pyhd3eb1b0_1
78
+ - isort=5.9.3=pyhd3eb1b0_0
79
+ - jedi=0.17.2=py38haa95532_1
80
+ - jinja2=3.0.3=pyhd3eb1b0_0
81
+ - jinja2-time=0.2.0=pyhd3eb1b0_3
82
+ - joblib=1.1.0=pyhd3eb1b0_0
83
+ - jpeg=9e=h2bbff1b_0
84
+ - jsonschema=4.4.0=py38haa95532_0
85
+ - jupyter_client=7.2.2=py38haa95532_0
86
+ - jupyter_core=4.10.0=py38haa95532_0
87
+ - jupyterlab_pygments=0.1.2=py_0
88
+ - keras-applications=1.0.8=py_1
89
+ - keras-preprocessing=1.1.2=pyhd8ed1ab_0
90
+ - keyring=23.4.0=py38haa95532_0
91
+ - lazy-object-proxy=1.6.0=py38h2bbff1b_0
92
+ - libblas=3.9.0=1_h8933c1f_netlib
93
+ - libcblas=3.9.0=5_hd5c7e75_netlib
94
+ - libffi=3.4.2=hd77b12b_4
95
+ - libiconv=1.16=h2bbff1b_2
96
+ - liblapack=3.9.0=5_hd5c7e75_netlib
97
+ - libogg=1.3.5=h2bbff1b_1
98
+ - libopencv=4.0.1=hbb9e17c_0
99
+ - libpng=1.6.37=h2a8f88b_0
100
+ - libprotobuf=3.20.1=h23ce68f_0
101
+ - libspatialindex=1.9.3=h6c2663c_0
102
+ - libtiff=4.2.0=he0120a3_1
103
+ - libvorbis=1.3.7=he774522_0
104
+ - libwebp-base=1.2.2=h2bbff1b_0
105
+ - lz4-c=1.9.3=h2bbff1b_1
106
+ - m2w64-gcc-libgfortran=5.3.0=6
107
+ - m2w64-gcc-libs=5.3.0=7
108
+ - m2w64-gcc-libs-core=5.3.0=7
109
+ - m2w64-gmp=6.1.0=2
110
+ - m2w64-libwinpthread-git=5.0.0.4634.697f757=2
111
+ - markdown=3.3.7=pyhd8ed1ab_0
112
+ - markupsafe=2.1.1=py38h2bbff1b_0
113
+ - matplotlib-inline=0.1.2=pyhd3eb1b0_2
114
+ - mccabe=0.6.1=py38_1
115
+ - mistune=0.8.4=py38he774522_1000
116
+ - msys2-conda-epoch=20160418=1
117
+ - multidict=6.0.2=py38h294d835_1
118
+ - mypy_extensions=0.4.3=py38haa95532_1
119
+ - nbclient=0.5.13=py38haa95532_0
120
+ - nbconvert=6.4.4=py38haa95532_0
121
+ - nbformat=5.3.0=py38haa95532_0
122
+ - nest-asyncio=1.5.5=py38haa95532_0
123
+ - nomkl=1.0=h5ca1d4c_0
124
+ - notebook=6.4.11=py38haa95532_0
125
+ - numpydoc=1.2=pyhd3eb1b0_0
126
+ - oauthlib=3.2.0=pyhd8ed1ab_0
127
+ - opencv=4.0.1=py38h2a7c758_0
128
+ - openssl=1.1.1p=h8ffe710_0
129
+ - opt_einsum=3.3.0=pyhd8ed1ab_1
130
+ - packaging=21.3=pyhd3eb1b0_0
131
+ - pandas=1.2.4=py38hf11a4ad_0
132
+ - pandocfilters=1.5.0=pyhd3eb1b0_0
133
+ - paramiko=2.8.1=pyhd3eb1b0_0
134
+ - parso=0.7.0=py_0
135
+ - pathspec=0.7.0=py_0
136
+ - pcre=8.45=hd77b12b_0
137
+ - pexpect=4.8.0=pyhd3eb1b0_3
138
+ - pickleshare=0.7.5=pyhd3eb1b0_1003
139
+ - pip=21.2.2=py38haa95532_0
140
+ - platformdirs=2.4.0=pyhd3eb1b0_0
141
+ - pluggy=1.0.0=py38haa95532_1
142
+ - poyo=0.5.0=pyhd3eb1b0_0
143
+ - prometheus_client=0.13.1=pyhd3eb1b0_0
144
+ - prompt-toolkit=3.0.20=pyhd3eb1b0_0
145
+ - psutil=5.8.0=py38h2bbff1b_1
146
+ - ptyprocess=0.7.0=pyhd3eb1b0_2
147
+ - pure_eval=0.2.2=pyhd3eb1b0_0
148
+ - py-opencv=4.0.1=py38he44ac1e_0
149
+ - pyasn1=0.4.8=py_0
150
+ - pyasn1-modules=0.2.7=py_0
151
+ - pycodestyle=2.6.0=pyhd3eb1b0_0
152
+ - pycparser=2.21=pyhd3eb1b0_0
153
+ - pydocstyle=6.1.1=pyhd3eb1b0_0
154
+ - pyflakes=2.2.0=pyhd3eb1b0_0
155
+ - pygments=2.11.2=pyhd3eb1b0_0
156
+ - pyjwt=2.4.0=pyhd8ed1ab_0
157
+ - pylint=2.12.2=py38haa95532_1
158
+ - pyls-black=0.4.6=hd3eb1b0_0
159
+ - pyls-spyder=0.3.2=pyhd3eb1b0_0
160
+ - pynacl=1.4.0=py38h62dcd97_1
161
+ - pyopenssl=22.0.0=pyhd3eb1b0_0
162
+ - pyqt=5.9.2=py38hd77b12b_6
163
+ - pyreadline=2.1=py38haa244fe_1005
164
+ - pyrsistent=0.18.0=py38h196d8e1_0
165
+ - pysocks=1.7.1=py38haa95532_0
166
+ - python=3.8.13=h6244533_0
167
+ - python-dateutil=2.8.2=pyhd3eb1b0_0
168
+ - python-fastjsonschema=2.15.1=pyhd3eb1b0_0
169
+ - python-jsonrpc-server=0.4.0=py_0
170
+ - python-language-server=0.36.2=pyhd3eb1b0_0
171
+ - python-slugify=5.0.2=pyhd3eb1b0_0
172
+ - python_abi=3.8=2_cp38
173
+ - pytz=2022.1=py38haa95532_0
174
+ - pyu2f=0.1.5=pyhd8ed1ab_0
175
+ - pywin32=302=py38h2bbff1b_2
176
+ - pywin32-ctypes=0.2.0=py38_1000
177
+ - pywinpty=2.0.2=py38h5da7b33_0
178
+ - pyyaml=6.0=py38h2bbff1b_1
179
+ - pyzmq=22.3.0=py38hd77b12b_2
180
+ - qdarkstyle=3.0.2=pyhd3eb1b0_0
181
+ - qstylizer=0.1.10=pyhd3eb1b0_0
182
+ - qt=5.9.7=vc14h73c81de_0
183
+ - qtawesome=1.0.3=pyhd3eb1b0_0
184
+ - qtconsole=5.3.0=pyhd3eb1b0_0
185
+ - qtpy=2.0.1=pyhd3eb1b0_0
186
+ - regex=2022.3.15=py38h2bbff1b_0
187
+ - requests=2.27.1=pyhd3eb1b0_0
188
+ - requests-oauthlib=1.3.1=pyhd8ed1ab_0
189
+ - rope=0.22.0=pyhd3eb1b0_0
190
+ - rsa=4.8=pyhd8ed1ab_0
191
+ - rtree=0.9.7=py38h2eaa2aa_1
192
+ - scikit-learn=1.0.2=py38hf11a4ad_1
193
+ - scipy=1.5.3=py38h5f893b4_0
194
+ - send2trash=1.8.0=pyhd3eb1b0_1
195
+ - setuptools=61.2.0=py38haa95532_0
196
+ - sip=4.19.13=py38hd77b12b_0
197
+ - snowballstemmer=2.2.0=pyhd3eb1b0_0
198
+ - sortedcontainers=2.4.0=pyhd3eb1b0_0
199
+ - soupsieve=2.3.1=pyhd3eb1b0_0
200
+ - sphinx=4.4.0=pyhd3eb1b0_0
201
+ - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0
202
+ - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0
203
+ - sphinxcontrib-htmlhelp=2.0.0=pyhd3eb1b0_0
204
+ - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0
205
+ - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0
206
+ - sphinxcontrib-serializinghtml=1.1.5=pyhd3eb1b0_0
207
+ - spyder=5.0.5=py38haa95532_2
208
+ - spyder-kernels=2.0.5=py38haa95532_0
209
+ - sqlite=3.38.3=h2bbff1b_0
210
+ - stack_data=0.2.0=pyhd3eb1b0_0
211
+ - tensorboard=2.9.0=pyhd8ed1ab_0
212
+ - tensorboard-data-server=0.6.0=py38haa244fe_2
213
+ - tensorboard-plugin-wit=1.8.1=pyhd8ed1ab_0
214
+ - tensorflow=2.3.0=mkl_py38h8557ec7_0
215
+ - tensorflow-base=2.3.0=eigen_py38h75a453f_0
216
+ - tensorflow-estimator=2.5.0=pyh8a188c0_0
217
+ - tensorflow-gpu=2.3.0=he13fc11_0
218
+ - termcolor=1.1.0=py_2
219
+ - terminado=0.13.1=py38haa95532_0
220
+ - testpath=0.5.0=pyhd3eb1b0_0
221
+ - text-unidecode=1.3=pyhd3eb1b0_0
222
+ - textdistance=4.2.1=pyhd3eb1b0_0
223
+ - threadpoolctl=2.2.0=pyh0d69192_0
224
+ - three-merge=0.1.1=pyhd3eb1b0_0
225
+ - tinycss=0.4=pyhd3eb1b0_1002
226
+ - toml=0.10.2=pyhd3eb1b0_0
227
+ - tornado=6.1=py38h2bbff1b_0
228
+ - traitlets=5.1.1=pyhd3eb1b0_0
229
+ - typed-ast=1.4.3=py38h2bbff1b_1
230
+ - ujson=5.1.0=py38hd77b12b_0
231
+ - unidecode=1.2.0=pyhd3eb1b0_0
232
+ - urllib3=1.26.9=py38haa95532_0
233
+ - vc=14.2=h21ff451_1
234
+ - vs2015_runtime=14.27.29016=h5e58377_2
235
+ - watchdog=2.1.6=py38haa95532_0
236
+ - wcwidth=0.2.5=pyhd3eb1b0_0
237
+ - webencodings=0.5.1=py38_1
238
+ - werkzeug=2.1.2=pyhd8ed1ab_1
239
+ - wheel=0.37.1=pyhd3eb1b0_0
240
+ - win_inet_pton=1.1.0=py38haa95532_0
241
+ - wincertstore=0.2=py38haa95532_2
242
+ - winpty=0.4.3=4
243
+ - xz=5.2.5=h8cc25b3_1
244
+ - yaml=0.2.5=he774522_0
245
+ - yapf=0.31.0=pyhd3eb1b0_0
246
+ - yarl=1.7.2=py38h294d835_2
247
+ - zipp=3.8.0=py38haa95532_0
248
+ - zlib=1.2.12=h8cc25b3_2
249
+ - zstd=1.5.2=h19a0ad4_0
250
+ - pip:
251
+ - absl-py==0.15.0
252
+ - cycler==0.11.0
253
+ - fonttools==4.33.3
254
+ - grpcio==1.32.0
255
+ - kiwisolver==1.4.3
256
+ - matplotlib==3.5.2
257
+ - mediapipe==0.8.10
258
+ - numpy==1.23.0
259
+ - opencv-contrib-python==4.6.0.66
260
+ - pillow==9.1.1
261
+ - protobuf==4.21.1
262
+ - pyparsing==3.0.9
263
+ - six==1.15.0
264
+ - typing-extensions==3.7.4.3
265
+ - wrapt==1.12.1
266
+ prefix: C:\Users\cpras\anaconda3\envs\AItrainer
models/LSTM.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6778664cec93d5e917b064af44e03dfb4344c3a779d453106d68c1d3ea00e560
3
+ size 9069616
models/LSTM_Attention.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2395d5eb371bb8221e2cacb7c98dbc336de6775bd2607747f4e1f72d0fa4e915
3
+ size 104036816
pose_tracking_full_body_landmarks.png ADDED
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ streamlit_webrtc
3
+ keras==2.9.0
4
+ notebook==6.4.11
5
+ numpy==1.23.0
6
+ Markdown==3.3.7
7
+ ipykernel==6.9.1
8
+ ipython==8.3.0
9
+ mediapipe==0.8.10
10
+ pillow==9.1.1
11
+ opencv-python==4.6.0.66
12
+ opencv-contrib-python==4.6.0.66
13
+ tensorflow
tests/feature_engineering.ipynb ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 242,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import cv2\n",
10
+ "import numpy as np\n",
11
+ "import os\n",
12
+ "from matplotlib import pyplot as plt\n",
13
+ "import time\n",
14
+ "import mediapipe as mp\n"
15
+ ]
16
+ },
17
+ {
18
+ "cell_type": "code",
19
+ "execution_count": 243,
20
+ "metadata": {},
21
+ "outputs": [],
22
+ "source": [
23
+ "# Pre-trained pose estimation model from Google Mediapipe\n",
24
+ "mp_pose = mp.solutions.pose\n",
25
+ "\n",
26
+ "# Supported Mediapipe visualization tools\n",
27
+ "mp_drawing = mp.solutions.drawing_utils"
28
+ ]
29
+ },
30
+ {
31
+ "cell_type": "code",
32
+ "execution_count": 244,
33
+ "metadata": {},
34
+ "outputs": [],
35
+ "source": [
36
+ "def mediapipe_detection(image, model):\n",
37
+ " \"\"\"\n",
38
+ " This function detects human pose estimation keypoints from webcam footage\n",
39
+ " \n",
40
+ " \"\"\"\n",
41
+ " image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) # COLOR CONVERSION BGR 2 RGB\n",
42
+ " image.flags.writeable = False # Image is no longer writeable\n",
43
+ " results = model.process(image) # Make prediction\n",
44
+ " image.flags.writeable = True # Image is now writeable \n",
45
+ " image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) # COLOR COVERSION RGB 2 BGR\n",
46
+ " return image, results"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "execution_count": 245,
52
+ "metadata": {},
53
+ "outputs": [],
54
+ "source": [
55
+ "def draw_landmarks(image, results):\n",
56
+ " \"\"\"\n",
57
+ " This function draws keypoints and landmarks detected by the human pose estimation model\n",
58
+ " \n",
59
+ " \"\"\"\n",
60
+ " mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,\n",
61
+ " mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2), \n",
62
+ " mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2) \n",
63
+ " )"
64
+ ]
65
+ },
66
+ {
67
+ "cell_type": "code",
68
+ "execution_count": 246,
69
+ "metadata": {},
70
+ "outputs": [],
71
+ "source": [
72
+ "def draw_detection(image, results):\n",
73
+ "\n",
74
+ " h, w, c = image.shape\n",
75
+ " cx_min = w\n",
76
+ " cy_min = h\n",
77
+ " cx_max = cy_max = 0\n",
78
+ " center = [w//2, h//2]\n",
79
+ " try:\n",
80
+ " for id, lm in enumerate(results.pose_landmarks.landmark):\n",
81
+ " cx, cy = int(lm.x * w), int(lm.y * h)\n",
82
+ " if cx < cx_min:\n",
83
+ " cx_min = cx\n",
84
+ " if cy < cy_min:\n",
85
+ " cy_min = cy\n",
86
+ " if cx > cx_max:\n",
87
+ " cx_max = cx\n",
88
+ " if cy > cy_max:\n",
89
+ " cy_max = cy\n",
90
+ " \n",
91
+ " boxW, boxH = cx_max - cx_min, cy_max - cy_min\n",
92
+ " \n",
93
+ " # center\n",
94
+ " cx, cy = cx_min + (boxW // 2), \\\n",
95
+ " cy_min + (boxH // 2) \n",
96
+ " center = [cx, cy]\n",
97
+ " \n",
98
+ " cv2.rectangle(\n",
99
+ " image, (cx_min, cy_min), (cx_max, cy_max), (255, 255, 0), 2\n",
100
+ " )\n",
101
+ " except:\n",
102
+ " pass\n",
103
+ " \n",
104
+ " return [[cx_min, cy_min], [cx_max, cy_max]], center"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": 247,
110
+ "metadata": {},
111
+ "outputs": [],
112
+ "source": [
113
+ "def normalize(image, results, bounding_box, landmark_names):\n",
114
+ " h, w, c = image.shape\n",
115
+ " if results.pose_landmarks:\n",
116
+ " xy = {}\n",
117
+ " xy_norm = {}\n",
118
+ " i = 0\n",
119
+ " for res in results.pose_landmarks.landmark:\n",
120
+ " x = res.x * w\n",
121
+ " y = res.y * h\n",
122
+ " \n",
123
+ " x_norm = (x - bounding_box[0][0]) / (bounding_box[1][0] - bounding_box[0][0])\n",
124
+ " y_norm = (y - bounding_box[0][1]) / (bounding_box[1][1] - bounding_box[0][1])\n",
125
+ " \n",
126
+ " # xy_norm.append([x_norm, y_norm])\n",
127
+ " \n",
128
+ " xy_norm[landmark_names[i]] = [x_norm, y_norm]\n",
129
+ " i += 1\n",
130
+ " else:\n",
131
+ " # xy_norm = np.zeros([0,0] * 33)\n",
132
+ " \n",
133
+ " # xy = {landmark_names: [0,0]}\n",
134
+ " # xy_norm = {landmark_names: [0,0]}\n",
135
+ " \n",
136
+ " xy_norm = dict(zip(landmark_names, [0,0] * 33))\n",
137
+ " \n",
138
+ " return xy_norm"
139
+ ]
140
+ },
141
+ {
142
+ "cell_type": "code",
143
+ "execution_count": 248,
144
+ "metadata": {},
145
+ "outputs": [],
146
+ "source": [
147
+ "def get_coordinates(landmarks, mp_pose, side, joint):\n",
148
+ " \"\"\"\n",
149
+ " Retrieves x and y coordinates of a particular keypoint from the pose estimation model\n",
150
+ " \n",
151
+ " Args:\n",
152
+ " landmarks: processed keypoints from the pose estimation model\n",
153
+ " mp_pose: Mediapipe pose estimation model\n",
154
+ " side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.\n",
155
+ " joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.\n",
156
+ " \n",
157
+ " \"\"\"\n",
158
+ " coord = getattr(mp_pose.PoseLandmark,side.upper()+\"_\"+joint.upper())\n",
159
+ " x_coord_val = landmarks[coord.value].x\n",
160
+ " y_coord_val = landmarks[coord.value].y\n",
161
+ " return [x_coord_val, y_coord_val] "
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 249,
167
+ "metadata": {},
168
+ "outputs": [],
169
+ "source": [
170
+ "def viz_coords(image, norm_coords, landmarks, mp_pose, side, joint):\n",
171
+ " \"\"\"\n",
172
+ " Displays the joint angle value near the joint within the image frame\n",
173
+ " \n",
174
+ " \"\"\"\n",
175
+ " try:\n",
176
+ " point = side.upper()+\"_\"+joint.upper()\n",
177
+ " norm_coords = norm_coords[point]\n",
178
+ " joint = get_coordinates(landmarks, mp_pose, side, joint)\n",
179
+ " \n",
180
+ " coords = [ '%.2f' % elem for elem in joint ]\n",
181
+ " coords = ' '.join(str(coords))\n",
182
+ " norm_coords = [ '%.2f' % elem for elem in norm_coords ]\n",
183
+ " norm_coords = ' '.join(str(norm_coords))\n",
184
+ " cv2.putText(image, coords, \n",
185
+ " tuple(np.multiply(joint, [640, 480]).astype(int)), \n",
186
+ " cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA\n",
187
+ " )\n",
188
+ " cv2.putText(image, norm_coords, \n",
189
+ " tuple(np.multiply(joint, [640, 480]).astype(int) + 20), \n",
190
+ " cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0,0,255), 2, cv2.LINE_AA\n",
191
+ " )\n",
192
+ " except:\n",
193
+ " pass\n",
194
+ " return"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": 250,
200
+ "metadata": {},
201
+ "outputs": [],
202
+ "source": [
203
+ "cap = cv2.VideoCapture(0) # camera object\n",
204
+ "HEIGHT = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # webcam video frame height\n",
205
+ "WIDTH = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) # webcam video frame width\n",
206
+ "FPS = int(cap.get(cv2.CAP_PROP_FPS)) # webcam video fram rate \n",
207
+ "\n",
208
+ "landmark_names = dir(mp_pose.PoseLandmark)[:-4]\n",
209
+ "\n",
210
+ "# Set and test mediapipe model using webcam\n",
211
+ "with mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5, enable_segmentation=True) as pose:\n",
212
+ " while cap.isOpened():\n",
213
+ "\n",
214
+ " # Read feed\n",
215
+ " ret, frame = cap.read()\n",
216
+ " \n",
217
+ " # Make detection\n",
218
+ " image, results = mediapipe_detection(frame, pose)\n",
219
+ " \n",
220
+ " # Extract landmarks\n",
221
+ " try:\n",
222
+ " landmarks = results.pose_landmarks.landmark\n",
223
+ " except:\n",
224
+ " pass\n",
225
+ " \n",
226
+ " # draw bounding box\n",
227
+ " bounding_box, box_center = draw_detection(image, results)\n",
228
+ " \n",
229
+ " # Render detections\n",
230
+ " draw_landmarks(image, results) \n",
231
+ " \n",
232
+ " # normalize coordinates\n",
233
+ " xy_norm = normalize(image, results, bounding_box, landmark_names) \n",
234
+ " viz_coords(image, xy_norm, landmarks, mp_pose, 'left', 'wrist') \n",
235
+ " viz_coords(image, xy_norm, landmarks, mp_pose, 'right', 'wrist') \n",
236
+ " \n",
237
+ " # Display frame on screen\n",
238
+ " cv2.imshow('OpenCV Feed', image)\n",
239
+ " \n",
240
+ " # Draw segmentation on the image.\n",
241
+ " # To improve segmentation around boundaries, consider applying a joint\n",
242
+ " # bilateral filter to \"results.segmentation_mask\" with \"image\".\n",
243
+ " # tightness = 0.3 # Probability threshold in [0, 1] that says how \"tight\" to make the segmentation. Greater value => tighter.\n",
244
+ " # condition = np.stack((results.segmentation_mask,) * 3, axis=-1) > tightness\n",
245
+ " # bg_image = np.zeros(image.shape, dtype=np.uint8)\n",
246
+ " # bg_image[:] = (192, 192, 192) # gray\n",
247
+ " # image = np.where(condition, image, bg_image)\n",
248
+ " \n",
249
+ " # Exit / break out logic\n",
250
+ " if cv2.waitKey(10) & 0xFF == ord('q'):\n",
251
+ " break\n",
252
+ "\n",
253
+ " cap.release()\n",
254
+ " cv2.destroyAllWindows()"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": 251,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "cap.release()\n",
264
+ "cv2.destroyAllWindows()"
265
+ ]
266
+ }
267
+ ],
268
+ "metadata": {
269
+ "kernelspec": {
270
+ "display_name": "Python 3.8.13 ('AItrainer')",
271
+ "language": "python",
272
+ "name": "python3"
273
+ },
274
+ "language_info": {
275
+ "codemirror_mode": {
276
+ "name": "ipython",
277
+ "version": 3
278
+ },
279
+ "file_extension": ".py",
280
+ "mimetype": "text/x-python",
281
+ "name": "python",
282
+ "nbconvert_exporter": "python",
283
+ "pygments_lexer": "ipython3",
284
+ "version": "3.8.13"
285
+ },
286
+ "orig_nbformat": 4,
287
+ "vscode": {
288
+ "interpreter": {
289
+ "hash": "80aa1d3f3a8cfb37a38c47373cc49a39149184c5fa770d709389b1b8782c1d85"
290
+ }
291
+ }
292
+ },
293
+ "nbformat": 4,
294
+ "nbformat_minor": 2
295
+ }