sandesh-bharadwaj commited on
Commit
0f442d9
·
1 Parent(s): d715645

Updated README for Gradio and removed streamlit main

Browse files
Files changed (2) hide show
  1. README.md +16 -14
  2. main.py +0 -343
README.md CHANGED
@@ -65,6 +65,7 @@ short_description: Generate tailored soundtracks for your videos.
65
  </p>
66
  </div>
67
 
 
68
 
69
 
70
  <!-- TABLE OF CONTENTS -->
@@ -105,6 +106,7 @@ short_description: Generate tailored soundtracks for your videos.
105
  1. [**Google Gemini**](https://ai.google.dev/gemini-api) - Google's largest and most capable multimodal AI model.
106
  2. [**MusicGen**](https://huggingface.co/facebook/musicgen-large) - Meta's text-to-music model, capable of generating high-quality music conditioned on text or audio prompts.
107
 
 
108
  <p align="right">(<a href="#readme-top">back to top</a>)</p>
109
 
110
 
@@ -145,34 +147,34 @@ While VidTune is supported on CPU-only machines, we recommend using a GPU with m
145
  [![Watch the video](https://img.youtube.com/vi/knbQjWZtL3Y/maxresdefault.jpg)](https://youtu.be/knbQjWZtL3Y)
146
 
147
  ## Running VidTune
148
- First, clone the repository:
149
  ```sh
150
  git clone https://github.com/sandesh-bharadwaj/VidTune.git
151
  cd VidTune
 
152
  ```
153
  ### Using conda
154
  If you're using conda as your virtual environment manager, do the following:
155
- ```
156
  conda env create -f environment.yml
157
  conda activate vidtune
158
 
159
- streamlit run main.py
 
 
 
 
160
  ```
161
 
162
  ### Using python / pip
163
- ```
164
  pip install -r requirements.txt
165
- streamlit run main.py
166
- ```
167
 
168
- ### Using Docker
169
- - [Docker](https://docs.docker.com/engine/install/)
170
- - [Nvidia Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#installing-with-apt)
171
 
172
- Docker Hub Image: https://hub.docker.com/r/animikhaich/vidtune
173
-
174
- ```
175
- docker run --rm -it --gpus all -p 8003:8003 animikhaich/vidtune
176
  ```
177
 
178
 
@@ -183,7 +185,7 @@ docker run --rm -it --gpus all -p 8003:8003 animikhaich/vidtune
183
  - [x] Flutter version of app for proof-of-concept
184
  - [x] MusicGen integration
185
  - [x] Audio Mixing
186
- - [x] Streamlit app
187
  - [x] Docker image
188
  - [ ] OpenVINO-optimized versions of MusicGen for CPU-Only use.
189
  - [ ] Support for music generation duration > 30 seconds.
 
65
  </p>
66
  </div>
67
 
68
+ # Gradio implementation of VidTune for Hugging Face ZeroGPU Spaces.
69
 
70
 
71
  <!-- TABLE OF CONTENTS -->
 
106
  1. [**Google Gemini**](https://ai.google.dev/gemini-api) - Google's largest and most capable multimodal AI model.
107
  2. [**MusicGen**](https://huggingface.co/facebook/musicgen-large) - Meta's text-to-music model, capable of generating high-quality music conditioned on text or audio prompts.
108
 
109
+
110
  <p align="right">(<a href="#readme-top">back to top</a>)</p>
111
 
112
 
 
147
  [![Watch the video](https://img.youtube.com/vi/knbQjWZtL3Y/maxresdefault.jpg)](https://youtu.be/knbQjWZtL3Y)
148
 
149
  ## Running VidTune
150
+ First, clone the repository and switch to the `gradio-dev` branch:
151
  ```sh
152
  git clone https://github.com/sandesh-bharadwaj/VidTune.git
153
  cd VidTune
154
+ git switch gradio-dev
155
  ```
156
  ### Using conda
157
  If you're using conda as your virtual environment manager, do the following:
158
+ ```sh
159
  conda env create -f environment.yml
160
  conda activate vidtune
161
 
162
+ # Hot Reload enabled
163
+ gradio app.py
164
+
165
+ # w/o Hot Reload
166
+ python app.py
167
  ```
168
 
169
  ### Using python / pip
170
+ ```sh
171
  pip install -r requirements.txt
 
 
172
 
173
+ # Hot Reload enabled
174
+ gradio app.py
 
175
 
176
+ # w/o Hot Reload
177
+ python app.py
 
 
178
  ```
179
 
180
 
 
185
  - [x] Flutter version of app for proof-of-concept
186
  - [x] MusicGen integration
187
  - [x] Audio Mixing
188
+ - [x] Gradio app
189
  - [x] Docker image
190
  - [ ] OpenVINO-optimized versions of MusicGen for CPU-Only use.
191
  - [ ] Support for music generation duration > 30 seconds.
main.py DELETED
@@ -1,343 +0,0 @@
1
- import streamlit as st
2
- from engine import DescribeVideo, GenerateAudio
3
- import os
4
- from moviepy.editor import VideoFileClip, AudioFileClip, CompositeAudioClip
5
- from moviepy.audio.fx.volumex import volumex
6
- from streamlit.runtime.scriptrunner import get_script_run_ctx
7
-
8
-
9
- def get_session_id():
10
- session_id = get_script_run_ctx().session_id
11
- session_id = session_id.replace("-", "_")
12
- session_id = "_id_" + session_id
13
- return session_id
14
-
15
-
16
- user_session_id = get_session_id()
17
- os.makedirs(user_session_id, exist_ok=True)
18
- # Define model maps
19
- video_model_map = {
20
- "Fast": "flash",
21
- "Quality": "pro",
22
- }
23
-
24
- music_model_map = {
25
- "Fast": "musicgen-stereo-small",
26
- "Balanced": "musicgen-stereo-medium",
27
- "Quality": "musicgen-stereo-large",
28
- }
29
-
30
- # music_model_map = {
31
- # "Fast": "facebook/musicgen-melody",
32
- # "Quality": "facebook/musicgen-melody-large",
33
- # }
34
-
35
- genre_map = {
36
- "None": None,
37
- "Pop": "Pop",
38
- "Rock": "Rock",
39
- "Hip Hop": "Hip-Hop/Rap",
40
- "Jazz": "Jazz",
41
- "Classical": "Classical",
42
- "Blues": "Blues",
43
- "Country": "Country",
44
- "EDM": "Electronic/Dance",
45
- "Metal": "Metal",
46
- "Disco": "Disco",
47
- "Lo-Fi": "Lo-Fi",
48
- }
49
-
50
- # Streamlit page configuration
51
- st.set_page_config(
52
- page_title="VidTune: Where Videos Find Their Melody",
53
- layout="centered",
54
- page_icon="assets/favicon.png",
55
- )
56
-
57
- left_co, cent_co, last_co = st.columns(3)
58
- with cent_co:
59
- st.image("assets/VidTune-Logo-Without-BG.png", use_column_width=False, width=200)
60
-
61
- # Title and Description
62
- st.markdown(
63
- """
64
- <style>
65
- h2, p, div, img {
66
- text-align: center;
67
- }
68
- </style>
69
- <div style="font-size: 35px; font-weight: bold;">VidTune: Where Videos Find Their Melody</div>
70
- <p>VidTune is a web application to effortlessly tailor perfect soundtracks for your videos with AI.</p>
71
- """,
72
- unsafe_allow_html=True,
73
- )
74
-
75
- # Initialize session state for advanced settings and other inputs
76
- if "show_advanced" not in st.session_state:
77
- st.session_state.show_advanced = False
78
- if "video_model" not in st.session_state:
79
- st.session_state.video_model = "Fast"
80
- if "music_model" not in st.session_state:
81
- st.session_state.music_model = "Fast"
82
- if "num_samples" not in st.session_state:
83
- st.session_state.num_samples = 3
84
- if "music_genre" not in st.session_state:
85
- st.session_state.music_genre = None
86
- if "music_bpm" not in st.session_state:
87
- st.session_state.music_bpm = 100
88
- if "user_keywords" not in st.session_state:
89
- st.session_state.user_keywords = None
90
- if "selected_audio" not in st.session_state:
91
- st.session_state.selected_audio = "None"
92
- if "audio_paths" not in st.session_state:
93
- st.session_state.audio_paths = []
94
- if "selected_audio_path" not in st.session_state:
95
- st.session_state.selected_audio_path = None
96
- if "orig_audio_vol" not in st.session_state:
97
- st.session_state.orig_audio_vol = 100
98
- if "generated_audio_vol" not in st.session_state:
99
- st.session_state.generated_audio_vol = 100
100
- if "generate_button_flag" not in st.session_state:
101
- st.session_state.generate_button_flag = False
102
- if "video_description_content" not in st.session_state:
103
- st.session_state.video_description_content = ""
104
- if "music_prompt" not in st.session_state:
105
- st.session_state.music_prompt = ""
106
- if "audio_mix_flag" not in st.session_state:
107
- st.session_state.audio_mix_flag = False
108
- if "google_api_key" not in st.session_state:
109
- st.session_state.google_api_key = ""
110
-
111
- # Sidebar
112
- st.sidebar.title("Configuration")
113
-
114
- # Google API Key
115
- st.session_state.google_api_key = st.sidebar.text_input(
116
- "Enter your [Google API Key](https://ai.google.dev/gemini-api/docs/api-key) to get started :",
117
- st.session_state.google_api_key,
118
- type="password",
119
- )
120
-
121
- if not st.session_state.google_api_key:
122
- st.warning("Please enter your Google API Key to proceed.")
123
- st.stop()
124
-
125
- # Basic Settings
126
- st.session_state.video_model = st.sidebar.selectbox(
127
- "Select Video Descriptor",
128
- ["Fast", "Quality"],
129
- index=["Fast", "Quality"].index(st.session_state.video_model),
130
- )
131
- st.session_state.music_model = st.sidebar.selectbox(
132
- "Select Music Generator",
133
- ["Fast", "Balanced", "Quality"],
134
- index=["Fast", "Balanced", "Quality"].index(st.session_state.music_model),
135
- )
136
- st.session_state.num_samples = st.sidebar.slider(
137
- "Number of samples", 1, 5, st.session_state.num_samples
138
- )
139
-
140
- # Sidebar for advanced settings
141
- with st.sidebar:
142
- # Create a placeholder for the advanced settings button
143
- placeholder = st.empty()
144
-
145
- # Button to toggle advanced settings
146
- if placeholder.button("Advanced"):
147
- st.session_state.show_advanced = not st.session_state.show_advanced
148
- st.rerun() # Refresh the layout after button click
149
-
150
- # Display advanced settings if enabled
151
- if st.session_state.show_advanced:
152
- # Advanced settings
153
- st.session_state.music_bpm = st.sidebar.slider("Beats Per Minute", 35, 180, 100)
154
- st.session_state.music_genre = st.sidebar.selectbox(
155
- "Select Music Genre",
156
- list(genre_map.keys()),
157
- index=(
158
- list(genre_map.keys()).index(st.session_state.music_genre)
159
- if st.session_state.music_genre in genre_map.keys()
160
- else 0
161
- ),
162
- )
163
- st.session_state.user_keywords = st.sidebar.text_input(
164
- "User Keywords",
165
- value=st.session_state.user_keywords,
166
- help="Enter keywords separated by commas.",
167
- )
168
- else:
169
- st.session_state.music_genre = None
170
- st.session_state.music_bpm = None
171
- st.session_state.user_keywords = None
172
-
173
- # Generate Button
174
- generate_button = st.sidebar.button("Generate Music")
175
-
176
-
177
- # Cache the model loading
178
- @st.cache_resource
179
- def load_models(video_model_key, music_model_key, google_api_key):
180
- video_descriptor = DescribeVideo(
181
- model=video_model_map[video_model_key], google_api_key=google_api_key
182
- )
183
- audio_generator = GenerateAudio(model=music_model_map[music_model_key])
184
- if audio_generator.device == "cpu":
185
- st.warning(
186
- "The music generator model is running on CPU. For faster results, consider using a GPU."
187
- )
188
- return video_descriptor, audio_generator
189
-
190
-
191
- # Load models
192
- video_descriptor, audio_generator = load_models(
193
- st.session_state.video_model,
194
- st.session_state.music_model,
195
- st.session_state.google_api_key,
196
- )
197
-
198
- # Video Uploader
199
- uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
200
- if uploaded_video is not None:
201
- st.session_state.uploaded_video = uploaded_video
202
- with open(f"{user_session_id}/temp.mp4", mode="wb") as w:
203
- w.write(uploaded_video.getvalue())
204
-
205
- # Video Player
206
- if os.path.exists(f"{user_session_id}/temp.mp4") and uploaded_video is not None:
207
- st.video(uploaded_video)
208
-
209
- # Submit button if video is not uploaded
210
- if generate_button:
211
- if uploaded_video is None:
212
- st.error("Please upload a video before generating music.")
213
- st.stop()
214
-
215
- with st.spinner("Analyzing video..."):
216
- video_description = video_descriptor.describe_video(
217
- f"{user_session_id}/temp.mp4",
218
- genre=st.session_state.music_genre,
219
- bpm=st.session_state.music_bpm,
220
- user_keywords=st.session_state.user_keywords,
221
- )
222
- video_duration = VideoFileClip(f"{user_session_id}/temp.mp4").duration
223
- st.session_state.video_description_content = video_description[
224
- "Content Description"
225
- ]
226
- st.session_state.music_prompt = video_description["Music Prompt"]
227
-
228
- st.success("Video description generated successfully.")
229
- st.session_state.generate_button_flag = True
230
-
231
- # Display Video Description and Music Prompt
232
- if st.session_state.generate_button_flag:
233
- st.text_area(
234
- "Video Description",
235
- st.session_state.video_description_content,
236
- disabled=True,
237
- height=120,
238
- )
239
- music_prompt = st.text_area(
240
- "Music Prompt",
241
- st.session_state.music_prompt,
242
- disabled=True,
243
- height=120,
244
- )
245
-
246
- if generate_button:
247
- # Generate Music
248
- with st.spinner("Generating music..."):
249
- if video_duration > 30:
250
- st.warning(
251
- "Due to hardware limitations, the maximum music length is capped at 30 seconds."
252
- )
253
- music_prompt = [st.session_state.music_prompt] * st.session_state.num_samples
254
- audio_generator.generate_audio(music_prompt, duration=video_duration)
255
- st.session_state.audio_paths = audio_generator.save_audio()
256
- st.success("Music generated successfully.")
257
- st.balloons()
258
-
259
-
260
- # Callback function for radio button selection change
261
- def on_audio_selection_change():
262
- st.session_state.audio_mix_flag = False
263
- selected_audio_index = st.session_state.selected_audio
264
- if selected_audio_index > 0:
265
- st.session_state.selected_audio_path = st.session_state.audio_paths[
266
- selected_audio_index - 1
267
- ]
268
- else:
269
- st.session_state.selected_audio_path = None
270
-
271
-
272
- if st.session_state.audio_paths:
273
- # Dropdown to select one of the generated audio files
274
- audio_options = ["None"] + [
275
- f"Generated Music {i+1}" for i in range(len(st.session_state.audio_paths))
276
- ]
277
-
278
- # Display the audio files
279
- for i, audio_path in enumerate(st.session_state.audio_paths):
280
- st.audio(audio_path, format="audio/wav")
281
-
282
- selected_audio_index = st.selectbox(
283
- "Select one of the generated audio files for further processing:",
284
- range(len(audio_options)),
285
- format_func=lambda x: audio_options[x],
286
- index=0,
287
- key="selected_audio",
288
- on_change=on_audio_selection_change,
289
- )
290
-
291
- # Button to confirm the selection
292
- if st.button("Add Generated Music to Video"):
293
- st.session_state.audio_mix_flag = True
294
-
295
- # Handle Audio Mixing and Export
296
- if st.session_state.selected_audio_path is not None and st.session_state.audio_mix_flag:
297
- with st.spinner("Mixing Audio..."):
298
- orig_clip = VideoFileClip(f"{user_session_id}/temp.mp4")
299
- orig_clip_audio = orig_clip.audio
300
- generated_audio = AudioFileClip(st.session_state.selected_audio_path)
301
-
302
- st.session_state.orig_audio_vol = st.slider(
303
- "Original Audio Volume",
304
- 0,
305
- 200,
306
- st.session_state.orig_audio_vol,
307
- format="%d%%",
308
- )
309
-
310
- st.session_state.generated_audio_vol = st.slider(
311
- "Generated Music Volume",
312
- 0,
313
- 200,
314
- st.session_state.generated_audio_vol,
315
- format="%d%%",
316
- )
317
-
318
- orig_clip_audio = volumex(
319
- orig_clip_audio, float(st.session_state.orig_audio_vol / 100)
320
- )
321
- generated_audio = volumex(
322
- generated_audio, float(st.session_state.generated_audio_vol / 100)
323
- )
324
-
325
- orig_clip.audio = CompositeAudioClip([orig_clip_audio, generated_audio])
326
-
327
- final_video_path = f"{user_session_id}/out_tmp.mp4"
328
- orig_clip.write_videofile(final_video_path)
329
-
330
- orig_clip.close()
331
- generated_audio.close()
332
-
333
- st.session_state.final_video_path = final_video_path
334
-
335
- st.video(final_video_path)
336
- if st.session_state.final_video_path:
337
- with open(st.session_state.final_video_path, "rb") as video_file:
338
- st.download_button(
339
- label="Download final video",
340
- data=video_file,
341
- file_name="final_video.mp4",
342
- mime="video/mp4",
343
- )