chaojiemao commited on
Commit
83266af
·
1 Parent(s): 9cee51c

modify old files

Browse files
__init__.py DELETED
@@ -1 +0,0 @@
1
- from . import models
 
 
ace_flux_inference.py DELETED
@@ -1,329 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import math
4
- import os
5
- import random
6
- import numpy as np
7
- import torch
8
- import torch.nn.functional as F
9
- from PIL import Image
10
- import torchvision.transforms as T
11
- from scepter.modules.model.registry import DIFFUSIONS, BACKBONES
12
- import torchvision.transforms.functional as TF
13
- from scepter.modules.model.utils.basic_utils import check_list_of_list
14
- from scepter.modules.model.utils.basic_utils import \
15
- pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
16
- from scepter.modules.model.utils.basic_utils import (
17
- to_device, unpack_tensor_into_imagelist)
18
- from scepter.modules.utils.distribute import we
19
- from scepter.modules.utils.file_system import FS
20
- from scepter.modules.utils.logger import get_logger
21
- from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
22
-
23
- def process_edit_image(images,
24
- masks,
25
- tasks):
26
-
27
- if not isinstance(images, list):
28
- images = [images]
29
- if not isinstance(masks, list):
30
- masks = [masks]
31
- if not isinstance(tasks, list):
32
- tasks = [tasks]
33
-
34
- img_tensors = []
35
- mask_tensors = []
36
- for img, mask, task in zip(images, masks, tasks):
37
- if mask is None or mask == '':
38
- mask = Image.new('L', img.size, 0)
39
- img = TF.center_crop(img, [512, 512])
40
- mask = TF.center_crop(mask, [512, 512])
41
-
42
- mask = np.asarray(mask)
43
- mask = np.where(mask > 128, 1, 0)
44
- mask = mask.astype(
45
- np.float32) if np.any(mask) else np.ones_like(mask).astype(
46
- np.float32)
47
-
48
- img_tensor = TF.to_tensor(img).to(we.device_id)
49
- img_tensor = TF.normalize(img_tensor,
50
- mean=[0.5, 0.5, 0.5],
51
- std=[0.5, 0.5, 0.5])
52
- mask_tensor = TF.to_tensor(mask).to(we.device_id)
53
- if task in ['inpainting', 'Try On', 'Inpainting']:
54
- mask_indicator = mask_tensor.repeat(3, 1, 1)
55
- img_tensor[mask_indicator == 1] = -1.0
56
- img_tensors.append(img_tensor)
57
- mask_tensors.append(mask_tensor)
58
- return img_tensors, mask_tensors
59
-
60
- class FluxACEInference(DiffusionInference):
61
-
62
- def __init__(self, logger=None):
63
- if logger is None:
64
- logger = get_logger(name='scepter')
65
- self.logger = logger
66
- self.loaded_model = {}
67
- self.loaded_model_name = [
68
- 'diffusion_model', 'first_stage_model', 'cond_stage_model', 'ref_cond_stage_model'
69
- ]
70
-
71
- def init_from_cfg(self, cfg):
72
- self.name = cfg.NAME
73
- self.is_default = cfg.get('IS_DEFAULT', False)
74
- self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
75
- module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
76
- assert cfg.have('MODEL')
77
- self.size_factor = cfg.get('SIZE_FACTOR', 8)
78
- self.diffusion_model = self.infer_model(
79
- cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
80
- 'DIFFUSION_MODEL',
81
- None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
82
- self.first_stage_model = self.infer_model(
83
- cfg.MODEL.FIRST_STAGE_MODEL,
84
- module_paras.get(
85
- 'FIRST_STAGE_MODEL',
86
- None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
87
- self.cond_stage_model = self.infer_model(
88
- cfg.MODEL.COND_STAGE_MODEL,
89
- module_paras.get(
90
- 'COND_STAGE_MODEL',
91
- None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
92
-
93
- self.ref_cond_stage_model = self.infer_model(
94
- cfg.MODEL.REF_COND_STAGE_MODEL,
95
- module_paras.get(
96
- 'REF_COND_STAGE_MODEL',
97
- None)) if cfg.MODEL.have('REF_COND_STAGE_MODEL') else None
98
-
99
- self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
100
- logger=self.logger)
101
- self.interpolate_func = lambda x: (F.interpolate(
102
- x.unsqueeze(0),
103
- scale_factor=1 / self.size_factor,
104
- mode='nearest-exact') if x is not None else None)
105
-
106
- self.max_seq_length = cfg.get("MAX_SEQ_LENGTH", 4096)
107
- if not self.use_dynamic_model:
108
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
109
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
110
- if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
111
- with torch.device("meta"):
112
- pretrained_model = self.diffusion_model['cfg'].PRETRAINED_MODEL
113
- self.diffusion_model['cfg'].PRETRAINED_MODEL = None
114
- diffusers_lora = self.diffusion_model['cfg'].get("DIFFUSERS_LORA_MODEL", None)
115
- self.diffusion_model['cfg'].DIFFUSERS_LORA_MODEL = None
116
- swift_lora = self.diffusion_model['cfg'].get("SWIFT_LORA_MODEL", None)
117
- self.diffusion_model['cfg'].SWIFT_LORA_MODEL = None
118
- pretrain_adapter = self.diffusion_model['cfg'].get("PRETRAIN_ADAPTER", None)
119
- self.diffusion_model['cfg'].PRETRAIN_ADAPTER = None
120
- blackforest_lora = self.diffusion_model['cfg'].get("BLACKFOREST_LORA_MODEL", None)
121
- self.diffusion_model['cfg'].BLACKFOREST_LORA_MODEL = None
122
- self.diffusion_model['model'] = BACKBONES.build(self.diffusion_model['cfg'], logger=self.logger).eval()
123
- # self.dynamic_load(self.diffusion_model, 'diffusion_model')
124
- self.diffusion_model['model'].lora_model = diffusers_lora
125
- self.diffusion_model['model'].swift_lora_model = swift_lora
126
- self.diffusion_model['model'].pretrain_adapter = pretrain_adapter
127
- self.diffusion_model['model'].blackforest_lora_model = blackforest_lora
128
- self.diffusion_model['model'].load_pretrained_model(pretrained_model)
129
- self.diffusion_model['device'] = we.device_id
130
-
131
- def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
132
- c, H, W = image.shape
133
- scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
134
- rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
135
- rW = int(W * scale) // 16 * 16
136
- image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
137
- return image
138
-
139
-
140
- @torch.no_grad()
141
- def encode_first_stage(self, x, **kwargs):
142
- _, dtype = self.get_function_info(self.first_stage_model, 'encode')
143
- with torch.autocast('cuda',
144
- enabled=dtype in ('float16', 'bfloat16'),
145
- dtype=getattr(torch, dtype)):
146
- def run_one_image(u):
147
- zu = get_model(self.first_stage_model).encode(u)
148
- if isinstance(zu, (tuple, list)):
149
- zu = zu[0]
150
- return zu
151
-
152
- z = [run_one_image(u.unsqueeze(0) if u.dim() == 3 else u) for u in x]
153
- return z
154
-
155
-
156
- @torch.no_grad()
157
- def decode_first_stage(self, z):
158
- _, dtype = self.get_function_info(self.first_stage_model, 'decode')
159
- with torch.autocast('cuda',
160
- enabled=dtype in ('float16', 'bfloat16'),
161
- dtype=getattr(torch, dtype)):
162
- return [get_model(self.first_stage_model).decode(zu) for zu in z]
163
-
164
- def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
165
- noise = torch.randn(
166
- num_samples,
167
- 16,
168
- # allow for packing
169
- 2 * math.ceil(h / 16),
170
- 2 * math.ceil(w / 16),
171
- device="cpu",
172
- dtype=dtype,
173
- generator=torch.Generator().manual_seed(seed),
174
- ).to(device)
175
- return noise
176
-
177
- @torch.no_grad()
178
- def __call__(self,
179
- image=None,
180
- mask=None,
181
- prompt='',
182
- task=None,
183
- negative_prompt='',
184
- output_height=1024,
185
- output_width=1024,
186
- sampler='flow_euler',
187
- sample_steps=20,
188
- guide_scale=3.5,
189
- seed=-1,
190
- history_io=None,
191
- tar_index=0,
192
- # align=0,
193
- **kwargs):
194
- input_image, input_mask = image, mask
195
- seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
196
- if input_image is not None:
197
- # assert isinstance(input_image, list) and isinstance(input_mask, list)
198
- if task is None:
199
- task = [''] * len(input_image)
200
- if not isinstance(prompt, list):
201
- prompt = [prompt] * len(input_image)
202
- prompt = [
203
- pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
204
- for i, pp in enumerate(prompt)
205
- ]
206
- edit_image, edit_image_mask = process_edit_image(
207
- input_image, input_mask, task)
208
- image = torch.zeros(
209
- size=[3, int(output_height),
210
- int(output_width)])
211
- image_mask = torch.ones(
212
- size=[1, int(output_height),
213
- int(output_width)])
214
- edit_image, edit_image_mask = [edit_image], [edit_image_mask]
215
- else:
216
- edit_image = edit_image_mask = [[]]
217
- image = torch.zeros(
218
- size=[3, int(output_height),
219
- int(output_width)])
220
- image_mask = torch.ones(
221
- size=[1, int(output_height),
222
- int(output_width)])
223
- if not isinstance(prompt, list):
224
- prompt = [prompt]
225
- align = 0
226
- image, image_mask, prompt = [image], [image_mask], [prompt],
227
- align = [align for p in prompt] if isinstance(align, int) else align
228
-
229
- assert check_list_of_list(prompt) and check_list_of_list(
230
- edit_image) and check_list_of_list(edit_image_mask)
231
- # negative prompt is not used
232
- image = to_device(image)
233
- ctx = {}
234
- # Get Noise Shape
235
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
236
- x = self.encode_first_stage(image)
237
- self.dynamic_unload(self.first_stage_model,
238
- 'first_stage_model',
239
- skip_loaded=not self.use_dynamic_model)
240
-
241
- g = torch.Generator(device=we.device_id).manual_seed(seed)
242
- noise = [
243
- torch.randn((1, 16, i.shape[2], i.shape[3]), device=we.device_id, dtype=torch.bfloat16).normal_(generator=g)
244
- for i in x
245
- ]
246
- # import pdb;pdb.set_trace()
247
- noise, x_shapes = pack_imagelist_into_tensor(noise)
248
- ctx['x_shapes'] = x_shapes
249
- ctx['align'] = align
250
-
251
- image_mask = to_device(image_mask, strict=False)
252
- cond_mask = [self.interpolate_func(i) for i in image_mask
253
- ] if image_mask is not None else [None] * len(image)
254
- ctx['x_mask'] = cond_mask
255
- # Encode Prompt
256
- instruction_prompt = [[pp[-1]] if "{image}" in pp[-1] else ["{image} " + pp[-1]] for pp in prompt]
257
- self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
258
- function_name, dtype = self.get_function_info(self.cond_stage_model)
259
- cont = getattr(get_model(self.cond_stage_model), function_name)(instruction_prompt)
260
- cont["context"] = [ct[-1] for ct in cont["context"]]
261
- cont["y"] = [ct[-1] for ct in cont["y"]]
262
- self.dynamic_unload(self.cond_stage_model,
263
- 'cond_stage_model',
264
- skip_loaded=not self.use_dynamic_model)
265
- ctx.update(cont)
266
-
267
- # Encode Edit Images
268
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
269
- edit_image = [to_device(i, strict=False) for i in edit_image]
270
- edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
271
- e_img, e_mask = [], []
272
- for u, m in zip(edit_image, edit_image_mask):
273
- if u is None:
274
- continue
275
- if m is None:
276
- m = [None] * len(u)
277
- e_img.append(self.encode_first_stage(u, **kwargs))
278
- e_mask.append([self.interpolate_func(i) for i in m])
279
- self.dynamic_unload(self.first_stage_model,
280
- 'first_stage_model',
281
- skip_loaded=not self.use_dynamic_model)
282
- ctx['edit'] = e_img
283
- ctx['edit_mask'] = e_mask
284
- # Encode Ref Images
285
- if guide_scale is not None:
286
- guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device, dtype=noise.dtype)
287
- else:
288
- guide_scale = None
289
-
290
- # Diffusion Process
291
- self.dynamic_load(self.diffusion_model, 'diffusion_model')
292
- function_name, dtype = self.get_function_info(self.diffusion_model)
293
- with torch.autocast('cuda',
294
- enabled=dtype in ('float16', 'bfloat16'),
295
- dtype=getattr(torch, dtype)):
296
- latent = self.diffusion.sample(
297
- noise=noise,
298
- sampler=sampler,
299
- model=get_model(self.diffusion_model),
300
- model_kwargs={
301
- "cond": ctx, "guidance": guide_scale, "gc_seg": -1
302
- },
303
- steps=sample_steps,
304
- show_progress=True,
305
- guide_scale=guide_scale,
306
- return_intermediate=None,
307
- reverse_scale=-1,
308
- **kwargs).float()
309
- if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
310
- 'diffusion_model',
311
- skip_loaded=not self.use_dynamic_model)
312
-
313
- # Decode to Pixel Space
314
- self.dynamic_load(self.first_stage_model, 'first_stage_model')
315
- samples = unpack_tensor_into_imagelist(latent, x_shapes)
316
- x_samples = self.decode_first_stage(samples)
317
- self.dynamic_unload(self.first_stage_model,
318
- 'first_stage_model',
319
- skip_loaded=not self.use_dynamic_model)
320
- x_samples = [x.squeeze(0) for x in x_samples]
321
-
322
- imgs = [
323
- torch.clamp((x_i.float() + 1.0) / 2.0,
324
- min=0.0,
325
- max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
326
- for x_i in x_samples
327
- ]
328
- imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
329
- return imgs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py DELETED
@@ -1,1428 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import base64
4
- import copy
5
- import glob
6
- import io
7
- import os, csv, sys
8
- import random
9
- import re
10
- import shlex
11
- import string
12
- import subprocess
13
- import threading
14
- import spaces
15
-
16
- subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
17
- #subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'),
18
- # env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
19
- subprocess.run(shlex.split('pip install scepter'))
20
-
21
- import cv2
22
- import gradio as gr
23
- import numpy as np
24
- import torch
25
- import transformers
26
- from PIL import Image
27
- from transformers import AutoModel, AutoTokenizer
28
- from ace_flux_inference import FluxACEInference
29
- from scepter.modules.utils.config import Config
30
- from scepter.modules.utils.directory import get_md5
31
- from scepter.modules.utils.file_system import FS
32
- from scepter.studio.utils.env import init_env
33
- from importlib.metadata import version
34
-
35
- from example import get_examples
36
- from utils import load_image
37
- from huggingface_hub import login
38
-
39
- login(token=os.environ.get("HF_TOKEN", ""))
40
-
41
- csv.field_size_limit(sys.maxsize)
42
-
43
- refresh_sty = '\U0001f504' # 🔄
44
- clear_sty = '\U0001f5d1' # 🗑️
45
- upload_sty = '\U0001f5bc' # 🖼️
46
- sync_sty = '\U0001f4be' # 💾
47
- chat_sty = '\U0001F4AC' # 💬
48
- video_sty = '\U0001f3a5' # 🎥
49
-
50
- lock = threading.Lock()
51
- inference_dict = {
52
- "ACE_FLUX": FluxACEInference,
53
- }
54
-
55
-
56
- class ChatBotUI(object):
57
- def __init__(self,
58
- cfg_general_file,
59
- is_debug=False,
60
- language='en',
61
- root_work_dir='./'):
62
- try:
63
- from diffusers import CogVideoXImageToVideoPipeline
64
- from diffusers.utils import export_to_video
65
- except Exception as e:
66
- print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
67
-
68
- cfg = Config(cfg_file=cfg_general_file)
69
- if cfg.have("FILE_SYSTEM"):
70
- for file_sys in cfg.FILE_SYSTEM:
71
- fs_prefix = FS.init_fs_client(file_sys)
72
- else:
73
- fs_prefix = FS.init_fs_client(cfg)
74
- cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
75
- if not FS.exists(cfg.WORK_DIR):
76
- FS.make_dir(cfg.WORK_DIR)
77
- cfg = init_env(cfg)
78
- self.cache_dir = cfg.WORK_DIR
79
- self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
80
- self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
81
- self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
82
- '*.yaml'))
83
- self.model_choices = dict()
84
- self.default_model_name = ''
85
- for i in self.model_yamls:
86
- model_cfg = Config(load=True, cfg_file=i)
87
- model_name = model_cfg.NAME
88
- if model_cfg.IS_DEFAULT: self.default_model_name = model_name
89
- self.model_choices[model_name] = model_cfg
90
- print('Models: ', self.model_choices.keys())
91
- local_folder = FS.get_dir_to_local_dir("hf://black-forest-labs/FLUX.1-dev")
92
- subprocess.run(shlex.split(f'rm -rf {local_folder}/transformer'))
93
- subprocess.run(shlex.split(f'rm -rf {local_folder}/vae'))
94
- subprocess.run(shlex.split(f'rm -rf {local_folder}/flux1-dev.safetensors'))
95
-
96
- assert len(self.model_choices) > 0
97
- if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
98
- self.model_name = self.default_model_name
99
- pipe_cfg = self.model_choices[self.default_model_name]
100
- infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
101
- self.pipe = inference_dict[infer_name]()
102
- self.pipe.init_from_cfg(pipe_cfg)
103
- self.max_msgs = 20
104
- self.enable_i2v = cfg.get('ENABLE_I2V', False)
105
- self.gradio_version = version('gradio')
106
-
107
- if self.enable_i2v:
108
- self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
109
- self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
110
- if self.i2v_model_name == 'CogVideoX-5b-I2V':
111
- with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir:
112
- self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
113
- local_dir, torch_dtype=torch.bfloat16).cuda()
114
- else:
115
- raise NotImplementedError
116
-
117
- with FS.get_dir_to_local_dir(
118
- cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir:
119
- self.captioner = AutoModel.from_pretrained(
120
- local_dir,
121
- torch_dtype=torch.bfloat16,
122
- low_cpu_mem_usage=True,
123
- use_flash_attn=True,
124
- trust_remote_code=True).eval().cuda()
125
- self.llm_tokenizer = AutoTokenizer.from_pretrained(
126
- local_dir, trust_remote_code=True, use_fast=False)
127
- self.llm_generation_config = dict(max_new_tokens=1024,
128
- do_sample=True)
129
- self.llm_prompt = cfg.LLM.PROMPT
130
- self.llm_max_num = 2
131
-
132
- with FS.get_dir_to_local_dir(
133
- cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir:
134
- self.enhancer = transformers.pipeline(
135
- 'text-generation',
136
- model=local_dir,
137
- model_kwargs={'torch_dtype': torch.bfloat16},
138
- device_map='auto',
139
- )
140
-
141
- sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
142
- For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
143
- There are a few rules to follow:
144
- You will only ever output a single video description per user request.
145
- When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
146
- Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
147
- Video descriptions must have the same num of words as examples below. Extra words will be ignored.
148
- """
149
- self.enhance_ctx = [
150
- {
151
- 'role': 'system',
152
- 'content': sys_prompt
153
- },
154
- {
155
- 'role':
156
- 'user',
157
- 'content':
158
- 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
159
- },
160
- {
161
- 'role':
162
- 'assistant',
163
- 'content':
164
- "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
165
- },
166
- {
167
- 'role':
168
- 'user',
169
- 'content':
170
- 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
171
- },
172
- {
173
- 'role':
174
- 'assistant',
175
- 'content':
176
- "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
177
- },
178
- {
179
- 'role':
180
- 'user',
181
- 'content':
182
- 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
183
- },
184
- {
185
- 'role':
186
- 'assistant',
187
- 'content':
188
- 'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.',
189
- },
190
- ]
191
-
192
- def create_ui(self):
193
-
194
- css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
195
- with gr.Blocks(css=css,
196
- title='Chatbot',
197
- head='Chatbot',
198
- analytics_enabled=False):
199
- self.history = gr.State(value=[])
200
- self.images = gr.State(value={})
201
- self.history_result = gr.State(value={})
202
- self.retry_msg = gr.State(value='')
203
- with gr.Group():
204
- self.ui_mode = gr.State(value='legacy')
205
- with gr.Row(equal_height=True, visible=False) as self.chat_group:
206
- with gr.Column(visible=True) as self.chat_page:
207
- self.chatbot = gr.Chatbot(
208
- height=600,
209
- value=[],
210
- bubble_full_width=False,
211
- show_copy_button=True,
212
- container=False,
213
- placeholder='<strong>Chat Box</strong>')
214
- with gr.Row():
215
- self.clear_btn = gr.Button(clear_sty +
216
- ' Clear Chat',
217
- size='sm')
218
-
219
- with gr.Column(visible=False) as self.editor_page:
220
- with gr.Tabs(visible=False) as self.upload_tabs:
221
- with gr.Tab(id='ImageUploader',
222
- label='Image Uploader',
223
- visible=True) as self.upload_tab:
224
- self.image_uploader = gr.Image(
225
- height=550,
226
- interactive=True,
227
- type='pil',
228
- image_mode='RGB',
229
- sources=['upload'],
230
- elem_id='image_uploader',
231
- format='png')
232
- with gr.Row():
233
- self.sub_btn_1 = gr.Button(
234
- value='Submit',
235
- elem_id='upload_submit')
236
- self.ext_btn_1 = gr.Button(value='Exit')
237
- with gr.Tabs(visible=False) as self.edit_tabs:
238
- with gr.Tab(id='ImageEditor',
239
- label='Image Editor') as self.edit_tab:
240
- self.mask_type = gr.Dropdown(
241
- label='Mask Type',
242
- choices=[
243
- 'Background', 'Composite',
244
- 'Outpainting'
245
- ],
246
- value='Background')
247
- self.mask_type_info = gr.HTML(
248
- value=
249
- "<div style='background-color: white; padding-left: 15px; color: grey;'>Background mode will not erase the visual content in the mask area</div>"
250
- )
251
- with gr.Accordion(
252
- label='Outpainting Setting',
253
- open=True,
254
- visible=False) as self.outpaint_tab:
255
- with gr.Row(variant='panel'):
256
- self.top_ext = gr.Slider(
257
- show_label=True,
258
- label='Top Extend Ratio',
259
- minimum=0.0,
260
- maximum=2.0,
261
- step=0.1,
262
- value=0.25)
263
- self.bottom_ext = gr.Slider(
264
- show_label=True,
265
- label='Bottom Extend Ratio',
266
- minimum=0.0,
267
- maximum=2.0,
268
- step=0.1,
269
- value=0.25)
270
- with gr.Row(variant='panel'):
271
- self.left_ext = gr.Slider(
272
- show_label=True,
273
- label='Left Extend Ratio',
274
- minimum=0.0,
275
- maximum=2.0,
276
- step=0.1,
277
- value=0.25)
278
- self.right_ext = gr.Slider(
279
- show_label=True,
280
- label='Right Extend Ratio',
281
- minimum=0.0,
282
- maximum=2.0,
283
- step=0.1,
284
- value=0.25)
285
- with gr.Row(variant='panel'):
286
- self.img_pad_btn = gr.Button(
287
- value='Pad Image')
288
-
289
- self.image_editor = gr.ImageMask(
290
- value=None,
291
- sources=[],
292
- layers=False,
293
- label='Edit Image',
294
- elem_id='image_editor',
295
- format='png')
296
- with gr.Row():
297
- self.sub_btn_2 = gr.Button(
298
- value='Submit', elem_id='edit_submit')
299
- self.ext_btn_2 = gr.Button(value='Exit')
300
-
301
- with gr.Tab(id='ImageViewer',
302
- label='Image Viewer') as self.image_view_tab:
303
- if self.gradio_version >= '5.0.0':
304
- self.image_viewer = gr.Image(
305
- label='Image',
306
- type='pil',
307
- show_download_button=True,
308
- elem_id='image_viewer')
309
- else:
310
- try:
311
- from gradio_imageslider import ImageSlider
312
- except Exception as e:
313
- print(f"Import gradio_imageslider failed, please install.")
314
- self.image_viewer = ImageSlider(
315
- label='Image',
316
- type='pil',
317
- show_download_button=True,
318
- elem_id='image_viewer')
319
-
320
- self.ext_btn_3 = gr.Button(value='Exit')
321
-
322
- with gr.Tab(id='VideoViewer',
323
- label='Video Viewer',
324
- visible=False) as self.video_view_tab:
325
- self.video_viewer = gr.Video(
326
- label='Video',
327
- interactive=False,
328
- sources=[],
329
- format='mp4',
330
- show_download_button=True,
331
- elem_id='video_viewer',
332
- loop=True,
333
- autoplay=True)
334
-
335
- self.ext_btn_4 = gr.Button(value='Exit')
336
-
337
- with gr.Row(equal_height=True, visible=True) as self.legacy_group:
338
- with gr.Column():
339
- self.legacy_image_uploader = gr.Image(
340
- height=550,
341
- interactive=True,
342
- type='pil',
343
- image_mode='RGB',
344
- elem_id='legacy_image_uploader',
345
- format='png')
346
- with gr.Column():
347
- self.legacy_image_viewer = gr.Image(
348
- label='Image',
349
- height=550,
350
- type='pil',
351
- interactive=False,
352
- show_download_button=True,
353
- elem_id='image_viewer')
354
-
355
- with gr.Accordion(label='Setting', open=False):
356
- with gr.Row():
357
- self.model_name_dd = gr.Dropdown(
358
- choices=self.model_choices,
359
- value=self.default_model_name,
360
- label='Model Version')
361
-
362
- with gr.Row():
363
- self.negative_prompt = gr.Textbox(
364
- value='',
365
- placeholder=
366
- 'Negative prompt used for Classifier-Free Guidance',
367
- label='Negative Prompt',
368
- container=False)
369
-
370
- with gr.Row():
371
- # REFINER_PROMPT
372
- self.refiner_prompt = gr.Textbox(
373
- value=self.pipe.input.get("refiner_prompt", ""),
374
- visible=self.pipe.input.get("refiner_prompt", None) is not None,
375
- placeholder=
376
- 'Prompt used for refiner',
377
- label='Refiner Prompt',
378
- container=False)
379
-
380
- with gr.Row():
381
- with gr.Column(scale=8, min_width=500):
382
- with gr.Row():
383
- self.step = gr.Slider(minimum=1,
384
- maximum=1000,
385
- value=self.pipe.input.get("sample_steps", 20),
386
- visible=self.pipe.input.get("sample_steps", None) is not None,
387
- label='Sample Step')
388
- self.cfg_scale = gr.Slider(
389
- minimum=1.0,
390
- maximum=20.0,
391
- value=self.pipe.input.get("guide_scale", 4.5),
392
- visible=self.pipe.input.get("guide_scale", None) is not None,
393
- label='Guidance Scale')
394
- self.rescale = gr.Slider(minimum=0.0,
395
- maximum=1.0,
396
- value=self.pipe.input.get("guide_rescale", 0.5),
397
- visible=self.pipe.input.get("guide_rescale", None) is not None,
398
- label='Rescale')
399
- self.refiner_scale = gr.Slider(minimum=-0.1,
400
- maximum=1.0,
401
- value=self.pipe.input.get("refiner_scale", -1),
402
- visible=self.pipe.input.get("refiner_scale",
403
- None) is not None,
404
- label='Refiner Scale')
405
- self.seed = gr.Slider(minimum=-1,
406
- maximum=10000000,
407
- value=-1,
408
- label='Seed')
409
- self.output_height = gr.Slider(
410
- minimum=256,
411
- maximum=1440,
412
- value=self.pipe.input.get("output_height", 1024),
413
- visible=self.pipe.input.get("output_height", None) is not None,
414
- label='Output Height')
415
- self.output_width = gr.Slider(
416
- minimum=256,
417
- maximum=1440,
418
- value=self.pipe.input.get("output_width", 1024),
419
- visible=self.pipe.input.get("output_width", None) is not None,
420
- label='Output Width')
421
- with gr.Column(scale=1, min_width=50):
422
- self.use_history = gr.Checkbox(value=False,
423
- label='Use History')
424
- self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
425
- visible=self.pipe.input.get("use_ace", None) is not None,
426
- label='Use ACE')
427
- self.video_auto = gr.Checkbox(
428
- value=False,
429
- label='Auto Gen Video',
430
- visible=self.enable_i2v)
431
-
432
- with gr.Row(variant='panel',
433
- equal_height=True,
434
- visible=self.enable_i2v):
435
- self.video_fps = gr.Slider(minimum=1,
436
- maximum=16,
437
- value=8,
438
- label='Video FPS',
439
- visible=True)
440
- self.video_frames = gr.Slider(minimum=8,
441
- maximum=49,
442
- value=49,
443
- label='Video Frame Num',
444
- visible=True)
445
- self.video_step = gr.Slider(minimum=1,
446
- maximum=1000,
447
- value=50,
448
- label='Video Sample Step',
449
- visible=True)
450
- self.video_cfg_scale = gr.Slider(
451
- minimum=1.0,
452
- maximum=20.0,
453
- value=6.0,
454
- label='Video Guidance Scale',
455
- visible=True)
456
- self.video_seed = gr.Slider(minimum=-1,
457
- maximum=10000000,
458
- value=-1,
459
- label='Video Seed',
460
- visible=True)
461
-
462
- with gr.Row():
463
- self.chatbot_inst = """
464
- **Instruction**:
465
- 1. Click 'Upload' button to upload one or more images as input images.
466
- 2. Enter '@' in the text box will exhibit all images in the gallery.
467
- 3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
468
- 4. Compose the editing instruction for the selected image, incorporating image id '@xxxxxx' into your instruction.
469
- For example, you might say, "Change the girl's skirt in @123456 to blue." The '@xxxxx' token will facilitate the identification of the specific image, and will be automatically replaced by a special token '{image}' in the instruction. Furthermore, it is also possible to engage in text-to-image generation without any initial image input.
470
- 5. Once your instructions are prepared, please click the "Chat" button to view the edited result in the chat window.
471
- 6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
472
- 7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
473
- 8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
474
- """
475
-
476
- self.legacy_inst = """
477
- **Instruction**:
478
- 1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
479
- 2. Enter '@' in the text box will exhibit all images in the gallery.
480
- 3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
481
- 4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
482
- 5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
483
- 6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
484
- """
485
-
486
- self.instruction = gr.Markdown(value=self.legacy_inst)
487
-
488
- with gr.Row(variant='panel',
489
- equal_height=True,
490
- show_progress=False):
491
- with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
492
- self.upload_btn = gr.Button(value=upload_sty +
493
- ' Upload',
494
- variant='secondary')
495
- with gr.Column(scale=5, min_width=500):
496
- self.text = gr.Textbox(
497
- placeholder='Input "@" find history of image',
498
- label='Instruction',
499
- container=False)
500
- with gr.Column(scale=1, min_width=100):
501
- self.chat_btn = gr.Button(value='Generate',
502
- variant='primary')
503
- with gr.Column(scale=1, min_width=100):
504
- self.retry_btn = gr.Button(value=refresh_sty +
505
- ' Retry',
506
- variant='secondary')
507
- with gr.Column(scale=1, min_width=100):
508
- self.mode_checkbox = gr.Checkbox(
509
- value=False,
510
- label='ChatBot')
511
- with gr.Column(scale=(1 if self.enable_i2v else 0),
512
- min_width=0):
513
- self.video_gen_btn = gr.Button(value=video_sty +
514
- ' Gen Video',
515
- variant='secondary',
516
- visible=self.enable_i2v)
517
- with gr.Column(scale=(1 if self.enable_i2v else 0),
518
- min_width=0):
519
- self.extend_prompt = gr.Checkbox(
520
- value=True,
521
- label='Extend Prompt',
522
- visible=self.enable_i2v)
523
-
524
- with gr.Row():
525
- self.gallery = gr.Gallery(visible=False,
526
- label='History',
527
- columns=10,
528
- allow_preview=False,
529
- interactive=False)
530
-
531
- self.eg = gr.Column(visible=True)
532
-
533
- def set_callbacks(self, *args, **kwargs):
534
-
535
- ########################################
536
- # @spaces.GPU(duration=60)
537
- def change_model(model_name):
538
- if model_name not in self.model_choices:
539
- gr.Info('The provided model name is not a valid choice!')
540
- return model_name, gr.update(), gr.update()
541
-
542
- if model_name != self.model_name:
543
- lock.acquire()
544
- del self.pipe
545
- torch.cuda.empty_cache()
546
- torch.cuda.ipc_collect()
547
- pipe_cfg = self.model_choices[model_name]
548
- infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
549
- self.pipe = inference_dict[infer_name]()
550
- self.pipe.init_from_cfg(pipe_cfg)
551
- self.model_name = model_name
552
- lock.release()
553
-
554
- return (model_name, gr.update(), gr.update(),
555
- gr.Slider(
556
- value=self.pipe.input.get("sample_steps", 20),
557
- visible=self.pipe.input.get("sample_steps", None) is not None),
558
- gr.Slider(
559
- value=self.pipe.input.get("guide_scale", 4.5),
560
- visible=self.pipe.input.get("guide_scale", None) is not None),
561
- gr.Slider(
562
- value=self.pipe.input.get("guide_rescale", 0.5),
563
- visible=self.pipe.input.get("guide_rescale", None) is not None),
564
- gr.Slider(
565
- value=self.pipe.input.get("output_height", 1024),
566
- visible=self.pipe.input.get("output_height", None) is not None),
567
- gr.Slider(
568
- value=self.pipe.input.get("output_width", 1024),
569
- visible=self.pipe.input.get("output_width", None) is not None),
570
- gr.Textbox(
571
- value=self.pipe.input.get("refiner_prompt", ""),
572
- visible=self.pipe.input.get("refiner_prompt", None) is not None),
573
- gr.Slider(
574
- value=self.pipe.input.get("refiner_scale", -1),
575
- visible=self.pipe.input.get("refiner_scale", None) is not None
576
- ),
577
- gr.Checkbox(
578
- value=self.pipe.input.get("use_ace", True),
579
- visible=self.pipe.input.get("use_ace", None) is not None
580
- )
581
- )
582
-
583
- self.model_name_dd.change(
584
- change_model,
585
- inputs=[self.model_name_dd],
586
- outputs=[
587
- self.model_name_dd, self.chatbot, self.text,
588
- self.step,
589
- self.cfg_scale, self.rescale, self.output_height,
590
- self.output_width, self.refiner_prompt, self.refiner_scale,
591
- self.use_ace])
592
-
593
- def mode_change(mode_check):
594
- if mode_check:
595
- # ChatBot
596
- return (
597
- gr.Row(visible=False),
598
- gr.Row(visible=True),
599
- gr.Button(value='Generate'),
600
- gr.State(value='chatbot'),
601
- gr.Column(visible=True),
602
- gr.Markdown(value=self.chatbot_inst)
603
- )
604
- else:
605
- # Legacy
606
- return (
607
- gr.Row(visible=True),
608
- gr.Row(visible=False),
609
- gr.Button(value=chat_sty + ' Chat'),
610
- gr.State(value='legacy'),
611
- gr.Column(visible=False),
612
- gr.Markdown(value=self.legacy_inst)
613
- )
614
-
615
- self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
616
- outputs=[self.legacy_group, self.chat_group,
617
- self.chat_btn, self.ui_mode,
618
- self.upload_panel, self.instruction])
619
-
620
- ########################################
621
- def generate_gallery(text, images):
622
- if text.endswith(' '):
623
- return gr.update(), gr.update(visible=False)
624
- elif text.endswith('@'):
625
- gallery_info = []
626
- for image_id, image_meta in images.items():
627
- thumbnail_path = image_meta['thumbnail']
628
- gallery_info.append((thumbnail_path, image_id))
629
- return gr.update(), gr.update(visible=True, value=gallery_info)
630
- else:
631
- gallery_info = []
632
- match = re.search('@([^@ ]+)$', text)
633
- if match:
634
- prefix = match.group(1)
635
- for image_id, image_meta in images.items():
636
- if not image_id.startswith(prefix):
637
- continue
638
- thumbnail_path = image_meta['thumbnail']
639
- gallery_info.append((thumbnail_path, image_id))
640
-
641
- if len(gallery_info) > 0:
642
- return gr.update(), gr.update(visible=True,
643
- value=gallery_info)
644
- else:
645
- return gr.update(), gr.update(visible=False)
646
- else:
647
- return gr.update(), gr.update(visible=False)
648
-
649
- self.text.input(generate_gallery,
650
- inputs=[self.text, self.images],
651
- outputs=[self.text, self.gallery],
652
- show_progress='hidden')
653
-
654
- ########################################
655
- def select_image(text, evt: gr.SelectData):
656
- image_id = evt.value['caption']
657
- text = '@'.join(text.split('@')[:-1]) + f'@{image_id} '
658
- return gr.update(value=text), gr.update(visible=False, value=None)
659
-
660
- self.gallery.select(select_image,
661
- inputs=self.text,
662
- outputs=[self.text, self.gallery])
663
-
664
- ########################################
665
- def generate_video(message,
666
- extend_prompt,
667
- history,
668
- images,
669
- num_steps,
670
- num_frames,
671
- cfg_scale,
672
- fps,
673
- seed,
674
- progress=gr.Progress(track_tqdm=True)):
675
-
676
- from diffusers.utils import export_to_video
677
-
678
- generator = torch.Generator(device='cuda').manual_seed(seed)
679
- img_ids = re.findall('@(.*?)[ ,;.?$]', message)
680
- if len(img_ids) == 0:
681
- history.append((
682
- message,
683
- 'Sorry, no images were found in the prompt to be used as the first frame of the video.'
684
- ))
685
- while len(history) >= self.max_msgs:
686
- history.pop(0)
687
- return history, self.get_history(
688
- history), gr.update(), gr.update(visible=False)
689
-
690
- img_id = img_ids[0]
691
- prompt = re.sub(f'@{img_id}\s+', '', message)
692
-
693
- if extend_prompt:
694
- messages = copy.deepcopy(self.enhance_ctx)
695
- messages.append({
696
- 'role':
697
- 'user',
698
- 'content':
699
- f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
700
- })
701
- lock.acquire()
702
- outputs = self.enhancer(
703
- messages,
704
- max_new_tokens=200,
705
- )
706
-
707
- prompt = outputs[0]['generated_text'][-1]['content']
708
- print(prompt)
709
- lock.release()
710
-
711
- img_meta = images[img_id]
712
- img_path = img_meta['image']
713
- image = Image.open(img_path).convert('RGB')
714
-
715
- lock.acquire()
716
- video = self.i2v_pipe(
717
- prompt=prompt,
718
- image=image,
719
- num_videos_per_prompt=1,
720
- num_inference_steps=num_steps,
721
- num_frames=num_frames,
722
- guidance_scale=cfg_scale,
723
- generator=generator,
724
- ).frames[0]
725
- lock.release()
726
-
727
- out_video_path = export_to_video(video, fps=fps)
728
- history.append((
729
- f"Based on first frame @{img_id} and description '{prompt}', generate a video",
730
- 'This is generated video:'))
731
- history.append((None, out_video_path))
732
- while len(history) >= self.max_msgs:
733
- history.pop(0)
734
-
735
- return history, self.get_history(history), gr.update(
736
- value=''), gr.update(visible=False)
737
-
738
- self.video_gen_btn.click(
739
- generate_video,
740
- inputs=[
741
- self.text, self.extend_prompt, self.history, self.images,
742
- self.video_step, self.video_frames, self.video_cfg_scale,
743
- self.video_fps, self.video_seed
744
- ],
745
- outputs=[self.history, self.chatbot, self.text, self.gallery])
746
-
747
- ########################################
748
- @spaces.GPU(duration=120)
749
- def run_chat(
750
- message,
751
- legacy_image,
752
- ui_mode,
753
- use_ace,
754
- extend_prompt,
755
- history,
756
- images,
757
- use_history,
758
- history_result,
759
- negative_prompt,
760
- cfg_scale,
761
- rescale,
762
- refiner_prompt,
763
- refiner_scale,
764
- step,
765
- seed,
766
- output_h,
767
- output_w,
768
- video_auto,
769
- video_steps,
770
- video_frames,
771
- video_cfg_scale,
772
- video_fps,
773
- video_seed,
774
- progress=gr.Progress(track_tqdm=True)):
775
- legacy_img_ids = []
776
- if ui_mode == 'legacy':
777
- if legacy_image is not None:
778
- history, images, img_id = self.add_uploaded_image_to_history(
779
- legacy_image, history, images)
780
- legacy_img_ids.append(img_id)
781
- retry_msg = message
782
- gen_id = get_md5(message)[:12]
783
- save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
784
-
785
- img_ids = re.findall('@(.*?)[ ,;.?$]', message)
786
- history_io = None
787
-
788
- if len(img_ids) < 1:
789
- img_ids = legacy_img_ids
790
- for img_id in img_ids:
791
- if f'@{img_id}' not in message:
792
- message = f'@{img_id} ' + message
793
-
794
- new_message = message
795
-
796
- if len(img_ids) > 0:
797
- edit_image, edit_image_mask, edit_task = [], [], []
798
- for i, img_id in enumerate(img_ids):
799
- if img_id not in images:
800
- gr.Info(
801
- f'The input image ID {img_id} is not exist... Skip loading image.'
802
- )
803
- continue
804
- placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}'
805
- if placeholder not in new_message:
806
- new_message = re.sub(f'@{img_id}', placeholder,
807
- new_message)
808
- else:
809
- new_message = re.sub(f'@{img_id} ', "",
810
- new_message, 1)
811
- img_meta = images[img_id]
812
- img_path = img_meta['image']
813
- img_mask = img_meta['mask']
814
- img_mask_type = img_meta['mask_type']
815
- if img_mask_type is not None and img_mask_type == 'Composite':
816
- task = 'inpainting'
817
- else:
818
- task = ''
819
- edit_image.append(Image.open(img_path).convert('RGB'))
820
- edit_image_mask.append(
821
- Image.open(img_mask).
822
- convert('L') if img_mask is not None else None)
823
- edit_task.append(task)
824
-
825
- if use_history and (img_id in history_result):
826
- history_io = history_result[img_id]
827
-
828
- buffered = io.BytesIO()
829
- edit_image[0].save(buffered, format='PNG')
830
- img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
831
- img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
832
- pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
833
- else:
834
- pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
835
- edit_image = None
836
- edit_image_mask = None
837
- edit_task = ''
838
- if new_message == "":
839
- new_message = "a beautiful girl wear a skirt."
840
- print(new_message)
841
- imgs = self.pipe(
842
- image=edit_image,
843
- mask=edit_image_mask,
844
- task=edit_task,
845
- prompt=[new_message] *
846
- len(edit_image) if edit_image is not None else [new_message],
847
- negative_prompt=[negative_prompt] * len(edit_image)
848
- if edit_image is not None else [negative_prompt],
849
- history_io=history_io,
850
- output_height=output_h,
851
- output_width=output_w,
852
- sampler=self.pipe.input.get("sampler", "ddim"),
853
- sample_steps=step,
854
- guide_scale=cfg_scale,
855
- guide_rescale=rescale,
856
- seed=seed,
857
- refiner_prompt=refiner_prompt,
858
- refiner_scale=refiner_scale,
859
- use_ace=use_ace
860
- )
861
-
862
- img = imgs[0]
863
- img.save(save_path, format='JPEG')
864
-
865
- if history_io:
866
- history_io_new = copy.deepcopy(history_io)
867
- history_io_new['image'] += edit_image[:1]
868
- history_io_new['mask'] += edit_image_mask[:1]
869
- history_io_new['task'] += edit_task[:1]
870
- history_io_new['prompt'] += [new_message]
871
- history_io_new['image'] = history_io_new['image'][-5:]
872
- history_io_new['mask'] = history_io_new['mask'][-5:]
873
- history_io_new['task'] = history_io_new['task'][-5:]
874
- history_io_new['prompt'] = history_io_new['prompt'][-5:]
875
- history_result[gen_id] = history_io_new
876
- elif edit_image is not None and len(edit_image) > 0:
877
- history_io_new = {
878
- 'image': edit_image[:1],
879
- 'mask': edit_image_mask[:1],
880
- 'task': edit_task[:1],
881
- 'prompt': [new_message]
882
- }
883
- history_result[gen_id] = history_io_new
884
-
885
- w, h = img.size
886
- if w > h:
887
- tb_w = 128
888
- tb_h = int(h * tb_w / w)
889
- else:
890
- tb_h = 128
891
- tb_w = int(w * tb_h / h)
892
-
893
- thumbnail_path = os.path.join(self.cache_dir,
894
- f'{gen_id}_thumbnail.jpg')
895
- thumbnail = img.resize((tb_w, tb_h))
896
- thumbnail.save(thumbnail_path, format='JPEG')
897
-
898
- images[gen_id] = {
899
- 'image': save_path,
900
- 'mask': None,
901
- 'mask_type': None,
902
- 'thumbnail': thumbnail_path
903
- }
904
-
905
- buffered = io.BytesIO()
906
- img.convert('RGB').save(buffered, format='JPEG')
907
- img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
908
- img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
909
-
910
- history.append(
911
- (message,
912
- f'{pre_info} The generated image @{gen_id} is:\n {img_str}'))
913
-
914
- if video_auto:
915
- if video_seed is None or video_seed == -1:
916
- video_seed = random.randint(0, 10000000)
917
-
918
- lock.acquire()
919
- generator = torch.Generator(
920
- device='cuda').manual_seed(video_seed)
921
- pixel_values = load_image(img.convert('RGB'),
922
- max_num=self.llm_max_num).to(
923
- torch.bfloat16).cuda()
924
- prompt = self.captioner.chat(self.llm_tokenizer, pixel_values,
925
- self.llm_prompt,
926
- self.llm_generation_config)
927
- print(prompt)
928
- lock.release()
929
-
930
- if extend_prompt:
931
- messages = copy.deepcopy(self.enhance_ctx)
932
- messages.append({
933
- 'role':
934
- 'user',
935
- 'content':
936
- f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
937
- })
938
- lock.acquire()
939
- outputs = self.enhancer(
940
- messages,
941
- max_new_tokens=200,
942
- )
943
- prompt = outputs[0]['generated_text'][-1]['content']
944
- print(prompt)
945
- lock.release()
946
-
947
- lock.acquire()
948
- video = self.i2v_pipe(
949
- prompt=prompt,
950
- image=img,
951
- num_videos_per_prompt=1,
952
- num_inference_steps=video_steps,
953
- num_frames=video_frames,
954
- guidance_scale=video_cfg_scale,
955
- generator=generator,
956
- ).frames[0]
957
- lock.release()
958
-
959
- out_video_path = export_to_video(video, fps=video_fps)
960
- history.append((
961
- f"Based on first frame @{gen_id} and description '{prompt}', generate a video",
962
- 'This is generated video:'))
963
- history.append((None, out_video_path))
964
-
965
- while len(history) >= self.max_msgs:
966
- history.pop(0)
967
-
968
- return (history, images, gr.Image(value=save_path),
969
- history_result, self.get_history(
970
- history), gr.update(), gr.update(
971
- visible=False), retry_msg)
972
-
973
- chat_inputs = [
974
- self.legacy_image_uploader, self.ui_mode, self.use_ace,
975
- self.extend_prompt, self.history, self.images, self.use_history,
976
- self.history_result, self.negative_prompt, self.cfg_scale,
977
- self.rescale, self.refiner_prompt, self.refiner_scale,
978
- self.step, self.seed, self.output_height,
979
- self.output_width, self.video_auto, self.video_step,
980
- self.video_frames, self.video_cfg_scale, self.video_fps,
981
- self.video_seed
982
- ]
983
-
984
- chat_outputs = [
985
- self.history, self.images, self.legacy_image_viewer,
986
- self.history_result, self.chatbot,
987
- self.text, self.gallery, self.retry_msg
988
- ]
989
-
990
- self.chat_btn.click(run_chat,
991
- inputs=[self.text] + chat_inputs,
992
- outputs=chat_outputs)
993
-
994
- self.text.submit(run_chat,
995
- inputs=[self.text] + chat_inputs,
996
- outputs=chat_outputs)
997
-
998
- def retry_fn(*args):
999
- return run_chat(*args)
1000
-
1001
- self.retry_btn.click(retry_fn,
1002
- inputs=[self.retry_msg] + chat_inputs,
1003
- outputs=chat_outputs)
1004
-
1005
- ########################################
1006
- @spaces.GPU(duration=120)
1007
- def run_example(task, img, img_mask, ref1, prompt, seed):
1008
- edit_image, edit_image_mask, edit_task = [], [], []
1009
- if img is not None:
1010
- w, h = img.size
1011
- if w > 2048:
1012
- ratio = w / 2048.
1013
- w = 2048
1014
- h = int(h / ratio)
1015
- if h > 2048:
1016
- ratio = h / 2048.
1017
- h = 2048
1018
- w = int(w / ratio)
1019
- img = img.resize((w, h))
1020
- edit_image.append(img)
1021
- if img_mask is not None:
1022
- img_mask = img_mask if np.sum(np.array(img_mask)) > 0 else None
1023
- edit_image_mask.append(
1024
- img_mask if img_mask is not None else None)
1025
- edit_task.append(task)
1026
- if ref1 is not None:
1027
- ref1 = ref1 if np.sum(np.array(ref1)) > 0 else None
1028
- if ref1 is not None:
1029
- edit_image.append(ref1)
1030
- edit_image_mask.append(None)
1031
- edit_task.append('')
1032
-
1033
- buffered = io.BytesIO()
1034
- img.save(buffered, format='PNG')
1035
- img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1036
- img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1037
- pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
1038
- else:
1039
- pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
1040
- edit_image = None
1041
- edit_image_mask = None
1042
- edit_task = ''
1043
-
1044
- img_num = len(edit_image) if edit_image is not None else 1
1045
- imgs = self.pipe(
1046
- image=edit_image,
1047
- mask=edit_image_mask,
1048
- task=edit_task,
1049
- prompt=[prompt] * img_num,
1050
- negative_prompt=[''] * img_num,
1051
- seed=seed,
1052
- refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
1053
- refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
1054
- )
1055
-
1056
- img = imgs[0]
1057
- buffered = io.BytesIO()
1058
- img.convert('RGB').save(buffered, format='JPEG')
1059
- img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1060
- img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1061
- history = [(prompt,
1062
- f'{pre_info} The generated image is:\n {img_str}')]
1063
-
1064
- img_id = get_md5(img_b64)[:12]
1065
- save_path = os.path.join(self.cache_dir, f'{img_id}.jpg')
1066
- img.convert('RGB').save(save_path)
1067
-
1068
- return self.get_history(history), gr.update(value=prompt), gr.update(
1069
- visible=False), gr.update(value=save_path), gr.update(value=-1)
1070
-
1071
- with self.eg:
1072
- self.example_task = gr.Text(label='Task Name',
1073
- value='',
1074
- visible=False)
1075
- self.example_image = gr.Image(label='Edit Image',
1076
- type='pil',
1077
- image_mode='RGB',
1078
- visible=False)
1079
- self.example_mask = gr.Image(label='Edit Image Mask',
1080
- type='pil',
1081
- image_mode='L',
1082
- visible=False)
1083
- self.example_ref_im1 = gr.Image(label='Ref Image',
1084
- type='pil',
1085
- image_mode='RGB',
1086
- visible=False)
1087
-
1088
- self.examples = gr.Examples(
1089
- fn=run_example,
1090
- examples=self.chatbot_examples,
1091
- inputs=[
1092
- self.example_task, self.example_image, self.example_mask,
1093
- self.example_ref_im1, self.text, self.seed
1094
- ],
1095
- outputs=[self.chatbot, self.text, self.gallery, self.legacy_image_viewer, self.seed],
1096
- examples_per_page=4,
1097
- cache_examples=False,
1098
- run_on_click=True)
1099
-
1100
- ########################################
1101
- def upload_image():
1102
- return (gr.update(visible=True,
1103
- scale=1), gr.update(visible=True, scale=1),
1104
- gr.update(visible=True), gr.update(visible=False),
1105
- gr.update(visible=False), gr.update(visible=False),
1106
- gr.update(visible=True))
1107
-
1108
- self.upload_btn.click(upload_image,
1109
- inputs=[],
1110
- outputs=[
1111
- self.chat_page, self.editor_page,
1112
- self.upload_tab, self.edit_tab,
1113
- self.image_view_tab, self.video_view_tab,
1114
- self.upload_tabs
1115
- ])
1116
-
1117
- ########################################
1118
- def edit_image(evt: gr.SelectData):
1119
- if isinstance(evt.value, str):
1120
- img_b64s = re.findall(
1121
- '<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
1122
- evt.value)
1123
- imgs = [
1124
- Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
1125
- for i in img_b64s
1126
- ]
1127
- if len(imgs) > 0:
1128
- if len(imgs) == 2:
1129
- if self.gradio_version >= '5.0.0':
1130
- view_img = copy.deepcopy(imgs[-1])
1131
- else:
1132
- view_img = copy.deepcopy(imgs)
1133
- edit_img = copy.deepcopy(imgs[-1])
1134
- else:
1135
- if self.gradio_version >= '5.0.0':
1136
- view_img = copy.deepcopy(imgs[-1])
1137
- else:
1138
- view_img = [
1139
- copy.deepcopy(imgs[-1]),
1140
- copy.deepcopy(imgs[-1])
1141
- ]
1142
- edit_img = copy.deepcopy(imgs[-1])
1143
-
1144
- return (gr.update(visible=True,
1145
- scale=1), gr.update(visible=True,
1146
- scale=1),
1147
- gr.update(visible=False), gr.update(visible=True),
1148
- gr.update(visible=True), gr.update(visible=False),
1149
- gr.update(value=edit_img),
1150
- gr.update(value=view_img), gr.update(value=None),
1151
- gr.update(visible=True))
1152
- else:
1153
- return (gr.update(), gr.update(), gr.update(), gr.update(),
1154
- gr.update(), gr.update(), gr.update(), gr.update(),
1155
- gr.update(), gr.update())
1156
- elif isinstance(evt.value, dict) and evt.value.get(
1157
- 'component', '') == 'video':
1158
- value = evt.value['value']['video']['path']
1159
- return (gr.update(visible=True,
1160
- scale=1), gr.update(visible=True, scale=1),
1161
- gr.update(visible=False), gr.update(visible=False),
1162
- gr.update(visible=False), gr.update(visible=True),
1163
- gr.update(), gr.update(), gr.update(value=value),
1164
- gr.update())
1165
- else:
1166
- return (gr.update(), gr.update(), gr.update(), gr.update(),
1167
- gr.update(), gr.update(), gr.update(), gr.update(),
1168
- gr.update(), gr.update())
1169
-
1170
- self.chatbot.select(edit_image,
1171
- outputs=[
1172
- self.chat_page, self.editor_page,
1173
- self.upload_tab, self.edit_tab,
1174
- self.image_view_tab, self.video_view_tab,
1175
- self.image_editor, self.image_viewer,
1176
- self.video_viewer, self.edit_tabs
1177
- ])
1178
-
1179
- if self.gradio_version < '5.0.0':
1180
- self.image_viewer.change(lambda x: x,
1181
- inputs=self.image_viewer,
1182
- outputs=self.image_viewer)
1183
-
1184
- ########################################
1185
- def submit_upload_image(image, history, images):
1186
- history, images, _ = self.add_uploaded_image_to_history(
1187
- image, history, images)
1188
- return gr.update(visible=False), gr.update(
1189
- visible=True), gr.update(
1190
- value=self.get_history(history)), history, images
1191
-
1192
- self.sub_btn_1.click(
1193
- submit_upload_image,
1194
- inputs=[self.image_uploader, self.history, self.images],
1195
- outputs=[
1196
- self.editor_page, self.chat_page, self.chatbot, self.history,
1197
- self.images
1198
- ])
1199
-
1200
- ########################################
1201
- def submit_edit_image(imagemask, mask_type, history, images):
1202
- history, images = self.add_edited_image_to_history(
1203
- imagemask, mask_type, history, images)
1204
- return gr.update(visible=False), gr.update(
1205
- visible=True), gr.update(
1206
- value=self.get_history(history)), history, images
1207
-
1208
- self.sub_btn_2.click(submit_edit_image,
1209
- inputs=[
1210
- self.image_editor, self.mask_type,
1211
- self.history, self.images
1212
- ],
1213
- outputs=[
1214
- self.editor_page, self.chat_page,
1215
- self.chatbot, self.history, self.images
1216
- ])
1217
-
1218
- ########################################
1219
- def exit_edit():
1220
- return gr.update(visible=False), gr.update(visible=True, scale=3)
1221
-
1222
- self.ext_btn_1.click(exit_edit,
1223
- outputs=[self.editor_page, self.chat_page])
1224
- self.ext_btn_2.click(exit_edit,
1225
- outputs=[self.editor_page, self.chat_page])
1226
- self.ext_btn_3.click(exit_edit,
1227
- outputs=[self.editor_page, self.chat_page])
1228
- self.ext_btn_4.click(exit_edit,
1229
- outputs=[self.editor_page, self.chat_page])
1230
-
1231
- ########################################
1232
- def update_mask_type_info(mask_type):
1233
- if mask_type == 'Background':
1234
- info = 'Background mode will not erase the visual content in the mask area'
1235
- visible = False
1236
- elif mask_type == 'Composite':
1237
- info = 'Composite mode will erase the visual content in the mask area'
1238
- visible = False
1239
- elif mask_type == 'Outpainting':
1240
- info = 'Outpaint mode is used for preparing input image for outpainting task'
1241
- visible = True
1242
- return (gr.update(
1243
- visible=True,
1244
- value=
1245
- f"<div style='background-color: white; padding-left: 15px; color: grey;'>{info}</div>"
1246
- ), gr.update(visible=visible))
1247
-
1248
- self.mask_type.change(update_mask_type_info,
1249
- inputs=self.mask_type,
1250
- outputs=[self.mask_type_info, self.outpaint_tab])
1251
-
1252
- ########################################
1253
- def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio,
1254
- image):
1255
- img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB)
1256
- h, w = img.shape[:2]
1257
- new_h = int(h * (top_ratio + bottom_ratio + 1))
1258
- new_w = int(w * (left_ratio + right_ratio + 1))
1259
- start_h = int(h * top_ratio)
1260
- start_w = int(w * left_ratio)
1261
- new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)
1262
- new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255
1263
- new_img[start_h:start_h + h, start_w:start_w + w, :] = img
1264
- new_mask[start_h:start_h + h, start_w:start_w + w] = 0
1265
- layer = np.concatenate([new_img, new_mask], axis=2)
1266
- value = {
1267
- 'background': new_img,
1268
- 'composite': new_img,
1269
- 'layers': [layer]
1270
- }
1271
- return gr.update(value=value)
1272
-
1273
- self.img_pad_btn.click(extend_image,
1274
- inputs=[
1275
- self.top_ext, self.bottom_ext,
1276
- self.left_ext, self.right_ext,
1277
- self.image_editor
1278
- ],
1279
- outputs=self.image_editor)
1280
-
1281
- ########################################
1282
- def clear_chat(history, images, history_result):
1283
- history.clear()
1284
- images.clear()
1285
- history_result.clear()
1286
- return history, images, history_result, self.get_history(history)
1287
-
1288
- self.clear_btn.click(
1289
- clear_chat,
1290
- inputs=[self.history, self.images, self.history_result],
1291
- outputs=[
1292
- self.history, self.images, self.history_result, self.chatbot
1293
- ])
1294
-
1295
- def get_history(self, history):
1296
- info = []
1297
- for item in history:
1298
- new_item = [None, None]
1299
- if isinstance(item[0], str) and item[0].endswith('.mp4'):
1300
- new_item[0] = gr.Video(item[0], format='mp4')
1301
- else:
1302
- new_item[0] = item[0]
1303
- if isinstance(item[1], str) and item[1].endswith('.mp4'):
1304
- new_item[1] = gr.Video(item[1], format='mp4')
1305
- else:
1306
- new_item[1] = item[1]
1307
- info.append(new_item)
1308
- return info
1309
-
1310
- def generate_random_string(self, length=20):
1311
- letters_and_digits = string.ascii_letters + string.digits
1312
- random_string = ''.join(
1313
- random.choice(letters_and_digits) for i in range(length))
1314
- return random_string
1315
-
1316
- def add_edited_image_to_history(self, image, mask_type, history, images):
1317
- if mask_type == 'Composite':
1318
- img = Image.fromarray(image['composite'])
1319
- else:
1320
- img = Image.fromarray(image['background'])
1321
-
1322
- img_id = get_md5(self.generate_random_string())[:12]
1323
- save_path = os.path.join(self.cache_dir, f'{img_id}.png')
1324
- img.convert('RGB').save(save_path)
1325
-
1326
- mask = image['layers'][0][:, :, 3]
1327
- mask = Image.fromarray(mask).convert('RGB')
1328
- mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png')
1329
- mask.save(mask_path)
1330
-
1331
- w, h = img.size
1332
- if w > h:
1333
- tb_w = 128
1334
- tb_h = int(h * tb_w / w)
1335
- else:
1336
- tb_h = 128
1337
- tb_w = int(w * tb_h / h)
1338
-
1339
- if mask_type == 'Background':
1340
- comp_mask = np.array(mask, dtype=np.uint8)
1341
- mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) *
1342
- 0.6).astype(np.uint8)
1343
- comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2)
1344
- thumbnail = Image.alpha_composite(
1345
- img.convert('RGBA'),
1346
- Image.fromarray(comp_mask).convert('RGBA')).convert('RGB')
1347
- else:
1348
- thumbnail = img.convert('RGB')
1349
-
1350
- thumbnail_path = os.path.join(self.cache_dir,
1351
- f'{img_id}_thumbnail.jpg')
1352
- thumbnail = thumbnail.resize((tb_w, tb_h))
1353
- thumbnail.save(thumbnail_path, format='JPEG')
1354
-
1355
- buffered = io.BytesIO()
1356
- img.convert('RGB').save(buffered, format='PNG')
1357
- img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1358
- img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1359
-
1360
- buffered = io.BytesIO()
1361
- mask.convert('RGB').save(buffered, format='PNG')
1362
- mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1363
- mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
1364
-
1365
- images[img_id] = {
1366
- 'image': save_path,
1367
- 'mask': mask_path,
1368
- 'mask_type': mask_type,
1369
- 'thumbnail': thumbnail_path
1370
- }
1371
- history.append((
1372
- None,
1373
- f'This is edited image and mask:\n {img_str} {mask_str} image ID is: {img_id}'
1374
- ))
1375
- return history, images
1376
-
1377
- def add_uploaded_image_to_history(self, img, history, images):
1378
- img_id = get_md5(self.generate_random_string())[:12]
1379
- save_path = os.path.join(self.cache_dir, f'{img_id}.png')
1380
- w, h = img.size
1381
- if w > 2048:
1382
- ratio = w / 2048.
1383
- w = 2048
1384
- h = int(h / ratio)
1385
- if h > 2048:
1386
- ratio = h / 2048.
1387
- h = 2048
1388
- w = int(w / ratio)
1389
- img = img.resize((w, h))
1390
- img.save(save_path)
1391
-
1392
- w, h = img.size
1393
- if w > h:
1394
- tb_w = 128
1395
- tb_h = int(h * tb_w / w)
1396
- else:
1397
- tb_h = 128
1398
- tb_w = int(w * tb_h / h)
1399
- thumbnail_path = os.path.join(self.cache_dir,
1400
- f'{img_id}_thumbnail.jpg')
1401
- thumbnail = img.resize((tb_w, tb_h))
1402
- thumbnail.save(thumbnail_path, format='JPEG')
1403
-
1404
- images[img_id] = {
1405
- 'image': save_path,
1406
- 'mask': None,
1407
- 'mask_type': None,
1408
- 'thumbnail': thumbnail_path
1409
- }
1410
-
1411
- buffered = io.BytesIO()
1412
- img.convert('RGB').save(buffered, format='PNG')
1413
- img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
1414
- img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
1415
-
1416
- history.append(
1417
- (None,
1418
- f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
1419
- return history, images, img_id
1420
-
1421
-
1422
- if __name__ == '__main__':
1423
- cfg = "config/chatbot_ui.yaml"
1424
- with gr.Blocks() as demo:
1425
- chatbot = ChatBotUI(cfg)
1426
- chatbot.create_ui()
1427
- chatbot.set_callbacks()
1428
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/chatbot_ui.yaml DELETED
@@ -1,25 +0,0 @@
1
- WORK_DIR: chatbot
2
- FILE_SYSTEM:
3
- - NAME: LocalFs
4
- TEMP_DIR: ./cache
5
- - NAME: ModelscopeFs
6
- TEMP_DIR: ./cache
7
- - NAME: HuggingfaceFs
8
- TEMP_DIR: ./cache
9
- #
10
- ENABLE_I2V: False
11
- SKIP_EXAMPLES: True
12
- #
13
- MODEL:
14
- EDIT_MODEL:
15
- MODEL_CFG_DIR: config/models/
16
- I2V:
17
- MODEL_NAME: CogVideoX-5b-I2V
18
- MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
19
- CAPTIONER:
20
- MODEL_NAME: InternVL2-2B
21
- MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
22
- PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
23
- ENHANCER:
24
- MODEL_NAME: Meta-Llama-3.1-8B-Instruct
25
- MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config/models/ace_flux_dev.yaml DELETED
@@ -1,187 +0,0 @@
1
- NAME: ACE_FLUX.1_dev
2
- IS_DEFAULT: True
3
- USE_DYNAMIC_MODEL: False
4
- INFERENCE_TYPE: ACE_FLUX
5
- MAX_SEQ_LENGTH: 3072
6
- SRC_MAX_SEQ_LENGTH: 2048
7
- DEFAULT_PARAS:
8
- PARAS:
9
- #
10
- INPUT:
11
- INPUT_IMAGE:
12
- INPUT_MASK:
13
- TASK:
14
- PROMPT: ""
15
- OUTPUT_HEIGHT: 1024
16
- OUTPUT_WIDTH: 1024
17
- SAMPLER: flow_euler
18
- SAMPLE_STEPS: 20
19
- GUIDE_SCALE: 3.5
20
- SEED: -1
21
- TAR_INDEX: 0
22
- ALIGN: False
23
- OUTPUT:
24
- LATENT:
25
- IMAGES:
26
- SEED:
27
- MODULES_PARAS:
28
- FIRST_STAGE_MODEL:
29
- FUNCTION:
30
- - NAME: encode
31
- DTYPE: bfloat16
32
- INPUT: [ "IMAGE" ]
33
- - NAME: decode
34
- DTYPE: bfloat16
35
- INPUT: [ "LATENT" ]
36
- PARAS:
37
- SCALE_FACTOR: 1.5305
38
- SHIFT_FACTOR: 0.0609
39
- SIZE_FACTOR: 8
40
- DIFFUSION_MODEL:
41
- FUNCTION:
42
- - NAME: forward
43
- DTYPE: bfloat16
44
- INPUT: [ "SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE" ]
45
- COND_STAGE_MODEL:
46
- FUNCTION:
47
- - NAME: encode_list
48
- DTYPE: bfloat16
49
- INPUT: [ "PROMPT" ]
50
- #
51
- MODEL:
52
- NAME: LatentDiffusionACEFlux
53
- PARAMETERIZATION: rf
54
- PRETRAINED_MODEL:
55
- IGNORE_KEYS: [ ]
56
- SIZE_FACTOR: 8
57
- TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
58
- USE_TEXT_POS_EMBEDDINGS: True
59
- DIFFUSION:
60
- # NAME DESCRIPTION: TYPE: default: 'DiffusionFluxRF'
61
- NAME: DiffusionFluxRF
62
- PREDICTION_TYPE: raw
63
- # NOISE_SCHEDULER DESCRIPTION: TYPE: default: ''
64
- NOISE_SCHEDULER:
65
- NAME: FlowMatchFluxShiftScheduler
66
- SHIFT: True
67
- SIGMOID_SCALE: 1
68
- BASE_SHIFT: 0.5
69
- MAX_SHIFT: 1.15
70
- #
71
- DIFFUSION_MODEL:
72
- # NAME DESCRIPTION: TYPE: default: 'Flux'
73
- NAME: ACEFlux
74
- PRETRAINED_MODEL: hf://black-forest-labs/[email protected]
75
- SWIFT_LORA_MODEL: ["hf://scepter-studio/ACE-FLUX.1-dev@ace_flux.1_dev_lora.bin"]
76
- # IN_CHANNELS DESCRIPTION: model's input channels. TYPE: int default: 64
77
- IN_CHANNELS: 64
78
- # HIDDEN_SIZE DESCRIPTION: model's hidden size. TYPE: int default: 1024
79
- HIDDEN_SIZE: 3072
80
- # NUM_HEADS DESCRIPTION: number of heads in the transformer. TYPE: int default: 16
81
- NUM_HEADS: 24
82
- # AXES_DIM DESCRIPTION: dimensions of the axes of the positional encoding. TYPE: list default: [16, 56, 56]
83
- AXES_DIM: [ 16, 56, 56 ]
84
- # THETA DESCRIPTION: theta for positional encoding. TYPE: int default: 10000
85
- THETA: 10000
86
- # VEC_IN_DIM DESCRIPTION: dimension of the vector input. TYPE: int default: 768
87
- VEC_IN_DIM: 768
88
- # GUIDANCE_EMBED DESCRIPTION: whether to use guidance embedding. TYPE: bool default: False
89
- GUIDANCE_EMBED: True
90
- # CONTEXT_IN_DIM DESCRIPTION: dimension of the context input. TYPE: int default: 4096
91
- CONTEXT_IN_DIM: 4096
92
- # MLP_RATIO DESCRIPTION: ratio of mlp hidden size to hidden size. TYPE: float default: 4.0
93
- MLP_RATIO: 4.0
94
- # QKV_BIAS DESCRIPTION: whether to use bias in qkv projection. TYPE: bool default: True
95
- QKV_BIAS: True
96
- # DEPTH DESCRIPTION: number of transformer blocks. TYPE: int default: 19
97
- DEPTH: 19
98
- # DEPTH_SINGLE_BLOCKS DESCRIPTION: number of transformer blocks in the single stream block. TYPE: int default: 38
99
- DEPTH_SINGLE_BLOCKS: 38
100
- ATTN_BACKEND: pytorch
101
-
102
- #
103
- FIRST_STAGE_MODEL:
104
- NAME: AutoencoderKLFlux
105
- EMBED_DIM: 16
106
- PRETRAINED_MODEL: hf://black-forest-labs/[email protected]
107
- IGNORE_KEYS: [ ]
108
- BATCH_SIZE: 8
109
- USE_CONV: False
110
- SCALE_FACTOR: 0.3611
111
- SHIFT_FACTOR: 0.1159
112
- #
113
- ENCODER:
114
- NAME: Encoder
115
- USE_CHECKPOINT: True
116
- CH: 128
117
- OUT_CH: 3
118
- NUM_RES_BLOCKS: 2
119
- IN_CHANNELS: 3
120
- ATTN_RESOLUTIONS: [ ]
121
- CH_MULT: [ 1, 2, 4, 4 ]
122
- Z_CHANNELS: 16
123
- DOUBLE_Z: True
124
- DROPOUT: 0.0
125
- RESAMP_WITH_CONV: True
126
- #
127
- DECODER:
128
- NAME: Decoder
129
- USE_CHECKPOINT: True
130
- CH: 128
131
- OUT_CH: 3
132
- NUM_RES_BLOCKS: 2
133
- IN_CHANNELS: 3
134
- ATTN_RESOLUTIONS: [ ]
135
- CH_MULT: [ 1, 2, 4, 4 ]
136
- Z_CHANNELS: 16
137
- DROPOUT: 0.0
138
- RESAMP_WITH_CONV: True
139
- GIVE_PRE_END: False
140
- TANH_OUT: False
141
- #
142
- COND_STAGE_MODEL:
143
- # NAME DESCRIPTION: TYPE: default: 'T5PlusClipFluxEmbedder'
144
- NAME: T5ACEPlusClipFluxEmbedder
145
- # T5_MODEL DESCRIPTION: TYPE: default: ''
146
- T5_MODEL:
147
- # NAME DESCRIPTION: TYPE: default: 'HFEmbedder'
148
- NAME: ACEHFEmbedder
149
- # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
150
- HF_MODEL_CLS: T5EncoderModel
151
- # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
152
- MODEL_PATH: hf://black-forest-labs/FLUX.1-dev@text_encoder_2/
153
- # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
154
- HF_TOKENIZER_CLS: T5Tokenizer
155
- # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
156
- TOKENIZER_PATH: hf://black-forest-labs/FLUX.1-dev@tokenizer_2/
157
- ADDED_IDENTIFIER: [ '<img>','{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
158
- # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
159
- MAX_LENGTH: 512
160
- # OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
161
- OUTPUT_KEY: last_hidden_state
162
- # D_TYPE DESCRIPTION: dtype TYPE: str default: 'bfloat16'
163
- D_TYPE: bfloat16
164
- # BATCH_INFER DESCRIPTION: batch infer TYPE: bool default: False
165
- BATCH_INFER: False
166
- CLEAN: whitespace
167
- # CLIP_MODEL DESCRIPTION: TYPE: default: ''
168
- CLIP_MODEL:
169
- # NAME DESCRIPTION: TYPE: default: 'HFEmbedder'
170
- NAME: ACEHFEmbedder
171
- # HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
172
- HF_MODEL_CLS: CLIPTextModel
173
- # MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
174
- MODEL_PATH: hf://black-forest-labs/FLUX.1-dev@text_encoder/
175
- # HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
176
- HF_TOKENIZER_CLS: CLIPTokenizer
177
- # TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
178
- TOKENIZER_PATH: hf://black-forest-labs/FLUX.1-dev@tokenizer/
179
- # MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
180
- MAX_LENGTH: 77
181
- # OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
182
- OUTPUT_KEY: pooler_output
183
- # D_TYPE DESCRIPTION: dtype TYPE: str default: 'bfloat16'
184
- D_TYPE: bfloat16
185
- # BATCH_INFER DESCRIPTION: batch infer TYPE: bool default: False
186
- BATCH_INFER: True
187
- CLEAN: whitespace
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
example.py DELETED
@@ -1,370 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import os
4
- from PIL import Image
5
- from scepter.modules.utils.file_system import FS
6
-
7
-
8
- def download_image(image, local_path=None):
9
- if not FS.exists(local_path):
10
- local_path = FS.get_from(image, local_path=local_path)
11
- if local_path.split(".")[-1] in ['jpg', 'jpeg']:
12
- im = Image.open(local_path).convert("RGB")
13
- im.save(local_path, format='JPEG')
14
- return local_path
15
-
16
-
17
- def get_examples(cache_dir):
18
- print('Downloading Examples ...')
19
- examples = [
20
- [
21
- 'Facial Editing',
22
- download_image(
23
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e33edc106953.png?raw=true',
24
- os.path.join(cache_dir, 'examples/e33edc106953.jpg')), None,
25
- None, '{image} let the man smile', 6666
26
- ],
27
- [
28
- 'Facial Editing',
29
- download_image(
30
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5d2bcc91a3e9.png?raw=true',
31
- os.path.join(cache_dir, 'examples/5d2bcc91a3e9.jpg')), None,
32
- None, 'let the man in {image} wear sunglasses', 9999
33
- ],
34
- [
35
- 'Facial Editing',
36
- download_image(
37
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a52eac708bd.png?raw=true',
38
- os.path.join(cache_dir, 'examples/3a52eac708bd.jpg')), None,
39
- None, '{image} red hair', 9999
40
- ],
41
- [
42
- 'Facial Editing',
43
- download_image(
44
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3f4dc464a0ea.png?raw=true',
45
- os.path.join(cache_dir, 'examples/3f4dc464a0ea.jpg')), None,
46
- None, '{image} let the man serious', 99999
47
- ],
48
- [
49
- 'Controllable Generation',
50
- download_image(
51
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',
52
- os.path.join(cache_dir,
53
- 'examples/131ca90fd2a9.jpg')), None, None,
54
- '"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.',
55
- 613725
56
- ],
57
- [
58
- 'Render Text',
59
- download_image(
60
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',
61
- os.path.join(cache_dir, 'examples/33e9f27c2c48.jpg')),
62
- download_image(
63
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',
64
- os.path.join(cache_dir,
65
- 'examples/33e9f27c2c48_mask.jpg')), None,
66
- 'Put the text "C A T" at the position marked by mask in the {image}',
67
- 6666
68
- ],
69
- [
70
- 'Style Transfer',
71
- download_image(
72
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',
73
- os.path.join(cache_dir, 'examples/9e73e7eeef55.jpg')), None,
74
- download_image(
75
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',
76
- os.path.join(cache_dir, 'examples/2e02975293d6.jpg')),
77
- 'edit {image} based on the style of {image1} ', 99999
78
- ],
79
- [
80
- 'Outpainting',
81
- download_image(
82
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',
83
- os.path.join(cache_dir, 'examples/f2b22c08be3f.jpg')),
84
- download_image(
85
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',
86
- os.path.join(cache_dir,
87
- 'examples/f2b22c08be3f_mask.jpg')), None,
88
- 'Could the {image} be widened within the space designated by mask, while retaining the original?',
89
- 6666
90
- ],
91
- [
92
- 'Image Segmentation',
93
- download_image(
94
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',
95
- os.path.join(cache_dir, 'examples/db3ebaa81899.jpg')), None,
96
- None, '{image} Segmentation', 6666
97
- ],
98
- [
99
- 'Depth Estimation',
100
- download_image(
101
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',
102
- os.path.join(cache_dir, 'examples/f1927c4692ba.jpg')), None,
103
- None, '{image} Depth Estimation', 6666
104
- ],
105
- [
106
- 'Pose Estimation',
107
- download_image(
108
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',
109
- os.path.join(cache_dir, 'examples/014e5bf3b4d1.jpg')), None,
110
- None, '{image} distinguish the poses of the figures', 999999
111
- ],
112
- [
113
- 'Scribble Extraction',
114
- download_image(
115
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',
116
- os.path.join(cache_dir, 'examples/5f59a202f8ac.jpg')), None,
117
- None, 'Generate a scribble of {image}, please.', 6666
118
- ],
119
- [
120
- 'Mosaic',
121
- download_image(
122
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',
123
- os.path.join(cache_dir, 'examples/3a2f52361eea.jpg')), None,
124
- None, 'Adapt {image} into a mosaic representation.', 6666
125
- ],
126
- [
127
- 'Edge map Extraction',
128
- download_image(
129
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',
130
- os.path.join(cache_dir, 'examples/b9d1e519d6e5.jpg')), None,
131
- None, 'Get the edge-enhanced result for {image}.', 6666
132
- ],
133
- [
134
- 'Grayscale',
135
- download_image(
136
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',
137
- os.path.join(cache_dir, 'examples/c4ebbe2ba29b.jpg')), None,
138
- None, 'transform {image} into a black and white one', 6666
139
- ],
140
- [
141
- 'Contour Extraction',
142
- download_image(
143
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',
144
- os.path.join(cache_dir,
145
- 'examples/19652d0f6c4b.jpg')), None, None,
146
- 'Would you be able to make a contour picture from {image} for me?',
147
- 6666
148
- ],
149
- [
150
- 'Controllable Generation',
151
- download_image(
152
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',
153
- os.path.join(cache_dir,
154
- 'examples/249cda2844b7.jpg')), None, None,
155
- 'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.',
156
- 6666
157
- ],
158
- [
159
- 'Controllable Generation',
160
- download_image(
161
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',
162
- os.path.join(cache_dir,
163
- 'examples/411f6c4b8e6c.jpg')), None, None,
164
- 'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image',
165
- 999999
166
- ],
167
- [
168
- 'Controllable Generation',
169
- download_image(
170
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',
171
- os.path.join(cache_dir,
172
- 'examples/a35c96ed137a.jpg')), None, None,
173
- 'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."',
174
- 3599999
175
- ],
176
- [
177
- 'Controllable Generation',
178
- download_image(
179
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',
180
- os.path.join(cache_dir,
181
- 'examples/dcb2fc86f1ce.jpg')), None, None,
182
- 'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description',
183
- 6666
184
- ],
185
- [
186
- 'Controllable Generation',
187
- download_image(
188
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',
189
- os.path.join(cache_dir,
190
- 'examples/4cd4ee494962.jpg')), None, None,
191
- 'make this {image} colorful as per the "beautiful sunflowers"',
192
- 6666
193
- ],
194
- [
195
- 'Controllable Generation',
196
- download_image(
197
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',
198
- os.path.join(cache_dir,
199
- 'examples/a47e3a9cd166.jpg')), None, None,
200
- 'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.',
201
- 613725
202
- ],
203
- [
204
- 'Controllable Generation',
205
- download_image(
206
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',
207
- os.path.join(cache_dir,
208
- 'examples/d890ed8a3ac2.jpg')), None, None,
209
- 'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."',
210
- 6666
211
- ],
212
- [
213
- 'Image Denoising',
214
- download_image(
215
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',
216
- os.path.join(cache_dir,
217
- 'examples/0844a686a179.jpg')), None, None,
218
- 'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',
219
- 6666
220
- ],
221
- [
222
- 'Inpainting',
223
- download_image(
224
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',
225
- os.path.join(cache_dir, 'examples/fa91b6b7e59b.jpg')),
226
- download_image(
227
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',
228
- os.path.join(cache_dir,
229
- 'examples/fa91b6b7e59b_mask.jpg')), None,
230
- 'Ensure to overhaul the parts of the {image} indicated by the mask.',
231
- 6666
232
- ],
233
- [
234
- 'Inpainting',
235
- download_image(
236
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',
237
- os.path.join(cache_dir, 'examples/632899695b26.jpg')),
238
- download_image(
239
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',
240
- os.path.join(cache_dir,
241
- 'examples/632899695b26_mask.jpg')), None,
242
- 'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"',
243
- 6666
244
- ],
245
- [
246
- 'General Editing',
247
- download_image(
248
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',
249
- os.path.join(cache_dir,
250
- 'examples/354d17594afe.jpg')), None, None,
251
- '{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.',
252
- 6666
253
- ],
254
- [
255
- 'General Editing',
256
- download_image(
257
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',
258
- os.path.join(cache_dir,
259
- 'examples/38946455752b.jpg')), None, None,
260
- '{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged',
261
- 6669
262
- ],
263
- [
264
- 'Facial Editing',
265
- download_image(
266
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',
267
- os.path.join(cache_dir,
268
- 'examples/3ba5202f0cd8.jpg')), None, None,
269
- 'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',
270
- 99999
271
- ],
272
- [
273
- 'Facial Editing',
274
- download_image(
275
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',
276
- os.path.join(cache_dir, 'examples/369365b94725.jpg')), None,
277
- None, '{image} Make her looking at the camera', 6666
278
- ],
279
- [
280
- 'Facial Editing',
281
- download_image(
282
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',
283
- os.path.join(cache_dir, 'examples/92751f2e4a0e.jpg')), None,
284
- None, '{image} Remove the smile from his face', 9899999
285
- ],
286
- [
287
- 'Remove Text',
288
- download_image(
289
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',
290
- os.path.join(cache_dir, 'examples/8530a6711b2e.jpg')), None,
291
- None, 'Aim to remove any textual element in {image}', 6666
292
- ],
293
- [
294
- 'Remove Text',
295
- download_image(
296
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',
297
- os.path.join(cache_dir, 'examples/c4d7fb28f8f6.jpg')),
298
- download_image(
299
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',
300
- os.path.join(cache_dir,
301
- 'examples/c4d7fb28f8f6_mask.jpg')), None,
302
- 'Rub out any text found in the mask sector of the {image}.', 6666
303
- ],
304
- [
305
- 'Remove Object',
306
- download_image(
307
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',
308
- os.path.join(cache_dir,
309
- 'examples/e2f318fa5e5b.jpg')), None, None,
310
- 'Remove the unicorn in this {image}, ensuring a smooth edit.',
311
- 99999
312
- ],
313
- [
314
- 'Remove Object',
315
- download_image(
316
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',
317
- os.path.join(cache_dir, 'examples/1ae96d8aca00.jpg')),
318
- download_image(
319
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',
320
- os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.jpg')),
321
- None, 'Discard the contents of the mask area from {image}.', 99999
322
- ],
323
- [
324
- 'Add Object',
325
- download_image(
326
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',
327
- os.path.join(cache_dir, 'examples/80289f48e511.jpg')),
328
- download_image(
329
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',
330
- os.path.join(cache_dir,
331
- 'examples/80289f48e511_mask.jpg')), None,
332
- 'add a Hot Air Balloon into the {image}, per the mask', 613725
333
- ],
334
- [
335
- 'Style Transfer',
336
- download_image(
337
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',
338
- os.path.join(cache_dir, 'examples/d725cb2009e8.jpg')), None,
339
- None, 'Change the style of {image} to colored pencil style', 99999
340
- ],
341
- [
342
- 'Style Transfer',
343
- download_image(
344
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',
345
- os.path.join(cache_dir, 'examples/e0f48b3fd010.jpg')), None,
346
- None, 'make {image} to Walt Disney Animation style', 99999
347
- ],
348
- [
349
- 'Try On',
350
- download_image(
351
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',
352
- os.path.join(cache_dir, 'examples/ee4ca60b8c96.jpg')),
353
- download_image(
354
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',
355
- os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.jpg')),
356
- download_image(
357
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',
358
- os.path.join(cache_dir, 'examples/ebe825bbfe3c.jpg')),
359
- 'Change the cloth in {image} to the one in {image1}', 99999
360
- ],
361
- [
362
- 'Workflow',
363
- download_image(
364
- 'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',
365
- os.path.join(cache_dir, 'examples/cb85353c004b.jpg')), None,
366
- None, '<workflow> ice cream {image}', 99999
367
- ],
368
- ]
369
- print('Finish. Start building UI ...')
370
- return examples
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/__init__.py DELETED
@@ -1,2 +0,0 @@
1
- from .flux import Flux, ACEFlux
2
- from .embedder import ACETextEmbedder, T5ACEPlusClipFluxEmbedder, ACEHFEmbedder
 
 
 
models/embedder.py DELETED
@@ -1,383 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- # Copyright (c) Alibaba, Inc. and its affiliates.
3
- import warnings
4
- from contextlib import nullcontext
5
-
6
- import torch
7
- import torch.nn.functional as F
8
- import torch.utils.dlpack
9
- import transformers
10
- from scepter.modules.model.embedder.base_embedder import BaseEmbedder
11
- from scepter.modules.model.registry import EMBEDDERS
12
- from scepter.modules.model.tokenizer.tokenizer_component import (
13
- basic_clean, canonicalize, heavy_clean, whitespace_clean)
14
- from scepter.modules.utils.config import dict_to_yaml
15
- from scepter.modules.utils.distribute import we
16
- from scepter.modules.utils.file_system import FS
17
-
18
- try:
19
- from transformers import AutoTokenizer, T5EncoderModel
20
- except Exception as e:
21
- warnings.warn(
22
- f'Import transformers error, please deal with this problem: {e}')
23
-
24
-
25
- @EMBEDDERS.register_class()
26
- class ACETextEmbedder(BaseEmbedder):
27
- """
28
- Uses the OpenCLIP transformer encoder for text
29
- """
30
- """
31
- Uses the OpenCLIP transformer encoder for text
32
- """
33
- para_dict = {
34
- 'PRETRAINED_MODEL': {
35
- 'value':
36
- 'google/umt5-small',
37
- 'description':
38
- 'Pretrained Model for umt5, modelcard path or local path.'
39
- },
40
- 'TOKENIZER_PATH': {
41
- 'value': 'google/umt5-small',
42
- 'description':
43
- 'Tokenizer Path for umt5, modelcard path or local path.'
44
- },
45
- 'FREEZE': {
46
- 'value': True,
47
- 'description': ''
48
- },
49
- 'USE_GRAD': {
50
- 'value': False,
51
- 'description': 'Compute grad or not.'
52
- },
53
- 'CLEAN': {
54
- 'value':
55
- 'whitespace',
56
- 'description':
57
- 'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
58
- },
59
- 'LAYER': {
60
- 'value': 'last',
61
- 'description': ''
62
- },
63
- 'LEGACY': {
64
- 'value':
65
- True,
66
- 'description':
67
- 'Whether use legacy returnd feature or not ,default True.'
68
- }
69
- }
70
-
71
- def __init__(self, cfg, logger=None):
72
- super().__init__(cfg, logger=logger)
73
- pretrained_path = cfg.get('PRETRAINED_MODEL', None)
74
- self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
75
- assert pretrained_path
76
- with FS.get_dir_to_local_dir(pretrained_path,
77
- wait_finish=True) as local_path:
78
- self.model = T5EncoderModel.from_pretrained(
79
- local_path,
80
- torch_dtype=getattr(
81
- torch,
82
- 'float' if self.t5_dtype == 'float32' else self.t5_dtype))
83
- tokenizer_path = cfg.get('TOKENIZER_PATH', None)
84
- self.length = cfg.get('LENGTH', 77)
85
-
86
- self.use_grad = cfg.get('USE_GRAD', False)
87
- self.clean = cfg.get('CLEAN', 'whitespace')
88
- self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
89
- if tokenizer_path:
90
- self.tokenize_kargs = {'return_tensors': 'pt'}
91
- with FS.get_dir_to_local_dir(tokenizer_path,
92
- wait_finish=True) as local_path:
93
- if self.added_identifier is not None and isinstance(
94
- self.added_identifier, list):
95
- self.tokenizer = AutoTokenizer.from_pretrained(local_path)
96
- else:
97
- self.tokenizer = AutoTokenizer.from_pretrained(local_path)
98
- if self.length is not None:
99
- self.tokenize_kargs.update({
100
- 'padding': 'max_length',
101
- 'truncation': True,
102
- 'max_length': self.length
103
- })
104
- self.eos_token = self.tokenizer(
105
- self.tokenizer.eos_token)['input_ids'][0]
106
- else:
107
- self.tokenizer = None
108
- self.tokenize_kargs = {}
109
-
110
- self.use_grad = cfg.get('USE_GRAD', False)
111
- self.clean = cfg.get('CLEAN', 'whitespace')
112
-
113
- def freeze(self):
114
- self.model = self.model.eval()
115
- for param in self.parameters():
116
- param.requires_grad = False
117
-
118
- # encode && encode_text
119
- def forward(self, tokens, return_mask=False, use_mask=True):
120
- # tokenization
121
- embedding_context = nullcontext if self.use_grad else torch.no_grad
122
- with embedding_context():
123
- if use_mask:
124
- x = self.model(tokens.input_ids.to(we.device_id),
125
- tokens.attention_mask.to(we.device_id))
126
- else:
127
- x = self.model(tokens.input_ids.to(we.device_id))
128
- x = x.last_hidden_state
129
-
130
- if return_mask:
131
- return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
132
- else:
133
- return x.detach() + 0.0, None
134
-
135
- def _clean(self, text):
136
- if self.clean == 'whitespace':
137
- text = whitespace_clean(basic_clean(text))
138
- elif self.clean == 'lower':
139
- text = whitespace_clean(basic_clean(text)).lower()
140
- elif self.clean == 'canonicalize':
141
- text = canonicalize(basic_clean(text))
142
- elif self.clean == 'heavy':
143
- text = heavy_clean(basic_clean(text))
144
- return text
145
-
146
- def encode(self, text, return_mask=False, use_mask=True):
147
- if isinstance(text, str):
148
- text = [text]
149
- if self.clean:
150
- text = [self._clean(u) for u in text]
151
- assert self.tokenizer is not None
152
- cont, mask = [], []
153
- with torch.autocast(device_type='cuda',
154
- enabled=self.t5_dtype in ('float16', 'bfloat16'),
155
- dtype=getattr(torch, self.t5_dtype)):
156
- for tt in text:
157
- tokens = self.tokenizer([tt], **self.tokenize_kargs)
158
- one_cont, one_mask = self(tokens,
159
- return_mask=return_mask,
160
- use_mask=use_mask)
161
- cont.append(one_cont)
162
- mask.append(one_mask)
163
- if return_mask:
164
- return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
165
- else:
166
- return torch.cat(cont, dim=0)
167
-
168
- def encode_list(self, text_list, return_mask=True):
169
- cont_list = []
170
- mask_list = []
171
- for pp in text_list:
172
- cont, cont_mask = self.encode(pp, return_mask=return_mask)
173
- cont_list.append(cont)
174
- mask_list.append(cont_mask)
175
- if return_mask:
176
- return cont_list, mask_list
177
- else:
178
- return cont_list
179
-
180
- @staticmethod
181
- def get_config_template():
182
- return dict_to_yaml('MODELS',
183
- __class__.__name__,
184
- ACETextEmbedder.para_dict,
185
- set_name=True)
186
-
187
- @EMBEDDERS.register_class()
188
- class ACEHFEmbedder(BaseEmbedder):
189
- para_dict = {
190
- "HF_MODEL_CLS": {
191
- "value": None,
192
- "description": "huggingface cls in transfomer"
193
- },
194
- "MODEL_PATH": {
195
- "value": None,
196
- "description": "model folder path"
197
- },
198
- "HF_TOKENIZER_CLS": {
199
- "value": None,
200
- "description": "huggingface cls in transfomer"
201
- },
202
-
203
- "TOKENIZER_PATH": {
204
- "value": None,
205
- "description": "tokenizer folder path"
206
- },
207
- "MAX_LENGTH": {
208
- "value": 77,
209
- "description": "max length of input"
210
- },
211
- "OUTPUT_KEY": {
212
- "value": "last_hidden_state",
213
- "description": "output key"
214
- },
215
- "D_TYPE": {
216
- "value": "float",
217
- "description": "dtype"
218
- },
219
- "BATCH_INFER": {
220
- "value": False,
221
- "description": "batch infer"
222
- }
223
- }
224
- para_dict.update(BaseEmbedder.para_dict)
225
- def __init__(self, cfg, logger=None):
226
- super().__init__(cfg, logger=logger)
227
- hf_model_cls = cfg.get('HF_MODEL_CLS', None)
228
- model_path = cfg.get("MODEL_PATH", None)
229
- hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
230
- tokenizer_path = cfg.get('TOKENIZER_PATH', None)
231
- self.max_length = cfg.get('MAX_LENGTH', 77)
232
- self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
233
- self.d_type = cfg.get("D_TYPE", "float")
234
- self.clean = cfg.get("CLEAN", "whitespace")
235
- self.batch_infer = cfg.get("BATCH_INFER", False)
236
- self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
237
- torch_dtype = getattr(torch, self.d_type)
238
-
239
- assert hf_model_cls is not None and hf_tokenizer_cls is not None
240
- assert model_path is not None and tokenizer_path is not None
241
- with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
242
- self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
243
- max_length = self.max_length,
244
- torch_dtype = torch_dtype,
245
- additional_special_tokens=self.added_identifier)
246
-
247
- with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
248
- self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
249
-
250
-
251
- self.hf_module = self.hf_module.eval().requires_grad_(False)
252
-
253
- def forward(self, text: list[str], return_mask = False):
254
- batch_encoding = self.tokenizer(
255
- text,
256
- truncation=True,
257
- max_length=self.max_length,
258
- return_length=False,
259
- return_overflowing_tokens=False,
260
- padding="max_length",
261
- return_tensors="pt",
262
- )
263
-
264
- outputs = self.hf_module(
265
- input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
266
- attention_mask=None,
267
- output_hidden_states=False,
268
- )
269
- if return_mask:
270
- return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
271
- else:
272
- return outputs[self.output_key], None
273
-
274
- def encode(self, text, return_mask = False):
275
- if isinstance(text, str):
276
- text = [text]
277
- if self.clean:
278
- text = [self._clean(u) for u in text]
279
- if not self.batch_infer:
280
- cont, mask = [], []
281
- for tt in text:
282
- one_cont, one_mask = self([tt], return_mask=return_mask)
283
- cont.append(one_cont)
284
- mask.append(one_mask)
285
- if return_mask:
286
- return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
287
- else:
288
- return torch.cat(cont, dim=0)
289
- else:
290
- ret_data = self(text, return_mask = return_mask)
291
- if return_mask:
292
- return ret_data
293
- else:
294
- return ret_data[0]
295
-
296
- def encode_list(self, text_list, return_mask=True):
297
- cont_list = []
298
- mask_list = []
299
- for pp in text_list:
300
- cont = self.encode(pp, return_mask=return_mask)
301
- cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
302
- mask_list.append(cont[1]) if return_mask else mask_list.append(None)
303
- if return_mask:
304
- return cont_list, mask_list
305
- else:
306
- return cont_list
307
-
308
- def encode_list_of_list(self, text_list, return_mask=True):
309
- cont_list = []
310
- mask_list = []
311
- for pp in text_list:
312
- cont = self.encode_list(pp, return_mask=return_mask)
313
- cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
314
- mask_list.append(cont[1]) if return_mask else mask_list.append(None)
315
- if return_mask:
316
- return cont_list, mask_list
317
- else:
318
- return cont_list
319
-
320
- def _clean(self, text):
321
- if self.clean == 'whitespace':
322
- text = whitespace_clean(basic_clean(text))
323
- elif self.clean == 'lower':
324
- text = whitespace_clean(basic_clean(text)).lower()
325
- elif self.clean == 'canonicalize':
326
- text = canonicalize(basic_clean(text))
327
- return text
328
- @staticmethod
329
- def get_config_template():
330
- return dict_to_yaml('EMBEDDER',
331
- __class__.__name__,
332
- ACEHFEmbedder.para_dict,
333
- set_name=True)
334
-
335
- @EMBEDDERS.register_class()
336
- class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
337
- """
338
- Uses the OpenCLIP transformer encoder for text
339
- """
340
- para_dict = {
341
- 'T5_MODEL': {},
342
- 'CLIP_MODEL': {}
343
- }
344
-
345
- def __init__(self, cfg, logger=None):
346
- super().__init__(cfg, logger=logger)
347
- self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
348
- self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
349
-
350
- def encode(self, text, return_mask = False):
351
- t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
352
- clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
353
- # change embedding strategy here
354
- return {
355
- 'context': t5_embeds,
356
- 'y': clip_embeds,
357
- }
358
-
359
- def encode_list(self, text, return_mask = False):
360
- t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
361
- clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
362
- # change embedding strategy here
363
- return {
364
- 'context': t5_embeds,
365
- 'y': clip_embeds,
366
- }
367
-
368
- def encode_list_of_list(self, text, return_mask = False):
369
- t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
370
- clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
371
- # change embedding strategy here
372
- return {
373
- 'context': t5_embeds,
374
- 'y': clip_embeds,
375
- }
376
-
377
-
378
- @staticmethod
379
- def get_config_template():
380
- return dict_to_yaml('EMBEDDER',
381
- __class__.__name__,
382
- T5ACEPlusClipFluxEmbedder.para_dict,
383
- set_name=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/flux.py DELETED
@@ -1,798 +0,0 @@
1
- import math, torch
2
- from collections import OrderedDict
3
- from functools import partial
4
- from einops import rearrange, repeat
5
- from scepter.modules.model.base_model import BaseModel
6
- from scepter.modules.model.registry import BACKBONES
7
- from scepter.modules.utils.config import dict_to_yaml
8
- from scepter.modules.utils.distribute import we
9
- from scepter.modules.utils.file_system import FS
10
- from torch import Tensor, nn
11
- from torch.nn.utils.rnn import pad_sequence
12
- from torch.utils.checkpoint import checkpoint_sequential
13
-
14
- from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
15
- MLPEmbedder, SingleStreamBlock,
16
- timestep_embedding, DoubleStreamBlockACE, SingleStreamBlockACE)
17
-
18
- @BACKBONES.register_class()
19
- class Flux(BaseModel):
20
- """
21
- Transformer backbone Diffusion model with RoPE.
22
- """
23
- para_dict = {
24
- "IN_CHANNELS": {
25
- "value": 64,
26
- "description": "model's input channels."
27
- },
28
- "OUT_CHANNELS": {
29
- "value": 64,
30
- "description": "model's output channels."
31
- },
32
- "HIDDEN_SIZE": {
33
- "value": 1024,
34
- "description": "model's hidden size."
35
- },
36
- "NUM_HEADS": {
37
- "value": 16,
38
- "description": "number of heads in the transformer."
39
- },
40
- "AXES_DIM": {
41
- "value": [16, 56, 56],
42
- "description": "dimensions of the axes of the positional encoding."
43
- },
44
- "THETA": {
45
- "value": 10_000,
46
- "description": "theta for positional encoding."
47
- },
48
- "VEC_IN_DIM": {
49
- "value": 768,
50
- "description": "dimension of the vector input."
51
- },
52
- "GUIDANCE_EMBED": {
53
- "value": False,
54
- "description": "whether to use guidance embedding."
55
- },
56
- "CONTEXT_IN_DIM": {
57
- "value": 4096,
58
- "description": "dimension of the context input."
59
- },
60
- "MLP_RATIO": {
61
- "value": 4.0,
62
- "description": "ratio of mlp hidden size to hidden size."
63
- },
64
- "QKV_BIAS": {
65
- "value": True,
66
- "description": "whether to use bias in qkv projection."
67
- },
68
- "DEPTH": {
69
- "value": 19,
70
- "description": "number of transformer blocks."
71
- },
72
- "DEPTH_SINGLE_BLOCKS": {
73
- "value": 38,
74
- "description": "number of transformer blocks in the single stream block."
75
- },
76
- "USE_GRAD_CHECKPOINT": {
77
- "value": False,
78
- "description": "whether to use gradient checkpointing."
79
- },
80
- "ATTN_BACKEND": {
81
- "value": "pytorch",
82
- "description": "backend for the transformer blocks, 'pytorch' or 'flash_attn'."
83
- }
84
- }
85
- def __init__(
86
- self,
87
- cfg,
88
- logger = None
89
- ):
90
- super().__init__(cfg, logger=logger)
91
- self.in_channels = cfg.IN_CHANNELS
92
- self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
93
- hidden_size = cfg.get("HIDDEN_SIZE", 1024)
94
- num_heads = cfg.get("NUM_HEADS", 16)
95
- axes_dim = cfg.AXES_DIM
96
- theta = cfg.THETA
97
- vec_in_dim = cfg.VEC_IN_DIM
98
- self.guidance_embed = cfg.GUIDANCE_EMBED
99
- context_in_dim = cfg.CONTEXT_IN_DIM
100
- mlp_ratio = cfg.MLP_RATIO
101
- qkv_bias = cfg.QKV_BIAS
102
- depth = cfg.DEPTH
103
- depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
104
- self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
105
- self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
106
- self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
107
- self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
108
- self.blackforest_lora_model = cfg.get("BLACKFOREST_LORA_MODEL", None)
109
- self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
110
-
111
- if hidden_size % num_heads != 0:
112
- raise ValueError(
113
- f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
114
- )
115
- pe_dim = hidden_size // num_heads
116
- if sum(axes_dim) != pe_dim:
117
- raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
118
- self.hidden_size = hidden_size
119
- self.num_heads = num_heads
120
- self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim= axes_dim)
121
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
122
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
123
- self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
124
- self.guidance_in = (
125
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
126
- )
127
- self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
128
-
129
- self.double_blocks = nn.ModuleList(
130
- [
131
- DoubleStreamBlock(
132
- self.hidden_size,
133
- self.num_heads,
134
- mlp_ratio=mlp_ratio,
135
- qkv_bias=qkv_bias,
136
- backend=self.attn_backend
137
- )
138
- for _ in range(depth)
139
- ]
140
- )
141
-
142
- self.single_blocks = nn.ModuleList(
143
- [
144
- SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
145
- for _ in range(depth_single_blocks)
146
- ]
147
- )
148
-
149
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
150
-
151
- def prepare_input(self, x, context, y, x_shape=None):
152
- # x.shape [6, 16, 16, 16] target is [6, 16, 768, 1360]
153
- bs, c, h, w = x.shape
154
- x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
155
- x_id = torch.zeros(h // 2, w // 2, 3)
156
- x_id[..., 1] = x_id[..., 1] + torch.arange(h // 2)[:, None]
157
- x_id[..., 2] = x_id[..., 2] + torch.arange(w // 2)[None, :]
158
- x_ids = repeat(x_id, "h w c -> b (h w) c", b=bs)
159
- txt_ids = torch.zeros(bs, context.shape[1], 3)
160
- return x, x_ids.to(x), context.to(x), txt_ids.to(x), y.to(x), h, w
161
-
162
- def unpack(self, x: Tensor, height: int, width: int) -> Tensor:
163
- return rearrange(
164
- x,
165
- "b (h w) (c ph pw) -> b c (h ph) (w pw)",
166
- h=math.ceil(height/2),
167
- w=math.ceil(width/2),
168
- ph=2,
169
- pw=2,
170
- )
171
-
172
- # def merge_diffuser_lora(self, ori_sd, lora_sd, scale = 1.0):
173
- # key_map = {
174
- # "single_blocks.{}.linear1.weight": {"key_list": [
175
- # ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
176
- # "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight"],
177
- # ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
178
- # "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight"],
179
- # ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
180
- # "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight"],
181
- # ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
182
- # "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight"]
183
- # ], "num": 38},
184
- # "single_blocks.{}.modulation.lin.weight": {"key_list": [
185
- # ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
186
- # "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight"],
187
- # ], "num": 38},
188
- # "single_blocks.{}.linear2.weight": {"key_list": [
189
- # ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
190
- # "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight"],
191
- # ], "num": 38},
192
- # "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
193
- # ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
194
- # "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight"],
195
- # ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
196
- # "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight"],
197
- # ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
198
- # "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight"],
199
- # ], "num": 19},
200
- # "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
201
- # ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
202
- # "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight"],
203
- # ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
204
- # "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight"],
205
- # ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
206
- # "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight"],
207
- # ], "num": 19},
208
- # "double_blocks.{}.img_attn.proj.weight": {"key_list": [
209
- # ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
210
- # "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight"]
211
- # ], "num": 19},
212
- # "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
213
- # ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
214
- # "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight"]
215
- # ], "num": 19},
216
- # "double_blocks.{}.img_mlp.0.weight": {"key_list": [
217
- # ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
218
- # "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight"]
219
- # ], "num": 19},
220
- # "double_blocks.{}.img_mlp.2.weight": {"key_list": [
221
- # ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
222
- # "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight"]
223
- # ], "num": 19},
224
- # "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
225
- # ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
226
- # "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight"]
227
- # ], "num": 19},
228
- # "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
229
- # ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
230
- # "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight"]
231
- # ], "num": 19},
232
- # "double_blocks.{}.img_mod.lin.weight": {"key_list": [
233
- # ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
234
- # "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight"]
235
- # ], "num": 19},
236
- # "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
237
- # ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
238
- # "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight"]
239
- # ], "num": 19}
240
- # }
241
- # have_lora_keys = 0
242
- # for k, v in key_map.items():
243
- # key_list = v["key_list"]
244
- # block_num = v["num"]
245
- # for block_id in range(block_num):
246
- # current_weight_list = []
247
- # for k_list in key_list:
248
- # current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
249
- # lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
250
- # current_weight_list.append(current_weight)
251
- # current_weight = torch.cat(current_weight_list, dim=0)
252
- # ori_sd[k.format(block_id)] += scale*current_weight
253
- # have_lora_keys += 1
254
- # self.logger.info(f"merge_swift_lora loads lora'parameters {have_lora_keys}")
255
- # return ori_sd
256
-
257
- def merge_diffuser_lora(self, ori_sd, lora_sd, scale=1.0):
258
- key_map = {
259
- "single_blocks.{}.linear1.weight": {"key_list": [
260
- ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
261
- "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
262
- ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
263
- "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
264
- ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
265
- "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
266
- ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
267
- "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight", [9216, 21504]]
268
- ], "num": 38},
269
- "single_blocks.{}.modulation.lin.weight": {"key_list": [
270
- ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
271
- "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight", [0, 9216]],
272
- ], "num": 38},
273
- "single_blocks.{}.linear2.weight": {"key_list": [
274
- ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
275
- "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight", [0, 3072]],
276
- ], "num": 38},
277
- "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
278
- ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
279
- "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight", [0, 3072]],
280
- ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
281
- "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight", [3072, 6144]],
282
- ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
283
- "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight", [6144, 9216]],
284
- ], "num": 19},
285
- "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
286
- ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
287
- "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
288
- ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
289
- "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
290
- ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
291
- "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
292
- ], "num": 19},
293
- "double_blocks.{}.img_attn.proj.weight": {"key_list": [
294
- ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
295
- "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight", [0, 3072]]
296
- ], "num": 19},
297
- "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
298
- ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
299
- "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight", [0, 3072]]
300
- ], "num": 19},
301
- "double_blocks.{}.img_mlp.0.weight": {"key_list": [
302
- ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
303
- "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight", [0, 12288]]
304
- ], "num": 19},
305
- "double_blocks.{}.img_mlp.2.weight": {"key_list": [
306
- ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
307
- "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight", [0, 3072]]
308
- ], "num": 19},
309
- "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
310
- ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
311
- "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight", [0, 12288]]
312
- ], "num": 19},
313
- "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
314
- ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
315
- "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight", [0, 3072]]
316
- ], "num": 19},
317
- "double_blocks.{}.img_mod.lin.weight": {"key_list": [
318
- ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
319
- "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight", [0, 18432]]
320
- ], "num": 19},
321
- "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
322
- ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
323
- "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight", [0, 18432]]
324
- ], "num": 19}
325
- }
326
- cover_lora_keys = set()
327
- cover_ori_keys = set()
328
- for k, v in key_map.items():
329
- key_list = v["key_list"]
330
- block_num = v["num"]
331
- for block_id in range(block_num):
332
- for k_list in key_list:
333
- if k_list[0].format(block_id) in lora_sd and k_list[1].format(block_id) in lora_sd:
334
- cover_lora_keys.add(k_list[0].format(block_id))
335
- cover_lora_keys.add(k_list[1].format(block_id))
336
- current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
337
- lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
338
- ori_sd[k.format(block_id)][k_list[2][0]:k_list[2][1], ...] += scale * current_weight
339
- cover_ori_keys.add(k.format(block_id))
340
- # lora_sd.pop(k_list[0].format(block_id))
341
- # lora_sd.pop(k_list[1].format(block_id))
342
- self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
343
- f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
344
- f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
345
- return ori_sd
346
-
347
- def merge_swift_lora(self, ori_sd, lora_sd, scale = 1.0):
348
- have_lora_keys = {}
349
- for k, v in lora_sd.items():
350
- k = k[len("model."):] if k.startswith("model.") else k
351
- ori_key = k.split("lora")[0] + "weight"
352
- if ori_key not in ori_sd:
353
- raise f"{ori_key} should in the original statedict"
354
- if ori_key not in have_lora_keys:
355
- have_lora_keys[ori_key] = {}
356
- if "lora_A" in k:
357
- have_lora_keys[ori_key]["lora_A"] = v
358
- elif "lora_B" in k:
359
- have_lora_keys[ori_key]["lora_B"] = v
360
- else:
361
- raise NotImplementedError
362
- self.logger.info(f"merge_swift_lora loads lora'parameters {len(have_lora_keys)}")
363
- for key, v in have_lora_keys.items():
364
- current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
365
- ori_sd[key] += scale * current_weight
366
- return ori_sd
367
-
368
-
369
- def merge_blackforest_lora(self, ori_sd, lora_sd, scale = 1.0):
370
- have_lora_keys = {}
371
- cover_lora_keys = set()
372
- cover_ori_keys = set()
373
- for k, v in lora_sd.items():
374
- if "lora" in k:
375
- ori_key = k.split("lora")[0] + "weight"
376
- if ori_key not in ori_sd:
377
- raise f"{ori_key} should in the original statedict"
378
- if ori_key not in have_lora_keys:
379
- have_lora_keys[ori_key] = {}
380
- if "lora_A" in k:
381
- have_lora_keys[ori_key]["lora_A"] = v
382
- cover_lora_keys.add(k)
383
- cover_ori_keys.add(ori_key)
384
- elif "lora_B" in k:
385
- have_lora_keys[ori_key]["lora_B"] = v
386
- cover_lora_keys.add(k)
387
- cover_ori_keys.add(ori_key)
388
- else:
389
- if k in ori_sd:
390
- ori_sd[k] = v
391
- cover_lora_keys.add(k)
392
- cover_ori_keys.add(k)
393
- else:
394
- print("unsurpport keys: ", k)
395
- self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
396
- f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
397
- f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
398
-
399
- for key, v in have_lora_keys.items():
400
- current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
401
- # print(key, ori_sd[key].shape, current_weight.shape)
402
- ori_sd[key] += scale * current_weight
403
- return ori_sd
404
-
405
- def load_pretrained_model(self, pretrained_model):
406
- if next(self.parameters()).device.type == 'meta':
407
- map_location = torch.device(we.device_id)
408
- safe_device = we.device_id
409
- else:
410
- map_location = "cpu"
411
- safe_device = "cpu"
412
-
413
- if pretrained_model is not None:
414
- with FS.get_from(pretrained_model, wait_finish=True) as local_model:
415
- if local_model.endswith('safetensors'):
416
- from safetensors.torch import load_file as load_safetensors
417
- sd = load_safetensors(local_model, device=safe_device)
418
- else:
419
- sd = torch.load(local_model, map_location=map_location, weights_only=True)
420
- if "state_dict" in sd:
421
- sd = sd["state_dict"]
422
- if "model" in sd:
423
- sd = sd["model"]["model"]
424
-
425
-
426
- new_ckpt = OrderedDict()
427
- for k, v in sd.items():
428
- if k in ("img_in.weight"):
429
- model_p = self.state_dict()[k]
430
- if v.shape != model_p.shape:
431
- expanded_state_dict_weight = torch.zeros_like(model_p, device=v.device)
432
- slices = tuple(slice(0, dim) for dim in v.shape)
433
- expanded_state_dict_weight[slices] = v
434
- new_ckpt[k] = expanded_state_dict_weight
435
- else:
436
- new_ckpt[k] = v
437
- else:
438
- new_ckpt[k] = v
439
-
440
-
441
- if self.lora_model is not None:
442
- with FS.get_from(self.lora_model, wait_finish=True) as local_model:
443
- if local_model.endswith('safetensors'):
444
- from safetensors.torch import load_file as load_safetensors
445
- lora_sd = load_safetensors(local_model, device=safe_device)
446
- else:
447
- lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
448
- new_ckpt = self.merge_diffuser_lora(new_ckpt, lora_sd)
449
- if self.swift_lora_model is not None:
450
- if not isinstance(self.swift_lora_model, list):
451
- self.swift_lora_model = [self.swift_lora_model]
452
- for lora_model in self.swift_lora_model:
453
- self.logger.info(f"load swift lora model: {lora_model}")
454
- with FS.get_from(lora_model, wait_finish=True) as local_model:
455
- if local_model.endswith('safetensors'):
456
- from safetensors.torch import load_file as load_safetensors
457
- lora_sd = load_safetensors(local_model, device=safe_device)
458
- else:
459
- lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
460
- new_ckpt = self.merge_swift_lora(new_ckpt, lora_sd)
461
- if self.blackforest_lora_model is not None:
462
-
463
- with FS.get_from(self.blackforest_lora_model, wait_finish=True) as local_model:
464
- if local_model.endswith('safetensors'):
465
- from safetensors.torch import load_file as load_safetensors
466
- lora_sd = load_safetensors(local_model, device=safe_device)
467
- else:
468
- lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
469
- new_ckpt = self.merge_blackforest_lora(new_ckpt, lora_sd)
470
-
471
-
472
- adapter_ckpt = {}
473
- if self.pretrain_adapter is not None:
474
- with FS.get_from(self.pretrain_adapter, wait_finish=True) as local_adapter:
475
- if local_adapter.endswith('safetensors'):
476
- from safetensors.torch import load_file as load_safetensors
477
- adapter_ckpt = load_safetensors(local_adapter, device=safe_device)
478
- else:
479
- adapter_ckpt = torch.load(local_adapter, map_location=map_location, weights_only=True)
480
- new_ckpt.update(adapter_ckpt)
481
-
482
- missing, unexpected = self.load_state_dict(new_ckpt, strict=False, assign=True)
483
- self.logger.info(
484
- f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
485
- )
486
- if len(missing) > 0:
487
- self.logger.info(f'Missing Keys:\n {missing}')
488
- if len(unexpected) > 0:
489
- self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
490
-
491
- def forward(
492
- self,
493
- x: Tensor,
494
- t: Tensor,
495
- cond: dict = {},
496
- guidance: Tensor | None = None,
497
- gc_seg: int = 0
498
- ) -> Tensor:
499
- x, x_ids, txt, txt_ids, y, h, w = self.prepare_input(x, cond["context"], cond["y"])
500
- # running on sequences img
501
- x = self.img_in(x)
502
- vec = self.time_in(timestep_embedding(t, 256))
503
- if self.guidance_embed:
504
- if guidance is None:
505
- raise ValueError("Didn't get guidance strength for guidance distilled model.")
506
- vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
507
- vec = vec + self.vector_in(y)
508
- txt = self.txt_in(txt)
509
- ids = torch.cat((txt_ids, x_ids), dim=1)
510
- pe = self.pe_embedder(ids)
511
- kwargs = dict(
512
- vec=vec,
513
- pe=pe,
514
- txt_length=txt.shape[1],
515
- )
516
- x = torch.cat((txt, x), 1)
517
- if self.use_grad_checkpoint and gc_seg >= 0:
518
- x = checkpoint_sequential(
519
- functions=[partial(block, **kwargs) for block in self.double_blocks],
520
- segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
521
- input=x,
522
- use_reentrant=False
523
- )
524
- else:
525
- for block in self.double_blocks:
526
- x = block(x, **kwargs)
527
-
528
- kwargs = dict(
529
- vec=vec,
530
- pe=pe,
531
- )
532
-
533
- if self.use_grad_checkpoint and gc_seg >= 0:
534
- x = checkpoint_sequential(
535
- functions=[partial(block, **kwargs) for block in self.single_blocks],
536
- segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
537
- input=x,
538
- use_reentrant=False
539
- )
540
- else:
541
- for block in self.single_blocks:
542
- x = block(x, **kwargs)
543
- x = x[:, txt.shape[1] :, ...]
544
- x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
545
- x = self.unpack(x, h, w)
546
- return x
547
-
548
- @staticmethod
549
- def get_config_template():
550
- return dict_to_yaml('MODEL',
551
- __class__.__name__,
552
- Flux.para_dict,
553
- set_name=True)
554
- @BACKBONES.register_class()
555
- class ACEFlux(Flux):
556
- '''
557
- cat[x_seq, edit_seq]
558
- pe[x_seq] pe[edit_seq]
559
- '''
560
-
561
- def __init__(
562
- self,
563
- cfg,
564
- logger=None
565
- ):
566
- super().__init__(cfg, logger=logger)
567
- self.in_channels = cfg.IN_CHANNELS
568
- self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
569
- hidden_size = cfg.get("HIDDEN_SIZE", 1024)
570
- num_heads = cfg.get("NUM_HEADS", 16)
571
- axes_dim = cfg.AXES_DIM
572
- theta = cfg.THETA
573
- vec_in_dim = cfg.VEC_IN_DIM
574
- self.guidance_embed = cfg.GUIDANCE_EMBED
575
- context_in_dim = cfg.CONTEXT_IN_DIM
576
- mlp_ratio = cfg.MLP_RATIO
577
- qkv_bias = cfg.QKV_BIAS
578
- depth = cfg.DEPTH
579
- depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
580
- self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
581
- self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
582
- self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
583
- self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
584
- self.blackforest_lora_model = cfg.get("BLACKFOREST_LORA_MODEL", None)
585
- self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
586
-
587
- if hidden_size % num_heads != 0:
588
- raise ValueError(
589
- f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
590
- )
591
- pe_dim = hidden_size // num_heads
592
- if sum(axes_dim) != pe_dim:
593
- raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
594
- self.hidden_size = hidden_size
595
- self.num_heads = num_heads
596
- self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
597
- self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
598
- self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
599
- self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
600
- self.guidance_in = (
601
- MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
602
- )
603
- self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
604
-
605
- self.double_blocks = nn.ModuleList(
606
- [
607
- DoubleStreamBlockACE(
608
- self.hidden_size,
609
- self.num_heads,
610
- mlp_ratio=mlp_ratio,
611
- qkv_bias=qkv_bias,
612
- backend=self.attn_backend
613
- )
614
- for _ in range(depth)
615
- ]
616
- )
617
-
618
- self.single_blocks = nn.ModuleList(
619
- [
620
- SingleStreamBlockACE(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
621
- for _ in range(depth_single_blocks)
622
- ]
623
- )
624
-
625
- self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
626
-
627
- def prepare_input(self, x, cond, *args, **kwargs):
628
- context, y = cond["context"], cond["y"]
629
- # import pdb;pdb.set_trace()
630
- batch_shift = []
631
- x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
632
- for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
633
- # unpack image from sequence
634
- ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
635
- c, h, w = ix.shape
636
- ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
637
- ix_id = torch.zeros(h // 2, w // 2, 3)
638
- ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
639
- ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
640
- batch_shift.append(w // 2) if is_align < 1 else batch_shift.append(0)
641
- ix_id = rearrange(ix_id, "h w c -> (h w) c")
642
- ix = self.img_in(ix)
643
- x_list.append(ix)
644
- x_id_list.append(ix_id)
645
- mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
646
- x_seq_length.append(ix.shape[0])
647
-
648
- x = pad_sequence(tuple(x_list), batch_first=True)
649
- x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
650
- mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
651
-
652
- if 'edit' in cond and sum(len(e) for e in cond['edit']) > 0:
653
- batch_frames, batch_frames_ids = [], []
654
- for i, edit in enumerate(cond['edit']):
655
- batch_frames.append([])
656
- batch_frames_ids.append([])
657
- for ie in edit:
658
- ie = ie.squeeze(0)
659
- c, h, w = ie.shape
660
- ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
661
- ie_id = torch.zeros(h // 2, w // 2, 3)
662
- ie_id[..., 1] = ie_id[..., 1] + torch.arange(h // 2)[:, None]
663
- ie_id[..., 2] = ie_id[..., 2] + torch.arange(batch_shift[i], batch_shift[i] + w // 2)[None, :]
664
- ie_id = rearrange(ie_id, "h w c -> (h w) c")
665
- batch_frames[i].append(ie)
666
- batch_frames_ids[i].append(ie_id)
667
- edit_list, edit_id_list, edit_mask_x_list = [], [], []
668
- for frames, frame_ids in zip(batch_frames, batch_frames_ids):
669
- proj_frames = []
670
- for idx, one_frame in enumerate(frames):
671
- one_frame = self.img_in(one_frame)
672
- proj_frames.append(one_frame)
673
- ie = torch.cat(proj_frames, dim=0)
674
- ie_id = torch.cat(frame_ids, dim=0)
675
- edit_list.append(ie)
676
- edit_id_list.append(ie_id)
677
- edit_mask_x_list.append(torch.ones(ie.shape[0]).to(ie.device, non_blocking=True).bool())
678
- edit = pad_sequence(tuple(edit_list), batch_first=True)
679
- edit_ids = pad_sequence(tuple(edit_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
680
- edit_mask_x = pad_sequence(tuple(edit_mask_x_list), batch_first=True)
681
- else:
682
- edit, edit_ids, edit_mask_x = None, None, None
683
-
684
- txt_list, mask_txt_list, y_list = [], [], []
685
- for sample_id, (ctx, yy) in enumerate(zip(context, y)):
686
- txt_list.append(self.txt_in(ctx.to(x)))
687
- mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
688
- y_list.append(yy.to(x))
689
- txt = pad_sequence(tuple(txt_list), batch_first=True)
690
- txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
691
- mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
692
- y = torch.cat(y_list, dim=0)
693
- return x, x_ids, edit, edit_ids, txt, txt_ids, y, mask_x, edit_mask_x, mask_txt, x_seq_length
694
-
695
- def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
696
- x_list = []
697
- image_shapes = cond["x_shapes"]
698
- for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
699
- height, width = shape
700
- h, w = math.ceil(height / 2), math.ceil(width / 2)
701
- u = rearrange(
702
- u[:h * w, ...],
703
- "(h w) (c ph pw) -> (h ph w pw) c",
704
- h=h,
705
- w=w,
706
- ph=2,
707
- pw=2,
708
- )
709
- x_list.append(u)
710
- x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
711
- return x
712
-
713
- def forward(
714
- self,
715
- x: Tensor,
716
- t: Tensor,
717
- cond: dict = {},
718
- guidance: Tensor | None = None,
719
- gc_seg: int = 0,
720
- **kwargs
721
- ) -> Tensor:
722
- x, x_ids, edit, edit_ids, txt, txt_ids, y, mask_x, edit_mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
723
- # running on sequences img
724
- # condition use zero t
725
- x_length = x.shape[1]
726
- vec = self.time_in(timestep_embedding(t, 256))
727
-
728
- if edit is not None:
729
- edit_vec = self.time_in(timestep_embedding(t * 0, 256))
730
- # print("edit_vec", torch.sum(edit_vec))
731
- else:
732
- edit_vec = None
733
-
734
- if self.guidance_embed:
735
- if guidance is None:
736
- raise ValueError("Didn't get guidance strength for guidance distilled model.")
737
- vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
738
- if edit is not None:
739
- edit_vec = edit_vec + self.guidance_in(timestep_embedding(guidance, 256))
740
-
741
- vec = vec + self.vector_in(y)
742
- if edit is not None:
743
- edit_vec = edit_vec + self.vector_in(y)
744
- ids = torch.cat((txt_ids, x_ids, edit_ids), dim=1)
745
- mask_aside = torch.cat((mask_txt, mask_x, edit_mask_x), dim=1)
746
- x = torch.cat((txt, x, edit), 1)
747
- else:
748
- ids = torch.cat((txt_ids, x_ids), dim=1)
749
- mask_aside = torch.cat((mask_txt, mask_x), dim=1)
750
- x = torch.cat((txt, x), 1)
751
-
752
- pe = self.pe_embedder(ids)
753
- mask = mask_aside[:, None, :] * mask_aside[:, :, None]
754
-
755
- kwargs = dict(
756
- vec=vec,
757
- pe=pe,
758
- mask=mask,
759
- txt_length=txt.shape[1],
760
- x_length=x_length,
761
- edit_vec=edit_vec,
762
-
763
- )
764
-
765
- if self.use_grad_checkpoint and gc_seg >= 0:
766
- x = checkpoint_sequential(
767
- functions=[partial(block, **kwargs) for block in self.double_blocks],
768
- segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
769
- input=x,
770
- use_reentrant=False
771
- )
772
- else:
773
- for idx, block in enumerate(self.double_blocks):
774
- # print("double block", idx)
775
- x = block(x, **kwargs)
776
-
777
- if self.use_grad_checkpoint and gc_seg >= 0:
778
- x = checkpoint_sequential(
779
- functions=[partial(block, **kwargs) for block in self.single_blocks],
780
- segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
781
- input=x,
782
- use_reentrant=False
783
- )
784
- else:
785
- for idx, block in enumerate(self.single_blocks):
786
- # print("single block", idx)
787
- x = block(x, **kwargs)
788
- x = x[:, txt.shape[1]:txt.shape[1] + x_length, ...]
789
- x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
790
- x = self.unpack(x, cond, seq_length_list)
791
- return x
792
-
793
- @staticmethod
794
- def get_config_template():
795
- return dict_to_yaml('MODEL',
796
- __class__.__name__,
797
- ACEFlux.para_dict,
798
- set_name=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/layers.py DELETED
@@ -1,497 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import math
4
- from dataclasses import dataclass
5
- from torch import Tensor, nn
6
- import torch
7
- from einops import rearrange, repeat
8
- from torch import Tensor
9
- from torch.nn.utils.rnn import pad_sequence
10
-
11
- try:
12
- from flash_attn import (
13
- flash_attn_varlen_func
14
- )
15
- FLASHATTN_IS_AVAILABLE = True
16
- except ImportError:
17
- FLASHATTN_IS_AVAILABLE = False
18
- flash_attn_varlen_func = None
19
-
20
- def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor | None = None, backend = 'pytorch') -> Tensor:
21
- q, k = apply_rope(q, k, pe)
22
- if backend == 'pytorch':
23
- if mask is not None and mask.dtype == torch.bool:
24
- mask = torch.zeros_like(mask).to(q).masked_fill_(mask.logical_not(), -1e20)
25
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
26
- # x = torch.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
27
- x = rearrange(x, "B H L D -> B L (H D)")
28
- elif backend == 'flash_attn':
29
- # q: (B, H, L, D)
30
- # k: (B, H, S, D) now L = S
31
- # v: (B, H, S, D)
32
- b, h, lq, d = q.shape
33
- _, _, lk, _ = k.shape
34
- q = rearrange(q, "B H L D -> B L H D")
35
- k = rearrange(k, "B H S D -> B S H D")
36
- v = rearrange(v, "B H S D -> B S H D")
37
- if mask is None:
38
- q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(q.device, non_blocking=True)
39
- k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(k.device, non_blocking=True)
40
- else:
41
- q_lens = torch.sum(mask[:, 0, :, 0], dim=1).int()
42
- k_lens = torch.sum(mask[:, 0, 0, :], dim=1).int()
43
- q = torch.cat([q_v[:q_l] for q_v, q_l in zip(q, q_lens)])
44
- k = torch.cat([k_v[:k_l] for k_v, k_l in zip(k, k_lens)])
45
- v = torch.cat([v_v[:v_l] for v_v, v_l in zip(v, k_lens)])
46
- cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
47
- cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
48
- max_seqlen_q = q_lens.max()
49
- max_seqlen_k = k_lens.max()
50
-
51
- x = flash_attn_varlen_func(
52
- q,
53
- k,
54
- v,
55
- cu_seqlens_q=cu_seqlens_q,
56
- cu_seqlens_k=cu_seqlens_k,
57
- max_seqlen_q=max_seqlen_q,
58
- max_seqlen_k=max_seqlen_k
59
- )
60
- x_list = [x[cu_seqlens_q[i]:cu_seqlens_q[i+1]] for i in range(b)]
61
- x = pad_sequence(tuple(x_list), batch_first=True)
62
- x = rearrange(x, "B L H D -> B L (H D)")
63
- else:
64
- raise NotImplementedError
65
- return x
66
-
67
-
68
- def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
69
- assert dim % 2 == 0
70
- scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
71
- omega = 1.0 / (theta**scale)
72
- out = torch.einsum("...n,d->...nd", pos, omega)
73
- out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
74
- out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
75
- return out.float()
76
-
77
-
78
- def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
79
- xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
80
- xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
81
- xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
82
- xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
83
- return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
84
-
85
- class EmbedND(nn.Module):
86
- def __init__(self, dim: int, theta: int, axes_dim: list[int]):
87
- super().__init__()
88
- self.dim = dim
89
- self.theta = theta
90
- self.axes_dim = axes_dim
91
-
92
- def forward(self, ids: Tensor) -> Tensor:
93
- n_axes = ids.shape[-1]
94
- emb = torch.cat(
95
- [rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
96
- dim=-3,
97
- )
98
-
99
- return emb.unsqueeze(1)
100
-
101
-
102
- def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
103
- """
104
- Create sinusoidal timestep embeddings.
105
- :param t: a 1-D Tensor of N indices, one per batch element.
106
- These may be fractional.
107
- :param dim: the dimension of the output.
108
- :param max_period: controls the minimum frequency of the embeddings.
109
- :return: an (N, D) Tensor of positional embeddings.
110
- """
111
- t = time_factor * t
112
- half = dim // 2
113
- freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
114
- t.device
115
- )
116
-
117
- args = t[:, None].float() * freqs[None]
118
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
119
- if dim % 2:
120
- embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
121
- if torch.is_floating_point(t):
122
- embedding = embedding.to(t)
123
- return embedding
124
-
125
-
126
- class MLPEmbedder(nn.Module):
127
- def __init__(self, in_dim: int, hidden_dim: int):
128
- super().__init__()
129
- self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
130
- self.silu = nn.SiLU()
131
- self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
132
-
133
- def forward(self, x: Tensor) -> Tensor:
134
- return self.out_layer(self.silu(self.in_layer(x)))
135
-
136
-
137
- class RMSNorm(torch.nn.Module):
138
- def __init__(self, dim: int):
139
- super().__init__()
140
- self.scale = nn.Parameter(torch.ones(dim))
141
-
142
- def forward(self, x: Tensor):
143
- x_dtype = x.dtype
144
- x = x.float()
145
- rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
146
- return (x * rrms).to(dtype=x_dtype) * self.scale
147
-
148
-
149
- class QKNorm(torch.nn.Module):
150
- def __init__(self, dim: int):
151
- super().__init__()
152
- self.query_norm = RMSNorm(dim)
153
- self.key_norm = RMSNorm(dim)
154
-
155
- def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
156
- q = self.query_norm(q)
157
- k = self.key_norm(k)
158
- return q.to(v), k.to(v)
159
-
160
-
161
- class SelfAttention(nn.Module):
162
- def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
163
- super().__init__()
164
- self.num_heads = num_heads
165
- head_dim = dim // num_heads
166
-
167
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
168
- self.norm = QKNorm(head_dim)
169
- self.proj = nn.Linear(dim, dim)
170
-
171
- def forward(self, x: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
172
- qkv = self.qkv(x)
173
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
174
- q, k = self.norm(q, k, v)
175
- x = attention(q, k, v, pe=pe, mask=mask)
176
- x = self.proj(x)
177
- return x
178
-
179
- class CrossAttention(nn.Module):
180
- def __init__(self, dim: int, context_dim: int, num_heads: int = 8, qkv_bias: bool = False):
181
- super().__init__()
182
- self.num_heads = num_heads
183
- head_dim = dim // num_heads
184
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
185
- self.kv = nn.Linear(dim, context_dim * 2, bias=qkv_bias)
186
- self.norm = QKNorm(head_dim)
187
- self.proj = nn.Linear(dim, dim)
188
-
189
- def forward(self, x: Tensor, context: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
190
- qkv = self.qkv(x)
191
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
192
- q, k = self.norm(q, k, v)
193
- x = attention(q, k, v, pe=pe, mask=mask)
194
- x = self.proj(x)
195
- return x
196
-
197
-
198
- @dataclass
199
- class ModulationOut:
200
- shift: Tensor
201
- scale: Tensor
202
- gate: Tensor
203
-
204
-
205
- class Modulation(nn.Module):
206
- def __init__(self, dim: int, double: bool):
207
- super().__init__()
208
- self.is_double = double
209
- self.multiplier = 6 if double else 3
210
- self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
211
-
212
- def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
213
- out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
214
-
215
- return (
216
- ModulationOut(*out[:3]),
217
- ModulationOut(*out[3:]) if self.is_double else None,
218
- )
219
-
220
-
221
- class DoubleStreamBlock(nn.Module):
222
- def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, backend = 'pytorch'):
223
- super().__init__()
224
-
225
- mlp_hidden_dim = int(hidden_size * mlp_ratio)
226
- self.num_heads = num_heads
227
- self.hidden_size = hidden_size
228
- self.img_mod = Modulation(hidden_size, double=True)
229
- self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
230
- self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
231
-
232
- self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
233
- self.img_mlp = nn.Sequential(
234
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
235
- nn.GELU(approximate="tanh"),
236
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
237
- )
238
-
239
- self.backend = backend
240
-
241
- self.txt_mod = Modulation(hidden_size, double=True)
242
- self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
243
- self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
244
-
245
- self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
246
- self.txt_mlp = nn.Sequential(
247
- nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
248
- nn.GELU(approximate="tanh"),
249
- nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
250
- )
251
-
252
-
253
-
254
-
255
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None, txt_length = None):
256
- img_mod1, img_mod2 = self.img_mod(vec)
257
- txt_mod1, txt_mod2 = self.txt_mod(vec)
258
-
259
- txt, img = x[:, :txt_length], x[:, txt_length:]
260
-
261
- # prepare image for attention
262
- img_modulated = self.img_norm1(img)
263
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
264
- img_qkv = self.img_attn.qkv(img_modulated)
265
- img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
266
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
267
- # prepare txt for attention
268
- txt_modulated = self.txt_norm1(txt)
269
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
270
- txt_qkv = self.txt_attn.qkv(txt_modulated)
271
- txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
272
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
273
-
274
- # run actual attention
275
- q = torch.cat((txt_q, img_q), dim=2)
276
- k = torch.cat((txt_k, img_k), dim=2)
277
- v = torch.cat((txt_v, img_v), dim=2)
278
- if mask is not None:
279
- mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
280
- attn = attention(q, k, v, pe=pe, mask = mask, backend = self.backend)
281
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
282
-
283
- # calculate the img bloks
284
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
285
- img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
286
-
287
- # calculate the txt bloks
288
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
289
- txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
290
- x = torch.cat((txt, img), 1)
291
- return x
292
-
293
-
294
- class SingleStreamBlock(nn.Module):
295
- """
296
- A DiT block with parallel linear layers as described in
297
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
298
- """
299
-
300
- def __init__(
301
- self,
302
- hidden_size: int,
303
- num_heads: int,
304
- mlp_ratio: float = 4.0,
305
- qk_scale: float | None = None,
306
- backend='pytorch'
307
- ):
308
- super().__init__()
309
- self.hidden_dim = hidden_size
310
- self.num_heads = num_heads
311
- head_dim = hidden_size // num_heads
312
- self.scale = qk_scale or head_dim**-0.5
313
-
314
- self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
315
- # qkv and mlp_in
316
- self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
317
- # proj and mlp_out
318
- self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
319
-
320
- self.norm = QKNorm(head_dim)
321
-
322
- self.hidden_size = hidden_size
323
- self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
324
-
325
- self.mlp_act = nn.GELU(approximate="tanh")
326
- self.modulation = Modulation(hidden_size, double=False)
327
- self.backend = backend
328
-
329
- def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None) -> Tensor:
330
- mod, _ = self.modulation(vec)
331
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
332
- qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
333
-
334
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
335
- q, k = self.norm(q, k, v)
336
- if mask is not None:
337
- mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
338
- # compute attention
339
- attn = attention(q, k, v, pe=pe, mask = mask, backend=self.backend)
340
- # compute activation in mlp stream, cat again and run second linear layer
341
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
342
- return x + mod.gate * output
343
-
344
-
345
- class DoubleStreamBlockACE(DoubleStreamBlock):
346
- def forward(self,
347
- x: Tensor,
348
- vec: Tensor,
349
- pe: Tensor,
350
- edit_vec: Tensor | None = None,
351
- mask: Tensor = None,
352
- txt_length = None,
353
- x_length = None):
354
- img_mod1, img_mod2 = self.img_mod(vec)
355
- txt_mod1, txt_mod2 = self.txt_mod(vec)
356
- if edit_vec is not None:
357
- edit_mod1, edit_mod2 = self.img_mod(edit_vec)
358
- txt, img, edit = x[:, :txt_length], x[:, txt_length:txt_length+x_length], x[:, txt_length+x_length:]
359
- else:
360
- edit_mod1, edit_mod2 = None, None
361
- txt, img = x[:, :txt_length], x[:, txt_length:]
362
- edit = None
363
-
364
-
365
- # prepare image for attention
366
- img_modulated = self.img_norm1(img)
367
- img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
368
- img_qkv = self.img_attn.qkv(img_modulated)
369
- img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
370
- img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
371
- # prepare txt for attention
372
- txt_modulated = self.txt_norm1(txt)
373
- txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
374
- txt_qkv = self.txt_attn.qkv(txt_modulated)
375
- txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
376
- txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
377
- # prepare edit for attention
378
- if edit_vec is not None:
379
- edit_modulated = self.img_norm1(edit)
380
- edit_modulated = (1 + edit_mod1.scale) * edit_modulated + edit_mod1.shift
381
- edit_qkv = self.img_attn.qkv(edit_modulated)
382
- edit_q, edit_k, edit_v = rearrange(edit_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
383
- edit_q, edit_k = self.img_attn.norm(edit_q, edit_k, edit_v)
384
- q = torch.cat((txt_q, img_q, edit_q), dim=2)
385
- k = torch.cat((txt_k, img_k, edit_k), dim=2)
386
- v = torch.cat((txt_v, img_v, edit_v), dim=2)
387
- else:
388
- q = torch.cat((txt_q, img_q), dim=2)
389
- k = torch.cat((txt_k, img_k), dim=2)
390
- v = torch.cat((txt_v, img_v), dim=2)
391
-
392
- # run actual attention
393
- if mask is not None:
394
- mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
395
- attn = attention(q, k, v, pe=pe, mask = mask, backend = "pytorch")
396
- if edit_vec is not None:
397
- txt_attn, img_attn, edit_attn = (attn[:, : txt.shape[1]],
398
- attn[:, txt.shape[1] : txt.shape[1]+img.shape[1]],
399
- attn[:, txt.shape[1]+img.shape[1]:])
400
- # calculate the img bloks
401
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
402
- img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
403
-
404
- # calculate the img bloks
405
- edit = edit + edit_mod1.gate * self.img_attn.proj(edit_attn)
406
- edit = edit + edit_mod2.gate * self.img_mlp((1 + edit_mod2.scale) * self.img_norm2(edit) + edit_mod2.shift)
407
-
408
- # calculate the txt bloks
409
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
410
- txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
411
-
412
- x = torch.cat((txt, img, edit), 1)
413
- else:
414
- txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
415
- # calculate the img bloks
416
- img = img + img_mod1.gate * self.img_attn.proj(img_attn)
417
- img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
418
-
419
- # calculate the txt bloks
420
- txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
421
- txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
422
- x = torch.cat((txt, img), 1)
423
- return x
424
-
425
-
426
- class SingleStreamBlockACE(SingleStreamBlock):
427
- """
428
- A DiT block with parallel linear layers as described in
429
- https://arxiv.org/abs/2302.05442 and adapted modulation interface.
430
- """
431
-
432
- def forward(self, x: Tensor, vec: Tensor,
433
- pe: Tensor, mask: Tensor = None,
434
- edit_vec: Tensor | None = None,
435
- txt_length=None,
436
- x_length=None
437
- ) -> Tensor:
438
- mod, _ = self.modulation(vec)
439
- if edit_vec is not None:
440
- x, edit = x[:, :txt_length + x_length], x[:, txt_length + x_length:]
441
- e_mod, _ = self.modulation(edit_vec)
442
- edit_mod = (1 + e_mod.scale) * self.pre_norm(edit) + e_mod.shift
443
- edit_qkv, edit_mlp = torch.split(self.linear1(edit_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
444
-
445
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
446
- qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
447
- qkv, mlp = torch.cat([qkv, edit_qkv], 1), torch.cat([mlp, edit_mlp], 1)
448
- else:
449
- x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
450
- qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
451
-
452
- q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
453
- q, k = self.norm(q, k, v)
454
- if mask is not None:
455
- mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
456
- # compute attention
457
- attn = attention(q, k, v, pe=pe, mask = mask, backend="pytorch")
458
- # compute activation in mlp stream, cat again and run second linear layer
459
- output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
460
-
461
- if edit_vec is not None:
462
- x_output, edit_output = output.split([x.shape[1], edit.shape[1]], dim = 1)
463
- x = x + mod.gate * x_output
464
- edit = edit + e_mod.gate * edit_output
465
- x = torch.cat((x, edit), 1)
466
- return x
467
- else:
468
- return x + mod.gate * output
469
-
470
-
471
- class LastLayer(nn.Module):
472
- def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
473
- super().__init__()
474
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
475
- self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
476
- self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
477
-
478
- def forward(self, x: Tensor, vec: Tensor) -> Tensor:
479
- shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
480
- x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
481
- x = self.linear(x)
482
- return x
483
-
484
-
485
- if __name__ == '__main__':
486
- pe = EmbedND(dim=64, theta=10000, axes_dim=[16, 56, 56])
487
-
488
- ix_id = torch.zeros(64 // 2, 64 // 2, 3)
489
- ix_id[..., 1] = ix_id[..., 1] + torch.arange(64 // 2)[:, None]
490
- ix_id[..., 2] = ix_id[..., 2] + torch.arange(64 // 2)[None, :]
491
- ix_id = rearrange(ix_id, "h w c -> 1 (h w) c")
492
- pos = torch.cat([ix_id, ix_id], dim = 1)
493
- a = pe(pos)
494
-
495
- b = torch.cat([pe(ix_id), pe(ix_id)], dim = 2)
496
-
497
- print(a - b)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- huggingface_hub
2
- diffusers
3
- transformers
4
- gradio_imageslider
5
- torch==2.4.0
6
- xformers==0.0.27.post2
7
- torchvision
8
- gradio==4.44.1
 
 
 
 
 
 
 
 
 
utils.py DELETED
@@ -1,95 +0,0 @@
1
- #copyright (c) Alibaba, Inc. and its affiliates.
2
- import torch
3
- import torchvision.transforms as T
4
- from PIL import Image
5
- from torchvision.transforms.functional import InterpolationMode
6
-
7
- IMAGENET_MEAN = (0.485, 0.456, 0.406)
8
- IMAGENET_STD = (0.229, 0.224, 0.225)
9
-
10
-
11
- def build_transform(input_size):
12
- MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
13
- transform = T.Compose([
14
- T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
15
- T.Resize((input_size, input_size),
16
- interpolation=InterpolationMode.BICUBIC),
17
- T.ToTensor(),
18
- T.Normalize(mean=MEAN, std=STD)
19
- ])
20
- return transform
21
-
22
-
23
- def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
24
- image_size):
25
- best_ratio_diff = float('inf')
26
- best_ratio = (1, 1)
27
- area = width * height
28
- for ratio in target_ratios:
29
- target_aspect_ratio = ratio[0] / ratio[1]
30
- ratio_diff = abs(aspect_ratio - target_aspect_ratio)
31
- if ratio_diff < best_ratio_diff:
32
- best_ratio_diff = ratio_diff
33
- best_ratio = ratio
34
- elif ratio_diff == best_ratio_diff:
35
- if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
36
- best_ratio = ratio
37
- return best_ratio
38
-
39
-
40
- def dynamic_preprocess(image,
41
- min_num=1,
42
- max_num=12,
43
- image_size=448,
44
- use_thumbnail=False):
45
- orig_width, orig_height = image.size
46
- aspect_ratio = orig_width / orig_height
47
-
48
- # calculate the existing image aspect ratio
49
- target_ratios = set((i, j) for n in range(min_num, max_num + 1)
50
- for i in range(1, n + 1) for j in range(1, n + 1)
51
- if i * j <= max_num and i * j >= min_num)
52
- target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
53
-
54
- # find the closest aspect ratio to the target
55
- target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
56
- target_ratios, orig_width,
57
- orig_height, image_size)
58
-
59
- # calculate the target width and height
60
- target_width = image_size * target_aspect_ratio[0]
61
- target_height = image_size * target_aspect_ratio[1]
62
- blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
63
-
64
- # resize the image
65
- resized_img = image.resize((target_width, target_height))
66
- processed_images = []
67
- for i in range(blocks):
68
- box = ((i % (target_width // image_size)) * image_size,
69
- (i // (target_width // image_size)) * image_size,
70
- ((i % (target_width // image_size)) + 1) * image_size,
71
- ((i // (target_width // image_size)) + 1) * image_size)
72
- # split the image
73
- split_img = resized_img.crop(box)
74
- processed_images.append(split_img)
75
- assert len(processed_images) == blocks
76
- if use_thumbnail and len(processed_images) != 1:
77
- thumbnail_img = image.resize((image_size, image_size))
78
- processed_images.append(thumbnail_img)
79
- return processed_images
80
-
81
-
82
- def load_image(image_file, input_size=448, max_num=12):
83
- if isinstance(image_file, str):
84
- image = Image.open(image_file).convert('RGB')
85
- else:
86
- image = image_file
87
- transform = build_transform(input_size=input_size)
88
- images = dynamic_preprocess(image,
89
- image_size=input_size,
90
- use_thumbnail=True,
91
- max_num=max_num)
92
- pixel_values = [transform(image) for image in images]
93
- pixel_values = torch.stack(pixel_values)
94
- return pixel_values
95
-