Upload 4 files
Browse files- app.py +83 -0
- dance_generator.py +109 -0
- dance_visualizer.py +27 -0
- pose_detector.py +22 -0
app.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import tempfile
|
5 |
+
import os
|
6 |
+
from pose_detector import PoseDetector
|
7 |
+
from dance_generator import DanceGenerator
|
8 |
+
from dance_visualizer import DanceVisualizer
|
9 |
+
|
10 |
+
class AIDancePartner:
|
11 |
+
def __init__(self):
|
12 |
+
self.pose_detector = PoseDetector()
|
13 |
+
self.dance_generator = DanceGenerator()
|
14 |
+
self.visualizer = DanceVisualizer()
|
15 |
+
|
16 |
+
def setup_ui(self):
|
17 |
+
st.title("AI Dance Partner 💃🤖")
|
18 |
+
st.sidebar.header("Controls")
|
19 |
+
|
20 |
+
# Upload video section
|
21 |
+
video_file = st.sidebar.file_uploader(
|
22 |
+
"Upload your dance video",
|
23 |
+
type=['mp4', 'avi', 'mov']
|
24 |
+
)
|
25 |
+
|
26 |
+
# Mode selection
|
27 |
+
mode = st.sidebar.selectbox(
|
28 |
+
"Select Mode",
|
29 |
+
["Sync Partner", "Generate New Moves"]
|
30 |
+
)
|
31 |
+
|
32 |
+
if video_file:
|
33 |
+
self.process_video(video_file, mode)
|
34 |
+
|
35 |
+
def process_video(self, video_file, mode):
|
36 |
+
# Create temporary file to store uploaded video
|
37 |
+
tfile = tempfile.NamedTemporaryFile(delete=False)
|
38 |
+
tfile.write(video_file.read())
|
39 |
+
|
40 |
+
# Process the video
|
41 |
+
cap = cv2.VideoCapture(tfile.name)
|
42 |
+
|
43 |
+
# Display original and AI dance side by side
|
44 |
+
col1, col2 = st.columns(2)
|
45 |
+
|
46 |
+
with col1:
|
47 |
+
st.header("Your Dance")
|
48 |
+
stframe1 = st.empty()
|
49 |
+
|
50 |
+
with col2:
|
51 |
+
st.header("AI Partner")
|
52 |
+
stframe2 = st.empty()
|
53 |
+
|
54 |
+
while cap.isOpened():
|
55 |
+
ret, frame = cap.read()
|
56 |
+
if not ret:
|
57 |
+
break
|
58 |
+
|
59 |
+
# Detect pose in original frame
|
60 |
+
pose_landmarks = self.pose_detector.detect_pose(frame)
|
61 |
+
|
62 |
+
if mode == "Sync Partner":
|
63 |
+
# Generate synchronized dance moves
|
64 |
+
ai_frame = self.dance_generator.sync_moves(pose_landmarks)
|
65 |
+
else:
|
66 |
+
# Generate new dance moves based on style
|
67 |
+
ai_frame = self.dance_generator.generate_new_moves(pose_landmarks)
|
68 |
+
|
69 |
+
# Visualize both frames
|
70 |
+
vis_frame = self.visualizer.draw_pose(frame, pose_landmarks)
|
71 |
+
stframe1.image(vis_frame, channels="BGR")
|
72 |
+
stframe2.image(ai_frame, channels="BGR")
|
73 |
+
|
74 |
+
# Cleanup
|
75 |
+
cap.release()
|
76 |
+
os.unlink(tfile.name)
|
77 |
+
|
78 |
+
def main():
|
79 |
+
app = AIDancePartner()
|
80 |
+
app.setup_ui()
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
dance_generator.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
class DanceGenerator:
|
5 |
+
def __init__(self):
|
6 |
+
self.prev_moves = []
|
7 |
+
self.style_memory = []
|
8 |
+
|
9 |
+
def sync_moves(self, pose_landmarks):
|
10 |
+
"""Generate synchronized dance moves based on input pose"""
|
11 |
+
if pose_landmarks is None:
|
12 |
+
return np.zeros((480, 640, 3), dtype=np.uint8)
|
13 |
+
|
14 |
+
# Convert landmarks to numpy array for processing
|
15 |
+
landmarks_array = self._landmarks_to_array(pose_landmarks)
|
16 |
+
|
17 |
+
# Mirror the movements for sync mode
|
18 |
+
mirrored_moves = self._mirror_movements(landmarks_array)
|
19 |
+
|
20 |
+
# Generate visualization frame
|
21 |
+
return self._create_dance_frame(mirrored_moves)
|
22 |
+
|
23 |
+
def generate_new_moves(self, pose_landmarks):
|
24 |
+
"""Generate new dance moves based on learned style"""
|
25 |
+
if pose_landmarks is None:
|
26 |
+
return np.zeros((480, 640, 3), dtype=np.uint8)
|
27 |
+
|
28 |
+
# Convert landmarks to array
|
29 |
+
landmarks_array = self._landmarks_to_array(pose_landmarks)
|
30 |
+
|
31 |
+
# Update style memory
|
32 |
+
self._update_style_memory(landmarks_array)
|
33 |
+
|
34 |
+
# Generate new moves based on style
|
35 |
+
new_moves = self._generate_style_based_moves()
|
36 |
+
|
37 |
+
# Create visualization frame
|
38 |
+
return self._create_dance_frame(new_moves)
|
39 |
+
|
40 |
+
def _landmarks_to_array(self, landmarks):
|
41 |
+
"""Convert MediaPipe landmarks to numpy array"""
|
42 |
+
points = []
|
43 |
+
for landmark in landmarks.landmark:
|
44 |
+
points.append([landmark.x, landmark.y, landmark.z])
|
45 |
+
return np.array(points)
|
46 |
+
|
47 |
+
def _mirror_movements(self, landmarks):
|
48 |
+
"""Mirror the input movements"""
|
49 |
+
mirrored = landmarks.copy()
|
50 |
+
mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
|
51 |
+
return mirrored
|
52 |
+
|
53 |
+
def _update_style_memory(self, landmarks):
|
54 |
+
"""Update memory of dance style"""
|
55 |
+
self.style_memory.append(landmarks)
|
56 |
+
if len(self.style_memory) > 30: # Keep last 30 frames
|
57 |
+
self.style_memory.pop(0)
|
58 |
+
|
59 |
+
def _generate_style_based_moves(self):
|
60 |
+
"""Generate new moves based on learned style"""
|
61 |
+
if not self.style_memory:
|
62 |
+
return np.zeros((33, 3)) # Default pose shape
|
63 |
+
|
64 |
+
# Simple implementation: interpolate between stored poses
|
65 |
+
base_pose = self.style_memory[-1]
|
66 |
+
if len(self.style_memory) > 1:
|
67 |
+
prev_pose = self.style_memory[-2]
|
68 |
+
t = np.random.random()
|
69 |
+
new_pose = t * base_pose + (1-t) * prev_pose
|
70 |
+
else:
|
71 |
+
new_pose = base_pose
|
72 |
+
|
73 |
+
return new_pose
|
74 |
+
|
75 |
+
def _create_dance_frame(self, pose_array):
|
76 |
+
"""Create visualization frame from pose array"""
|
77 |
+
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
78 |
+
|
79 |
+
# Convert normalized coordinates to pixel coordinates
|
80 |
+
points = (pose_array[:, :2] * [640, 480]).astype(int)
|
81 |
+
|
82 |
+
# Draw connections between joints
|
83 |
+
connections = self._get_pose_connections()
|
84 |
+
for connection in connections:
|
85 |
+
start_idx, end_idx = connection
|
86 |
+
if start_idx < len(points) and end_idx < len(points):
|
87 |
+
cv2.line(frame,
|
88 |
+
tuple(points[start_idx]),
|
89 |
+
tuple(points[end_idx]),
|
90 |
+
(0, 255, 0), 2)
|
91 |
+
|
92 |
+
# Draw joints
|
93 |
+
for point in points:
|
94 |
+
cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
|
95 |
+
|
96 |
+
return frame
|
97 |
+
|
98 |
+
def _get_pose_connections(self):
|
99 |
+
"""Define connections between pose landmarks"""
|
100 |
+
return [
|
101 |
+
(0, 1), (1, 2), (2, 3), (3, 7), # Face
|
102 |
+
(0, 4), (4, 5), (5, 6), (6, 8),
|
103 |
+
(9, 10), (11, 12), (11, 13), (13, 15), # Arms
|
104 |
+
(12, 14), (14, 16),
|
105 |
+
(11, 23), (12, 24), # Torso
|
106 |
+
(23, 24), (23, 25), (24, 26), # Legs
|
107 |
+
(25, 27), (26, 28), (27, 29), (28, 30),
|
108 |
+
(29, 31), (30, 32)
|
109 |
+
]
|
dance_visualizer.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import mediapipe as mp
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class DanceVisualizer:
|
6 |
+
def __init__(self):
|
7 |
+
self.mp_drawing = mp.solutions.drawing_utils
|
8 |
+
self.mp_pose = mp.solutions.pose
|
9 |
+
|
10 |
+
def draw_pose(self, frame, pose_landmarks):
|
11 |
+
"""Draw pose landmarks on frame"""
|
12 |
+
if pose_landmarks is None:
|
13 |
+
return frame
|
14 |
+
|
15 |
+
# Create copy of frame
|
16 |
+
annotated_frame = frame.copy()
|
17 |
+
|
18 |
+
# Draw the pose landmarks
|
19 |
+
self.mp_drawing.draw_landmarks(
|
20 |
+
annotated_frame,
|
21 |
+
pose_landmarks,
|
22 |
+
self.mp_pose.POSE_CONNECTIONS,
|
23 |
+
self.mp_drawing.DrawingSpec(color=(245,117,66), thickness=2, circle_radius=2),
|
24 |
+
self.mp_drawing.DrawingSpec(color=(245,66,230), thickness=2, circle_radius=2)
|
25 |
+
)
|
26 |
+
|
27 |
+
return annotated_frame
|
pose_detector.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import mediapipe as mp
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class PoseDetector:
|
6 |
+
def __init__(self):
|
7 |
+
self.mp_pose = mp.solutions.pose
|
8 |
+
self.pose = self.mp_pose.Pose(
|
9 |
+
min_detection_confidence=0.5,
|
10 |
+
min_tracking_confidence=0.5
|
11 |
+
)
|
12 |
+
|
13 |
+
def detect_pose(self, frame):
|
14 |
+
# Convert BGR to RGB
|
15 |
+
rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
16 |
+
|
17 |
+
# Process the frame and detect poses
|
18 |
+
results = self.pose.process(rgb_frame)
|
19 |
+
|
20 |
+
if results.pose_landmarks:
|
21 |
+
return results.pose_landmarks
|
22 |
+
return None
|