ducdatit2002 commited on
Commit
c20c60b
·
verified ·
1 Parent(s): 83999b1

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +122 -0
app.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from tensorflow.keras.layers import LSTM
6
+ import streamlit as st
7
+
8
+ labels = np.array(['FALL', 'LYING', 'SIT', 'STAND', 'MOVE'])
9
+
10
+ n_time_steps = 25
11
+ mpPose = mp.solutions.pose
12
+ pose = mpPose.Pose()
13
+ mpDraw = mp.solutions.drawing_utils
14
+
15
+ def custom_lstm(*args, **kwargs):
16
+ kwargs.pop('time_major', None)
17
+ return LSTM(*args, **kwargs)
18
+
19
+ model = tf.keras.models.load_model('bro.h5', custom_objects={'LSTM': custom_lstm})
20
+
21
+ def make_landmark_timestep(results):
22
+ c_lm = []
23
+ for id, lm in enumerate(results.pose_landmarks.landmark):
24
+ c_lm.append(lm.x)
25
+ c_lm.append(lm.y)
26
+ c_lm.append(lm.z)
27
+ c_lm.append(lm.visibility)
28
+ return c_lm
29
+
30
+ def draw_landmark_on_image(mpDraw, results, img, label):
31
+ mpDraw.draw_landmarks(img, results.pose_landmarks, mpPose.POSE_CONNECTIONS)
32
+ for id, lm in enumerate(results.pose_landmarks.landmark):
33
+ h, w, c = img.shape
34
+ cx, cy = int(lm.x * w), int(lm.y * h)
35
+ if label != "FALL":
36
+ cv2.circle(img, (cx, cy), 5, (0, 255, 0), cv2.FILLED)
37
+ else:
38
+ cv2.circle(img, (cx, cy), 5, (0, 0, 255), cv2.FILLED)
39
+ return img
40
+
41
+ def draw_class_on_image(label, img):
42
+ font = cv2.FONT_HERSHEY_SIMPLEX
43
+ bottomLeftCornerOfText = (10, 30)
44
+ fontScale = 1
45
+ fontColor = (0, 255, 0)
46
+ thickness = 2
47
+ lineType = 2
48
+ cv2.putText(img, label,
49
+ bottomLeftCornerOfText,
50
+ font,
51
+ fontScale,
52
+ fontColor,
53
+ thickness,
54
+ lineType)
55
+ return img
56
+
57
+ def detect(model, lm_list):
58
+ lm_list = np.array(lm_list)
59
+ lm_list = np.expand_dims(lm_list, axis=0)
60
+ results = model.predict(lm_list)
61
+ if results[0][0] >= 0.5:
62
+ label = labels[0]
63
+ elif results[0][1] >= 0.5:
64
+ label = labels[1]
65
+ elif results[0][2] >= 0.5:
66
+ label = labels[2]
67
+ elif results[0][3] >= 0.5:
68
+ label = labels[3]
69
+ elif results[0][4] >= 0.5:
70
+ label = labels[4]
71
+ else:
72
+ label = "NONE DETECTION"
73
+ return label
74
+
75
+ def main():
76
+ st.title("Pose Detection and Classification")
77
+
78
+ run_type = st.sidebar.selectbox("Select input type", ("Camera", "Video File"))
79
+
80
+ if run_type == "Camera":
81
+ cap = cv2.VideoCapture(0)
82
+ else:
83
+ video_file = st.sidebar.file_uploader("Upload a video", type=["mp4", "mov", "avi"])
84
+ if video_file is not None:
85
+ # Temporarily save the uploaded video to disk to pass to cv2.VideoCapture
86
+ with open("temp_video.mp4", "wb") as f:
87
+ f.write(video_file.read())
88
+ cap = cv2.VideoCapture("temp_video.mp4")
89
+ else:
90
+ st.write("Please upload a video file.")
91
+ return
92
+
93
+ stframe = st.empty()
94
+ label = 'Starting...'
95
+ lm_list = []
96
+
97
+ while cap.isOpened():
98
+ success, img = cap.read()
99
+ if not success:
100
+ break
101
+
102
+ imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
+ results = pose.process(imgRGB)
104
+
105
+ if results.pose_landmarks:
106
+ c_lm = make_landmark_timestep(results)
107
+ img = draw_landmark_on_image(mpDraw, results, img, label)
108
+ img = draw_class_on_image(label, img)
109
+ lm_list.append(c_lm)
110
+ if len(lm_list) == n_time_steps:
111
+ label = detect(model, lm_list)
112
+ lm_list = []
113
+
114
+ stframe.image(img, channels="BGR")
115
+
116
+ if cv2.waitKey(1) == ord('q'):
117
+ break
118
+
119
+ cap.release()
120
+
121
+ if __name__ == '__main__':
122
+ main()