|
import streamlit as st
|
|
import cv2
|
|
import numpy as np
|
|
import tempfile
|
|
import os
|
|
import time
|
|
from pose_detector import PoseDetector
|
|
from dance_generator import DanceGenerator
|
|
from dance_visualizer import DanceVisualizer
|
|
|
|
class AIDancePartner:
|
|
def __init__(self):
|
|
self.pose_detector = PoseDetector()
|
|
self.dance_generator = DanceGenerator()
|
|
self.visualizer = DanceVisualizer()
|
|
|
|
def setup_ui(self):
|
|
st.title("AI Dance Partner ππ€")
|
|
st.sidebar.header("Controls")
|
|
|
|
|
|
video_file = st.sidebar.file_uploader(
|
|
"Upload your dance video",
|
|
type=['mp4', 'avi', 'mov']
|
|
)
|
|
|
|
|
|
mode = st.sidebar.selectbox(
|
|
"Select Mode",
|
|
["Sync Partner", "Generate New Moves"]
|
|
)
|
|
|
|
|
|
play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
|
|
|
|
|
|
is_playing = st.sidebar.button("Play/Pause", key="play_pause_button")
|
|
|
|
if video_file:
|
|
self.process_video(video_file, mode, play_speed, is_playing)
|
|
|
|
def process_video(self, video_file, mode, play_speed, is_playing):
|
|
|
|
video_placeholder = st.empty()
|
|
|
|
|
|
tfile = tempfile.NamedTemporaryFile(delete=False)
|
|
tfile.write(video_file.read())
|
|
|
|
|
|
cap = cv2.VideoCapture(tfile.name)
|
|
|
|
|
|
fps = cap.get(cv2.CAP_PROP_FPS)
|
|
frame_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
|
|
frame_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
|
|
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
|
|
|
|
|
progress_bar = st.progress(0)
|
|
|
|
|
|
if 'play_state' not in st.session_state:
|
|
st.session_state.play_state = is_playing
|
|
|
|
|
|
frames = []
|
|
ai_sequences = []
|
|
|
|
|
|
with st.spinner('Processing video...'):
|
|
while cap.isOpened():
|
|
ret, frame = cap.read()
|
|
if not ret:
|
|
break
|
|
|
|
pose_landmarks = self.pose_detector.detect_pose(frame)
|
|
ai_frame = self.dance_generator.generate_dance_sequence(
|
|
[pose_landmarks],
|
|
mode,
|
|
1,
|
|
(frame_height, frame_width)
|
|
)[0]
|
|
|
|
frames.append(frame)
|
|
ai_sequences.append(ai_frame)
|
|
|
|
|
|
col1, col2 = st.sidebar.columns(2)
|
|
with col1:
|
|
stop_button = st.button("Stop", key="stop_button")
|
|
with col2:
|
|
replay_button = st.button("Replay", key="replay_button")
|
|
|
|
|
|
frame_count = 0
|
|
|
|
while not stop_button and frame_count < len(frames):
|
|
if st.session_state.play_state:
|
|
if frame_count >= len(frames):
|
|
frame_count = 0
|
|
|
|
|
|
frame = frames[frame_count]
|
|
ai_frame = ai_sequences[frame_count]
|
|
|
|
|
|
combined_frame = np.hstack([
|
|
frame,
|
|
cv2.resize(ai_frame, (frame_width, frame_height))
|
|
])
|
|
|
|
|
|
progress = frame_count / total_frames
|
|
progress_bar.progress(progress)
|
|
|
|
|
|
video_placeholder.image(
|
|
combined_frame,
|
|
channels="BGR",
|
|
use_container_width=True
|
|
)
|
|
|
|
|
|
time.sleep(1 / (fps * play_speed))
|
|
frame_count += 1
|
|
|
|
|
|
if replay_button:
|
|
frame_count = 0
|
|
st.session_state.play_state = True
|
|
|
|
|
|
if is_playing:
|
|
st.session_state.play_state = not st.session_state.play_state
|
|
|
|
|
|
cap.release()
|
|
os.unlink(tfile.name)
|
|
|
|
def main():
|
|
app = AIDancePartner()
|
|
app.setup_ui()
|
|
|
|
if __name__ == "__main__":
|
|
main() |