ombhojane commited on
Commit
46b86ce
·
verified ·
1 Parent(s): cb32870

Upload colab.py

Browse files
Files changed (1) hide show
  1. colab.py +469 -0
colab.py ADDED
@@ -0,0 +1,469 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Import necessary libraries
3
+ import cv2
4
+ import mediapipe as mp
5
+ import numpy as np
6
+ from scipy.interpolate import interp1d
7
+ from google.colab import files
8
+ from IPython.display import HTML, display
9
+ from IPython.display import clear_output
10
+ import time
11
+ import base64
12
+
13
+ class PoseDetector:
14
+ def __init__(self):
15
+ self.mp_pose = mp.solutions.pose
16
+ self.pose = self.mp_pose.Pose(
17
+ min_detection_confidence=0.5,
18
+ min_tracking_confidence=0.5
19
+ )
20
+
21
+ def detect_pose(self, frame):
22
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
23
+ results = self.pose.process(rgb_frame)
24
+ return results.pose_landmarks if results.pose_landmarks else None
25
+
26
+ class DanceGenerator:
27
+ def __init__(self):
28
+ self.prev_moves = []
29
+ self.style_memory = []
30
+ self.rhythm_patterns = []
31
+
32
+ def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
33
+ height, width = frame_size
34
+ sequence = []
35
+
36
+ if mode == "Sync Partner":
37
+ sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
38
+ else:
39
+ sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
40
+
41
+ return sequence
42
+
43
+ def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
44
+ height, width = frame_size
45
+ sequence = []
46
+
47
+ # Enhanced rhythm analysis
48
+ rhythm_window = 10 # Analyze chunks of frames for rhythm
49
+ beat_positions = self._detect_dance_beats(all_poses, rhythm_window)
50
+
51
+ pose_arrays = []
52
+ for pose in all_poses:
53
+ if pose is not None:
54
+ pose_arrays.append(self._landmarks_to_array(pose))
55
+ else:
56
+ pose_arrays.append(None)
57
+
58
+ for i in range(total_frames):
59
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
60
+
61
+ if pose_arrays[i] is not None:
62
+ # Enhanced mirroring with rhythm awareness
63
+ mirrored = self._mirror_movements(pose_arrays[i])
64
+
65
+ # Apply rhythm-based movement enhancement
66
+ if i in beat_positions:
67
+ mirrored = self._enhance_movement_on_beat(mirrored)
68
+
69
+ if i > 0 and pose_arrays[i-1] is not None:
70
+ mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
71
+
72
+ frame = self._create_enhanced_dance_frame(
73
+ mirrored,
74
+ frame_size,
75
+ add_effects=True
76
+ )
77
+
78
+ sequence.append(frame)
79
+
80
+ return sequence
81
+
82
+ def _detect_dance_beats(self, poses, window_size):
83
+ """Detect main beats in the dance sequence"""
84
+ beat_positions = []
85
+
86
+ if len(poses) < window_size:
87
+ return beat_positions
88
+
89
+ for i in range(window_size, len(poses)):
90
+ if poses[i] is not None and poses[i-1] is not None:
91
+ curr_pose = self._landmarks_to_array(poses[i])
92
+ prev_pose = self._landmarks_to_array(poses[i-1])
93
+
94
+ # Calculate movement magnitude
95
+ movement = np.mean(np.abs(curr_pose - prev_pose))
96
+
97
+ # Detect significant movements as beats
98
+ if movement > np.mean(self.rhythm_patterns) + np.std(self.rhythm_patterns):
99
+ beat_positions.append(i)
100
+
101
+ return beat_positions
102
+
103
+ def _enhance_movement_on_beat(self, pose):
104
+ """Enhance movements during detected beats"""
105
+ # Amplify movements slightly on beats
106
+ center = np.mean(pose, axis=0)
107
+ enhanced_pose = pose.copy()
108
+
109
+ for i in range(len(pose)):
110
+ # Amplify movement relative to center
111
+ vector = pose[i] - center
112
+ enhanced_pose[i] = center + vector * 1.2
113
+
114
+ return enhanced_pose
115
+
116
+ def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
117
+ """Generate creative dance sequence based on style"""
118
+ height, width = frame_size
119
+ sequence = []
120
+
121
+ # Analyze style from all poses
122
+ style_patterns = self._analyze_style_patterns(all_poses)
123
+
124
+ # Generate new sequence using style patterns
125
+ for i in range(total_frames):
126
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
127
+
128
+ # Generate new pose based on style
129
+ new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
130
+
131
+ if new_pose is not None:
132
+ frame = self._create_enhanced_dance_frame(
133
+ new_pose,
134
+ frame_size,
135
+ add_effects=True
136
+ )
137
+
138
+ sequence.append(frame)
139
+
140
+ return sequence
141
+
142
+ def _analyze_style_patterns(self, poses):
143
+ """Enhanced style analysis including rhythm and movement patterns"""
144
+ patterns = []
145
+ rhythm_data = []
146
+
147
+ for i in range(1, len(poses)):
148
+ if poses[i] is not None and poses[i-1] is not None:
149
+ # Calculate movement speed and direction
150
+ curr_pose = self._landmarks_to_array(poses[i])
151
+ prev_pose = self._landmarks_to_array(poses[i-1])
152
+
153
+ # Analyze movement velocity
154
+ velocity = np.mean(np.abs(curr_pose - prev_pose), axis=0)
155
+ rhythm_data.append(velocity)
156
+
157
+ # Store enhanced pattern data
158
+ pattern_info = {
159
+ 'pose': curr_pose,
160
+ 'velocity': velocity,
161
+ 'acceleration': velocity if i == 1 else velocity - prev_velocity
162
+ }
163
+ patterns.append(pattern_info)
164
+ prev_velocity = velocity
165
+
166
+ self.rhythm_patterns = rhythm_data
167
+ return patterns
168
+
169
+ def _generate_style_based_pose(self, patterns, progress):
170
+ """Generate new pose based on style patterns and progress"""
171
+ if not patterns:
172
+ return None
173
+
174
+ # Create smooth interpolation between poses
175
+ num_patterns = len(patterns)
176
+ pattern_idx = int(progress * (num_patterns - 1))
177
+
178
+ if pattern_idx < num_patterns - 1:
179
+ t = progress * (num_patterns - 1) - pattern_idx
180
+ pose = self._interpolate_poses(
181
+ patterns[pattern_idx],
182
+ patterns[pattern_idx + 1],
183
+ t
184
+ )
185
+ else:
186
+ pose = patterns[-1]
187
+
188
+ return pose
189
+
190
+ def _interpolate_poses(self, pose1, pose2, t):
191
+ """Smoothly interpolate between two poses"""
192
+ return pose1 * (1 - t) + pose2 * t
193
+
194
+ def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
195
+ """Create enhanced visualization frame with effects"""
196
+ height, width = frame_size
197
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
198
+
199
+ # Convert coordinates
200
+ points = (pose_array[:, :2] * [width, height]).astype(int)
201
+
202
+ # Draw enhanced skeleton
203
+ connections = self._get_pose_connections()
204
+ for connection in connections:
205
+ start_idx, end_idx = connection
206
+ if start_idx < len(points) and end_idx < len(points):
207
+ # Draw glowing lines
208
+ if add_effects:
209
+ self._draw_glowing_line(
210
+ frame,
211
+ points[start_idx],
212
+ points[end_idx],
213
+ (0, 255, 0)
214
+ )
215
+ else:
216
+ cv2.line(frame,
217
+ tuple(points[start_idx]),
218
+ tuple(points[end_idx]),
219
+ (0, 255, 0), 2)
220
+
221
+ # Draw enhanced joints
222
+ for point in points:
223
+ if add_effects:
224
+ self._draw_glowing_point(frame, point, (0, 0, 255))
225
+ else:
226
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
227
+
228
+ return frame
229
+
230
+ def _draw_glowing_line(self, frame, start, end, color, thickness=2):
231
+ """Draw a line with glow effect"""
232
+ # Draw main line
233
+ cv2.line(frame, tuple(start), tuple(end), color, thickness)
234
+
235
+ # Draw glow
236
+ for i in range(3):
237
+ alpha = 0.3 - i * 0.1
238
+ thickness = thickness + 2
239
+ cv2.line(frame, tuple(start), tuple(end),
240
+ tuple([int(c * alpha) for c in color]),
241
+ thickness)
242
+
243
+ def _draw_glowing_point(self, frame, point, color, radius=4):
244
+ """Draw a point with glow effect"""
245
+ # Draw main point
246
+ cv2.circle(frame, tuple(point), radius, color, -1)
247
+
248
+ # Draw glow
249
+ for i in range(3):
250
+ alpha = 0.3 - i * 0.1
251
+ r = radius + i * 2
252
+ cv2.circle(frame, tuple(point), r,
253
+ tuple([int(c * alpha) for c in color]),
254
+ -1)
255
+
256
+ def _landmarks_to_array(self, landmarks):
257
+ """Convert MediaPipe landmarks to numpy array"""
258
+ points = []
259
+ for landmark in landmarks.landmark:
260
+ points.append([landmark.x, landmark.y, landmark.z])
261
+ return np.array(points)
262
+
263
+ def _mirror_movements(self, landmarks):
264
+ """Mirror the input movements"""
265
+ mirrored = landmarks.copy()
266
+ mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
267
+ return mirrored
268
+
269
+ def _update_style_memory(self, landmarks):
270
+ """Update memory of dance style"""
271
+ self.style_memory.append(landmarks)
272
+ if len(self.style_memory) > 30: # Keep last 30 frames
273
+ self.style_memory.pop(0)
274
+
275
+ def _generate_style_based_moves(self):
276
+ """Generate new moves based on learned style"""
277
+ if not self.style_memory:
278
+ return np.zeros((33, 3)) # Default pose shape
279
+
280
+ # Simple implementation: interpolate between stored poses
281
+ base_pose = self.style_memory[-1]
282
+ if len(self.style_memory) > 1:
283
+ prev_pose = self.style_memory[-2]
284
+ t = np.random.random()
285
+ new_pose = t * base_pose + (1-t) * prev_pose
286
+ else:
287
+ new_pose = base_pose
288
+
289
+ return new_pose
290
+
291
+ def _create_dance_frame(self, pose_array):
292
+ """Create visualization frame from pose array"""
293
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
294
+
295
+ # Convert normalized coordinates to pixel coordinates
296
+ points = (pose_array[:, :2] * [640, 480]).astype(int)
297
+
298
+ # Draw connections between joints
299
+ connections = self._get_pose_connections()
300
+ for connection in connections:
301
+ start_idx, end_idx = connection
302
+ if start_idx < len(points) and end_idx < len(points):
303
+ cv2.line(frame,
304
+ tuple(points[start_idx]),
305
+ tuple(points[end_idx]),
306
+ (0, 255, 0), 2)
307
+
308
+ # Draw joints
309
+ for point in points:
310
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
311
+
312
+ return frame
313
+
314
+ def _get_pose_connections(self):
315
+ """Define connections between pose landmarks"""
316
+ return [
317
+ (0, 1), (1, 2), (2, 3), (3, 7), # Face
318
+ (0, 4), (4, 5), (5, 6), (6, 8),
319
+ (9, 10), (11, 12), (11, 13), (13, 15), # Arms
320
+ (12, 14), (14, 16),
321
+ (11, 23), (12, 24), # Torso
322
+ (23, 24), (23, 25), (24, 26), # Legs
323
+ (25, 27), (26, 28), (27, 29), (28, 30),
324
+ (29, 31), (30, 32)
325
+ ]
326
+
327
+ def _smooth_transition(self, prev_pose, current_pose, smoothing_factor=0.3):
328
+ """Create smooth transition between poses"""
329
+ if prev_pose is None or current_pose is None:
330
+ return current_pose
331
+
332
+ # Interpolate between previous and current pose
333
+ smoothed_pose = (1 - smoothing_factor) * prev_pose + smoothing_factor * current_pose
334
+
335
+ # Ensure the smoothed pose maintains proper proportions
336
+ # Normalize joint positions relative to hip center
337
+ hip_center_idx = 23 # Index for hip center landmark
338
+
339
+ prev_hip = prev_pose[hip_center_idx]
340
+ current_hip = current_pose[hip_center_idx]
341
+ smoothed_hip = smoothed_pose[hip_center_idx]
342
+
343
+ # Adjust positions relative to hip center
344
+ for i in range(len(smoothed_pose)):
345
+ if i != hip_center_idx:
346
+ # Calculate relative positions
347
+ prev_relative = prev_pose[i] - prev_hip
348
+ current_relative = current_pose[i] - current_hip
349
+
350
+ # Interpolate relative positions
351
+ smoothed_relative = (1 - smoothing_factor) * prev_relative + smoothing_factor * current_relative
352
+
353
+ # Update smoothed pose
354
+ smoothed_pose[i] = smoothed_hip + smoothed_relative
355
+
356
+ return smoothed_pose
357
+
358
+ class AIDancePartner:
359
+ def __init__(self):
360
+ self.pose_detector = PoseDetector()
361
+ self.dance_generator = DanceGenerator()
362
+
363
+ def process_video(self, video_path, mode="Sync Partner"):
364
+ cap = cv2.VideoCapture(video_path)
365
+
366
+ # Get video properties
367
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
368
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
369
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
370
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
371
+
372
+ # Create output video writer
373
+ output_path = 'output_dance.mp4'
374
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
375
+ out = cv2.VideoWriter(output_path, fourcc, fps,
376
+ (frame_width * 2, frame_height))
377
+
378
+ # Pre-process video to extract all poses
379
+ print("Processing video...")
380
+ all_poses = []
381
+ frame_count = 0
382
+
383
+ while cap.isOpened():
384
+ ret, frame = cap.read()
385
+ if not ret:
386
+ break
387
+
388
+ pose_landmarks = self.pose_detector.detect_pose(frame)
389
+ all_poses.append(pose_landmarks)
390
+
391
+ frame_count += 1
392
+ if frame_count % 30 == 0:
393
+ print(f"Processed {frame_count}/{total_frames} frames")
394
+
395
+ # Generate AI dance sequence
396
+ print("Generating AI dance sequence...")
397
+ ai_sequence = self.dance_generator.generate_dance_sequence(
398
+ all_poses,
399
+ mode,
400
+ total_frames,
401
+ (frame_height, frame_width)
402
+ )
403
+
404
+ # Reset video capture and create final video
405
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
406
+ frame_count = 0
407
+ print("Creating final video...")
408
+
409
+ while cap.isOpened():
410
+ ret, frame = cap.read()
411
+ if not ret:
412
+ break
413
+
414
+ # Get corresponding AI frame
415
+ ai_frame = ai_sequence[frame_count]
416
+
417
+ # Combine frames side by side
418
+ combined_frame = np.hstack([frame, ai_frame])
419
+
420
+ # Write frame to output video
421
+ out.write(combined_frame)
422
+
423
+ frame_count += 1
424
+ if frame_count % 30 == 0:
425
+ print(f"Processed {frame_count}/{total_frames} frames")
426
+
427
+ # Release resources
428
+ cap.release()
429
+ out.release()
430
+
431
+ print("Video processing complete!")
432
+ return output_path
433
+
434
+ def upload_and_process_video():
435
+ """Function to handle video upload and processing in Colab"""
436
+ print("Please upload a video file...")
437
+ uploaded = files.upload()
438
+
439
+ if uploaded:
440
+ video_path = list(uploaded.keys())[0]
441
+ print(f"Processing video: {video_path}")
442
+
443
+ # Create AI Dance Partner instance
444
+ dance_partner = AIDancePartner()
445
+
446
+ # Process video
447
+ output_path = dance_partner.process_video(video_path)
448
+
449
+ # Display the output video
450
+ def create_video_player(video_path):
451
+ video_html = f'''
452
+ <video width="100%" height="480" controls>
453
+ <source src="data:video/mp4;base64,{base64.b64encode(open(output_path, 'rb').read()).decode()}" type="video/mp4">
454
+ Your browser does not support the video tag.
455
+ </video>
456
+ '''
457
+ return HTML(video_html)
458
+
459
+ print("Displaying processed video...")
460
+ display(create_video_player(output_path))
461
+
462
+ # Option to download the processed video
463
+ files.download(output_path)
464
+
465
+ # Run the application
466
+ if __name__ == "__main__":
467
+ print("AI Dance Partner - Video Processing")
468
+ print("==================================")
469
+ upload_and_process_video()