File size: 3,511 Bytes
75aee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a0264
 
 
75aee9e
96a0264
75aee9e
96a0264
75aee9e
 
 
 
 
96a0264
 
 
 
 
75aee9e
96a0264
 
 
 
 
 
75aee9e
 
 
 
 
96a0264
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75aee9e
96a0264
 
 
 
 
 
75aee9e
96a0264
 
 
75aee9e
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import streamlit as st
import cv2
import numpy as np
import tempfile
import os
from pose_detector import PoseDetector
from dance_generator import DanceGenerator
from dance_visualizer import DanceVisualizer

class AIDancePartner:
    def __init__(self):
        self.pose_detector = PoseDetector()
        self.dance_generator = DanceGenerator()
        self.visualizer = DanceVisualizer()
        
    def setup_ui(self):
        st.title("AI Dance Partner 💃🤖")
        st.sidebar.header("Controls")
        
        # Upload video section
        video_file = st.sidebar.file_uploader(
            "Upload your dance video", 
            type=['mp4', 'avi', 'mov']
        )
        
        # Mode selection
        mode = st.sidebar.selectbox(
            "Select Mode",
            ["Sync Partner", "Generate New Moves"]
        )
        
        # Add playback controls
        play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
        
        if video_file:
            self.process_video(video_file, mode, play_speed)

    def process_video(self, video_file, mode, play_speed):
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(video_file.read())
        
        cap = cv2.VideoCapture(tfile.name)
        
        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Initialize progress bar
        progress_bar = st.progress(0)
        frame_placeholder = st.empty()
        
        # Pre-process video to extract all poses
        all_poses = []
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
            pose_landmarks = self.pose_detector.detect_pose(frame)
            all_poses.append(pose_landmarks)
        
        # Reset video capture
        cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
        
        # Generate AI dance sequence
        ai_sequence = self.dance_generator.generate_dance_sequence(
            all_poses, 
            mode, 
            total_frames,
            (frame_height, frame_width)
        )
        
        # Playback loop
        frame_count = 0
        while cap.isOpened():
            ret, frame = cap.read()
            if not ret:
                break
                
            # Update progress
            progress = frame_count / total_frames
            progress_bar.progress(progress)
            
            # Get corresponding AI frame
            ai_frame = ai_sequence[frame_count]
            
            # Combine frames side by side
            combined_frame = np.hstack([
                frame,
                cv2.resize(ai_frame, (frame_width, frame_height))
            ])
            
            # Display combined frame
            frame_placeholder.image(
                combined_frame,
                channels="BGR",
                use_column_width=True
            )
            
            # Control playback speed
            cv2.waitKey(int(1000 / (fps * play_speed)))
            frame_count += 1
            
        # Cleanup
        cap.release()
        os.unlink(tfile.name)

def main():
    app = AIDancePartner()
    app.setup_ui()

if __name__ == "__main__":
    main()