randomshit11 commited on
Commit
1d15de2
·
verified ·
1 Parent(s): 54b36c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +128 -68
app.py CHANGED
@@ -73,100 +73,74 @@
73
  # mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
74
  # return image
75
 
76
- import os
77
- import streamlit as st
78
- import cv2
79
- import mediapipe as mp
80
- import numpy as np
81
- import math
82
- from tensorflow.keras.models import Model
83
- from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, Flatten,
84
- Bidirectional, Permute, multiply)
85
-
86
- # Load the pose estimation model from Mediapipe
87
- mp_pose = mp.solutions.pose
88
- mp_drawing = mp.solutions.drawing_utils
89
- pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
90
-
91
- # Define the attention block for the LSTM model
92
- def attention_block(inputs, time_steps):
93
- a = Permute((2, 1))(inputs)
94
- a = Dense(time_steps, activation='softmax')(a)
95
- a_probs = Permute((2, 1), name='attention_vec')(a)
96
- output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
97
- return output_attention_mul
98
-
99
- # Build and load the LSTM model
100
- @st.cache(allow_output_mutation=True)
101
- def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
102
- inputs = Input(shape=(sequence_length, num_input_values))
103
- lstm_out = Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True))(inputs)
104
- attention_mul = attention_block(lstm_out, sequence_length)
105
- attention_mul = Flatten()(attention_mul)
106
- x = Dense(2*HIDDEN_UNITS, activation='relu')(attention_mul)
107
- x = Dropout(0.5)(x)
108
- x = Dense(num_classes, activation='softmax')(x)
109
- model = Model(inputs=[inputs], outputs=x)
110
- load_dir = "./models/LSTM_Attention.h5"
111
- model.load_weights(load_dir)
112
- return model
113
-
114
- # Define the VideoProcessor class for real-time video processing
115
  class VideoProcessor:
116
  def __init__(self):
117
  self.actions = np.array(['curl', 'press', 'squat'])
118
  self.sequence_length = 30
119
  self.colors = [(245,117,16), (117,245,16), (16,117,245)]
120
- self.pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
121
- self.model = build_model()
122
-
123
- def process_video(self, video_file):
124
- # Get the filename from the file object
125
- filename = video_file.name
126
- # Create a temporary file to write the contents of the uploaded video file
127
- temp_file = open(filename, 'wb')
128
- temp_file.write(video_file.read())
129
- temp_file.close()
130
- # Now we can open the video file using cv2.VideoCapture()
131
- cap = cv2.VideoCapture(filename)
132
- out_frames = []
133
- while cap.isOpened():
134
- ret, frame = cap.read()
135
- if not ret:
136
- break
137
- frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
138
- results = self.pose.process(frame_rgb)
139
- frame = self.draw_landmarks(frame, results)
140
- out_frames.append(frame)
141
- cap.release()
142
- # Remove the temporary file
143
- os.remove(filename)
144
- return out_frames
145
 
 
146
  def draw_landmarks(self, image, results):
147
- mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
148
- mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
149
- mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
 
 
 
 
 
150
  return image
151
-
152
  @st.cache()
153
  def extract_keypoints(self, results):
 
 
 
 
 
154
  pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
155
  return pose
156
 
157
  @st.cache()
158
  def calculate_angle(self, a, b, c):
 
 
 
 
159
  a = np.array(a) # First
160
  b = np.array(b) # Mid
161
  c = np.array(c) # End
 
162
  radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
163
  angle = np.abs(radians*180.0/np.pi)
 
164
  if angle > 180.0:
165
  angle = 360-angle
 
166
  return angle
167
 
168
  @st.cache()
169
  def get_coordinates(self, landmarks, side, joint):
 
 
 
 
 
 
 
 
 
170
  coord = getattr(self.mp_pose.PoseLandmark, side.upper() + "_" + joint.upper())
171
  x_coord_val = landmarks[coord.value].x
172
  y_coord_val = landmarks[coord.value].y
@@ -174,12 +148,98 @@ class VideoProcessor:
174
 
175
  @st.cache()
176
  def viz_joint_angle(self, image, angle, joint):
 
 
 
 
177
  cv2.putText(image, str(int(angle)),
178
  tuple(np.multiply(joint, [640, 480]).astype(int)),
179
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
180
  )
181
  return
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
  # Define Streamlit app
184
  def main():
185
  st.title("Real-time Exercise Detection")
 
73
  # mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
74
  # return image
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class VideoProcessor:
77
  def __init__(self):
78
  self.actions = np.array(['curl', 'press', 'squat'])
79
  self.sequence_length = 30
80
  self.colors = [(245,117,16), (117,245,16), (16,117,245)]
81
+ self.threshold = 0.50 # Default threshold for activity classification confidence
82
+
83
+ # Detection variables
84
+ self.sequence = []
85
+ self.current_action = ''
86
+
87
+ # Initialize pose model
88
+ self.mp_pose = mp.solutions.pose
89
+ self.mp_drawing = mp.solutions.drawing_utils
90
+ self.pose = self.mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
91
+ self.model = build_model() # Load the LSTM model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ @st.cache()
94
  def draw_landmarks(self, image, results):
95
+ """
96
+ This function draws keypoints and landmarks detected by the human pose estimation model
97
+
98
+ """
99
+ self.mp_drawing.draw_landmarks(image, results.pose_landmarks, self.mp_pose.POSE_CONNECTIONS,
100
+ self.mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
101
+ self.mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
102
+ )
103
  return image
104
+
105
  @st.cache()
106
  def extract_keypoints(self, results):
107
+ """
108
+ Processes and organizes the keypoints detected from the pose estimation model
109
+ to be used as inputs for the exercise decoder models
110
+
111
+ """
112
  pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
113
  return pose
114
 
115
  @st.cache()
116
  def calculate_angle(self, a, b, c):
117
+ """
118
+ Computes 3D joint angle inferred by 3 keypoints and their relative positions to one another
119
+
120
+ """
121
  a = np.array(a) # First
122
  b = np.array(b) # Mid
123
  c = np.array(c) # End
124
+
125
  radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
126
  angle = np.abs(radians*180.0/np.pi)
127
+
128
  if angle > 180.0:
129
  angle = 360-angle
130
+
131
  return angle
132
 
133
  @st.cache()
134
  def get_coordinates(self, landmarks, side, joint):
135
+ """
136
+ Retrieves x and y coordinates of a particular keypoint from the pose estimation model
137
+
138
+ Args:
139
+ landmarks: processed keypoints from the pose estimation model
140
+ side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.
141
+ joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.
142
+
143
+ """
144
  coord = getattr(self.mp_pose.PoseLandmark, side.upper() + "_" + joint.upper())
145
  x_coord_val = landmarks[coord.value].x
146
  y_coord_val = landmarks[coord.value].y
 
148
 
149
  @st.cache()
150
  def viz_joint_angle(self, image, angle, joint):
151
+ """
152
+ Displays the joint angle value near the joint within the image frame
153
+
154
+ """
155
  cv2.putText(image, str(int(angle)),
156
  tuple(np.multiply(joint, [640, 480]).astype(int)),
157
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
158
  )
159
  return
160
 
161
+ @st.cache()
162
+ def process(self, image):
163
+ """
164
+ Function to process the video frame from the user's webcam and run the fitness trainer AI
165
+
166
+ Args:
167
+ image (numpy array): input image from the webcam
168
+
169
+ Returns:
170
+ numpy array: processed image with keypoint detection and fitness activity classification visualized
171
+ """
172
+ # Pose detection model
173
+ image.flags.writeable = False
174
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
175
+ results = pose.process(image)
176
+
177
+ # Draw the hand annotations on the image.
178
+ image.flags.writeable = True
179
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
180
+ self.draw_landmarks(image, results)
181
+
182
+ # Prediction logic
183
+ keypoints = self.extract_keypoints(results)
184
+ self.sequence.append(keypoints.astype('float32',casting='same_kind'))
185
+ self.sequence = self.sequence[-self.sequence_length:]
186
+
187
+ if len(self.sequence) == self.sequence_length:
188
+ res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
189
+ # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
190
+ # interpreter.invoke()
191
+ # res = interpreter.get_tensor(self.output_details[0]['index'])
192
+
193
+ self.current_action = self.actions[np.argmax(res)]
194
+ confidence = np.max(res)
195
+
196
+ # Erase current action variable if no probability is above threshold
197
+ if confidence < self.threshold:
198
+ self.current_action = ''
199
+
200
+ # Viz probabilities
201
+ image = self.prob_viz(res, image)
202
+
203
+ # Count reps
204
+ try:
205
+ landmarks = results.pose_landmarks.landmark
206
+ self.count_reps(
207
+ image, landmarks, mp_pose)
208
+ except:
209
+ pass
210
+
211
+ # Display graphical information
212
+ cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
213
+ cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
214
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
215
+ cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
216
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
217
+ cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
218
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
219
+
220
+ # return cv2.flip(image, 1)
221
+ return image
222
+
223
+ def process_video(self, video_file):
224
+ # Get the filename from the file object
225
+ filename = video_file.name
226
+ # Create a temporary file to write the contents of the uploaded video file
227
+ temp_file = open(filename, 'wb')
228
+ temp_file.write(video_file.read())
229
+ temp_file.close()
230
+ # Now we can open the video file using cv2.VideoCapture()
231
+ cap = cv2.VideoCapture(filename)
232
+ out_frames = []
233
+ while cap.isOpened():
234
+ ret, frame = cap.read()
235
+ if not ret:
236
+ break
237
+ frame_processed = self.process(frame)
238
+ out_frames.append(frame_processed)
239
+ cap.release()
240
+ # Remove the temporary file
241
+ os.remove(filename)
242
+ return out_frames
243
  # Define Streamlit app
244
  def main():
245
  st.title("Real-time Exercise Detection")