Upload app.py
Browse files
app.py
CHANGED
@@ -3,6 +3,7 @@ import cv2
|
|
3 |
import numpy as np
|
4 |
import tempfile
|
5 |
import os
|
|
|
6 |
from pose_detector import PoseDetector
|
7 |
from dance_generator import DanceGenerator
|
8 |
from dance_visualizer import DanceVisualizer
|
@@ -32,13 +33,21 @@ class AIDancePartner:
|
|
32 |
# Add playback controls
|
33 |
play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
|
34 |
|
|
|
|
|
|
|
35 |
if video_file:
|
36 |
-
self.process_video(video_file, mode, play_speed)
|
37 |
|
38 |
-
def process_video(self, video_file, mode, play_speed):
|
|
|
|
|
|
|
|
|
39 |
tfile = tempfile.NamedTemporaryFile(delete=False)
|
40 |
tfile.write(video_file.read())
|
41 |
|
|
|
42 |
cap = cv2.VideoCapture(tfile.name)
|
43 |
|
44 |
# Get video properties
|
@@ -49,59 +58,72 @@ class AIDancePartner:
|
|
49 |
|
50 |
# Initialize progress bar
|
51 |
progress_bar = st.progress(0)
|
52 |
-
frame_placeholder = st.empty()
|
53 |
-
|
54 |
-
# Pre-process video to extract all poses
|
55 |
-
all_poses = []
|
56 |
-
while cap.isOpened():
|
57 |
-
ret, frame = cap.read()
|
58 |
-
if not ret:
|
59 |
-
break
|
60 |
-
pose_landmarks = self.pose_detector.detect_pose(frame)
|
61 |
-
all_poses.append(pose_landmarks)
|
62 |
|
63 |
-
#
|
64 |
-
|
|
|
65 |
|
66 |
-
#
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
|
74 |
# Playback loop
|
75 |
frame_count = 0
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
break
|
80 |
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
ai_frame = ai_sequence[frame_count]
|
87 |
-
|
88 |
-
# Combine frames side by side
|
89 |
-
combined_frame = np.hstack([
|
90 |
-
frame,
|
91 |
-
cv2.resize(ai_frame, (frame_width, frame_height))
|
92 |
-
])
|
93 |
-
|
94 |
-
# Display combined frame
|
95 |
-
frame_placeholder.image(
|
96 |
-
combined_frame,
|
97 |
-
channels="BGR",
|
98 |
-
use_column_width=True
|
99 |
-
)
|
100 |
-
|
101 |
-
# Control playback speed
|
102 |
-
cv2.waitKey(int(1000 / (fps * play_speed)))
|
103 |
-
frame_count += 1
|
104 |
-
|
105 |
# Cleanup
|
106 |
cap.release()
|
107 |
os.unlink(tfile.name)
|
|
|
3 |
import numpy as np
|
4 |
import tempfile
|
5 |
import os
|
6 |
+
import time
|
7 |
from pose_detector import PoseDetector
|
8 |
from dance_generator import DanceGenerator
|
9 |
from dance_visualizer import DanceVisualizer
|
|
|
33 |
# Add playback controls
|
34 |
play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
|
35 |
|
36 |
+
# Add play/pause button
|
37 |
+
is_playing = st.sidebar.button("Play/Pause")
|
38 |
+
|
39 |
if video_file:
|
40 |
+
self.process_video(video_file, mode, play_speed, is_playing)
|
41 |
|
42 |
+
def process_video(self, video_file, mode, play_speed, is_playing):
|
43 |
+
# Create a placeholder for the video
|
44 |
+
video_placeholder = st.empty()
|
45 |
+
|
46 |
+
# Create temporary file
|
47 |
tfile = tempfile.NamedTemporaryFile(delete=False)
|
48 |
tfile.write(video_file.read())
|
49 |
|
50 |
+
# Open video file
|
51 |
cap = cv2.VideoCapture(tfile.name)
|
52 |
|
53 |
# Get video properties
|
|
|
58 |
|
59 |
# Initialize progress bar
|
60 |
progress_bar = st.progress(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
62 |
+
# Store video frames and AI sequences
|
63 |
+
frames = []
|
64 |
+
ai_sequences = []
|
65 |
|
66 |
+
# Pre-process video
|
67 |
+
with st.spinner('Processing video...'):
|
68 |
+
while cap.isOpened():
|
69 |
+
ret, frame = cap.read()
|
70 |
+
if not ret:
|
71 |
+
break
|
72 |
+
|
73 |
+
pose_landmarks = self.pose_detector.detect_pose(frame)
|
74 |
+
ai_frame = self.dance_generator.generate_dance_sequence(
|
75 |
+
[pose_landmarks],
|
76 |
+
mode,
|
77 |
+
1,
|
78 |
+
(frame_height, frame_width)
|
79 |
+
)[0]
|
80 |
+
|
81 |
+
frames.append(frame)
|
82 |
+
ai_sequences.append(ai_frame)
|
83 |
|
84 |
# Playback loop
|
85 |
frame_count = 0
|
86 |
+
play = True if is_playing else False
|
87 |
+
|
88 |
+
while True:
|
89 |
+
if play:
|
90 |
+
if frame_count >= len(frames):
|
91 |
+
frame_count = 0
|
92 |
+
|
93 |
+
# Get current frames
|
94 |
+
frame = frames[frame_count]
|
95 |
+
ai_frame = ai_sequences[frame_count]
|
96 |
+
|
97 |
+
# Combine frames side by side
|
98 |
+
combined_frame = np.hstack([
|
99 |
+
frame,
|
100 |
+
cv2.resize(ai_frame, (frame_width, frame_height))
|
101 |
+
])
|
102 |
+
|
103 |
+
# Update progress
|
104 |
+
progress = frame_count / total_frames
|
105 |
+
progress_bar.progress(progress)
|
106 |
+
|
107 |
+
# Display frame
|
108 |
+
video_placeholder.image(
|
109 |
+
combined_frame,
|
110 |
+
channels="BGR",
|
111 |
+
use_container_width=True
|
112 |
+
)
|
113 |
+
|
114 |
+
# Control playback speed
|
115 |
+
time.sleep(1 / (fps * play_speed))
|
116 |
+
frame_count += 1
|
117 |
+
|
118 |
+
# Check for play/pause button
|
119 |
+
if st.sidebar.button("Stop"):
|
120 |
break
|
121 |
|
122 |
+
# Add replay button
|
123 |
+
if frame_count >= len(frames):
|
124 |
+
if st.sidebar.button("Replay"):
|
125 |
+
frame_count = 0
|
126 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
# Cleanup
|
128 |
cap.release()
|
129 |
os.unlink(tfile.name)
|