iDrops commited on
Commit
c70fcbd
·
verified ·
1 Parent(s): b38d09c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -31
app.py CHANGED
@@ -5,27 +5,32 @@ import streamlit as st
5
  import cv2
6
  import tempfile
7
 
8
- dir = os.path.abspath(os.path.join(__file__, '../../'))
9
- sys.path.append(dir)
 
 
10
 
11
  from utils import get_mediapipe_pose
12
  from process_frame import ProcessFrame
13
  from thresholds import get_thresholds_beginner, get_thresholds_pro
14
 
15
- st.title('Exercise tracking Demo')
16
 
17
- diffuculty = st.radio('Select Mode', ['Beginners Squat', 'Pro Squat', 'Push up'], horizontal=True)
 
 
 
 
 
18
 
19
  thresholds = None
20
 
21
- if diffuculty == 'Beginners':
22
  thresholds = get_thresholds_beginner()
23
 
24
- elif diffuculty == 'Professionals':
25
  thresholds = get_thresholds_pro()
26
 
27
- #elif difficulty == 'Push up':
28
- #thresholds = get_pushups()
29
 
30
  upload_process_frame = ProcessFrame(thresholds=thresholds)
31
 
@@ -39,29 +44,29 @@ if 'download' not in st.session_state:
39
  st.session_state['download'] = False
40
 
41
 
42
- Video_Output = f'output_recorded.mp4'
43
 
44
- if os.path.exists(Video_Output):
45
- os.remove(Video_Output)
46
 
47
 
48
  with st.form('Upload', clear_on_submit=True):
49
  up_file = st.file_uploader("Upload a Video", ['mp4','mov', 'avi'])
50
- uploaded_File = st.form_submit_button("Upload")
51
 
52
  stframe = st.empty()
53
 
54
  ip_vid_str = '<p style="font-family:Helvetica; font-weight: bold; font-size: 16px;">Input Video</p>'
55
- warn_str = '<p style="font-family:Helvetica; font-weight: bold; color: Red; font-size: 17px;">Please Upload a Video first!!!</p>'
56
 
57
  warn = st.empty()
58
 
59
 
60
- download_btn = st.empty()
61
 
62
- if up_file and uploaded_File:
63
 
64
- download_btn.empty()
65
  tfile = tempfile.NamedTemporaryFile(delete=False)
66
 
67
  try:
@@ -71,16 +76,17 @@ if up_file and uploaded_File:
71
  vf = cv2.VideoCapture(tfile.name)
72
 
73
  # --------------------- Write the processed video frame. --------------------
74
- Frames_per_sec = int(vf.get(cv2.cap_FPS))
75
- width = int(vf.get(cv2.cap_FrameWidth))
76
- height = int(vf.get(cv2.cap_FrameHeight))
77
- Frame_Size = (width, height)
78
- four_cc = cv2.VideoWriter_four_cc(*'mp4v')
79
- video_output = cv2.VideoWriter(Video_Output, four_cc, Frames_per_sec, Frame_Size)
80
  # -----------------------------------------------------------------------------
 
81
 
82
  txt = st.sidebar.markdown(ip_vid_str, unsafe_allow_html=True)
83
- ip_vid = st.sidebar.video(tfile.name)
84
 
85
  while vf.isOpened():
86
  ret, frame = vf.read()
@@ -92,25 +98,30 @@ if up_file and uploaded_File:
92
  out_frame, _ = upload_process_frame.process(frame, pose)
93
  stframe.image(out_frame)
94
  video_output.write(out_frame[...,::-1])
 
95
 
96
  vf.release()
97
  video_output.release()
98
  stframe.empty()
99
- ip_vid.empty()
100
  txt.empty()
101
  tfile.close()
102
 
103
  except AttributeError:
104
- warn.markdown(warn_str, unsafe_allow_html=True)
 
105
 
106
- if os.path.exists(Video_Output):
107
- with open(Video_Output, 'rb') as op_vid:
108
- download = download_btn.download_btn('Download Video', data = op_vid, file_name='output.mp4')
 
109
 
110
  if download:
111
  st.session_state['download'] = True
112
 
113
- if os.path.exists(Video_Output) and st.session_state['download']:
114
- os.remove(Video_Output)
 
 
115
  st.session_state['download'] = False
116
- download_btn.empty()
 
5
  import cv2
6
  import tempfile
7
 
8
+
9
+ BASE_DIR = os.path.abspath(os.path.join(__file__, '../../'))
10
+ sys.path.append(BASE_DIR)
11
+
12
 
13
  from utils import get_mediapipe_pose
14
  from process_frame import ProcessFrame
15
  from thresholds import get_thresholds_beginner, get_thresholds_pro
16
 
 
17
 
18
+
19
+ st.title('AI Fitness Trainer: Squats Analysis')
20
+
21
+ mode = st.radio('Select Mode', ['Beginner', 'Pro'], horizontal=True)
22
+
23
+
24
 
25
  thresholds = None
26
 
27
+ if mode == 'Beginner':
28
  thresholds = get_thresholds_beginner()
29
 
30
+ elif mode == 'Pro':
31
  thresholds = get_thresholds_pro()
32
 
33
+
 
34
 
35
  upload_process_frame = ProcessFrame(thresholds=thresholds)
36
 
 
44
  st.session_state['download'] = False
45
 
46
 
47
+ output_video_file = f'output_recorded.mp4'
48
 
49
+ if os.path.exists(output_video_file):
50
+ os.remove(output_video_file)
51
 
52
 
53
  with st.form('Upload', clear_on_submit=True):
54
  up_file = st.file_uploader("Upload a Video", ['mp4','mov', 'avi'])
55
+ uploaded = st.form_submit_button("Upload")
56
 
57
  stframe = st.empty()
58
 
59
  ip_vid_str = '<p style="font-family:Helvetica; font-weight: bold; font-size: 16px;">Input Video</p>'
60
+ warning_str = '<p style="font-family:Helvetica; font-weight: bold; color: Red; font-size: 17px;">Please Upload a Video first!!!</p>'
61
 
62
  warn = st.empty()
63
 
64
 
65
+ download_button = st.empty()
66
 
67
+ if up_file and uploaded:
68
 
69
+ download_button.empty()
70
  tfile = tempfile.NamedTemporaryFile(delete=False)
71
 
72
  try:
 
76
  vf = cv2.VideoCapture(tfile.name)
77
 
78
  # --------------------- Write the processed video frame. --------------------
79
+ fps = int(vf.get(cv2.CAP_PROP_FPS))
80
+ width = int(vf.get(cv2.CAP_PROP_FRAME_WIDTH))
81
+ height = int(vf.get(cv2.CAP_PROP_FRAME_HEIGHT))
82
+ frame_size = (width, height)
83
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
84
+ video_output = cv2.VideoWriter(output_video_file, fourcc, fps, frame_size)
85
  # -----------------------------------------------------------------------------
86
+
87
 
88
  txt = st.sidebar.markdown(ip_vid_str, unsafe_allow_html=True)
89
+ ip_video = st.sidebar.video(tfile.name)
90
 
91
  while vf.isOpened():
92
  ret, frame = vf.read()
 
98
  out_frame, _ = upload_process_frame.process(frame, pose)
99
  stframe.image(out_frame)
100
  video_output.write(out_frame[...,::-1])
101
+
102
 
103
  vf.release()
104
  video_output.release()
105
  stframe.empty()
106
+ ip_video.empty()
107
  txt.empty()
108
  tfile.close()
109
 
110
  except AttributeError:
111
+ warn.markdown(warning_str, unsafe_allow_html=True)
112
+
113
 
114
+
115
+ if os.path.exists(output_video_file):
116
+ with open(output_video_file, 'rb') as op_vid:
117
+ download = download_button.download_button('Download Video', data = op_vid, file_name='output_recorded.mp4')
118
 
119
  if download:
120
  st.session_state['download'] = True
121
 
122
+
123
+
124
+ if os.path.exists(output_video_file) and st.session_state['download']:
125
+ os.remove(output_video_file)
126
  st.session_state['download'] = False
127
+ download_button.empty()