fffiloni commited on
Commit
b04fd93
·
verified ·
1 Parent(s): bc12e9f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -3
app.py CHANGED
@@ -149,6 +149,10 @@ models_rbm = core.Models(
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
 
 
 
 
152
  clear_gpu_cache() # Clear cache before inference
153
 
154
  height=1024
@@ -181,10 +185,10 @@ def infer(style_description, ref_style_file, caption):
181
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
182
 
183
  if low_vram:
184
- # The sampling process uses more vram, so we offload everything except two modules to the cpu.
185
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
186
 
187
- # Stage C reverse process.
188
  with torch.cuda.amp.autocast(): # Use mixed precision
189
  sampling_c = extras.gdf.sample(
190
  models_rbm.generator, conditions, stage_c_latent_shape,
@@ -202,7 +206,10 @@ def infer(style_description, ref_style_file, caption):
202
 
203
  clear_gpu_cache() # Clear cache between stages
204
 
205
- # Stage B reverse process.
 
 
 
206
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
207
  conditions_b['effnet'] = sampled_c
208
  unconditions_b['effnet'] = torch.zeros_like(sampled_c)
@@ -215,6 +222,7 @@ def infer(style_description, ref_style_file, caption):
215
  sampled_b = sampled_b
216
  sampled = models_b.stage_a.decode(sampled_b).float()
217
 
 
218
  sampled = torch.cat([
219
  torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
220
  sampled.cpu(),
@@ -234,6 +242,7 @@ def infer(style_description, ref_style_file, caption):
234
 
235
  return output_file # Return the path to the saved image
236
 
 
237
  import gradio as gr
238
 
239
  gr.Interface(
 
149
  models_rbm.generator.eval().requires_grad_(False)
150
 
151
  def infer(style_description, ref_style_file, caption):
152
+ # Ensure all models are moved back to the correct device
153
+ core_b.generator.to(device)
154
+ models_rbm.generator.to(device)
155
+
156
  clear_gpu_cache() # Clear cache before inference
157
 
158
  height=1024
 
185
  unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)
186
 
187
  if low_vram:
188
+ # Offload non-essential models to CPU for memory savings
189
  models_to(models_rbm, device="cpu", excepts=["generator", "previewer"])
190
 
191
+ # Stage C reverse process
192
  with torch.cuda.amp.autocast(): # Use mixed precision
193
  sampling_c = extras.gdf.sample(
194
  models_rbm.generator, conditions, stage_c_latent_shape,
 
206
 
207
  clear_gpu_cache() # Clear cache between stages
208
 
209
+ # Ensure all models are on the right device again
210
+ models_b.generator.to(device)
211
+
212
+ # Stage B reverse process
213
  with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
214
  conditions_b['effnet'] = sampled_c
215
  unconditions_b['effnet'] = torch.zeros_like(sampled_c)
 
222
  sampled_b = sampled_b
223
  sampled = models_b.stage_a.decode(sampled_b).float()
224
 
225
+ # Post-process and save the image
226
  sampled = torch.cat([
227
  torch.nn.functional.interpolate(ref_style.cpu(), size=(height, width)),
228
  sampled.cpu(),
 
242
 
243
  return output_file # Return the path to the saved image
244
 
245
+
246
  import gradio as gr
247
 
248
  gr.Interface(