randomshit11 commited on
Commit
a5ac526
·
verified ·
1 Parent(s): 7b618e5

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +89 -0
main.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ import cv2
4
+ import mediapipe as mp
5
+ import numpy as np
6
+ import math
7
+ from tensorflow.keras.models import Model
8
+ from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, Flatten,
9
+ Bidirectional, Permute, multiply)
10
+
11
+ # Load the pose estimation model from Mediapipe
12
+ mp_pose = mp.solutions.pose
13
+ mp_drawing = mp.solutions.drawing_utils
14
+ pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
15
+
16
+ # Define the attention block for the LSTM model
17
+ def attention_block(inputs, time_steps):
18
+ a = Permute((2, 1))(inputs)
19
+ a = Dense(time_steps, activation='softmax')(a)
20
+ a_probs = Permute((2, 1), name='attention_vec')(a)
21
+ output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
22
+ return output_attention_mul
23
+
24
+ # Build and load the LSTM model
25
+ @st.cache(allow_output_mutation=True)
26
+ def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
27
+ inputs = Input(shape=(sequence_length, num_input_values))
28
+ lstm_out = Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True))(inputs)
29
+ attention_mul = attention_block(lstm_out, sequence_length)
30
+ attention_mul = Flatten()(attention_mul)
31
+ x = Dense(2*HIDDEN_UNITS, activation='relu')(attention_mul)
32
+ x = Dropout(0.5)(x)
33
+ x = Dense(num_classes, activation='softmax')(x)
34
+ model = Model(inputs=[inputs], outputs=x)
35
+ load_dir = "./models/LSTM_Attention.h5"
36
+ model.load_weights(load_dir)
37
+ return model
38
+
39
+ # Define the VideoProcessor class for real-time video processing
40
+ class VideoProcessor:
41
+ def __init__(self):
42
+ self.actions = np.array(['curl', 'press', 'squat'])
43
+ self.sequence_length = 30
44
+ self.colors = [(245,117,16), (117,245,16), (16,117,245)]
45
+ self.pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
46
+ self.model = build_model()
47
+
48
+ def process_video(self, video_file):
49
+ # Get the filename from the file object
50
+ filename = video_file.name
51
+ # Create a temporary file to write the contents of the uploaded video file
52
+ temp_file = open(filename, 'wb')
53
+ temp_file.write(video_file.read())
54
+ temp_file.close()
55
+ # Now we can open the video file using cv2.VideoCapture()
56
+ cap = cv2.VideoCapture(filename)
57
+ out_frames = []
58
+ while cap.isOpened():
59
+ ret, frame = cap.read()
60
+ if not ret:
61
+ break
62
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
63
+ results = self.pose.process(frame_rgb)
64
+ frame = self.draw_landmarks(frame, results)
65
+ out_frames.append(frame)
66
+ cap.release()
67
+ # Remove the temporary file
68
+ os.remove(filename)
69
+ return out_frames
70
+
71
+ def draw_landmarks(self, image, results):
72
+ mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
73
+ mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
74
+ mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
75
+ return image
76
+
77
+ # Define Streamlit app
78
+ def main():
79
+ st.title("Real-time Exercise Detection")
80
+ video_file = st.file_uploader("Upload a video file", type=["mp4", "avi"])
81
+ if video_file is not None:
82
+ st.video(video_file)
83
+ video_processor = VideoProcessor()
84
+ frames = video_processor.process_video(video_file)
85
+ for frame in frames:
86
+ st.image(frame, channels="BGR")
87
+
88
+ if __name__ == "__main__":
89
+ main()