flosstradamus commited on
Commit
6a91f5a
·
verified ·
1 Parent(s): 0b5c0f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -17
app.py CHANGED
@@ -9,6 +9,7 @@ import glob
9
  import random
10
  import numpy as np
11
  import re
 
12
 
13
  # Import necessary functions and classes
14
  from utils import load_t5, load_clap
@@ -71,12 +72,24 @@ def unload_current_model():
71
  global_model = None
72
  current_model_name = None
73
 
74
- def load_model(model_name):
75
  global global_model, current_model_name
76
- device = "cpu" # Force CPU usage
77
 
78
  unload_current_model()
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  # Determine model size from filename
81
  if 'musicflow_b' in model_name:
82
  model_size = "base"
@@ -91,7 +104,6 @@ def load_model(model_name):
91
 
92
  print(f"Loading {model_size} model: {model_name}")
93
 
94
- model_path = os.path.join(MODELS_DIR, model_name)
95
  global_model = build_model(model_size).to(device)
96
 
97
  try:
@@ -106,11 +118,9 @@ def load_model(model_name):
106
  print(f"Error loading model {model_name}: {str(e)}")
107
  return f"Failed to load model: {model_name}. Error: {str(e)}"
108
 
109
- def load_resources():
110
  global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
111
 
112
- device = "cpu"
113
-
114
  print("Loading T5 and CLAP models...")
115
  global_t5 = load_t5(device, max_length=256)
116
  global_clap = load_clap(device, max_length=256)
@@ -124,7 +134,7 @@ def load_resources():
124
 
125
  print("Base resources loaded successfully!")
126
 
127
- def generate_music(prompt, seed, cfg_scale, steps, duration, batch_size=4, progress=gr.Progress()):
128
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
129
 
130
  if global_model is None:
@@ -134,7 +144,6 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, batch_size=4, progr
134
  seed = random.randint(1, 1000000)
135
  print(f"Using seed: {seed}")
136
 
137
- device = "cpu"
138
  torch.manual_seed(seed)
139
  torch.set_grad_enabled(False)
140
 
@@ -226,9 +235,6 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, batch_size=4, progr
226
  progress(1.0, desc="Audio generation complete")
227
  return f"Generated with seed: {seed}", output_path
228
 
229
- # Load base resources at startup
230
- load_resources()
231
-
232
  # Get list of .pt files in the models directory
233
  model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
234
  model_choices = [os.path.basename(f) for f in model_files]
@@ -258,11 +264,14 @@ with gr.Blocks(theme=theme) as iface:
258
  <div style="text-align: center;">
259
  <h1>FluxMusic Generator</h1>
260
  <p>Generate music based on text prompts using FluxMusic model.</p>
 
261
  </div>
262
  """)
263
 
264
  with gr.Row():
265
  model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
 
 
266
  load_model_button = gr.Button("Load Model")
267
  model_status = gr.Textbox(label="Model Status", value="No model loaded")
268
 
@@ -279,15 +288,18 @@ with gr.Blocks(theme=theme) as iface:
279
  output_status = gr.Textbox(label="Generation Status")
280
  output_audio = gr.Audio(type="filepath")
281
 
282
- def on_load_model_click(model_name):
283
- result = load_model(model_name)
 
 
 
284
  return result
285
 
286
- load_model_button.click(on_load_model_click, inputs=[model_dropdown], outputs=[model_status])
287
- generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration], outputs=[output_status, output_audio])
288
 
289
- # Load default model on startup
290
- iface.load(lambda: on_load_model_click(default_model), inputs=None, outputs=None)
291
 
292
  # Launch the interface
293
  iface.launch()
 
9
  import random
10
  import numpy as np
11
  import re
12
+ import requests
13
 
14
  # Import necessary functions and classes
15
  from utils import load_t5, load_clap
 
72
  global_model = None
73
  current_model_name = None
74
 
75
+ def load_model(model_name, device, model_url=None):
76
  global global_model, current_model_name
 
77
 
78
  unload_current_model()
79
 
80
+ if model_url:
81
+ print(f"Downloading model from URL: {model_url}")
82
+ response = requests.get(model_url)
83
+ if response.status_code == 200:
84
+ model_path = os.path.join(MODELS_DIR, "downloaded_model.pt")
85
+ with open(model_path, 'wb') as f:
86
+ f.write(response.content)
87
+ model_name = "downloaded_model.pt"
88
+ else:
89
+ return f"Failed to download model from URL: {model_url}"
90
+ else:
91
+ model_path = os.path.join(MODELS_DIR, model_name)
92
+
93
  # Determine model size from filename
94
  if 'musicflow_b' in model_name:
95
  model_size = "base"
 
104
 
105
  print(f"Loading {model_size} model: {model_name}")
106
 
 
107
  global_model = build_model(model_size).to(device)
108
 
109
  try:
 
118
  print(f"Error loading model {model_name}: {str(e)}")
119
  return f"Failed to load model: {model_name}. Error: {str(e)}"
120
 
121
+ def load_resources(device):
122
  global global_t5, global_clap, global_vae, global_vocoder, global_diffusion
123
 
 
 
124
  print("Loading T5 and CLAP models...")
125
  global_t5 = load_t5(device, max_length=256)
126
  global_clap = load_clap(device, max_length=256)
 
134
 
135
  print("Base resources loaded successfully!")
136
 
137
+ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=4, progress=gr.Progress()):
138
  global global_model, global_t5, global_clap, global_vae, global_vocoder, global_diffusion
139
 
140
  if global_model is None:
 
144
  seed = random.randint(1, 1000000)
145
  print(f"Using seed: {seed}")
146
 
 
147
  torch.manual_seed(seed)
148
  torch.set_grad_enabled(False)
149
 
 
235
  progress(1.0, desc="Audio generation complete")
236
  return f"Generated with seed: {seed}", output_path
237
 
 
 
 
238
  # Get list of .pt files in the models directory
239
  model_files = glob.glob(os.path.join(MODELS_DIR, "*.pt"))
240
  model_choices = [os.path.basename(f) for f in model_files]
 
264
  <div style="text-align: center;">
265
  <h1>FluxMusic Generator</h1>
266
  <p>Generate music based on text prompts using FluxMusic model.</p>
267
+ <p>Feel free to clone this space and run on GPU locally or on Hugging Face.</p>
268
  </div>
269
  """)
270
 
271
  with gr.Row():
272
  model_dropdown = gr.Dropdown(choices=model_choices, label="Select Model", value=default_model)
273
+ model_url = gr.Textbox(label="Or enter model URL")
274
+ device_choice = gr.Radio(["cpu", "cuda"], label="Device", value="cpu")
275
  load_model_button = gr.Button("Load Model")
276
  model_status = gr.Textbox(label="Model Status", value="No model loaded")
277
 
 
288
  output_status = gr.Textbox(label="Generation Status")
289
  output_audio = gr.Audio(type="filepath")
290
 
291
+ def on_load_model_click(model_name, device, url):
292
+ if url:
293
+ result = load_model(None, device, model_url=url)
294
+ else:
295
+ result = load_model(model_name, device)
296
  return result
297
 
298
+ load_model_button.click(on_load_model_click, inputs=[model_dropdown, device_choice, model_url], outputs=[model_status])
299
+ generate_button.click(generate_music, inputs=[prompt, seed, cfg_scale, steps, duration, device_choice], outputs=[output_status, output_audio])
300
 
301
+ # Load default model and resources on startup
302
+ iface.load(lambda: (load_resources("cpu"), on_load_model_click(default_model, "cpu", None)), inputs=None, outputs=None)
303
 
304
  # Launch the interface
305
  iface.launch()