randomshit11 commited on
Commit
0f13ce7
·
verified ·
1 Parent(s): e3f6d22

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +298 -0
app.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import mediapipe as mp
4
+ import numpy as np
5
+ import math
6
+ import gradio as gr
7
+ from tensorflow.keras.models import Model
8
+ from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, Flatten,
9
+ Bidirectional, Permute, multiply)
10
+
11
+ # Load the pose estimation model from Mediapipe
12
+ mp_pose = mp.solutions.pose
13
+ mp_drawing = mp.solutions.drawing_utils
14
+ pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
15
+
16
+ # Define the attention block for the LSTM model
17
+ def attention_block(inputs, time_steps):
18
+ a = Permute((2, 1))(inputs)
19
+ a = Dense(time_steps, activation='softmax')(a)
20
+ a_probs = Permute((2, 1), name='attention_vec')(a)
21
+ output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
22
+ return output_attention_mul
23
+
24
+ # Build and load the LSTM model
25
+ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
26
+ inputs = Input(shape=(sequence_length, num_input_values))
27
+ lstm_out = Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True))(inputs)
28
+ attention_mul = attention_block(lstm_out, sequence_length)
29
+ attention_mul = Flatten()(attention_mul)
30
+ x = Dense(2*HIDDEN_UNITS, activation='relu')(attention_mul)
31
+ x = Dropout(0.5)(x)
32
+ x = Dense(num_classes, activation='softmax')(x)
33
+ model = Model(inputs=[inputs], outputs=x)
34
+ load_dir = "./models/LSTM_Attention.h5"
35
+ model.load_weights(load_dir)
36
+ return model
37
+
38
+ # Define the VideoProcessor class for real-time video processing
39
+ class VideoProcessor:
40
+ def __init__(self):
41
+ # Parameters
42
+ self.actions = np.array(['curl', 'press', 'squat'])
43
+ self.sequence_length = 30
44
+ self.colors = [(245,117,16), (117,245,16), (16,117,245)]
45
+ self.threshold = 0.5
46
+
47
+ self.model = build_model(256)
48
+
49
+ # Detection variables
50
+ self.sequence = []
51
+ self.current_action = ''
52
+
53
+ # Rep counter logic variables
54
+ self.curl_counter = 0
55
+ self.press_counter = 0
56
+ self.squat_counter = 0
57
+ self.curl_stage = None
58
+ self.press_stage = None
59
+ self.squat_stage = None
60
+ self.pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
61
+
62
+ def process_video(self, video_file):
63
+ # Get the filename from the file object
64
+ filename = "temp_video.mp4"
65
+ # Create a temporary file to write the contents of the uploaded video file
66
+ with open(filename, 'wb') as temp_file:
67
+ temp_file.write(video_file.read())
68
+
69
+ # Process the video and save the processed video to a new file
70
+ output_filename = "processed_video.mp4"
71
+ cap = cv2.VideoCapture(filename)
72
+ frame_width = int(cap.get(3))
73
+ frame_height = int(cap.get(4))
74
+ out = cv2.VideoWriter(output_filename, cv2.VideoWriter_fourcc(*'h264'), 30, (frame_width, frame_height))
75
+ while cap.isOpened():
76
+ ret, frame = cap.read()
77
+ if not ret:
78
+ break
79
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
80
+ results = self.pose.process(frame_rgb)
81
+ processed_frame = self.process_frame(frame, results)
82
+ out.write(processed_frame)
83
+ cap.release()
84
+ out.release()
85
+
86
+ # Remove the temporary file
87
+ os.remove(filename)
88
+
89
+ # Return the path to the processed video file
90
+ return output_filename
91
+
92
+ def process_frame(self, frame, results):
93
+ # Process the frame using the `process` function
94
+ processed_frame = self.process(frame)
95
+ return processed_frame
96
+
97
+ def process(self, image):
98
+
99
+ # Pose detection model
100
+ image.flags.writeable = False
101
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
102
+ results = pose.process(image)
103
+
104
+ # Draw the hand annotations on the image.
105
+ image.flags.writeable = True
106
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
107
+ self.draw_landmarks(image, results)
108
+
109
+ # Prediction logic
110
+ keypoints = self.extract_keypoints(results)
111
+ self.sequence.append(keypoints.astype('float32',casting='same_kind'))
112
+ self.sequence = self.sequence[-self.sequence_length:]
113
+
114
+ if len(self.sequence) == self.sequence_length:
115
+ res = self.model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
116
+
117
+ self.current_action = self.actions[np.argmax(res)]
118
+ confidence = np.max(res)
119
+
120
+ # Erase current action variable if no probability is above threshold
121
+ if confidence < self.threshold:
122
+ self.current_action = ''
123
+
124
+ # Viz probabilities
125
+ image = self.prob_viz(res, image)
126
+
127
+ # Count reps
128
+ landmarks = results.pose_landmarks.landmark
129
+ self.count_reps(image, landmarks, mp_pose)
130
+
131
+ # Display graphical information
132
+ cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
133
+ cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
134
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
135
+ cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
136
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
137
+ cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
138
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
139
+
140
+ return image
141
+
142
+ def draw_landmarks(self, image, results):
143
+ mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
144
+ mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
145
+ mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
146
+ return image
147
+
148
+ def extract_keypoints(self, results):
149
+ pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
150
+ return pose
151
+
152
+ def count_reps(self, image, landmarks, mp_pose):
153
+ """
154
+ Counts repetitions of each exercise. Global count and stage (i.e., state) variables are updated within this function.
155
+
156
+ """
157
+
158
+ if self.current_action == 'curl':
159
+ # Get coords
160
+ shoulder = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'SHOULDER')
161
+ elbow = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'ELBOW')
162
+ wrist = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'WRIST')
163
+
164
+ # calculate elbow angle
165
+ angle = self.calculate_angle(shoulder, elbow, wrist)
166
+
167
+ # curl counter logic
168
+ if angle < 30:
169
+ self.curl_stage = "up"
170
+ if angle > 140 and self.curl_stage == 'up':
171
+ self.curl_stage = "down"
172
+ self.curl_counter += 1
173
+ self.press_stage = None
174
+ self.squat_stage = None
175
+
176
+ # Viz joint angle
177
+ self.viz_joint_angle(image, angle, elbow)
178
+
179
+ elif self.current_action == 'press':
180
+ # Get coords
181
+ shoulder = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'SHOULDER')
182
+ elbow = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'ELBOW')
183
+ wrist = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'WRIST')
184
+
185
+ # Calculate elbow angle
186
+ elbow_angle = self.calculate_angle(shoulder, elbow, wrist)
187
+
188
+ # Compute distances between joints
189
+ shoulder2elbow_dist = abs(math.dist(shoulder, elbow))
190
+ shoulder2wrist_dist = abs(math.dist(shoulder, wrist))
191
+
192
+ # Press counter logic
193
+ if (elbow_angle > 130) and (shoulder2elbow_dist < shoulder2wrist_dist):
194
+ self.press_stage = "up"
195
+ if (elbow_angle < 50) and (shoulder2elbow_dist > shoulder2wrist_dist) and (self.press_stage == 'up'):
196
+ self.press_stage = 'down'
197
+ self.press_counter += 1
198
+ self.curl_stage = None
199
+ self.squat_stage = None
200
+
201
+ # Viz joint angle
202
+ self.viz_joint_angle(image, elbow_angle, elbow)
203
+
204
+ elif self.current_action == 'squat':
205
+ # Get coords
206
+ # left side
207
+ left_shoulder = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'SHOULDER')
208
+ left_hip = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'HIP')
209
+ left_knee = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'KNEE')
210
+ left_ankle = self.get_coordinates(landmarks, mp_pose, 'LEFT', 'ANKLE')
211
+ # right side
212
+ right_shoulder = self.get_coordinates(landmarks, mp_pose, 'RIGHT', 'SHOULDER')
213
+ right_hip = self.get_coordinates(landmarks, mp_pose, 'RIGHT', 'HIP')
214
+ right_knee = self.get_coordinates(landmarks, mp_pose, 'RIGHT', 'KNEE')
215
+ right_ankle = self.get_coordinates(landmarks, mp_pose, 'RIGHT', 'ANKLE')
216
+
217
+ # Calculate knee angles
218
+ left_knee_angle = self.calculate_angle(left_hip, left_knee, left_ankle)
219
+ right_knee_angle = self.calculate_angle(right_hip, right_knee, right_ankle)
220
+
221
+ # Calculate hip angles
222
+ left_hip_angle = self.calculate_angle(left_shoulder, left_hip, left_knee)
223
+ right_hip_angle = self.calculate_angle(right_shoulder, right_hip, right_knee)
224
+
225
+ # Squat counter logic
226
+ thr = 165
227
+ if (left_knee_angle < thr) and (right_knee_angle < thr) and (left_hip_angle < thr) and (
228
+ right_hip_angle < thr):
229
+ self.squat_stage = "down"
230
+ if (left_knee_angle > thr) and (right_knee_angle > thr) and (left_hip_angle > thr) and (
231
+ right_hip_angle > thr) and (self.squat_stage == 'down'):
232
+ self.squat_stage = 'up'
233
+ self.squat_counter += 1
234
+ self.curl_stage = None
235
+ self.press_stage = None
236
+
237
+ # Viz joint angles
238
+ self.viz_joint_angle(image, left_knee_angle, left_knee)
239
+ self.viz_joint_angle(image, left_hip_angle, left_hip)
240
+
241
+ else:
242
+ pass
243
+ return
244
+
245
+ def prob_viz(self, res, input_frame):
246
+ """
247
+ This function displays the model prediction probability distribution over the set of exercise classes
248
+ as a horizontal bar graph
249
+
250
+ """
251
+ output_frame = input_frame.copy()
252
+ for num, prob in enumerate(res):
253
+ cv2.rectangle(output_frame, (0,60+num*40), (int(prob*100), 90+num*40), self.colors[num], -1)
254
+ cv2.putText(output_frame, self.actions[num], (0, 85+num*40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255,255,255), 2, cv2.LINE_AA)
255
+
256
+ return output_frame
257
+
258
+ def get_coordinates(self, landmarks, mp_pose, side, part):
259
+
260
+
261
+ coord = getattr(mp_pose.PoseLandmark,side.upper()+"_"+part.upper())
262
+ x_coord_val = landmarks[coord.value].x
263
+ y_coord_val = landmarks[coord.value].y
264
+ return [x_coord_val, y_coord_val]
265
+
266
+ def calculate_angle(self, a, b, c):
267
+ a = np.array(a)
268
+ b = np.array(b)
269
+ c = np.array(c)
270
+ radians = math.atan2(c[1]-b[1], c[0]-b[0]) - math.atan2(a[1]-b[1], a[0]-b[0])
271
+ angle = np.abs(radians*180.0/np.pi)
272
+ if angle > 180.0:
273
+ angle = 360 - angle
274
+ return angle
275
+
276
+ def viz_joint_angle(self, image, angle, joint):
277
+ cv2.putText(image, str(round(angle, 2)),
278
+ tuple(np.multiply(joint, [640, 480]).astype(int)),
279
+ cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255,255,255), 2, cv2.LINE_AA)
280
+
281
+ # Define Gradio Interface
282
+ def main(video_file):
283
+ video_processor = VideoProcessor()
284
+ output_video = video_processor.process_video(video_file)
285
+ with open(output_video, 'rb') as f:
286
+ video_bytes = f.read()
287
+ return video_bytes
288
+
289
+ iface = gr.Interface(
290
+ fn=main,
291
+ inputs="file",
292
+ outputs="video",
293
+ title="Real-time Exercise Detection",
294
+ description="Upload a video file for real-time exercise detection.",
295
+ allow_flagging=False
296
+ )
297
+
298
+ iface.launch()