rohitashva commited on
Commit
662b753
·
verified ·
1 Parent(s): d5b1e39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -37
app.py CHANGED
@@ -1,8 +1,9 @@
1
  import streamlit as st
2
- import cv2
3
  import joblib
4
  import mediapipe as mp
5
  import numpy as np
 
6
 
7
  # Load trained model and label encoder
8
  model = joblib.load("pose_classifier.joblib")
@@ -13,48 +14,37 @@ mp_pose = mp.solutions.pose
13
  pose = mp_pose.Pose()
14
 
15
  # Streamlit UI
16
- st.title("Live Pose Classification")
17
- st.write("Real-time pose detection using OpenCV and MediaPipe.")
18
 
19
- # OpenCV Video Capture
20
- cap = cv2.VideoCapture(0)
 
21
 
22
- # Streamlit Image Display
23
- frame_placeholder = st.empty()
24
 
25
- while cap.isOpened():
26
- ret, frame = cap.read()
27
- if not ret:
28
- st.warning("Failed to capture video. Check your camera.")
29
- break
30
 
31
- # Convert frame to RGB
32
- img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
 
 
33
 
34
- # Process frame with MediaPipe Pose
35
- results = pose.process(img_rgb)
36
 
37
- if results.pose_landmarks:
38
- landmarks = results.pose_landmarks.landmark
39
- pose_data = [j.x for j in landmarks] + [j.y for j in landmarks] + \
40
- [j.z for j in landmarks] + [j.visibility for j in landmarks]
41
 
42
- pose_data = np.array(pose_data).reshape(1, -1)
 
 
 
 
43
 
44
- # Predict pose
45
- y_pred = model.predict(pose_data)
46
- predicted_label = label_encoder.inverse_transform(y_pred)[0]
47
 
48
- # Display predicted label
49
- cv2.putText(frame, f"Pose: {predicted_label}", (20, 50),
50
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 3)
51
-
52
- # Display frame in Streamlit
53
- frame_placeholder.image(frame, channels="BGR")
54
-
55
- # Break loop if user stops execution
56
- if st.button("Stop Camera"):
57
- break
58
-
59
- cap.release()
60
- cv2.destroyAllWindows()
 
1
  import streamlit as st
2
+ import av
3
  import joblib
4
  import mediapipe as mp
5
  import numpy as np
6
+ from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
7
 
8
  # Load trained model and label encoder
9
  model = joblib.load("pose_classifier.joblib")
 
14
  pose = mp_pose.Pose()
15
 
16
  # Streamlit UI
17
+ st.title("Live Pose Classification on Hugging Face Spaces")
18
+ st.write("Using Streamlit WebRTC, OpenCV, and MediaPipe.")
19
 
20
+ class PoseClassification(VideoTransformerBase):
21
+ def transform(self, frame):
22
+ img = frame.to_ndarray(format="bgr24")
23
 
24
+ # Convert frame to RGB
25
+ img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
26
 
27
+ # Process frame with MediaPipe Pose
28
+ results = pose.process(img_rgb)
 
 
 
29
 
30
+ if results.pose_landmarks:
31
+ landmarks = results.pose_landmarks.landmark
32
+ pose_data = [j.x for j in landmarks] + [j.y for j in landmarks] + \
33
+ [j.z for j in landmarks] + [j.visibility for j in landmarks]
34
 
35
+ pose_data = np.array(pose_data).reshape(1, -1)
 
36
 
37
+ try:
38
+ y_pred = model.predict(pose_data)
39
+ predicted_label = label_encoder.inverse_transform(y_pred)[0]
 
40
 
41
+ # Draw label on frame
42
+ cv2.putText(img, f"Pose: {predicted_label}", (20, 50),
43
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 3)
44
+ except Exception as e:
45
+ st.warning(f"⚠️ Prediction Error: {e}")
46
 
47
+ return av.VideoFrame.from_ndarray(img, format="bgr24")
 
 
48
 
49
+ # Start WebRTC streamer
50
+ webrtc_streamer(key="pose-classification", video_transformer_factory=PoseClassification)