randomshit11 commited on
Commit
f5e0301
·
verified ·
1 Parent(s): 4fba2a2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -59
app.py CHANGED
@@ -11,6 +11,9 @@ import math
11
 
12
  from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
13
  import av
 
 
 
14
 
15
  ## Build and Load Model
16
  def attention_block(inputs, time_steps):
@@ -286,87 +289,148 @@ class VideoProcessor:
286
 
287
  return output_frame
288
 
289
- @st.cache()
290
- def process(self, image):
291
- """
292
- Function to process the video frame from the user's webcam and run the fitness trainer AI
293
 
294
- Args:
295
- image (numpy array): input image from the webcam
296
 
297
- Returns:
298
- numpy array: processed image with keypoint detection and fitness activity classification visualized
299
- """
300
- # Pose detection model
301
- image.flags.writeable = False
302
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
303
- results = pose.process(image)
304
 
305
- # Draw the hand annotations on the image.
306
- image.flags.writeable = True
307
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
308
- self.draw_landmarks(image, results)
309
 
310
- # Prediction logic
311
- keypoints = self.extract_keypoints(results)
312
- self.sequence.append(keypoints.astype('float32',casting='same_kind'))
313
- self.sequence = self.sequence[-self.sequence_length:]
314
 
315
- if len(self.sequence) == self.sequence_length:
316
- res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
317
- # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
318
- # interpreter.invoke()
319
- # res = interpreter.get_tensor(self.output_details[0]['index'])
320
 
321
- self.current_action = self.actions[np.argmax(res)]
322
- confidence = np.max(res)
323
 
324
- # Erase current action variable if no probability is above threshold
325
- if confidence < self.threshold:
326
- self.current_action = ''
327
 
328
- # Viz probabilities
329
- image = self.prob_viz(res, image)
330
 
331
- # Count reps
332
- try:
333
- landmarks = results.pose_landmarks.landmark
334
- self.count_reps(
335
- image, landmarks, mp_pose)
336
- except:
337
- pass
338
 
339
- # Display graphical information
340
- cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
341
- cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
342
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
343
- cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
344
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
345
- cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
346
- cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
347
 
348
- # return cv2.flip(image, 1)
349
- return image
350
 
351
- def recv(self, frame):
352
- """
353
- Receive and process video stream from webcam
 
 
 
354
 
 
 
 
 
 
 
 
 
 
355
  Args:
356
- frame: current video frame
 
 
 
 
 
 
 
 
357
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  Returns:
359
- av.VideoFrame: processed video frame
360
  """
361
- img = frame.to_ndarray(format="bgr24")
362
- img = self.process(img)
363
- return av.VideoFrame.from_ndarray(img, format="bgr24")
 
 
 
 
 
 
 
364
 
365
- ## Stream Webcam Video and Run Model
366
  # Options
367
  RTC_CONFIGURATION = RTCConfiguration(
368
  {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
369
  )
 
370
  # Streamer
371
  webrtc_ctx = webrtc_streamer(
372
  key="AI trainer",
 
11
 
12
  from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
13
  import av
14
+ from io import BytesIO
15
+ import av
16
+ from PIL import Image
17
 
18
  ## Build and Load Model
19
  def attention_block(inputs, time_steps):
 
289
 
290
  return output_frame
291
 
292
+ # @st.cache()
293
+ # def process(self, image):
294
+ # """
295
+ # Function to process the video frame from the user's webcam and run the fitness trainer AI
296
 
297
+ # Args:
298
+ # image (numpy array): input image from the webcam
299
 
300
+ # Returns:
301
+ # numpy array: processed image with keypoint detection and fitness activity classification visualized
302
+ # """
303
+ # # Pose detection model
304
+ # image.flags.writeable = False
305
+ # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
306
+ # results = pose.process(image)
307
 
308
+ # # Draw the hand annotations on the image.
309
+ # image.flags.writeable = True
310
+ # image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
311
+ # self.draw_landmarks(image, results)
312
 
313
+ # # Prediction logic
314
+ # keypoints = self.extract_keypoints(results)
315
+ # self.sequence.append(keypoints.astype('float32',casting='same_kind'))
316
+ # self.sequence = self.sequence[-self.sequence_length:]
317
 
318
+ # if len(self.sequence) == self.sequence_length:
319
+ # res = model.predict(np.expand_dims(self.sequence, axis=0), verbose=0)[0]
320
+ # # interpreter.set_tensor(self.input_details[0]['index'], np.expand_dims(self.sequence, axis=0))
321
+ # # interpreter.invoke()
322
+ # # res = interpreter.get_tensor(self.output_details[0]['index'])
323
 
324
+ # self.current_action = self.actions[np.argmax(res)]
325
+ # confidence = np.max(res)
326
 
327
+ # # Erase current action variable if no probability is above threshold
328
+ # if confidence < self.threshold:
329
+ # self.current_action = ''
330
 
331
+ # # Viz probabilities
332
+ # image = self.prob_viz(res, image)
333
 
334
+ # # Count reps
335
+ # try:
336
+ # landmarks = results.pose_landmarks.landmark
337
+ # self.count_reps(
338
+ # image, landmarks, mp_pose)
339
+ # except:
340
+ # pass
341
 
342
+ # # Display graphical information
343
+ # cv2.rectangle(image, (0,0), (640, 40), self.colors[np.argmax(res)], -1)
344
+ # cv2.putText(image, 'curl ' + str(self.curl_counter), (3,30),
345
+ # cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
346
+ # cv2.putText(image, 'press ' + str(self.press_counter), (240,30),
347
+ # cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
348
+ # cv2.putText(image, 'squat ' + str(self.squat_counter), (490,30),
349
+ # cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 255, 255), 2, cv2.LINE_AA)
350
 
351
+ # # return cv2.flip(image, 1)
352
+ # return image
353
 
354
+ # def recv(self, frame):
355
+ # """
356
+ # Receive and process video stream from webcam
357
+
358
+ # Args:
359
+ # frame: current video frame
360
 
361
+ # Returns:
362
+ # av.VideoFrame: processed video frame
363
+ # """
364
+ # img = frame.to_ndarray(format="bgr24")
365
+ # img = self.process(img)
366
+ # return av.VideoFrame.from_ndarray(img, format="bgr24")
367
+ def process_uploaded_file(self, file):
368
+ """
369
+ Function to process an uploaded image or video file and run the fitness trainer AI
370
  Args:
371
+ file (BytesIO): uploaded image or video file
372
+ Returns:
373
+ numpy array: processed image with keypoint detection and fitness activity classification visualized
374
+ """
375
+ # Initialize an empty list to store processed frames
376
+ processed_frames = []
377
+
378
+ # Check if the uploaded file is a video
379
+ is_video = hasattr(file, 'name') and file.name.endswith(('.mp4', '.avi', '.mov'))
380
 
381
+ if is_video:
382
+ container = av.open(file)
383
+ for frame in container.decode(video=0):
384
+ # Convert the frame to OpenCV format
385
+ image = frame.to_image().convert("RGB")
386
+ image = np.array(image)
387
+
388
+ # Process the frame
389
+ processed_frame = self.process(image)
390
+
391
+ # Append the processed frame to the list
392
+ processed_frames.append(processed_frame)
393
+
394
+ # Close the video file container
395
+ container.close()
396
+ else:
397
+ # If the uploaded file is an image
398
+ # Load the image from the BytesIO object
399
+ image = Image.open(file)
400
+ image = np.array(image)
401
+
402
+ # Process the image
403
+ processed_frame = self.process(image)
404
+
405
+ # Append the processed frame to the list
406
+ processed_frames.append(processed_frame)
407
+
408
+ return processed_frames
409
+
410
+ def recv_uploaded_file(self, file):
411
+ """
412
+ Receive and process an uploaded video file
413
+ Args:
414
+ file (BytesIO): uploaded video file
415
  Returns:
416
+ List[av.VideoFrame]: list of processed video frames
417
  """
418
+ # Process the uploaded file
419
+ processed_frames = self.process_uploaded_file(file)
420
+
421
+ # Convert processed frames to av.VideoFrame objects
422
+ av_frames = []
423
+ for frame in processed_frames:
424
+ av_frame = av.VideoFrame.from_ndarray(frame, format="bgr24")
425
+ av_frames.append(av_frame)
426
+
427
+ return av_frames
428
 
 
429
  # Options
430
  RTC_CONFIGURATION = RTCConfiguration(
431
  {"iceServers": [{"urls": ["stun:stun.l.google.com:19302"]}]}
432
  )
433
+
434
  # Streamer
435
  webrtc_ctx = webrtc_streamer(
436
  key="AI trainer",