randomshit11 commited on
Commit
af36780
·
verified ·
1 Parent(s): 7d291aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +177 -38
app.py CHANGED
@@ -47,74 +47,213 @@ class VideoProcessor:
47
 
48
  def process_video(self, video_file):
49
  # Get the filename from the file object
50
- filename = video_file.name
51
  # Create a temporary file to write the contents of the uploaded video file
52
- temp_file = open(filename, 'wb')
53
- temp_file.write(video_file.read())
54
- temp_file.close()
55
  # Now we can open the video file using cv2.VideoCapture()
56
  cap = cv2.VideoCapture(filename)
57
- out_frames = []
 
 
 
58
  while cap.isOpened():
59
  ret, frame = cap.read()
60
  if not ret:
61
  break
62
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
  results = self.pose.process(frame_rgb)
64
- frame = self.draw_landmarks(frame, results)
65
- out_frames.append(frame)
66
  cap.release()
 
67
  # Remove the temporary file
68
  os.remove(filename)
69
- return out_frames
70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  def draw_landmarks(self, image, results):
72
  mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
73
  mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
74
  mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
75
  return image
76
 
77
- @st.cache()
78
  def extract_keypoints(self, results):
79
  pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
80
  return pose
81
-
82
- @st.cache()
83
- def calculate_angle(self, a, b, c):
84
- a = np.array(a) # First
85
- b = np.array(b) # Mid
86
- c = np.array(c) # End
87
- radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
88
- angle = np.abs(radians*180.0/np.pi)
89
- if angle > 180.0:
90
- angle = 360-angle
91
- return angle
92
-
93
- @st.cache()
94
- def get_coordinates(self, landmarks, side, joint):
95
- coord = getattr(self.mp_pose.PoseLandmark, side.upper() + "_" + joint.upper())
96
- x_coord_val = landmarks[coord.value].x
97
- y_coord_val = landmarks[coord.value].y
98
- return [x_coord_val, y_coord_val]
99
-
100
- @st.cache()
101
- def viz_joint_angle(self, image, angle, joint):
102
- cv2.putText(image, str(int(angle)),
103
- tuple(np.multiply(joint, [640, 480]).astype(int)),
104
- cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
105
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  return
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Define Streamlit app
109
  def main():
110
  st.title("Real-time Exercise Detection")
111
  video_file = st.file_uploader("Upload a video file", type=["mp4", "avi"])
112
  if video_file is not None:
113
- st.video(video_file)
114
  video_processor = VideoProcessor()
115
- frames = video_processor.process_video(video_file)
116
- for frame in frames:
117
- st.image(frame, channels="BGR")
118
 
119
  if __name__ == "__main__":
120
  main()
 
47
 
48
  def process_video(self, video_file):
49
  # Get the filename from the file object
50
+ filename = "temp_video.mp4"
51
  # Create a temporary file to write the contents of the uploaded video file
52
+ with open(filename, 'wb') as temp_file:
53
+ temp_file.write(video_file.read())
 
54
  # Now we can open the video file using cv2.VideoCapture()
55
  cap = cv2.VideoCapture(filename)
56
+ output_filename = "processed_video.mp4"
57
+ frame_width = int(cap.get(3))
58
+ frame_height = int(cap.get(4))
59
+ out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'mp4v'), 30, (frame_width,frame_height))
60
  while cap.isOpened():
61
  ret, frame = cap.read()
62
  if not ret:
63
  break
64
  frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
65
  results = self.pose.process(frame_rgb)
66
+ processed_frame = self.process_frame(frame, results)
67
+ out.write(processed_frame)
68
  cap.release()
69
+ out.release()
70
  # Remove the temporary file
71
  os.remove(filename)
72
+ return output_filename
73
 
74
+ def process_frame(self, frame, results):
75
+ # Process the frame using the `process` function
76
+ processed_frame = self.process(frame)
77
+ return processed_frame
78
+
79
+ def process(self, image):
80
+ """
81
+ Function to process the video frame and run the fitness trainer AI
82
+
83
+ Args:
84
+ image (numpy array): input image from the video
85
+
86
+ Returns:
87
+ numpy array: processed image with keypoint detection and fitness activity classification visualized
88
+ """
89
+ # Pose detection model
90
+ image.flags.writeable = False
91
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
92
+ results = pose.process(image)
93
+
94
+ # Draw the hand annotations on the image.
95
+ image.flags.writeable = True
96
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
97
+ self.draw_landmarks(image, results)
98
+
99
+ # Prediction logic
100
+ keypoints = self.extract_keypoints(results)
101
+ self.sequence.append(keypoints.astype('float32',casting='same_kind'))
102
+ self.sequence = self.sequence[-self.sequence_length:]
103
+
104
+ if len(self.sequence) == self.sequence_length:
105
+ res = self.model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
106
+
107
+ self.current_action = self.actions[np.argmax(res)]
108
+ confidence = np.max(res)
109
+
110
+ # Erase current action variable if no probability is above threshold
111
+ if confidence < self.threshold:
112
+ self.current_action = ''
113
+
114
+ # Viz probabilities
115
+ image = self.prob_viz(res, image)
116
+
117
+ # Count reps
118
+ try:
119
+ landmarks = results.pose_landmarks.landmark
120
+ self.count_reps(image, landmarks, mp_pose)
121
+ except:
122
+ pass
123
+
124
+ # Display graphical information
125
+ cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
126
+ cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
127
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
128
+ cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
129
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
130
+ cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
131
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
132
+
133
+ return image
134
+
135
  def draw_landmarks(self, image, results):
136
  mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
137
  mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
138
  mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
139
  return image
140
 
 
141
  def extract_keypoints(self, results):
142
  pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
143
  return pose
144
+
145
+ def count_reps(self, image, landmarks, mp_pose):
146
+ """
147
+ Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
148
+
149
+ """
150
+
151
+ if self.current_action == 'curl':
152
+ # Get coords
153
+ shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
154
+ elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
155
+ wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
156
+
157
+ # calculate elbow angle
158
+ angle = self.calculate_angle(shoulder, elbow, wrist)
159
+
160
+ # curl counter logic
161
+ if angle < 30:
162
+ self.curl_stage = "up"
163
+ if angle > 140 and self.curl_stage =='up':
164
+ self.curl_stage="down"
165
+ self.curl_counter +=1
166
+ self.press_stage = None
167
+ self.squat_stage = None
168
+
169
+ # Viz joint angle
170
+ self.viz_joint_angle(image, angle, elbow)
171
+
172
+ elif self.current_action == 'press':
173
+ # Get coords
174
+ shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
175
+ elbow = self.get_coordinates(landmarks, mp_pose, 'left', 'elbow')
176
+ wrist = self.get_coordinates(landmarks, mp_pose, 'left', 'wrist')
177
+
178
+ # Calculate elbow angle
179
+ elbow_angle = self.calculate_angle(shoulder, elbow, wrist)
180
+
181
+ # Compute distances between joints
182
+ shoulder2elbow_dist = abs(math.dist(shoulder,elbow))
183
+ shoulder2wrist_dist = abs(math.dist(shoulder,wrist))
184
+
185
+ # Press counter logic
186
+ if (elbow_angle > 130) and (shoulder2elbow_dist < shoulder2wrist_dist):
187
+ self.press_stage = "up"
188
+ if (elbow_angle < 50) and (shoulder2elbow_dist > shoulder2wrist_dist) and (self.press_stage =='up'):
189
+ self.press_stage='down'
190
+ self.press_counter += 1
191
+ self.curl_stage = None
192
+ self.squat_stage = None
193
+
194
+ # Viz joint angle
195
+ self.viz_joint_angle(image, elbow_angle, elbow)
196
+
197
+ elif self.current_action == 'squat':
198
+ # Get coords
199
+ # left side
200
+ left_shoulder = self.get_coordinates(landmarks, mp_pose, 'left', 'shoulder')
201
+ left_hip = self.get_coordinates(landmarks, mp_pose, 'left', 'hip')
202
+ left_knee = self.get_coordinates(landmarks, mp_pose, 'left', 'knee')
203
+ left_ankle = self.get_coordinates(landmarks, mp_pose, 'left', 'ankle')
204
+ # right side
205
+ right_shoulder = self.get_coordinates(landmarks, mp_pose, 'right', 'shoulder')
206
+ right_hip = self.get_coordinates(landmarks, mp_pose, 'right', 'hip')
207
+ right_knee = self.get_coordinates(landmarks, mp_pose, 'right', 'knee')
208
+ right_ankle = self.get_coordinates(landmarks, mp_pose, 'right', 'ankle')
209
+
210
+ # Calculate knee angles
211
+ left_knee_angle = self.calculate_angle(left_hip, left_knee, left_ankle)
212
+ right_knee_angle = self.calculate_angle(right_hip, right_knee, right_ankle)
213
+
214
+ # Calculate hip angles
215
+ left_hip_angle = self.calculate_angle(left_shoulder, left_hip, left_knee)
216
+ right_hip_angle = self.calculate_angle(right_shoulder, right_hip, right_knee)
217
+
218
+ # Squat counter logic
219
+ thr = 165
220
+ if (left_knee_angle < thr) and (right_knee_angle < thr) and (left_hip_angle < thr) and (right_hip_angle < thr):
221
+ self.squat_stage = "down"
222
+ if (left_knee_angle > thr) and (right_knee_angle > thr) and (left_hip_angle > thr) and (right_hip_angle > thr) and (self.squat_stage =='down'):
223
+ self.squat_stage='up'
224
+ self.squat_counter += 1
225
+ self.curl_stage = None
226
+ self.press_stage = None
227
+
228
+ # Viz joint angles
229
+ self.viz_joint_angle(image, left_knee_angle, left_knee)
230
+ self.viz_joint_angle(image, left_hip_angle, left_hip)
231
+
232
+ else:
233
+ pass
234
  return
235
 
236
+ def prob_viz(self, res, input_frame):
237
+ """
238
+ This function displays the model prediction probability distribution over the set of exercise classes
239
+ as a horizontal bar graph
240
+
241
+ """
242
+ output_frame = input_frame.copy()
243
+ for num, prob in enumerate(res):
244
+ cv2.rectangle(output_frame, (0,60+num*40), (int(prob*100), 90+num*40), self.colors[num], -1)
245
+ cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
246
+
247
+ return output_frame
248
+
249
  # Define Streamlit app
250
  def main():
251
  st.title("Real-time Exercise Detection")
252
  video_file = st.file_uploader("Upload a video file", type=["mp4", "avi"])
253
  if video_file is not None:
 
254
  video_processor = VideoProcessor()
255
+ processed_video_file = video_processor.process_video(video_file)
256
+ st.video(processed_video_file)
 
257
 
258
  if __name__ == "__main__":
259
  main()