File size: 4,884 Bytes
75aee9e
 
 
 
 
ab74dbe
75aee9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a0264
 
 
bf4a6e3
 
ab74dbe
75aee9e
ab74dbe
75aee9e
ab74dbe
 
 
 
 
75aee9e
 
 
ab74dbe
75aee9e
 
96a0264
 
 
 
 
75aee9e
96a0264
 
 
bf4a6e3
 
 
 
ab74dbe
 
 
96a0264
ab74dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96a0264
bf4a6e3
 
 
 
 
 
 
96a0264
 
ab74dbe
bf4a6e3
 
ab74dbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bf4a6e3
 
 
 
 
 
 
 
 
ab74dbe
75aee9e
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import streamlit as st
import cv2
import numpy as np
import tempfile
import os
import time
from pose_detector import PoseDetector
from dance_generator import DanceGenerator
from dance_visualizer import DanceVisualizer

class AIDancePartner:
    def __init__(self):
        self.pose_detector = PoseDetector()
        self.dance_generator = DanceGenerator()
        self.visualizer = DanceVisualizer()
        
    def setup_ui(self):
        st.title("AI Dance Partner 💃🤖")
        st.sidebar.header("Controls")
        
        # Upload video section
        video_file = st.sidebar.file_uploader(
            "Upload your dance video", 
            type=['mp4', 'avi', 'mov']
        )
        
        # Mode selection
        mode = st.sidebar.selectbox(
            "Select Mode",
            ["Sync Partner", "Generate New Moves"]
        )
        
        # Add playback controls
        play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
        
        # Add play/pause button with unique key
        is_playing = st.sidebar.button("Play/Pause", key="play_pause_button")
        
        if video_file:
            self.process_video(video_file, mode, play_speed, is_playing)

    def process_video(self, video_file, mode, play_speed, is_playing):
        # Create a placeholder for the video
        video_placeholder = st.empty()
        
        # Create temporary file
        tfile = tempfile.NamedTemporaryFile(delete=False)
        tfile.write(video_file.read())
        
        # Open video file
        cap = cv2.VideoCapture(tfile.name)
        
        # Get video properties
        fps = cap.get(cv2.CAP_PROP_FPS)
        frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
        frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        
        # Initialize progress bar
        progress_bar = st.progress(0)
        
        # Initialize play state in session state
        if 'play_state' not in st.session_state:
            st.session_state.play_state = is_playing

        # Store video frames and AI sequences
        frames = []
        ai_sequences = []
        
        # Pre-process video
        with st.spinner('Processing video...'):
            while cap.isOpened():
                ret, frame = cap.read()
                if not ret:
                    break
                    
                pose_landmarks = self.pose_detector.detect_pose(frame)
                ai_frame = self.dance_generator.generate_dance_sequence(
                    [pose_landmarks], 
                    mode, 
                    1,
                    (frame_height, frame_width)
                )[0]
                
                frames.append(frame)
                ai_sequences.append(ai_frame)
        
        # Playback controls in sidebar
        col1, col2 = st.sidebar.columns(2)
        with col1:
            stop_button = st.button("Stop", key="stop_button")
        with col2:
            replay_button = st.button("Replay", key="replay_button")

        # Playback loop
        frame_count = 0
        
        while not stop_button and frame_count < len(frames):
            if st.session_state.play_state:
                if frame_count >= len(frames):
                    frame_count = 0
                    
                # Get current frames
                frame = frames[frame_count]
                ai_frame = ai_sequences[frame_count]
                
                # Combine frames side by side
                combined_frame = np.hstack([
                    frame,
                    cv2.resize(ai_frame, (frame_width, frame_height))
                ])
                
                # Update progress
                progress = frame_count / total_frames
                progress_bar.progress(progress)
                
                # Display frame
                video_placeholder.image(
                    combined_frame,
                    channels="BGR",
                    use_container_width=True
                )
                
                # Control playback speed
                time.sleep(1 / (fps * play_speed))
                frame_count += 1
            
            # Handle replay
            if replay_button:
                frame_count = 0
                st.session_state.play_state = True
            
            # Toggle play state when play/pause button is pressed
            if is_playing:
                st.session_state.play_state = not st.session_state.play_state
        
        # Cleanup
        cap.release()
        os.unlink(tfile.name)

def main():
    app = AIDancePartner()
    app.setup_ui()

if __name__ == "__main__":
    main()