Sandesh Bharadwaj commited on
Commit
599e71e
·
unverified ·
2 Parent(s): ae68709 ee4f393

Merge pull request #2 from animikhaich/web-app-dev

Browse files
Files changed (3) hide show
  1. engine/audio_generator.py +30 -16
  2. engine/video_descriptor.py +23 -8
  3. main.py +165 -67
engine/audio_generator.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  import warnings
 
3
 
4
  warnings.simplefilter("ignore")
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -30,7 +31,9 @@ class GenerateAudio:
30
  logging.info(f"Loaded model: {model}")
31
  return model
32
  except Exception as e:
33
- logging.error(f"Failed to load model: {e}")
 
 
34
  raise ValueError(f"Failed to load model: {e}")
35
  return
36
 
@@ -39,14 +42,18 @@ class GenerateAudio:
39
  if model_name.startswith("facebook/"):
40
  return model_name
41
  return f"facebook/{model_name}"
42
-
43
  @staticmethod
44
  def duration_sanity_check(duration):
45
  if duration < 1:
46
- logging.warning("Duration is less than 1 second. Setting duration to 1 second.")
 
 
47
  return 1
48
  elif duration > 30:
49
- logging.warning("Duration is greater than 30 seconds. Setting duration to 30 seconds.")
 
 
50
  return 30
51
  return duration
52
 
@@ -60,16 +67,16 @@ class GenerateAudio:
60
  for prompt in prompts:
61
  if not isinstance(prompt, str):
62
  raise ValueError("Prompts should be a string or a list of strings.")
63
- if len(prompts) > 8: # Too many prompts will cause OOM error
64
  raise ValueError("Maximum number of prompts allowed is 8.")
65
  return prompts
66
-
67
 
68
  def generate_audio(self, prompts, duration=10):
69
  duration = self.duration_sanity_check(duration)
70
  prompts = self.prompts_sanity_check(prompts)
71
 
72
  try:
 
73
  if duration <= 30:
74
  self.model.set_generation_params(duration=duration)
75
  result = self.model.generate(prompts, progress=False)
@@ -77,17 +84,23 @@ class GenerateAudio:
77
  self.model.set_generation_params(duration=30)
78
  result = self.model.generate(prompts, progress=False)
79
  self.model.set_generation_params(duration=duration)
80
- result = self.model.generate_with_chroma(prompts, result, melody_sample_rate=self.sampling_rate, progress=False)
 
 
 
 
 
81
  self.result = result.cpu().numpy().T
82
  self.result = self.result.transpose((2, 0, 1))
83
- self.sampling_rate = self.model.sample_rate
84
  logging.info(
85
  f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
86
  )
87
- print(f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz")
88
  return self.sampling_rate, self.result
89
  except Exception as e:
90
- logging.error(f"Failed to generate audio: {e}")
 
 
91
  raise ValueError(f"Failed to generate audio: {e}")
92
 
93
  def save_audio(self, audio_dir="generated_audio"):
@@ -118,17 +131,18 @@ class GenerateAudio:
118
  buffers.append(buffer)
119
  return buffers
120
 
 
121
  if __name__ == "__main__":
122
  audio_gen = GenerateAudio()
123
  sample_rate, result = audio_gen.generate_audio(
124
  [
125
- "A piano playing a jazz melody",
126
- "A guitar playing a rock riff",
127
- "A LoFi music for coding"
128
- ],
129
- duration=10
130
  )
131
  paths = audio_gen.save_audio()
132
  print(f"Saved audio to: {paths}")
133
  buffers = audio_gen.get_audio_buffer()
134
- print(f"Audio buffers: {buffers}")
 
1
  import os
2
  import warnings
3
+ import traceback
4
 
5
  warnings.simplefilter("ignore")
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 
31
  logging.info(f"Loaded model: {model}")
32
  return model
33
  except Exception as e:
34
+ logging.error(
35
+ f"Failed to load model: {e}, Traceback: {traceback.format_exc()}"
36
+ )
37
  raise ValueError(f"Failed to load model: {e}")
38
  return
39
 
 
42
  if model_name.startswith("facebook/"):
43
  return model_name
44
  return f"facebook/{model_name}"
45
+
46
  @staticmethod
47
  def duration_sanity_check(duration):
48
  if duration < 1:
49
+ logging.warning(
50
+ "Duration is less than 1 second. Setting duration to 1 second."
51
+ )
52
  return 1
53
  elif duration > 30:
54
+ logging.warning(
55
+ "Duration is greater than 30 seconds. Setting duration to 30 seconds."
56
+ )
57
  return 30
58
  return duration
59
 
 
67
  for prompt in prompts:
68
  if not isinstance(prompt, str):
69
  raise ValueError("Prompts should be a string or a list of strings.")
70
+ if len(prompts) > 8: # Too many prompts will cause OOM error
71
  raise ValueError("Maximum number of prompts allowed is 8.")
72
  return prompts
 
73
 
74
  def generate_audio(self, prompts, duration=10):
75
  duration = self.duration_sanity_check(duration)
76
  prompts = self.prompts_sanity_check(prompts)
77
 
78
  try:
79
+ self.sampling_rate = self.model.sample_rate
80
  if duration <= 30:
81
  self.model.set_generation_params(duration=duration)
82
  result = self.model.generate(prompts, progress=False)
 
84
  self.model.set_generation_params(duration=30)
85
  result = self.model.generate(prompts, progress=False)
86
  self.model.set_generation_params(duration=duration)
87
+ result = self.model.generate_with_chroma(
88
+ prompts,
89
+ result,
90
+ melody_sample_rate=self.sampling_rate,
91
+ progress=False,
92
+ )
93
  self.result = result.cpu().numpy().T
94
  self.result = self.result.transpose((2, 0, 1))
95
+
96
  logging.info(
97
  f"Generated audio with shape: {self.result.shape}, sample rate: {self.sampling_rate} Hz"
98
  )
 
99
  return self.sampling_rate, self.result
100
  except Exception as e:
101
+ logging.error(
102
+ f"Failed to generate audio: {e}, Traceback: {traceback.format_exc()}"
103
+ )
104
  raise ValueError(f"Failed to generate audio: {e}")
105
 
106
  def save_audio(self, audio_dir="generated_audio"):
 
131
  buffers.append(buffer)
132
  return buffers
133
 
134
+
135
  if __name__ == "__main__":
136
  audio_gen = GenerateAudio()
137
  sample_rate, result = audio_gen.generate_audio(
138
  [
139
+ "A piano playing a jazz melody",
140
+ "A guitar playing a rock riff",
141
+ "A LoFi music for coding",
142
+ ],
143
+ duration=10,
144
  )
145
  paths = audio_gen.save_audio()
146
  print(f"Saved audio to: {paths}")
147
  buffers = audio_gen.get_audio_buffer()
148
+ print(f"Audio buffers: {buffers}")
engine/video_descriptor.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  from warnings import simplefilter
 
3
 
4
  simplefilter("ignore")
5
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
@@ -43,26 +44,38 @@ class DescribeVideo:
43
  self.safety_settings = self.get_safety_settings()
44
 
45
  genai.configure(api_key=__api_key)
46
- self.mllm_model = genai.GenerativeModel(self.model, system_instruction=gemini_instructions)
 
 
47
 
48
  logging.info(f"Initialized DescribeVideo with model: {self.model}")
49
 
50
  def describe_video(self, video_path, genre, bpm, user_keywords):
51
  video_file = genai.upload_file(video_path)
52
- logging.info(f"Uploaded video: {video_path}")
53
 
54
  while video_file.state.name == "PROCESSING":
55
  time.sleep(0.25)
56
  video_file = genai.get_file(video_file.name)
57
 
58
  if video_file.state.name == "FAILED":
59
- logging.error(f"Failed to upload video: {video_file.state.name}")
 
 
60
  raise ValueError(f"Failed to upload video: {video_file.state.name}")
61
-
62
- additional_keywords = ", ".join([genre, user_keywords, bpm]) + "bpm"
 
 
 
 
 
 
 
 
 
63
 
64
  response = self.mllm_model.generate_content(
65
- [video_file, f"Explain what is happening in this video. The following keywords are provided by the user for generating the music prompt: {additional_keywords}"],
66
  request_options={"timeout": 600},
67
  safety_settings=self.safety_settings,
68
  )
@@ -116,7 +129,9 @@ class DescribeVideo:
116
 
117
  api_key = creds.get("google_api_key", None)
118
  if api_key is None or not isinstance(api_key, str):
119
- logging.error(f"Google API key not found in {path}")
 
 
120
  raise ValueError(f"Gemini API key not found in {path}")
121
  return api_key
122
 
@@ -129,7 +144,7 @@ class DescribeVideo:
129
 
130
  if model not in models:
131
  logging.error(
132
- f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
133
  )
134
  raise ValueError(
135
  f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
 
1
  import os
2
  from warnings import simplefilter
3
+ import traceback
4
 
5
  simplefilter("ignore")
6
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
 
44
  self.safety_settings = self.get_safety_settings()
45
 
46
  genai.configure(api_key=__api_key)
47
+ self.mllm_model = genai.GenerativeModel(
48
+ self.model, system_instruction=gemini_instructions
49
+ )
50
 
51
  logging.info(f"Initialized DescribeVideo with model: {self.model}")
52
 
53
  def describe_video(self, video_path, genre, bpm, user_keywords):
54
  video_file = genai.upload_file(video_path)
 
55
 
56
  while video_file.state.name == "PROCESSING":
57
  time.sleep(0.25)
58
  video_file = genai.get_file(video_file.name)
59
 
60
  if video_file.state.name == "FAILED":
61
+ logging.error(
62
+ f"Failed to upload video: {video_file.state.name}, Traceback: {traceback.format_exc()}"
63
+ )
64
  raise ValueError(f"Failed to upload video: {video_file.state.name}")
65
+
66
+ additional_keywords = ", ".join(filter(None, [genre, user_keywords])) + (
67
+ f", {bpm} bpm" if bpm else ""
68
+ )
69
+
70
+ logging.info(f"Uploaded video: {video_path} and config: {additional_keywords}")
71
+
72
+ user_prompt = "Explain what is happening in this video."
73
+
74
+ if additional_keywords:
75
+ user_prompt += f" The following keywords are provided by the user for generating the music prompt: {additional_keywords}"
76
 
77
  response = self.mllm_model.generate_content(
78
+ [video_file, user_prompt],
79
  request_options={"timeout": 600},
80
  safety_settings=self.safety_settings,
81
  )
 
129
 
130
  api_key = creds.get("google_api_key", None)
131
  if api_key is None or not isinstance(api_key, str):
132
+ logging.error(
133
+ f"Google API key not found in {path}, Traceback: {traceback.format_exc()}"
134
+ )
135
  raise ValueError(f"Gemini API key not found in {path}")
136
  return api_key
137
 
 
144
 
145
  if model not in models:
146
  logging.error(
147
+ f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}, Traceback: {traceback.format_exc()}"
148
  )
149
  raise ValueError(
150
  f"Invalid model name '{model}'. Valid options are: {', '.join(models.keys())}"
main.py CHANGED
@@ -1,7 +1,9 @@
1
  import streamlit as st
2
  from engine import DescribeVideo, GenerateAudio
 
 
3
 
4
-
5
  video_model_map = {
6
  "Fast": "flash",
7
  "Quality": "pro",
@@ -13,79 +15,175 @@ music_model_map = {
13
  "Quality": "musicgen-stereo-large",
14
  }
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- st.set_page_config(page_title="VidTune: Where Videos Find Their Melody", layout="centered")
 
 
18
 
19
  # Title and Description
20
  st.title("VidTune: Where Videos Find Their Melody")
21
- st.write("VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video.")
 
 
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Sidebar
25
  st.sidebar.title("Settings")
26
- video_model = st.sidebar.selectbox("Select Video Descriptor", ["Fast", "Balanced", "Quality"], index=0)
27
- music_model = st.sidebar.selectbox("Select Music Generator", ["Fast", "Balanced", "Quality"], index=0)
28
- num_samples = st.sidebar.slider("Number of samples", 1, 8, 3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  generate_button = st.sidebar.button("Generate Music")
30
 
31
- video_descriptor = DescribeVideo(model=video_model_map[video_model])
32
- audio_generator = GenerateAudio(model=music_model_map[music_model])
33
-
34
- video_description = None
35
-
36
- # Main Page (Page 1)
37
- if 'page' not in st.session_state:
38
- st.session_state.page = 'main'
39
-
40
- if st.session_state.page == 'main':
41
- st.header("Video to Music")
42
- uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
43
-
44
- if uploaded_video is not None:
45
- st.session_state.uploaded_video = uploaded_video
46
- with open("temp.mp4", mode='wb') as w:
47
- w.write(uploaded_video.getvalue())
48
- video_description = video_descriptor.describe_video("temp.mp4")
49
-
50
- st.session_state.page = 'video_to_music'
51
-
52
- if st.session_state.page == 'main':
53
- st.header("Prompt to Music")
54
- prompt = st.text_area("Prompt")
55
- if generate_button:
56
- st.session_state.prompt = prompt
57
- st.session_state.page = 'prompt_to_music'
58
-
59
- # Page 2a (If the user uploads a video)
60
- if st.session_state.page == 'video_to_music':
61
- st.video(st.session_state.uploaded_video)
62
-
63
- st.text_area("Video Description", "This is a fixed video description", disabled=True)
64
- st.text_area("Music Description")
65
-
66
- if generate_button:
67
- st.session_state.page = 'result'
68
- st.session_state.device = device
69
- st.session_state.num_samples = num_samples
70
-
71
- # Page 2b (If user selects "Prompt to Music" in Page 1)
72
- if st.session_state.page == 'prompt_to_music':
73
- st.sidebar.title("Settings")
74
- device = st.sidebar.selectbox("Select Device", ["GPU", "CPU"], index=0)
75
- num_samples = st.sidebar.slider("Number of samples", 1, 10, 3)
76
-
77
- if generate_button:
78
- st.session_state.page = 'result'
79
- st.session_state.device = device
80
- st.session_state.num_samples = num_samples
81
-
82
- # Page 3 (Results Page)
83
- if st.session_state.page == 'result':
84
- st.header("Generated Music")
85
- for i in range(st.session_state.num_samples):
86
- st.write(f"Music Sample {i+1}")
87
- st.audio(f"Generated Music {i+1}.mp3", format='audio/mp3')
88
- st.download_button(f"Download Music {i+1}", f"Generated Music {i+1}.mp3")
89
-
90
- if st.button("Start Over"):
91
- st.session_state.page = 'main'
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  from engine import DescribeVideo, GenerateAudio
3
+ import os
4
+ from moviepy.editor import VideoFileClip
5
 
6
+ # Define model maps
7
  video_model_map = {
8
  "Fast": "flash",
9
  "Quality": "pro",
 
15
  "Quality": "musicgen-stereo-large",
16
  }
17
 
18
+ # music_model_map = {
19
+ # "Fast": "facebook/musicgen-melody",
20
+ # "Quality": "facebook/musicgen-melody-large",
21
+ # }
22
+
23
+ genre_map = {
24
+ "None": None,
25
+ "Pop": "Pop",
26
+ "Rock": "Rock",
27
+ "Hip Hop": "Hip-Hop/Rap",
28
+ "Jazz": "Jazz",
29
+ "Classical": "Classical",
30
+ "Blues": "Blues",
31
+ "Country": "Country",
32
+ "EDM": "Electronic/Dance",
33
+ "Metal": "Metal",
34
+ "Disco": "Disco",
35
+ "Lo-Fi": "Lo-Fi",
36
+ }
37
 
38
+ st.set_page_config(
39
+ page_title="VidTune: Where Videos Find Their Melody", layout="centered"
40
+ )
41
 
42
  # Title and Description
43
  st.title("VidTune: Where Videos Find Their Melody")
44
+ st.write(
45
+ "VidTune is a web application that allows users to upload videos and generate melodies matching the mood of the video."
46
+ )
47
 
48
+ # Initialize session state for advanced settings and other inputs
49
+ if "show_advanced" not in st.session_state:
50
+ st.session_state.show_advanced = False
51
+ if "video_model" not in st.session_state:
52
+ st.session_state.video_model = "Fast"
53
+ if "music_model" not in st.session_state:
54
+ st.session_state.music_model = "Fast"
55
+ if "num_samples" not in st.session_state:
56
+ st.session_state.num_samples = 3
57
+ if "music_genre" not in st.session_state:
58
+ st.session_state.music_genre = None
59
+ if "music_bpm" not in st.session_state:
60
+ st.session_state.music_bpm = 100
61
+ if "user_keywords" not in st.session_state:
62
+ st.session_state.user_keywords = None
63
 
64
  # Sidebar
65
  st.sidebar.title("Settings")
66
+
67
+ # Basic Settings
68
+ st.session_state.video_model = st.sidebar.selectbox(
69
+ "Select Video Descriptor",
70
+ ["Fast", "Quality"],
71
+ index=["Fast", "Quality"].index(st.session_state.video_model),
72
+ )
73
+ st.session_state.music_model = st.sidebar.selectbox(
74
+ "Select Music Generator",
75
+ ["Fast", "Balanced", "Quality"],
76
+ index=["Fast", "Balanced", "Quality"].index(st.session_state.music_model),
77
+ )
78
+ st.session_state.num_samples = st.sidebar.slider(
79
+ "Number of samples", 1, 5, st.session_state.num_samples
80
+ )
81
+
82
+ # Sidebar for advanced settings
83
+ with st.sidebar:
84
+ # Create a placeholder for the advanced settings button
85
+ placeholder = st.empty()
86
+
87
+ # Button to toggle advanced settings
88
+ if placeholder.button("Advanced"):
89
+ st.session_state.show_advanced = not st.session_state.show_advanced
90
+ st.rerun() # Refresh the layout after button click
91
+
92
+ # Display advanced settings if enabled
93
+ if st.session_state.show_advanced:
94
+ # Advanced settings
95
+ st.session_state.music_bpm = st.sidebar.slider("Beats Per Minute", 35, 180, 100)
96
+ st.session_state.music_genre = st.sidebar.selectbox(
97
+ "Select Music Genre",
98
+ list(genre_map.keys()),
99
+ index=(
100
+ list(genre_map.keys()).index(st.session_state.music_genre)
101
+ if st.session_state.music_genre in genre_map.keys()
102
+ else 0
103
+ ),
104
+ )
105
+ st.session_state.user_keywords = st.sidebar.text_input(
106
+ "User Keywords",
107
+ value=st.session_state.user_keywords,
108
+ help="Enter keywords separated by commas.",
109
+ )
110
+ else:
111
+ st.session_state.music_genre = None
112
+ st.session_state.music_bpm = None
113
+ st.session_state.user_keywords = None
114
+
115
+ # Generate Button
116
  generate_button = st.sidebar.button("Generate Music")
117
 
118
+
119
+ # Cache the model loading
120
+ @st.cache_resource
121
+ def load_models(video_model_key, music_model_key):
122
+ video_descriptor = DescribeVideo(model=video_model_map[video_model_key])
123
+ audio_generator = GenerateAudio(model=music_model_map[music_model_key])
124
+ return video_descriptor, audio_generator
125
+
126
+
127
+ # Load models
128
+ video_descriptor, audio_generator = load_models(
129
+ st.session_state.video_model, st.session_state.music_model
130
+ )
131
+
132
+ # Video Uploader
133
+ uploaded_video = st.file_uploader("Upload Video", type=["mp4"])
134
+ if uploaded_video is not None:
135
+ st.session_state.uploaded_video = uploaded_video
136
+ with open("temp.mp4", mode="wb") as w:
137
+ w.write(uploaded_video.getvalue())
138
+
139
+ # Video Player
140
+ if os.path.exists("temp.mp4") and uploaded_video is not None:
141
+ st.video(uploaded_video)
142
+
143
+ # Submit button if video is not uploaded
144
+ if generate_button and uploaded_video is None:
145
+ st.error("Please upload a video before generating music.")
146
+ st.stop()
147
+
148
+ # Submit Button and music generation if video is uploaded
149
+ if generate_button and uploaded_video is not None:
150
+ with st.spinner("Analyzing video..."):
151
+ video_description = video_descriptor.describe_video(
152
+ "temp.mp4",
153
+ genre=st.session_state.music_genre,
154
+ bpm=st.session_state.music_bpm,
155
+ user_keywords=st.session_state.user_keywords,
156
+ )
157
+ video_duration = VideoFileClip("temp.mp4").duration
158
+ music_prompt = video_description["Music Prompt"]
159
+
160
+ st.success("Video description generated successfully.")
161
+
162
+ # Display Video Description and Music Prompt
163
+ st.text_area(
164
+ "Video Description",
165
+ video_description["Content Description"],
166
+ disabled=True,
167
+ height=120,
168
+ )
169
+ music_prompt = st.text_area(
170
+ "Music Prompt",
171
+ music_prompt,
172
+ disabled=False,
173
+ height=120,
174
+ )
175
+
176
+ # Generate Music
177
+ with st.spinner("Generating music..."):
178
+ if video_duration > 30:
179
+ st.warning(
180
+ "Due to hardware limitations, the maximum music length is capped at 30 seconds."
181
+ )
182
+ music_prompt = [music_prompt] * st.session_state.num_samples
183
+ audio_generator.generate_audio(music_prompt, duration=video_duration)
184
+ audio_paths = audio_generator.save_audio()
185
+ st.success("Music generated successfully.")
186
+ for i, audio_path in enumerate(audio_paths):
187
+ st.audio(audio_path, format="audio/wav")
188
+
189
+ st.balloons()