rohitashva commited on
Commit
d0f6bea
·
verified ·
1 Parent(s): 97fdf95

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -30
app.py CHANGED
@@ -1,11 +1,13 @@
1
- import streamlit as st
2
- import av
3
  import joblib
4
  import mediapipe as mp
5
  import numpy as np
6
- from streamlit_webrtc import webrtc_streamer, VideoTransformerBase
7
 
8
- # Load trained model and label encoder
 
 
9
  model = joblib.load("pose_classifier.joblib")
10
  label_encoder = joblib.load("label_encoder.joblib")
11
 
@@ -13,38 +15,47 @@ label_encoder = joblib.load("label_encoder.joblib")
13
  mp_pose = mp.solutions.pose
14
  pose = mp_pose.Pose()
15
 
16
- # Streamlit UI
17
- st.title("Live Pose Classification on Hugging Face Spaces")
18
- st.write("Using Streamlit WebRTC, OpenCV, and MediaPipe.")
 
 
 
 
 
 
 
19
 
20
- class PoseClassification(VideoTransformerBase):
21
- def transform(self, frame):
22
- img = frame.to_ndarray(format="bgr24")
23
 
24
- # Convert frame to RGB
25
- img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
 
 
26
 
27
- # Process frame with MediaPipe Pose
28
- results = pose.process(img_rgb)
 
 
 
 
29
 
30
- if results.pose_landmarks:
31
- landmarks = results.pose_landmarks.landmark
32
- pose_data = [j.x for j in landmarks] + [j.y for j in landmarks] + \
33
- [j.z for j in landmarks] + [j.visibility for j in landmarks]
34
 
35
- pose_data = np.array(pose_data).reshape(1, -1)
 
36
 
37
- try:
38
- y_pred = model.predict(pose_data)
39
- predicted_label = label_encoder.inverse_transform(y_pred)[0]
40
 
41
- # Draw label on frame
42
- cv2.putText(img, f"Pose: {predicted_label}", (20, 50),
43
- cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 3)
44
- except Exception as e:
45
- st.warning(f"⚠️ Prediction Error: {e}")
46
 
47
- return av.VideoFrame.from_ndarray(img, format="bgr24")
 
48
 
49
- # Start WebRTC streamer
50
- webrtc_streamer(key="pose-classification", video_transformer_factory=PoseClassification)
 
1
+ from flask import Flask, request, jsonify
2
+ import cv2
3
  import joblib
4
  import mediapipe as mp
5
  import numpy as np
6
+ import tempfile
7
 
8
+ app = Flask(__name__)
9
+
10
+ # Load model and label encoder
11
  model = joblib.load("pose_classifier.joblib")
12
  label_encoder = joblib.load("label_encoder.joblib")
13
 
 
15
  mp_pose = mp.solutions.pose
16
  pose = mp_pose.Pose()
17
 
18
+ def predict_pose_from_image(image_bytes):
19
+ # Convert image bytes to numpy array
20
+ nparr = np.frombuffer(image_bytes, np.uint8)
21
+ frame = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
22
+
23
+ if frame is None:
24
+ return None, "Invalid image"
25
+
26
+ # Convert to RGB
27
+ img_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
28
 
29
+ # Run MediaPipe Pose
30
+ results = pose.process(img_rgb)
 
31
 
32
+ if results.pose_landmarks:
33
+ landmarks = results.pose_landmarks.landmark
34
+ pose_data = [j.x for j in landmarks] + [j.y for j in landmarks] + \
35
+ [j.z for j in landmarks] + [j.visibility for j in landmarks]
36
 
37
+ pose_data = np.array(pose_data).reshape(1, -1)
38
+ y_pred = model.predict(pose_data)
39
+ predicted_label = label_encoder.inverse_transform(y_pred)[0]
40
+ return predicted_label, None
41
+ else:
42
+ return None, "No pose detected"
43
 
44
+ @app.route('/predict-pose', methods=['POST'])
45
+ def predict_pose():
46
+ if 'frame' not in request.files:
47
+ return jsonify({"error": "No image frame uploaded"}), 400
48
 
49
+ file = request.files['frame']
50
+ img_bytes = file.read()
51
 
52
+ label, error = predict_pose_from_image(img_bytes)
53
+ if error:
54
+ return jsonify({"error": error}), 400
55
 
56
+ return jsonify({"predicted_pose": label})
 
 
 
 
57
 
58
+ if __name__ == "__main__":
59
+ app.run(debug=True, port=5007)
60
 
61
+ # curl -X POST -F "frame=@your_image.jpg" http://localhost:5007/predict-pose