Spaces:
Running
on
Zero
Running
on
Zero
Update ace_inference.py
Browse files- ace_inference.py +15 -9
ace_inference.py
CHANGED
@@ -92,6 +92,9 @@ class RefinerInference(DiffusionInference):
|
|
92 |
if cfg.MODEL.have('DIFFUSION') else None
|
93 |
self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
|
94 |
assert self.diffusion is not None
|
|
|
|
|
|
|
95 |
|
96 |
@torch.no_grad()
|
97 |
def encode_first_stage(self, x, **kwargs):
|
@@ -153,7 +156,7 @@ class RefinerInference(DiffusionInference):
|
|
153 |
noise.append(noise_)
|
154 |
noise, x_shapes = pack_imagelist_into_tensor(noise)
|
155 |
if reverse_scale > 0:
|
156 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
157 |
x_samples = [x.unsqueeze(0) for x in x_samples]
|
158 |
x_start = self.encode_first_stage(x_samples, **kwargs)
|
159 |
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
@@ -163,7 +166,7 @@ class RefinerInference(DiffusionInference):
|
|
163 |
else:
|
164 |
x_start = None
|
165 |
# cond stage
|
166 |
-
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
167 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
168 |
with torch.autocast('cuda',
|
169 |
enabled=dtype == 'float16',
|
@@ -176,7 +179,7 @@ class RefinerInference(DiffusionInference):
|
|
176 |
skip_loaded=True)
|
177 |
|
178 |
|
179 |
-
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
180 |
# UNet use input n_prompt
|
181 |
function_name, dtype = self.get_function_info(
|
182 |
self.diffusion_model)
|
@@ -207,7 +210,7 @@ class RefinerInference(DiffusionInference):
|
|
207 |
if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
208 |
'diffusion_model',
|
209 |
skip_loaded=True)
|
210 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
211 |
x_samples = self.decode_first_stage(latent)
|
212 |
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
213 |
'first_stage_model',
|
@@ -279,6 +282,9 @@ class ACEInference(DiffusionInference):
|
|
279 |
self.size_factor = cfg.get('SIZE_FACTOR', 8)
|
280 |
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
281 |
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
|
|
|
|
|
|
282 |
|
283 |
@torch.no_grad()
|
284 |
def encode_first_stage(self, x, **kwargs):
|
@@ -390,7 +396,7 @@ class ACEInference(DiffusionInference):
|
|
390 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
391 |
ctx, null_ctx = {}, {}
|
392 |
# Get Noise Shape
|
393 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
394 |
x = self.encode_first_stage(image)
|
395 |
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
396 |
'first_stage_model',
|
@@ -408,7 +414,7 @@ class ACEInference(DiffusionInference):
|
|
408 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
409 |
|
410 |
# Encode Prompt
|
411 |
-
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
412 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
413 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
414 |
function_name)(prompt)
|
@@ -425,7 +431,7 @@ class ACEInference(DiffusionInference):
|
|
425 |
null_ctx['crossattn'] = null_cont
|
426 |
|
427 |
# Encode Edit Images
|
428 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
429 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
430 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
431 |
e_img, e_mask = [], []
|
@@ -443,7 +449,7 @@ class ACEInference(DiffusionInference):
|
|
443 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
444 |
|
445 |
# Diffusion Process
|
446 |
-
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
447 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
448 |
with torch.autocast('cuda',
|
449 |
enabled=dtype in ('float16', 'bfloat16'),
|
@@ -489,7 +495,7 @@ class ACEInference(DiffusionInference):
|
|
489 |
skip_loaded=False)
|
490 |
|
491 |
# Decode to Pixel Space
|
492 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
493 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
494 |
x_samples = self.decode_first_stage(samples)
|
495 |
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
|
|
92 |
if cfg.MODEL.have('DIFFUSION') else None
|
93 |
self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
|
94 |
assert self.diffusion is not None
|
95 |
+
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
96 |
+
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
97 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
98 |
|
99 |
@torch.no_grad()
|
100 |
def encode_first_stage(self, x, **kwargs):
|
|
|
156 |
noise.append(noise_)
|
157 |
noise, x_shapes = pack_imagelist_into_tensor(noise)
|
158 |
if reverse_scale > 0:
|
159 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
160 |
x_samples = [x.unsqueeze(0) for x in x_samples]
|
161 |
x_start = self.encode_first_stage(x_samples, **kwargs)
|
162 |
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
|
|
166 |
else:
|
167 |
x_start = None
|
168 |
# cond stage
|
169 |
+
if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
170 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
171 |
with torch.autocast('cuda',
|
172 |
enabled=dtype == 'float16',
|
|
|
179 |
skip_loaded=True)
|
180 |
|
181 |
|
182 |
+
if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
183 |
# UNet use input n_prompt
|
184 |
function_name, dtype = self.get_function_info(
|
185 |
self.diffusion_model)
|
|
|
210 |
if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
211 |
'diffusion_model',
|
212 |
skip_loaded=True)
|
213 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
214 |
x_samples = self.decode_first_stage(latent)
|
215 |
if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
216 |
'first_stage_model',
|
|
|
282 |
self.size_factor = cfg.get('SIZE_FACTOR', 8)
|
283 |
self.decoder_bias = cfg.get('DECODER_BIAS', 0)
|
284 |
self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
|
285 |
+
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
286 |
+
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
287 |
+
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
288 |
|
289 |
@torch.no_grad()
|
290 |
def encode_first_stage(self, x, **kwargs):
|
|
|
396 |
if use_ace and (not is_txt_image or refiner_scale <= 0):
|
397 |
ctx, null_ctx = {}, {}
|
398 |
# Get Noise Shape
|
399 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
400 |
x = self.encode_first_stage(image)
|
401 |
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|
402 |
'first_stage_model',
|
|
|
414 |
ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
|
415 |
|
416 |
# Encode Prompt
|
417 |
+
if use_dynamic_model: self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
418 |
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
419 |
cont, cont_mask = getattr(get_model(self.cond_stage_model),
|
420 |
function_name)(prompt)
|
|
|
431 |
null_ctx['crossattn'] = null_cont
|
432 |
|
433 |
# Encode Edit Images
|
434 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
435 |
edit_image = [to_device(i, strict=False) for i in edit_image]
|
436 |
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
437 |
e_img, e_mask = [], []
|
|
|
449 |
null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
|
450 |
|
451 |
# Diffusion Process
|
452 |
+
if use_dynamic_model: self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
453 |
function_name, dtype = self.get_function_info(self.diffusion_model)
|
454 |
with torch.autocast('cuda',
|
455 |
enabled=dtype in ('float16', 'bfloat16'),
|
|
|
495 |
skip_loaded=False)
|
496 |
|
497 |
# Decode to Pixel Space
|
498 |
+
if use_dynamic_model: self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
499 |
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
500 |
x_samples = self.decode_first_stage(samples)
|
501 |
if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
|