ombhojane commited on
Commit
f7b842d
·
verified ·
1 Parent(s): e402479

Update colab.py

Browse files
Files changed (1) hide show
  1. colab.py +422 -469
colab.py CHANGED
@@ -1,469 +1,422 @@
1
-
2
- # Import necessary libraries
3
- import cv2
4
- import mediapipe as mp
5
- import numpy as np
6
- from scipy.interpolate import interp1d
7
- from google.colab import files
8
- from IPython.display import HTML, display
9
- from IPython.display import clear_output
10
- import time
11
- import base64
12
-
13
- class PoseDetector:
14
- def __init__(self):
15
- self.mp_pose = mp.solutions.pose
16
- self.pose = self.mp_pose.Pose(
17
- min_detection_confidence=0.5,
18
- min_tracking_confidence=0.5
19
- )
20
-
21
- def detect_pose(self, frame):
22
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
23
- results = self.pose.process(rgb_frame)
24
- return results.pose_landmarks if results.pose_landmarks else None
25
-
26
- class DanceGenerator:
27
- def __init__(self):
28
- self.prev_moves = []
29
- self.style_memory = []
30
- self.rhythm_patterns = []
31
-
32
- def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
33
- height, width = frame_size
34
- sequence = []
35
-
36
- if mode == "Sync Partner":
37
- sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
38
- else:
39
- sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
40
-
41
- return sequence
42
-
43
- def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
44
- height, width = frame_size
45
- sequence = []
46
-
47
- # Enhanced rhythm analysis
48
- rhythm_window = 10 # Analyze chunks of frames for rhythm
49
- beat_positions = self._detect_dance_beats(all_poses, rhythm_window)
50
-
51
- pose_arrays = []
52
- for pose in all_poses:
53
- if pose is not None:
54
- pose_arrays.append(self._landmarks_to_array(pose))
55
- else:
56
- pose_arrays.append(None)
57
-
58
- for i in range(total_frames):
59
- frame = np.zeros((height, width, 3), dtype=np.uint8)
60
-
61
- if pose_arrays[i] is not None:
62
- # Enhanced mirroring with rhythm awareness
63
- mirrored = self._mirror_movements(pose_arrays[i])
64
-
65
- # Apply rhythm-based movement enhancement
66
- if i in beat_positions:
67
- mirrored = self._enhance_movement_on_beat(mirrored)
68
-
69
- if i > 0 and pose_arrays[i-1] is not None:
70
- mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
71
-
72
- frame = self._create_enhanced_dance_frame(
73
- mirrored,
74
- frame_size,
75
- add_effects=True
76
- )
77
-
78
- sequence.append(frame)
79
-
80
- return sequence
81
-
82
- def _detect_dance_beats(self, poses, window_size):
83
- """Detect main beats in the dance sequence"""
84
- beat_positions = []
85
-
86
- if len(poses) < window_size:
87
- return beat_positions
88
-
89
- for i in range(window_size, len(poses)):
90
- if poses[i] is not None and poses[i-1] is not None:
91
- curr_pose = self._landmarks_to_array(poses[i])
92
- prev_pose = self._landmarks_to_array(poses[i-1])
93
-
94
- # Calculate movement magnitude
95
- movement = np.mean(np.abs(curr_pose - prev_pose))
96
-
97
- # Detect significant movements as beats
98
- if movement > np.mean(self.rhythm_patterns) + np.std(self.rhythm_patterns):
99
- beat_positions.append(i)
100
-
101
- return beat_positions
102
-
103
- def _enhance_movement_on_beat(self, pose):
104
- """Enhance movements during detected beats"""
105
- # Amplify movements slightly on beats
106
- center = np.mean(pose, axis=0)
107
- enhanced_pose = pose.copy()
108
-
109
- for i in range(len(pose)):
110
- # Amplify movement relative to center
111
- vector = pose[i] - center
112
- enhanced_pose[i] = center + vector * 1.2
113
-
114
- return enhanced_pose
115
-
116
- def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
117
- """Generate creative dance sequence based on style"""
118
- height, width = frame_size
119
- sequence = []
120
-
121
- # Analyze style from all poses
122
- style_patterns = self._analyze_style_patterns(all_poses)
123
-
124
- # Generate new sequence using style patterns
125
- for i in range(total_frames):
126
- frame = np.zeros((height, width, 3), dtype=np.uint8)
127
-
128
- # Generate new pose based on style
129
- new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
130
-
131
- if new_pose is not None:
132
- frame = self._create_enhanced_dance_frame(
133
- new_pose,
134
- frame_size,
135
- add_effects=True
136
- )
137
-
138
- sequence.append(frame)
139
-
140
- return sequence
141
-
142
- def _analyze_style_patterns(self, poses):
143
- """Enhanced style analysis including rhythm and movement patterns"""
144
- patterns = []
145
- rhythm_data = []
146
-
147
- for i in range(1, len(poses)):
148
- if poses[i] is not None and poses[i-1] is not None:
149
- # Calculate movement speed and direction
150
- curr_pose = self._landmarks_to_array(poses[i])
151
- prev_pose = self._landmarks_to_array(poses[i-1])
152
-
153
- # Analyze movement velocity
154
- velocity = np.mean(np.abs(curr_pose - prev_pose), axis=0)
155
- rhythm_data.append(velocity)
156
-
157
- # Store enhanced pattern data
158
- pattern_info = {
159
- 'pose': curr_pose,
160
- 'velocity': velocity,
161
- 'acceleration': velocity if i == 1 else velocity - prev_velocity
162
- }
163
- patterns.append(pattern_info)
164
- prev_velocity = velocity
165
-
166
- self.rhythm_patterns = rhythm_data
167
- return patterns
168
-
169
- def _generate_style_based_pose(self, patterns, progress):
170
- """Generate new pose based on style patterns and progress"""
171
- if not patterns:
172
- return None
173
-
174
- # Create smooth interpolation between poses
175
- num_patterns = len(patterns)
176
- pattern_idx = int(progress * (num_patterns - 1))
177
-
178
- if pattern_idx < num_patterns - 1:
179
- t = progress * (num_patterns - 1) - pattern_idx
180
- pose = self._interpolate_poses(
181
- patterns[pattern_idx],
182
- patterns[pattern_idx + 1],
183
- t
184
- )
185
- else:
186
- pose = patterns[-1]
187
-
188
- return pose
189
-
190
- def _interpolate_poses(self, pose1, pose2, t):
191
- """Smoothly interpolate between two poses"""
192
- return pose1 * (1 - t) + pose2 * t
193
-
194
- def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
195
- """Create enhanced visualization frame with effects"""
196
- height, width = frame_size
197
- frame = np.zeros((height, width, 3), dtype=np.uint8)
198
-
199
- # Convert coordinates
200
- points = (pose_array[:, :2] * [width, height]).astype(int)
201
-
202
- # Draw enhanced skeleton
203
- connections = self._get_pose_connections()
204
- for connection in connections:
205
- start_idx, end_idx = connection
206
- if start_idx < len(points) and end_idx < len(points):
207
- # Draw glowing lines
208
- if add_effects:
209
- self._draw_glowing_line(
210
- frame,
211
- points[start_idx],
212
- points[end_idx],
213
- (0, 255, 0)
214
- )
215
- else:
216
- cv2.line(frame,
217
- tuple(points[start_idx]),
218
- tuple(points[end_idx]),
219
- (0, 255, 0), 2)
220
-
221
- # Draw enhanced joints
222
- for point in points:
223
- if add_effects:
224
- self._draw_glowing_point(frame, point, (0, 0, 255))
225
- else:
226
- cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
227
-
228
- return frame
229
-
230
- def _draw_glowing_line(self, frame, start, end, color, thickness=2):
231
- """Draw a line with glow effect"""
232
- # Draw main line
233
- cv2.line(frame, tuple(start), tuple(end), color, thickness)
234
-
235
- # Draw glow
236
- for i in range(3):
237
- alpha = 0.3 - i * 0.1
238
- thickness = thickness + 2
239
- cv2.line(frame, tuple(start), tuple(end),
240
- tuple([int(c * alpha) for c in color]),
241
- thickness)
242
-
243
- def _draw_glowing_point(self, frame, point, color, radius=4):
244
- """Draw a point with glow effect"""
245
- # Draw main point
246
- cv2.circle(frame, tuple(point), radius, color, -1)
247
-
248
- # Draw glow
249
- for i in range(3):
250
- alpha = 0.3 - i * 0.1
251
- r = radius + i * 2
252
- cv2.circle(frame, tuple(point), r,
253
- tuple([int(c * alpha) for c in color]),
254
- -1)
255
-
256
- def _landmarks_to_array(self, landmarks):
257
- """Convert MediaPipe landmarks to numpy array"""
258
- points = []
259
- for landmark in landmarks.landmark:
260
- points.append([landmark.x, landmark.y, landmark.z])
261
- return np.array(points)
262
-
263
- def _mirror_movements(self, landmarks):
264
- """Mirror the input movements"""
265
- mirrored = landmarks.copy()
266
- mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
267
- return mirrored
268
-
269
- def _update_style_memory(self, landmarks):
270
- """Update memory of dance style"""
271
- self.style_memory.append(landmarks)
272
- if len(self.style_memory) > 30: # Keep last 30 frames
273
- self.style_memory.pop(0)
274
-
275
- def _generate_style_based_moves(self):
276
- """Generate new moves based on learned style"""
277
- if not self.style_memory:
278
- return np.zeros((33, 3)) # Default pose shape
279
-
280
- # Simple implementation: interpolate between stored poses
281
- base_pose = self.style_memory[-1]
282
- if len(self.style_memory) > 1:
283
- prev_pose = self.style_memory[-2]
284
- t = np.random.random()
285
- new_pose = t * base_pose + (1-t) * prev_pose
286
- else:
287
- new_pose = base_pose
288
-
289
- return new_pose
290
-
291
- def _create_dance_frame(self, pose_array):
292
- """Create visualization frame from pose array"""
293
- frame = np.zeros((480, 640, 3), dtype=np.uint8)
294
-
295
- # Convert normalized coordinates to pixel coordinates
296
- points = (pose_array[:, :2] * [640, 480]).astype(int)
297
-
298
- # Draw connections between joints
299
- connections = self._get_pose_connections()
300
- for connection in connections:
301
- start_idx, end_idx = connection
302
- if start_idx < len(points) and end_idx < len(points):
303
- cv2.line(frame,
304
- tuple(points[start_idx]),
305
- tuple(points[end_idx]),
306
- (0, 255, 0), 2)
307
-
308
- # Draw joints
309
- for point in points:
310
- cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
311
-
312
- return frame
313
-
314
- def _get_pose_connections(self):
315
- """Define connections between pose landmarks"""
316
- return [
317
- (0, 1), (1, 2), (2, 3), (3, 7), # Face
318
- (0, 4), (4, 5), (5, 6), (6, 8),
319
- (9, 10), (11, 12), (11, 13), (13, 15), # Arms
320
- (12, 14), (14, 16),
321
- (11, 23), (12, 24), # Torso
322
- (23, 24), (23, 25), (24, 26), # Legs
323
- (25, 27), (26, 28), (27, 29), (28, 30),
324
- (29, 31), (30, 32)
325
- ]
326
-
327
- def _smooth_transition(self, prev_pose, current_pose, smoothing_factor=0.3):
328
- """Create smooth transition between poses"""
329
- if prev_pose is None or current_pose is None:
330
- return current_pose
331
-
332
- # Interpolate between previous and current pose
333
- smoothed_pose = (1 - smoothing_factor) * prev_pose + smoothing_factor * current_pose
334
-
335
- # Ensure the smoothed pose maintains proper proportions
336
- # Normalize joint positions relative to hip center
337
- hip_center_idx = 23 # Index for hip center landmark
338
-
339
- prev_hip = prev_pose[hip_center_idx]
340
- current_hip = current_pose[hip_center_idx]
341
- smoothed_hip = smoothed_pose[hip_center_idx]
342
-
343
- # Adjust positions relative to hip center
344
- for i in range(len(smoothed_pose)):
345
- if i != hip_center_idx:
346
- # Calculate relative positions
347
- prev_relative = prev_pose[i] - prev_hip
348
- current_relative = current_pose[i] - current_hip
349
-
350
- # Interpolate relative positions
351
- smoothed_relative = (1 - smoothing_factor) * prev_relative + smoothing_factor * current_relative
352
-
353
- # Update smoothed pose
354
- smoothed_pose[i] = smoothed_hip + smoothed_relative
355
-
356
- return smoothed_pose
357
-
358
- class AIDancePartner:
359
- def __init__(self):
360
- self.pose_detector = PoseDetector()
361
- self.dance_generator = DanceGenerator()
362
-
363
- def process_video(self, video_path, mode="Sync Partner"):
364
- cap = cv2.VideoCapture(video_path)
365
-
366
- # Get video properties
367
- fps = int(cap.get(cv2.CAP_PROP_FPS))
368
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
369
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
370
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
371
-
372
- # Create output video writer
373
- output_path = 'output_dance.mp4'
374
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
375
- out = cv2.VideoWriter(output_path, fourcc, fps,
376
- (frame_width * 2, frame_height))
377
-
378
- # Pre-process video to extract all poses
379
- print("Processing video...")
380
- all_poses = []
381
- frame_count = 0
382
-
383
- while cap.isOpened():
384
- ret, frame = cap.read()
385
- if not ret:
386
- break
387
-
388
- pose_landmarks = self.pose_detector.detect_pose(frame)
389
- all_poses.append(pose_landmarks)
390
-
391
- frame_count += 1
392
- if frame_count % 30 == 0:
393
- print(f"Processed {frame_count}/{total_frames} frames")
394
-
395
- # Generate AI dance sequence
396
- print("Generating AI dance sequence...")
397
- ai_sequence = self.dance_generator.generate_dance_sequence(
398
- all_poses,
399
- mode,
400
- total_frames,
401
- (frame_height, frame_width)
402
- )
403
-
404
- # Reset video capture and create final video
405
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
406
- frame_count = 0
407
- print("Creating final video...")
408
-
409
- while cap.isOpened():
410
- ret, frame = cap.read()
411
- if not ret:
412
- break
413
-
414
- # Get corresponding AI frame
415
- ai_frame = ai_sequence[frame_count]
416
-
417
- # Combine frames side by side
418
- combined_frame = np.hstack([frame, ai_frame])
419
-
420
- # Write frame to output video
421
- out.write(combined_frame)
422
-
423
- frame_count += 1
424
- if frame_count % 30 == 0:
425
- print(f"Processed {frame_count}/{total_frames} frames")
426
-
427
- # Release resources
428
- cap.release()
429
- out.release()
430
-
431
- print("Video processing complete!")
432
- return output_path
433
-
434
- def upload_and_process_video():
435
- """Function to handle video upload and processing in Colab"""
436
- print("Please upload a video file...")
437
- uploaded = files.upload()
438
-
439
- if uploaded:
440
- video_path = list(uploaded.keys())[0]
441
- print(f"Processing video: {video_path}")
442
-
443
- # Create AI Dance Partner instance
444
- dance_partner = AIDancePartner()
445
-
446
- # Process video
447
- output_path = dance_partner.process_video(video_path)
448
-
449
- # Display the output video
450
- def create_video_player(video_path):
451
- video_html = f'''
452
- <video width="100%" height="480" controls>
453
- <source src="data:video/mp4;base64,{base64.b64encode(open(output_path, 'rb').read()).decode()}" type="video/mp4">
454
- Your browser does not support the video tag.
455
- </video>
456
- '''
457
- return HTML(video_html)
458
-
459
- print("Displaying processed video...")
460
- display(create_video_player(output_path))
461
-
462
- # Option to download the processed video
463
- files.download(output_path)
464
-
465
- # Run the application
466
- if __name__ == "__main__":
467
- print("AI Dance Partner - Video Processing")
468
- print("==================================")
469
- upload_and_process_video()
 
1
+ # Import necessary libraries
2
+ import cv2
3
+ import mediapipe as mp
4
+ import numpy as np
5
+ from scipy.interpolate import interp1d
6
+ import time
7
+ import os
8
+ import tempfile
9
+
10
+ class PoseDetector:
11
+ def __init__(self):
12
+ self.mp_pose = mp.solutions.pose
13
+ self.pose = self.mp_pose.Pose(
14
+ min_detection_confidence=0.5,
15
+ min_tracking_confidence=0.5
16
+ )
17
+
18
+ def detect_pose(self, frame):
19
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
20
+ results = self.pose.process(rgb_frame)
21
+ return results.pose_landmarks if results.pose_landmarks else None
22
+
23
+ class DanceGenerator:
24
+ def __init__(self):
25
+ self.prev_moves = []
26
+ self.style_memory = []
27
+ self.rhythm_patterns = []
28
+
29
+ def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
30
+ height, width = frame_size
31
+ sequence = []
32
+
33
+ if mode == "Sync Partner":
34
+ sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
35
+ else:
36
+ sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
37
+
38
+ return sequence
39
+
40
+ def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
41
+ height, width = frame_size
42
+ sequence = []
43
+
44
+ # Enhanced rhythm analysis
45
+ rhythm_window = 10 # Analyze chunks of frames for rhythm
46
+ beat_positions = self._detect_dance_beats(all_poses, rhythm_window)
47
+
48
+ pose_arrays = []
49
+ for pose in all_poses:
50
+ if pose is not None:
51
+ pose_arrays.append(self._landmarks_to_array(pose))
52
+ else:
53
+ pose_arrays.append(None)
54
+
55
+ for i in range(total_frames):
56
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
57
+
58
+ if pose_arrays[i] is not None:
59
+ # Enhanced mirroring with rhythm awareness
60
+ mirrored = self._mirror_movements(pose_arrays[i])
61
+
62
+ # Apply rhythm-based movement enhancement
63
+ if i in beat_positions:
64
+ mirrored = self._enhance_movement_on_beat(mirrored)
65
+
66
+ if i > 0 and pose_arrays[i-1] is not None:
67
+ mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
68
+
69
+ frame = self._create_enhanced_dance_frame(
70
+ mirrored,
71
+ frame_size,
72
+ add_effects=True
73
+ )
74
+
75
+ sequence.append(frame)
76
+
77
+ return sequence
78
+
79
+ def _detect_dance_beats(self, poses, window_size):
80
+ """Detect main beats in the dance sequence"""
81
+ beat_positions = []
82
+
83
+ if len(poses) < window_size:
84
+ return beat_positions
85
+
86
+ for i in range(window_size, len(poses)):
87
+ if poses[i] is not None and poses[i-1] is not None:
88
+ curr_pose = self._landmarks_to_array(poses[i])
89
+ prev_pose = self._landmarks_to_array(poses[i-1])
90
+
91
+ # Calculate movement magnitude
92
+ movement = np.mean(np.abs(curr_pose - prev_pose))
93
+
94
+ # Detect significant movements as beats
95
+ if movement > np.mean(self.rhythm_patterns) + np.std(self.rhythm_patterns):
96
+ beat_positions.append(i)
97
+
98
+ return beat_positions
99
+
100
+ def _enhance_movement_on_beat(self, pose):
101
+ """Enhance movements during detected beats"""
102
+ # Amplify movements slightly on beats
103
+ center = np.mean(pose, axis=0)
104
+ enhanced_pose = pose.copy()
105
+
106
+ for i in range(len(pose)):
107
+ # Amplify movement relative to center
108
+ vector = pose[i] - center
109
+ enhanced_pose[i] = center + vector * 1.2
110
+
111
+ return enhanced_pose
112
+
113
+ def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
114
+ """Generate creative dance sequence based on style"""
115
+ height, width = frame_size
116
+ sequence = []
117
+
118
+ # Analyze style from all poses
119
+ style_patterns = self._analyze_style_patterns(all_poses)
120
+
121
+ # Generate new sequence using style patterns
122
+ for i in range(total_frames):
123
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
124
+
125
+ # Generate new pose based on style
126
+ new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
127
+
128
+ if new_pose is not None:
129
+ frame = self._create_enhanced_dance_frame(
130
+ new_pose,
131
+ frame_size,
132
+ add_effects=True
133
+ )
134
+
135
+ sequence.append(frame)
136
+
137
+ return sequence
138
+
139
+ def _analyze_style_patterns(self, poses):
140
+ """Enhanced style analysis including rhythm and movement patterns"""
141
+ patterns = []
142
+ rhythm_data = []
143
+
144
+ for i in range(1, len(poses)):
145
+ if poses[i] is not None and poses[i-1] is not None:
146
+ # Calculate movement speed and direction
147
+ curr_pose = self._landmarks_to_array(poses[i])
148
+ prev_pose = self._landmarks_to_array(poses[i-1])
149
+
150
+ # Analyze movement velocity
151
+ velocity = np.mean(np.abs(curr_pose - prev_pose), axis=0)
152
+ rhythm_data.append(velocity)
153
+
154
+ # Store enhanced pattern data
155
+ pattern_info = {
156
+ 'pose': curr_pose,
157
+ 'velocity': velocity,
158
+ 'acceleration': velocity if i == 1 else velocity - prev_velocity
159
+ }
160
+ patterns.append(pattern_info)
161
+ prev_velocity = velocity
162
+
163
+ self.rhythm_patterns = rhythm_data
164
+ return patterns
165
+
166
+ def _generate_style_based_pose(self, patterns, progress):
167
+ """Generate new pose based on style patterns and progress"""
168
+ if not patterns:
169
+ return None
170
+
171
+ # Create smooth interpolation between poses
172
+ num_patterns = len(patterns)
173
+ pattern_idx = int(progress * (num_patterns - 1))
174
+
175
+ if pattern_idx < num_patterns - 1:
176
+ t = progress * (num_patterns - 1) - pattern_idx
177
+ pose = self._interpolate_poses(
178
+ patterns[pattern_idx],
179
+ patterns[pattern_idx + 1],
180
+ t
181
+ )
182
+ else:
183
+ pose = patterns[-1]
184
+
185
+ return pose
186
+
187
+ def _interpolate_poses(self, pose1, pose2, t):
188
+ """Smoothly interpolate between two poses"""
189
+ return pose1 * (1 - t) + pose2 * t
190
+
191
+ def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
192
+ """Create enhanced visualization frame with effects"""
193
+ height, width = frame_size
194
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
195
+
196
+ # Convert coordinates
197
+ points = (pose_array[:, :2] * [width, height]).astype(int)
198
+
199
+ # Draw enhanced skeleton
200
+ connections = self._get_pose_connections()
201
+ for connection in connections:
202
+ start_idx, end_idx = connection
203
+ if start_idx < len(points) and end_idx < len(points):
204
+ # Draw glowing lines
205
+ if add_effects:
206
+ self._draw_glowing_line(
207
+ frame,
208
+ points[start_idx],
209
+ points[end_idx],
210
+ (0, 255, 0)
211
+ )
212
+ else:
213
+ cv2.line(frame,
214
+ tuple(points[start_idx]),
215
+ tuple(points[end_idx]),
216
+ (0, 255, 0), 2)
217
+
218
+ # Draw enhanced joints
219
+ for point in points:
220
+ if add_effects:
221
+ self._draw_glowing_point(frame, point, (0, 0, 255))
222
+ else:
223
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
224
+
225
+ return frame
226
+
227
+ def _draw_glowing_line(self, frame, start, end, color, thickness=2):
228
+ """Draw a line with glow effect"""
229
+ # Draw main line
230
+ cv2.line(frame, tuple(start), tuple(end), color, thickness)
231
+
232
+ # Draw glow
233
+ for i in range(3):
234
+ alpha = 0.3 - i * 0.1
235
+ thickness = thickness + 2
236
+ cv2.line(frame, tuple(start), tuple(end),
237
+ tuple([int(c * alpha) for c in color]),
238
+ thickness)
239
+
240
+ def _draw_glowing_point(self, frame, point, color, radius=4):
241
+ """Draw a point with glow effect"""
242
+ # Draw main point
243
+ cv2.circle(frame, tuple(point), radius, color, -1)
244
+
245
+ # Draw glow
246
+ for i in range(3):
247
+ alpha = 0.3 - i * 0.1
248
+ r = radius + i * 2
249
+ cv2.circle(frame, tuple(point), r,
250
+ tuple([int(c * alpha) for c in color]),
251
+ -1)
252
+
253
+ def _landmarks_to_array(self, landmarks):
254
+ """Convert MediaPipe landmarks to numpy array"""
255
+ points = []
256
+ for landmark in landmarks.landmark:
257
+ points.append([landmark.x, landmark.y, landmark.z])
258
+ return np.array(points)
259
+
260
+ def _mirror_movements(self, landmarks):
261
+ """Mirror the input movements"""
262
+ mirrored = landmarks.copy()
263
+ mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
264
+ return mirrored
265
+
266
+ def _update_style_memory(self, landmarks):
267
+ """Update memory of dance style"""
268
+ self.style_memory.append(landmarks)
269
+ if len(self.style_memory) > 30: # Keep last 30 frames
270
+ self.style_memory.pop(0)
271
+
272
+ def _generate_style_based_moves(self):
273
+ """Generate new moves based on learned style"""
274
+ if not self.style_memory:
275
+ return np.zeros((33, 3)) # Default pose shape
276
+
277
+ # Simple implementation: interpolate between stored poses
278
+ base_pose = self.style_memory[-1]
279
+ if len(self.style_memory) > 1:
280
+ prev_pose = self.style_memory[-2]
281
+ t = np.random.random()
282
+ new_pose = t * base_pose + (1-t) * prev_pose
283
+ else:
284
+ new_pose = base_pose
285
+
286
+ return new_pose
287
+
288
+ def _create_dance_frame(self, pose_array):
289
+ """Create visualization frame from pose array"""
290
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
291
+
292
+ # Convert normalized coordinates to pixel coordinates
293
+ points = (pose_array[:, :2] * [640, 480]).astype(int)
294
+
295
+ # Draw connections between joints
296
+ connections = self._get_pose_connections()
297
+ for connection in connections:
298
+ start_idx, end_idx = connection
299
+ if start_idx < len(points) and end_idx < len(points):
300
+ cv2.line(frame,
301
+ tuple(points[start_idx]),
302
+ tuple(points[end_idx]),
303
+ (0, 255, 0), 2)
304
+
305
+ # Draw joints
306
+ for point in points:
307
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
308
+
309
+ return frame
310
+
311
+ def _get_pose_connections(self):
312
+ """Define connections between pose landmarks"""
313
+ return [
314
+ (0, 1), (1, 2), (2, 3), (3, 7), # Face
315
+ (0, 4), (4, 5), (5, 6), (6, 8),
316
+ (9, 10), (11, 12), (11, 13), (13, 15), # Arms
317
+ (12, 14), (14, 16),
318
+ (11, 23), (12, 24), # Torso
319
+ (23, 24), (23, 25), (24, 26), # Legs
320
+ (25, 27), (26, 28), (27, 29), (28, 30),
321
+ (29, 31), (30, 32)
322
+ ]
323
+
324
+ def _smooth_transition(self, prev_pose, current_pose, smoothing_factor=0.3):
325
+ """Create smooth transition between poses"""
326
+ if prev_pose is None or current_pose is None:
327
+ return current_pose
328
+
329
+ # Interpolate between previous and current pose
330
+ smoothed_pose = (1 - smoothing_factor) * prev_pose + smoothing_factor * current_pose
331
+
332
+ # Ensure the smoothed pose maintains proper proportions
333
+ # Normalize joint positions relative to hip center
334
+ hip_center_idx = 23 # Index for hip center landmark
335
+
336
+ prev_hip = prev_pose[hip_center_idx]
337
+ current_hip = current_pose[hip_center_idx]
338
+ smoothed_hip = smoothed_pose[hip_center_idx]
339
+
340
+ # Adjust positions relative to hip center
341
+ for i in range(len(smoothed_pose)):
342
+ if i != hip_center_idx:
343
+ # Calculate relative positions
344
+ prev_relative = prev_pose[i] - prev_hip
345
+ current_relative = current_pose[i] - current_hip
346
+
347
+ # Interpolate relative positions
348
+ smoothed_relative = (1 - smoothing_factor) * prev_relative + smoothing_factor * current_relative
349
+
350
+ # Update smoothed pose
351
+ smoothed_pose[i] = smoothed_hip + smoothed_relative
352
+
353
+ return smoothed_pose
354
+
355
+ class AIDancePartner:
356
+ def __init__(self):
357
+ self.pose_detector = PoseDetector()
358
+ self.dance_generator = DanceGenerator()
359
+
360
+ def process_video(self, video_path, mode="Sync Partner"):
361
+ # Create a temporary directory for output
362
+ temp_dir = tempfile.mkdtemp()
363
+ output_path = os.path.join(temp_dir, 'output_dance.mp4')
364
+
365
+ cap = cv2.VideoCapture(video_path)
366
+
367
+ # Get video properties
368
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
369
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
370
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
371
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
372
+
373
+ # Create output video writer
374
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
375
+ out = cv2.VideoWriter(output_path, fourcc, fps,
376
+ (frame_width * 2, frame_height))
377
+
378
+ # Pre-process video to extract all poses
379
+ all_poses = []
380
+ frame_count = 0
381
+
382
+ while cap.isOpened():
383
+ ret, frame = cap.read()
384
+ if not ret:
385
+ break
386
+
387
+ pose_landmarks = self.pose_detector.detect_pose(frame)
388
+ all_poses.append(pose_landmarks)
389
+ frame_count += 1
390
+
391
+ # Generate AI dance sequence
392
+ ai_sequence = self.dance_generator.generate_dance_sequence(
393
+ all_poses,
394
+ mode,
395
+ total_frames,
396
+ (frame_height, frame_width)
397
+ )
398
+
399
+ # Reset video capture and create final video
400
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
401
+ frame_count = 0
402
+
403
+ while cap.isOpened():
404
+ ret, frame = cap.read()
405
+ if not ret:
406
+ break
407
+
408
+ # Get corresponding AI frame
409
+ ai_frame = ai_sequence[frame_count]
410
+
411
+ # Combine frames side by side
412
+ combined_frame = np.hstack([frame, ai_frame])
413
+
414
+ # Write frame to output video
415
+ out.write(combined_frame)
416
+ frame_count += 1
417
+
418
+ # Release resources
419
+ cap.release()
420
+ out.release()
421
+
422
+ return output_path