Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -11,6 +11,7 @@ import numpy as np
|
|
| 11 |
import re
|
| 12 |
import requests
|
| 13 |
import time
|
|
|
|
| 14 |
|
| 15 |
# Import necessary functions and classes
|
| 16 |
from utils import load_t5, load_clap
|
|
@@ -69,9 +70,10 @@ def unload_current_model():
|
|
| 69 |
global global_model, current_model_name
|
| 70 |
if global_model is not None:
|
| 71 |
del global_model
|
| 72 |
-
torch.cuda.empty_cache()
|
| 73 |
global_model = None
|
| 74 |
current_model_name = None
|
|
|
|
|
|
|
| 75 |
|
| 76 |
def load_model(model_name, device, model_url=None):
|
| 77 |
global global_model, current_model_name
|
|
@@ -121,8 +123,7 @@ def load_model(model_name, device, model_url=None):
|
|
| 121 |
load_time = end_time - start_time
|
| 122 |
return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
|
| 123 |
except Exception as e:
|
| 124 |
-
|
| 125 |
-
current_model_name = None
|
| 126 |
print(f"Error loading model {model_name}: {str(e)}")
|
| 127 |
return f"Failed to load model: {model_name}. Error: {str(e)}"
|
| 128 |
|
|
@@ -229,6 +230,11 @@ def generate_music(prompt, seed, cfg_scale, steps, duration, device, batch_size=
|
|
| 229 |
|
| 230 |
all_waveforms.append(waveform)
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
# Concatenate all waveforms
|
| 233 |
final_waveform = np.concatenate(all_waveforms)
|
| 234 |
|
|
|
|
| 11 |
import re
|
| 12 |
import requests
|
| 13 |
import time
|
| 14 |
+
import gc
|
| 15 |
|
| 16 |
# Import necessary functions and classes
|
| 17 |
from utils import load_t5, load_clap
|
|
|
|
| 70 |
global global_model, current_model_name
|
| 71 |
if global_model is not None:
|
| 72 |
del global_model
|
|
|
|
| 73 |
global_model = None
|
| 74 |
current_model_name = None
|
| 75 |
+
torch.cuda.empty_cache()
|
| 76 |
+
gc.collect()
|
| 77 |
|
| 78 |
def load_model(model_name, device, model_url=None):
|
| 79 |
global global_model, current_model_name
|
|
|
|
| 123 |
load_time = end_time - start_time
|
| 124 |
return f"Successfully loaded model: {model_name} in {load_time:.2f} seconds"
|
| 125 |
except Exception as e:
|
| 126 |
+
unload_current_model()
|
|
|
|
| 127 |
print(f"Error loading model {model_name}: {str(e)}")
|
| 128 |
return f"Failed to load model: {model_name}. Error: {str(e)}"
|
| 129 |
|
|
|
|
| 230 |
|
| 231 |
all_waveforms.append(waveform)
|
| 232 |
|
| 233 |
+
# Clear some memory after each segment
|
| 234 |
+
del images, latents, mel_spectrogram, x_i
|
| 235 |
+
torch.cuda.empty_cache()
|
| 236 |
+
gc.collect()
|
| 237 |
+
|
| 238 |
# Concatenate all waveforms
|
| 239 |
final_waveform = np.concatenate(all_waveforms)
|
| 240 |
|