aidancer / app.py
ombhojane's picture
Upload app.py
bf4a6e3 verified
import streamlit as st
import cv2
import numpy as np
import tempfile
import os
import time
from pose_detector import PoseDetector
from dance_generator import DanceGenerator
from dance_visualizer import DanceVisualizer
class AIDancePartner:
def __init__(self):
self.pose_detector = PoseDetector()
self.dance_generator = DanceGenerator()
self.visualizer = DanceVisualizer()
def setup_ui(self):
st.title("AI Dance Partner πŸ’ƒπŸ€–")
st.sidebar.header("Controls")
# Upload video section
video_file = st.sidebar.file_uploader(
"Upload your dance video",
type=['mp4', 'avi', 'mov']
)
# Mode selection
mode = st.sidebar.selectbox(
"Select Mode",
["Sync Partner", "Generate New Moves"]
)
# Add playback controls
play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
# Add play/pause button with unique key
is_playing = st.sidebar.button("Play/Pause", key="play_pause_button")
if video_file:
self.process_video(video_file, mode, play_speed, is_playing)
def process_video(self, video_file, mode, play_speed, is_playing):
# Create a placeholder for the video
video_placeholder = st.empty()
# Create temporary file
tfile = tempfile.NamedTemporaryFile(delete=False)
tfile.write(video_file.read())
# Open video file
cap = cv2.VideoCapture(tfile.name)
# Get video properties
fps = cap.get(cv2.CAP_PROP_FPS)
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
# Initialize progress bar
progress_bar = st.progress(0)
# Initialize play state in session state
if 'play_state' not in st.session_state:
st.session_state.play_state = is_playing
# Store video frames and AI sequences
frames = []
ai_sequences = []
# Pre-process video
with st.spinner('Processing video...'):
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
pose_landmarks = self.pose_detector.detect_pose(frame)
ai_frame = self.dance_generator.generate_dance_sequence(
[pose_landmarks],
mode,
1,
(frame_height, frame_width)
)[0]
frames.append(frame)
ai_sequences.append(ai_frame)
# Playback controls in sidebar
col1, col2 = st.sidebar.columns(2)
with col1:
stop_button = st.button("Stop", key="stop_button")
with col2:
replay_button = st.button("Replay", key="replay_button")
# Playback loop
frame_count = 0
while not stop_button and frame_count < len(frames):
if st.session_state.play_state:
if frame_count >= len(frames):
frame_count = 0
# Get current frames
frame = frames[frame_count]
ai_frame = ai_sequences[frame_count]
# Combine frames side by side
combined_frame = np.hstack([
frame,
cv2.resize(ai_frame, (frame_width, frame_height))
])
# Update progress
progress = frame_count / total_frames
progress_bar.progress(progress)
# Display frame
video_placeholder.image(
combined_frame,
channels="BGR",
use_container_width=True
)
# Control playback speed
time.sleep(1 / (fps * play_speed))
frame_count += 1
# Handle replay
if replay_button:
frame_count = 0
st.session_state.play_state = True
# Toggle play state when play/pause button is pressed
if is_playing:
st.session_state.play_state = not st.session_state.play_state
# Cleanup
cap.release()
os.unlink(tfile.name)
def main():
app = AIDancePartner()
app.setup_ui()
if __name__ == "__main__":
main()