flosstradamus commited on
Commit
923a82b
·
verified ·
1 Parent(s): 4f4644c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -38
app.py CHANGED
@@ -39,54 +39,158 @@ def unload_current_model():
39
  global_model = None
40
 
41
  def download_model(url):
42
- response = requests.get(url)
43
  if response.status_code == 200:
44
  filename = os.path.basename(urlparse(url).path)
45
  model_path = os.path.join("/tmp", filename)
46
  with open(model_path, "wb") as f:
47
- f.write(response.content)
 
48
  return model_path
49
  else:
50
  raise Exception(f"Failed to download model from {url}")
51
 
52
  def load_model(url):
53
  global global_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
  device = "cuda" if torch.cuda.is_available() else "cpu"
55
 
56
- unload_current_model()
 
 
57
 
58
- print(f"Downloading model from {url}")
59
- model_path = download_model(url)
 
60
 
61
- # Determine model size from filename
62
- filename = os.path.basename(model_path)
63
- if 'musicflow_b' in filename:
64
- model_size = "base"
65
- elif 'musicflow_g' in filename:
66
- model_size = "giant"
67
- elif 'musicflow_l' in filename:
68
- model_size = "large"
69
- elif 'musicflow_s' in filename:
70
- model_size = "small"
71
- else:
72
- model_size = "base" # Default to base if unrecognized
73
 
74
- print(f"Loading {model_size} model: {filename}")
 
 
75
 
76
- global_model = build_model(model_size).to(device)
77
- state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
78
- global_model.load_state_dict(state_dict['ema'])
79
- global_model.eval()
80
- global_model.model_path = model_path
81
- print("Model loaded successfully")
82
 
83
- def load_resources():
84
- # ... [The load_resources function remains unchanged]
85
- pass
86
 
87
- def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
88
- # ... [The generate_music function remains unchanged]
89
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
  # Load base resources at startup
92
  load_resources()
@@ -126,14 +230,7 @@ with gr.Blocks(theme=theme) as iface:
126
  output_status = gr.Textbox(label="Status")
127
  output_audio = gr.Audio(type="filepath")
128
 
129
- def on_load_model(url):
130
- try:
131
- load_model(url)
132
- return "Model loaded successfully"
133
- except Exception as e:
134
- return f"Error loading model: {str(e)}"
135
-
136
- load_model_button.click(on_load_model, inputs=[model_url], outputs=[output_status])
137
  generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
138
 
139
  # Launch the interface
 
39
  global_model = None
40
 
41
  def download_model(url):
42
+ response = requests.get(url, stream=True)
43
  if response.status_code == 200:
44
  filename = os.path.basename(urlparse(url).path)
45
  model_path = os.path.join("/tmp", filename)
46
  with open(model_path, "wb") as f:
47
+ for chunk in response.iter_content(chunk_size=8192):
48
+ f.write(chunk)
49
  return model_path
50
  else:
51
  raise Exception(f"Failed to download model from {url}")
52
 
53
  def load_model(url):
54
  global global_model
55
+ try:
56
+ device = "cuda" if torch.cuda.is_available() else "cpu"
57
+
58
+ unload_current_model()
59
+
60
+ print(f"Downloading model from {url}")
61
+ model_path = download_model(url)
62
+
63
+ # Determine model size from filename
64
+ filename = os.path.basename(model_path)
65
+ if 'musicflow_b' in filename:
66
+ model_size = "base"
67
+ elif 'musicflow_g' in filename:
68
+ model_size = "giant"
69
+ elif 'musicflow_l' in filename:
70
+ model_size = "large"
71
+ elif 'musicflow_s' in filename:
72
+ model_size = "small"
73
+ else:
74
+ model_size = "base" # Default to base if unrecognized
75
+
76
+ print(f"Loading {model_size} model: {filename}")
77
+
78
+ global_model = build_model(model_size).to(device)
79
+ state_dict = torch.load(model_path, map_location=lambda storage, loc: storage, weights_only=True)
80
+ global_model.load_state_dict(state_dict['ema'])
81
+ global_model.eval()
82
+ global_model.model_path = model_path
83
+ print("Model loaded successfully")
84
+ return "Model loaded successfully"
85
+ except Exception as e:
86
+ print(f"Error loading model: {str(e)}")
87
+ return f"Error loading model: {str(e)}"
88
+
89
+ def load_resources():
90
+ global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
91
+
92
  device = "cuda" if torch.cuda.is_available() else "cpu"
93
 
94
+ print("Loading T5 and CLAP models...")
95
+ global_t5 = load_t5(device, max_length=256)
96
+ global_clap = load_clap(device, max_length=256)
97
 
98
+ print("Loading VAE and vocoder...")
99
+ global_vae = AutoencoderKL.from_pretrained('cvssp/audioldm2', subfolder="vae").to(device)
100
+ global_vocoder = SpeechT5HifiGan.from_pretrained('cvssp/audioldm2', subfolder="vocoder").to(device)
101
 
102
+ print("Initializing diffusion...")
103
+ global_diffusion = RF()
104
+
105
+ print("Base resources loaded successfully!")
106
+
107
+ def generate_music(prompt, seed, cfg_scale, steps, duration, progress=gr.Progress()):
108
+ global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
109
+
110
+ if global_model is None:
111
+ return "Please load a model first.", None
 
 
112
 
113
+ if seed == 0:
114
+ seed = random.randint(1, 1000000)
115
+ print(f"Using seed: {seed}")
116
 
117
+ device = "cuda" if torch.cuda.is_available() else "cpu"
118
+ torch.manual_seed(seed)
119
+ torch.set_grad_enabled(False)
 
 
 
120
 
121
+ # Calculate the number of segments needed for the desired duration
122
+ segment_duration = 10 # Each segment is 10 seconds
123
+ num_segments = int(np.ceil(duration / segment_duration))
124
 
125
+ all_waveforms = []
126
+
127
+ for i in range(num_segments):
128
+ progress(i / num_segments, desc=f"Generating segment {i+1}/{num_segments}")
129
+
130
+ # Use the same seed for all segments
131
+ torch.manual_seed(seed + i) # Add i to slightly vary each segment while maintaining consistency
132
+
133
+ latent_size = (256, 16)
134
+ conds_txt = [prompt]
135
+ unconds_txt = ["low quality, gentle"]
136
+ L = len(conds_txt)
137
+
138
+ init_noise = torch.randn(L, 8, latent_size[0], latent_size[1]).to(device)
139
+
140
+ img, conds = prepare(global_t5, global_clap, init_noise, conds_txt)
141
+ _, unconds = prepare(global_t5, global_clap, init_noise, unconds_txt)
142
+
143
+ with torch.autocast(device_type='cuda'):
144
+ images = global_diffusion.sample_with_xps(global_model, img, conds=conds, null_cond=unconds, sample_steps=steps, cfg=cfg_scale)
145
+
146
+ images = rearrange(
147
+ images[-1],
148
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
149
+ h=128,
150
+ w=8,
151
+ ph=2,
152
+ pw=2,)
153
+
154
+ latents = 1 / global_vae.config.scaling_factor * images
155
+ mel_spectrogram = global_vae.decode(latents).sample
156
+
157
+ x_i = mel_spectrogram[0]
158
+ if x_i.dim() == 4:
159
+ x_i = x_i.squeeze(1)
160
+ waveform = global_vocoder(x_i)
161
+ waveform = waveform[0].cpu().float().detach().numpy()
162
+
163
+ all_waveforms.append(waveform)
164
+
165
+ # Concatenate all waveforms
166
+ final_waveform = np.concatenate(all_waveforms)
167
+
168
+ # Trim to exact duration
169
+ sample_rate = 16000
170
+ final_waveform = final_waveform[:int(duration * sample_rate)]
171
+
172
+ progress(0.9, desc="Saving audio file")
173
+
174
+ # Create 'generations' folder
175
+ os.makedirs(GENERATIONS_DIR, exist_ok=True)
176
+
177
+ # Generate filename
178
+ prompt_part = re.sub(r'[^\w\s-]', '', prompt)[:10].strip().replace(' ', '_')
179
+ model_name = os.path.splitext(os.path.basename(global_model.model_path))[0]
180
+ model_suffix = '_mf_b' if model_name == 'musicflow_b' else f'_{model_name}'
181
+ base_filename = f"{prompt_part}_{seed}{model_suffix}"
182
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}.wav")
183
+
184
+ # Check if file exists and add numerical suffix if needed
185
+ counter = 1
186
+ while os.path.exists(output_path):
187
+ output_path = os.path.join(GENERATIONS_DIR, f"{base_filename}_{counter}.wav")
188
+ counter += 1
189
+
190
+ wavfile.write(output_path, sample_rate, final_waveform)
191
+
192
+ progress(1.0, desc="Audio generation complete")
193
+ return f"Generated with seed: {seed}", output_path
194
 
195
  # Load base resources at startup
196
  load_resources()
 
230
  output_status = gr.Textbox(label="Status")
231
  output_audio = gr.Audio(type="filepath")
232
 
233
+ load_model_button.click(load_model, inputs=[model_url], outputs=[output_status])
 
 
 
 
 
 
 
234
  generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
235
 
236
  # Launch the interface