frrf / app.py
randomshit11's picture
Update app.py
7d291aa verified
raw
history blame
4.84 kB
import os
import streamlit as st
import cv2
import mediapipe as mp
import numpy as np
import math
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (LSTM, Dense, Dropout, Input, Flatten,
Bidirectional, Permute, multiply)
# Load the pose estimation model from Mediapipe
mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils
pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
# Define the attention block for the LSTM model
def attention_block(inputs, time_steps):
a = Permute((2, 1))(inputs)
a = Dense(time_steps, activation='softmax')(a)
a_probs = Permute((2, 1), name='attention_vec')(a)
output_attention_mul = multiply([inputs, a_probs], name='attention_mul')
return output_attention_mul
# Build and load the LSTM model
@st.cache(allow_output_mutation=True)
def build_model(HIDDEN_UNITS=256, sequence_length=30, num_input_values=33*4, num_classes=3):
inputs = Input(shape=(sequence_length, num_input_values))
lstm_out = Bidirectional(LSTM(HIDDEN_UNITS, return_sequences=True))(inputs)
attention_mul = attention_block(lstm_out, sequence_length)
attention_mul = Flatten()(attention_mul)
x = Dense(2*HIDDEN_UNITS, activation='relu')(attention_mul)
x = Dropout(0.5)(x)
x = Dense(num_classes, activation='softmax')(x)
model = Model(inputs=[inputs], outputs=x)
load_dir = "./models/LSTM_Attention.h5"
model.load_weights(load_dir)
return model
# Define the VideoProcessor class for real-time video processing
class VideoProcessor:
def __init__(self):
self.actions = np.array(['curl', 'press', 'squat'])
self.sequence_length = 30
self.colors = [(245,117,16), (117,245,16), (16,117,245)]
self.pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
self.model = build_model()
def process_video(self, video_file):
# Get the filename from the file object
filename = video_file.name
# Create a temporary file to write the contents of the uploaded video file
temp_file = open(filename, 'wb')
temp_file.write(video_file.read())
temp_file.close()
# Now we can open the video file using cv2.VideoCapture()
cap = cv2.VideoCapture(filename)
out_frames = []
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
results = self.pose.process(frame_rgb)
frame = self.draw_landmarks(frame, results)
out_frames.append(frame)
cap.release()
# Remove the temporary file
os.remove(filename)
return out_frames
def draw_landmarks(self, image, results):
mp_drawing.draw_landmarks(image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS,
mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2))
return image
@st.cache()
def extract_keypoints(self, results):
pose = np.array([[res.x, res.y, res.z, res.visibility] for res in results.pose_landmarks.landmark]).flatten() if results.pose_landmarks else np.zeros(33*4)
return pose
@st.cache()
def calculate_angle(self, a, b, c):
a = np.array(a) # First
b = np.array(b) # Mid
c = np.array(c) # End
radians = np.arctan2(c[1]-b[1], c[0]-b[0]) - np.arctan2(a[1]-b[1], a[0]-b[0])
angle = np.abs(radians*180.0/np.pi)
if angle > 180.0:
angle = 360-angle
return angle
@st.cache()
def get_coordinates(self, landmarks, side, joint):
coord = getattr(self.mp_pose.PoseLandmark, side.upper() + "_" + joint.upper())
x_coord_val = landmarks[coord.value].x
y_coord_val = landmarks[coord.value].y
return [x_coord_val, y_coord_val]
@st.cache()
def viz_joint_angle(self, image, angle, joint):
cv2.putText(image, str(int(angle)),
tuple(np.multiply(joint, [640, 480]).astype(int)),
cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 255), 2, cv2.LINE_AA
)
return
# Define Streamlit app
def main():
st.title("Real-time Exercise Detection")
video_file = st.file_uploader("Upload a video file", type=["mp4", "avi"])
if video_file is not None:
st.video(video_file)
video_processor = VideoProcessor()
frames = video_processor.process_video(video_file)
for frame in frames:
st.image(frame, channels="BGR")
if __name__ == "__main__":
main()