chaojiemao commited on
Commit
06b3202
·
verified ·
1 Parent(s): 202a7a7

Update ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +15 -9
ace_inference.py CHANGED
@@ -92,6 +92,9 @@ class RefinerInference(DiffusionInference):
92
  if cfg.MODEL.have('DIFFUSION') else None
93
  self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
94
  assert self.diffusion is not None
 
 
 
95
 
96
  @torch.no_grad()
97
  def encode_first_stage(self, x, **kwargs):
@@ -153,7 +156,7 @@ class RefinerInference(DiffusionInference):
153
  noise.append(noise_)
154
  noise, x_shapes = pack_imagelist_into_tensor(noise)
155
  if reverse_scale > 0:
156
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
157
  x_samples = [x.unsqueeze(0) for x in x_samples]
158
  x_start = self.encode_first_stage(x_samples, **kwargs)
159
  if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
@@ -163,7 +166,7 @@ class RefinerInference(DiffusionInference):
163
  else:
164
  x_start = None
165
  # cond stage
166
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
167
  function_name, dtype = self.get_function_info(self.cond_stage_model)
168
  with torch.autocast('cuda',
169
  enabled=dtype == 'float16',
@@ -176,7 +179,7 @@ class RefinerInference(DiffusionInference):
176
  skip_loaded=True)
177
 
178
 
179
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
180
  # UNet use input n_prompt
181
  function_name, dtype = self.get_function_info(
182
  self.diffusion_model)
@@ -207,7 +210,7 @@ class RefinerInference(DiffusionInference):
207
  if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
208
  'diffusion_model',
209
  skip_loaded=True)
210
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
211
  x_samples = self.decode_first_stage(latent)
212
  if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
213
  'first_stage_model',
@@ -279,6 +282,9 @@ class ACEInference(DiffusionInference):
279
  self.size_factor = cfg.get('SIZE_FACTOR', 8)
280
  self.decoder_bias = cfg.get('DECODER_BIAS', 0)
281
  self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
 
 
 
282
 
283
  @torch.no_grad()
284
  def encode_first_stage(self, x, **kwargs):
@@ -390,7 +396,7 @@ class ACEInference(DiffusionInference):
390
  if use_ace and (not is_txt_image or refiner_scale <= 0):
391
  ctx, null_ctx = {}, {}
392
  # Get Noise Shape
393
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
394
  x = self.encode_first_stage(image)
395
  if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
396
  'first_stage_model',
@@ -408,7 +414,7 @@ class ACEInference(DiffusionInference):
408
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
409
 
410
  # Encode Prompt
411
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
412
  function_name, dtype = self.get_function_info(self.cond_stage_model)
413
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
414
  function_name)(prompt)
@@ -425,7 +431,7 @@ class ACEInference(DiffusionInference):
425
  null_ctx['crossattn'] = null_cont
426
 
427
  # Encode Edit Images
428
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
429
  edit_image = [to_device(i, strict=False) for i in edit_image]
430
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
431
  e_img, e_mask = [], []
@@ -443,7 +449,7 @@ class ACEInference(DiffusionInference):
443
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
444
 
445
  # Diffusion Process
446
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
447
  function_name, dtype = self.get_function_info(self.diffusion_model)
448
  with torch.autocast('cuda',
449
  enabled=dtype in ('float16', 'bfloat16'),
@@ -489,7 +495,7 @@ class ACEInference(DiffusionInference):
489
  skip_loaded=False)
490
 
491
  # Decode to Pixel Space
492
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
493
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
494
  x_samples = self.decode_first_stage(samples)
495
  if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
 
92
  if cfg.MODEL.have('DIFFUSION') else None
93
  self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
94
  assert self.diffusion is not None
95
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
96
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
97
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
98
 
99
  @torch.no_grad()
100
  def encode_first_stage(self, x, **kwargs):
 
156
  noise.append(noise_)
157
  noise, x_shapes = pack_imagelist_into_tensor(noise)
158
  if reverse_scale > 0:
159
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
160
  x_samples = [x.unsqueeze(0) for x in x_samples]
161
  x_start = self.encode_first_stage(x_samples, **kwargs)
162
  if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
 
166
  else:
167
  x_start = None
168
  # cond stage
169
+ if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
170
  function_name, dtype = self.get_function_info(self.cond_stage_model)
171
  with torch.autocast('cuda',
172
  enabled=dtype == 'float16',
 
179
  skip_loaded=True)
180
 
181
 
182
+ if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
183
  # UNet use input n_prompt
184
  function_name, dtype = self.get_function_info(
185
  self.diffusion_model)
 
210
  if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
211
  'diffusion_model',
212
  skip_loaded=True)
213
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
214
  x_samples = self.decode_first_stage(latent)
215
  if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
216
  'first_stage_model',
 
282
  self.size_factor = cfg.get('SIZE_FACTOR', 8)
283
  self.decoder_bias = cfg.get('DECODER_BIAS', 0)
284
  self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
285
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
286
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
287
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
288
 
289
  @torch.no_grad()
290
  def encode_first_stage(self, x, **kwargs):
 
396
  if use_ace and (not is_txt_image or refiner_scale <= 0):
397
  ctx, null_ctx = {}, {}
398
  # Get Noise Shape
399
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
400
  x = self.encode_first_stage(image)
401
  if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
402
  'first_stage_model',
 
414
  ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
415
 
416
  # Encode Prompt
417
+ if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
418
  function_name, dtype = self.get_function_info(self.cond_stage_model)
419
  cont, cont_mask = getattr(get_model(self.cond_stage_model),
420
  function_name)(prompt)
 
431
  null_ctx['crossattn'] = null_cont
432
 
433
  # Encode Edit Images
434
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
435
  edit_image = [to_device(i, strict=False) for i in edit_image]
436
  edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
437
  e_img, e_mask = [], []
 
449
  null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
450
 
451
  # Diffusion Process
452
+ if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
453
  function_name, dtype = self.get_function_info(self.diffusion_model)
454
  with torch.autocast('cuda',
455
  enabled=dtype in ('float16', 'bfloat16'),
 
495
  skip_loaded=False)
496
 
497
  # Decode to Pixel Space
498
+ if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
499
  samples = unpack_tensor_into_imagelist(latent, x_shapes)
500
  x_samples = self.decode_first_stage(samples)
501
  if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,