|
import numpy as np
|
|
import cv2
|
|
|
|
class DanceGenerator:
|
|
def __init__(self):
|
|
self.prev_moves = []
|
|
self.style_memory = []
|
|
|
|
def sync_moves(self, pose_landmarks):
|
|
"""Generate synchronized dance moves based on input pose"""
|
|
if pose_landmarks is None:
|
|
return np.zeros((480, 640, 3), dtype=np.uint8)
|
|
|
|
|
|
landmarks_array = self._landmarks_to_array(pose_landmarks)
|
|
|
|
|
|
mirrored_moves = self._mirror_movements(landmarks_array)
|
|
|
|
|
|
return self._create_dance_frame(mirrored_moves)
|
|
|
|
def generate_new_moves(self, pose_landmarks):
|
|
"""Generate new dance moves based on learned style"""
|
|
if pose_landmarks is None:
|
|
return np.zeros((480, 640, 3), dtype=np.uint8)
|
|
|
|
|
|
landmarks_array = self._landmarks_to_array(pose_landmarks)
|
|
|
|
|
|
self._update_style_memory(landmarks_array)
|
|
|
|
|
|
new_moves = self._generate_style_based_moves()
|
|
|
|
|
|
return self._create_dance_frame(new_moves)
|
|
|
|
def _landmarks_to_array(self, landmarks):
|
|
"""Convert MediaPipe landmarks to numpy array"""
|
|
points = []
|
|
for landmark in landmarks.landmark:
|
|
points.append([landmark.x, landmark.y, landmark.z])
|
|
return np.array(points)
|
|
|
|
def _mirror_movements(self, landmarks):
|
|
"""Mirror the input movements"""
|
|
mirrored = landmarks.copy()
|
|
mirrored[:, 0] = 1 - mirrored[:, 0]
|
|
return mirrored
|
|
|
|
def _update_style_memory(self, landmarks):
|
|
"""Update memory of dance style"""
|
|
self.style_memory.append(landmarks)
|
|
if len(self.style_memory) > 30:
|
|
self.style_memory.pop(0)
|
|
|
|
def _generate_style_based_moves(self):
|
|
"""Generate new moves based on learned style"""
|
|
if not self.style_memory:
|
|
return np.zeros((33, 3))
|
|
|
|
|
|
base_pose = self.style_memory[-1]
|
|
if len(self.style_memory) > 1:
|
|
prev_pose = self.style_memory[-2]
|
|
t = np.random.random()
|
|
new_pose = t * base_pose + (1-t) * prev_pose
|
|
else:
|
|
new_pose = base_pose
|
|
|
|
return new_pose
|
|
|
|
def _create_dance_frame(self, pose_array):
|
|
"""Create visualization frame from pose array"""
|
|
frame = np.zeros((480, 640, 3), dtype=np.uint8)
|
|
|
|
|
|
points = (pose_array[:, :2] * [640, 480]).astype(int)
|
|
|
|
|
|
connections = self._get_pose_connections()
|
|
for connection in connections:
|
|
start_idx, end_idx = connection
|
|
if start_idx < len(points) and end_idx < len(points):
|
|
cv2.line(frame,
|
|
tuple(points[start_idx]),
|
|
tuple(points[end_idx]),
|
|
(0, 255, 0), 2)
|
|
|
|
|
|
for point in points:
|
|
cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
|
|
|
|
return frame
|
|
|
|
def _get_pose_connections(self):
|
|
"""Define connections between pose landmarks"""
|
|
return [
|
|
(0, 1), (1, 2), (2, 3), (3, 7),
|
|
(0, 4), (4, 5), (5, 6), (6, 8),
|
|
(9, 10), (11, 12), (11, 13), (13, 15),
|
|
(12, 14), (14, 16),
|
|
(11, 23), (12, 24),
|
|
(23, 24), (23, 25), (24, 26),
|
|
(25, 27), (26, 28), (27, 29), (28, 30),
|
|
(29, 31), (30, 32)
|
|
] |