File size: 4,370 Bytes
75aee9e
 
 
 
 
ab74dbe
75aee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a0264
 
 
ab74dbe
 
 
75aee9e
ab74dbe
75aee9e
ab74dbe
 
 
 
 
75aee9e
 
 
ab74dbe
75aee9e
 
96a0264
 
 
 
 
75aee9e
96a0264
 
 
ab74dbe
 
 
96a0264
ab74dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a0264
 
 
ab74dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a0264
 
ab74dbe
 
 
 
 
75aee9e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import streamlit as st
import cv2
import numpy as np
import tempfile
import os
import time
from pose_detector import PoseDetector
from dance_generator import DanceGenerator
from dance_visualizer import DanceVisualizer

class AIDancePartner:
    def __init__(self):
        self.pose_detector = PoseDetector()
        self.dance_generator = DanceGenerator()
        self.visualizer = DanceVisualizer()
        
    def setup_ui(self):
        st.title("AI Dance Partner 💃🤖")
        st.sidebar.header("Controls")
        
        # Upload video section
        video_file = st.sidebar.file_uploader(
            "Upload your dance video", 
            type=['mp4', 'avi', 'mov']
        )
        
        # Mode selection
        mode = st.sidebar.selectbox(
            "Select Mode",
            ["Sync Partner", "Generate New Moves"]
        )
        
        # Add playback controls
        play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
        
        # Add play/pause button
        is_playing = st.sidebar.button("Play/Pause")
        
        if video_file:
            self.process_video(video_file, mode, play_speed, is_playing)

    def process_video(self, video_file, mode, play_speed, is_playing):
        # Create a placeholder for the video
        video_placeholder = st.empty()
        
        # Create temporary file
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(video_file.read())
        
        # Open video file
        cap = cv2.VideoCapture(tfile.name)
        
        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Initialize progress bar
        progress_bar = st.progress(0)
        
        # Store video frames and AI sequences
        frames = []
        ai_sequences = []
        
        # Pre-process video
        with st.spinner('Processing video...'):
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break
                    
                pose_landmarks = self.pose_detector.detect_pose(frame)
                ai_frame = self.dance_generator.generate_dance_sequence(
                    [pose_landmarks], 
                    mode, 
                    1,
                    (frame_height, frame_width)
                )[0]
                
                frames.append(frame)
                ai_sequences.append(ai_frame)
        
        # Playback loop
        frame_count = 0
        play = True if is_playing else False
        
        while True:
            if play:
                if frame_count >= len(frames):
                    frame_count = 0
                    
                # Get current frames
                frame = frames[frame_count]
                ai_frame = ai_sequences[frame_count]
                
                # Combine frames side by side
                combined_frame = np.hstack([
                    frame,
                    cv2.resize(ai_frame, (frame_width, frame_height))
                ])
                
                # Update progress
                progress = frame_count / total_frames
                progress_bar.progress(progress)
                
                # Display frame
                video_placeholder.image(
                    combined_frame,
                    channels="BGR",
                    use_container_width=True
                )
                
                # Control playback speed
                time.sleep(1 / (fps * play_speed))
                frame_count += 1
                
            # Check for play/pause button
            if st.sidebar.button("Stop"):
                break
                
            # Add replay button
            if frame_count >= len(frames):
                if st.sidebar.button("Replay"):
                    frame_count = 0
        
        # Cleanup
        cap.release()
        os.unlink(tfile.name)

def main():
    app = AIDancePartner()
    app.setup_ui()

if __name__ == "__main__":
    main()