ombhojane commited on
Commit
96a0264
·
verified ·
1 Parent(s): 9cebfe8

Upload 4 files

Browse files
Files changed (2) hide show
  1. app.py +57 -26
  2. dance_generator.py +165 -22
app.py CHANGED
@@ -29,47 +29,78 @@ class AIDancePartner:
29
  ["Sync Partner", "Generate New Moves"]
30
  )
31
 
 
 
 
32
  if video_file:
33
- self.process_video(video_file, mode)
34
 
35
- def process_video(self, video_file, mode):
36
- # Create temporary file to store uploaded video
37
  tfile = tempfile.NamedTemporaryFile(delete=False)
38
  tfile.write(video_file.read())
39
 
40
- # Process the video
41
  cap = cv2.VideoCapture(tfile.name)
42
 
43
- # Display original and AI dance side by side
44
- col1, col2 = st.columns(2)
 
 
 
45
 
46
- with col1:
47
- st.header("Your Dance")
48
- stframe1 = st.empty()
49
-
50
- with col2:
51
- st.header("AI Partner")
52
- stframe2 = st.empty()
53
-
54
  while cap.isOpened():
55
  ret, frame = cap.read()
56
  if not ret:
57
  break
58
-
59
- # Detect pose in original frame
60
  pose_landmarks = self.pose_detector.detect_pose(frame)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if mode == "Sync Partner":
63
- # Generate synchronized dance moves
64
- ai_frame = self.dance_generator.sync_moves(pose_landmarks)
65
- else:
66
- # Generate new dance moves based on style
67
- ai_frame = self.dance_generator.generate_new_moves(pose_landmarks)
68
 
69
- # Visualize both frames
70
- vis_frame = self.visualizer.draw_pose(frame, pose_landmarks)
71
- stframe1.image(vis_frame, channels="BGR")
72
- stframe2.image(ai_frame, channels="BGR")
73
 
74
  # Cleanup
75
  cap.release()
 
29
  ["Sync Partner", "Generate New Moves"]
30
  )
31
 
32
+ # Add playback controls
33
+ play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
34
+
35
  if video_file:
36
+ self.process_video(video_file, mode, play_speed)
37
 
38
+ def process_video(self, video_file, mode, play_speed):
 
39
  tfile = tempfile.NamedTemporaryFile(delete=False)
40
  tfile.write(video_file.read())
41
 
 
42
  cap = cv2.VideoCapture(tfile.name)
43
 
44
+ # Get video properties
45
+ fps = cap.get(cv2.CAP_PROP_FPS)
46
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
47
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
48
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
49
 
50
+ # Initialize progress bar
51
+ progress_bar = st.progress(0)
52
+ frame_placeholder = st.empty()
53
+
54
+ # Pre-process video to extract all poses
55
+ all_poses = []
 
 
56
  while cap.isOpened():
57
  ret, frame = cap.read()
58
  if not ret:
59
  break
 
 
60
  pose_landmarks = self.pose_detector.detect_pose(frame)
61
+ all_poses.append(pose_landmarks)
62
+
63
+ # Reset video capture
64
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
65
+
66
+ # Generate AI dance sequence
67
+ ai_sequence = self.dance_generator.generate_dance_sequence(
68
+ all_poses,
69
+ mode,
70
+ total_frames,
71
+ (frame_height, frame_width)
72
+ )
73
+
74
+ # Playback loop
75
+ frame_count = 0
76
+ while cap.isOpened():
77
+ ret, frame = cap.read()
78
+ if not ret:
79
+ break
80
+
81
+ # Update progress
82
+ progress = frame_count / total_frames
83
+ progress_bar.progress(progress)
84
+
85
+ # Get corresponding AI frame
86
+ ai_frame = ai_sequence[frame_count]
87
+
88
+ # Combine frames side by side
89
+ combined_frame = np.hstack([
90
+ frame,
91
+ cv2.resize(ai_frame, (frame_width, frame_height))
92
+ ])
93
 
94
+ # Display combined frame
95
+ frame_placeholder.image(
96
+ combined_frame,
97
+ channels="BGR",
98
+ use_column_width=True
99
+ )
100
 
101
+ # Control playback speed
102
+ cv2.waitKey(int(1000 / (fps * play_speed)))
103
+ frame_count += 1
 
104
 
105
  # Cleanup
106
  cap.release()
dance_generator.py CHANGED
@@ -1,42 +1,185 @@
1
  import numpy as np
2
  import cv2
 
3
 
4
  class DanceGenerator:
5
  def __init__(self):
6
  self.prev_moves = []
7
  self.style_memory = []
 
8
 
9
- def sync_moves(self, pose_landmarks):
10
- """Generate synchronized dance moves based on input pose"""
11
- if pose_landmarks is None:
12
- return np.zeros((480, 640, 3), dtype=np.uint8)
 
 
 
 
 
13
 
14
- # Convert landmarks to numpy array for processing
15
- landmarks_array = self._landmarks_to_array(pose_landmarks)
 
 
 
 
16
 
17
- # Mirror the movements for sync mode
18
- mirrored_moves = self._mirror_movements(landmarks_array)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Generate visualization frame
21
- return self._create_dance_frame(mirrored_moves)
 
 
22
 
23
- def generate_new_moves(self, pose_landmarks):
24
- """Generate new dance moves based on learned style"""
25
- if pose_landmarks is None:
26
- return np.zeros((480, 640, 3), dtype=np.uint8)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
- # Convert landmarks to array
29
- landmarks_array = self._landmarks_to_array(pose_landmarks)
30
 
31
- # Update style memory
32
- self._update_style_memory(landmarks_array)
 
33
 
34
- # Generate new moves based on style
35
- new_moves = self._generate_style_based_moves()
 
 
 
 
36
 
37
- # Create visualization frame
38
- return self._create_dance_frame(new_moves)
 
 
 
 
 
 
39
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  def _landmarks_to_array(self, landmarks):
41
  """Convert MediaPipe landmarks to numpy array"""
42
  points = []
 
1
  import numpy as np
2
  import cv2
3
+ from scipy.interpolate import interp1d
4
 
5
  class DanceGenerator:
6
  def __init__(self):
7
  self.prev_moves = []
8
  self.style_memory = []
9
+ self.avatar = cv2.imread('assets/dancer_avatar.png') # Add a dancer avatar image
10
 
11
+ def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
12
+ """Generate complete dance sequence for the entire video"""
13
+ height, width = frame_size
14
+ sequence = []
15
+
16
+ if mode == "Sync Partner":
17
+ sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
18
+ else:
19
+ sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
20
 
21
+ return sequence
22
+
23
+ def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
24
+ """Generate synchronized dance sequence"""
25
+ height, width = frame_size
26
+ sequence = []
27
 
28
+ # Convert all poses to arrays
29
+ pose_arrays = []
30
+ for pose in all_poses:
31
+ if pose is not None:
32
+ pose_arrays.append(self._landmarks_to_array(pose))
33
+ else:
34
+ pose_arrays.append(None)
35
+
36
+ # Generate mirrored sequence with smooth transitions
37
+ for i in range(total_frames):
38
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
39
+
40
+ if pose_arrays[i] is not None:
41
+ # Mirror the pose
42
+ mirrored = self._mirror_movements(pose_arrays[i])
43
+
44
+ # Add smooth transition from previous frame
45
+ if i > 0 and pose_arrays[i-1] is not None:
46
+ mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
47
+
48
+ # Create dance frame
49
+ frame = self._create_enhanced_dance_frame(
50
+ mirrored,
51
+ frame_size,
52
+ add_effects=True
53
+ )
54
+
55
+ sequence.append(frame)
56
+
57
+ return sequence
58
 
59
+ def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
60
+ """Generate creative dance sequence based on style"""
61
+ height, width = frame_size
62
+ sequence = []
63
 
64
+ # Analyze style from all poses
65
+ style_patterns = self._analyze_style_patterns(all_poses)
66
+
67
+ # Generate new sequence using style patterns
68
+ for i in range(total_frames):
69
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
70
+
71
+ # Generate new pose based on style
72
+ new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
73
+
74
+ if new_pose is not None:
75
+ frame = self._create_enhanced_dance_frame(
76
+ new_pose,
77
+ frame_size,
78
+ add_effects=True
79
+ )
80
+
81
+ sequence.append(frame)
82
 
83
+ return sequence
 
84
 
85
+ def _analyze_style_patterns(self, poses):
86
+ """Analyze dance style patterns from poses"""
87
+ patterns = []
88
 
89
+ for pose in poses:
90
+ if pose is not None:
91
+ landmarks = self._landmarks_to_array(pose)
92
+ patterns.append(landmarks)
93
+
94
+ return patterns
95
 
96
+ def _generate_style_based_pose(self, patterns, progress):
97
+ """Generate new pose based on style patterns and progress"""
98
+ if not patterns:
99
+ return None
100
+
101
+ # Create smooth interpolation between poses
102
+ num_patterns = len(patterns)
103
+ pattern_idx = int(progress * (num_patterns - 1))
104
 
105
+ if pattern_idx < num_patterns - 1:
106
+ t = progress * (num_patterns - 1) - pattern_idx
107
+ pose = self._interpolate_poses(
108
+ patterns[pattern_idx],
109
+ patterns[pattern_idx + 1],
110
+ t
111
+ )
112
+ else:
113
+ pose = patterns[-1]
114
+
115
+ return pose
116
+
117
+ def _interpolate_poses(self, pose1, pose2, t):
118
+ """Smoothly interpolate between two poses"""
119
+ return pose1 * (1 - t) + pose2 * t
120
+
121
+ def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
122
+ """Create enhanced visualization frame with effects"""
123
+ height, width = frame_size
124
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
125
+
126
+ # Convert coordinates
127
+ points = (pose_array[:, :2] * [width, height]).astype(int)
128
+
129
+ # Draw enhanced skeleton
130
+ connections = self._get_pose_connections()
131
+ for connection in connections:
132
+ start_idx, end_idx = connection
133
+ if start_idx < len(points) and end_idx < len(points):
134
+ # Draw glowing lines
135
+ if add_effects:
136
+ self._draw_glowing_line(
137
+ frame,
138
+ points[start_idx],
139
+ points[end_idx],
140
+ (0, 255, 0)
141
+ )
142
+ else:
143
+ cv2.line(frame,
144
+ tuple(points[start_idx]),
145
+ tuple(points[end_idx]),
146
+ (0, 255, 0), 2)
147
+
148
+ # Draw enhanced joints
149
+ for point in points:
150
+ if add_effects:
151
+ self._draw_glowing_point(frame, point, (0, 0, 255))
152
+ else:
153
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
154
+
155
+ return frame
156
+
157
+ def _draw_glowing_line(self, frame, start, end, color, thickness=2):
158
+ """Draw a line with glow effect"""
159
+ # Draw main line
160
+ cv2.line(frame, tuple(start), tuple(end), color, thickness)
161
+
162
+ # Draw glow
163
+ for i in range(3):
164
+ alpha = 0.3 - i * 0.1
165
+ thickness = thickness + 2
166
+ cv2.line(frame, tuple(start), tuple(end),
167
+ tuple([int(c * alpha) for c in color]),
168
+ thickness)
169
+
170
+ def _draw_glowing_point(self, frame, point, color, radius=4):
171
+ """Draw a point with glow effect"""
172
+ # Draw main point
173
+ cv2.circle(frame, tuple(point), radius, color, -1)
174
+
175
+ # Draw glow
176
+ for i in range(3):
177
+ alpha = 0.3 - i * 0.1
178
+ r = radius + i * 2
179
+ cv2.circle(frame, tuple(point), r,
180
+ tuple([int(c * alpha) for c in color]),
181
+ -1)
182
+
183
  def _landmarks_to_array(self, landmarks):
184
  """Convert MediaPipe landmarks to numpy array"""
185
  points = []