import streamlit as st import cv2 import numpy as np import tempfile import os import time from pose_detector import PoseDetector from dance_generator import DanceGenerator from dance_visualizer import DanceVisualizer class AIDancePartner: def __init__(self): self.pose_detector = PoseDetector() self.dance_generator = DanceGenerator() self.visualizer = DanceVisualizer() def setup_ui(self): st.title("AI Dance Partner 💃🤖") st.sidebar.header("Controls") # Upload video section video_file = st.sidebar.file_uploader( "Upload your dance video", type=['mp4', 'avi', 'mov'] ) # Mode selection mode = st.sidebar.selectbox( "Select Mode", ["Sync Partner", "Generate New Moves"] ) # Add playback controls play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1) # Add play/pause button with unique key is_playing = st.sidebar.button("Play/Pause", key="play_pause_button") if video_file: self.process_video(video_file, mode, play_speed, is_playing) def process_video(self, video_file, mode, play_speed, is_playing): # Create a placeholder for the video video_placeholder = st.empty() # Create temporary file tfile = tempfile.NamedTemporaryFile(delete=False) tfile.write(video_file.read()) # Open video file cap = cv2.VideoCapture(tfile.name) # Get video properties fps = cap.get(cv2.CAP_PROP_FPS) frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)) frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) # Initialize progress bar progress_bar = st.progress(0) # Initialize play state in session state if 'play_state' not in st.session_state: st.session_state.play_state = is_playing # Store video frames and AI sequences frames = [] ai_sequences = [] # Pre-process video with st.spinner('Processing video...'): while cap.isOpened(): ret, frame = cap.read() if not ret: break pose_landmarks = self.pose_detector.detect_pose(frame) ai_frame = self.dance_generator.generate_dance_sequence( [pose_landmarks], mode, 1, (frame_height, frame_width) )[0] frames.append(frame) ai_sequences.append(ai_frame) # Playback controls in sidebar col1, col2 = st.sidebar.columns(2) with col1: stop_button = st.button("Stop", key="stop_button") with col2: replay_button = st.button("Replay", key="replay_button") # Playback loop frame_count = 0 while not stop_button and frame_count < len(frames): if st.session_state.play_state: if frame_count >= len(frames): frame_count = 0 # Get current frames frame = frames[frame_count] ai_frame = ai_sequences[frame_count] # Combine frames side by side combined_frame = np.hstack([ frame, cv2.resize(ai_frame, (frame_width, frame_height)) ]) # Update progress progress = frame_count / total_frames progress_bar.progress(progress) # Display frame video_placeholder.image( combined_frame, channels="BGR", use_container_width=True ) # Control playback speed time.sleep(1 / (fps * play_speed)) frame_count += 1 # Handle replay if replay_button: frame_count = 0 st.session_state.play_state = True # Toggle play state when play/pause button is pressed if is_playing: st.session_state.play_state = not st.session_state.play_state # Cleanup cap.release() os.unlink(tfile.name) def main(): app = AIDancePartner() app.setup_ui() if __name__ == "__main__": main()