File size: 4,884 Bytes
75aee9e ab74dbe 75aee9e 96a0264 bf4a6e3 ab74dbe 75aee9e ab74dbe 75aee9e ab74dbe 75aee9e ab74dbe 75aee9e 96a0264 75aee9e 96a0264 bf4a6e3 ab74dbe 96a0264 ab74dbe 96a0264 bf4a6e3 96a0264 ab74dbe bf4a6e3 ab74dbe bf4a6e3 ab74dbe 75aee9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import streamlit as st
import cv2
import numpy as np
import tempfile
import os
import time
from pose_detector import PoseDetector
from dance_generator import DanceGenerator
from dance_visualizer import DanceVisualizer
class AIDancePartner:
def __init__(self):
self.pose_detector = PoseDetector()
self.dance_generator = DanceGenerator()
self.visualizer = DanceVisualizer()
def setup_ui(self):
st.title("AI Dance Partner 💃🤖")
st.sidebar.header("Controls")
# Upload video section
video_file = st.sidebar.file_uploader(
"Upload your dance video",
type=['mp4', 'avi', 'mov']
)
# Mode selection
mode = st.sidebar.selectbox(
"Select Mode",
["Sync Partner", "Generate New Moves"]
)
# Add playback controls
play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
# Add play/pause button with unique key
is_playing = st.sidebar.button("Play/Pause", key="play_pause_button")
if video_file:
self.process_video(video_file, mode, play_speed, is_playing)
def process_video(self, video_file, mode, play_speed, is_playing):
# Create a placeholder for the video
video_placeholder = st.empty()
# Create temporary file
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(video_file.read())
# Open video file
cap = cv2.VideoCapture(tfile.name)
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Initialize progress bar
progress_bar = st.progress(0)
# Initialize play state in session state
if 'play_state' not in st.session_state:
st.session_state.play_state = is_playing
# Store video frames and AI sequences
frames = []
ai_sequences = []
# Pre-process video
with st.spinner('Processing video...'):
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
pose_landmarks = self.pose_detector.detect_pose(frame)
ai_frame = self.dance_generator.generate_dance_sequence(
[pose_landmarks],
mode,
1,
(frame_height, frame_width)
)[0]
frames.append(frame)
ai_sequences.append(ai_frame)
# Playback controls in sidebar
col1, col2 = st.sidebar.columns(2)
with col1:
stop_button = st.button("Stop", key="stop_button")
with col2:
replay_button = st.button("Replay", key="replay_button")
# Playback loop
frame_count = 0
while not stop_button and frame_count < len(frames):
if st.session_state.play_state:
if frame_count >= len(frames):
frame_count = 0
# Get current frames
frame = frames[frame_count]
ai_frame = ai_sequences[frame_count]
# Combine frames side by side
combined_frame = np.hstack([
frame,
cv2.resize(ai_frame, (frame_width, frame_height))
])
# Update progress
progress = frame_count / total_frames
progress_bar.progress(progress)
# Display frame
video_placeholder.image(
combined_frame,
channels="BGR",
use_container_width=True
)
# Control playback speed
time.sleep(1 / (fps * play_speed))
frame_count += 1
# Handle replay
if replay_button:
frame_count = 0
st.session_state.play_state = True
# Toggle play state when play/pause button is pressed
if is_playing:
st.session_state.play_state = not st.session_state.play_state
# Cleanup
cap.release()
os.unlink(tfile.name)
def main():
app = AIDancePartner()
app.setup_ui()
if __name__ == "__main__":
main() |