chaojiemao commited on
Commit
fa93807
·
verified ·
1 Parent(s): 031645f

Create ace_inference.py

Browse files
Files changed (1) hide show
  1. ace_inference.py +543 -0
ace_inference.py ADDED
@@ -0,0 +1,543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import copy
4
+ import math
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+ from PIL import Image
13
+ import torchvision.transforms as T
14
+ from scepter.modules.model.registry import DIFFUSIONS
15
+ from scepter.modules.model.utils.basic_utils import check_list_of_list
16
+ from scepter.modules.model.utils.basic_utils import \
17
+ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
18
+ from scepter.modules.model.utils.basic_utils import (
19
+ to_device, unpack_tensor_into_imagelist)
20
+ from scepter.modules.utils.distribute import we
21
+ from scepter.modules.utils.logger import get_logger
22
+
23
+ from .diffusion_inference import DiffusionInference, get_model
24
+
25
+
26
+ def process_edit_image(images,
27
+ masks,
28
+ tasks,
29
+ max_seq_len=1024,
30
+ max_aspect_ratio=4,
31
+ d=16,
32
+ **kwargs):
33
+
34
+ if not isinstance(images, list):
35
+ images = [images]
36
+ if not isinstance(masks, list):
37
+ masks = [masks]
38
+ if not isinstance(tasks, list):
39
+ tasks = [tasks]
40
+
41
+ img_tensors = []
42
+ mask_tensors = []
43
+ for img, mask, task in zip(images, masks, tasks):
44
+ if mask is None or mask == '':
45
+ mask = Image.new('L', img.size, 0)
46
+ W, H = img.size
47
+ if H / W > max_aspect_ratio:
48
+ img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
49
+ mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
50
+ elif W / H > max_aspect_ratio:
51
+ img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
52
+ mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
53
+
54
+ H, W = img.height, img.width
55
+ scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
56
+ rH = int(H * scale) // d * d # ensure divisible by self.d
57
+ rW = int(W * scale) // d * d
58
+
59
+ img = TF.resize(img, (rH, rW),
60
+ interpolation=TF.InterpolationMode.BICUBIC)
61
+ mask = TF.resize(mask, (rH, rW),
62
+ interpolation=TF.InterpolationMode.NEAREST_EXACT)
63
+
64
+ mask = np.asarray(mask)
65
+ mask = np.where(mask > 128, 1, 0)
66
+ mask = mask.astype(
67
+ np.float32) if np.any(mask) else np.ones_like(mask).astype(
68
+ np.float32)
69
+
70
+ img_tensor = TF.to_tensor(img).to(we.device_id)
71
+ img_tensor = TF.normalize(img_tensor,
72
+ mean=[0.5, 0.5, 0.5],
73
+ std=[0.5, 0.5, 0.5])
74
+ mask_tensor = TF.to_tensor(mask).to(we.device_id)
75
+ if task in ['inpainting', 'Try On', 'Inpainting']:
76
+ mask_indicator = mask_tensor.repeat(3, 1, 1)
77
+ img_tensor[mask_indicator == 1] = -1.0
78
+ img_tensors.append(img_tensor)
79
+ mask_tensors.append(mask_tensor)
80
+ return img_tensors, mask_tensors
81
+
82
+
83
+ class TextEmbedding(nn.Module):
84
+ def __init__(self, embedding_shape):
85
+ super().__init__()
86
+ self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
87
+
88
+ class RefinerInference(DiffusionInference):
89
+ def init_from_cfg(self, cfg):
90
+ super().init_from_cfg(cfg)
91
+ self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION, logger=self.logger) \
92
+ if cfg.MODEL.have('DIFFUSION') else None
93
+ self.max_seq_length = cfg.MODEL.get("MAX_SEQ_LENGTH", 4096)
94
+ assert self.diffusion is not None
95
+
96
+ @torch.no_grad()
97
+ def encode_first_stage(self, x, **kwargs):
98
+ _, dtype = self.get_function_info(self.first_stage_model, 'encode')
99
+ with torch.autocast('cuda',
100
+ enabled=dtype in ('float16', 'bfloat16'),
101
+ dtype=getattr(torch, dtype)):
102
+ def run_one_image(u):
103
+ zu = get_model(self.first_stage_model).encode(u)
104
+ if isinstance(zu, (tuple, list)):
105
+ zu = zu[0]
106
+ return zu
107
+ z = [run_one_image(u.unsqueeze(0) if u.dim == 3 else u) for u in x]
108
+ return z
109
+ def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
110
+ c, H, W = image.shape
111
+ scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
112
+ rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
113
+ rW = int(W * scale) // 16 * 16
114
+ image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
115
+ return image
116
+ @torch.no_grad()
117
+ def decode_first_stage(self, z):
118
+ _, dtype = self.get_function_info(self.first_stage_model, 'decode')
119
+ with torch.autocast('cuda',
120
+ enabled=dtype in ('float16', 'bfloat16'),
121
+ dtype=getattr(torch, dtype)):
122
+ return [get_model(self.first_stage_model).decode(zu) for zu in z]
123
+
124
+ def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
125
+ noise = torch.randn(
126
+ num_samples,
127
+ 16,
128
+ # allow for packing
129
+ 2 * math.ceil(h / 16),
130
+ 2 * math.ceil(w / 16),
131
+ device=device,
132
+ dtype=dtype,
133
+ generator=torch.Generator(device=device).manual_seed(seed),
134
+ )
135
+ return noise
136
+ def refine(self,
137
+ x_samples=None,
138
+ prompt=None,
139
+ reverse_scale=-1.,
140
+ seed = 2024,
141
+ use_dynamic_model = False,
142
+ **kwargs
143
+ ):
144
+ print(prompt)
145
+ value_input = copy.deepcopy(self.input)
146
+ x_samples = [self.upscale_resize(x) for x in x_samples]
147
+
148
+ noise = []
149
+ for i, x in enumerate(x_samples):
150
+ noise_ = self.noise_sample(1, x.shape[1],
151
+ x.shape[2], seed,
152
+ device = x.device)
153
+ noise.append(noise_)
154
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
155
+ if reverse_scale > 0:
156
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
157
+ x_samples = [x.unsqueeze(0) for x in x_samples]
158
+ x_start = self.encode_first_stage(x_samples, **kwargs)
159
+ if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
160
+ 'first_stage_model',
161
+ skip_loaded=True)
162
+ x_start, _ = pack_imagelist_into_tensor(x_start)
163
+ else:
164
+ x_start = None
165
+ # cond stage
166
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
167
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
168
+ with torch.autocast('cuda',
169
+ enabled=dtype == 'float16',
170
+ dtype=getattr(torch, dtype)):
171
+ ctx = getattr(get_model(self.cond_stage_model),
172
+ function_name)(prompt)
173
+ ctx["x_shapes"] = x_shapes
174
+ if use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
175
+ 'cond_stage_model',
176
+ skip_loaded=True)
177
+
178
+
179
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
180
+ # UNet use input n_prompt
181
+ function_name, dtype = self.get_function_info(
182
+ self.diffusion_model)
183
+ with torch.autocast('cuda',
184
+ enabled=dtype in ('float16', 'bfloat16'),
185
+ dtype=getattr(torch, dtype)):
186
+ solver_sample = value_input.get('sample', 'flow_euler')
187
+ sample_steps = value_input.get('sample_steps', 20)
188
+ guide_scale = value_input.get('guide_scale', 3.5)
189
+ if guide_scale is not None:
190
+ guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device,
191
+ dtype=noise.dtype)
192
+ else:
193
+ guide_scale = None
194
+ latent = self.diffusion.sample(
195
+ noise=noise,
196
+ sampler=solver_sample,
197
+ model=get_model(self.diffusion_model),
198
+ model_kwargs={"cond": ctx, "guidance": guide_scale},
199
+ steps=sample_steps,
200
+ show_progress=True,
201
+ guide_scale=guide_scale,
202
+ return_intermediate=None,
203
+ reverse_scale=reverse_scale,
204
+ x=x_start,
205
+ **kwargs).float()
206
+ latent = unpack_tensor_into_imagelist(latent, x_shapes)
207
+ if use_dynamic_model: self.dynamic_unload(self.diffusion_model,
208
+ 'diffusion_model',
209
+ skip_loaded=True)
210
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
211
+ x_samples = self.decode_first_stage(latent)
212
+ if use_dynamic_model: self.dynamic_unload(self.first_stage_model,
213
+ 'first_stage_model',
214
+ skip_loaded=True)
215
+ return x_samples
216
+
217
+
218
+ class ACEInference(DiffusionInference):
219
+ def __init__(self, logger=None):
220
+ if logger is None:
221
+ logger = get_logger(name='scepter')
222
+ self.logger = logger
223
+ self.loaded_model = {}
224
+ self.loaded_model_name = [
225
+ 'diffusion_model', 'first_stage_model', 'cond_stage_model'
226
+ ]
227
+
228
+ def init_from_cfg(self, cfg):
229
+ self.name = cfg.NAME
230
+ self.is_default = cfg.get('IS_DEFAULT', False)
231
+ self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
232
+ module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
233
+ assert cfg.have('MODEL')
234
+
235
+ self.diffusion_model = self.infer_model(
236
+ cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
237
+ 'DIFFUSION_MODEL',
238
+ None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
239
+ self.first_stage_model = self.infer_model(
240
+ cfg.MODEL.FIRST_STAGE_MODEL,
241
+ module_paras.get(
242
+ 'FIRST_STAGE_MODEL',
243
+ None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
244
+ self.cond_stage_model = self.infer_model(
245
+ cfg.MODEL.COND_STAGE_MODEL,
246
+ module_paras.get(
247
+ 'COND_STAGE_MODEL',
248
+ None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
249
+
250
+ self.refiner_model_cfg = cfg.get('REFINER_MODEL', None)
251
+ # self.refiner_scale = cfg.get('REFINER_SCALE', 0.)
252
+ # self.refiner_prompt = cfg.get('REFINER_PROMPT', "")
253
+ self.ace_prompt = cfg.get("ACE_PROMPT", [])
254
+ if self.refiner_model_cfg:
255
+ self.refiner_module = RefinerInference(self.logger)
256
+ self.refiner_module.init_from_cfg(self.refiner_model_cfg)
257
+ else:
258
+ self.refiner_module = None
259
+
260
+ self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
261
+ logger=self.logger)
262
+
263
+
264
+ self.interpolate_func = lambda x: (F.interpolate(
265
+ x.unsqueeze(0),
266
+ scale_factor=1 / self.size_factor,
267
+ mode='nearest-exact') if x is not None else None)
268
+ self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
269
+ self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
270
+ False)
271
+ if self.use_text_pos_embeddings:
272
+ self.text_position_embeddings = TextEmbedding(
273
+ (10, 4096)).eval().requires_grad_(False).to(we.device_id)
274
+ else:
275
+ self.text_position_embeddings = None
276
+
277
+ self.max_seq_len = cfg.MODEL.DIFFUSION_MODEL.MAX_SEQ_LEN
278
+ self.scale_factor = cfg.get('SCALE_FACTOR', 0.18215)
279
+ self.size_factor = cfg.get('SIZE_FACTOR', 8)
280
+ self.decoder_bias = cfg.get('DECODER_BIAS', 0)
281
+ self.default_n_prompt = cfg.get('DEFAULT_N_PROMPT', '')
282
+
283
+ @torch.no_grad()
284
+ def encode_first_stage(self, x, **kwargs):
285
+ _, dtype = self.get_function_info(self.first_stage_model, 'encode')
286
+ with torch.autocast('cuda',
287
+ enabled=(dtype != 'float32'),
288
+ dtype=getattr(torch, dtype)):
289
+ z = [
290
+ self.scale_factor * get_model(self.first_stage_model)._encode(
291
+ i.unsqueeze(0).to(getattr(torch, dtype))) for i in x
292
+ ]
293
+ return z
294
+
295
+ @torch.no_grad()
296
+ def decode_first_stage(self, z):
297
+ _, dtype = self.get_function_info(self.first_stage_model, 'decode')
298
+ with torch.autocast('cuda',
299
+ enabled=(dtype != 'float32'),
300
+ dtype=getattr(torch, dtype)):
301
+ x = [
302
+ get_model(self.first_stage_model)._decode(
303
+ 1. / self.scale_factor * i.to(getattr(torch, dtype)))
304
+ for i in z
305
+ ]
306
+ return x
307
+
308
+
309
+
310
+ @torch.no_grad()
311
+ def __call__(self,
312
+ image=None,
313
+ mask=None,
314
+ prompt='',
315
+ task=None,
316
+ negative_prompt='',
317
+ output_height=512,
318
+ output_width=512,
319
+ sampler='ddim',
320
+ sample_steps=20,
321
+ guide_scale=4.5,
322
+ guide_rescale=0.5,
323
+ seed=-1,
324
+ history_io=None,
325
+ tar_index=0,
326
+ **kwargs):
327
+ input_image, input_mask = image, mask
328
+ g = torch.Generator(device=we.device_id)
329
+ seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
330
+ g.manual_seed(int(seed))
331
+ if input_image is not None:
332
+ # assert isinstance(input_image, list) and isinstance(input_mask, list)
333
+ if task is None:
334
+ task = [''] * len(input_image)
335
+ if not isinstance(prompt, list):
336
+ prompt = [prompt] * len(input_image)
337
+ if history_io is not None and len(history_io) > 0:
338
+ his_image, his_maks, his_prompt, his_task = history_io[
339
+ 'image'], history_io['mask'], history_io[
340
+ 'prompt'], history_io['task']
341
+ assert len(his_image) == len(his_maks) == len(
342
+ his_prompt) == len(his_task)
343
+ input_image = his_image + input_image
344
+ input_mask = his_maks + input_mask
345
+ task = his_task + task
346
+ prompt = his_prompt + [prompt[-1]]
347
+ prompt = [
348
+ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
349
+ for i, pp in enumerate(prompt)
350
+ ]
351
+
352
+ edit_image, edit_image_mask = process_edit_image(
353
+ input_image, input_mask, task, max_seq_len=self.max_seq_len)
354
+
355
+ image, image_mask = edit_image[tar_index], edit_image_mask[
356
+ tar_index]
357
+ edit_image, edit_image_mask = [edit_image], [edit_image_mask]
358
+
359
+ else:
360
+ edit_image = edit_image_mask = [[]]
361
+ image = torch.zeros(
362
+ size=[3, int(output_height),
363
+ int(output_width)])
364
+ image_mask = torch.ones(
365
+ size=[1, int(output_height),
366
+ int(output_width)])
367
+ if not isinstance(prompt, list):
368
+ prompt = [prompt]
369
+
370
+ image, image_mask, prompt = [image], [image_mask], [prompt]
371
+ assert check_list_of_list(prompt) and check_list_of_list(
372
+ edit_image) and check_list_of_list(edit_image_mask)
373
+ # Assign Negative Prompt
374
+ if isinstance(negative_prompt, list):
375
+ negative_prompt = negative_prompt[0]
376
+ assert isinstance(negative_prompt, str)
377
+
378
+ n_prompt = copy.deepcopy(prompt)
379
+ for nn_p_id, nn_p in enumerate(n_prompt):
380
+ assert isinstance(nn_p, list)
381
+ n_prompt[nn_p_id][-1] = negative_prompt
382
+
383
+ is_txt_image = sum([len(e_i) for e_i in edit_image]) < 1
384
+ image = to_device(image)
385
+
386
+ refiner_scale = kwargs.pop("refiner_scale", 0.0)
387
+ refiner_prompt = kwargs.pop("refiner_prompt", "")
388
+ use_ace = kwargs.pop("use_ace", True)
389
+ # <= 0 use ace as the txt2img generator.
390
+ if use_ace and (not is_txt_image or refiner_scale <= 0):
391
+ ctx, null_ctx = {}, {}
392
+ # Get Noise Shape
393
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
394
+ x = self.encode_first_stage(image)
395
+ if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
396
+ 'first_stage_model',
397
+ skip_loaded=True)
398
+ noise = [
399
+ torch.empty(*i.shape, device=we.device_id).normal_(generator=g)
400
+ for i in x
401
+ ]
402
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
403
+ ctx['x_shapes'] = null_ctx['x_shapes'] = x_shapes
404
+
405
+ image_mask = to_device(image_mask, strict=False)
406
+ cond_mask = [self.interpolate_func(i) for i in image_mask
407
+ ] if image_mask is not None else [None] * len(image)
408
+ ctx['x_mask'] = null_ctx['x_mask'] = cond_mask
409
+
410
+ # Encode Prompt
411
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
412
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
413
+ cont, cont_mask = getattr(get_model(self.cond_stage_model),
414
+ function_name)(prompt)
415
+ cont, cont_mask = self.cond_stage_embeddings(prompt, edit_image, cont,
416
+ cont_mask)
417
+ null_cont, null_cont_mask = getattr(get_model(self.cond_stage_model),
418
+ function_name)(n_prompt)
419
+ null_cont, null_cont_mask = self.cond_stage_embeddings(
420
+ prompt, edit_image, null_cont, null_cont_mask)
421
+ if self.use_dynamic_model: self.dynamic_unload(self.cond_stage_model,
422
+ 'cond_stage_model',
423
+ skip_loaded=False)
424
+ ctx['crossattn'] = cont
425
+ null_ctx['crossattn'] = null_cont
426
+
427
+ # Encode Edit Images
428
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
429
+ edit_image = [to_device(i, strict=False) for i in edit_image]
430
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
431
+ e_img, e_mask = [], []
432
+ for u, m in zip(edit_image, edit_image_mask):
433
+ if u is None:
434
+ continue
435
+ if m is None:
436
+ m = [None] * len(u)
437
+ e_img.append(self.encode_first_stage(u, **kwargs))
438
+ e_mask.append([self.interpolate_func(i) for i in m])
439
+ if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
440
+ 'first_stage_model',
441
+ skip_loaded=True)
442
+ null_ctx['edit'] = ctx['edit'] = e_img
443
+ null_ctx['edit_mask'] = ctx['edit_mask'] = e_mask
444
+
445
+ # Diffusion Process
446
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
447
+ function_name, dtype = self.get_function_info(self.diffusion_model)
448
+ with torch.autocast('cuda',
449
+ enabled=dtype in ('float16', 'bfloat16'),
450
+ dtype=getattr(torch, dtype)):
451
+ latent = self.diffusion.sample(
452
+ noise=noise,
453
+ sampler=sampler,
454
+ model=get_model(self.diffusion_model),
455
+ model_kwargs=[{
456
+ 'cond':
457
+ ctx,
458
+ 'mask':
459
+ cont_mask,
460
+ 'text_position_embeddings':
461
+ self.text_position_embeddings.pos if hasattr(
462
+ self.text_position_embeddings, 'pos') else None
463
+ }, {
464
+ 'cond':
465
+ null_ctx,
466
+ 'mask':
467
+ null_cont_mask,
468
+ 'text_position_embeddings':
469
+ self.text_position_embeddings.pos if hasattr(
470
+ self.text_position_embeddings, 'pos') else None
471
+ }] if guide_scale is not None and guide_scale > 1 else {
472
+ 'cond':
473
+ null_ctx,
474
+ 'mask':
475
+ cont_mask,
476
+ 'text_position_embeddings':
477
+ self.text_position_embeddings.pos if hasattr(
478
+ self.text_position_embeddings, 'pos') else None
479
+ },
480
+ steps=sample_steps,
481
+ show_progress=True,
482
+ seed=seed,
483
+ guide_scale=guide_scale,
484
+ guide_rescale=guide_rescale,
485
+ return_intermediate=None,
486
+ **kwargs)
487
+ if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
488
+ 'diffusion_model',
489
+ skip_loaded=False)
490
+
491
+ # Decode to Pixel Space
492
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
493
+ samples = unpack_tensor_into_imagelist(latent, x_shapes)
494
+ x_samples = self.decode_first_stage(samples)
495
+ if self.use_dynamic_model: self.dynamic_unload(self.first_stage_model,
496
+ 'first_stage_model',
497
+ skip_loaded=False)
498
+ x_samples = [x.squeeze(0) for x in x_samples]
499
+ else:
500
+ x_samples = image
501
+ if self.refiner_module and refiner_scale > 0:
502
+ if is_txt_image:
503
+ random.shuffle(self.ace_prompt)
504
+ input_refine_prompt = [self.ace_prompt[0] + refiner_prompt if p[0] == "" else p[0] for p in prompt]
505
+ input_refine_scale = -1.
506
+ else:
507
+ input_refine_prompt = [p[0].replace("{image}", "") + " " + refiner_prompt for p in prompt]
508
+ input_refine_scale = refiner_scale
509
+ print(input_refine_prompt)
510
+
511
+ x_samples = self.refiner_module.refine(x_samples,
512
+ reverse_scale = input_refine_scale,
513
+ prompt= input_refine_prompt,
514
+ seed=seed,
515
+ use_dynamic_model=self.use_dynamic_model)
516
+
517
+ imgs = [
518
+ torch.clamp((x_i.float() + 1.0) / 2.0 + self.decoder_bias / 255,
519
+ min=0.0,
520
+ max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
521
+ for x_i in x_samples
522
+ ]
523
+ imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
524
+ return imgs
525
+
526
+ def cond_stage_embeddings(self, prompt, edit_image, cont, cont_mask):
527
+ if self.use_text_pos_embeddings and not torch.sum(
528
+ self.text_position_embeddings.pos) > 0:
529
+ identifier_cont, _ = getattr(get_model(self.cond_stage_model),
530
+ 'encode')(self.text_indentifers,
531
+ return_mask=True)
532
+ self.text_position_embeddings.load_state_dict(
533
+ {'pos': identifier_cont[:, 0, :]})
534
+
535
+ cont_, cont_mask_ = [], []
536
+ for pp, edit, c, cm in zip(prompt, edit_image, cont, cont_mask):
537
+ if isinstance(pp, list):
538
+ cont_.append([c[-1], *c] if len(edit) > 0 else [c[-1]])
539
+ cont_mask_.append([cm[-1], *cm] if len(edit) > 0 else [cm[-1]])
540
+ else:
541
+ raise NotImplementedError
542
+
543
+ return cont_, cont_mask_