thecollabagepatch commited on
Commit
3e6fdea
Β·
1 Parent(s): 108943b

trying to solve model loading

Browse files
Files changed (1) hide show
  1. app.py +99 -50
app.py CHANGED
@@ -25,20 +25,21 @@ model_lock = threading.Lock()
25
 
26
  @contextmanager
27
  def resource_cleanup():
28
- """Context manager to ensure proper cleanup of GPU resources."""
29
  try:
30
  yield
31
  finally:
 
32
  if torch.cuda.is_available():
33
  torch.cuda.synchronize()
34
- torch.cuda.empty_cache()
35
- gc.collect()
36
 
37
  def load_stable_audio_model():
38
  """Load stable-audio-open-small model if not already loaded."""
39
  with model_lock:
40
  if 'stable_audio_model' not in model_cache:
41
  print("πŸ”„ Loading stable-audio-open-small model...")
 
42
 
43
  # Authenticate with HF
44
  hf_token = os.getenv('HF_TOKEN')
@@ -53,10 +54,36 @@ def load_stable_audio_model():
53
  if device == "cuda":
54
  model = model.half()
55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  model_cache['stable_audio_model'] = model
57
  model_cache['stable_audio_config'] = config
58
  model_cache['stable_audio_device'] = device
59
- print(f"βœ… Stable Audio model loaded on {device}")
 
 
60
 
61
  return (model_cache['stable_audio_model'],
62
  model_cache['stable_audio_config'],
@@ -66,7 +93,12 @@ def load_stable_audio_model():
66
  def generate_stable_audio_loop(prompt, loop_type, bpm, bars, seed=-1):
67
  """Generate a BPM-aware loop using stable-audio-open-small"""
68
  try:
 
 
 
 
69
  model, config, device = load_stable_audio_model()
 
70
 
71
  # Calculate loop duration based on BPM and bars
72
  seconds_per_beat = 60.0 / bpm
@@ -95,6 +127,7 @@ def generate_stable_audio_loop(prompt, loop_type, bpm, bars, seed=-1):
95
  print(f" Seed: {seed}")
96
 
97
  # Prepare conditioning
 
98
  conditioning = [{
99
  "prompt": enhanced_prompt,
100
  "seconds_total": 12 # Model generates 12s max
@@ -104,54 +137,70 @@ def generate_stable_audio_loop(prompt, loop_type, bpm, bars, seed=-1):
104
  "prompt": negative_prompt,
105
  "seconds_total": 12
106
  }]
 
107
 
108
- start_time = time.time()
 
109
 
110
- with resource_cleanup():
111
- if device == "cuda":
112
- torch.cuda.empty_cache()
113
-
114
- with torch.cuda.amp.autocast(enabled=(device == "cuda")):
115
- output = generate_diffusion_cond(
116
- model,
117
- steps=8, # Fast generation
118
- cfg_scale=1.0, # Good balance for loops
119
- conditioning=conditioning,
120
- negative_conditioning=negative_conditioning,
121
- sample_size=config["sample_size"],
122
- sampler_type="pingpong",
123
- device=device,
124
- seed=seed
125
- )
126
-
127
- generation_time = time.time() - start_time
128
-
129
- # Post-process audio
130
- output = rearrange(output, "b d n -> d (b n)") # (2, N) stereo
131
- output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1)
132
-
133
- # Extract the loop portion
134
- sample_rate = config["sample_rate"]
135
- loop_samples = int(target_loop_duration * sample_rate)
136
- available_samples = output.shape[1]
137
-
138
- if loop_samples > available_samples:
139
- loop_samples = available_samples
140
- actual_duration = available_samples / sample_rate
141
- print(f"⚠️ Requested {target_loop_duration:.2f}s, got {actual_duration:.2f}s")
142
-
143
- # Extract loop from beginning (cleanest beat alignment)
144
- loop_output = output[:, :loop_samples]
145
- loop_output_int16 = loop_output.mul(32767).to(torch.int16).cpu()
146
-
147
- # Save to temporary file
148
- loop_filename = f"loop_{loop_type}_{bpm}bpm_{bars}bars_{seed}.wav"
149
- torchaudio.save(loop_filename, loop_output_int16, sample_rate)
150
-
151
- actual_duration = loop_samples / sample_rate
152
- print(f"βœ… {loop_type.title()} loop generated: {actual_duration:.2f}s in {generation_time:.2f}s")
153
-
154
- return loop_filename, f"Generated {actual_duration:.2f}s {loop_type} loop at {bpm}bpm ({bars} bars)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
  except Exception as e:
157
  print(f"❌ Generation error: {str(e)}")
 
25
 
26
  @contextmanager
27
  def resource_cleanup():
28
+ """Lightweight context manager - let zerogpu handle memory management"""
29
  try:
30
  yield
31
  finally:
32
+ # Minimal cleanup - let zerogpu handle the heavy lifting
33
  if torch.cuda.is_available():
34
  torch.cuda.synchronize()
35
+ # Removed aggressive empty_cache() and gc.collect() calls
 
36
 
37
  def load_stable_audio_model():
38
  """Load stable-audio-open-small model if not already loaded."""
39
  with model_lock:
40
  if 'stable_audio_model' not in model_cache:
41
  print("πŸ”„ Loading stable-audio-open-small model...")
42
+ load_start = time.time()
43
 
44
  # Authenticate with HF
45
  hf_token = os.getenv('HF_TOKEN')
 
54
  if device == "cuda":
55
  model = model.half()
56
 
57
+ load_time = time.time() - load_start
58
+ print(f"βœ… Model loaded on {device} in {load_time:.2f}s")
59
+
60
+ # Aggressive model persistence - warm up with dummy generation
61
+ print("πŸ”₯ Warming up model...")
62
+ warmup_start = time.time()
63
+ try:
64
+ dummy_conditioning = [{"prompt": "test", "seconds_total": 12}]
65
+ with torch.no_grad():
66
+ _ = generate_diffusion_cond(
67
+ model,
68
+ steps=1, # Minimal steps for warmup
69
+ cfg_scale=1.0,
70
+ conditioning=dummy_conditioning,
71
+ sample_size=config["sample_size"],
72
+ sampler_type="pingpong",
73
+ device=device,
74
+ seed=42
75
+ )
76
+ warmup_time = time.time() - warmup_start
77
+ print(f"πŸ”₯ Model warmed up in {warmup_time:.2f}s")
78
+ except Exception as e:
79
+ print(f"⚠️ Warmup failed (but continuing): {e}")
80
+
81
  model_cache['stable_audio_model'] = model
82
  model_cache['stable_audio_config'] = config
83
  model_cache['stable_audio_device'] = device
84
+ print(f"βœ… Stable Audio model ready for fast generation!")
85
+ else:
86
+ print("♻️ Using cached model (should be fast!)")
87
 
88
  return (model_cache['stable_audio_model'],
89
  model_cache['stable_audio_config'],
 
93
  def generate_stable_audio_loop(prompt, loop_type, bpm, bars, seed=-1):
94
  """Generate a BPM-aware loop using stable-audio-open-small"""
95
  try:
96
+ total_start = time.time()
97
+
98
+ # Model loading timing
99
+ load_start = time.time()
100
  model, config, device = load_stable_audio_model()
101
+ load_time = time.time() - load_start
102
 
103
  # Calculate loop duration based on BPM and bars
104
  seconds_per_beat = 60.0 / bpm
 
127
  print(f" Seed: {seed}")
128
 
129
  # Prepare conditioning
130
+ conditioning_start = time.time()
131
  conditioning = [{
132
  "prompt": enhanced_prompt,
133
  "seconds_total": 12 # Model generates 12s max
 
137
  "prompt": negative_prompt,
138
  "seconds_total": 12
139
  }]
140
+ conditioning_time = time.time() - conditioning_start
141
 
142
+ # Generation timing
143
+ generation_start = time.time()
144
 
145
+ # Removed aggressive resource cleanup wrapper
146
+ # Clear GPU cache once before generation (not after)
147
+ if device == "cuda":
148
+ torch.cuda.empty_cache()
149
+
150
+ with torch.cuda.amp.autocast(enabled=(device == "cuda")):
151
+ output = generate_diffusion_cond(
152
+ model,
153
+ steps=8, # Fast generation
154
+ cfg_scale=1.0, # Good balance for loops
155
+ conditioning=conditioning,
156
+ negative_conditioning=negative_conditioning,
157
+ sample_size=config["sample_size"],
158
+ sampler_type="pingpong",
159
+ device=device,
160
+ seed=seed
161
+ )
162
+
163
+ generation_time = time.time() - generation_start
164
+
165
+ # Post-processing timing
166
+ postproc_start = time.time()
167
+
168
+ # Post-process audio
169
+ output = rearrange(output, "b d n -> d (b n)") # (2, N) stereo
170
+ output = output.to(torch.float32).div(torch.max(torch.abs(output))).clamp(-1, 1)
171
+
172
+ # Extract the loop portion
173
+ sample_rate = config["sample_rate"]
174
+ loop_samples = int(target_loop_duration * sample_rate)
175
+ available_samples = output.shape[1]
176
+
177
+ if loop_samples > available_samples:
178
+ loop_samples = available_samples
179
+ actual_duration = available_samples / sample_rate
180
+ print(f"⚠️ Requested {target_loop_duration:.2f}s, got {actual_duration:.2f}s")
181
+
182
+ # Extract loop from beginning (cleanest beat alignment)
183
+ loop_output = output[:, :loop_samples]
184
+ loop_output_int16 = loop_output.mul(32767).to(torch.int16).cpu()
185
+
186
+ # Save to temporary file
187
+ loop_filename = f"loop_{loop_type}_{bpm}bpm_{bars}bars_{seed}.wav"
188
+ torchaudio.save(loop_filename, loop_output_int16, sample_rate)
189
+
190
+ postproc_time = time.time() - postproc_start
191
+ total_time = time.time() - total_start
192
+ actual_duration = loop_samples / sample_rate
193
+
194
+ # Detailed timing breakdown
195
+ print(f"⏱️ Timing breakdown:")
196
+ print(f" Model load: {load_time:.2f}s")
197
+ print(f" Conditioning: {conditioning_time:.3f}s")
198
+ print(f" Generation: {generation_time:.2f}s")
199
+ print(f" Post-processing: {postproc_time:.3f}s")
200
+ print(f" Total: {total_time:.2f}s")
201
+ print(f"βœ… {loop_type.title()} loop: {actual_duration:.2f}s audio in {total_time:.2f}s")
202
+
203
+ return loop_filename, f"Generated {actual_duration:.2f}s {loop_type} loop at {bpm}bpm ({bars} bars) in {total_time:.2f}s"
204
 
205
  except Exception as e:
206
  print(f"❌ Generation error: {str(e)}")