ombhojane commited on
Commit
ec27c1f
·
verified ·
1 Parent(s): ddad4bf

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +174 -201
  2. colab.py +424 -425
app.py CHANGED
@@ -1,202 +1,175 @@
1
- import streamlit as st
2
- from colab import AIDancePartner
3
- import tempfile
4
- import os
5
- import time
6
- import cv2
7
- from PIL import Image
8
- import io
9
-
10
- # Set page configuration
11
- st.set_page_config(
12
- page_title="AI Dance Partner",
13
- page_icon="💃",
14
- layout="wide",
15
- initial_sidebar_state="expanded"
16
- )
17
-
18
- # Custom CSS for better styling
19
- def local_css():
20
- st.markdown("""
21
- <style>
22
- .main {
23
- padding: 2rem;
24
- }
25
- .stButton>button {
26
- background-color: #FF4B4B;
27
- color: white;
28
- border-radius: 20px;
29
- padding: 0.5rem 2rem;
30
- font-weight: bold;
31
- }
32
- .stButton>button:hover {
33
- background-color: #FF6B6B;
34
- border-color: #FF4B4B;
35
- }
36
- .upload-text {
37
- font-size: 1.2rem;
38
- color: #666;
39
- margin-bottom: 1rem;
40
- }
41
- .title-container {
42
- background: linear-gradient(90deg, #FF4B4B, #FF8C8C);
43
- padding: 2rem;
44
- border-radius: 10px;
45
- margin-bottom: 2rem;
46
- color: white;
47
- text-align: center;
48
- }
49
- .info-box {
50
- background-color: #f0f2f6;
51
- padding: 1rem;
52
- border-radius: 10px;
53
- margin-bottom: 1rem;
54
- }
55
- </style>
56
- """, unsafe_allow_html=True)
57
-
58
- def get_video_preview(video_path):
59
- """Generate a preview frame from the video"""
60
- cap = cv2.VideoCapture(video_path)
61
- ret, frame = cap.read()
62
- cap.release()
63
-
64
- if ret:
65
- # Convert BGR to RGB
66
- frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
67
- return Image.fromarray(frame)
68
- return None
69
-
70
- def main():
71
- local_css()
72
-
73
- # Title section with gradient background
74
- st.markdown("""
75
- <div class="title-container">
76
- <h1>🕺 AI Dance Partner 💃</h1>
77
- <p style="font-size: 1.2rem;">Transform your solo dance into a dynamic duet!</p>
78
- </div>
79
- """, unsafe_allow_html=True)
80
-
81
- # Create two columns for layout
82
- col1, col2 = st.columns([2, 1])
83
-
84
- with col1:
85
- st.markdown('<p class="upload-text">Upload your dance video and watch the magic happen!</p>', unsafe_allow_html=True)
86
- uploaded_file = st.file_uploader("", type=['mp4', 'avi', 'mov'])
87
-
88
- # Add video preview
89
- if uploaded_file is not None:
90
- # Create a temporary file for the uploaded video
91
- with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
92
- tfile.write(uploaded_file.read())
93
- temp_input_path = tfile.name
94
-
95
- # Show video preview
96
- st.markdown("### 📽️ Preview")
97
- preview_image = get_video_preview(temp_input_path)
98
- if preview_image:
99
- st.image(preview_image, use_column_width=True, caption="Video Preview")
100
-
101
- # Add video player for original
102
- st.markdown("### 🎥 Original Video")
103
- st.video(temp_input_path)
104
-
105
- with col2:
106
- st.markdown('<div class="info-box">', unsafe_allow_html=True)
107
- st.markdown("### How it works")
108
- st.markdown("""
109
- 1. Upload your solo dance video
110
- 2. Choose your preferred dance style
111
- 3. Watch as AI creates your perfect dance partner!
112
- """)
113
- st.markdown('</div>', unsafe_allow_html=True)
114
-
115
- if uploaded_file is not None:
116
- # Style selection with custom design
117
- st.markdown("### 🎭 Choose Your Dance Partner Style")
118
- style = st.select_slider(
119
- "",
120
- options=["Sync Partner", "Creative Partner"],
121
- value="Sync Partner"
122
- )
123
-
124
- # Add description based on selected style
125
- if style == "Sync Partner":
126
- st.info("💫 Sync Partner will mirror your movements in perfect harmony.")
127
- else:
128
- st.info("🎨 Creative Partner will add its own artistic flair to your dance.")
129
-
130
- if st.button("Generate Dance Partner 🎬"):
131
- try:
132
- # Create a progress bar
133
- progress_bar = st.progress(0)
134
- status_text = st.empty()
135
-
136
- # Processing steps with more detailed progress
137
- steps = [
138
- "Analyzing dance moves...",
139
- "Detecting pose landmarks...",
140
- "Generating partner movements...",
141
- "Applying style patterns...",
142
- "Creating final video..."
143
- ]
144
-
145
- for i, step in enumerate(steps):
146
- status_text.text(step)
147
- progress_bar.progress((i + 1) * 20)
148
- time.sleep(0.5)
149
-
150
- # Process video
151
- dance_partner = AIDancePartner()
152
- output_path = dance_partner.process_video(temp_input_path, mode=style)
153
-
154
- # Update progress
155
- progress_bar.progress(100)
156
- status_text.text("Done! 🎉")
157
-
158
- # Display result in a nice container
159
- st.markdown("### 🎥 Your Dance Duet")
160
-
161
- # Show preview of the output
162
- preview_output = get_video_preview(output_path)
163
- if preview_output:
164
- st.image(preview_output, use_column_width=True, caption="Dance Duet Preview")
165
-
166
- # Display the video
167
- st.video(output_path)
168
-
169
- # Download button with custom styling
170
- with open(output_path, 'rb') as file:
171
- st.download_button(
172
- label="Download Your Dance Duet 📥",
173
- data=file,
174
- file_name="ai_dance_partner.mp4",
175
- mime="video/mp4"
176
- )
177
-
178
- # Cleanup temporary files
179
- os.unlink(temp_input_path)
180
- os.unlink(output_path)
181
-
182
- except Exception as e:
183
- st.error(f"Oops! Something went wrong: {str(e)}")
184
- if os.path.exists(temp_input_path):
185
- os.unlink(temp_input_path)
186
-
187
- # Add footer with additional information
188
- st.markdown("""
189
- ---
190
- <div style="text-align: center;">
191
- <h3>🌟 Features</h3>
192
- <p>• Real-time pose detection</p>
193
- <p>• Synchronized movement matching</p>
194
- <p>• Creative dance style generation</p>
195
- <p>• High-quality video output</p>
196
- <br>
197
- <p style="color: #666;">Made with ❤️ by AI Dance Partner Team</p>
198
- </div>
199
- """, unsafe_allow_html=True)
200
-
201
- if __name__ == "__main__":
202
  main()
 
1
+ import streamlit as st
2
+ from colab import AIDancePartner
3
+ import tempfile
4
+ import os
5
+ import time
6
+ import cv2
7
+ from PIL import Image
8
+ import io
9
+
10
+ # Set page configuration
11
+ st.set_page_config(
12
+ page_title="AI Dance Partner",
13
+ page_icon="💃",
14
+ layout="wide",
15
+ initial_sidebar_state="expanded"
16
+ )
17
+
18
+ # Custom CSS for better styling
19
+ def local_css():
20
+ st.markdown("""
21
+ <style>
22
+ .main {
23
+ padding: 2rem;
24
+ }
25
+ .stButton>button {
26
+ background-color: #FF4B4B;
27
+ color: white;
28
+ border-radius: 20px;
29
+ padding: 0.5rem 2rem;
30
+ font-weight: bold;
31
+ }
32
+ .stButton>button:hover {
33
+ background-color: #FF6B6B;
34
+ border-color: #FF4B4B;
35
+ }
36
+ .upload-text {
37
+ font-size: 1.2rem;
38
+ color: #666;
39
+ margin-bottom: 1rem;
40
+ }
41
+ .title-container {
42
+ background: linear-gradient(90deg, #FF4B4B, #FF8C8C);
43
+ padding: 2rem;
44
+ border-radius: 10px;
45
+ margin-bottom: 2rem;
46
+ color: white;
47
+ text-align: center;
48
+ }
49
+ .info-box {
50
+ background-color: #f0f2f6;
51
+ padding: 1rem;
52
+ border-radius: 10px;
53
+ margin-bottom: 1rem;
54
+ }
55
+ </style>
56
+ """, unsafe_allow_html=True)
57
+
58
+ def get_video_preview(video_path):
59
+ """Generate a preview frame from the video"""
60
+ cap = cv2.VideoCapture(video_path)
61
+ ret, frame = cap.read()
62
+ cap.release()
63
+
64
+ if ret:
65
+ # Convert BGR to RGB
66
+ frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
67
+ return Image.fromarray(frame)
68
+ return None
69
+
70
+ def main():
71
+ local_css()
72
+
73
+ # Title section with gradient background
74
+ st.markdown("""
75
+ <div class="title-container">
76
+ <h1>🕺 AI Dance Partner 💃</h1>
77
+ <p style="font-size: 1.2rem;">Transform your solo dance into a dynamic duet!</p>
78
+ </div>
79
+ """, unsafe_allow_html=True)
80
+
81
+ # Create two columns for layout
82
+ col1, col2 = st.columns([2, 1])
83
+
84
+ with col1:
85
+ st.markdown('<p class="upload-text">Upload your dance video and watch the magic happen!</p>', unsafe_allow_html=True)
86
+ uploaded_file = st.file_uploader("", type=['mp4', 'avi', 'mov'])
87
+
88
+ if uploaded_file is not None:
89
+ # Create a temporary file for the uploaded video
90
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as tfile:
91
+ tfile.write(uploaded_file.read())
92
+ temp_input_path = tfile.name
93
+
94
+ # Show video preview
95
+ st.markdown("### 📽️ Preview")
96
+ preview_image = get_video_preview(temp_input_path)
97
+ if preview_image:
98
+ st.image(preview_image, use_container_width=True, caption="Video Preview")
99
+
100
+ # Add video player for original
101
+ st.markdown("### 🎥 Original Video")
102
+ st.video(temp_input_path)
103
+
104
+ with col2:
105
+ st.markdown('<div class="info-box">', unsafe_allow_html=True)
106
+ st.markdown("### How it works")
107
+ st.markdown("""
108
+ 1. Upload your solo dance video
109
+ 2. Choose your preferred dance style
110
+ 3. Watch as AI creates your perfect dance partner!
111
+ """)
112
+ st.markdown('</div>', unsafe_allow_html=True)
113
+
114
+ if uploaded_file is not None:
115
+ st.markdown("### 🎭 Choose Your Dance Partner Style")
116
+ style = st.select_slider(
117
+ "",
118
+ options=["Sync Partner", "Creative Partner"],
119
+ value="Sync Partner"
120
+ )
121
+
122
+ if style == "Sync Partner":
123
+ st.info("💫 Sync Partner will mirror your movements in perfect harmony.")
124
+ else:
125
+ st.info("🎨 Creative Partner will add its own artistic flair to your dance.")
126
+
127
+ if st.button("Generate Dance Partner 🎬"):
128
+ try:
129
+ progress_bar = st.progress(0)
130
+ status_text = st.empty()
131
+
132
+ steps = [
133
+ "Analyzing dance moves...",
134
+ "Detecting pose landmarks...",
135
+ "Generating partner movements...",
136
+ "Creating final video..."
137
+ ]
138
+
139
+ for i, step in enumerate(steps):
140
+ status_text.text(step)
141
+ progress_bar.progress((i + 1) * 25)
142
+ time.sleep(0.5)
143
+
144
+ # Process video
145
+ dance_partner = AIDancePartner()
146
+ output_path = dance_partner.process_video(temp_input_path, mode=style)
147
+
148
+ # Update progress
149
+ progress_bar.progress(100)
150
+ status_text.text("Done! 🎉")
151
+
152
+ # Display result
153
+ st.markdown("### 🎥 Your Dance Duet")
154
+ st.video(output_path)
155
+
156
+ # Download button
157
+ with open(output_path, 'rb') as file:
158
+ st.download_button(
159
+ label="Download Video 📥",
160
+ data=file,
161
+ file_name="ai_dance_partner.mp4",
162
+ mime="video/mp4"
163
+ )
164
+
165
+ # Cleanup temporary files
166
+ os.unlink(temp_input_path)
167
+ os.unlink(output_path)
168
+
169
+ except Exception as e:
170
+ st.error(f"Oops! Something went wrong: {str(e)}")
171
+ if os.path.exists(temp_input_path):
172
+ os.unlink(temp_input_path)
173
+
174
+ if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  main()
colab.py CHANGED
@@ -1,426 +1,425 @@
1
- # Import necessary libraries
2
- import cv2
3
- import mediapipe as mp
4
- import numpy as np
5
- from scipy.interpolate import interp1d
6
- import time
7
- import os
8
- import tempfile
9
-
10
- class PoseDetector:
11
- def __init__(self):
12
- self.mp_pose = mp.solutions.pose
13
- self.pose = self.mp_pose.Pose(
14
- min_detection_confidence=0.5,
15
- min_tracking_confidence=0.5
16
- )
17
-
18
- def detect_pose(self, frame):
19
- rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
20
- results = self.pose.process(rgb_frame)
21
- return results.pose_landmarks if results.pose_landmarks else None
22
-
23
- class DanceGenerator:
24
- def __init__(self):
25
- self.prev_moves = []
26
- self.style_memory = []
27
- self.rhythm_patterns = []
28
-
29
- def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
30
- height, width = frame_size
31
- sequence = []
32
-
33
- if mode == "Sync Partner":
34
- sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
35
- else:
36
- sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
37
-
38
- return sequence
39
-
40
- def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
41
- height, width = frame_size
42
- sequence = []
43
-
44
- # Enhanced rhythm analysis
45
- rhythm_window = 10 # Analyze chunks of frames for rhythm
46
- beat_positions = self._detect_dance_beats(all_poses, rhythm_window)
47
-
48
- pose_arrays = []
49
- for pose in all_poses:
50
- if pose is not None:
51
- pose_arrays.append(self._landmarks_to_array(pose))
52
- else:
53
- pose_arrays.append(None)
54
-
55
- for i in range(total_frames):
56
- frame = np.zeros((height, width, 3), dtype=np.uint8)
57
-
58
- if pose_arrays[i] is not None:
59
- # Enhanced mirroring with rhythm awareness
60
- mirrored = self._mirror_movements(pose_arrays[i])
61
-
62
- # Apply rhythm-based movement enhancement
63
- if i in beat_positions:
64
- mirrored = self._enhance_movement_on_beat(mirrored)
65
-
66
- if i > 0 and pose_arrays[i-1] is not None:
67
- mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
68
-
69
- frame = self._create_enhanced_dance_frame(
70
- mirrored,
71
- frame_size,
72
- add_effects=True
73
- )
74
-
75
- sequence.append(frame)
76
-
77
- return sequence
78
-
79
- def _detect_dance_beats(self, poses, window_size):
80
- """Detect main beats in the dance sequence"""
81
- beat_positions = []
82
-
83
- if len(poses) < window_size:
84
- return beat_positions
85
-
86
- for i in range(window_size, len(poses)):
87
- if poses[i] is not None and poses[i-1] is not None:
88
- curr_pose = self._landmarks_to_array(poses[i])
89
- prev_pose = self._landmarks_to_array(poses[i-1])
90
-
91
- # Calculate movement magnitude
92
- movement = np.mean(np.abs(curr_pose - prev_pose))
93
-
94
- # Detect significant movements as beats
95
- if movement > np.mean(self.rhythm_patterns) + np.std(self.rhythm_patterns):
96
- beat_positions.append(i)
97
-
98
- return beat_positions
99
-
100
- def _enhance_movement_on_beat(self, pose):
101
- """Enhance movements during detected beats"""
102
- # Amplify movements slightly on beats
103
- center = np.mean(pose, axis=0)
104
- enhanced_pose = pose.copy()
105
-
106
- for i in range(len(pose)):
107
- # Amplify movement relative to center
108
- vector = pose[i] - center
109
- enhanced_pose[i] = center + vector * 1.2
110
-
111
- return enhanced_pose
112
-
113
- def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
114
- """Generate creative dance sequence based on style"""
115
- height, width = frame_size
116
- sequence = []
117
-
118
- # Analyze style from all poses
119
- style_patterns = self._analyze_style_patterns(all_poses)
120
-
121
- # Generate new sequence using style patterns
122
- for i in range(total_frames):
123
- frame = np.zeros((height, width, 3), dtype=np.uint8)
124
-
125
- # Generate new pose based on style
126
- new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
127
-
128
- if new_pose is not None:
129
- frame = self._create_enhanced_dance_frame(
130
- new_pose,
131
- frame_size,
132
- add_effects=True
133
- )
134
-
135
- sequence.append(frame)
136
-
137
- return sequence
138
-
139
- def _analyze_style_patterns(self, poses):
140
- """Enhanced style analysis including rhythm and movement patterns"""
141
- patterns = []
142
- rhythm_data = []
143
-
144
- for i in range(1, len(poses)):
145
- if poses[i] is not None and poses[i-1] is not None:
146
- # Calculate movement speed and direction
147
- curr_pose = self._landmarks_to_array(poses[i])
148
- prev_pose = self._landmarks_to_array(poses[i-1])
149
-
150
- # Analyze movement velocity
151
- velocity = np.mean(np.abs(curr_pose - prev_pose), axis=0)
152
- rhythm_data.append(velocity)
153
-
154
- # Store enhanced pattern data
155
- pattern_info = {
156
- 'pose': curr_pose,
157
- 'velocity': velocity,
158
- 'acceleration': velocity if i == 1 else velocity - prev_velocity
159
- }
160
- patterns.append(pattern_info)
161
- prev_velocity = velocity
162
-
163
- self.rhythm_patterns = rhythm_data
164
- return patterns
165
-
166
- def _generate_style_based_pose(self, patterns, progress):
167
- """Generate new pose based on style patterns and progress"""
168
- if not patterns:
169
- return None
170
-
171
- # Create smooth interpolation between poses
172
- num_patterns = len(patterns)
173
- pattern_idx = int(progress * (num_patterns - 1))
174
-
175
- if pattern_idx < num_patterns - 1:
176
- t = progress * (num_patterns - 1) - pattern_idx
177
- # Extract pose arrays from pattern dictionaries
178
- pose1 = patterns[pattern_idx]['pose']
179
- pose2 = patterns[pattern_idx + 1]['pose']
180
- pose = self._interpolate_poses(pose1, pose2, t)
181
- else:
182
- pose = patterns[-1]['pose']
183
-
184
- return pose
185
-
186
- def _interpolate_poses(self, pose1, pose2, t):
187
- """Smoothly interpolate between two poses"""
188
- if isinstance(pose1, dict):
189
- pose1 = pose1['pose']
190
- if isinstance(pose2, dict):
191
- pose2 = pose2['pose']
192
- return pose1 * (1 - t) + pose2 * t
193
-
194
- def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
195
- """Create enhanced visualization frame with effects"""
196
- height, width = frame_size
197
- # Change background from black to light gray for better visibility
198
- frame = np.ones((height, width, 3), dtype=np.uint8) * 240 # Light gray background
199
-
200
- # Convert coordinates
201
- points = (pose_array[:, :2] * [width, height]).astype(int)
202
-
203
- # Draw enhanced skeleton with thicker lines and more visible colors
204
- connections = self._get_pose_connections()
205
- for connection in connections:
206
- start_idx, end_idx = connection
207
- if start_idx < len(points) and end_idx < len(points):
208
- if add_effects:
209
- self._draw_glowing_line(
210
- frame,
211
- points[start_idx],
212
- points[end_idx],
213
- (0, 100, 255), # Orange color for skeleton
214
- thickness=4
215
- )
216
- else:
217
- cv2.line(frame,
218
- tuple(points[start_idx]),
219
- tuple(points[end_idx]),
220
- (0, 100, 255), 4)
221
-
222
- # Draw enhanced joints with larger radius
223
- for point in points:
224
- if add_effects:
225
- self._draw_glowing_point(frame, point, (255, 0, 0), radius=6) # Blue joints
226
- else:
227
- cv2.circle(frame, tuple(point), 6, (255, 0, 0), -1)
228
-
229
- return frame
230
-
231
- def _draw_glowing_line(self, frame, start, end, color, thickness=4):
232
- """Draw a line with enhanced glow effect"""
233
- # Draw outer glow
234
- for i in range(3):
235
- alpha = 0.5 - i * 0.15
236
- thick = thickness + (i * 4)
237
- cv2.line(frame, tuple(start), tuple(end),
238
- tuple([int(c * alpha) for c in color]),
239
- thick)
240
-
241
- # Draw main line
242
- cv2.line(frame, tuple(start), tuple(end), color, thickness)
243
-
244
- def _draw_glowing_point(self, frame, point, color, radius=6):
245
- """Draw a point with enhanced glow effect"""
246
- # Draw outer glow
247
- for i in range(3):
248
- alpha = 0.5 - i * 0.15
249
- r = radius + (i * 3)
250
- cv2.circle(frame, tuple(point), r,
251
- tuple([int(c * alpha) for c in color]),
252
- -1)
253
-
254
- # Draw main point
255
- cv2.circle(frame, tuple(point), radius, color, -1)
256
-
257
- def _landmarks_to_array(self, landmarks):
258
- """Convert MediaPipe landmarks to numpy array"""
259
- points = []
260
- for landmark in landmarks.landmark:
261
- points.append([landmark.x, landmark.y, landmark.z])
262
- return np.array(points)
263
-
264
- def _mirror_movements(self, landmarks):
265
- """Mirror the input movements"""
266
- mirrored = landmarks.copy()
267
- mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
268
- return mirrored
269
-
270
- def _update_style_memory(self, landmarks):
271
- """Update memory of dance style"""
272
- self.style_memory.append(landmarks)
273
- if len(self.style_memory) > 30: # Keep last 30 frames
274
- self.style_memory.pop(0)
275
-
276
- def _generate_style_based_moves(self):
277
- """Generate new moves based on learned style"""
278
- if not self.style_memory:
279
- return np.zeros((33, 3)) # Default pose shape
280
-
281
- # Simple implementation: interpolate between stored poses
282
- base_pose = self.style_memory[-1]
283
- if len(self.style_memory) > 1:
284
- prev_pose = self.style_memory[-2]
285
- t = np.random.random()
286
- new_pose = t * base_pose + (1-t) * prev_pose
287
- else:
288
- new_pose = base_pose
289
-
290
- return new_pose
291
-
292
- def _create_dance_frame(self, pose_array):
293
- """Create visualization frame from pose array"""
294
- frame = np.zeros((480, 640, 3), dtype=np.uint8)
295
-
296
- # Convert normalized coordinates to pixel coordinates
297
- points = (pose_array[:, :2] * [640, 480]).astype(int)
298
-
299
- # Draw connections between joints
300
- connections = self._get_pose_connections()
301
- for connection in connections:
302
- start_idx, end_idx = connection
303
- if start_idx < len(points) and end_idx < len(points):
304
- cv2.line(frame,
305
- tuple(points[start_idx]),
306
- tuple(points[end_idx]),
307
- (0, 255, 0), 2)
308
-
309
- # Draw joints
310
- for point in points:
311
- cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
312
-
313
- return frame
314
-
315
- def _get_pose_connections(self):
316
- """Define connections between pose landmarks"""
317
- return [
318
- (0, 1), (1, 2), (2, 3), (3, 7), # Face
319
- (0, 4), (4, 5), (5, 6), (6, 8),
320
- (9, 10), (11, 12), (11, 13), (13, 15), # Arms
321
- (12, 14), (14, 16),
322
- (11, 23), (12, 24), # Torso
323
- (23, 24), (23, 25), (24, 26), # Legs
324
- (25, 27), (26, 28), (27, 29), (28, 30),
325
- (29, 31), (30, 32)
326
- ]
327
-
328
- def _smooth_transition(self, prev_pose, current_pose, smoothing_factor=0.3):
329
- """Create smooth transition between poses"""
330
- if prev_pose is None or current_pose is None:
331
- return current_pose
332
-
333
- # Interpolate between previous and current pose
334
- smoothed_pose = (1 - smoothing_factor) * prev_pose + smoothing_factor * current_pose
335
-
336
- # Ensure the smoothed pose maintains proper proportions
337
- # Normalize joint positions relative to hip center
338
- hip_center_idx = 23 # Index for hip center landmark
339
-
340
- prev_hip = prev_pose[hip_center_idx]
341
- current_hip = current_pose[hip_center_idx]
342
- smoothed_hip = smoothed_pose[hip_center_idx]
343
-
344
- # Adjust positions relative to hip center
345
- for i in range(len(smoothed_pose)):
346
- if i != hip_center_idx:
347
- # Calculate relative positions
348
- prev_relative = prev_pose[i] - prev_hip
349
- current_relative = current_pose[i] - current_hip
350
-
351
- # Interpolate relative positions
352
- smoothed_relative = (1 - smoothing_factor) * prev_relative + smoothing_factor * current_relative
353
-
354
- # Update smoothed pose
355
- smoothed_pose[i] = smoothed_hip + smoothed_relative
356
-
357
- return smoothed_pose
358
-
359
- class AIDancePartner:
360
- def __init__(self):
361
- self.pose_detector = PoseDetector()
362
- self.dance_generator = DanceGenerator()
363
-
364
- def process_video(self, video_path, mode="Sync Partner"):
365
- # Create a temporary directory for output
366
- temp_dir = tempfile.mkdtemp()
367
- output_path = os.path.join(temp_dir, 'output_dance.mp4')
368
-
369
- cap = cv2.VideoCapture(video_path)
370
-
371
- # Get video properties
372
- fps = int(cap.get(cv2.CAP_PROP_FPS))
373
- frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
374
- frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
375
- total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
376
-
377
- # Create output video writer
378
- fourcc = cv2.VideoWriter_fourcc(*'mp4v')
379
- out = cv2.VideoWriter(output_path, fourcc, fps,
380
- (frame_width * 2, frame_height))
381
-
382
- # Pre-process video to extract all poses
383
- all_poses = []
384
- frame_count = 0
385
-
386
- while cap.isOpened():
387
- ret, frame = cap.read()
388
- if not ret:
389
- break
390
-
391
- pose_landmarks = self.pose_detector.detect_pose(frame)
392
- all_poses.append(pose_landmarks)
393
- frame_count += 1
394
-
395
- # Generate AI dance sequence
396
- ai_sequence = self.dance_generator.generate_dance_sequence(
397
- all_poses,
398
- mode,
399
- total_frames,
400
- (frame_height, frame_width)
401
- )
402
-
403
- # Reset video capture and create final video
404
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
405
- frame_count = 0
406
-
407
- while cap.isOpened():
408
- ret, frame = cap.read()
409
- if not ret:
410
- break
411
-
412
- # Get corresponding AI frame
413
- ai_frame = ai_sequence[frame_count]
414
-
415
- # Combine frames side by side
416
- combined_frame = np.hstack([frame, ai_frame])
417
-
418
- # Write frame to output video
419
- out.write(combined_frame)
420
- frame_count += 1
421
-
422
- # Release resources
423
- cap.release()
424
- out.release()
425
-
426
  return output_path
 
1
+ # Import necessary libraries
2
+ import cv2
3
+ import mediapipe as mp
4
+ import numpy as np
5
+ from scipy.interpolate import interp1d
6
+ import time
7
+ import os
8
+ import tempfile
9
+
10
+ class PoseDetector:
11
+ def __init__(self):
12
+ self.mp_pose = mp.solutions.pose
13
+ self.pose = self.mp_pose.Pose(
14
+ min_detection_confidence=0.5,
15
+ min_tracking_confidence=0.5
16
+ )
17
+
18
+ def detect_pose(self, frame):
19
+ rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
20
+ results = self.pose.process(rgb_frame)
21
+ return results.pose_landmarks if results.pose_landmarks else None
22
+
23
+ class DanceGenerator:
24
+ def __init__(self):
25
+ self.prev_moves = []
26
+ self.style_memory = []
27
+ self.rhythm_patterns = []
28
+
29
+ def generate_dance_sequence(self, all_poses, mode, total_frames, frame_size):
30
+ height, width = frame_size
31
+ sequence = []
32
+
33
+ if mode == "Sync Partner":
34
+ sequence = self._generate_sync_sequence(all_poses, total_frames, frame_size)
35
+ else:
36
+ sequence = self._generate_creative_sequence(all_poses, total_frames, frame_size)
37
+
38
+ return sequence
39
+
40
+ def _generate_sync_sequence(self, all_poses, total_frames, frame_size):
41
+ height, width = frame_size
42
+ sequence = []
43
+
44
+ # Enhanced rhythm analysis
45
+ rhythm_window = 10 # Analyze chunks of frames for rhythm
46
+ beat_positions = self._detect_dance_beats(all_poses, rhythm_window)
47
+
48
+ pose_arrays = []
49
+ for pose in all_poses:
50
+ if pose is not None:
51
+ pose_arrays.append(self._landmarks_to_array(pose))
52
+ else:
53
+ pose_arrays.append(None)
54
+
55
+ for i in range(total_frames):
56
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
57
+
58
+ if pose_arrays[i] is not None:
59
+ # Enhanced mirroring with rhythm awareness
60
+ mirrored = self._mirror_movements(pose_arrays[i])
61
+
62
+ # Apply rhythm-based movement enhancement
63
+ if i in beat_positions:
64
+ mirrored = self._enhance_movement_on_beat(mirrored)
65
+
66
+ if i > 0 and pose_arrays[i-1] is not None:
67
+ mirrored = self._smooth_transition(pose_arrays[i-1], mirrored, 0.3)
68
+
69
+ frame = self._create_enhanced_dance_frame(
70
+ mirrored,
71
+ frame_size,
72
+ add_effects=True
73
+ )
74
+
75
+ sequence.append(frame)
76
+
77
+ return sequence
78
+
79
+ def _detect_dance_beats(self, poses, window_size):
80
+ """Detect main beats in the dance sequence"""
81
+ beat_positions = []
82
+
83
+ if len(poses) < window_size:
84
+ return beat_positions
85
+
86
+ for i in range(window_size, len(poses)):
87
+ if poses[i] is not None and poses[i-1] is not None:
88
+ curr_pose = self._landmarks_to_array(poses[i])
89
+ prev_pose = self._landmarks_to_array(poses[i-1])
90
+
91
+ # Calculate movement magnitude
92
+ movement = np.mean(np.abs(curr_pose - prev_pose))
93
+
94
+ # Detect significant movements as beats
95
+ if movement > np.mean(self.rhythm_patterns) + np.std(self.rhythm_patterns):
96
+ beat_positions.append(i)
97
+
98
+ return beat_positions
99
+
100
+ def _enhance_movement_on_beat(self, pose):
101
+ """Enhance movements during detected beats"""
102
+ # Amplify movements slightly on beats
103
+ center = np.mean(pose, axis=0)
104
+ enhanced_pose = pose.copy()
105
+
106
+ for i in range(len(pose)):
107
+ # Amplify movement relative to center
108
+ vector = pose[i] - center
109
+ enhanced_pose[i] = center + vector * 1.2
110
+
111
+ return enhanced_pose
112
+
113
+ def _generate_creative_sequence(self, all_poses, total_frames, frame_size):
114
+ """Generate creative dance sequence based on style"""
115
+ height, width = frame_size
116
+ sequence = []
117
+
118
+ # Analyze style from all poses
119
+ style_patterns = self._analyze_style_patterns(all_poses)
120
+
121
+ # Generate new sequence using style patterns
122
+ for i in range(total_frames):
123
+ frame = np.zeros((height, width, 3), dtype=np.uint8)
124
+
125
+ # Generate new pose based on style
126
+ new_pose = self._generate_style_based_pose(style_patterns, i/total_frames)
127
+
128
+ if new_pose is not None:
129
+ frame = self._create_enhanced_dance_frame(
130
+ new_pose,
131
+ frame_size,
132
+ add_effects=True
133
+ )
134
+
135
+ sequence.append(frame)
136
+
137
+ return sequence
138
+
139
+ def _analyze_style_patterns(self, poses):
140
+ """Enhanced style analysis including rhythm and movement patterns"""
141
+ patterns = []
142
+ rhythm_data = []
143
+
144
+ for i in range(1, len(poses)):
145
+ if poses[i] is not None and poses[i-1] is not None:
146
+ # Calculate movement speed and direction
147
+ curr_pose = self._landmarks_to_array(poses[i])
148
+ prev_pose = self._landmarks_to_array(poses[i-1])
149
+
150
+ # Analyze movement velocity
151
+ velocity = np.mean(np.abs(curr_pose - prev_pose), axis=0)
152
+ rhythm_data.append(velocity)
153
+
154
+ # Store enhanced pattern data
155
+ pattern_info = {
156
+ 'pose': curr_pose,
157
+ 'velocity': velocity,
158
+ 'acceleration': velocity if i == 1 else velocity - prev_velocity
159
+ }
160
+ patterns.append(pattern_info)
161
+ prev_velocity = velocity
162
+
163
+ self.rhythm_patterns = rhythm_data
164
+ return patterns
165
+
166
+ def _generate_style_based_pose(self, patterns, progress):
167
+ """Generate new pose based on style patterns and progress"""
168
+ if not patterns:
169
+ return None
170
+
171
+ # Create smooth interpolation between poses
172
+ num_patterns = len(patterns)
173
+ pattern_idx = int(progress * (num_patterns - 1))
174
+
175
+ if pattern_idx < num_patterns - 1:
176
+ t = progress * (num_patterns - 1) - pattern_idx
177
+ # Extract pose arrays from pattern dictionaries
178
+ pose1 = patterns[pattern_idx]['pose']
179
+ pose2 = patterns[pattern_idx + 1]['pose']
180
+ pose = self._interpolate_poses(pose1, pose2, t)
181
+ else:
182
+ pose = patterns[-1]['pose']
183
+
184
+ return pose
185
+
186
+ def _interpolate_poses(self, pose1, pose2, t):
187
+ """Smoothly interpolate between two poses"""
188
+ if isinstance(pose1, dict):
189
+ pose1 = pose1['pose']
190
+ if isinstance(pose2, dict):
191
+ pose2 = pose2['pose']
192
+ return pose1 * (1 - t) + pose2 * t
193
+
194
+ def _create_enhanced_dance_frame(self, pose_array, frame_size, add_effects=True):
195
+ """Create enhanced visualization frame with effects"""
196
+ height, width = frame_size
197
+ # Change background from black to light gray for better visibility
198
+ frame = np.ones((height, width, 3), dtype=np.uint8) * 240 # Light gray background
199
+
200
+ # Convert coordinates
201
+ points = (pose_array[:, :2] * [width, height]).astype(int)
202
+
203
+ # Draw enhanced skeleton with thicker lines and more visible colors
204
+ connections = self._get_pose_connections()
205
+ for connection in connections:
206
+ start_idx, end_idx = connection
207
+ if start_idx < len(points) and end_idx < len(points):
208
+ if add_effects:
209
+ self._draw_glowing_line(
210
+ frame,
211
+ points[start_idx],
212
+ points[end_idx],
213
+ (0, 100, 255), # Orange color for skeleton
214
+ thickness=4
215
+ )
216
+ else:
217
+ cv2.line(frame,
218
+ tuple(points[start_idx]),
219
+ tuple(points[end_idx]),
220
+ (0, 100, 255), 4)
221
+
222
+ # Draw enhanced joints with larger radius
223
+ for point in points:
224
+ if add_effects:
225
+ self._draw_glowing_point(frame, point, (255, 0, 0), radius=6) # Blue joints
226
+ else:
227
+ cv2.circle(frame, tuple(point), 6, (255, 0, 0), -1)
228
+
229
+ return frame
230
+
231
+ def _draw_glowing_line(self, frame, start, end, color, thickness=4):
232
+ """Draw a line with enhanced glow effect"""
233
+ # Draw outer glow
234
+ for i in range(3):
235
+ alpha = 0.5 - i * 0.15
236
+ thick = thickness + (i * 4)
237
+ cv2.line(frame, tuple(start), tuple(end),
238
+ tuple([int(c * alpha) for c in color]),
239
+ thick)
240
+
241
+ # Draw main line
242
+ cv2.line(frame, tuple(start), tuple(end), color, thickness)
243
+
244
+ def _draw_glowing_point(self, frame, point, color, radius=6):
245
+ """Draw a point with enhanced glow effect"""
246
+ # Draw outer glow
247
+ for i in range(3):
248
+ alpha = 0.5 - i * 0.15
249
+ r = radius + (i * 3)
250
+ cv2.circle(frame, tuple(point), r,
251
+ tuple([int(c * alpha) for c in color]),
252
+ -1)
253
+
254
+ # Draw main point
255
+ cv2.circle(frame, tuple(point), radius, color, -1)
256
+
257
+ def _landmarks_to_array(self, landmarks):
258
+ """Convert MediaPipe landmarks to numpy array"""
259
+ points = []
260
+ for landmark in landmarks.landmark:
261
+ points.append([landmark.x, landmark.y, landmark.z])
262
+ return np.array(points)
263
+
264
+ def _mirror_movements(self, landmarks):
265
+ """Mirror the input movements"""
266
+ mirrored = landmarks.copy()
267
+ mirrored[:, 0] = 1 - mirrored[:, 0] # Flip x coordinates
268
+ return mirrored
269
+
270
+ def _update_style_memory(self, landmarks):
271
+ """Update memory of dance style"""
272
+ self.style_memory.append(landmarks)
273
+ if len(self.style_memory) > 30: # Keep last 30 frames
274
+ self.style_memory.pop(0)
275
+
276
+ def _generate_style_based_moves(self):
277
+ """Generate new moves based on learned style"""
278
+ if not self.style_memory:
279
+ return np.zeros((33, 3)) # Default pose shape
280
+
281
+ # Simple implementation: interpolate between stored poses
282
+ base_pose = self.style_memory[-1]
283
+ if len(self.style_memory) > 1:
284
+ prev_pose = self.style_memory[-2]
285
+ t = np.random.random()
286
+ new_pose = t * base_pose + (1-t) * prev_pose
287
+ else:
288
+ new_pose = base_pose
289
+
290
+ return new_pose
291
+
292
+ def _create_dance_frame(self, pose_array):
293
+ """Create visualization frame from pose array"""
294
+ frame = np.zeros((480, 640, 3), dtype=np.uint8)
295
+
296
+ # Convert normalized coordinates to pixel coordinates
297
+ points = (pose_array[:, :2] * [640, 480]).astype(int)
298
+
299
+ # Draw connections between joints
300
+ connections = self._get_pose_connections()
301
+ for connection in connections:
302
+ start_idx, end_idx = connection
303
+ if start_idx < len(points) and end_idx < len(points):
304
+ cv2.line(frame,
305
+ tuple(points[start_idx]),
306
+ tuple(points[end_idx]),
307
+ (0, 255, 0), 2)
308
+
309
+ # Draw joints
310
+ for point in points:
311
+ cv2.circle(frame, tuple(point), 4, (0, 0, 255), -1)
312
+
313
+ return frame
314
+
315
+ def _get_pose_connections(self):
316
+ """Define connections between pose landmarks"""
317
+ return [
318
+ (0, 1), (1, 2), (2, 3), (3, 7), # Face
319
+ (0, 4), (4, 5), (5, 6), (6, 8),
320
+ (9, 10), (11, 12), (11, 13), (13, 15), # Arms
321
+ (12, 14), (14, 16),
322
+ (11, 23), (12, 24), # Torso
323
+ (23, 24), (23, 25), (24, 26), # Legs
324
+ (25, 27), (26, 28), (27, 29), (28, 30),
325
+ (29, 31), (30, 32)
326
+ ]
327
+
328
+ def _smooth_transition(self, prev_pose, current_pose, smoothing_factor=0.3):
329
+ """Create smooth transition between poses"""
330
+ if prev_pose is None or current_pose is None:
331
+ return current_pose
332
+
333
+ # Interpolate between previous and current pose
334
+ smoothed_pose = (1 - smoothing_factor) * prev_pose + smoothing_factor * current_pose
335
+
336
+ # Ensure the smoothed pose maintains proper proportions
337
+ # Normalize joint positions relative to hip center
338
+ hip_center_idx = 23 # Index for hip center landmark
339
+
340
+ prev_hip = prev_pose[hip_center_idx]
341
+ current_hip = current_pose[hip_center_idx]
342
+ smoothed_hip = smoothed_pose[hip_center_idx]
343
+
344
+ # Adjust positions relative to hip center
345
+ for i in range(len(smoothed_pose)):
346
+ if i != hip_center_idx:
347
+ # Calculate relative positions
348
+ prev_relative = prev_pose[i] - prev_hip
349
+ current_relative = current_pose[i] - current_hip
350
+
351
+ # Interpolate relative positions
352
+ smoothed_relative = (1 - smoothing_factor) * prev_relative + smoothing_factor * current_relative
353
+
354
+ # Update smoothed pose
355
+ smoothed_pose[i] = smoothed_hip + smoothed_relative
356
+
357
+ return smoothed_pose
358
+
359
+ class AIDancePartner:
360
+ def __init__(self):
361
+ self.pose_detector = PoseDetector()
362
+ self.dance_generator = DanceGenerator()
363
+
364
+ def process_video(self, video_path, mode="Sync Partner"):
365
+ # Create a temporary directory for output
366
+ temp_dir = tempfile.mkdtemp()
367
+ output_path = os.path.join(temp_dir, 'output_dance.mp4')
368
+
369
+ cap = cv2.VideoCapture(video_path)
370
+
371
+ # Get video properties
372
+ fps = int(cap.get(cv2.CAP_PROP_FPS))
373
+ frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
374
+ frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
375
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
376
+
377
+ # Create output video writer
378
+ fourcc = cv2.VideoWriter_fourcc(*'avc1')
379
+ out = cv2.VideoWriter(output_path, fourcc, fps, (frame_width * 2, frame_height))
380
+
381
+ # Pre-process video to extract all poses
382
+ all_poses = []
383
+ frame_count = 0
384
+
385
+ while cap.isOpened():
386
+ ret, frame = cap.read()
387
+ if not ret:
388
+ break
389
+
390
+ pose_landmarks = self.pose_detector.detect_pose(frame)
391
+ all_poses.append(pose_landmarks)
392
+ frame_count += 1
393
+
394
+ # Generate AI dance sequence
395
+ ai_sequence = self.dance_generator.generate_dance_sequence(
396
+ all_poses,
397
+ mode,
398
+ total_frames,
399
+ (frame_height, frame_width)
400
+ )
401
+
402
+ # Reset video capture and create final video
403
+ cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
404
+ frame_count = 0
405
+
406
+ while cap.isOpened():
407
+ ret, frame = cap.read()
408
+ if not ret:
409
+ break
410
+
411
+ # Get corresponding AI frame
412
+ ai_frame = ai_sequence[frame_count]
413
+
414
+ # Combine frames side by side
415
+ combined_frame = np.hstack([frame, ai_frame])
416
+
417
+ # Write frame to output video
418
+ out.write(combined_frame)
419
+ frame_count += 1
420
+
421
+ # Release resources
422
+ cap.release()
423
+ out.release()
424
+
 
425
  return output_path