GarimaPuri01 commited on
Commit
f497e19
·
verified ·
1 Parent(s): 3799beb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -0
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
4
+ import streamlit as st
5
+ import numpy as np
6
+ import mediapipe as mp
7
+
8
+ # Initialize MediaPipe Pose
9
+ mp_pose = mp.solutions.pose
10
+ pose = mp_pose.Pose(min_detection_confidence=0.5, min_tracking_confidence=0.5)
11
+
12
+ # Initialize MediaPipe Drawing
13
+ mp_drawing = mp.solutions.drawing_utils
14
+
15
+ # Load the Hugging Face model and tokenizer
16
+ model_name = "your-huggingface-model-name"
17
+ model = AutoModelForSequenceClassification.from_pretrained(model_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
19
+
20
+ # Yoga Pose Classification Function using Hugging Face model
21
+ def classify_pose(landmarks):
22
+ # Prepare input for the model
23
+ landmark_list = [landmark.x for landmark in landmarks] + [landmark.y for landmark in landmarks] + [landmark.z for landmark in landmarks]
24
+ inputs = tokenizer(landmark_list, return_tensors="pt")
25
+
26
+ # Get model predictions
27
+ outputs = model(**inputs)
28
+ predictions = torch.argmax(outputs.logits, dim=1)
29
+
30
+ # Map predictions to pose names (adjust this mapping according to your model)
31
+ pose_names = ["Mountain Pose", "Tree Pose", "Warrior Pose", "Unknown Pose"]
32
+ return pose_names[predictions.item()]
33
+
34
+ def main():
35
+ st.title("Live Yoga Pose Detection with Hugging Face")
36
+
37
+ # Start video capture
38
+ cap = cv2.VideoCapture(0)
39
+
40
+ stframe = st.empty()
41
+
42
+ while cap.isOpened():
43
+ success, image = cap.read()
44
+ if not success:
45
+ st.error("Ignoring empty camera frame.")
46
+ continue
47
+
48
+ # Convert the BGR image to RGB
49
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
50
+
51
+ # Process the image and detect the pose
52
+ results = pose.process(image_rgb)
53
+
54
+ # Draw the pose annotation on the image
55
+ if results.pose_landmarks:
56
+ mp_drawing.draw_landmarks(
57
+ image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
58
+
59
+ # Classify the detected pose
60
+ landmarks = results.pose_landmarks.landmark
61
+ pose_name = classify_pose(landmarks)
62
+
63
+ # Display the classification result on the image
64
+ cv2.putText(image, pose_name, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2, cv2.LINE_AA)
65
+
66
+ # Convert the image back to BGR for OpenCV
67
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
68
+
69
+ # Display the image in Streamlit
70
+ stframe.image(image_bgr, channels='BGR')
71
+
72
+ # Break the loop if 'q' is pressed
73
+ if cv2.waitKey(5) & 0xFF == ord('q'):
74
+ break
75
+
76
+ cap.release()
77
+ cv2.destroyAllWindows()
78
+
79
+ if _name_ == "_main_":
80
+ main()