ombhojane commited on
Commit
bbed0a3
·
verified ·
1 Parent(s): ec27c1f

Update colab.py

Browse files
Files changed (1) hide show
  1. colab.py +425 -424
colab.py CHANGED
@@ -1,425 +1,426 @@
1
- # Import necessary libraries
2
- import cv2
3
- import mediapipe as mp
4
- import numpy as np
5
- from scipy.interpolate import interp1d
6
- import time
7
- import os
8
- import tempfile
9
-
10
- class PoseDetector:
11
- def __init__(self):
12
- self.mp_pose = mp.solutions.pose
13
- self.pose = self.mp_pose.Pose(
14
- min_detection_confidence=0.5,
15
- min_tracking_confidence=0.5
16
- )
17
-
18
- def detect_pose(self, frame):
19
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
20
- results = self.pose.process(rgb_frame)
21
- return results.pose_landmarks if results.pose_landmarks else None
22
-
23
- class DanceGenerator:
24
- def __init__(self):
25
- self.prev_moves = []
26
- self.style_memory = []
27
- self.rhythm_patterns = []
28
-
29
- def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
30
- height, width = frame_size
31
- sequence = []
32
-
33
- if mode == "Sync Partner":
34
- sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
35
- else:
36
- sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
37
-
38
- return sequence
39
-
40
- def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
41
- height, width = frame_size
42
- sequence = []
43
-
44
- # Enhanced rhythm analysis
45
- rhythm_window = 10 # Analyze chunks of frames for rhythm
46
- beat_positions = self._detect_dance_beats(all_poses, rhythm_window)
47
-
48
- pose_arrays = []
49
- for pose in all_poses:
50
- if pose is not None:
51
- pose_arrays.append(self._landmarks_to_array(pose))
52
- else:
53
- pose_arrays.append(None)
54
-
55
- for i in range(total_frames):
56
- frame = np.zeros((height, width, 3), dtype=np.uint8)
57
-
58
- if pose_arrays[i] is not None:
59
- # Enhanced mirroring with rhythm awareness
60
- mirrored = self._mirror_movements(pose_arrays[i])
61
-
62
- # Apply rhythm-based movement enhancement
63
- if i in beat_positions:
64
- mirrored = self._enhance_movement_on_beat(mirrored)
65
-
66
- if i > 0 and pose_arrays[i-1] is not None:
67
- mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
68
-
69
- frame = self._create_enhanced_dance_frame(
70
- mirrored,
71
- frame_size,
72
- add_effects=True
73
- )
74
-
75
- sequence.append(frame)
76
-
77
- return sequence
78
-
79
- def _detect_dance_beats(self, poses, window_size):
80
- """Detect main beats in the dance sequence"""
81
- beat_positions = []
82
-
83
- if len(poses) < window_size:
84
- return beat_positions
85
-
86
- for i in range(window_size, len(poses)):
87
- if poses[i] is not None and poses[i-1] is not None:
88
- curr_pose = self._landmarks_to_array(poses[i])
89
- prev_pose = self._landmarks_to_array(poses[i-1])
90
-
91
- # Calculate movement magnitude
92
- movement = np.mean(np.abs(curr_pose - prev_pose))
93
-
94
- # Detect significant movements as beats
95
- if movement > np.mean(self.rhythm_patterns) + np.std(self.rhythm_patterns):
96
- beat_positions.append(i)
97
-
98
- return beat_positions
99
-
100
- def _enhance_movement_on_beat(self, pose):
101
- """Enhance movements during detected beats"""
102
- # Amplify movements slightly on beats
103
- center = np.mean(pose, axis=0)
104
- enhanced_pose = pose.copy()
105
-
106
- for i in range(len(pose)):
107
- # Amplify movement relative to center
108
- vector = pose[i] - center
109
- enhanced_pose[i] = center + vector * 1.2
110
-
111
- return enhanced_pose
112
-
113
- def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
114
- """Generate creative dance sequence based on style"""
115
- height, width = frame_size
116
- sequence = []
117
-
118
- # Analyze style from all poses
119
- style_patterns = self._analyze_style_patterns(all_poses)
120
-
121
- # Generate new sequence using style patterns
122
- for i in range(total_frames):
123
- frame = np.zeros((height, width, 3), dtype=np.uint8)
124
-
125
- # Generate new pose based on style
126
- new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
127
-
128
- if new_pose is not None:
129
- frame = self._create_enhanced_dance_frame(
130
- new_pose,
131
- frame_size,
132
- add_effects=True
133
- )
134
-
135
- sequence.append(frame)
136
-
137
- return sequence
138
-
139
- def _analyze_style_patterns(self, poses):
140
- """Enhanced style analysis including rhythm and movement patterns"""
141
- patterns = []
142
- rhythm_data = []
143
-
144
- for i in range(1, len(poses)):
145
- if poses[i] is not None and poses[i-1] is not None:
146
- # Calculate movement speed and direction
147
- curr_pose = self._landmarks_to_array(poses[i])
148
- prev_pose = self._landmarks_to_array(poses[i-1])
149
-
150
- # Analyze movement velocity
151
- velocity = np.mean(np.abs(curr_pose - prev_pose), axis=0)
152
- rhythm_data.append(velocity)
153
-
154
- # Store enhanced pattern data
155
- pattern_info = {
156
- 'pose': curr_pose,
157
- 'velocity': velocity,
158
- 'acceleration': velocity if i == 1 else velocity - prev_velocity
159
- }
160
- patterns.append(pattern_info)
161
- prev_velocity = velocity
162
-
163
- self.rhythm_patterns = rhythm_data
164
- return patterns
165
-
166
- def _generate_style_based_pose(self, patterns, progress):
167
- """Generate new pose based on style patterns and progress"""
168
- if not patterns:
169
- return None
170
-
171
- # Create smooth interpolation between poses
172
- num_patterns = len(patterns)
173
- pattern_idx = int(progress * (num_patterns - 1))
174
-
175
- if pattern_idx < num_patterns - 1:
176
- t = progress * (num_patterns - 1) - pattern_idx
177
- # Extract pose arrays from pattern dictionaries
178
- pose1 = patterns[pattern_idx]['pose']
179
- pose2 = patterns[pattern_idx + 1]['pose']
180
- pose = self._interpolate_poses(pose1, pose2, t)
181
- else:
182
- pose = patterns[-1]['pose']
183
-
184
- return pose
185
-
186
- def _interpolate_poses(self, pose1, pose2, t):
187
- """Smoothly interpolate between two poses"""
188
- if isinstance(pose1, dict):
189
- pose1 = pose1['pose']
190
- if isinstance(pose2, dict):
191
- pose2 = pose2['pose']
192
- return pose1 * (1 - t) + pose2 * t
193
-
194
- def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
195
- """Create enhanced visualization frame with effects"""
196
- height, width = frame_size
197
- # Change background from black to light gray for better visibility
198
- frame = np.ones((height, width, 3), dtype=np.uint8) * 240 # Light gray background
199
-
200
- # Convert coordinates
201
- points = (pose_array[:, :2] * [width, height]).astype(int)
202
-
203
- # Draw enhanced skeleton with thicker lines and more visible colors
204
- connections = self._get_pose_connections()
205
- for connection in connections:
206
- start_idx, end_idx = connection
207
- if start_idx < len(points) and end_idx < len(points):
208
- if add_effects:
209
- self._draw_glowing_line(
210
- frame,
211
- points[start_idx],
212
- points[end_idx],
213
- (0, 100, 255), # Orange color for skeleton
214
- thickness=4
215
- )
216
- else:
217
- cv2.line(frame,
218
- tuple(points[start_idx]),
219
- tuple(points[end_idx]),
220
- (0, 100, 255), 4)
221
-
222
- # Draw enhanced joints with larger radius
223
- for point in points:
224
- if add_effects:
225
- self._draw_glowing_point(frame, point, (255, 0, 0), radius=6) # Blue joints
226
- else:
227
- cv2.circle(frame, tuple(point), 6, (255, 0, 0), -1)
228
-
229
- return frame
230
-
231
- def _draw_glowing_line(self, frame, start, end, color, thickness=4):
232
- """Draw a line with enhanced glow effect"""
233
- # Draw outer glow
234
- for i in range(3):
235
- alpha = 0.5 - i * 0.15
236
- thick = thickness + (i * 4)
237
- cv2.line(frame, tuple(start), tuple(end),
238
- tuple([int(c * alpha) for c in color]),
239
- thick)
240
-
241
- # Draw main line
242
- cv2.line(frame, tuple(start), tuple(end), color, thickness)
243
-
244
- def _draw_glowing_point(self, frame, point, color, radius=6):
245
- """Draw a point with enhanced glow effect"""
246
- # Draw outer glow
247
- for i in range(3):
248
- alpha = 0.5 - i * 0.15
249
- r = radius + (i * 3)
250
- cv2.circle(frame, tuple(point), r,
251
- tuple([int(c * alpha) for c in color]),
252
- -1)
253
-
254
- # Draw main point
255
- cv2.circle(frame, tuple(point), radius, color, -1)
256
-
257
- def _landmarks_to_array(self, landmarks):
258
- """Convert MediaPipe landmarks to numpy array"""
259
- points = []
260
- for landmark in landmarks.landmark:
261
- points.append([landmark.x, landmark.y, landmark.z])
262
- return np.array(points)
263
-
264
- def _mirror_movements(self, landmarks):
265
- """Mirror the input movements"""
266
- mirrored = landmarks.copy()
267
- mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
268
- return mirrored
269
-
270
- def _update_style_memory(self, landmarks):
271
- """Update memory of dance style"""
272
- self.style_memory.append(landmarks)
273
- if len(self.style_memory) > 30: # Keep last 30 frames
274
- self.style_memory.pop(0)
275
-
276
- def _generate_style_based_moves(self):
277
- """Generate new moves based on learned style"""
278
- if not self.style_memory:
279
- return np.zeros((33, 3)) # Default pose shape
280
-
281
- # Simple implementation: interpolate between stored poses
282
- base_pose = self.style_memory[-1]
283
- if len(self.style_memory) > 1:
284
- prev_pose = self.style_memory[-2]
285
- t = np.random.random()
286
- new_pose = t * base_pose + (1-t) * prev_pose
287
- else:
288
- new_pose = base_pose
289
-
290
- return new_pose
291
-
292
- def _create_dance_frame(self, pose_array):
293
- """Create visualization frame from pose array"""
294
- frame = np.zeros((480, 640, 3), dtype=np.uint8)
295
-
296
- # Convert normalized coordinates to pixel coordinates
297
- points = (pose_array[:, :2] * [640, 480]).astype(int)
298
-
299
- # Draw connections between joints
300
- connections = self._get_pose_connections()
301
- for connection in connections:
302
- start_idx, end_idx = connection
303
- if start_idx < len(points) and end_idx < len(points):
304
- cv2.line(frame,
305
- tuple(points[start_idx]),
306
- tuple(points[end_idx]),
307
- (0, 255, 0), 2)
308
-
309
- # Draw joints
310
- for point in points:
311
- cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
312
-
313
- return frame
314
-
315
- def _get_pose_connections(self):
316
- """Define connections between pose landmarks"""
317
- return [
318
- (0, 1), (1, 2), (2, 3), (3, 7), # Face
319
- (0, 4), (4, 5), (5, 6), (6, 8),
320
- (9, 10), (11, 12), (11, 13), (13, 15), # Arms
321
- (12, 14), (14, 16),
322
- (11, 23), (12, 24), # Torso
323
- (23, 24), (23, 25), (24, 26), # Legs
324
- (25, 27), (26, 28), (27, 29), (28, 30),
325
- (29, 31), (30, 32)
326
- ]
327
-
328
- def _smooth_transition(self, prev_pose, current_pose, smoothing_factor=0.3):
329
- """Create smooth transition between poses"""
330
- if prev_pose is None or current_pose is None:
331
- return current_pose
332
-
333
- # Interpolate between previous and current pose
334
- smoothed_pose = (1 - smoothing_factor) * prev_pose + smoothing_factor * current_pose
335
-
336
- # Ensure the smoothed pose maintains proper proportions
337
- # Normalize joint positions relative to hip center
338
- hip_center_idx = 23 # Index for hip center landmark
339
-
340
- prev_hip = prev_pose[hip_center_idx]
341
- current_hip = current_pose[hip_center_idx]
342
- smoothed_hip = smoothed_pose[hip_center_idx]
343
-
344
- # Adjust positions relative to hip center
345
- for i in range(len(smoothed_pose)):
346
- if i != hip_center_idx:
347
- # Calculate relative positions
348
- prev_relative = prev_pose[i] - prev_hip
349
- current_relative = current_pose[i] - current_hip
350
-
351
- # Interpolate relative positions
352
- smoothed_relative = (1 - smoothing_factor) * prev_relative + smoothing_factor * current_relative
353
-
354
- # Update smoothed pose
355
- smoothed_pose[i] = smoothed_hip + smoothed_relative
356
-
357
- return smoothed_pose
358
-
359
- class AIDancePartner:
360
- def __init__(self):
361
- self.pose_detector = PoseDetector()
362
- self.dance_generator = DanceGenerator()
363
-
364
- def process_video(self, video_path, mode="Sync Partner"):
365
- # Create a temporary directory for output
366
- temp_dir = tempfile.mkdtemp()
367
- output_path = os.path.join(temp_dir, 'output_dance.mp4')
368
-
369
- cap = cv2.VideoCapture(video_path)
370
-
371
- # Get video properties
372
- fps = int(cap.get(cv2.CAP_PROP_FPS))
373
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
374
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
375
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
376
-
377
- # Create output video writer
378
- fourcc = cv2.VideoWriter_fourcc(*'avc1')
379
- out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width * 2, frame_height))
380
-
381
- # Pre-process video to extract all poses
382
- all_poses = []
383
- frame_count = 0
384
-
385
- while cap.isOpened():
386
- ret, frame = cap.read()
387
- if not ret:
388
- break
389
-
390
- pose_landmarks = self.pose_detector.detect_pose(frame)
391
- all_poses.append(pose_landmarks)
392
- frame_count += 1
393
-
394
- # Generate AI dance sequence
395
- ai_sequence = self.dance_generator.generate_dance_sequence(
396
- all_poses,
397
- mode,
398
- total_frames,
399
- (frame_height, frame_width)
400
- )
401
-
402
- # Reset video capture and create final video
403
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
404
- frame_count = 0
405
-
406
- while cap.isOpened():
407
- ret, frame = cap.read()
408
- if not ret:
409
- break
410
-
411
- # Get corresponding AI frame
412
- ai_frame = ai_sequence[frame_count]
413
-
414
- # Combine frames side by side
415
- combined_frame = np.hstack([frame, ai_frame])
416
-
417
- # Write frame to output video
418
- out.write(combined_frame)
419
- frame_count += 1
420
-
421
- # Release resources
422
- cap.release()
423
- out.release()
424
-
 
425
  return output_path
 
1
+ # Import necessary libraries
2
+ import cv2
3
+ import mediapipe as mp
4
+ import numpy as np
5
+ from scipy.interpolate import interp1d
6
+ import time
7
+ import os
8
+ import tempfile
9
+
10
+ class PoseDetector:
11
+ def __init__(self):
12
+ self.mp_pose = mp.solutions.pose
13
+ self.pose = self.mp_pose.Pose(
14
+ min_detection_confidence=0.5,
15
+ min_tracking_confidence=0.5
16
+ )
17
+
18
+ def detect_pose(self, frame):
19
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
20
+ results = self.pose.process(rgb_frame)
21
+ return results.pose_landmarks if results.pose_landmarks else None
22
+
23
+ class DanceGenerator:
24
+ def __init__(self):
25
+ self.prev_moves = []
26
+ self.style_memory = []
27
+ self.rhythm_patterns = []
28
+
29
+ def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
30
+ height, width = frame_size
31
+ sequence = []
32
+
33
+ if mode == "Sync Partner":
34
+ sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
35
+ else:
36
+ sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
37
+
38
+ return sequence
39
+
40
+ def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
41
+ height, width = frame_size
42
+ sequence = []
43
+
44
+ # Enhanced rhythm analysis
45
+ rhythm_window = 10 # Analyze chunks of frames for rhythm
46
+ beat_positions = self._detect_dance_beats(all_poses, rhythm_window)
47
+
48
+ pose_arrays = []
49
+ for pose in all_poses:
50
+ if pose is not None:
51
+ pose_arrays.append(self._landmarks_to_array(pose))
52
+ else:
53
+ pose_arrays.append(None)
54
+
55
+ for i in range(total_frames):
56
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
57
+
58
+ if pose_arrays[i] is not None:
59
+ # Enhanced mirroring with rhythm awareness
60
+ mirrored = self._mirror_movements(pose_arrays[i])
61
+
62
+ # Apply rhythm-based movement enhancement
63
+ if i in beat_positions:
64
+ mirrored = self._enhance_movement_on_beat(mirrored)
65
+
66
+ if i > 0 and pose_arrays[i-1] is not None:
67
+ mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
68
+
69
+ frame = self._create_enhanced_dance_frame(
70
+ mirrored,
71
+ frame_size,
72
+ add_effects=True
73
+ )
74
+
75
+ sequence.append(frame)
76
+
77
+ return sequence
78
+
79
+ def _detect_dance_beats(self, poses, window_size):
80
+ """Detect main beats in the dance sequence"""
81
+ beat_positions = []
82
+
83
+ if len(poses) < window_size:
84
+ return beat_positions
85
+
86
+ for i in range(window_size, len(poses)):
87
+ if poses[i] is not None and poses[i-1] is not None:
88
+ curr_pose = self._landmarks_to_array(poses[i])
89
+ prev_pose = self._landmarks_to_array(poses[i-1])
90
+
91
+ # Calculate movement magnitude
92
+ movement = np.mean(np.abs(curr_pose - prev_pose))
93
+
94
+ # Detect significant movements as beats
95
+ if movement > np.mean(self.rhythm_patterns) + np.std(self.rhythm_patterns):
96
+ beat_positions.append(i)
97
+
98
+ return beat_positions
99
+
100
+ def _enhance_movement_on_beat(self, pose):
101
+ """Enhance movements during detected beats"""
102
+ # Amplify movements slightly on beats
103
+ center = np.mean(pose, axis=0)
104
+ enhanced_pose = pose.copy()
105
+
106
+ for i in range(len(pose)):
107
+ # Amplify movement relative to center
108
+ vector = pose[i] - center
109
+ enhanced_pose[i] = center + vector * 1.2
110
+
111
+ return enhanced_pose
112
+
113
+ def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
114
+ """Generate creative dance sequence based on style"""
115
+ height, width = frame_size
116
+ sequence = []
117
+
118
+ # Analyze style from all poses
119
+ style_patterns = self._analyze_style_patterns(all_poses)
120
+
121
+ # Generate new sequence using style patterns
122
+ for i in range(total_frames):
123
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
124
+
125
+ # Generate new pose based on style
126
+ new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
127
+
128
+ if new_pose is not None:
129
+ frame = self._create_enhanced_dance_frame(
130
+ new_pose,
131
+ frame_size,
132
+ add_effects=True
133
+ )
134
+
135
+ sequence.append(frame)
136
+
137
+ return sequence
138
+
139
+ def _analyze_style_patterns(self, poses):
140
+ """Enhanced style analysis including rhythm and movement patterns"""
141
+ patterns = []
142
+ rhythm_data = []
143
+
144
+ for i in range(1, len(poses)):
145
+ if poses[i] is not None and poses[i-1] is not None:
146
+ # Calculate movement speed and direction
147
+ curr_pose = self._landmarks_to_array(poses[i])
148
+ prev_pose = self._landmarks_to_array(poses[i-1])
149
+
150
+ # Analyze movement velocity
151
+ velocity = np.mean(np.abs(curr_pose - prev_pose), axis=0)
152
+ rhythm_data.append(velocity)
153
+
154
+ # Store enhanced pattern data
155
+ pattern_info = {
156
+ 'pose': curr_pose,
157
+ 'velocity': velocity,
158
+ 'acceleration': velocity if i == 1 else velocity - prev_velocity
159
+ }
160
+ patterns.append(pattern_info)
161
+ prev_velocity = velocity
162
+
163
+ self.rhythm_patterns = rhythm_data
164
+ return patterns
165
+
166
+ def _generate_style_based_pose(self, patterns, progress):
167
+ """Generate new pose based on style patterns and progress"""
168
+ if not patterns:
169
+ return None
170
+
171
+ # Create smooth interpolation between poses
172
+ num_patterns = len(patterns)
173
+ pattern_idx = int(progress * (num_patterns - 1))
174
+
175
+ if pattern_idx < num_patterns - 1:
176
+ t = progress * (num_patterns - 1) - pattern_idx
177
+ # Extract pose arrays from pattern dictionaries
178
+ pose1 = patterns[pattern_idx]['pose']
179
+ pose2 = patterns[pattern_idx + 1]['pose']
180
+ pose = self._interpolate_poses(pose1, pose2, t)
181
+ else:
182
+ pose = patterns[-1]['pose']
183
+
184
+ return pose
185
+
186
+ def _interpolate_poses(self, pose1, pose2, t):
187
+ """Smoothly interpolate between two poses"""
188
+ if isinstance(pose1, dict):
189
+ pose1 = pose1['pose']
190
+ if isinstance(pose2, dict):
191
+ pose2 = pose2['pose']
192
+ return pose1 * (1 - t) + pose2 * t
193
+
194
+ def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
195
+ """Create enhanced visualization frame with effects"""
196
+ height, width = frame_size
197
+ # Change background from black to light gray for better visibility
198
+ frame = np.ones((height, width, 3), dtype=np.uint8) * 240 # Light gray background
199
+
200
+ # Convert coordinates
201
+ points = (pose_array[:, :2] * [width, height]).astype(int)
202
+
203
+ # Draw enhanced skeleton with thicker lines and more visible colors
204
+ connections = self._get_pose_connections()
205
+ for connection in connections:
206
+ start_idx, end_idx = connection
207
+ if start_idx < len(points) and end_idx < len(points):
208
+ if add_effects:
209
+ self._draw_glowing_line(
210
+ frame,
211
+ points[start_idx],
212
+ points[end_idx],
213
+ (0, 100, 255), # Orange color for skeleton
214
+ thickness=4
215
+ )
216
+ else:
217
+ cv2.line(frame,
218
+ tuple(points[start_idx]),
219
+ tuple(points[end_idx]),
220
+ (0, 100, 255), 4)
221
+
222
+ # Draw enhanced joints with larger radius
223
+ for point in points:
224
+ if add_effects:
225
+ self._draw_glowing_point(frame, point, (255, 0, 0), radius=6) # Blue joints
226
+ else:
227
+ cv2.circle(frame, tuple(point), 6, (255, 0, 0), -1)
228
+
229
+ return frame
230
+
231
+ def _draw_glowing_line(self, frame, start, end, color, thickness=4):
232
+ """Draw a line with enhanced glow effect"""
233
+ # Draw outer glow
234
+ for i in range(3):
235
+ alpha = 0.5 - i * 0.15
236
+ thick = thickness + (i * 4)
237
+ cv2.line(frame, tuple(start), tuple(end),
238
+ tuple([int(c * alpha) for c in color]),
239
+ thick)
240
+
241
+ # Draw main line
242
+ cv2.line(frame, tuple(start), tuple(end), color, thickness)
243
+
244
+ def _draw_glowing_point(self, frame, point, color, radius=6):
245
+ """Draw a point with enhanced glow effect"""
246
+ # Draw outer glow
247
+ for i in range(3):
248
+ alpha = 0.5 - i * 0.15
249
+ r = radius + (i * 3)
250
+ cv2.circle(frame, tuple(point), r,
251
+ tuple([int(c * alpha) for c in color]),
252
+ -1)
253
+
254
+ # Draw main point
255
+ cv2.circle(frame, tuple(point), radius, color, -1)
256
+
257
+ def _landmarks_to_array(self, landmarks):
258
+ """Convert MediaPipe landmarks to numpy array"""
259
+ points = []
260
+ for landmark in landmarks.landmark:
261
+ points.append([landmark.x, landmark.y, landmark.z])
262
+ return np.array(points)
263
+
264
+ def _mirror_movements(self, landmarks):
265
+ """Mirror the input movements"""
266
+ mirrored = landmarks.copy()
267
+ mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
268
+ return mirrored
269
+
270
+ def _update_style_memory(self, landmarks):
271
+ """Update memory of dance style"""
272
+ self.style_memory.append(landmarks)
273
+ if len(self.style_memory) > 30: # Keep last 30 frames
274
+ self.style_memory.pop(0)
275
+
276
+ def _generate_style_based_moves(self):
277
+ """Generate new moves based on learned style"""
278
+ if not self.style_memory:
279
+ return np.zeros((33, 3)) # Default pose shape
280
+
281
+ # Simple implementation: interpolate between stored poses
282
+ base_pose = self.style_memory[-1]
283
+ if len(self.style_memory) > 1:
284
+ prev_pose = self.style_memory[-2]
285
+ t = np.random.random()
286
+ new_pose = t * base_pose + (1-t) * prev_pose
287
+ else:
288
+ new_pose = base_pose
289
+
290
+ return new_pose
291
+
292
+ def _create_dance_frame(self, pose_array):
293
+ """Create visualization frame from pose array"""
294
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
295
+
296
+ # Convert normalized coordinates to pixel coordinates
297
+ points = (pose_array[:, :2] * [640, 480]).astype(int)
298
+
299
+ # Draw connections between joints
300
+ connections = self._get_pose_connections()
301
+ for connection in connections:
302
+ start_idx, end_idx = connection
303
+ if start_idx < len(points) and end_idx < len(points):
304
+ cv2.line(frame,
305
+ tuple(points[start_idx]),
306
+ tuple(points[end_idx]),
307
+ (0, 255, 0), 2)
308
+
309
+ # Draw joints
310
+ for point in points:
311
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
312
+
313
+ return frame
314
+
315
+ def _get_pose_connections(self):
316
+ """Define connections between pose landmarks"""
317
+ return [
318
+ (0, 1), (1, 2), (2, 3), (3, 7), # Face
319
+ (0, 4), (4, 5), (5, 6), (6, 8),
320
+ (9, 10), (11, 12), (11, 13), (13, 15), # Arms
321
+ (12, 14), (14, 16),
322
+ (11, 23), (12, 24), # Torso
323
+ (23, 24), (23, 25), (24, 26), # Legs
324
+ (25, 27), (26, 28), (27, 29), (28, 30),
325
+ (29, 31), (30, 32)
326
+ ]
327
+
328
+ def _smooth_transition(self, prev_pose, current_pose, smoothing_factor=0.3):
329
+ """Create smooth transition between poses"""
330
+ if prev_pose is None or current_pose is None:
331
+ return current_pose
332
+
333
+ # Interpolate between previous and current pose
334
+ smoothed_pose = (1 - smoothing_factor) * prev_pose + smoothing_factor * current_pose
335
+
336
+ # Ensure the smoothed pose maintains proper proportions
337
+ # Normalize joint positions relative to hip center
338
+ hip_center_idx = 23 # Index for hip center landmark
339
+
340
+ prev_hip = prev_pose[hip_center_idx]
341
+ current_hip = current_pose[hip_center_idx]
342
+ smoothed_hip = smoothed_pose[hip_center_idx]
343
+
344
+ # Adjust positions relative to hip center
345
+ for i in range(len(smoothed_pose)):
346
+ if i != hip_center_idx:
347
+ # Calculate relative positions
348
+ prev_relative = prev_pose[i] - prev_hip
349
+ current_relative = current_pose[i] - current_hip
350
+
351
+ # Interpolate relative positions
352
+ smoothed_relative = (1 - smoothing_factor) * prev_relative + smoothing_factor * current_relative
353
+
354
+ # Update smoothed pose
355
+ smoothed_pose[i] = smoothed_hip + smoothed_relative
356
+
357
+ return smoothed_pose
358
+
359
+ class AIDancePartner:
360
+ def __init__(self):
361
+ self.pose_detector = PoseDetector()
362
+ self.dance_generator = DanceGenerator()
363
+
364
+ def process_video(self, video_path, mode="Sync Partner"):
365
+ # Create a temporary directory for output
366
+ temp_dir = tempfile.mkdtemp()
367
+ output_path = os.path.join(temp_dir, 'output_dance.mp4')
368
+
369
+ cap = cv2.VideoCapture(video_path)
370
+
371
+ # Get video properties
372
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
373
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
374
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
375
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
376
+
377
+ # Create output video writer
378
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
379
+ out = cv2.VideoWriter(output_path, fourcc, fps,
380
+ (frame_width * 2, frame_height))
381
+
382
+ # Pre-process video to extract all poses
383
+ all_poses = []
384
+ frame_count = 0
385
+
386
+ while cap.isOpened():
387
+ ret, frame = cap.read()
388
+ if not ret:
389
+ break
390
+
391
+ pose_landmarks = self.pose_detector.detect_pose(frame)
392
+ all_poses.append(pose_landmarks)
393
+ frame_count += 1
394
+
395
+ # Generate AI dance sequence
396
+ ai_sequence = self.dance_generator.generate_dance_sequence(
397
+ all_poses,
398
+ mode,
399
+ total_frames,
400
+ (frame_height, frame_width)
401
+ )
402
+
403
+ # Reset video capture and create final video
404
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
405
+ frame_count = 0
406
+
407
+ while cap.isOpened():
408
+ ret, frame = cap.read()
409
+ if not ret:
410
+ break
411
+
412
+ # Get corresponding AI frame
413
+ ai_frame = ai_sequence[frame_count]
414
+
415
+ # Combine frames side by side
416
+ combined_frame = np.hstack([frame, ai_frame])
417
+
418
+ # Write frame to output video
419
+ out.write(combined_frame)
420
+ frame_count += 1
421
+
422
+ # Release resources
423
+ cap.release()
424
+ out.release()
425
+
426
  return output_path