flosstradamus commited on
Commit
771145b
·
verified ·
1 Parent(s): 923a82b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -91
app.py CHANGED
@@ -10,6 +10,10 @@ import numpy as np
10
  import re
11
  import requests
12
  from urllib.parse import urlparse
 
 
 
 
13
 
14
  # Import necessary functions and classes
15
  from utils import load_t5, load_clap
@@ -39,16 +43,20 @@ def unload_current_model():
39
  global_model = None
40
 
41
  def download_model(url):
42
- response = requests.get(url, stream=True)
43
- if response.status_code == 200:
44
- filename = os.path.basename(urlparse(url).path)
45
- model_path = os.path.join("/tmp", filename)
46
- with open(model_path, "wb") as f:
47
- for chunk in response.iter_content(chunk_size=8192):
48
- f.write(chunk)
49
- return model_path
50
- else:
51
- raise Exception(f"Failed to download model from {url}")
 
 
 
 
52
 
53
  def load_model(url):
54
  global global_model
@@ -57,7 +65,7 @@ def load_model(url):
57
 
58
  unload_current_model()
59
 
60
- print(f"Downloading model from {url}")
61
  model_path = download_model(url)
62
 
63
  # Determine model size from filename
@@ -73,17 +81,17 @@ def load_model(url):
73
  else:
74
  model_size = "base" # Default to base if unrecognized
75
 
76
- print(f"Loading {model_size} model: {filename}")
77
 
78
  global_model = build_model(model_size).to(device)
79
  state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
80
  global_model.load_state_dict(state_dict['ema'])
81
  global_model.eval()
82
  global_model.model_path = model_path
83
- print("Model loaded successfully")
84
  return "Model loaded successfully"
85
  except Exception as e:
86
- print(f"Error loading model: {str(e)}")
87
  return f"Error loading model: {str(e)}"
88
 
89
  def load_resources():
@@ -91,20 +99,20 @@ def load_resources():
91
 
92
  device = "cuda" if torch.cuda.is_available() else "cpu"
93
 
94
- print("Loading T5 and CLAP models...")
95
  global_t5 = load_t5(device, max_length=256)
96
  global_clap = load_clap(device, max_length=256)
97
 
98
- print("Loading VAE and vocoder...")
99
  global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
100
  global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
101
 
102
- print("Initializing diffusion...")
103
  global_diffusion = RF()
104
 
105
- print("Base resources loaded successfully!")
106
 
107
- def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
108
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
109
 
110
  if global_model is None:
@@ -112,85 +120,89 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progres
112
 
113
  if seed == 0:
114
  seed = random.randint(1, 1000000)
115
- print(f"Using seed: {seed}")
116
 
117
  device = "cuda" if torch.cuda.is_available() else "cpu"
118
  torch.manual_seed(seed)
119
  torch.set_grad_enabled(False)
120
 
121
- # Calculate the number of segments needed for the desired duration
122
- segment_duration = 10 # Each segment is 10 seconds
123
- num_segments = int(np.ceil(duration / segment_duration))
 
124
 
125
- all_waveforms = []
126
 
127
- for i in range(num_segments):
128
- progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
129
 
130
- # Use the same seed for all segments
131
- torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
132
 
133
- latent_size = (256, 16)
134
- conds_txt = [prompt]
135
- unconds_txt = ["low quality, gentle"]
136
- L = len(conds_txt)
137
 
138
- init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
139
 
140
- img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
141
- _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
142
 
143
- with torch.autocast(device_type='cuda'):
144
- images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
145
 
146
- images = rearrange(
147
- images[-1],
148
- "b (h w) (c ph pw) -> b c (h ph) (w pw)",
149
- h=128,
150
- w=8,
151
- ph=2,
152
- pw=2,)
153
 
154
- latents = 1 / global_vae.config.scaling_factor * images
155
- mel_spectrogram = global_vae.decode(latents).sample
156
 
157
- x_i = mel_spectrogram[0]
158
- if x_i.dim() == 4:
159
- x_i = x_i.squeeze(1)
160
- waveform = global_vocoder(x_i)
161
- waveform = waveform[0].cpu().float().detach().numpy()
162
 
163
- all_waveforms.append(waveform)
164
 
165
- # Concatenate all waveforms
166
- final_waveform = np.concatenate(all_waveforms)
167
 
168
- # Trim to exact duration
169
- sample_rate = 16000
170
- final_waveform = final_waveform[:int(duration * sample_rate)]
171
 
172
- progress(0.9, desc="Saving audio file")
173
-
174
- # Create 'generations' folder
175
- os.makedirs(GENERATIONS_DIR, exist_ok=True)
176
-
177
- # Generate filename
178
- prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
179
- model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
180
- model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
181
- base_filename = f"{prompt_part}_{seed}{model_suffix}"
182
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
183
-
184
- # Check if file exists and add numerical suffix if needed
185
- counter = 1
186
- while os.path.exists(output_path):
187
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
188
- counter += 1
189
 
190
- wavfile.write(output_path, sample_rate, final_waveform)
191
 
192
- progress(1.0, desc="Audio generation complete")
193
- return f"Generated with seed: {seed}", output_path
 
 
 
194
 
195
  # Load base resources at startup
196
  load_resources()
@@ -213,25 +225,29 @@ with gr.Blocks(theme=theme) as iface:
213
  </div>
214
  """)
215
 
216
- with gr.Row():
217
- model_url = gr.Textbox(label="Model URL", placeholder="Enter the URL of the model file (.pt)")
218
- load_model_button = gr.Button("Load Model")
219
 
220
- with gr.Row():
221
- prompt = gr.Textbox(label="Prompt")
222
- seed = gr.Number(label="Seed", value=0)
223
-
224
- with gr.Row():
225
- cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
226
- steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
227
- duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
228
 
229
  generate_button = gr.Button("Generate Music")
230
- output_status = gr.Textbox(label="Status")
231
  output_audio = gr.Audio(type="filepath")
232
 
233
- load_model_button.click(load_model, inputs=[model_url], outputs=[output_status])
234
- generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
 
 
 
 
 
 
 
235
 
236
  # Launch the interface
237
  iface.launch()
 
10
  import re
11
  import requests
12
  from urllib.parse import urlparse
13
+ import logging
14
+
15
+ # Set up logging
16
+ logging.basicConfig(level=logging.INFO)
17
 
18
  # Import necessary functions and classes
19
  from utils import load_t5, load_clap
 
43
  global_model = None
44
 
45
  def download_model(url):
46
+ try:
47
+ response = requests.get(url, stream=True)
48
+ if response.status_code == 200:
49
+ filename = os.path.basename(urlparse(url).path)
50
+ model_path = os.path.join("/tmp", filename)
51
+ with open(model_path, "wb") as f:
52
+ for chunk in response.iter_content(chunk_size=8192):
53
+ f.write(chunk)
54
+ return model_path
55
+ else:
56
+ raise Exception(f"Failed to download model from {url}")
57
+ except Exception as e:
58
+ logging.error(f"Error downloading model: {str(e)}")
59
+ raise
60
 
61
  def load_model(url):
62
  global global_model
 
65
 
66
  unload_current_model()
67
 
68
+ logging.info(f"Downloading model from {url}")
69
  model_path = download_model(url)
70
 
71
  # Determine model size from filename
 
81
  else:
82
  model_size = "base" # Default to base if unrecognized
83
 
84
+ logging.info(f"Loading {model_size} model: {filename}")
85
 
86
  global_model = build_model(model_size).to(device)
87
  state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
88
  global_model.load_state_dict(state_dict['ema'])
89
  global_model.eval()
90
  global_model.model_path = model_path
91
+ logging.info("Model loaded successfully")
92
  return "Model loaded successfully"
93
  except Exception as e:
94
+ logging.error(f"Error loading model: {str(e)}")
95
  return f"Error loading model: {str(e)}"
96
 
97
  def load_resources():
 
99
 
100
  device = "cuda" if torch.cuda.is_available() else "cpu"
101
 
102
+ logging.info("Loading T5 and CLAP models...")
103
  global_t5 = load_t5(device, max_length=256)
104
  global_clap = load_clap(device, max_length=256)
105
 
106
+ logging.info("Loading VAE and vocoder...")
107
  global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
108
  global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
109
 
110
+ logging.info("Initializing diffusion...")
111
  global_diffusion = RF()
112
 
113
+ logging.info("Base resources loaded successfully!")
114
 
115
+ def generate_music(prompt, seed, cfg_scale, steps, duration):
116
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
117
 
118
  if global_model is None:
 
120
 
121
  if seed == 0:
122
  seed = random.randint(1, 1000000)
123
+ logging.info(f"Using seed: {seed}")
124
 
125
  device = "cuda" if torch.cuda.is_available() else "cpu"
126
  torch.manual_seed(seed)
127
  torch.set_grad_enabled(False)
128
 
129
+ try:
130
+ # Calculate the number of segments needed for the desired duration
131
+ segment_duration = 10 # Each segment is 10 seconds
132
+ num_segments = int(np.ceil(duration / segment_duration))
133
 
134
+ all_waveforms = []
135
 
136
+ for i in range(num_segments):
137
+ logging.info(f"Generating segment {i+1}/{num_segments}")
138
 
139
+ # Use the same seed for all segments
140
+ torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
141
 
142
+ latent_size = (256, 16)
143
+ conds_txt = [prompt]
144
+ unconds_txt = ["low quality, gentle"]
145
+ L = len(conds_txt)
146
 
147
+ init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
148
 
149
+ img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
150
+ _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
151
 
152
+ with torch.autocast(device_type='cuda'):
153
+ images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
154
 
155
+ images = rearrange(
156
+ images[-1],
157
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
158
+ h=128,
159
+ w=8,
160
+ ph=2,
161
+ pw=2,)
162
 
163
+ latents = 1 / global_vae.config.scaling_factor * images
164
+ mel_spectrogram = global_vae.decode(latents).sample
165
 
166
+ x_i = mel_spectrogram[0]
167
+ if x_i.dim() == 4:
168
+ x_i = x_i.squeeze(1)
169
+ waveform = global_vocoder(x_i)
170
+ waveform = waveform[0].cpu().float().detach().numpy()
171
 
172
+ all_waveforms.append(waveform)
173
 
174
+ # Concatenate all waveforms
175
+ final_waveform = np.concatenate(all_waveforms)
176
 
177
+ # Trim to exact duration
178
+ sample_rate = 16000
179
+ final_waveform = final_waveform[:int(duration * sample_rate)]
180
 
181
+ logging.info("Saving audio file")
182
+
183
+ # Create 'generations' folder
184
+ os.makedirs(GENERATIONS_DIR, exist_ok=True)
185
+
186
+ # Generate filename
187
+ prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
188
+ model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
189
+ model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
190
+ base_filename = f"{prompt_part}_{seed}{model_suffix}"
191
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
192
+
193
+ # Check if file exists and add numerical suffix if needed
194
+ counter = 1
195
+ while os.path.exists(output_path):
196
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
197
+ counter += 1
198
 
199
+ wavfile.write(output_path, sample_rate, final_waveform)
200
 
201
+ logging.info("Audio generation complete")
202
+ return f"Generated with seed: {seed}", output_path
203
+ except Exception as e:
204
+ logging.error(f"Error generating music: {str(e)}")
205
+ return f"Error generating music: {str(e)}", None
206
 
207
  # Load base resources at startup
208
  load_resources()
 
225
  </div>
226
  """)
227
 
228
+ model_url = gr.Textbox(label="Model URL", placeholder="Enter the URL of the model file (.pt)")
229
+ load_model_button = gr.Button("Load Model")
230
+ model_status = gr.Textbox(label="Model Status")
231
 
232
+ prompt = gr.Textbox(label="Prompt")
233
+ seed = gr.Number(label="Seed", value=0)
234
+ cfg_scale = gr.Slider(minimum=1, maximum=40, step=0.1, label="CFG Scale", value=20)
235
+ steps = gr.Slider(minimum=10, maximum=200, step=1, label="Steps", value=100)
236
+ duration = gr.Number(label="Duration (seconds)", value=10, minimum=10, maximum=300, step=1)
 
 
 
237
 
238
  generate_button = gr.Button("Generate Music")
239
+ output_status = gr.Textbox(label="Generation Status")
240
  output_audio = gr.Audio(type="filepath")
241
 
242
+ def load_model_wrapper(url):
243
+ return load_model(url)
244
+
245
+ def generate_music_wrapper(prompt, seed, cfg_scale, steps, duration):
246
+ status, audio_path = generate_music(prompt, seed, cfg_scale, steps, duration)
247
+ return status, audio_path if audio_path else None
248
+
249
+ load_model_button.click(load_model_wrapper, inputs=[model_url], outputs=[model_status])
250
+ generate_button.click(generate_music_wrapper, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
251
 
252
  # Launch the interface
253
  iface.launch()