randomshit11 commited on
Commit
7d291aa
·
verified ·
1 Parent(s): 04bfe32

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -130
app.py CHANGED
@@ -36,74 +36,62 @@ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num
36
  model.load_weights(load_dir)
37
  return model
38
 
 
39
  class VideoProcessor:
40
  def __init__(self):
41
  self.actions = np.array(['curl', 'press', 'squat'])
42
  self.sequence_length = 30
43
  self.colors = [(245,117,16), (117,245,16), (16,117,245)]
44
- self.threshold = 0.50 # Default threshold for activity classification confidence
45
-
46
- # Detection variables
47
- self.sequence = []
48
- self.current_action = ''
49
-
50
- # Initialize pose model
51
- self.mp_pose = mp.solutions.pose
52
- self.mp_drawing = mp.solutions.drawing_utils
53
- self.pose = self.mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
54
- self.model = build_model() # Load the LSTM model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
- @st.cache()
57
  def draw_landmarks(self, image, results):
58
- """
59
- This function draws keypoints and landmarks detected by the human pose estimation model
60
-
61
- """
62
- self.mp_drawing.draw_landmarks(image, results.pose_landmarks, self.mp_pose.POSE_CONNECTIONS,
63
- self.mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
64
- self.mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
65
- )
66
  return image
67
-
68
  @st.cache()
69
  def extract_keypoints(self, results):
70
- """
71
- Processes and organizes the keypoints detected from the pose estimation model
72
- to be used as inputs for the exercise decoder models
73
-
74
- """
75
  pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
76
  return pose
77
 
78
  @st.cache()
79
  def calculate_angle(self, a, b, c):
80
- """
81
- Computes 3D joint angle inferred by 3 keypoints and their relative positions to one another
82
-
83
- """
84
  a = np.array(a) # First
85
  b = np.array(b) # Mid
86
  c = np.array(c) # End
87
-
88
  radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
89
  angle = np.abs(radians*180.0/np.pi)
90
-
91
  if angle > 180.0:
92
  angle = 360-angle
93
-
94
  return angle
95
 
96
  @st.cache()
97
  def get_coordinates(self, landmarks, side, joint):
98
- """
99
- Retrieves x and y coordinates of a particular keypoint from the pose estimation model
100
-
101
- Args:
102
- landmarks: processed keypoints from the pose estimation model
103
- side: 'left' or 'right'. Denotes the side of the body of the landmark of interest.
104
- joint: 'shoulder', 'elbow', 'wrist', 'hip', 'knee', or 'ankle'. Denotes which body joint is associated with the landmark of interest.
105
-
106
- """
107
  coord = getattr(self.mp_pose.PoseLandmark, side.upper() + "_" + joint.upper())
108
  x_coord_val = landmarks[coord.value].x
109
  y_coord_val = landmarks[coord.value].y
@@ -111,100 +99,12 @@ class VideoProcessor:
111
 
112
  @st.cache()
113
  def viz_joint_angle(self, image, angle, joint):
114
- """
115
- Displays the joint angle value near the joint within the image frame
116
-
117
- """
118
  cv2.putText(image, str(int(angle)),
119
  tuple(np.multiply(joint, [640, 480]).astype(int)),
120
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
121
  )
122
  return
123
 
124
- @st.cache()
125
- def process(self, image):
126
- """
127
- Function to process the video frame from the user's webcam and run the fitness trainer AI
128
-
129
- Args:
130
- image (numpy array): input image from the webcam
131
-
132
- Returns:
133
- numpy array: processed image with keypoint detection and fitness activity classification visualized
134
- """
135
- # Pose detection model
136
- image.flags.writeable = False
137
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
138
- results = pose.process(image)
139
-
140
- # Draw the hand annotations on the image.
141
- image.flags.writeable = True
142
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
143
- self.draw_landmarks(image, results)
144
-
145
- # Prediction logic
146
- keypoints = self.extract_keypoints(results)
147
- self.sequence.append(keypoints.astype('float32',casting='same_kind'))
148
- self.sequence = self.sequence[-self.sequence_length:]
149
-
150
- if len(self.sequence) == self.sequence_length:
151
- res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
152
- # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
153
- # interpreter.invoke()
154
- # res = interpreter.get_tensor(self.output_details[0]['index'])
155
-
156
- self.current_action = self.actions[np.argmax(res)]
157
- confidence = np.max(res)
158
-
159
- # Erase current action variable if no probability is above threshold
160
- if confidence < self.threshold:
161
- self.current_action = ''
162
-
163
- # Viz probabilities
164
- image = self.prob_viz(res, image)
165
-
166
- # Count reps
167
- try:
168
- landmarks = results.pose_landmarks.landmark
169
- self.count_reps(
170
- image, landmarks, mp_pose)
171
- except:
172
- pass
173
-
174
- # Display graphical information
175
- cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
176
- cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
177
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
178
- cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
179
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
180
- cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
181
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
182
-
183
- # return cv2.flip(image, 1)
184
- return image
185
-
186
- def process_video(self, video_file):
187
- # Get the filename from the file object
188
- filename = video_file.name
189
- # Create a temporary file to write the contents of the uploaded video file
190
- temp_file = open(filename, 'wb')
191
- temp_file.write(video_file.read())
192
- temp_file.close()
193
- # Now we can open the video file using cv2.VideoCapture()
194
- cap = cv2.VideoCapture(filename)
195
- out_frames = []
196
- while cap.isOpened():
197
- ret, frame = cap.read()
198
- if not ret:
199
- break
200
- frame_processed = self.process(frame)
201
- out_frames.append(frame_processed)
202
- cap.release()
203
- # Remove the temporary file
204
- os.remove(filename)
205
- return out_frames
206
-
207
-
208
  # Define Streamlit app
209
  def main():
210
  st.title("Real-time Exercise Detection")
 
36
  model.load_weights(load_dir)
37
  return model
38
 
39
+ # Define the VideoProcessor class for real-time video processing
40
  class VideoProcessor:
41
  def __init__(self):
42
  self.actions = np.array(['curl', 'press', 'squat'])
43
  self.sequence_length = 30
44
  self.colors = [(245,117,16), (117,245,16), (16,117,245)]
45
+ self.pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
46
+ self.model = build_model()
47
+
48
+ def process_video(self, video_file):
49
+ # Get the filename from the file object
50
+ filename = video_file.name
51
+ # Create a temporary file to write the contents of the uploaded video file
52
+ temp_file = open(filename, 'wb')
53
+ temp_file.write(video_file.read())
54
+ temp_file.close()
55
+ # Now we can open the video file using cv2.VideoCapture()
56
+ cap = cv2.VideoCapture(filename)
57
+ out_frames = []
58
+ while cap.isOpened():
59
+ ret, frame = cap.read()
60
+ if not ret:
61
+ break
62
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
+ results = self.pose.process(frame_rgb)
64
+ frame = self.draw_landmarks(frame, results)
65
+ out_frames.append(frame)
66
+ cap.release()
67
+ # Remove the temporary file
68
+ os.remove(filename)
69
+ return out_frames
70
 
 
71
  def draw_landmarks(self, image, results):
72
+ mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
73
+ mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
74
+ mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
 
 
 
 
 
75
  return image
76
+
77
  @st.cache()
78
  def extract_keypoints(self, results):
 
 
 
 
 
79
  pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
80
  return pose
81
 
82
  @st.cache()
83
  def calculate_angle(self, a, b, c):
 
 
 
 
84
  a = np.array(a) # First
85
  b = np.array(b) # Mid
86
  c = np.array(c) # End
 
87
  radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
88
  angle = np.abs(radians*180.0/np.pi)
 
89
  if angle > 180.0:
90
  angle = 360-angle
 
91
  return angle
92
 
93
  @st.cache()
94
  def get_coordinates(self, landmarks, side, joint):
 
 
 
 
 
 
 
 
 
95
  coord = getattr(self.mp_pose.PoseLandmark, side.upper() + "_" + joint.upper())
96
  x_coord_val = landmarks[coord.value].x
97
  y_coord_val = landmarks[coord.value].y
 
99
 
100
  @st.cache()
101
  def viz_joint_angle(self, image, angle, joint):
 
 
 
 
102
  cv2.putText(image, str(int(angle)),
103
  tuple(np.multiply(joint, [640, 480]).astype(int)),
104
  cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
105
  )
106
  return
107
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  # Define Streamlit app
109
  def main():
110
  st.title("Real-time Exercise Detection")