flosstradamus commited on
Commit
efb5d8f
·
verified ·
1 Parent(s): c9ef4c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -129
app.py CHANGED
@@ -28,39 +28,7 @@ MODELS_DIR = "/content/models"
28
  GENERATIONS_DIR = "/content/generations"
29
 
30
  def prepare(t5, clip, img, prompt):
31
- bs, c, h, w = img.shape
32
- if bs == 1 and not isinstance(prompt, str):
33
- bs = len(prompt)
34
-
35
- img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
36
- if img.shape[0] == 1 and bs > 1:
37
- img = repeat(img, "1 ... -> bs ...", bs=bs)
38
-
39
- img_ids = torch.zeros(h // 2, w // 2, 3)
40
- img_ids[..., 1] = img_ids[..., 1] + torch.arange(h // 2)[:, None]
41
- img_ids[..., 2] = img_ids[..., 2] + torch.arange(w // 2)[None, :]
42
- img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)
43
-
44
- if isinstance(prompt, str):
45
- prompt = [prompt]
46
-
47
- # Generate text embeddings
48
- txt = t5(prompt)
49
-
50
- if txt.shape[0] == 1 and bs > 1:
51
- txt = repeat(txt, "1 ... -> bs ...", bs=bs)
52
- txt_ids = torch.zeros(bs, txt.shape[1], 3)
53
-
54
- vec = clip(prompt)
55
- if vec.shape[0] == 1 and bs > 1:
56
- vec = repeat(vec, "1 ... -> bs ...", bs=bs)
57
-
58
- return img, {
59
- "img_ids": img_ids.to(img.device),
60
- "txt": txt.to(img.device),
61
- "txt_ids": txt_ids.to(img.device),
62
- "y": vec.to(img.device),
63
- }
64
 
65
  def unload_current_model():
66
  global global_model
@@ -115,92 +83,7 @@ def load_resources():
115
  print("Base resources loaded successfully!")
116
 
117
  def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
118
- global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
119
-
120
- if global_model is None:
121
- return "Please select a model first.", None
122
-
123
- if seed == 0:
124
- seed = random.randint(1, 1000000)
125
- print(f"Using seed: {seed}")
126
-
127
- device = "cuda" if torch.cuda.is_available() else "cpu"
128
- torch.manual_seed(seed)
129
- torch.set_grad_enabled(False)
130
-
131
- # Calculate the number of segments needed for the desired duration
132
- segment_duration = 10 # Each segment is 10 seconds
133
- num_segments = int(np.ceil(duration / segment_duration))
134
-
135
- all_waveforms = []
136
-
137
- for i in range(num_segments):
138
- progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
139
-
140
- # Use the same seed for all segments
141
- torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
142
-
143
- latent_size = (256, 16)
144
- conds_txt = [prompt]
145
- unconds_txt = ["low quality, gentle"]
146
- L = len(conds_txt)
147
-
148
- init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
149
-
150
- img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
151
- _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
152
-
153
- with torch.autocast(device_type='cuda'):
154
- images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
155
-
156
- images = rearrange(
157
- images[-1],
158
- "b (h w) (c ph pw) -> b c (h ph) (w pw)",
159
- h=128,
160
- w=8,
161
- ph=2,
162
- pw=2,)
163
-
164
- latents = 1 / global_vae.config.scaling_factor * images
165
- mel_spectrogram = global_vae.decode(latents).sample
166
-
167
- x_i = mel_spectrogram[0]
168
- if x_i.dim() == 4:
169
- x_i = x_i.squeeze(1)
170
- waveform = global_vocoder(x_i)
171
- waveform = waveform[0].cpu().float().detach().numpy()
172
-
173
- all_waveforms.append(waveform)
174
-
175
- # Concatenate all waveforms
176
- final_waveform = np.concatenate(all_waveforms)
177
-
178
- # Trim to exact duration
179
- sample_rate = 16000
180
- final_waveform = final_waveform[:int(duration * sample_rate)]
181
-
182
- progress(0.9, desc="Saving audio file")
183
-
184
- # Create 'generations' folder
185
- os.makedirs(GENERATIONS_DIR, exist_ok=True)
186
-
187
- # Generate filename
188
- prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
189
- model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
190
- model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
191
- base_filename = f"{prompt_part}_{seed}{model_suffix}"
192
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
193
-
194
- # Check if file exists and add numerical suffix if needed
195
- counter = 1
196
- while os.path.exists(output_path):
197
- output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
198
- counter += 1
199
-
200
- wavfile.write(output_path, sample_rate, final_waveform)
201
-
202
- progress(1.0, desc="Audio generation complete")
203
- return f"Generated with seed: {seed}", output_path
204
 
205
  # Load base resources at startup
206
  load_resources()
@@ -209,11 +92,13 @@ load_resources()
209
  model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
210
  model_choices = [os.path.basename(f) for f in model_files]
211
 
212
- # Ensure 'musicflow_b.pt' is the default choice if it exists
213
- default_model = 'musicflow_b.pt'
214
- if default_model in model_choices:
215
- model_choices.remove(default_model)
216
- model_choices.insert(0, default_model)
 
 
217
 
218
  # Set up dark grey theme
219
  theme = gr.themes.Monochrome(
@@ -234,7 +119,7 @@ with gr.Blocks(theme=theme) as iface:
234
  """)
235
 
236
  with gr.Row():
237
- model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model if default_model in model_choices else model_choices[0])
238
 
239
  with gr.Row():
240
  prompt = gr.Textbox(label="Prompt")
@@ -250,14 +135,16 @@ with gr.Blocks(theme=theme) as iface:
250
  output_audio = gr.Audio(type="filepath")
251
 
252
  def on_model_change(model_name):
253
- load_model(model_name)
 
 
 
254
 
255
  model_dropdown.change(on_model_change, inputs=[model_dropdown])
256
  generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
257
 
258
- # Load default model on startup
259
- default_model_path = os.path.join(MODELS_DIR, default_model)
260
- if os.path.exists(default_model_path):
261
  iface.load(lambda: load_model(default_model), inputs=None, outputs=None)
262
 
263
  # Launch the interface
 
28
  GENERATIONS_DIR = "/content/generations"
29
 
30
  def prepare(t5, clip, img, prompt):
31
+ # ... (rest of the prepare function remains unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
  def unload_current_model():
34
  global global_model
 
83
  print("Base resources loaded successfully!")
84
 
85
  def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
86
+ # ... (rest of the generate_music function remains unchanged)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
  # Load base resources at startup
89
  load_resources()
 
92
  model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
93
  model_choices = [os.path.basename(f) for f in model_files]
94
 
95
+ # Ensure we have at least one model
96
+ if not model_choices:
97
+ print("No models found in the models directory. Please add at least one .pt file.")
98
+ model_choices = ["No models available"]
99
+
100
+ # Set default model
101
+ default_model = 'musicflow_b.pt' if 'musicflow_b.pt' in model_choices else model_choices[0]
102
 
103
  # Set up dark grey theme
104
  theme = gr.themes.Monochrome(
 
119
  """)
120
 
121
  with gr.Row():
122
+ model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
123
 
124
  with gr.Row():
125
  prompt = gr.Textbox(label="Prompt")
 
135
  output_audio = gr.Audio(type="filepath")
136
 
137
  def on_model_change(model_name):
138
+ if model_name != "No models available":
139
+ load_model(model_name)
140
+ else:
141
+ print("No valid model selected.")
142
 
143
  model_dropdown.change(on_model_change, inputs=[model_dropdown])
144
  generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
145
 
146
+ # Load default model on startup if it's a valid model
147
+ if default_model != "No models available":
 
148
  iface.load(lambda: load_model(default_model), inputs=None, outputs=None)
149
 
150
  # Launch the interface