ombhojane commited on
Commit
75aee9e
·
verified ·
1 Parent(s): 6c6c9e9

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +83 -0
  2. dance_generator.py +109 -0
  3. dance_visualizer.py +27 -0
  4. pose_detector.py +22 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import cv2
3
+ import numpy as np
4
+ import tempfile
5
+ import os
6
+ from pose_detector import PoseDetector
7
+ from dance_generator import DanceGenerator
8
+ from dance_visualizer import DanceVisualizer
9
+
10
+ class AIDancePartner:
11
+ def __init__(self):
12
+ self.pose_detector = PoseDetector()
13
+ self.dance_generator = DanceGenerator()
14
+ self.visualizer = DanceVisualizer()
15
+
16
+ def setup_ui(self):
17
+ st.title("AI Dance Partner 💃🤖")
18
+ st.sidebar.header("Controls")
19
+
20
+ # Upload video section
21
+ video_file = st.sidebar.file_uploader(
22
+ "Upload your dance video",
23
+ type=['mp4', 'avi', 'mov']
24
+ )
25
+
26
+ # Mode selection
27
+ mode = st.sidebar.selectbox(
28
+ "Select Mode",
29
+ ["Sync Partner", "Generate New Moves"]
30
+ )
31
+
32
+ if video_file:
33
+ self.process_video(video_file, mode)
34
+
35
+ def process_video(self, video_file, mode):
36
+ # Create temporary file to store uploaded video
37
+ tfile = tempfile.NamedTemporaryFile(delete=False)
38
+ tfile.write(video_file.read())
39
+
40
+ # Process the video
41
+ cap = cv2.VideoCapture(tfile.name)
42
+
43
+ # Display original and AI dance side by side
44
+ col1, col2 = st.columns(2)
45
+
46
+ with col1:
47
+ st.header("Your Dance")
48
+ stframe1 = st.empty()
49
+
50
+ with col2:
51
+ st.header("AI Partner")
52
+ stframe2 = st.empty()
53
+
54
+ while cap.isOpened():
55
+ ret, frame = cap.read()
56
+ if not ret:
57
+ break
58
+
59
+ # Detect pose in original frame
60
+ pose_landmarks = self.pose_detector.detect_pose(frame)
61
+
62
+ if mode == "Sync Partner":
63
+ # Generate synchronized dance moves
64
+ ai_frame = self.dance_generator.sync_moves(pose_landmarks)
65
+ else:
66
+ # Generate new dance moves based on style
67
+ ai_frame = self.dance_generator.generate_new_moves(pose_landmarks)
68
+
69
+ # Visualize both frames
70
+ vis_frame = self.visualizer.draw_pose(frame, pose_landmarks)
71
+ stframe1.image(vis_frame, channels="BGR")
72
+ stframe2.image(ai_frame, channels="BGR")
73
+
74
+ # Cleanup
75
+ cap.release()
76
+ os.unlink(tfile.name)
77
+
78
+ def main():
79
+ app = AIDancePartner()
80
+ app.setup_ui()
81
+
82
+ if __name__ == "__main__":
83
+ main()
dance_generator.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ class DanceGenerator:
5
+ def __init__(self):
6
+ self.prev_moves = []
7
+ self.style_memory = []
8
+
9
+ def sync_moves(self, pose_landmarks):
10
+ """Generate synchronized dance moves based on input pose"""
11
+ if pose_landmarks is None:
12
+ return np.zeros((480, 640, 3), dtype=np.uint8)
13
+
14
+ # Convert landmarks to numpy array for processing
15
+ landmarks_array = self._landmarks_to_array(pose_landmarks)
16
+
17
+ # Mirror the movements for sync mode
18
+ mirrored_moves = self._mirror_movements(landmarks_array)
19
+
20
+ # Generate visualization frame
21
+ return self._create_dance_frame(mirrored_moves)
22
+
23
+ def generate_new_moves(self, pose_landmarks):
24
+ """Generate new dance moves based on learned style"""
25
+ if pose_landmarks is None:
26
+ return np.zeros((480, 640, 3), dtype=np.uint8)
27
+
28
+ # Convert landmarks to array
29
+ landmarks_array = self._landmarks_to_array(pose_landmarks)
30
+
31
+ # Update style memory
32
+ self._update_style_memory(landmarks_array)
33
+
34
+ # Generate new moves based on style
35
+ new_moves = self._generate_style_based_moves()
36
+
37
+ # Create visualization frame
38
+ return self._create_dance_frame(new_moves)
39
+
40
+ def _landmarks_to_array(self, landmarks):
41
+ """Convert MediaPipe landmarks to numpy array"""
42
+ points = []
43
+ for landmark in landmarks.landmark:
44
+ points.append([landmark.x, landmark.y, landmark.z])
45
+ return np.array(points)
46
+
47
+ def _mirror_movements(self, landmarks):
48
+ """Mirror the input movements"""
49
+ mirrored = landmarks.copy()
50
+ mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
51
+ return mirrored
52
+
53
+ def _update_style_memory(self, landmarks):
54
+ """Update memory of dance style"""
55
+ self.style_memory.append(landmarks)
56
+ if len(self.style_memory) > 30: # Keep last 30 frames
57
+ self.style_memory.pop(0)
58
+
59
+ def _generate_style_based_moves(self):
60
+ """Generate new moves based on learned style"""
61
+ if not self.style_memory:
62
+ return np.zeros((33, 3)) # Default pose shape
63
+
64
+ # Simple implementation: interpolate between stored poses
65
+ base_pose = self.style_memory[-1]
66
+ if len(self.style_memory) > 1:
67
+ prev_pose = self.style_memory[-2]
68
+ t = np.random.random()
69
+ new_pose = t * base_pose + (1-t) * prev_pose
70
+ else:
71
+ new_pose = base_pose
72
+
73
+ return new_pose
74
+
75
+ def _create_dance_frame(self, pose_array):
76
+ """Create visualization frame from pose array"""
77
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
78
+
79
+ # Convert normalized coordinates to pixel coordinates
80
+ points = (pose_array[:, :2] * [640, 480]).astype(int)
81
+
82
+ # Draw connections between joints
83
+ connections = self._get_pose_connections()
84
+ for connection in connections:
85
+ start_idx, end_idx = connection
86
+ if start_idx < len(points) and end_idx < len(points):
87
+ cv2.line(frame,
88
+ tuple(points[start_idx]),
89
+ tuple(points[end_idx]),
90
+ (0, 255, 0), 2)
91
+
92
+ # Draw joints
93
+ for point in points:
94
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
95
+
96
+ return frame
97
+
98
+ def _get_pose_connections(self):
99
+ """Define connections between pose landmarks"""
100
+ return [
101
+ (0, 1), (1, 2), (2, 3), (3, 7), # Face
102
+ (0, 4), (4, 5), (5, 6), (6, 8),
103
+ (9, 10), (11, 12), (11, 13), (13, 15), # Arms
104
+ (12, 14), (14, 16),
105
+ (11, 23), (12, 24), # Torso
106
+ (23, 24), (23, 25), (24, 26), # Legs
107
+ (25, 27), (26, 28), (27, 29), (28, 30),
108
+ (29, 31), (30, 32)
109
+ ]
dance_visualizer.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import mediapipe as mp
3
+ import numpy as np
4
+
5
+ class DanceVisualizer:
6
+ def __init__(self):
7
+ self.mp_drawing = mp.solutions.drawing_utils
8
+ self.mp_pose = mp.solutions.pose
9
+
10
+ def draw_pose(self, frame, pose_landmarks):
11
+ """Draw pose landmarks on frame"""
12
+ if pose_landmarks is None:
13
+ return frame
14
+
15
+ # Create copy of frame
16
+ annotated_frame = frame.copy()
17
+
18
+ # Draw the pose landmarks
19
+ self.mp_drawing.draw_landmarks(
20
+ annotated_frame,
21
+ pose_landmarks,
22
+ self.mp_pose.POSE_CONNECTIONS,
23
+ self.mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
24
+ self.mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
25
+ )
26
+
27
+ return annotated_frame
pose_detector.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mediapipe as mp
2
+ import cv2
3
+ import numpy as np
4
+
5
+ class PoseDetector:
6
+ def __init__(self):
7
+ self.mp_pose = mp.solutions.pose
8
+ self.pose = self.mp_pose.Pose(
9
+ min_detection_confidence=0.5,
10
+ min_tracking_confidence=0.5
11
+ )
12
+
13
+ def detect_pose(self, frame):
14
+ # Convert BGR to RGB
15
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
16
+
17
+ # Process the frame and detect poses
18
+ results = self.pose.process(rgb_frame)
19
+
20
+ if results.pose_landmarks:
21
+ return results.pose_landmarks
22
+ return None