chaojiemao commited on
Commit
154c805
·
1 Parent(s): 86c729f

modify app.py

Browse files
Files changed (6) hide show
  1. ace_inference.py +356 -0
  2. example.py +370 -0
  3. model/__init__.py +1 -0
  4. model/flux.py +1064 -0
  5. model/layers.py +356 -0
  6. utils.py +95 -0
ace_inference.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import copy
4
+ import math
5
+ import random
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ import torchvision.transforms.functional as TF
12
+ from PIL import Image
13
+ import torchvision.transforms as T
14
+ from scepter.modules.model.registry import DIFFUSIONS
15
+ from scepter.modules.model.utils.basic_utils import check_list_of_list
16
+ from scepter.modules.model.utils.basic_utils import \
17
+ pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
18
+ from scepter.modules.model.utils.basic_utils import (
19
+ to_device, unpack_tensor_into_imagelist)
20
+ from scepter.modules.utils.distribute import we
21
+ from scepter.modules.utils.logger import get_logger
22
+
23
+ from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
24
+
25
+
26
+ def process_edit_image(images,
27
+ masks,
28
+ tasks,
29
+ max_seq_len=1024,
30
+ max_aspect_ratio=4,
31
+ d=16,
32
+ **kwargs):
33
+
34
+ if not isinstance(images, list):
35
+ images = [images]
36
+ if not isinstance(masks, list):
37
+ masks = [masks]
38
+ if not isinstance(tasks, list):
39
+ tasks = [tasks]
40
+
41
+ img_tensors = []
42
+ mask_tensors = []
43
+ for img, mask, task in zip(images, masks, tasks):
44
+ if mask is None or mask == '':
45
+ mask = Image.new('L', img.size, 0)
46
+ W, H = img.size
47
+ if H / W > max_aspect_ratio:
48
+ img = TF.center_crop(img, [int(max_aspect_ratio * W), W])
49
+ mask = TF.center_crop(mask, [int(max_aspect_ratio * W), W])
50
+ elif W / H > max_aspect_ratio:
51
+ img = TF.center_crop(img, [H, int(max_aspect_ratio * H)])
52
+ mask = TF.center_crop(mask, [H, int(max_aspect_ratio * H)])
53
+
54
+ H, W = img.height, img.width
55
+ scale = min(1.0, math.sqrt(max_seq_len / ((H / d) * (W / d))))
56
+ rH = int(H * scale) // d * d # ensure divisible by self.d
57
+ rW = int(W * scale) // d * d
58
+
59
+ img = TF.resize(img, (rH, rW),
60
+ interpolation=TF.InterpolationMode.BICUBIC)
61
+ mask = TF.resize(mask, (rH, rW),
62
+ interpolation=TF.InterpolationMode.NEAREST_EXACT)
63
+
64
+ mask = np.asarray(mask)
65
+ mask = np.where(mask > 128, 1, 0)
66
+ mask = mask.astype(
67
+ np.float32) if np.any(mask) else np.ones_like(mask).astype(
68
+ np.float32)
69
+
70
+ img_tensor = TF.to_tensor(img).to(we.device_id)
71
+ img_tensor = TF.normalize(img_tensor,
72
+ mean=[0.5, 0.5, 0.5],
73
+ std=[0.5, 0.5, 0.5])
74
+ mask_tensor = TF.to_tensor(mask).to(we.device_id)
75
+ if task in ['inpainting', 'Try On', 'Inpainting']:
76
+ mask_indicator = mask_tensor.repeat(3, 1, 1)
77
+ img_tensor[mask_indicator == 1] = -1.0
78
+ img_tensors.append(img_tensor)
79
+ mask_tensors.append(mask_tensor)
80
+ return img_tensors, mask_tensors
81
+
82
+ class TextEmbedding(nn.Module):
83
+ def __init__(self, embedding_shape):
84
+ super().__init__()
85
+ self.pos = nn.Parameter(data=torch.zeros(embedding_shape))
86
+
87
+ class ACEFluxLCInference(DiffusionInference):
88
+ def __init__(self, logger=None):
89
+ if logger is None:
90
+ logger = get_logger(name='scepter')
91
+ self.logger = logger
92
+ self.loaded_model = {}
93
+ self.loaded_model_name = [
94
+ 'diffusion_model', 'first_stage_model', 'cond_stage_model', 'ref_cond_stage_model'
95
+ ]
96
+
97
+ def init_from_cfg(self, cfg):
98
+ self.name = cfg.NAME
99
+ self.is_default = cfg.get('IS_DEFAULT', False)
100
+ self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
101
+ module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
102
+ assert cfg.have('MODEL')
103
+ self.size_factor = cfg.get('SIZE_FACTOR', 8)
104
+ self.diffusion_model = self.infer_model(
105
+ cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
106
+ 'DIFFUSION_MODEL',
107
+ None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
108
+ self.first_stage_model = self.infer_model(
109
+ cfg.MODEL.FIRST_STAGE_MODEL,
110
+ module_paras.get(
111
+ 'FIRST_STAGE_MODEL',
112
+ None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
113
+ self.cond_stage_model = self.infer_model(
114
+ cfg.MODEL.COND_STAGE_MODEL,
115
+ module_paras.get(
116
+ 'COND_STAGE_MODEL',
117
+ None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
118
+
119
+ self.ref_cond_stage_model = self.infer_model(
120
+ cfg.MODEL.REF_COND_STAGE_MODEL,
121
+ module_paras.get(
122
+ 'REF_COND_STAGE_MODEL',
123
+ None)) if cfg.MODEL.have('REF_COND_STAGE_MODEL') else None
124
+
125
+ self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
126
+ logger=self.logger)
127
+ self.interpolate_func = lambda x: (F.interpolate(
128
+ x.unsqueeze(0),
129
+ scale_factor=1 / self.size_factor,
130
+ mode='nearest-exact') if x is not None else None)
131
+
132
+ self.max_seq_length = cfg.get("MAX_SEQ_LENGTH", 4096)
133
+ self.src_max_seq_length = cfg.get("SRC_MAX_SEQ_LENGTH", 1024)
134
+ self.image_token = cfg.MODEL.get("IMAGE_TOKEN", "<img>")
135
+
136
+ self.text_indentifers = cfg.MODEL.get('TEXT_IDENTIFIER', [])
137
+ self.use_text_pos_embeddings = cfg.MODEL.get('USE_TEXT_POS_EMBEDDINGS',
138
+ False)
139
+ if self.use_text_pos_embeddings:
140
+ self.text_position_embeddings = TextEmbedding(
141
+ (10, 4096)).eval().requires_grad_(False).to(we.device_id)
142
+ else:
143
+ self.text_position_embeddings = None
144
+
145
+ if not self.use_dynamic_model:
146
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
147
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
148
+ if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
149
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
150
+
151
+ def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
152
+ c, H, W = image.shape
153
+ scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
154
+ rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
155
+ rW = int(W * scale) // 16 * 16
156
+ image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
157
+ return image
158
+
159
+
160
+ @torch.no_grad()
161
+ def encode_first_stage(self, x, **kwargs):
162
+ _, dtype = self.get_function_info(self.first_stage_model, 'encode')
163
+ with torch.autocast('cuda',
164
+ enabled=dtype in ('float16', 'bfloat16'),
165
+ dtype=getattr(torch, dtype)):
166
+ def run_one_image(u):
167
+ zu = get_model(self.first_stage_model).encode(u)
168
+ if isinstance(zu, (tuple, list)):
169
+ zu = zu[0]
170
+ return zu
171
+
172
+ z = [run_one_image(u.unsqueeze(0) if u.dim() == 3 else u) for u in x]
173
+ return z
174
+
175
+
176
+ @torch.no_grad()
177
+ def decode_first_stage(self, z):
178
+ _, dtype = self.get_function_info(self.first_stage_model, 'decode')
179
+ with torch.autocast('cuda',
180
+ enabled=dtype in ('float16', 'bfloat16'),
181
+ dtype=getattr(torch, dtype)):
182
+ return [get_model(self.first_stage_model).decode(zu) for zu in z]
183
+
184
+ def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
185
+ noise = torch.randn(
186
+ num_samples,
187
+ 16,
188
+ # allow for packing
189
+ 2 * math.ceil(h / 16),
190
+ 2 * math.ceil(w / 16),
191
+ device=device,
192
+ dtype=dtype,
193
+ generator=torch.Generator(device=device).manual_seed(seed),
194
+ )
195
+ return noise
196
+
197
+ # def preprocess_prompt(self, prompt):
198
+ # prompt_ = [[pp] if isinstance(pp, str) else pp for pp in prompt]
199
+ # for pp_id, pp in enumerate(prompt_):
200
+ # prompt_[pp_id] = [""] + pp
201
+ # for p_id, p in enumerate(prompt_[pp_id]):
202
+ # prompt_[pp_id][p_id] = self.image_token + self.text_indentifers[p_id] + " " + p
203
+ # prompt_[pp_id] = [f";".join(prompt_[pp_id])]
204
+ # return prompt_
205
+
206
+ @torch.no_grad()
207
+ def __call__(self,
208
+ image=None,
209
+ mask=None,
210
+ prompt='',
211
+ task=None,
212
+ negative_prompt='',
213
+ output_height=1024,
214
+ output_width=1024,
215
+ sampler='flow_euler',
216
+ sample_steps=20,
217
+ guide_scale=3.5,
218
+ seed=-1,
219
+ history_io=None,
220
+ tar_index=0,
221
+ align=0,
222
+ **kwargs):
223
+ input_image, input_mask = image, mask
224
+ seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
225
+ if input_image is not None:
226
+ # assert isinstance(input_image, list) and isinstance(input_mask, list)
227
+ if task is None:
228
+ task = [''] * len(input_image)
229
+ if not isinstance(prompt, list):
230
+ prompt = [prompt] * len(input_image)
231
+ prompt = [
232
+ pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
233
+ for i, pp in enumerate(prompt)
234
+ ]
235
+ edit_image, edit_image_mask = process_edit_image(
236
+ input_image, input_mask, task, max_seq_len=self.src_max_seq_length)
237
+ image, image_mask = self.upscale_resize(edit_image[tar_index]), self.upscale_resize(edit_image_mask[
238
+ tar_index])
239
+ # edit_image, edit_image_mask = [[self.upscale_resize(i) for i in edit_image]], [[self.upscale_resize(i) for i in edit_image_mask]]
240
+ # image, image_mask = edit_image[tar_index], edit_image_mask[tar_index]
241
+ edit_image, edit_image_mask = [edit_image], [edit_image_mask]
242
+ else:
243
+ edit_image = edit_image_mask = [[]]
244
+ image = torch.zeros(
245
+ size=[3, int(output_height),
246
+ int(output_width)])
247
+ image_mask = torch.ones(
248
+ size=[1, int(output_height),
249
+ int(output_width)])
250
+ if not isinstance(prompt, list):
251
+ prompt = [prompt]
252
+
253
+ image, image_mask, prompt = [image], [image_mask], [prompt],
254
+ align = [align for p in prompt] if isinstance(align, int) else align
255
+
256
+ assert check_list_of_list(prompt) and check_list_of_list(
257
+ edit_image) and check_list_of_list(edit_image_mask)
258
+ # negative prompt is not used
259
+ image = to_device(image)
260
+ ctx = {}
261
+ # Get Noise Shape
262
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
263
+ x = self.encode_first_stage(image)
264
+ self.dynamic_unload(self.first_stage_model,
265
+ 'first_stage_model',
266
+ skip_loaded=not self.use_dynamic_model)
267
+
268
+ g = torch.Generator(device=we.device_id).manual_seed(seed)
269
+
270
+ noise = [
271
+ torch.randn((1, 16, i.shape[2], i.shape[3]), device=we.device_id, dtype=torch.bfloat16).normal_(generator=g)
272
+ for i in x
273
+ ]
274
+ noise, x_shapes = pack_imagelist_into_tensor(noise)
275
+ ctx['x_shapes'] = x_shapes
276
+ ctx['align'] = align
277
+
278
+ image_mask = to_device(image_mask, strict=False)
279
+ cond_mask = [self.interpolate_func(i) for i in image_mask
280
+ ] if image_mask is not None else [None] * len(image)
281
+ ctx['x_mask'] = cond_mask
282
+ # Encode Prompt
283
+ instruction_prompt = [[pp[-1]] if "{image}" in pp[-1] else ["{image} " + pp[-1]] for pp in prompt]
284
+ self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
285
+ function_name, dtype = self.get_function_info(self.cond_stage_model)
286
+ cont = getattr(get_model(self.cond_stage_model), function_name)(instruction_prompt)
287
+ cont["context"] = [ct[-1] for ct in cont["context"]]
288
+ cont["y"] = [ct[-1] for ct in cont["y"]]
289
+ self.dynamic_unload(self.cond_stage_model,
290
+ 'cond_stage_model',
291
+ skip_loaded=not self.use_dynamic_model)
292
+ ctx.update(cont)
293
+
294
+ # Encode Edit Images
295
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
296
+ edit_image = [to_device(i, strict=False) for i in edit_image]
297
+ edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
298
+ e_img, e_mask = [], []
299
+ for u, m in zip(edit_image, edit_image_mask):
300
+ if u is None:
301
+ continue
302
+ if m is None:
303
+ m = [None] * len(u)
304
+ e_img.append(self.encode_first_stage(u, **kwargs))
305
+ e_mask.append([self.interpolate_func(i) for i in m])
306
+ self.dynamic_unload(self.first_stage_model,
307
+ 'first_stage_model',
308
+ skip_loaded=not self.use_dynamic_model)
309
+ ctx['edit_x'] = e_img
310
+ ctx['edit_mask'] = e_mask
311
+ # Encode Ref Images
312
+ if guide_scale is not None:
313
+ guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device, dtype=noise.dtype)
314
+ else:
315
+ guide_scale = None
316
+
317
+ # Diffusion Process
318
+ self.dynamic_load(self.diffusion_model, 'diffusion_model')
319
+ function_name, dtype = self.get_function_info(self.diffusion_model)
320
+ with torch.autocast('cuda',
321
+ enabled=dtype in ('float16', 'bfloat16'),
322
+ dtype=getattr(torch, dtype)):
323
+ latent = self.diffusion.sample(
324
+ noise=noise,
325
+ sampler=sampler,
326
+ model=get_model(self.diffusion_model),
327
+ model_kwargs={
328
+ "cond": ctx, "guidance": guide_scale, "gc_seg": -1
329
+ },
330
+ steps=sample_steps,
331
+ show_progress=True,
332
+ guide_scale=guide_scale,
333
+ return_intermediate=None,
334
+ reverse_scale=-1,
335
+ **kwargs).float()
336
+ if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
337
+ 'diffusion_model',
338
+ skip_loaded=not self.use_dynamic_model)
339
+
340
+ # Decode to Pixel Space
341
+ self.dynamic_load(self.first_stage_model, 'first_stage_model')
342
+ samples = unpack_tensor_into_imagelist(latent, x_shapes)
343
+ x_samples = self.decode_first_stage(samples)
344
+ self.dynamic_unload(self.first_stage_model,
345
+ 'first_stage_model',
346
+ skip_loaded=not self.use_dynamic_model)
347
+ x_samples = [x.squeeze(0) for x in x_samples]
348
+
349
+ imgs = [
350
+ torch.clamp((x_i.float() + 1.0) / 2.0,
351
+ min=0.0,
352
+ max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
353
+ for x_i in x_samples
354
+ ]
355
+ imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
356
+ return imgs
example.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import os
4
+ from PIL import Image
5
+ from scepter.modules.utils.file_system import FS
6
+
7
+
8
+ def download_image(image, local_path=None):
9
+ if not FS.exists(local_path):
10
+ local_path = FS.get_from(image, local_path=local_path)
11
+ if local_path.split(".")[-1] in ['jpg', 'jpeg']:
12
+ im = Image.open(local_path).convert("RGB")
13
+ im.save(local_path, format='JPEG')
14
+ return local_path
15
+
16
+
17
+ def get_examples(cache_dir):
18
+ print('Downloading Examples ...')
19
+ examples = [
20
+ [
21
+ 'Facial Editing',
22
+ download_image(
23
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e33edc106953.png?raw=true',
24
+ os.path.join(cache_dir, 'examples/e33edc106953.jpg')), None,
25
+ None, '{image} let the man smile', 6666
26
+ ],
27
+ [
28
+ 'Facial Editing',
29
+ download_image(
30
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5d2bcc91a3e9.png?raw=true',
31
+ os.path.join(cache_dir, 'examples/5d2bcc91a3e9.jpg')), None,
32
+ None, 'let the man in {image} wear sunglasses', 9999
33
+ ],
34
+ [
35
+ 'Facial Editing',
36
+ download_image(
37
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a52eac708bd.png?raw=true',
38
+ os.path.join(cache_dir, 'examples/3a52eac708bd.jpg')), None,
39
+ None, '{image} red hair', 9999
40
+ ],
41
+ [
42
+ 'Facial Editing',
43
+ download_image(
44
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3f4dc464a0ea.png?raw=true',
45
+ os.path.join(cache_dir, 'examples/3f4dc464a0ea.jpg')), None,
46
+ None, '{image} let the man serious', 99999
47
+ ],
48
+ [
49
+ 'Controllable Generation',
50
+ download_image(
51
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',
52
+ os.path.join(cache_dir,
53
+ 'examples/131ca90fd2a9.jpg')), None, None,
54
+ '"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.',
55
+ 613725
56
+ ],
57
+ [
58
+ 'Render Text',
59
+ download_image(
60
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',
61
+ os.path.join(cache_dir, 'examples/33e9f27c2c48.jpg')),
62
+ download_image(
63
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',
64
+ os.path.join(cache_dir,
65
+ 'examples/33e9f27c2c48_mask.jpg')), None,
66
+ 'Put the text "C A T" at the position marked by mask in the {image}',
67
+ 6666
68
+ ],
69
+ [
70
+ 'Style Transfer',
71
+ download_image(
72
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',
73
+ os.path.join(cache_dir, 'examples/9e73e7eeef55.jpg')), None,
74
+ download_image(
75
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',
76
+ os.path.join(cache_dir, 'examples/2e02975293d6.jpg')),
77
+ 'edit {image} based on the style of {image1} ', 99999
78
+ ],
79
+ [
80
+ 'Outpainting',
81
+ download_image(
82
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',
83
+ os.path.join(cache_dir, 'examples/f2b22c08be3f.jpg')),
84
+ download_image(
85
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',
86
+ os.path.join(cache_dir,
87
+ 'examples/f2b22c08be3f_mask.jpg')), None,
88
+ 'Could the {image} be widened within the space designated by mask, while retaining the original?',
89
+ 6666
90
+ ],
91
+ [
92
+ 'Image Segmentation',
93
+ download_image(
94
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',
95
+ os.path.join(cache_dir, 'examples/db3ebaa81899.jpg')), None,
96
+ None, '{image} Segmentation', 6666
97
+ ],
98
+ [
99
+ 'Depth Estimation',
100
+ download_image(
101
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',
102
+ os.path.join(cache_dir, 'examples/f1927c4692ba.jpg')), None,
103
+ None, '{image} Depth Estimation', 6666
104
+ ],
105
+ [
106
+ 'Pose Estimation',
107
+ download_image(
108
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',
109
+ os.path.join(cache_dir, 'examples/014e5bf3b4d1.jpg')), None,
110
+ None, '{image} distinguish the poses of the figures', 999999
111
+ ],
112
+ [
113
+ 'Scribble Extraction',
114
+ download_image(
115
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',
116
+ os.path.join(cache_dir, 'examples/5f59a202f8ac.jpg')), None,
117
+ None, 'Generate a scribble of {image}, please.', 6666
118
+ ],
119
+ [
120
+ 'Mosaic',
121
+ download_image(
122
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',
123
+ os.path.join(cache_dir, 'examples/3a2f52361eea.jpg')), None,
124
+ None, 'Adapt {image} into a mosaic representation.', 6666
125
+ ],
126
+ [
127
+ 'Edge map Extraction',
128
+ download_image(
129
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',
130
+ os.path.join(cache_dir, 'examples/b9d1e519d6e5.jpg')), None,
131
+ None, 'Get the edge-enhanced result for {image}.', 6666
132
+ ],
133
+ [
134
+ 'Grayscale',
135
+ download_image(
136
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',
137
+ os.path.join(cache_dir, 'examples/c4ebbe2ba29b.jpg')), None,
138
+ None, 'transform {image} into a black and white one', 6666
139
+ ],
140
+ [
141
+ 'Contour Extraction',
142
+ download_image(
143
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',
144
+ os.path.join(cache_dir,
145
+ 'examples/19652d0f6c4b.jpg')), None, None,
146
+ 'Would you be able to make a contour picture from {image} for me?',
147
+ 6666
148
+ ],
149
+ [
150
+ 'Controllable Generation',
151
+ download_image(
152
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',
153
+ os.path.join(cache_dir,
154
+ 'examples/249cda2844b7.jpg')), None, None,
155
+ 'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.',
156
+ 6666
157
+ ],
158
+ [
159
+ 'Controllable Generation',
160
+ download_image(
161
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',
162
+ os.path.join(cache_dir,
163
+ 'examples/411f6c4b8e6c.jpg')), None, None,
164
+ 'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image',
165
+ 999999
166
+ ],
167
+ [
168
+ 'Controllable Generation',
169
+ download_image(
170
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',
171
+ os.path.join(cache_dir,
172
+ 'examples/a35c96ed137a.jpg')), None, None,
173
+ 'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."',
174
+ 3599999
175
+ ],
176
+ [
177
+ 'Controllable Generation',
178
+ download_image(
179
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',
180
+ os.path.join(cache_dir,
181
+ 'examples/dcb2fc86f1ce.jpg')), None, None,
182
+ 'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description',
183
+ 6666
184
+ ],
185
+ [
186
+ 'Controllable Generation',
187
+ download_image(
188
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',
189
+ os.path.join(cache_dir,
190
+ 'examples/4cd4ee494962.jpg')), None, None,
191
+ 'make this {image} colorful as per the "beautiful sunflowers"',
192
+ 6666
193
+ ],
194
+ [
195
+ 'Controllable Generation',
196
+ download_image(
197
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',
198
+ os.path.join(cache_dir,
199
+ 'examples/a47e3a9cd166.jpg')), None, None,
200
+ 'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.',
201
+ 613725
202
+ ],
203
+ [
204
+ 'Controllable Generation',
205
+ download_image(
206
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',
207
+ os.path.join(cache_dir,
208
+ 'examples/d890ed8a3ac2.jpg')), None, None,
209
+ 'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."',
210
+ 6666
211
+ ],
212
+ [
213
+ 'Image Denoising',
214
+ download_image(
215
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',
216
+ os.path.join(cache_dir,
217
+ 'examples/0844a686a179.jpg')), None, None,
218
+ 'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',
219
+ 6666
220
+ ],
221
+ [
222
+ 'Inpainting',
223
+ download_image(
224
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',
225
+ os.path.join(cache_dir, 'examples/fa91b6b7e59b.jpg')),
226
+ download_image(
227
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',
228
+ os.path.join(cache_dir,
229
+ 'examples/fa91b6b7e59b_mask.jpg')), None,
230
+ 'Ensure to overhaul the parts of the {image} indicated by the mask.',
231
+ 6666
232
+ ],
233
+ [
234
+ 'Inpainting',
235
+ download_image(
236
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',
237
+ os.path.join(cache_dir, 'examples/632899695b26.jpg')),
238
+ download_image(
239
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',
240
+ os.path.join(cache_dir,
241
+ 'examples/632899695b26_mask.jpg')), None,
242
+ 'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"',
243
+ 6666
244
+ ],
245
+ [
246
+ 'General Editing',
247
+ download_image(
248
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',
249
+ os.path.join(cache_dir,
250
+ 'examples/354d17594afe.jpg')), None, None,
251
+ '{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.',
252
+ 6666
253
+ ],
254
+ [
255
+ 'General Editing',
256
+ download_image(
257
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',
258
+ os.path.join(cache_dir,
259
+ 'examples/38946455752b.jpg')), None, None,
260
+ '{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged',
261
+ 6669
262
+ ],
263
+ [
264
+ 'Facial Editing',
265
+ download_image(
266
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',
267
+ os.path.join(cache_dir,
268
+ 'examples/3ba5202f0cd8.jpg')), None, None,
269
+ 'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',
270
+ 99999
271
+ ],
272
+ [
273
+ 'Facial Editing',
274
+ download_image(
275
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',
276
+ os.path.join(cache_dir, 'examples/369365b94725.jpg')), None,
277
+ None, '{image} Make her looking at the camera', 6666
278
+ ],
279
+ [
280
+ 'Facial Editing',
281
+ download_image(
282
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',
283
+ os.path.join(cache_dir, 'examples/92751f2e4a0e.jpg')), None,
284
+ None, '{image} Remove the smile from his face', 9899999
285
+ ],
286
+ [
287
+ 'Remove Text',
288
+ download_image(
289
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',
290
+ os.path.join(cache_dir, 'examples/8530a6711b2e.jpg')), None,
291
+ None, 'Aim to remove any textual element in {image}', 6666
292
+ ],
293
+ [
294
+ 'Remove Text',
295
+ download_image(
296
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',
297
+ os.path.join(cache_dir, 'examples/c4d7fb28f8f6.jpg')),
298
+ download_image(
299
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',
300
+ os.path.join(cache_dir,
301
+ 'examples/c4d7fb28f8f6_mask.jpg')), None,
302
+ 'Rub out any text found in the mask sector of the {image}.', 6666
303
+ ],
304
+ [
305
+ 'Remove Object',
306
+ download_image(
307
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',
308
+ os.path.join(cache_dir,
309
+ 'examples/e2f318fa5e5b.jpg')), None, None,
310
+ 'Remove the unicorn in this {image}, ensuring a smooth edit.',
311
+ 99999
312
+ ],
313
+ [
314
+ 'Remove Object',
315
+ download_image(
316
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',
317
+ os.path.join(cache_dir, 'examples/1ae96d8aca00.jpg')),
318
+ download_image(
319
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',
320
+ os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.jpg')),
321
+ None, 'Discard the contents of the mask area from {image}.', 99999
322
+ ],
323
+ [
324
+ 'Add Object',
325
+ download_image(
326
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',
327
+ os.path.join(cache_dir, 'examples/80289f48e511.jpg')),
328
+ download_image(
329
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',
330
+ os.path.join(cache_dir,
331
+ 'examples/80289f48e511_mask.jpg')), None,
332
+ 'add a Hot Air Balloon into the {image}, per the mask', 613725
333
+ ],
334
+ [
335
+ 'Style Transfer',
336
+ download_image(
337
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',
338
+ os.path.join(cache_dir, 'examples/d725cb2009e8.jpg')), None,
339
+ None, 'Change the style of {image} to colored pencil style', 99999
340
+ ],
341
+ [
342
+ 'Style Transfer',
343
+ download_image(
344
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',
345
+ os.path.join(cache_dir, 'examples/e0f48b3fd010.jpg')), None,
346
+ None, 'make {image} to Walt Disney Animation style', 99999
347
+ ],
348
+ [
349
+ 'Try On',
350
+ download_image(
351
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',
352
+ os.path.join(cache_dir, 'examples/ee4ca60b8c96.jpg')),
353
+ download_image(
354
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',
355
+ os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.jpg')),
356
+ download_image(
357
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',
358
+ os.path.join(cache_dir, 'examples/ebe825bbfe3c.jpg')),
359
+ 'Change the cloth in {image} to the one in {image1}', 99999
360
+ ],
361
+ [
362
+ 'Workflow',
363
+ download_image(
364
+ 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',
365
+ os.path.join(cache_dir, 'examples/cb85353c004b.jpg')), None,
366
+ None, '<workflow> ice cream {image}', 99999
367
+ ],
368
+ ]
369
+ print('Finish. Start building UI ...')
370
+ return examples
model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ #from .flux import Flux, FluxMR, FluxEdit
model/flux.py ADDED
@@ -0,0 +1,1064 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Copyright (c) Alibaba, Inc. and its affiliates.
3
+ import math
4
+ from collections import OrderedDict
5
+ from functools import partial
6
+ import warnings
7
+ from contextlib import nullcontext
8
+ import torch
9
+ from einops import rearrange, repeat
10
+ from scepter.modules.model.base_model import BaseModel
11
+ from scepter.modules.model.registry import BACKBONES
12
+ from scepter.modules.utils.config import dict_to_yaml
13
+ from scepter.modules.utils.distribute import we
14
+ from scepter.modules.utils.file_system import FS
15
+ from torch import Tensor, nn
16
+ from torch.nn.utils.rnn import pad_sequence
17
+ from torch.utils.checkpoint import checkpoint_sequential
18
+ import torch.nn.functional as F
19
+ import torch.utils.dlpack
20
+ import transformers
21
+ from scepter.modules.model.embedder.base_embedder import BaseEmbedder
22
+ from scepter.modules.model.registry import EMBEDDERS
23
+ from scepter.modules.model.tokenizer.tokenizer_component import (
24
+ basic_clean, canonicalize, heavy_clean, whitespace_clean)
25
+ try:
26
+ from transformers import AutoTokenizer, T5EncoderModel
27
+ except Exception as e:
28
+ warnings.warn(
29
+ f'Import transformers error, please deal with this problem: {e}')
30
+
31
+ from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
32
+ MLPEmbedder, SingleStreamBlock,
33
+ timestep_embedding)
34
+
35
+
36
+
37
+ @EMBEDDERS.register_class()
38
+ class ACETextEmbedder(BaseEmbedder):
39
+ """
40
+ Uses the OpenCLIP transformer encoder for text
41
+ """
42
+ """
43
+ Uses the OpenCLIP transformer encoder for text
44
+ """
45
+ para_dict = {
46
+ 'PRETRAINED_MODEL': {
47
+ 'value':
48
+ 'google/umt5-small',
49
+ 'description':
50
+ 'Pretrained Model for umt5, modelcard path or local path.'
51
+ },
52
+ 'TOKENIZER_PATH': {
53
+ 'value': 'google/umt5-small',
54
+ 'description':
55
+ 'Tokenizer Path for umt5, modelcard path or local path.'
56
+ },
57
+ 'FREEZE': {
58
+ 'value': True,
59
+ 'description': ''
60
+ },
61
+ 'USE_GRAD': {
62
+ 'value': False,
63
+ 'description': 'Compute grad or not.'
64
+ },
65
+ 'CLEAN': {
66
+ 'value':
67
+ 'whitespace',
68
+ 'description':
69
+ 'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
70
+ },
71
+ 'LAYER': {
72
+ 'value': 'last',
73
+ 'description': ''
74
+ },
75
+ 'LEGACY': {
76
+ 'value':
77
+ True,
78
+ 'description':
79
+ 'Whether use legacy returnd feature or not ,default True.'
80
+ }
81
+ }
82
+
83
+ def __init__(self, cfg, logger=None):
84
+ super().__init__(cfg, logger=logger)
85
+ pretrained_path = cfg.get('PRETRAINED_MODEL', None)
86
+ self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
87
+ assert pretrained_path
88
+ with FS.get_dir_to_local_dir(pretrained_path,
89
+ wait_finish=True) as local_path:
90
+ self.model = T5EncoderModel.from_pretrained(
91
+ local_path,
92
+ torch_dtype=getattr(
93
+ torch,
94
+ 'float' if self.t5_dtype == 'float32' else self.t5_dtype))
95
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
96
+ self.length = cfg.get('LENGTH', 77)
97
+
98
+ self.use_grad = cfg.get('USE_GRAD', False)
99
+ self.clean = cfg.get('CLEAN', 'whitespace')
100
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
101
+ if tokenizer_path:
102
+ self.tokenize_kargs = {'return_tensors': 'pt'}
103
+ with FS.get_dir_to_local_dir(tokenizer_path,
104
+ wait_finish=True) as local_path:
105
+ if self.added_identifier is not None and isinstance(
106
+ self.added_identifier, list):
107
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
108
+ else:
109
+ self.tokenizer = AutoTokenizer.from_pretrained(local_path)
110
+ if self.length is not None:
111
+ self.tokenize_kargs.update({
112
+ 'padding': 'max_length',
113
+ 'truncation': True,
114
+ 'max_length': self.length
115
+ })
116
+ self.eos_token = self.tokenizer(
117
+ self.tokenizer.eos_token)['input_ids'][0]
118
+ else:
119
+ self.tokenizer = None
120
+ self.tokenize_kargs = {}
121
+
122
+ self.use_grad = cfg.get('USE_GRAD', False)
123
+ self.clean = cfg.get('CLEAN', 'whitespace')
124
+
125
+ def freeze(self):
126
+ self.model = self.model.eval()
127
+ for param in self.parameters():
128
+ param.requires_grad = False
129
+
130
+ # encode && encode_text
131
+ def forward(self, tokens, return_mask=False, use_mask=True):
132
+ # tokenization
133
+ embedding_context = nullcontext if self.use_grad else torch.no_grad
134
+ with embedding_context():
135
+ if use_mask:
136
+ x = self.model(tokens.input_ids.to(we.device_id),
137
+ tokens.attention_mask.to(we.device_id))
138
+ else:
139
+ x = self.model(tokens.input_ids.to(we.device_id))
140
+ x = x.last_hidden_state
141
+
142
+ if return_mask:
143
+ return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
144
+ else:
145
+ return x.detach() + 0.0, None
146
+
147
+ def _clean(self, text):
148
+ if self.clean == 'whitespace':
149
+ text = whitespace_clean(basic_clean(text))
150
+ elif self.clean == 'lower':
151
+ text = whitespace_clean(basic_clean(text)).lower()
152
+ elif self.clean == 'canonicalize':
153
+ text = canonicalize(basic_clean(text))
154
+ elif self.clean == 'heavy':
155
+ text = heavy_clean(basic_clean(text))
156
+ return text
157
+
158
+ def encode(self, text, return_mask=False, use_mask=True):
159
+ if isinstance(text, str):
160
+ text = [text]
161
+ if self.clean:
162
+ text = [self._clean(u) for u in text]
163
+ assert self.tokenizer is not None
164
+ cont, mask = [], []
165
+ with torch.autocast(device_type='cuda',
166
+ enabled=self.t5_dtype in ('float16', 'bfloat16'),
167
+ dtype=getattr(torch, self.t5_dtype)):
168
+ for tt in text:
169
+ tokens = self.tokenizer([tt], **self.tokenize_kargs)
170
+ one_cont, one_mask = self(tokens,
171
+ return_mask=return_mask,
172
+ use_mask=use_mask)
173
+ cont.append(one_cont)
174
+ mask.append(one_mask)
175
+ if return_mask:
176
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
177
+ else:
178
+ return torch.cat(cont, dim=0)
179
+
180
+ def encode_list(self, text_list, return_mask=True):
181
+ cont_list = []
182
+ mask_list = []
183
+ for pp in text_list:
184
+ cont, cont_mask = self.encode(pp, return_mask=return_mask)
185
+ cont_list.append(cont)
186
+ mask_list.append(cont_mask)
187
+ if return_mask:
188
+ return cont_list, mask_list
189
+ else:
190
+ return cont_list
191
+
192
+ @staticmethod
193
+ def get_config_template():
194
+ return dict_to_yaml('MODELS',
195
+ __class__.__name__,
196
+ ACETextEmbedder.para_dict,
197
+ set_name=True)
198
+
199
+ @EMBEDDERS.register_class()
200
+ class ACEHFEmbedder(BaseEmbedder):
201
+ para_dict = {
202
+ "HF_MODEL_CLS": {
203
+ "value": None,
204
+ "description": "huggingface cls in transfomer"
205
+ },
206
+ "MODEL_PATH": {
207
+ "value": None,
208
+ "description": "model folder path"
209
+ },
210
+ "HF_TOKENIZER_CLS": {
211
+ "value": None,
212
+ "description": "huggingface cls in transfomer"
213
+ },
214
+
215
+ "TOKENIZER_PATH": {
216
+ "value": None,
217
+ "description": "tokenizer folder path"
218
+ },
219
+ "MAX_LENGTH": {
220
+ "value": 77,
221
+ "description": "max length of input"
222
+ },
223
+ "OUTPUT_KEY": {
224
+ "value": "last_hidden_state",
225
+ "description": "output key"
226
+ },
227
+ "D_TYPE": {
228
+ "value": "float",
229
+ "description": "dtype"
230
+ },
231
+ "BATCH_INFER": {
232
+ "value": False,
233
+ "description": "batch infer"
234
+ }
235
+ }
236
+ para_dict.update(BaseEmbedder.para_dict)
237
+ def __init__(self, cfg, logger=None):
238
+ super().__init__(cfg, logger=logger)
239
+ hf_model_cls = cfg.get('HF_MODEL_CLS', None)
240
+ model_path = cfg.get("MODEL_PATH", None)
241
+ hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
242
+ tokenizer_path = cfg.get('TOKENIZER_PATH', None)
243
+ self.max_length = cfg.get('MAX_LENGTH', 77)
244
+ self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
245
+ self.d_type = cfg.get("D_TYPE", "float")
246
+ self.clean = cfg.get("CLEAN", "whitespace")
247
+ self.batch_infer = cfg.get("BATCH_INFER", False)
248
+ self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
249
+ torch_dtype = getattr(torch, self.d_type)
250
+
251
+ assert hf_model_cls is not None and hf_tokenizer_cls is not None
252
+ assert model_path is not None and tokenizer_path is not None
253
+ with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
254
+ self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
255
+ max_length = self.max_length,
256
+ torch_dtype = torch_dtype,
257
+ additional_special_tokens=self.added_identifier)
258
+
259
+ with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
260
+ self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
261
+
262
+
263
+ self.hf_module = self.hf_module.eval().requires_grad_(False)
264
+
265
+ def forward(self, text: list[str], return_mask = False):
266
+ batch_encoding = self.tokenizer(
267
+ text,
268
+ truncation=True,
269
+ max_length=self.max_length,
270
+ return_length=False,
271
+ return_overflowing_tokens=False,
272
+ padding="max_length",
273
+ return_tensors="pt",
274
+ )
275
+
276
+ outputs = self.hf_module(
277
+ input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
278
+ attention_mask=None,
279
+ output_hidden_states=False,
280
+ )
281
+ if return_mask:
282
+ return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
283
+ else:
284
+ return outputs[self.output_key], None
285
+
286
+ def encode(self, text, return_mask = False):
287
+ if isinstance(text, str):
288
+ text = [text]
289
+ if self.clean:
290
+ text = [self._clean(u) for u in text]
291
+ if not self.batch_infer:
292
+ cont, mask = [], []
293
+ for tt in text:
294
+ one_cont, one_mask = self([tt], return_mask=return_mask)
295
+ cont.append(one_cont)
296
+ mask.append(one_mask)
297
+ if return_mask:
298
+ return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
299
+ else:
300
+ return torch.cat(cont, dim=0)
301
+ else:
302
+ ret_data = self(text, return_mask = return_mask)
303
+ if return_mask:
304
+ return ret_data
305
+ else:
306
+ return ret_data[0]
307
+
308
+ def encode_list(self, text_list, return_mask=True):
309
+ cont_list = []
310
+ mask_list = []
311
+ for pp in text_list:
312
+ cont = self.encode(pp, return_mask=return_mask)
313
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
314
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
315
+ if return_mask:
316
+ return cont_list, mask_list
317
+ else:
318
+ return cont_list
319
+
320
+ def encode_list_of_list(self, text_list, return_mask=True):
321
+ cont_list = []
322
+ mask_list = []
323
+ for pp in text_list:
324
+ cont = self.encode_list(pp, return_mask=return_mask)
325
+ cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
326
+ mask_list.append(cont[1]) if return_mask else mask_list.append(None)
327
+ if return_mask:
328
+ return cont_list, mask_list
329
+ else:
330
+ return cont_list
331
+
332
+ def _clean(self, text):
333
+ if self.clean == 'whitespace':
334
+ text = whitespace_clean(basic_clean(text))
335
+ elif self.clean == 'lower':
336
+ text = whitespace_clean(basic_clean(text)).lower()
337
+ elif self.clean == 'canonicalize':
338
+ text = canonicalize(basic_clean(text))
339
+ return text
340
+ @staticmethod
341
+ def get_config_template():
342
+ return dict_to_yaml('EMBEDDER',
343
+ __class__.__name__,
344
+ ACEHFEmbedder.para_dict,
345
+ set_name=True)
346
+
347
+ @EMBEDDERS.register_class()
348
+ class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
349
+ """
350
+ Uses the OpenCLIP transformer encoder for text
351
+ """
352
+ para_dict = {
353
+ 'T5_MODEL': {},
354
+ 'CLIP_MODEL': {}
355
+ }
356
+
357
+ def __init__(self, cfg, logger=None):
358
+ super().__init__(cfg, logger=logger)
359
+ self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
360
+ self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
361
+
362
+ def encode(self, text, return_mask = False):
363
+ t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
364
+ clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
365
+ # change embedding strategy here
366
+ return {
367
+ 'context': t5_embeds,
368
+ 'y': clip_embeds,
369
+ }
370
+
371
+ def encode_list(self, text, return_mask = False):
372
+ t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
373
+ clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
374
+ # change embedding strategy here
375
+ return {
376
+ 'context': t5_embeds,
377
+ 'y': clip_embeds,
378
+ }
379
+
380
+ def encode_list_of_list(self, text, return_mask = False):
381
+ t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
382
+ clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
383
+ # change embedding strategy here
384
+ return {
385
+ 'context': t5_embeds,
386
+ 'y': clip_embeds,
387
+ }
388
+
389
+
390
+ @staticmethod
391
+ def get_config_template():
392
+ return dict_to_yaml('EMBEDDER',
393
+ __class__.__name__,
394
+ T5ACEPlusClipFluxEmbedder.para_dict,
395
+ set_name=True)
396
+
397
+ @BACKBONES.register_class()
398
+ class Flux(BaseModel):
399
+ """
400
+ Transformer backbone Diffusion model with RoPE.
401
+ """
402
+ para_dict = {
403
+ "IN_CHANNELS": {
404
+ "value": 64,
405
+ "description": "model's input channels."
406
+ },
407
+ "OUT_CHANNELS": {
408
+ "value": 64,
409
+ "description": "model's output channels."
410
+ },
411
+ "HIDDEN_SIZE": {
412
+ "value": 1024,
413
+ "description": "model's hidden size."
414
+ },
415
+ "NUM_HEADS": {
416
+ "value": 16,
417
+ "description": "number of heads in the transformer."
418
+ },
419
+ "AXES_DIM": {
420
+ "value": [16, 56, 56],
421
+ "description": "dimensions of the axes of the positional encoding."
422
+ },
423
+ "THETA": {
424
+ "value": 10_000,
425
+ "description": "theta for positional encoding."
426
+ },
427
+ "VEC_IN_DIM": {
428
+ "value": 768,
429
+ "description": "dimension of the vector input."
430
+ },
431
+ "GUIDANCE_EMBED": {
432
+ "value": False,
433
+ "description": "whether to use guidance embedding."
434
+ },
435
+ "CONTEXT_IN_DIM": {
436
+ "value": 4096,
437
+ "description": "dimension of the context input."
438
+ },
439
+ "MLP_RATIO": {
440
+ "value": 4.0,
441
+ "description": "ratio of mlp hidden size to hidden size."
442
+ },
443
+ "QKV_BIAS": {
444
+ "value": True,
445
+ "description": "whether to use bias in qkv projection."
446
+ },
447
+ "DEPTH": {
448
+ "value": 19,
449
+ "description": "number of transformer blocks."
450
+ },
451
+ "DEPTH_SINGLE_BLOCKS": {
452
+ "value": 38,
453
+ "description": "number of transformer blocks in the single stream block."
454
+ },
455
+ "USE_GRAD_CHECKPOINT": {
456
+ "value": False,
457
+ "description": "whether to use gradient checkpointing."
458
+ },
459
+ "ATTN_BACKEND": {
460
+ "value": "pytorch",
461
+ "description": "backend for the transformer blocks, 'pytorch' or 'flash_attn'."
462
+ }
463
+ }
464
+ def __init__(
465
+ self,
466
+ cfg,
467
+ logger = None
468
+ ):
469
+ super().__init__(cfg, logger=logger)
470
+ self.in_channels = cfg.IN_CHANNELS
471
+ self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
472
+ hidden_size = cfg.get("HIDDEN_SIZE", 1024)
473
+ num_heads = cfg.get("NUM_HEADS", 16)
474
+ axes_dim = cfg.AXES_DIM
475
+ theta = cfg.THETA
476
+ vec_in_dim = cfg.VEC_IN_DIM
477
+ self.guidance_embed = cfg.GUIDANCE_EMBED
478
+ context_in_dim = cfg.CONTEXT_IN_DIM
479
+ mlp_ratio = cfg.MLP_RATIO
480
+ qkv_bias = cfg.QKV_BIAS
481
+ depth = cfg.DEPTH
482
+ depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
483
+ self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
484
+ self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
485
+ self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
486
+ self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
487
+ self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
488
+
489
+ if hidden_size % num_heads != 0:
490
+ raise ValueError(
491
+ f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
492
+ )
493
+ pe_dim = hidden_size // num_heads
494
+ if sum(axes_dim) != pe_dim:
495
+ raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
496
+ self.hidden_size = hidden_size
497
+ self.num_heads = num_heads
498
+ self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim= axes_dim)
499
+ self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
500
+ self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
501
+ self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
502
+ self.guidance_in = (
503
+ MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
504
+ )
505
+ self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
506
+
507
+ self.double_blocks = nn.ModuleList(
508
+ [
509
+ DoubleStreamBlock(
510
+ self.hidden_size,
511
+ self.num_heads,
512
+ mlp_ratio=mlp_ratio,
513
+ qkv_bias=qkv_bias,
514
+ backend=self.attn_backend
515
+ )
516
+ for _ in range(depth)
517
+ ]
518
+ )
519
+
520
+ self.single_blocks = nn.ModuleList(
521
+ [
522
+ SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
523
+ for _ in range(depth_single_blocks)
524
+ ]
525
+ )
526
+
527
+ self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
528
+
529
+ def prepare_input(self, x, context, y, x_shape=None):
530
+ # x.shape [6, 16, 16, 16] target is [6, 16, 768, 1360]
531
+ bs, c, h, w = x.shape
532
+ x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
533
+ x_id = torch.zeros(h // 2, w // 2, 3)
534
+ x_id[..., 1] = x_id[..., 1] + torch.arange(h // 2)[:, None]
535
+ x_id[..., 2] = x_id[..., 2] + torch.arange(w // 2)[None, :]
536
+ x_ids = repeat(x_id, "h w c -> b (h w) c", b=bs)
537
+ txt_ids = torch.zeros(bs, context.shape[1], 3)
538
+ return x, x_ids.to(x), context.to(x), txt_ids.to(x), y.to(x), h, w
539
+
540
+ def unpack(self, x: Tensor, height: int, width: int) -> Tensor:
541
+ return rearrange(
542
+ x,
543
+ "b (h w) (c ph pw) -> b c (h ph) (w pw)",
544
+ h=math.ceil(height/2),
545
+ w=math.ceil(width/2),
546
+ ph=2,
547
+ pw=2,
548
+ )
549
+
550
+ def merge_diffuser_lora(self, ori_sd, lora_sd, scale = 1.0):
551
+ key_map = {
552
+ "single_blocks.{}.linear1.weight": {"key_list": [
553
+ ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
554
+ "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight"],
555
+ ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
556
+ "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight"],
557
+ ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
558
+ "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight"],
559
+ ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
560
+ "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight"]
561
+ ], "num": 38},
562
+ "single_blocks.{}.modulation.lin.weight": {"key_list": [
563
+ ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
564
+ "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight"],
565
+ ], "num": 38},
566
+ "single_blocks.{}.linear2.weight": {"key_list": [
567
+ ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
568
+ "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight"],
569
+ ], "num": 38},
570
+ "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
571
+ ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
572
+ "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight"],
573
+ ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
574
+ "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight"],
575
+ ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
576
+ "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight"],
577
+ ], "num": 19},
578
+ "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
579
+ ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
580
+ "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight"],
581
+ ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
582
+ "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight"],
583
+ ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
584
+ "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight"],
585
+ ], "num": 19},
586
+ "double_blocks.{}.img_attn.proj.weight": {"key_list": [
587
+ ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
588
+ "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight"]
589
+ ], "num": 19},
590
+ "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
591
+ ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
592
+ "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight"]
593
+ ], "num": 19},
594
+ "double_blocks.{}.img_mlp.0.weight": {"key_list": [
595
+ ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
596
+ "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight"]
597
+ ], "num": 19},
598
+ "double_blocks.{}.img_mlp.2.weight": {"key_list": [
599
+ ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
600
+ "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight"]
601
+ ], "num": 19},
602
+ "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
603
+ ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
604
+ "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight"]
605
+ ], "num": 19},
606
+ "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
607
+ ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
608
+ "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight"]
609
+ ], "num": 19},
610
+ "double_blocks.{}.img_mod.lin.weight": {"key_list": [
611
+ ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
612
+ "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight"]
613
+ ], "num": 19},
614
+ "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
615
+ ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
616
+ "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight"]
617
+ ], "num": 19}
618
+ }
619
+ for k, v in key_map.items():
620
+ key_list = v["key_list"]
621
+ block_num = v["num"]
622
+ for block_id in range(block_num):
623
+ current_weight_list = []
624
+ for k_list in key_list:
625
+ current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
626
+ lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
627
+ current_weight_list.append(current_weight)
628
+ current_weight = torch.cat(current_weight_list, dim=0)
629
+ ori_sd[k.format(block_id)] += scale*current_weight
630
+ return ori_sd
631
+
632
+ def merge_swift_lora(self, ori_sd, lora_sd, scale = 1.0):
633
+ have_lora_keys = {}
634
+ for k, v in lora_sd.items():
635
+ k = k[len("model."):] if k.startswith("model.") else k
636
+ ori_key = k.split("lora")[0] + "weight"
637
+ if ori_key not in ori_sd:
638
+ raise f"{ori_key} should in the original statedict"
639
+ if ori_key not in have_lora_keys:
640
+ have_lora_keys[ori_key] = {}
641
+ if "lora_A" in k:
642
+ have_lora_keys[ori_key]["lora_A"] = v
643
+ elif "lora_B" in k:
644
+ have_lora_keys[ori_key]["lora_B"] = v
645
+ else:
646
+ raise NotImplementedError
647
+ for key, v in have_lora_keys.items():
648
+ current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
649
+ ori_sd[key] += scale * current_weight
650
+ return ori_sd
651
+
652
+
653
+ def load_pretrained_model(self, pretrained_model):
654
+ if next(self.parameters()).device.type == 'meta':
655
+ map_location = we.device_id
656
+ else:
657
+ map_location = "cpu"
658
+ if self.lora_model is not None:
659
+ map_location = we.device_id
660
+ if pretrained_model is not None:
661
+ with FS.get_from(pretrained_model, wait_finish=True) as local_model:
662
+ if local_model.endswith('safetensors'):
663
+ from safetensors.torch import load_file as load_safetensors
664
+ sd = load_safetensors(local_model, device=map_location)
665
+ else:
666
+ sd = torch.load(local_model, map_location=map_location)
667
+ if "state_dict" in sd:
668
+ sd = sd["state_dict"]
669
+ if "model" in sd:
670
+ sd = sd["model"]["model"]
671
+
672
+ if self.lora_model is not None:
673
+ with FS.get_from(self.lora_model, wait_finish=True) as local_model:
674
+ if local_model.endswith('safetensors'):
675
+ from safetensors.torch import load_file as load_safetensors
676
+ lora_sd = load_safetensors(local_model, device=map_location)
677
+ else:
678
+ lora_sd = torch.load(local_model, map_location=map_location)
679
+ sd = self.merge_diffuser_lora(sd, lora_sd)
680
+ if self.swift_lora_model is not None:
681
+ with FS.get_from(self.swift_lora_model, wait_finish=True) as local_model:
682
+ if local_model.endswith('safetensors'):
683
+ from safetensors.torch import load_file as load_safetensors
684
+ lora_sd = load_safetensors(local_model, device=map_location)
685
+ else:
686
+ lora_sd = torch.load(local_model, map_location=map_location)
687
+ sd = self.merge_swift_lora(sd, lora_sd)
688
+
689
+ adapter_ckpt = {}
690
+ if self.pretrain_adapter is not None:
691
+ with FS.get_from(self.pretrain_adapter, wait_finish=True) as local_adapter:
692
+ if local_model.endswith('safetensors'):
693
+ from safetensors.torch import load_file as load_safetensors
694
+ adapter_ckpt = load_safetensors(local_adapter, device=map_location)
695
+ else:
696
+ adapter_ckpt = torch.load(local_adapter, map_location=map_location)
697
+ sd.update(adapter_ckpt)
698
+
699
+
700
+ new_ckpt = OrderedDict()
701
+ for k, v in sd.items():
702
+ if k in ("img_in.weight"):
703
+ model_p = self.state_dict()[k]
704
+ if v.shape != model_p.shape:
705
+ model_p.zero_()
706
+ model_p[:, :64].copy_(v[:, :64])
707
+ new_ckpt[k] = torch.nn.parameter.Parameter(model_p)
708
+ else:
709
+ new_ckpt[k] = v
710
+ else:
711
+ new_ckpt[k] = v
712
+
713
+
714
+ missing, unexpected = self.load_state_dict(new_ckpt, strict=False, assign=True)
715
+ self.logger.info(
716
+ f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
717
+ )
718
+ if len(missing) > 0:
719
+ self.logger.info(f'Missing Keys:\n {missing}')
720
+ if len(unexpected) > 0:
721
+ self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
722
+
723
+ def forward(
724
+ self,
725
+ x: Tensor,
726
+ t: Tensor,
727
+ cond: dict = {},
728
+ guidance: Tensor | None = None,
729
+ gc_seg: int = 0
730
+ ) -> Tensor:
731
+ x, x_ids, txt, txt_ids, y, h, w = self.prepare_input(x, cond["context"], cond["y"])
732
+ # running on sequences img
733
+ x = self.img_in(x)
734
+ vec = self.time_in(timestep_embedding(t, 256))
735
+ if self.guidance_embed:
736
+ if guidance is None:
737
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
738
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
739
+ vec = vec + self.vector_in(y)
740
+ txt = self.txt_in(txt)
741
+ ids = torch.cat((txt_ids, x_ids), dim=1)
742
+ pe = self.pe_embedder(ids)
743
+ kwargs = dict(
744
+ vec=vec,
745
+ pe=pe,
746
+ txt_length=txt.shape[1],
747
+ )
748
+ x = torch.cat((txt, x), 1)
749
+ if self.use_grad_checkpoint and gc_seg >= 0:
750
+ x = checkpoint_sequential(
751
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
752
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
753
+ input=x,
754
+ use_reentrant=False
755
+ )
756
+ else:
757
+ for block in self.double_blocks:
758
+ x = block(x, **kwargs)
759
+
760
+ kwargs = dict(
761
+ vec=vec,
762
+ pe=pe,
763
+ )
764
+
765
+ if self.use_grad_checkpoint and gc_seg >= 0:
766
+ x = checkpoint_sequential(
767
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
768
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
769
+ input=x,
770
+ use_reentrant=False
771
+ )
772
+ else:
773
+ for block in self.single_blocks:
774
+ x = block(x, **kwargs)
775
+ x = x[:, txt.shape[1] :, ...]
776
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
777
+ x = self.unpack(x, h, w)
778
+ return x
779
+
780
+ @staticmethod
781
+ def get_config_template():
782
+ return dict_to_yaml('MODEL',
783
+ __class__.__name__,
784
+ Flux.para_dict,
785
+ set_name=True)
786
+
787
+ @BACKBONES.register_class()
788
+ class FluxMR(Flux):
789
+ def prepare_input(self, x, cond):
790
+ if isinstance(cond['context'], list):
791
+ context, y = torch.cat(cond["context"], dim=0).to(x), torch.cat(cond["y"], dim=0).to(x)
792
+ else:
793
+ context, y = cond['context'].to(x), cond['y'].to(x)
794
+ batch_frames, batch_frames_ids = [], []
795
+ for ix, shape in zip(x, cond["x_shapes"]):
796
+ # unpack image from sequence
797
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
798
+ c, h, w = ix.shape
799
+ ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
800
+ ix_id = torch.zeros(h // 2, w // 2, 3)
801
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
802
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
803
+ ix_id = rearrange(ix_id, "h w c -> (h w) c")
804
+ batch_frames.append([ix])
805
+ batch_frames_ids.append([ix_id])
806
+
807
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
808
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
809
+ proj_frames = []
810
+ for idx, one_frame in enumerate(frames):
811
+ one_frame = self.img_in(one_frame)
812
+ proj_frames.append(one_frame)
813
+ ix = torch.cat(proj_frames, dim=0)
814
+ if_id = torch.cat(frame_ids, dim=0)
815
+ x_list.append(ix)
816
+ x_id_list.append(if_id)
817
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
818
+ x_seq_length.append(ix.shape[0])
819
+ x = pad_sequence(tuple(x_list), batch_first=True)
820
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
821
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
822
+
823
+ txt = self.txt_in(context)
824
+ txt_ids = torch.zeros(context.shape[0], context.shape[1], 3).to(x)
825
+ mask_txt = torch.ones(context.shape[0], context.shape[1]).to(x.device, non_blocking=True).bool()
826
+
827
+ return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
828
+
829
+ def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
830
+ x_list = []
831
+ image_shapes = cond["x_shapes"]
832
+ for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
833
+ height, width = shape
834
+ h, w = math.ceil(height / 2), math.ceil(width / 2)
835
+ u = rearrange(
836
+ u[seq_length-h*w:seq_length, ...],
837
+ "(h w) (c ph pw) -> (h ph w pw) c",
838
+ h=h,
839
+ w=w,
840
+ ph=2,
841
+ pw=2,
842
+ )
843
+ x_list.append(u)
844
+ x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
845
+ return x
846
+
847
+ def forward(
848
+ self,
849
+ x: Tensor,
850
+ t: Tensor,
851
+ cond: dict = {},
852
+ guidance: Tensor | None = None,
853
+ gc_seg: int = 0,
854
+ **kwargs
855
+ ) -> Tensor:
856
+ x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
857
+ # running on sequences img
858
+ vec = self.time_in(timestep_embedding(t, 256))
859
+ if self.guidance_embed:
860
+ if guidance is None:
861
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
862
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
863
+ vec = vec + self.vector_in(y)
864
+ ids = torch.cat((txt_ids, x_ids), dim=1)
865
+ pe = self.pe_embedder(ids)
866
+
867
+ mask_aside = torch.cat((mask_txt, mask_x), dim=1)
868
+ mask = mask_aside[:, None, :] * mask_aside[:, :, None]
869
+
870
+ kwargs = dict(
871
+ vec=vec,
872
+ pe=pe,
873
+ mask=mask,
874
+ txt_length = txt.shape[1],
875
+ )
876
+ x = torch.cat((txt, x), 1)
877
+ if self.use_grad_checkpoint and gc_seg >= 0:
878
+ x = checkpoint_sequential(
879
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
880
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
881
+ input=x,
882
+ use_reentrant=False
883
+ )
884
+ else:
885
+ for block in self.double_blocks:
886
+ x = block(x, **kwargs)
887
+
888
+ kwargs = dict(
889
+ vec=vec,
890
+ pe=pe,
891
+ mask=mask,
892
+ )
893
+
894
+ if self.use_grad_checkpoint and gc_seg >= 0:
895
+ x = checkpoint_sequential(
896
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
897
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
898
+ input=x,
899
+ use_reentrant=False
900
+ )
901
+ else:
902
+ for block in self.single_blocks:
903
+ x = block(x, **kwargs)
904
+ x = x[:, txt.shape[1]:, ...]
905
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
906
+ x = self.unpack(x, cond, seq_length_list)
907
+ return x
908
+
909
+ @staticmethod
910
+ def get_config_template():
911
+ return dict_to_yaml('MODEL',
912
+ __class__.__name__,
913
+ FluxEdit.para_dict,
914
+ set_name=True)
915
+ @BACKBONES.register_class()
916
+ class FluxEdit(FluxMR):
917
+ def prepare_input(self, x, cond, *args, **kwargs):
918
+ context, y = cond["context"], cond["y"]
919
+ batch_frames, batch_frames_ids, batch_shift = [], [], []
920
+
921
+ for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
922
+ # unpack image from sequence
923
+ ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
924
+ c, h, w = ix.shape
925
+ ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
926
+ ix_id = torch.zeros(h // 2, w // 2, 3)
927
+ ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
928
+ ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
929
+ batch_shift.append(h // 2) #if is_align < 1 else batch_shift.append(0)
930
+ ix_id = rearrange(ix_id, "h w c -> (h w) c")
931
+ batch_frames.append([ix])
932
+ batch_frames_ids.append([ix_id])
933
+ if 'edit_x' in cond:
934
+ for i, edit in enumerate(cond['edit_x']):
935
+ if edit is None:
936
+ continue
937
+ for ie in edit:
938
+ ie = ie.squeeze(0)
939
+ c, h, w = ie.shape
940
+ ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
941
+ ie_id = torch.zeros(h // 2, w // 2, 3)
942
+ ie_id[..., 1] = ie_id[..., 1] + torch.arange(batch_shift[i], h // 2 + batch_shift[i])[:, None]
943
+ ie_id[..., 2] = ie_id[..., 2] + torch.arange(w // 2)[None, :]
944
+ ie_id = rearrange(ie_id, "h w c -> (h w) c")
945
+ batch_frames[i].append(ie)
946
+ batch_frames_ids[i].append(ie_id)
947
+
948
+ x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
949
+ for frames, frame_ids in zip(batch_frames, batch_frames_ids):
950
+ proj_frames = []
951
+ for idx, one_frame in enumerate(frames):
952
+ one_frame = self.img_in(one_frame)
953
+ proj_frames.append(one_frame)
954
+ ix = torch.cat(proj_frames, dim=0)
955
+ if_id = torch.cat(frame_ids, dim=0)
956
+ x_list.append(ix)
957
+ x_id_list.append(if_id)
958
+ mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
959
+ x_seq_length.append(ix.shape[0])
960
+ x = pad_sequence(tuple(x_list), batch_first=True)
961
+ x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
962
+ mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
963
+
964
+ txt_list, mask_txt_list, y_list = [], [], []
965
+ for sample_id, (ctx, yy) in enumerate(zip(context, y)):
966
+ ctx_batch = []
967
+ for frame_id, one_ctx in enumerate(ctx):
968
+ one_ctx = self.txt_in(one_ctx.to(x))
969
+ ctx_batch.append(one_ctx)
970
+ txt_list.append(torch.cat(ctx_batch, dim=0))
971
+ mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
972
+ y_list.append(yy.mean(dim = 0, keepdim=True))
973
+ txt = pad_sequence(tuple(txt_list), batch_first=True)
974
+ txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
975
+ mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
976
+ y = torch.cat(y_list, dim=0)
977
+ return x, x_ids, txt, txt_ids, y, mask_x, mask_txt, x_seq_length
978
+
979
+ def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
980
+ x_list = []
981
+ image_shapes = cond["x_shapes"]
982
+ for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
983
+ height, width = shape
984
+ h, w = math.ceil(height / 2), math.ceil(width / 2)
985
+ u = rearrange(
986
+ u[:h*w, ...],
987
+ "(h w) (c ph pw) -> (h ph w pw) c",
988
+ h=h,
989
+ w=w,
990
+ ph=2,
991
+ pw=2,
992
+ )
993
+ x_list.append(u)
994
+ x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
995
+ return x
996
+
997
+ def forward(
998
+ self,
999
+ x: Tensor,
1000
+ t: Tensor,
1001
+ cond: dict = {},
1002
+ guidance: Tensor | None = None,
1003
+ gc_seg: int = 0,
1004
+ text_position_embeddings = None
1005
+ ) -> Tensor:
1006
+ x, x_ids, txt, txt_ids, y, mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond, text_position_embeddings)
1007
+ # running on sequences img
1008
+ vec = self.time_in(timestep_embedding(t, 256))
1009
+ if self.guidance_embed:
1010
+ if guidance is None:
1011
+ raise ValueError("Didn't get guidance strength for guidance distilled model.")
1012
+ vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
1013
+ vec = vec + self.vector_in(y)
1014
+ ids = torch.cat((txt_ids, x_ids), dim=1)
1015
+ pe = self.pe_embedder(ids)
1016
+
1017
+ mask_aside = torch.cat((mask_txt, mask_x), dim=1)
1018
+ mask = mask_aside[:, None, :] * mask_aside[:, :, None]
1019
+
1020
+ kwargs = dict(
1021
+ vec=vec,
1022
+ pe=pe,
1023
+ mask=mask,
1024
+ txt_length = txt.shape[1],
1025
+ )
1026
+ x = torch.cat((txt, x), 1)
1027
+
1028
+ if self.use_grad_checkpoint and gc_seg >= 0:
1029
+ x = checkpoint_sequential(
1030
+ functions=[partial(block, **kwargs) for block in self.double_blocks],
1031
+ segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
1032
+ input=x,
1033
+ use_reentrant=False
1034
+ )
1035
+ else:
1036
+ for block in self.double_blocks:
1037
+ x = block(x, **kwargs)
1038
+
1039
+ kwargs = dict(
1040
+ vec=vec,
1041
+ pe=pe,
1042
+ mask=mask,
1043
+ )
1044
+
1045
+ if self.use_grad_checkpoint and gc_seg >= 0:
1046
+ x = checkpoint_sequential(
1047
+ functions=[partial(block, **kwargs) for block in self.single_blocks],
1048
+ segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
1049
+ input=x,
1050
+ use_reentrant=False
1051
+ )
1052
+ else:
1053
+ for block in self.single_blocks:
1054
+ x = block(x, **kwargs)
1055
+ x = x[:, txt.shape[1]:, ...]
1056
+ x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
1057
+ x = self.unpack(x, cond, seq_length_list)
1058
+ return x
1059
+ @staticmethod
1060
+ def get_config_template():
1061
+ return dict_to_yaml('MODEL',
1062
+ __class__.__name__,
1063
+ FluxEdit.para_dict,
1064
+ set_name=True)
model/layers.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import math
4
+ from dataclasses import dataclass
5
+ from torch import Tensor, nn
6
+ import torch
7
+ from einops import rearrange, repeat
8
+ from torch import Tensor
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ try:
12
+ from flash_attn import (
13
+ flash_attn_varlen_func
14
+ )
15
+ FLASHATTN_IS_AVAILABLE = True
16
+ except ImportError:
17
+ FLASHATTN_IS_AVAILABLE = False
18
+ flash_attn_varlen_func = None
19
+
20
+ def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor | None = None, backend = 'pytorch') -> Tensor:
21
+ q, k = apply_rope(q, k, pe)
22
+ if backend == 'pytorch':
23
+ if mask is not None and mask.dtype == torch.bool:
24
+ mask = torch.zeros_like(mask).to(q).masked_fill_(mask.logical_not(), -1e20)
25
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
26
+ # x = torch.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
27
+ x = rearrange(x, "B H L D -> B L (H D)")
28
+ elif backend == 'flash_attn':
29
+ # q: (B, H, L, D)
30
+ # k: (B, H, S, D) now L = S
31
+ # v: (B, H, S, D)
32
+ b, h, lq, d = q.shape
33
+ _, _, lk, _ = k.shape
34
+ q = rearrange(q, "B H L D -> B L H D")
35
+ k = rearrange(k, "B H S D -> B S H D")
36
+ v = rearrange(v, "B H S D -> B S H D")
37
+ if mask is None:
38
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(q.device, non_blocking=True)
39
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(k.device, non_blocking=True)
40
+ else:
41
+ q_lens = torch.sum(mask[:, 0, :, 0], dim=1).int()
42
+ k_lens = torch.sum(mask[:, 0, 0, :], dim=1).int()
43
+ q = torch.cat([q_v[:q_l] for q_v, q_l in zip(q, q_lens)])
44
+ k = torch.cat([k_v[:k_l] for k_v, k_l in zip(k, k_lens)])
45
+ v = torch.cat([v_v[:v_l] for v_v, v_l in zip(v, k_lens)])
46
+ cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
47
+ cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
48
+ max_seqlen_q = q_lens.max()
49
+ max_seqlen_k = k_lens.max()
50
+
51
+ x = flash_attn_varlen_func(
52
+ q,
53
+ k,
54
+ v,
55
+ cu_seqlens_q=cu_seqlens_q,
56
+ cu_seqlens_k=cu_seqlens_k,
57
+ max_seqlen_q=max_seqlen_q,
58
+ max_seqlen_k=max_seqlen_k
59
+ )
60
+ x_list = [x[cu_seqlens_q[i]:cu_seqlens_q[i+1]] for i in range(b)]
61
+ x = pad_sequence(tuple(x_list), batch_first=True)
62
+ x = rearrange(x, "B L H D -> B L (H D)")
63
+ else:
64
+ raise NotImplementedError
65
+ return x
66
+
67
+
68
+ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
69
+ assert dim % 2 == 0
70
+ scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
71
+ omega = 1.0 / (theta**scale)
72
+ out = torch.einsum("...n,d->...nd", pos, omega)
73
+ out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
74
+ out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
75
+ return out.float()
76
+
77
+
78
+ def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
79
+ xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
80
+ xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
81
+ xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
82
+ xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
83
+ return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
84
+
85
+ class EmbedND(nn.Module):
86
+ def __init__(self, dim: int, theta: int, axes_dim: list[int]):
87
+ super().__init__()
88
+ self.dim = dim
89
+ self.theta = theta
90
+ self.axes_dim = axes_dim
91
+
92
+ def forward(self, ids: Tensor) -> Tensor:
93
+ n_axes = ids.shape[-1]
94
+ emb = torch.cat(
95
+ [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
96
+ dim=-3,
97
+ )
98
+
99
+ return emb.unsqueeze(1)
100
+
101
+
102
+ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
103
+ """
104
+ Create sinusoidal timestep embeddings.
105
+ :param t: a 1-D Tensor of N indices, one per batch element.
106
+ These may be fractional.
107
+ :param dim: the dimension of the output.
108
+ :param max_period: controls the minimum frequency of the embeddings.
109
+ :return: an (N, D) Tensor of positional embeddings.
110
+ """
111
+ t = time_factor * t
112
+ half = dim // 2
113
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
114
+ t.device
115
+ )
116
+
117
+ args = t[:, None].float() * freqs[None]
118
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
119
+ if dim % 2:
120
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
121
+ if torch.is_floating_point(t):
122
+ embedding = embedding.to(t)
123
+ return embedding
124
+
125
+
126
+ class MLPEmbedder(nn.Module):
127
+ def __init__(self, in_dim: int, hidden_dim: int):
128
+ super().__init__()
129
+ self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
130
+ self.silu = nn.SiLU()
131
+ self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
132
+
133
+ def forward(self, x: Tensor) -> Tensor:
134
+ return self.out_layer(self.silu(self.in_layer(x)))
135
+
136
+
137
+ class RMSNorm(torch.nn.Module):
138
+ def __init__(self, dim: int):
139
+ super().__init__()
140
+ self.scale = nn.Parameter(torch.ones(dim))
141
+
142
+ def forward(self, x: Tensor):
143
+ x_dtype = x.dtype
144
+ x = x.float()
145
+ rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
146
+ return (x * rrms).to(dtype=x_dtype) * self.scale
147
+
148
+
149
+ class QKNorm(torch.nn.Module):
150
+ def __init__(self, dim: int):
151
+ super().__init__()
152
+ self.query_norm = RMSNorm(dim)
153
+ self.key_norm = RMSNorm(dim)
154
+
155
+ def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
156
+ q = self.query_norm(q)
157
+ k = self.key_norm(k)
158
+ return q.to(v), k.to(v)
159
+
160
+
161
+ class SelfAttention(nn.Module):
162
+ def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
163
+ super().__init__()
164
+ self.num_heads = num_heads
165
+ head_dim = dim // num_heads
166
+
167
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
168
+ self.norm = QKNorm(head_dim)
169
+ self.proj = nn.Linear(dim, dim)
170
+
171
+ def forward(self, x: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
172
+ qkv = self.qkv(x)
173
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
+ q, k = self.norm(q, k, v)
175
+ x = attention(q, k, v, pe=pe, mask=mask)
176
+ x = self.proj(x)
177
+ return x
178
+
179
+ class CrossAttention(nn.Module):
180
+ def __init__(self, dim: int, context_dim: int, num_heads: int = 8, qkv_bias: bool = False):
181
+ super().__init__()
182
+ self.num_heads = num_heads
183
+ head_dim = dim // num_heads
184
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
185
+ self.kv = nn.Linear(dim, context_dim * 2, bias=qkv_bias)
186
+ self.norm = QKNorm(head_dim)
187
+ self.proj = nn.Linear(dim, dim)
188
+
189
+ def forward(self, x: Tensor, context: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
190
+ qkv = self.qkv(x)
191
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
192
+ q, k = self.norm(q, k, v)
193
+ x = attention(q, k, v, pe=pe, mask=mask)
194
+ x = self.proj(x)
195
+ return x
196
+
197
+
198
+ @dataclass
199
+ class ModulationOut:
200
+ shift: Tensor
201
+ scale: Tensor
202
+ gate: Tensor
203
+
204
+
205
+ class Modulation(nn.Module):
206
+ def __init__(self, dim: int, double: bool):
207
+ super().__init__()
208
+ self.is_double = double
209
+ self.multiplier = 6 if double else 3
210
+ self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
211
+
212
+ def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
213
+ out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
214
+
215
+ return (
216
+ ModulationOut(*out[:3]),
217
+ ModulationOut(*out[3:]) if self.is_double else None,
218
+ )
219
+
220
+
221
+ class DoubleStreamBlock(nn.Module):
222
+ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, backend = 'pytorch'):
223
+ super().__init__()
224
+
225
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
226
+ self.num_heads = num_heads
227
+ self.hidden_size = hidden_size
228
+ self.img_mod = Modulation(hidden_size, double=True)
229
+ self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
230
+ self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
231
+
232
+ self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
233
+ self.img_mlp = nn.Sequential(
234
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
235
+ nn.GELU(approximate="tanh"),
236
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
237
+ )
238
+
239
+ self.backend = backend
240
+
241
+ self.txt_mod = Modulation(hidden_size, double=True)
242
+ self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
243
+ self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
244
+
245
+ self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
+ self.txt_mlp = nn.Sequential(
247
+ nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
248
+ nn.GELU(approximate="tanh"),
249
+ nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
250
+ )
251
+
252
+
253
+
254
+
255
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None, txt_length = None):
256
+ img_mod1, img_mod2 = self.img_mod(vec)
257
+ txt_mod1, txt_mod2 = self.txt_mod(vec)
258
+
259
+ txt, img = x[:, :txt_length], x[:, txt_length:]
260
+
261
+ # prepare image for attention
262
+ img_modulated = self.img_norm1(img)
263
+ img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
264
+ img_qkv = self.img_attn.qkv(img_modulated)
265
+ img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
266
+ img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
267
+ # prepare txt for attention
268
+ txt_modulated = self.txt_norm1(txt)
269
+ txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
270
+ txt_qkv = self.txt_attn.qkv(txt_modulated)
271
+ txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
272
+ txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
273
+
274
+ # run actual attention
275
+ q = torch.cat((txt_q, img_q), dim=2)
276
+ k = torch.cat((txt_k, img_k), dim=2)
277
+ v = torch.cat((txt_v, img_v), dim=2)
278
+ if mask is not None:
279
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
280
+ attn = attention(q, k, v, pe=pe, mask = mask, backend = self.backend)
281
+ txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
282
+
283
+ # calculate the img bloks
284
+ img = img + img_mod1.gate * self.img_attn.proj(img_attn)
285
+ img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
286
+
287
+ # calculate the txt bloks
288
+ txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
289
+ txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
290
+ x = torch.cat((txt, img), 1)
291
+ return x
292
+
293
+
294
+ class SingleStreamBlock(nn.Module):
295
+ """
296
+ A DiT block with parallel linear layers as described in
297
+ https://arxiv.org/abs/2302.05442 and adapted modulation interface.
298
+ """
299
+
300
+ def __init__(
301
+ self,
302
+ hidden_size: int,
303
+ num_heads: int,
304
+ mlp_ratio: float = 4.0,
305
+ qk_scale: float | None = None,
306
+ backend='pytorch'
307
+ ):
308
+ super().__init__()
309
+ self.hidden_dim = hidden_size
310
+ self.num_heads = num_heads
311
+ head_dim = hidden_size // num_heads
312
+ self.scale = qk_scale or head_dim**-0.5
313
+
314
+ self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
315
+ # qkv and mlp_in
316
+ self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
317
+ # proj and mlp_out
318
+ self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
319
+
320
+ self.norm = QKNorm(head_dim)
321
+
322
+ self.hidden_size = hidden_size
323
+ self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
324
+
325
+ self.mlp_act = nn.GELU(approximate="tanh")
326
+ self.modulation = Modulation(hidden_size, double=False)
327
+ self.backend = backend
328
+
329
+ def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None) -> Tensor:
330
+ mod, _ = self.modulation(vec)
331
+ x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
332
+ qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
333
+
334
+ q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
335
+ q, k = self.norm(q, k, v)
336
+ if mask is not None:
337
+ mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
338
+ # compute attention
339
+ attn = attention(q, k, v, pe=pe, mask = mask, backend=self.backend)
340
+ # compute activation in mlp stream, cat again and run second linear layer
341
+ output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
342
+ return x + mod.gate * output
343
+
344
+
345
+ class LastLayer(nn.Module):
346
+ def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
347
+ super().__init__()
348
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
349
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
350
+ self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
351
+
352
+ def forward(self, x: Tensor, vec: Tensor) -> Tensor:
353
+ shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
354
+ x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
355
+ x = self.linear(x)
356
+ return x
utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #copyright (c) Alibaba, Inc. and its affiliates.
2
+ import torch
3
+ import torchvision.transforms as T
4
+ from PIL import Image
5
+ from torchvision.transforms.functional import InterpolationMode
6
+
7
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
8
+ IMAGENET_STD = (0.229, 0.224, 0.225)
9
+
10
+
11
+ def build_transform(input_size):
12
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
13
+ transform = T.Compose([
14
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
15
+ T.Resize((input_size, input_size),
16
+ interpolation=InterpolationMode.BICUBIC),
17
+ T.ToTensor(),
18
+ T.Normalize(mean=MEAN, std=STD)
19
+ ])
20
+ return transform
21
+
22
+
23
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
24
+ image_size):
25
+ best_ratio_diff = float('inf')
26
+ best_ratio = (1, 1)
27
+ area = width * height
28
+ for ratio in target_ratios:
29
+ target_aspect_ratio = ratio[0] / ratio[1]
30
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
31
+ if ratio_diff < best_ratio_diff:
32
+ best_ratio_diff = ratio_diff
33
+ best_ratio = ratio
34
+ elif ratio_diff == best_ratio_diff:
35
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
36
+ best_ratio = ratio
37
+ return best_ratio
38
+
39
+
40
+ def dynamic_preprocess(image,
41
+ min_num=1,
42
+ max_num=12,
43
+ image_size=448,
44
+ use_thumbnail=False):
45
+ orig_width, orig_height = image.size
46
+ aspect_ratio = orig_width / orig_height
47
+
48
+ # calculate the existing image aspect ratio
49
+ target_ratios = set((i, j) for n in range(min_num, max_num + 1)
50
+ for i in range(1, n + 1) for j in range(1, n + 1)
51
+ if i * j <= max_num and i * j >= min_num)
52
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
53
+
54
+ # find the closest aspect ratio to the target
55
+ target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
56
+ target_ratios, orig_width,
57
+ orig_height, image_size)
58
+
59
+ # calculate the target width and height
60
+ target_width = image_size * target_aspect_ratio[0]
61
+ target_height = image_size * target_aspect_ratio[1]
62
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
63
+
64
+ # resize the image
65
+ resized_img = image.resize((target_width, target_height))
66
+ processed_images = []
67
+ for i in range(blocks):
68
+ box = ((i % (target_width // image_size)) * image_size,
69
+ (i // (target_width // image_size)) * image_size,
70
+ ((i % (target_width // image_size)) + 1) * image_size,
71
+ ((i // (target_width // image_size)) + 1) * image_size)
72
+ # split the image
73
+ split_img = resized_img.crop(box)
74
+ processed_images.append(split_img)
75
+ assert len(processed_images) == blocks
76
+ if use_thumbnail and len(processed_images) != 1:
77
+ thumbnail_img = image.resize((image_size, image_size))
78
+ processed_images.append(thumbnail_img)
79
+ return processed_images
80
+
81
+
82
+ def load_image(image_file, input_size=448, max_num=12):
83
+ if isinstance(image_file, str):
84
+ image = Image.open(image_file).convert('RGB')
85
+ else:
86
+ image = image_file
87
+ transform = build_transform(input_size=input_size)
88
+ images = dynamic_preprocess(image,
89
+ image_size=input_size,
90
+ use_thumbnail=True,
91
+ max_num=max_num)
92
+ pixel_values = [transform(image) for image in images]
93
+ pixel_values = torch.stack(pixel_values)
94
+ return pixel_values
95
+