ombhojane commited on
Commit
ab74dbe
·
verified ·
1 Parent(s): b446ed6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -48
app.py CHANGED
@@ -3,6 +3,7 @@ import cv2
3
  import numpy as np
4
  import tempfile
5
  import os
 
6
  from pose_detector import PoseDetector
7
  from dance_generator import DanceGenerator
8
  from dance_visualizer import DanceVisualizer
@@ -32,13 +33,21 @@ class AIDancePartner:
32
  # Add playback controls
33
  play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
34
 
 
 
 
35
  if video_file:
36
- self.process_video(video_file, mode, play_speed)
37
 
38
- def process_video(self, video_file, mode, play_speed):
 
 
 
 
39
  tfile = tempfile.NamedTemporaryFile(delete=False)
40
  tfile.write(video_file.read())
41
 
 
42
  cap = cv2.VideoCapture(tfile.name)
43
 
44
  # Get video properties
@@ -49,59 +58,72 @@ class AIDancePartner:
49
 
50
  # Initialize progress bar
51
  progress_bar = st.progress(0)
52
- frame_placeholder = st.empty()
53
-
54
- # Pre-process video to extract all poses
55
- all_poses = []
56
- while cap.isOpened():
57
- ret, frame = cap.read()
58
- if not ret:
59
- break
60
- pose_landmarks = self.pose_detector.detect_pose(frame)
61
- all_poses.append(pose_landmarks)
62
 
63
- # Reset video capture
64
- cap.set(cv2.CAP_PROP_POS_FRAMES, 0)
 
65
 
66
- # Generate AI dance sequence
67
- ai_sequence = self.dance_generator.generate_dance_sequence(
68
- all_poses,
69
- mode,
70
- total_frames,
71
- (frame_height, frame_width)
72
- )
 
 
 
 
 
 
 
 
 
 
73
 
74
  # Playback loop
75
  frame_count = 0
76
- while cap.isOpened():
77
- ret, frame = cap.read()
78
- if not ret:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  break
80
 
81
- # Update progress
82
- progress = frame_count / total_frames
83
- progress_bar.progress(progress)
84
-
85
- # Get corresponding AI frame
86
- ai_frame = ai_sequence[frame_count]
87
-
88
- # Combine frames side by side
89
- combined_frame = np.hstack([
90
- frame,
91
- cv2.resize(ai_frame, (frame_width, frame_height))
92
- ])
93
-
94
- # Display combined frame
95
- frame_placeholder.image(
96
- combined_frame,
97
- channels="BGR",
98
- use_column_width=True
99
- )
100
-
101
- # Control playback speed
102
- cv2.waitKey(int(1000 / (fps * play_speed)))
103
- frame_count += 1
104
-
105
  # Cleanup
106
  cap.release()
107
  os.unlink(tfile.name)
 
3
  import numpy as np
4
  import tempfile
5
  import os
6
+ import time
7
  from pose_detector import PoseDetector
8
  from dance_generator import DanceGenerator
9
  from dance_visualizer import DanceVisualizer
 
33
  # Add playback controls
34
  play_speed = st.sidebar.slider("Playback Speed", 0.1, 2.0, 1.0, 0.1)
35
 
36
+ # Add play/pause button
37
+ is_playing = st.sidebar.button("Play/Pause")
38
+
39
  if video_file:
40
+ self.process_video(video_file, mode, play_speed, is_playing)
41
 
42
+ def process_video(self, video_file, mode, play_speed, is_playing):
43
+ # Create a placeholder for the video
44
+ video_placeholder = st.empty()
45
+
46
+ # Create temporary file
47
  tfile = tempfile.NamedTemporaryFile(delete=False)
48
  tfile.write(video_file.read())
49
 
50
+ # Open video file
51
  cap = cv2.VideoCapture(tfile.name)
52
 
53
  # Get video properties
 
58
 
59
  # Initialize progress bar
60
  progress_bar = st.progress(0)
 
 
 
 
 
 
 
 
 
 
61
 
62
+ # Store video frames and AI sequences
63
+ frames = []
64
+ ai_sequences = []
65
 
66
+ # Pre-process video
67
+ with st.spinner('Processing video...'):
68
+ while cap.isOpened():
69
+ ret, frame = cap.read()
70
+ if not ret:
71
+ break
72
+
73
+ pose_landmarks = self.pose_detector.detect_pose(frame)
74
+ ai_frame = self.dance_generator.generate_dance_sequence(
75
+ [pose_landmarks],
76
+ mode,
77
+ 1,
78
+ (frame_height, frame_width)
79
+ )[0]
80
+
81
+ frames.append(frame)
82
+ ai_sequences.append(ai_frame)
83
 
84
  # Playback loop
85
  frame_count = 0
86
+ play = True if is_playing else False
87
+
88
+ while True:
89
+ if play:
90
+ if frame_count >= len(frames):
91
+ frame_count = 0
92
+
93
+ # Get current frames
94
+ frame = frames[frame_count]
95
+ ai_frame = ai_sequences[frame_count]
96
+
97
+ # Combine frames side by side
98
+ combined_frame = np.hstack([
99
+ frame,
100
+ cv2.resize(ai_frame, (frame_width, frame_height))
101
+ ])
102
+
103
+ # Update progress
104
+ progress = frame_count / total_frames
105
+ progress_bar.progress(progress)
106
+
107
+ # Display frame
108
+ video_placeholder.image(
109
+ combined_frame,
110
+ channels="BGR",
111
+ use_container_width=True
112
+ )
113
+
114
+ # Control playback speed
115
+ time.sleep(1 / (fps * play_speed))
116
+ frame_count += 1
117
+
118
+ # Check for play/pause button
119
+ if st.sidebar.button("Stop"):
120
  break
121
 
122
+ # Add replay button
123
+ if frame_count >= len(frames):
124
+ if st.sidebar.button("Replay"):
125
+ frame_count = 0
126
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  # Cleanup
128
  cap.release()
129
  os.unlink(tfile.name)