iDrops commited on
Commit
5177211
·
verified ·
1 Parent(s): c70fcbd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -38
app.py CHANGED
@@ -5,123 +5,138 @@ import streamlit as st
5
  import cv2
6
  import tempfile
7
 
8
-
9
  BASE_DIR = os.path.abspath(os.path.join(__file__, '../../'))
10
  sys.path.append(BASE_DIR)
11
 
12
-
13
  from utils import get_mediapipe_pose
14
  from process_frame import ProcessFrame
15
  from thresholds import get_thresholds_beginner, get_thresholds_pro
16
 
 
17
 
18
-
19
- st.title('AI Fitness Trainer: Squats Analysis')
20
-
21
  mode = st.radio('Select Mode', ['Beginner', 'Pro'], horizontal=True)
22
 
23
-
24
-
25
- thresholds = None
26
-
27
  if mode == 'Beginner':
28
  thresholds = get_thresholds_beginner()
29
 
30
  elif mode == 'Pro':
31
  thresholds = get_thresholds_pro()
32
 
33
-
34
-
35
  upload_process_frame = ProcessFrame(thresholds=thresholds)
36
 
37
- # Initialize face mesh solution
38
  pose = get_mediapipe_pose()
39
 
40
-
41
  download = None
42
 
 
43
  if 'download' not in st.session_state:
44
  st.session_state['download'] = False
45
 
46
-
47
  output_video_file = f'output_recorded.mp4'
48
 
 
49
  if os.path.exists(output_video_file):
50
  os.remove(output_video_file)
51
 
52
-
53
  with st.form('Upload', clear_on_submit=True):
54
  up_file = st.file_uploader("Upload a Video", ['mp4','mov', 'avi'])
55
  uploaded = st.form_submit_button("Upload")
56
 
 
57
  stframe = st.empty()
58
 
 
59
  ip_vid_str = '<p style="font-family:Helvetica; font-weight: bold; font-size: 16px;">Input Video</p>'
60
  warning_str = '<p style="font-family:Helvetica; font-weight: bold; color: Red; font-size: 17px;">Please Upload a Video first!!!</p>'
61
 
 
62
  warn = st.empty()
63
 
64
-
65
  download_button = st.empty()
66
 
 
67
  if up_file and uploaded:
68
-
69
  download_button.empty()
 
 
70
  tfile = tempfile.NamedTemporaryFile(delete=False)
71
-
72
  try:
73
  warn.empty()
 
74
  tfile.write(up_file.read())
75
-
 
76
  vf = cv2.VideoCapture(tfile.name)
77
-
78
- # --------------------- Write the processed video frame. --------------------
79
  fps = int(vf.get(cv2.CAP_PROP_FPS))
80
  width = int(vf.get(cv2.CAP_PROP_FRAME_WIDTH))
81
  height = int(vf.get(cv2.CAP_PROP_FRAME_HEIGHT))
82
  frame_size = (width, height)
 
 
83
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
 
84
  video_output = cv2.VideoWriter(output_video_file, fourcc, fps, frame_size)
85
- # -----------------------------------------------------------------------------
86
-
87
 
88
- txt = st.sidebar.markdown(ip_vid_str, unsafe_allow_html=True)
89
- ip_video = st.sidebar.video(tfile.name)
90
-
 
 
91
  while vf.isOpened():
92
  ret, frame = vf.read()
93
  if not ret:
94
  break
95
-
96
- # convert frame from BGR to RGB before processing it.
97
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
 
98
  out_frame, _ = upload_process_frame.process(frame, pose)
99
  stframe.image(out_frame)
100
- video_output.write(out_frame[...,::-1])
101
-
102
 
 
103
  vf.release()
104
  video_output.release()
 
 
105
  stframe.empty()
106
  ip_video.empty()
107
  txt.empty()
108
  tfile.close()
109
 
110
  except AttributeError:
111
- warn.markdown(warning_str, unsafe_allow_html=True)
112
-
113
-
114
 
 
115
  if os.path.exists(output_video_file):
116
  with open(output_video_file, 'rb') as op_vid:
117
- download = download_button.download_button('Download Video', data = op_vid, file_name='output_recorded.mp4')
118
-
 
119
  if download:
120
  st.session_state['download'] = True
121
 
122
-
123
-
124
  if os.path.exists(output_video_file) and st.session_state['download']:
125
  os.remove(output_video_file)
126
  st.session_state['download'] = False
127
- download_button.empty()
 
5
  import cv2
6
  import tempfile
7
 
8
+ # Define the base directory path
9
  BASE_DIR = os.path.abspath(os.path.join(__file__, '../../'))
10
  sys.path.append(BASE_DIR)
11
 
12
+ # Import functions from other python files
13
  from utils import get_mediapipe_pose
14
  from process_frame import ProcessFrame
15
  from thresholds import get_thresholds_beginner, get_thresholds_pro
16
 
17
+ st.title('Mediapipe Exercise Sample')
18
 
19
+ # Create a button to select mode (Beginner or Pro)
 
 
20
  mode = st.radio('Select Mode', ['Beginner', 'Pro'], horizontal=True)
21
 
22
+ # Initialize thresholds based on the selected mode
23
+ thresholds = None
 
 
24
  if mode == 'Beginner':
25
  thresholds = get_thresholds_beginner()
26
 
27
  elif mode == 'Pro':
28
  thresholds = get_thresholds_pro()
29
 
30
+ # Create a ProcessFrame object with the loaded thresholds
 
31
  upload_process_frame = ProcessFrame(thresholds=thresholds)
32
 
33
+ # Initialize Mediapipe pose solution
34
  pose = get_mediapipe_pose()
35
 
36
+ # Initialize download flag
37
  download = None
38
 
39
+ # Set initial state for download in Streamlit session
40
  if 'download' not in st.session_state:
41
  st.session_state['download'] = False
42
 
43
+ # Define the output video file name
44
  output_video_file = f'output_recorded.mp4'
45
 
46
+ # Remove the output video file if it exists
47
  if os.path.exists(output_video_file):
48
  os.remove(output_video_file)
49
 
50
+ # Create a form for uploading a video
51
  with st.form('Upload', clear_on_submit=True):
52
  up_file = st.file_uploader("Upload a Video", ['mp4','mov', 'avi'])
53
  uploaded = st.form_submit_button("Upload")
54
 
55
+ # Create an empty element to display the video frame
56
  stframe = st.empty()
57
 
58
+ # Define HTML strings for displaying input video and warning message
59
  ip_vid_str = '<p style="font-family:Helvetica; font-weight: bold; font-size: 16px;">Input Video</p>'
60
  warning_str = '<p style="font-family:Helvetica; font-weight: bold; color: Red; font-size: 17px;">Please Upload a Video first!!!</p>'
61
 
62
+ # Create an empty element to display the warning message
63
  warn = st.empty()
64
 
65
+ # Create an empty element for the download button
66
  download_button = st.empty()
67
 
68
+ # Process uploaded video if a video is uploaded and submit button is clicked
69
  if up_file and uploaded:
70
+ # Clear previous download button and warning message
71
  download_button.empty()
72
+
73
+ # Create a temporary file to store the uploaded video
74
  tfile = tempfile.NamedTemporaryFile(delete=False)
75
+
76
  try:
77
  warn.empty()
78
+ # Write the uploaded video content to the temporary file
79
  tfile.write(up_file.read())
80
+
81
+ # Open the temporary file using OpenCV VideoCapture
82
  vf = cv2.VideoCapture(tfile.name)
83
+
84
+ # Get video properties (FPS, width, height)
85
  fps = int(vf.get(cv2.CAP_PROP_FPS))
86
  width = int(vf.get(cv2.CAP_PROP_FRAME_WIDTH))
87
  height = int(vf.get(cv2.CAP_PROP_FRAME_HEIGHT))
88
  frame_size = (width, height)
89
+
90
+ # Define video writer fourcc code for mp4 format
91
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
92
+
93
+ # Create video writer object for the output video
94
  video_output = cv2.VideoWriter(output_video_file, fourcc, fps, frame_size)
 
 
95
 
96
+ # Display the uploaded video on the sidebar
97
+ txt = st.sidebar.markdown(ip_vid_str, unsafe_allow_html=True)
98
+ ip_video = st.sidebar.video(tfile.name)
99
+
100
+ # Process each frame of the video
101
  while vf.isOpened():
102
  ret, frame = vf.read()
103
  if not ret:
104
  break
105
+
106
+ # Convert frame from BGR to RGB for Mediapipe processing
107
  frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
108
+
109
+ # Process the frame using the ProcessFrame object and
110
+ # (continued from previous code block)
111
  out_frame, _ = upload_process_frame.process(frame, pose)
112
  stframe.image(out_frame)
113
+ video_output.write(out_frame[...,::-1]) # Write the processed frame to output video
 
114
 
115
+ # Release video capture and writer resources
116
  vf.release()
117
  video_output.release()
118
+
119
+ # Clear elements that displayed video and temporary file
120
  stframe.empty()
121
  ip_video.empty()
122
  txt.empty()
123
  tfile.close()
124
 
125
  except AttributeError:
126
+ # Handle errors during processing (e.g., invalid video format)
127
+ warn.markdown(warning_str, unsafe_allow_html=True)
 
128
 
129
+ # Check if output video exists and offer download button
130
  if os.path.exists(output_video_file):
131
  with open(output_video_file, 'rb') as op_vid:
132
+ download = download_button.download_button('Download Video', data=op_vid, file_name='output_recorded.mp4')
133
+
134
+ # If video is downloaded, update download flag in Streamlit session state
135
  if download:
136
  st.session_state['download'] = True
137
 
138
+ # Remove output video and reset download flag if video downloaded
 
139
  if os.path.exists(output_video_file) and st.session_state['download']:
140
  os.remove(output_video_file)
141
  st.session_state['download'] = False
142
+ download_button.empty()