ducdatit2002 commited on
Commit
83999b1
·
verified ·
1 Parent(s): 12e4ea8

Delete run.py

Browse files
Files changed (1) hide show
  1. run.py +0 -122
run.py DELETED
@@ -1,122 +0,0 @@
1
- import cv2
2
- import mediapipe as mp
3
- import numpy as np
4
- import tensorflow as tf
5
- from tensorflow.keras.layers import LSTM
6
- import streamlit as st
7
-
8
- labels = np.array(['FALL', 'LYING', 'SIT', 'STAND', 'MOVE'])
9
-
10
- n_time_steps = 25
11
- mpPose = mp.solutions.pose
12
- pose = mpPose.Pose()
13
- mpDraw = mp.solutions.drawing_utils
14
-
15
- def custom_lstm(*args, **kwargs):
16
- kwargs.pop('time_major', None)
17
- return LSTM(*args, **kwargs)
18
-
19
- model = tf.keras.models.load_model('bro.h5', custom_objects={'LSTM': custom_lstm})
20
-
21
- def make_landmark_timestep(results):
22
- c_lm = []
23
- for id, lm in enumerate(results.pose_landmarks.landmark):
24
- c_lm.append(lm.x)
25
- c_lm.append(lm.y)
26
- c_lm.append(lm.z)
27
- c_lm.append(lm.visibility)
28
- return c_lm
29
-
30
- def draw_landmark_on_image(mpDraw, results, img, label):
31
- mpDraw.draw_landmarks(img, results.pose_landmarks, mpPose.POSE_CONNECTIONS)
32
- for id, lm in enumerate(results.pose_landmarks.landmark):
33
- h, w, c = img.shape
34
- cx, cy = int(lm.x * w), int(lm.y * h)
35
- if label != "FALL":
36
- cv2.circle(img, (cx, cy), 5, (0, 255, 0), cv2.FILLED)
37
- else:
38
- cv2.circle(img, (cx, cy), 5, (0, 0, 255), cv2.FILLED)
39
- return img
40
-
41
- def draw_class_on_image(label, img):
42
- font = cv2.FONT_HERSHEY_SIMPLEX
43
- bottomLeftCornerOfText = (10, 30)
44
- fontScale = 1
45
- fontColor = (0, 255, 0)
46
- thickness = 2
47
- lineType = 2
48
- cv2.putText(img, label,
49
- bottomLeftCornerOfText,
50
- font,
51
- fontScale,
52
- fontColor,
53
- thickness,
54
- lineType)
55
- return img
56
-
57
- def detect(model, lm_list):
58
- lm_list = np.array(lm_list)
59
- lm_list = np.expand_dims(lm_list, axis=0)
60
- results = model.predict(lm_list)
61
- if results[0][0] >= 0.5:
62
- label = labels[0]
63
- elif results[0][1] >= 0.5:
64
- label = labels[1]
65
- elif results[0][2] >= 0.5:
66
- label = labels[2]
67
- elif results[0][3] >= 0.5:
68
- label = labels[3]
69
- elif results[0][4] >= 0.5:
70
- label = labels[4]
71
- else:
72
- label = "NONE DETECTION"
73
- return label
74
-
75
- def main():
76
- st.title("Pose Detection and Classification")
77
-
78
- run_type = st.sidebar.selectbox("Select input type", ("Camera", "Video File"))
79
-
80
- if run_type == "Camera":
81
- cap = cv2.VideoCapture(0)
82
- else:
83
- video_file = st.sidebar.file_uploader("Upload a video", type=["mp4", "mov", "avi"])
84
- if video_file is not None:
85
- # Temporarily save the uploaded video to disk to pass to cv2.VideoCapture
86
- with open("temp_video.mp4", "wb") as f:
87
- f.write(video_file.read())
88
- cap = cv2.VideoCapture("temp_video.mp4")
89
- else:
90
- st.write("Please upload a video file.")
91
- return
92
-
93
- stframe = st.empty()
94
- label = 'Starting...'
95
- lm_list = []
96
-
97
- while cap.isOpened():
98
- success, img = cap.read()
99
- if not success:
100
- break
101
-
102
- imgRGB = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
103
- results = pose.process(imgRGB)
104
-
105
- if results.pose_landmarks:
106
- c_lm = make_landmark_timestep(results)
107
- img = draw_landmark_on_image(mpDraw, results, img, label)
108
- img = draw_class_on_image(label, img)
109
- lm_list.append(c_lm)
110
- if len(lm_list) == n_time_steps:
111
- label = detect(model, lm_list)
112
- lm_list = []
113
-
114
- stframe.image(img, channels="BGR")
115
-
116
- if cv2.waitKey(1) == ord('q'):
117
- break
118
-
119
- cap.release()
120
-
121
- if __name__ == '__main__':
122
- main()