Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
83266af
1
Parent(s):
9cee51c
modify old files
Browse files- __init__.py +0 -1
- ace_flux_inference.py +0 -329
- app.py +0 -1428
- config/chatbot_ui.yaml +0 -25
- config/models/ace_flux_dev.yaml +0 -187
- example.py +0 -370
- models/__init__.py +0 -2
- models/embedder.py +0 -383
- models/flux.py +0 -798
- models/layers.py +0 -497
- requirements.txt +0 -8
- utils.py +0 -95
__init__.py
DELETED
@@ -1 +0,0 @@
|
|
1 |
-
from . import models
|
|
|
|
ace_flux_inference.py
DELETED
@@ -1,329 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import math
|
4 |
-
import os
|
5 |
-
import random
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import torch.nn.functional as F
|
9 |
-
from PIL import Image
|
10 |
-
import torchvision.transforms as T
|
11 |
-
from scepter.modules.model.registry import DIFFUSIONS, BACKBONES
|
12 |
-
import torchvision.transforms.functional as TF
|
13 |
-
from scepter.modules.model.utils.basic_utils import check_list_of_list
|
14 |
-
from scepter.modules.model.utils.basic_utils import \
|
15 |
-
pack_imagelist_into_tensor_v2 as pack_imagelist_into_tensor
|
16 |
-
from scepter.modules.model.utils.basic_utils import (
|
17 |
-
to_device, unpack_tensor_into_imagelist)
|
18 |
-
from scepter.modules.utils.distribute import we
|
19 |
-
from scepter.modules.utils.file_system import FS
|
20 |
-
from scepter.modules.utils.logger import get_logger
|
21 |
-
from scepter.modules.inference.diffusion_inference import DiffusionInference, get_model
|
22 |
-
|
23 |
-
def process_edit_image(images,
|
24 |
-
masks,
|
25 |
-
tasks):
|
26 |
-
|
27 |
-
if not isinstance(images, list):
|
28 |
-
images = [images]
|
29 |
-
if not isinstance(masks, list):
|
30 |
-
masks = [masks]
|
31 |
-
if not isinstance(tasks, list):
|
32 |
-
tasks = [tasks]
|
33 |
-
|
34 |
-
img_tensors = []
|
35 |
-
mask_tensors = []
|
36 |
-
for img, mask, task in zip(images, masks, tasks):
|
37 |
-
if mask is None or mask == '':
|
38 |
-
mask = Image.new('L', img.size, 0)
|
39 |
-
img = TF.center_crop(img, [512, 512])
|
40 |
-
mask = TF.center_crop(mask, [512, 512])
|
41 |
-
|
42 |
-
mask = np.asarray(mask)
|
43 |
-
mask = np.where(mask > 128, 1, 0)
|
44 |
-
mask = mask.astype(
|
45 |
-
np.float32) if np.any(mask) else np.ones_like(mask).astype(
|
46 |
-
np.float32)
|
47 |
-
|
48 |
-
img_tensor = TF.to_tensor(img).to(we.device_id)
|
49 |
-
img_tensor = TF.normalize(img_tensor,
|
50 |
-
mean=[0.5, 0.5, 0.5],
|
51 |
-
std=[0.5, 0.5, 0.5])
|
52 |
-
mask_tensor = TF.to_tensor(mask).to(we.device_id)
|
53 |
-
if task in ['inpainting', 'Try On', 'Inpainting']:
|
54 |
-
mask_indicator = mask_tensor.repeat(3, 1, 1)
|
55 |
-
img_tensor[mask_indicator == 1] = -1.0
|
56 |
-
img_tensors.append(img_tensor)
|
57 |
-
mask_tensors.append(mask_tensor)
|
58 |
-
return img_tensors, mask_tensors
|
59 |
-
|
60 |
-
class FluxACEInference(DiffusionInference):
|
61 |
-
|
62 |
-
def __init__(self, logger=None):
|
63 |
-
if logger is None:
|
64 |
-
logger = get_logger(name='scepter')
|
65 |
-
self.logger = logger
|
66 |
-
self.loaded_model = {}
|
67 |
-
self.loaded_model_name = [
|
68 |
-
'diffusion_model', 'first_stage_model', 'cond_stage_model', 'ref_cond_stage_model'
|
69 |
-
]
|
70 |
-
|
71 |
-
def init_from_cfg(self, cfg):
|
72 |
-
self.name = cfg.NAME
|
73 |
-
self.is_default = cfg.get('IS_DEFAULT', False)
|
74 |
-
self.use_dynamic_model = cfg.get('USE_DYNAMIC_MODEL', True)
|
75 |
-
module_paras = self.load_default(cfg.get('DEFAULT_PARAS', None))
|
76 |
-
assert cfg.have('MODEL')
|
77 |
-
self.size_factor = cfg.get('SIZE_FACTOR', 8)
|
78 |
-
self.diffusion_model = self.infer_model(
|
79 |
-
cfg.MODEL.DIFFUSION_MODEL, module_paras.get(
|
80 |
-
'DIFFUSION_MODEL',
|
81 |
-
None)) if cfg.MODEL.have('DIFFUSION_MODEL') else None
|
82 |
-
self.first_stage_model = self.infer_model(
|
83 |
-
cfg.MODEL.FIRST_STAGE_MODEL,
|
84 |
-
module_paras.get(
|
85 |
-
'FIRST_STAGE_MODEL',
|
86 |
-
None)) if cfg.MODEL.have('FIRST_STAGE_MODEL') else None
|
87 |
-
self.cond_stage_model = self.infer_model(
|
88 |
-
cfg.MODEL.COND_STAGE_MODEL,
|
89 |
-
module_paras.get(
|
90 |
-
'COND_STAGE_MODEL',
|
91 |
-
None)) if cfg.MODEL.have('COND_STAGE_MODEL') else None
|
92 |
-
|
93 |
-
self.ref_cond_stage_model = self.infer_model(
|
94 |
-
cfg.MODEL.REF_COND_STAGE_MODEL,
|
95 |
-
module_paras.get(
|
96 |
-
'REF_COND_STAGE_MODEL',
|
97 |
-
None)) if cfg.MODEL.have('REF_COND_STAGE_MODEL') else None
|
98 |
-
|
99 |
-
self.diffusion = DIFFUSIONS.build(cfg.MODEL.DIFFUSION,
|
100 |
-
logger=self.logger)
|
101 |
-
self.interpolate_func = lambda x: (F.interpolate(
|
102 |
-
x.unsqueeze(0),
|
103 |
-
scale_factor=1 / self.size_factor,
|
104 |
-
mode='nearest-exact') if x is not None else None)
|
105 |
-
|
106 |
-
self.max_seq_length = cfg.get("MAX_SEQ_LENGTH", 4096)
|
107 |
-
if not self.use_dynamic_model:
|
108 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
109 |
-
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
110 |
-
if self.ref_cond_stage_model is not None: self.dynamic_load(self.ref_cond_stage_model, 'ref_cond_stage_model')
|
111 |
-
with torch.device("meta"):
|
112 |
-
pretrained_model = self.diffusion_model['cfg'].PRETRAINED_MODEL
|
113 |
-
self.diffusion_model['cfg'].PRETRAINED_MODEL = None
|
114 |
-
diffusers_lora = self.diffusion_model['cfg'].get("DIFFUSERS_LORA_MODEL", None)
|
115 |
-
self.diffusion_model['cfg'].DIFFUSERS_LORA_MODEL = None
|
116 |
-
swift_lora = self.diffusion_model['cfg'].get("SWIFT_LORA_MODEL", None)
|
117 |
-
self.diffusion_model['cfg'].SWIFT_LORA_MODEL = None
|
118 |
-
pretrain_adapter = self.diffusion_model['cfg'].get("PRETRAIN_ADAPTER", None)
|
119 |
-
self.diffusion_model['cfg'].PRETRAIN_ADAPTER = None
|
120 |
-
blackforest_lora = self.diffusion_model['cfg'].get("BLACKFOREST_LORA_MODEL", None)
|
121 |
-
self.diffusion_model['cfg'].BLACKFOREST_LORA_MODEL = None
|
122 |
-
self.diffusion_model['model'] = BACKBONES.build(self.diffusion_model['cfg'], logger=self.logger).eval()
|
123 |
-
# self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
124 |
-
self.diffusion_model['model'].lora_model = diffusers_lora
|
125 |
-
self.diffusion_model['model'].swift_lora_model = swift_lora
|
126 |
-
self.diffusion_model['model'].pretrain_adapter = pretrain_adapter
|
127 |
-
self.diffusion_model['model'].blackforest_lora_model = blackforest_lora
|
128 |
-
self.diffusion_model['model'].load_pretrained_model(pretrained_model)
|
129 |
-
self.diffusion_model['device'] = we.device_id
|
130 |
-
|
131 |
-
def upscale_resize(self, image, interpolation=T.InterpolationMode.BILINEAR):
|
132 |
-
c, H, W = image.shape
|
133 |
-
scale = max(1.0, math.sqrt(self.max_seq_length / ((H / 16) * (W / 16))))
|
134 |
-
rH = int(H * scale) // 16 * 16 # ensure divisible by self.d
|
135 |
-
rW = int(W * scale) // 16 * 16
|
136 |
-
image = T.Resize((rH, rW), interpolation=interpolation, antialias=True)(image)
|
137 |
-
return image
|
138 |
-
|
139 |
-
|
140 |
-
@torch.no_grad()
|
141 |
-
def encode_first_stage(self, x, **kwargs):
|
142 |
-
_, dtype = self.get_function_info(self.first_stage_model, 'encode')
|
143 |
-
with torch.autocast('cuda',
|
144 |
-
enabled=dtype in ('float16', 'bfloat16'),
|
145 |
-
dtype=getattr(torch, dtype)):
|
146 |
-
def run_one_image(u):
|
147 |
-
zu = get_model(self.first_stage_model).encode(u)
|
148 |
-
if isinstance(zu, (tuple, list)):
|
149 |
-
zu = zu[0]
|
150 |
-
return zu
|
151 |
-
|
152 |
-
z = [run_one_image(u.unsqueeze(0) if u.dim() == 3 else u) for u in x]
|
153 |
-
return z
|
154 |
-
|
155 |
-
|
156 |
-
@torch.no_grad()
|
157 |
-
def decode_first_stage(self, z):
|
158 |
-
_, dtype = self.get_function_info(self.first_stage_model, 'decode')
|
159 |
-
with torch.autocast('cuda',
|
160 |
-
enabled=dtype in ('float16', 'bfloat16'),
|
161 |
-
dtype=getattr(torch, dtype)):
|
162 |
-
return [get_model(self.first_stage_model).decode(zu) for zu in z]
|
163 |
-
|
164 |
-
def noise_sample(self, num_samples, h, w, seed, device = None, dtype = torch.bfloat16):
|
165 |
-
noise = torch.randn(
|
166 |
-
num_samples,
|
167 |
-
16,
|
168 |
-
# allow for packing
|
169 |
-
2 * math.ceil(h / 16),
|
170 |
-
2 * math.ceil(w / 16),
|
171 |
-
device="cpu",
|
172 |
-
dtype=dtype,
|
173 |
-
generator=torch.Generator().manual_seed(seed),
|
174 |
-
).to(device)
|
175 |
-
return noise
|
176 |
-
|
177 |
-
@torch.no_grad()
|
178 |
-
def __call__(self,
|
179 |
-
image=None,
|
180 |
-
mask=None,
|
181 |
-
prompt='',
|
182 |
-
task=None,
|
183 |
-
negative_prompt='',
|
184 |
-
output_height=1024,
|
185 |
-
output_width=1024,
|
186 |
-
sampler='flow_euler',
|
187 |
-
sample_steps=20,
|
188 |
-
guide_scale=3.5,
|
189 |
-
seed=-1,
|
190 |
-
history_io=None,
|
191 |
-
tar_index=0,
|
192 |
-
# align=0,
|
193 |
-
**kwargs):
|
194 |
-
input_image, input_mask = image, mask
|
195 |
-
seed = seed if seed >= 0 else random.randint(0, 2**32 - 1)
|
196 |
-
if input_image is not None:
|
197 |
-
# assert isinstance(input_image, list) and isinstance(input_mask, list)
|
198 |
-
if task is None:
|
199 |
-
task = [''] * len(input_image)
|
200 |
-
if not isinstance(prompt, list):
|
201 |
-
prompt = [prompt] * len(input_image)
|
202 |
-
prompt = [
|
203 |
-
pp.replace('{image}', f'{{image{i}}}') if i > 0 else pp
|
204 |
-
for i, pp in enumerate(prompt)
|
205 |
-
]
|
206 |
-
edit_image, edit_image_mask = process_edit_image(
|
207 |
-
input_image, input_mask, task)
|
208 |
-
image = torch.zeros(
|
209 |
-
size=[3, int(output_height),
|
210 |
-
int(output_width)])
|
211 |
-
image_mask = torch.ones(
|
212 |
-
size=[1, int(output_height),
|
213 |
-
int(output_width)])
|
214 |
-
edit_image, edit_image_mask = [edit_image], [edit_image_mask]
|
215 |
-
else:
|
216 |
-
edit_image = edit_image_mask = [[]]
|
217 |
-
image = torch.zeros(
|
218 |
-
size=[3, int(output_height),
|
219 |
-
int(output_width)])
|
220 |
-
image_mask = torch.ones(
|
221 |
-
size=[1, int(output_height),
|
222 |
-
int(output_width)])
|
223 |
-
if not isinstance(prompt, list):
|
224 |
-
prompt = [prompt]
|
225 |
-
align = 0
|
226 |
-
image, image_mask, prompt = [image], [image_mask], [prompt],
|
227 |
-
align = [align for p in prompt] if isinstance(align, int) else align
|
228 |
-
|
229 |
-
assert check_list_of_list(prompt) and check_list_of_list(
|
230 |
-
edit_image) and check_list_of_list(edit_image_mask)
|
231 |
-
# negative prompt is not used
|
232 |
-
image = to_device(image)
|
233 |
-
ctx = {}
|
234 |
-
# Get Noise Shape
|
235 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
236 |
-
x = self.encode_first_stage(image)
|
237 |
-
self.dynamic_unload(self.first_stage_model,
|
238 |
-
'first_stage_model',
|
239 |
-
skip_loaded=not self.use_dynamic_model)
|
240 |
-
|
241 |
-
g = torch.Generator(device=we.device_id).manual_seed(seed)
|
242 |
-
noise = [
|
243 |
-
torch.randn((1, 16, i.shape[2], i.shape[3]), device=we.device_id, dtype=torch.bfloat16).normal_(generator=g)
|
244 |
-
for i in x
|
245 |
-
]
|
246 |
-
# import pdb;pdb.set_trace()
|
247 |
-
noise, x_shapes = pack_imagelist_into_tensor(noise)
|
248 |
-
ctx['x_shapes'] = x_shapes
|
249 |
-
ctx['align'] = align
|
250 |
-
|
251 |
-
image_mask = to_device(image_mask, strict=False)
|
252 |
-
cond_mask = [self.interpolate_func(i) for i in image_mask
|
253 |
-
] if image_mask is not None else [None] * len(image)
|
254 |
-
ctx['x_mask'] = cond_mask
|
255 |
-
# Encode Prompt
|
256 |
-
instruction_prompt = [[pp[-1]] if "{image}" in pp[-1] else ["{image} " + pp[-1]] for pp in prompt]
|
257 |
-
self.dynamic_load(self.cond_stage_model, 'cond_stage_model')
|
258 |
-
function_name, dtype = self.get_function_info(self.cond_stage_model)
|
259 |
-
cont = getattr(get_model(self.cond_stage_model), function_name)(instruction_prompt)
|
260 |
-
cont["context"] = [ct[-1] for ct in cont["context"]]
|
261 |
-
cont["y"] = [ct[-1] for ct in cont["y"]]
|
262 |
-
self.dynamic_unload(self.cond_stage_model,
|
263 |
-
'cond_stage_model',
|
264 |
-
skip_loaded=not self.use_dynamic_model)
|
265 |
-
ctx.update(cont)
|
266 |
-
|
267 |
-
# Encode Edit Images
|
268 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
269 |
-
edit_image = [to_device(i, strict=False) for i in edit_image]
|
270 |
-
edit_image_mask = [to_device(i, strict=False) for i in edit_image_mask]
|
271 |
-
e_img, e_mask = [], []
|
272 |
-
for u, m in zip(edit_image, edit_image_mask):
|
273 |
-
if u is None:
|
274 |
-
continue
|
275 |
-
if m is None:
|
276 |
-
m = [None] * len(u)
|
277 |
-
e_img.append(self.encode_first_stage(u, **kwargs))
|
278 |
-
e_mask.append([self.interpolate_func(i) for i in m])
|
279 |
-
self.dynamic_unload(self.first_stage_model,
|
280 |
-
'first_stage_model',
|
281 |
-
skip_loaded=not self.use_dynamic_model)
|
282 |
-
ctx['edit'] = e_img
|
283 |
-
ctx['edit_mask'] = e_mask
|
284 |
-
# Encode Ref Images
|
285 |
-
if guide_scale is not None:
|
286 |
-
guide_scale = torch.full((noise.shape[0],), guide_scale, device=noise.device, dtype=noise.dtype)
|
287 |
-
else:
|
288 |
-
guide_scale = None
|
289 |
-
|
290 |
-
# Diffusion Process
|
291 |
-
self.dynamic_load(self.diffusion_model, 'diffusion_model')
|
292 |
-
function_name, dtype = self.get_function_info(self.diffusion_model)
|
293 |
-
with torch.autocast('cuda',
|
294 |
-
enabled=dtype in ('float16', 'bfloat16'),
|
295 |
-
dtype=getattr(torch, dtype)):
|
296 |
-
latent = self.diffusion.sample(
|
297 |
-
noise=noise,
|
298 |
-
sampler=sampler,
|
299 |
-
model=get_model(self.diffusion_model),
|
300 |
-
model_kwargs={
|
301 |
-
"cond": ctx, "guidance": guide_scale, "gc_seg": -1
|
302 |
-
},
|
303 |
-
steps=sample_steps,
|
304 |
-
show_progress=True,
|
305 |
-
guide_scale=guide_scale,
|
306 |
-
return_intermediate=None,
|
307 |
-
reverse_scale=-1,
|
308 |
-
**kwargs).float()
|
309 |
-
if self.use_dynamic_model: self.dynamic_unload(self.diffusion_model,
|
310 |
-
'diffusion_model',
|
311 |
-
skip_loaded=not self.use_dynamic_model)
|
312 |
-
|
313 |
-
# Decode to Pixel Space
|
314 |
-
self.dynamic_load(self.first_stage_model, 'first_stage_model')
|
315 |
-
samples = unpack_tensor_into_imagelist(latent, x_shapes)
|
316 |
-
x_samples = self.decode_first_stage(samples)
|
317 |
-
self.dynamic_unload(self.first_stage_model,
|
318 |
-
'first_stage_model',
|
319 |
-
skip_loaded=not self.use_dynamic_model)
|
320 |
-
x_samples = [x.squeeze(0) for x in x_samples]
|
321 |
-
|
322 |
-
imgs = [
|
323 |
-
torch.clamp((x_i.float() + 1.0) / 2.0,
|
324 |
-
min=0.0,
|
325 |
-
max=1.0).squeeze(0).permute(1, 2, 0).cpu().numpy()
|
326 |
-
for x_i in x_samples
|
327 |
-
]
|
328 |
-
imgs = [Image.fromarray((img * 255).astype(np.uint8)) for img in imgs]
|
329 |
-
return imgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
DELETED
@@ -1,1428 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import base64
|
4 |
-
import copy
|
5 |
-
import glob
|
6 |
-
import io
|
7 |
-
import os, csv, sys
|
8 |
-
import random
|
9 |
-
import re
|
10 |
-
import shlex
|
11 |
-
import string
|
12 |
-
import subprocess
|
13 |
-
import threading
|
14 |
-
import spaces
|
15 |
-
|
16 |
-
subprocess.run("rm -rf /data-nvme/zerogpu-offload/*", env={}, shell=True)
|
17 |
-
#subprocess.run(shlex.split('pip install flash-attn --no-build-isolation'),
|
18 |
-
# env=os.environ | {'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"})
|
19 |
-
subprocess.run(shlex.split('pip install scepter'))
|
20 |
-
|
21 |
-
import cv2
|
22 |
-
import gradio as gr
|
23 |
-
import numpy as np
|
24 |
-
import torch
|
25 |
-
import transformers
|
26 |
-
from PIL import Image
|
27 |
-
from transformers import AutoModel, AutoTokenizer
|
28 |
-
from ace_flux_inference import FluxACEInference
|
29 |
-
from scepter.modules.utils.config import Config
|
30 |
-
from scepter.modules.utils.directory import get_md5
|
31 |
-
from scepter.modules.utils.file_system import FS
|
32 |
-
from scepter.studio.utils.env import init_env
|
33 |
-
from importlib.metadata import version
|
34 |
-
|
35 |
-
from example import get_examples
|
36 |
-
from utils import load_image
|
37 |
-
from huggingface_hub import login
|
38 |
-
|
39 |
-
login(token=os.environ.get("HF_TOKEN", ""))
|
40 |
-
|
41 |
-
csv.field_size_limit(sys.maxsize)
|
42 |
-
|
43 |
-
refresh_sty = '\U0001f504' # 🔄
|
44 |
-
clear_sty = '\U0001f5d1' # 🗑️
|
45 |
-
upload_sty = '\U0001f5bc' # 🖼️
|
46 |
-
sync_sty = '\U0001f4be' # 💾
|
47 |
-
chat_sty = '\U0001F4AC' # 💬
|
48 |
-
video_sty = '\U0001f3a5' # 🎥
|
49 |
-
|
50 |
-
lock = threading.Lock()
|
51 |
-
inference_dict = {
|
52 |
-
"ACE_FLUX": FluxACEInference,
|
53 |
-
}
|
54 |
-
|
55 |
-
|
56 |
-
class ChatBotUI(object):
|
57 |
-
def __init__(self,
|
58 |
-
cfg_general_file,
|
59 |
-
is_debug=False,
|
60 |
-
language='en',
|
61 |
-
root_work_dir='./'):
|
62 |
-
try:
|
63 |
-
from diffusers import CogVideoXImageToVideoPipeline
|
64 |
-
from diffusers.utils import export_to_video
|
65 |
-
except Exception as e:
|
66 |
-
print(f"Import diffusers failed, please install or upgrade diffusers. Error information: {e}")
|
67 |
-
|
68 |
-
cfg = Config(cfg_file=cfg_general_file)
|
69 |
-
if cfg.have("FILE_SYSTEM"):
|
70 |
-
for file_sys in cfg.FILE_SYSTEM:
|
71 |
-
fs_prefix = FS.init_fs_client(file_sys)
|
72 |
-
else:
|
73 |
-
fs_prefix = FS.init_fs_client(cfg)
|
74 |
-
cfg.WORK_DIR = os.path.join(root_work_dir, cfg.WORK_DIR)
|
75 |
-
if not FS.exists(cfg.WORK_DIR):
|
76 |
-
FS.make_dir(cfg.WORK_DIR)
|
77 |
-
cfg = init_env(cfg)
|
78 |
-
self.cache_dir = cfg.WORK_DIR
|
79 |
-
self.chatbot_examples = get_examples(self.cache_dir) if not cfg.get('SKIP_EXAMPLES', False) else []
|
80 |
-
self.model_cfg_dir = cfg.MODEL.EDIT_MODEL.MODEL_CFG_DIR
|
81 |
-
self.model_yamls = glob.glob(os.path.join(self.model_cfg_dir,
|
82 |
-
'*.yaml'))
|
83 |
-
self.model_choices = dict()
|
84 |
-
self.default_model_name = ''
|
85 |
-
for i in self.model_yamls:
|
86 |
-
model_cfg = Config(load=True, cfg_file=i)
|
87 |
-
model_name = model_cfg.NAME
|
88 |
-
if model_cfg.IS_DEFAULT: self.default_model_name = model_name
|
89 |
-
self.model_choices[model_name] = model_cfg
|
90 |
-
print('Models: ', self.model_choices.keys())
|
91 |
-
local_folder = FS.get_dir_to_local_dir("hf://black-forest-labs/FLUX.1-dev")
|
92 |
-
subprocess.run(shlex.split(f'rm -rf {local_folder}/transformer'))
|
93 |
-
subprocess.run(shlex.split(f'rm -rf {local_folder}/vae'))
|
94 |
-
subprocess.run(shlex.split(f'rm -rf {local_folder}/flux1-dev.safetensors'))
|
95 |
-
|
96 |
-
assert len(self.model_choices) > 0
|
97 |
-
if self.default_model_name == "": self.default_model_name = list(self.model_choices.keys())[0]
|
98 |
-
self.model_name = self.default_model_name
|
99 |
-
pipe_cfg = self.model_choices[self.default_model_name]
|
100 |
-
infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
|
101 |
-
self.pipe = inference_dict[infer_name]()
|
102 |
-
self.pipe.init_from_cfg(pipe_cfg)
|
103 |
-
self.max_msgs = 20
|
104 |
-
self.enable_i2v = cfg.get('ENABLE_I2V', False)
|
105 |
-
self.gradio_version = version('gradio')
|
106 |
-
|
107 |
-
if self.enable_i2v:
|
108 |
-
self.i2v_model_dir = cfg.MODEL.I2V.MODEL_DIR
|
109 |
-
self.i2v_model_name = cfg.MODEL.I2V.MODEL_NAME
|
110 |
-
if self.i2v_model_name == 'CogVideoX-5b-I2V':
|
111 |
-
with FS.get_dir_to_local_dir(self.i2v_model_dir) as local_dir:
|
112 |
-
self.i2v_pipe = CogVideoXImageToVideoPipeline.from_pretrained(
|
113 |
-
local_dir, torch_dtype=torch.bfloat16).cuda()
|
114 |
-
else:
|
115 |
-
raise NotImplementedError
|
116 |
-
|
117 |
-
with FS.get_dir_to_local_dir(
|
118 |
-
cfg.MODEL.CAPTIONER.MODEL_DIR) as local_dir:
|
119 |
-
self.captioner = AutoModel.from_pretrained(
|
120 |
-
local_dir,
|
121 |
-
torch_dtype=torch.bfloat16,
|
122 |
-
low_cpu_mem_usage=True,
|
123 |
-
use_flash_attn=True,
|
124 |
-
trust_remote_code=True).eval().cuda()
|
125 |
-
self.llm_tokenizer = AutoTokenizer.from_pretrained(
|
126 |
-
local_dir, trust_remote_code=True, use_fast=False)
|
127 |
-
self.llm_generation_config = dict(max_new_tokens=1024,
|
128 |
-
do_sample=True)
|
129 |
-
self.llm_prompt = cfg.LLM.PROMPT
|
130 |
-
self.llm_max_num = 2
|
131 |
-
|
132 |
-
with FS.get_dir_to_local_dir(
|
133 |
-
cfg.MODEL.ENHANCER.MODEL_DIR) as local_dir:
|
134 |
-
self.enhancer = transformers.pipeline(
|
135 |
-
'text-generation',
|
136 |
-
model=local_dir,
|
137 |
-
model_kwargs={'torch_dtype': torch.bfloat16},
|
138 |
-
device_map='auto',
|
139 |
-
)
|
140 |
-
|
141 |
-
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
142 |
-
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
143 |
-
There are a few rules to follow:
|
144 |
-
You will only ever output a single video description per user request.
|
145 |
-
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
146 |
-
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
147 |
-
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
148 |
-
"""
|
149 |
-
self.enhance_ctx = [
|
150 |
-
{
|
151 |
-
'role': 'system',
|
152 |
-
'content': sys_prompt
|
153 |
-
},
|
154 |
-
{
|
155 |
-
'role':
|
156 |
-
'user',
|
157 |
-
'content':
|
158 |
-
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
|
159 |
-
},
|
160 |
-
{
|
161 |
-
'role':
|
162 |
-
'assistant',
|
163 |
-
'content':
|
164 |
-
"A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
|
165 |
-
},
|
166 |
-
{
|
167 |
-
'role':
|
168 |
-
'user',
|
169 |
-
'content':
|
170 |
-
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
|
171 |
-
},
|
172 |
-
{
|
173 |
-
'role':
|
174 |
-
'assistant',
|
175 |
-
'content':
|
176 |
-
"A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
|
177 |
-
},
|
178 |
-
{
|
179 |
-
'role':
|
180 |
-
'user',
|
181 |
-
'content':
|
182 |
-
'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
|
183 |
-
},
|
184 |
-
{
|
185 |
-
'role':
|
186 |
-
'assistant',
|
187 |
-
'content':
|
188 |
-
'A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.',
|
189 |
-
},
|
190 |
-
]
|
191 |
-
|
192 |
-
def create_ui(self):
|
193 |
-
|
194 |
-
css = '.chatbot.prose.md {opacity: 1.0 !important} #chatbot {opacity: 1.0 !important}'
|
195 |
-
with gr.Blocks(css=css,
|
196 |
-
title='Chatbot',
|
197 |
-
head='Chatbot',
|
198 |
-
analytics_enabled=False):
|
199 |
-
self.history = gr.State(value=[])
|
200 |
-
self.images = gr.State(value={})
|
201 |
-
self.history_result = gr.State(value={})
|
202 |
-
self.retry_msg = gr.State(value='')
|
203 |
-
with gr.Group():
|
204 |
-
self.ui_mode = gr.State(value='legacy')
|
205 |
-
with gr.Row(equal_height=True, visible=False) as self.chat_group:
|
206 |
-
with gr.Column(visible=True) as self.chat_page:
|
207 |
-
self.chatbot = gr.Chatbot(
|
208 |
-
height=600,
|
209 |
-
value=[],
|
210 |
-
bubble_full_width=False,
|
211 |
-
show_copy_button=True,
|
212 |
-
container=False,
|
213 |
-
placeholder='<strong>Chat Box</strong>')
|
214 |
-
with gr.Row():
|
215 |
-
self.clear_btn = gr.Button(clear_sty +
|
216 |
-
' Clear Chat',
|
217 |
-
size='sm')
|
218 |
-
|
219 |
-
with gr.Column(visible=False) as self.editor_page:
|
220 |
-
with gr.Tabs(visible=False) as self.upload_tabs:
|
221 |
-
with gr.Tab(id='ImageUploader',
|
222 |
-
label='Image Uploader',
|
223 |
-
visible=True) as self.upload_tab:
|
224 |
-
self.image_uploader = gr.Image(
|
225 |
-
height=550,
|
226 |
-
interactive=True,
|
227 |
-
type='pil',
|
228 |
-
image_mode='RGB',
|
229 |
-
sources=['upload'],
|
230 |
-
elem_id='image_uploader',
|
231 |
-
format='png')
|
232 |
-
with gr.Row():
|
233 |
-
self.sub_btn_1 = gr.Button(
|
234 |
-
value='Submit',
|
235 |
-
elem_id='upload_submit')
|
236 |
-
self.ext_btn_1 = gr.Button(value='Exit')
|
237 |
-
with gr.Tabs(visible=False) as self.edit_tabs:
|
238 |
-
with gr.Tab(id='ImageEditor',
|
239 |
-
label='Image Editor') as self.edit_tab:
|
240 |
-
self.mask_type = gr.Dropdown(
|
241 |
-
label='Mask Type',
|
242 |
-
choices=[
|
243 |
-
'Background', 'Composite',
|
244 |
-
'Outpainting'
|
245 |
-
],
|
246 |
-
value='Background')
|
247 |
-
self.mask_type_info = gr.HTML(
|
248 |
-
value=
|
249 |
-
"<div style='background-color: white; padding-left: 15px; color: grey;'>Background mode will not erase the visual content in the mask area</div>"
|
250 |
-
)
|
251 |
-
with gr.Accordion(
|
252 |
-
label='Outpainting Setting',
|
253 |
-
open=True,
|
254 |
-
visible=False) as self.outpaint_tab:
|
255 |
-
with gr.Row(variant='panel'):
|
256 |
-
self.top_ext = gr.Slider(
|
257 |
-
show_label=True,
|
258 |
-
label='Top Extend Ratio',
|
259 |
-
minimum=0.0,
|
260 |
-
maximum=2.0,
|
261 |
-
step=0.1,
|
262 |
-
value=0.25)
|
263 |
-
self.bottom_ext = gr.Slider(
|
264 |
-
show_label=True,
|
265 |
-
label='Bottom Extend Ratio',
|
266 |
-
minimum=0.0,
|
267 |
-
maximum=2.0,
|
268 |
-
step=0.1,
|
269 |
-
value=0.25)
|
270 |
-
with gr.Row(variant='panel'):
|
271 |
-
self.left_ext = gr.Slider(
|
272 |
-
show_label=True,
|
273 |
-
label='Left Extend Ratio',
|
274 |
-
minimum=0.0,
|
275 |
-
maximum=2.0,
|
276 |
-
step=0.1,
|
277 |
-
value=0.25)
|
278 |
-
self.right_ext = gr.Slider(
|
279 |
-
show_label=True,
|
280 |
-
label='Right Extend Ratio',
|
281 |
-
minimum=0.0,
|
282 |
-
maximum=2.0,
|
283 |
-
step=0.1,
|
284 |
-
value=0.25)
|
285 |
-
with gr.Row(variant='panel'):
|
286 |
-
self.img_pad_btn = gr.Button(
|
287 |
-
value='Pad Image')
|
288 |
-
|
289 |
-
self.image_editor = gr.ImageMask(
|
290 |
-
value=None,
|
291 |
-
sources=[],
|
292 |
-
layers=False,
|
293 |
-
label='Edit Image',
|
294 |
-
elem_id='image_editor',
|
295 |
-
format='png')
|
296 |
-
with gr.Row():
|
297 |
-
self.sub_btn_2 = gr.Button(
|
298 |
-
value='Submit', elem_id='edit_submit')
|
299 |
-
self.ext_btn_2 = gr.Button(value='Exit')
|
300 |
-
|
301 |
-
with gr.Tab(id='ImageViewer',
|
302 |
-
label='Image Viewer') as self.image_view_tab:
|
303 |
-
if self.gradio_version >= '5.0.0':
|
304 |
-
self.image_viewer = gr.Image(
|
305 |
-
label='Image',
|
306 |
-
type='pil',
|
307 |
-
show_download_button=True,
|
308 |
-
elem_id='image_viewer')
|
309 |
-
else:
|
310 |
-
try:
|
311 |
-
from gradio_imageslider import ImageSlider
|
312 |
-
except Exception as e:
|
313 |
-
print(f"Import gradio_imageslider failed, please install.")
|
314 |
-
self.image_viewer = ImageSlider(
|
315 |
-
label='Image',
|
316 |
-
type='pil',
|
317 |
-
show_download_button=True,
|
318 |
-
elem_id='image_viewer')
|
319 |
-
|
320 |
-
self.ext_btn_3 = gr.Button(value='Exit')
|
321 |
-
|
322 |
-
with gr.Tab(id='VideoViewer',
|
323 |
-
label='Video Viewer',
|
324 |
-
visible=False) as self.video_view_tab:
|
325 |
-
self.video_viewer = gr.Video(
|
326 |
-
label='Video',
|
327 |
-
interactive=False,
|
328 |
-
sources=[],
|
329 |
-
format='mp4',
|
330 |
-
show_download_button=True,
|
331 |
-
elem_id='video_viewer',
|
332 |
-
loop=True,
|
333 |
-
autoplay=True)
|
334 |
-
|
335 |
-
self.ext_btn_4 = gr.Button(value='Exit')
|
336 |
-
|
337 |
-
with gr.Row(equal_height=True, visible=True) as self.legacy_group:
|
338 |
-
with gr.Column():
|
339 |
-
self.legacy_image_uploader = gr.Image(
|
340 |
-
height=550,
|
341 |
-
interactive=True,
|
342 |
-
type='pil',
|
343 |
-
image_mode='RGB',
|
344 |
-
elem_id='legacy_image_uploader',
|
345 |
-
format='png')
|
346 |
-
with gr.Column():
|
347 |
-
self.legacy_image_viewer = gr.Image(
|
348 |
-
label='Image',
|
349 |
-
height=550,
|
350 |
-
type='pil',
|
351 |
-
interactive=False,
|
352 |
-
show_download_button=True,
|
353 |
-
elem_id='image_viewer')
|
354 |
-
|
355 |
-
with gr.Accordion(label='Setting', open=False):
|
356 |
-
with gr.Row():
|
357 |
-
self.model_name_dd = gr.Dropdown(
|
358 |
-
choices=self.model_choices,
|
359 |
-
value=self.default_model_name,
|
360 |
-
label='Model Version')
|
361 |
-
|
362 |
-
with gr.Row():
|
363 |
-
self.negative_prompt = gr.Textbox(
|
364 |
-
value='',
|
365 |
-
placeholder=
|
366 |
-
'Negative prompt used for Classifier-Free Guidance',
|
367 |
-
label='Negative Prompt',
|
368 |
-
container=False)
|
369 |
-
|
370 |
-
with gr.Row():
|
371 |
-
# REFINER_PROMPT
|
372 |
-
self.refiner_prompt = gr.Textbox(
|
373 |
-
value=self.pipe.input.get("refiner_prompt", ""),
|
374 |
-
visible=self.pipe.input.get("refiner_prompt", None) is not None,
|
375 |
-
placeholder=
|
376 |
-
'Prompt used for refiner',
|
377 |
-
label='Refiner Prompt',
|
378 |
-
container=False)
|
379 |
-
|
380 |
-
with gr.Row():
|
381 |
-
with gr.Column(scale=8, min_width=500):
|
382 |
-
with gr.Row():
|
383 |
-
self.step = gr.Slider(minimum=1,
|
384 |
-
maximum=1000,
|
385 |
-
value=self.pipe.input.get("sample_steps", 20),
|
386 |
-
visible=self.pipe.input.get("sample_steps", None) is not None,
|
387 |
-
label='Sample Step')
|
388 |
-
self.cfg_scale = gr.Slider(
|
389 |
-
minimum=1.0,
|
390 |
-
maximum=20.0,
|
391 |
-
value=self.pipe.input.get("guide_scale", 4.5),
|
392 |
-
visible=self.pipe.input.get("guide_scale", None) is not None,
|
393 |
-
label='Guidance Scale')
|
394 |
-
self.rescale = gr.Slider(minimum=0.0,
|
395 |
-
maximum=1.0,
|
396 |
-
value=self.pipe.input.get("guide_rescale", 0.5),
|
397 |
-
visible=self.pipe.input.get("guide_rescale", None) is not None,
|
398 |
-
label='Rescale')
|
399 |
-
self.refiner_scale = gr.Slider(minimum=-0.1,
|
400 |
-
maximum=1.0,
|
401 |
-
value=self.pipe.input.get("refiner_scale", -1),
|
402 |
-
visible=self.pipe.input.get("refiner_scale",
|
403 |
-
None) is not None,
|
404 |
-
label='Refiner Scale')
|
405 |
-
self.seed = gr.Slider(minimum=-1,
|
406 |
-
maximum=10000000,
|
407 |
-
value=-1,
|
408 |
-
label='Seed')
|
409 |
-
self.output_height = gr.Slider(
|
410 |
-
minimum=256,
|
411 |
-
maximum=1440,
|
412 |
-
value=self.pipe.input.get("output_height", 1024),
|
413 |
-
visible=self.pipe.input.get("output_height", None) is not None,
|
414 |
-
label='Output Height')
|
415 |
-
self.output_width = gr.Slider(
|
416 |
-
minimum=256,
|
417 |
-
maximum=1440,
|
418 |
-
value=self.pipe.input.get("output_width", 1024),
|
419 |
-
visible=self.pipe.input.get("output_width", None) is not None,
|
420 |
-
label='Output Width')
|
421 |
-
with gr.Column(scale=1, min_width=50):
|
422 |
-
self.use_history = gr.Checkbox(value=False,
|
423 |
-
label='Use History')
|
424 |
-
self.use_ace = gr.Checkbox(value=self.pipe.input.get("use_ace", True),
|
425 |
-
visible=self.pipe.input.get("use_ace", None) is not None,
|
426 |
-
label='Use ACE')
|
427 |
-
self.video_auto = gr.Checkbox(
|
428 |
-
value=False,
|
429 |
-
label='Auto Gen Video',
|
430 |
-
visible=self.enable_i2v)
|
431 |
-
|
432 |
-
with gr.Row(variant='panel',
|
433 |
-
equal_height=True,
|
434 |
-
visible=self.enable_i2v):
|
435 |
-
self.video_fps = gr.Slider(minimum=1,
|
436 |
-
maximum=16,
|
437 |
-
value=8,
|
438 |
-
label='Video FPS',
|
439 |
-
visible=True)
|
440 |
-
self.video_frames = gr.Slider(minimum=8,
|
441 |
-
maximum=49,
|
442 |
-
value=49,
|
443 |
-
label='Video Frame Num',
|
444 |
-
visible=True)
|
445 |
-
self.video_step = gr.Slider(minimum=1,
|
446 |
-
maximum=1000,
|
447 |
-
value=50,
|
448 |
-
label='Video Sample Step',
|
449 |
-
visible=True)
|
450 |
-
self.video_cfg_scale = gr.Slider(
|
451 |
-
minimum=1.0,
|
452 |
-
maximum=20.0,
|
453 |
-
value=6.0,
|
454 |
-
label='Video Guidance Scale',
|
455 |
-
visible=True)
|
456 |
-
self.video_seed = gr.Slider(minimum=-1,
|
457 |
-
maximum=10000000,
|
458 |
-
value=-1,
|
459 |
-
label='Video Seed',
|
460 |
-
visible=True)
|
461 |
-
|
462 |
-
with gr.Row():
|
463 |
-
self.chatbot_inst = """
|
464 |
-
**Instruction**:
|
465 |
-
1. Click 'Upload' button to upload one or more images as input images.
|
466 |
-
2. Enter '@' in the text box will exhibit all images in the gallery.
|
467 |
-
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
468 |
-
4. Compose the editing instruction for the selected image, incorporating image id '@xxxxxx' into your instruction.
|
469 |
-
For example, you might say, "Change the girl's skirt in @123456 to blue." The '@xxxxx' token will facilitate the identification of the specific image, and will be automatically replaced by a special token '{image}' in the instruction. Furthermore, it is also possible to engage in text-to-image generation without any initial image input.
|
470 |
-
5. Once your instructions are prepared, please click the "Chat" button to view the edited result in the chat window.
|
471 |
-
6. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
472 |
-
7. To implement local editing based on a specified mask, simply click on the image within the chat window to access the image editor. Here, you can draw a mask and then click the 'Submit' button to upload the edited image along with the mask. For inpainting tasks, select the 'Composite' mask type, while for outpainting tasks, choose the 'Outpainting' mask type. For all other local editing tasks, please select the 'Background' mask type.
|
473 |
-
8. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
474 |
-
"""
|
475 |
-
|
476 |
-
self.legacy_inst = """
|
477 |
-
**Instruction**:
|
478 |
-
1. You can edit the image by uploading it; if no image is uploaded, an image will be generated from text..
|
479 |
-
2. Enter '@' in the text box will exhibit all images in the gallery.
|
480 |
-
3. Select the image you wish to edit from the gallery, and its Image ID will be displayed in the text box.
|
481 |
-
4. **Important** To render text on an image, please ensure to include a space between each letter. For instance, "add text 'g i r l' on the mask area of @xxxxx".
|
482 |
-
5. To perform multi-step editing, partial editing, inpainting, outpainting, and other operations, please click the Chatbot Checkbox to enable the conversational editing mode and follow the relevant instructions..
|
483 |
-
6. If you find our work valuable, we invite you to refer to the [ACE Page](https://ali-vilab.github.io/ace-page/) for comprehensive information.
|
484 |
-
"""
|
485 |
-
|
486 |
-
self.instruction = gr.Markdown(value=self.legacy_inst)
|
487 |
-
|
488 |
-
with gr.Row(variant='panel',
|
489 |
-
equal_height=True,
|
490 |
-
show_progress=False):
|
491 |
-
with gr.Column(scale=1, min_width=100, visible=False) as self.upload_panel:
|
492 |
-
self.upload_btn = gr.Button(value=upload_sty +
|
493 |
-
' Upload',
|
494 |
-
variant='secondary')
|
495 |
-
with gr.Column(scale=5, min_width=500):
|
496 |
-
self.text = gr.Textbox(
|
497 |
-
placeholder='Input "@" find history of image',
|
498 |
-
label='Instruction',
|
499 |
-
container=False)
|
500 |
-
with gr.Column(scale=1, min_width=100):
|
501 |
-
self.chat_btn = gr.Button(value='Generate',
|
502 |
-
variant='primary')
|
503 |
-
with gr.Column(scale=1, min_width=100):
|
504 |
-
self.retry_btn = gr.Button(value=refresh_sty +
|
505 |
-
' Retry',
|
506 |
-
variant='secondary')
|
507 |
-
with gr.Column(scale=1, min_width=100):
|
508 |
-
self.mode_checkbox = gr.Checkbox(
|
509 |
-
value=False,
|
510 |
-
label='ChatBot')
|
511 |
-
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
512 |
-
min_width=0):
|
513 |
-
self.video_gen_btn = gr.Button(value=video_sty +
|
514 |
-
' Gen Video',
|
515 |
-
variant='secondary',
|
516 |
-
visible=self.enable_i2v)
|
517 |
-
with gr.Column(scale=(1 if self.enable_i2v else 0),
|
518 |
-
min_width=0):
|
519 |
-
self.extend_prompt = gr.Checkbox(
|
520 |
-
value=True,
|
521 |
-
label='Extend Prompt',
|
522 |
-
visible=self.enable_i2v)
|
523 |
-
|
524 |
-
with gr.Row():
|
525 |
-
self.gallery = gr.Gallery(visible=False,
|
526 |
-
label='History',
|
527 |
-
columns=10,
|
528 |
-
allow_preview=False,
|
529 |
-
interactive=False)
|
530 |
-
|
531 |
-
self.eg = gr.Column(visible=True)
|
532 |
-
|
533 |
-
def set_callbacks(self, *args, **kwargs):
|
534 |
-
|
535 |
-
########################################
|
536 |
-
# @spaces.GPU(duration=60)
|
537 |
-
def change_model(model_name):
|
538 |
-
if model_name not in self.model_choices:
|
539 |
-
gr.Info('The provided model name is not a valid choice!')
|
540 |
-
return model_name, gr.update(), gr.update()
|
541 |
-
|
542 |
-
if model_name != self.model_name:
|
543 |
-
lock.acquire()
|
544 |
-
del self.pipe
|
545 |
-
torch.cuda.empty_cache()
|
546 |
-
torch.cuda.ipc_collect()
|
547 |
-
pipe_cfg = self.model_choices[model_name]
|
548 |
-
infer_name = pipe_cfg.get("INFERENCE_TYPE", "ACE")
|
549 |
-
self.pipe = inference_dict[infer_name]()
|
550 |
-
self.pipe.init_from_cfg(pipe_cfg)
|
551 |
-
self.model_name = model_name
|
552 |
-
lock.release()
|
553 |
-
|
554 |
-
return (model_name, gr.update(), gr.update(),
|
555 |
-
gr.Slider(
|
556 |
-
value=self.pipe.input.get("sample_steps", 20),
|
557 |
-
visible=self.pipe.input.get("sample_steps", None) is not None),
|
558 |
-
gr.Slider(
|
559 |
-
value=self.pipe.input.get("guide_scale", 4.5),
|
560 |
-
visible=self.pipe.input.get("guide_scale", None) is not None),
|
561 |
-
gr.Slider(
|
562 |
-
value=self.pipe.input.get("guide_rescale", 0.5),
|
563 |
-
visible=self.pipe.input.get("guide_rescale", None) is not None),
|
564 |
-
gr.Slider(
|
565 |
-
value=self.pipe.input.get("output_height", 1024),
|
566 |
-
visible=self.pipe.input.get("output_height", None) is not None),
|
567 |
-
gr.Slider(
|
568 |
-
value=self.pipe.input.get("output_width", 1024),
|
569 |
-
visible=self.pipe.input.get("output_width", None) is not None),
|
570 |
-
gr.Textbox(
|
571 |
-
value=self.pipe.input.get("refiner_prompt", ""),
|
572 |
-
visible=self.pipe.input.get("refiner_prompt", None) is not None),
|
573 |
-
gr.Slider(
|
574 |
-
value=self.pipe.input.get("refiner_scale", -1),
|
575 |
-
visible=self.pipe.input.get("refiner_scale", None) is not None
|
576 |
-
),
|
577 |
-
gr.Checkbox(
|
578 |
-
value=self.pipe.input.get("use_ace", True),
|
579 |
-
visible=self.pipe.input.get("use_ace", None) is not None
|
580 |
-
)
|
581 |
-
)
|
582 |
-
|
583 |
-
self.model_name_dd.change(
|
584 |
-
change_model,
|
585 |
-
inputs=[self.model_name_dd],
|
586 |
-
outputs=[
|
587 |
-
self.model_name_dd, self.chatbot, self.text,
|
588 |
-
self.step,
|
589 |
-
self.cfg_scale, self.rescale, self.output_height,
|
590 |
-
self.output_width, self.refiner_prompt, self.refiner_scale,
|
591 |
-
self.use_ace])
|
592 |
-
|
593 |
-
def mode_change(mode_check):
|
594 |
-
if mode_check:
|
595 |
-
# ChatBot
|
596 |
-
return (
|
597 |
-
gr.Row(visible=False),
|
598 |
-
gr.Row(visible=True),
|
599 |
-
gr.Button(value='Generate'),
|
600 |
-
gr.State(value='chatbot'),
|
601 |
-
gr.Column(visible=True),
|
602 |
-
gr.Markdown(value=self.chatbot_inst)
|
603 |
-
)
|
604 |
-
else:
|
605 |
-
# Legacy
|
606 |
-
return (
|
607 |
-
gr.Row(visible=True),
|
608 |
-
gr.Row(visible=False),
|
609 |
-
gr.Button(value=chat_sty + ' Chat'),
|
610 |
-
gr.State(value='legacy'),
|
611 |
-
gr.Column(visible=False),
|
612 |
-
gr.Markdown(value=self.legacy_inst)
|
613 |
-
)
|
614 |
-
|
615 |
-
self.mode_checkbox.change(mode_change, inputs=[self.mode_checkbox],
|
616 |
-
outputs=[self.legacy_group, self.chat_group,
|
617 |
-
self.chat_btn, self.ui_mode,
|
618 |
-
self.upload_panel, self.instruction])
|
619 |
-
|
620 |
-
########################################
|
621 |
-
def generate_gallery(text, images):
|
622 |
-
if text.endswith(' '):
|
623 |
-
return gr.update(), gr.update(visible=False)
|
624 |
-
elif text.endswith('@'):
|
625 |
-
gallery_info = []
|
626 |
-
for image_id, image_meta in images.items():
|
627 |
-
thumbnail_path = image_meta['thumbnail']
|
628 |
-
gallery_info.append((thumbnail_path, image_id))
|
629 |
-
return gr.update(), gr.update(visible=True, value=gallery_info)
|
630 |
-
else:
|
631 |
-
gallery_info = []
|
632 |
-
match = re.search('@([^@ ]+)$', text)
|
633 |
-
if match:
|
634 |
-
prefix = match.group(1)
|
635 |
-
for image_id, image_meta in images.items():
|
636 |
-
if not image_id.startswith(prefix):
|
637 |
-
continue
|
638 |
-
thumbnail_path = image_meta['thumbnail']
|
639 |
-
gallery_info.append((thumbnail_path, image_id))
|
640 |
-
|
641 |
-
if len(gallery_info) > 0:
|
642 |
-
return gr.update(), gr.update(visible=True,
|
643 |
-
value=gallery_info)
|
644 |
-
else:
|
645 |
-
return gr.update(), gr.update(visible=False)
|
646 |
-
else:
|
647 |
-
return gr.update(), gr.update(visible=False)
|
648 |
-
|
649 |
-
self.text.input(generate_gallery,
|
650 |
-
inputs=[self.text, self.images],
|
651 |
-
outputs=[self.text, self.gallery],
|
652 |
-
show_progress='hidden')
|
653 |
-
|
654 |
-
########################################
|
655 |
-
def select_image(text, evt: gr.SelectData):
|
656 |
-
image_id = evt.value['caption']
|
657 |
-
text = '@'.join(text.split('@')[:-1]) + f'@{image_id} '
|
658 |
-
return gr.update(value=text), gr.update(visible=False, value=None)
|
659 |
-
|
660 |
-
self.gallery.select(select_image,
|
661 |
-
inputs=self.text,
|
662 |
-
outputs=[self.text, self.gallery])
|
663 |
-
|
664 |
-
########################################
|
665 |
-
def generate_video(message,
|
666 |
-
extend_prompt,
|
667 |
-
history,
|
668 |
-
images,
|
669 |
-
num_steps,
|
670 |
-
num_frames,
|
671 |
-
cfg_scale,
|
672 |
-
fps,
|
673 |
-
seed,
|
674 |
-
progress=gr.Progress(track_tqdm=True)):
|
675 |
-
|
676 |
-
from diffusers.utils import export_to_video
|
677 |
-
|
678 |
-
generator = torch.Generator(device='cuda').manual_seed(seed)
|
679 |
-
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
680 |
-
if len(img_ids) == 0:
|
681 |
-
history.append((
|
682 |
-
message,
|
683 |
-
'Sorry, no images were found in the prompt to be used as the first frame of the video.'
|
684 |
-
))
|
685 |
-
while len(history) >= self.max_msgs:
|
686 |
-
history.pop(0)
|
687 |
-
return history, self.get_history(
|
688 |
-
history), gr.update(), gr.update(visible=False)
|
689 |
-
|
690 |
-
img_id = img_ids[0]
|
691 |
-
prompt = re.sub(f'@{img_id}\s+', '', message)
|
692 |
-
|
693 |
-
if extend_prompt:
|
694 |
-
messages = copy.deepcopy(self.enhance_ctx)
|
695 |
-
messages.append({
|
696 |
-
'role':
|
697 |
-
'user',
|
698 |
-
'content':
|
699 |
-
f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
|
700 |
-
})
|
701 |
-
lock.acquire()
|
702 |
-
outputs = self.enhancer(
|
703 |
-
messages,
|
704 |
-
max_new_tokens=200,
|
705 |
-
)
|
706 |
-
|
707 |
-
prompt = outputs[0]['generated_text'][-1]['content']
|
708 |
-
print(prompt)
|
709 |
-
lock.release()
|
710 |
-
|
711 |
-
img_meta = images[img_id]
|
712 |
-
img_path = img_meta['image']
|
713 |
-
image = Image.open(img_path).convert('RGB')
|
714 |
-
|
715 |
-
lock.acquire()
|
716 |
-
video = self.i2v_pipe(
|
717 |
-
prompt=prompt,
|
718 |
-
image=image,
|
719 |
-
num_videos_per_prompt=1,
|
720 |
-
num_inference_steps=num_steps,
|
721 |
-
num_frames=num_frames,
|
722 |
-
guidance_scale=cfg_scale,
|
723 |
-
generator=generator,
|
724 |
-
).frames[0]
|
725 |
-
lock.release()
|
726 |
-
|
727 |
-
out_video_path = export_to_video(video, fps=fps)
|
728 |
-
history.append((
|
729 |
-
f"Based on first frame @{img_id} and description '{prompt}', generate a video",
|
730 |
-
'This is generated video:'))
|
731 |
-
history.append((None, out_video_path))
|
732 |
-
while len(history) >= self.max_msgs:
|
733 |
-
history.pop(0)
|
734 |
-
|
735 |
-
return history, self.get_history(history), gr.update(
|
736 |
-
value=''), gr.update(visible=False)
|
737 |
-
|
738 |
-
self.video_gen_btn.click(
|
739 |
-
generate_video,
|
740 |
-
inputs=[
|
741 |
-
self.text, self.extend_prompt, self.history, self.images,
|
742 |
-
self.video_step, self.video_frames, self.video_cfg_scale,
|
743 |
-
self.video_fps, self.video_seed
|
744 |
-
],
|
745 |
-
outputs=[self.history, self.chatbot, self.text, self.gallery])
|
746 |
-
|
747 |
-
########################################
|
748 |
-
@spaces.GPU(duration=120)
|
749 |
-
def run_chat(
|
750 |
-
message,
|
751 |
-
legacy_image,
|
752 |
-
ui_mode,
|
753 |
-
use_ace,
|
754 |
-
extend_prompt,
|
755 |
-
history,
|
756 |
-
images,
|
757 |
-
use_history,
|
758 |
-
history_result,
|
759 |
-
negative_prompt,
|
760 |
-
cfg_scale,
|
761 |
-
rescale,
|
762 |
-
refiner_prompt,
|
763 |
-
refiner_scale,
|
764 |
-
step,
|
765 |
-
seed,
|
766 |
-
output_h,
|
767 |
-
output_w,
|
768 |
-
video_auto,
|
769 |
-
video_steps,
|
770 |
-
video_frames,
|
771 |
-
video_cfg_scale,
|
772 |
-
video_fps,
|
773 |
-
video_seed,
|
774 |
-
progress=gr.Progress(track_tqdm=True)):
|
775 |
-
legacy_img_ids = []
|
776 |
-
if ui_mode == 'legacy':
|
777 |
-
if legacy_image is not None:
|
778 |
-
history, images, img_id = self.add_uploaded_image_to_history(
|
779 |
-
legacy_image, history, images)
|
780 |
-
legacy_img_ids.append(img_id)
|
781 |
-
retry_msg = message
|
782 |
-
gen_id = get_md5(message)[:12]
|
783 |
-
save_path = os.path.join(self.cache_dir, f'{gen_id}.png')
|
784 |
-
|
785 |
-
img_ids = re.findall('@(.*?)[ ,;.?$]', message)
|
786 |
-
history_io = None
|
787 |
-
|
788 |
-
if len(img_ids) < 1:
|
789 |
-
img_ids = legacy_img_ids
|
790 |
-
for img_id in img_ids:
|
791 |
-
if f'@{img_id}' not in message:
|
792 |
-
message = f'@{img_id} ' + message
|
793 |
-
|
794 |
-
new_message = message
|
795 |
-
|
796 |
-
if len(img_ids) > 0:
|
797 |
-
edit_image, edit_image_mask, edit_task = [], [], []
|
798 |
-
for i, img_id in enumerate(img_ids):
|
799 |
-
if img_id not in images:
|
800 |
-
gr.Info(
|
801 |
-
f'The input image ID {img_id} is not exist... Skip loading image.'
|
802 |
-
)
|
803 |
-
continue
|
804 |
-
placeholder = '{image}' if i == 0 else '{' + f'image{i}' + '}'
|
805 |
-
if placeholder not in new_message:
|
806 |
-
new_message = re.sub(f'@{img_id}', placeholder,
|
807 |
-
new_message)
|
808 |
-
else:
|
809 |
-
new_message = re.sub(f'@{img_id} ', "",
|
810 |
-
new_message, 1)
|
811 |
-
img_meta = images[img_id]
|
812 |
-
img_path = img_meta['image']
|
813 |
-
img_mask = img_meta['mask']
|
814 |
-
img_mask_type = img_meta['mask_type']
|
815 |
-
if img_mask_type is not None and img_mask_type == 'Composite':
|
816 |
-
task = 'inpainting'
|
817 |
-
else:
|
818 |
-
task = ''
|
819 |
-
edit_image.append(Image.open(img_path).convert('RGB'))
|
820 |
-
edit_image_mask.append(
|
821 |
-
Image.open(img_mask).
|
822 |
-
convert('L') if img_mask is not None else None)
|
823 |
-
edit_task.append(task)
|
824 |
-
|
825 |
-
if use_history and (img_id in history_result):
|
826 |
-
history_io = history_result[img_id]
|
827 |
-
|
828 |
-
buffered = io.BytesIO()
|
829 |
-
edit_image[0].save(buffered, format='PNG')
|
830 |
-
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
831 |
-
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
832 |
-
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image @{img_ids[0]} is:\n {img_str}'
|
833 |
-
else:
|
834 |
-
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
835 |
-
edit_image = None
|
836 |
-
edit_image_mask = None
|
837 |
-
edit_task = ''
|
838 |
-
if new_message == "":
|
839 |
-
new_message = "a beautiful girl wear a skirt."
|
840 |
-
print(new_message)
|
841 |
-
imgs = self.pipe(
|
842 |
-
image=edit_image,
|
843 |
-
mask=edit_image_mask,
|
844 |
-
task=edit_task,
|
845 |
-
prompt=[new_message] *
|
846 |
-
len(edit_image) if edit_image is not None else [new_message],
|
847 |
-
negative_prompt=[negative_prompt] * len(edit_image)
|
848 |
-
if edit_image is not None else [negative_prompt],
|
849 |
-
history_io=history_io,
|
850 |
-
output_height=output_h,
|
851 |
-
output_width=output_w,
|
852 |
-
sampler=self.pipe.input.get("sampler", "ddim"),
|
853 |
-
sample_steps=step,
|
854 |
-
guide_scale=cfg_scale,
|
855 |
-
guide_rescale=rescale,
|
856 |
-
seed=seed,
|
857 |
-
refiner_prompt=refiner_prompt,
|
858 |
-
refiner_scale=refiner_scale,
|
859 |
-
use_ace=use_ace
|
860 |
-
)
|
861 |
-
|
862 |
-
img = imgs[0]
|
863 |
-
img.save(save_path, format='JPEG')
|
864 |
-
|
865 |
-
if history_io:
|
866 |
-
history_io_new = copy.deepcopy(history_io)
|
867 |
-
history_io_new['image'] += edit_image[:1]
|
868 |
-
history_io_new['mask'] += edit_image_mask[:1]
|
869 |
-
history_io_new['task'] += edit_task[:1]
|
870 |
-
history_io_new['prompt'] += [new_message]
|
871 |
-
history_io_new['image'] = history_io_new['image'][-5:]
|
872 |
-
history_io_new['mask'] = history_io_new['mask'][-5:]
|
873 |
-
history_io_new['task'] = history_io_new['task'][-5:]
|
874 |
-
history_io_new['prompt'] = history_io_new['prompt'][-5:]
|
875 |
-
history_result[gen_id] = history_io_new
|
876 |
-
elif edit_image is not None and len(edit_image) > 0:
|
877 |
-
history_io_new = {
|
878 |
-
'image': edit_image[:1],
|
879 |
-
'mask': edit_image_mask[:1],
|
880 |
-
'task': edit_task[:1],
|
881 |
-
'prompt': [new_message]
|
882 |
-
}
|
883 |
-
history_result[gen_id] = history_io_new
|
884 |
-
|
885 |
-
w, h = img.size
|
886 |
-
if w > h:
|
887 |
-
tb_w = 128
|
888 |
-
tb_h = int(h * tb_w / w)
|
889 |
-
else:
|
890 |
-
tb_h = 128
|
891 |
-
tb_w = int(w * tb_h / h)
|
892 |
-
|
893 |
-
thumbnail_path = os.path.join(self.cache_dir,
|
894 |
-
f'{gen_id}_thumbnail.jpg')
|
895 |
-
thumbnail = img.resize((tb_w, tb_h))
|
896 |
-
thumbnail.save(thumbnail_path, format='JPEG')
|
897 |
-
|
898 |
-
images[gen_id] = {
|
899 |
-
'image': save_path,
|
900 |
-
'mask': None,
|
901 |
-
'mask_type': None,
|
902 |
-
'thumbnail': thumbnail_path
|
903 |
-
}
|
904 |
-
|
905 |
-
buffered = io.BytesIO()
|
906 |
-
img.convert('RGB').save(buffered, format='JPEG')
|
907 |
-
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
908 |
-
img_str = f'<img src="data:image/jpg;base64,{img_b64}" style="pointer-events: none;">'
|
909 |
-
|
910 |
-
history.append(
|
911 |
-
(message,
|
912 |
-
f'{pre_info} The generated image @{gen_id} is:\n {img_str}'))
|
913 |
-
|
914 |
-
if video_auto:
|
915 |
-
if video_seed is None or video_seed == -1:
|
916 |
-
video_seed = random.randint(0, 10000000)
|
917 |
-
|
918 |
-
lock.acquire()
|
919 |
-
generator = torch.Generator(
|
920 |
-
device='cuda').manual_seed(video_seed)
|
921 |
-
pixel_values = load_image(img.convert('RGB'),
|
922 |
-
max_num=self.llm_max_num).to(
|
923 |
-
torch.bfloat16).cuda()
|
924 |
-
prompt = self.captioner.chat(self.llm_tokenizer, pixel_values,
|
925 |
-
self.llm_prompt,
|
926 |
-
self.llm_generation_config)
|
927 |
-
print(prompt)
|
928 |
-
lock.release()
|
929 |
-
|
930 |
-
if extend_prompt:
|
931 |
-
messages = copy.deepcopy(self.enhance_ctx)
|
932 |
-
messages.append({
|
933 |
-
'role':
|
934 |
-
'user',
|
935 |
-
'content':
|
936 |
-
f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{prompt}"',
|
937 |
-
})
|
938 |
-
lock.acquire()
|
939 |
-
outputs = self.enhancer(
|
940 |
-
messages,
|
941 |
-
max_new_tokens=200,
|
942 |
-
)
|
943 |
-
prompt = outputs[0]['generated_text'][-1]['content']
|
944 |
-
print(prompt)
|
945 |
-
lock.release()
|
946 |
-
|
947 |
-
lock.acquire()
|
948 |
-
video = self.i2v_pipe(
|
949 |
-
prompt=prompt,
|
950 |
-
image=img,
|
951 |
-
num_videos_per_prompt=1,
|
952 |
-
num_inference_steps=video_steps,
|
953 |
-
num_frames=video_frames,
|
954 |
-
guidance_scale=video_cfg_scale,
|
955 |
-
generator=generator,
|
956 |
-
).frames[0]
|
957 |
-
lock.release()
|
958 |
-
|
959 |
-
out_video_path = export_to_video(video, fps=video_fps)
|
960 |
-
history.append((
|
961 |
-
f"Based on first frame @{gen_id} and description '{prompt}', generate a video",
|
962 |
-
'This is generated video:'))
|
963 |
-
history.append((None, out_video_path))
|
964 |
-
|
965 |
-
while len(history) >= self.max_msgs:
|
966 |
-
history.pop(0)
|
967 |
-
|
968 |
-
return (history, images, gr.Image(value=save_path),
|
969 |
-
history_result, self.get_history(
|
970 |
-
history), gr.update(), gr.update(
|
971 |
-
visible=False), retry_msg)
|
972 |
-
|
973 |
-
chat_inputs = [
|
974 |
-
self.legacy_image_uploader, self.ui_mode, self.use_ace,
|
975 |
-
self.extend_prompt, self.history, self.images, self.use_history,
|
976 |
-
self.history_result, self.negative_prompt, self.cfg_scale,
|
977 |
-
self.rescale, self.refiner_prompt, self.refiner_scale,
|
978 |
-
self.step, self.seed, self.output_height,
|
979 |
-
self.output_width, self.video_auto, self.video_step,
|
980 |
-
self.video_frames, self.video_cfg_scale, self.video_fps,
|
981 |
-
self.video_seed
|
982 |
-
]
|
983 |
-
|
984 |
-
chat_outputs = [
|
985 |
-
self.history, self.images, self.legacy_image_viewer,
|
986 |
-
self.history_result, self.chatbot,
|
987 |
-
self.text, self.gallery, self.retry_msg
|
988 |
-
]
|
989 |
-
|
990 |
-
self.chat_btn.click(run_chat,
|
991 |
-
inputs=[self.text] + chat_inputs,
|
992 |
-
outputs=chat_outputs)
|
993 |
-
|
994 |
-
self.text.submit(run_chat,
|
995 |
-
inputs=[self.text] + chat_inputs,
|
996 |
-
outputs=chat_outputs)
|
997 |
-
|
998 |
-
def retry_fn(*args):
|
999 |
-
return run_chat(*args)
|
1000 |
-
|
1001 |
-
self.retry_btn.click(retry_fn,
|
1002 |
-
inputs=[self.retry_msg] + chat_inputs,
|
1003 |
-
outputs=chat_outputs)
|
1004 |
-
|
1005 |
-
########################################
|
1006 |
-
@spaces.GPU(duration=120)
|
1007 |
-
def run_example(task, img, img_mask, ref1, prompt, seed):
|
1008 |
-
edit_image, edit_image_mask, edit_task = [], [], []
|
1009 |
-
if img is not None:
|
1010 |
-
w, h = img.size
|
1011 |
-
if w > 2048:
|
1012 |
-
ratio = w / 2048.
|
1013 |
-
w = 2048
|
1014 |
-
h = int(h / ratio)
|
1015 |
-
if h > 2048:
|
1016 |
-
ratio = h / 2048.
|
1017 |
-
h = 2048
|
1018 |
-
w = int(w / ratio)
|
1019 |
-
img = img.resize((w, h))
|
1020 |
-
edit_image.append(img)
|
1021 |
-
if img_mask is not None:
|
1022 |
-
img_mask = img_mask if np.sum(np.array(img_mask)) > 0 else None
|
1023 |
-
edit_image_mask.append(
|
1024 |
-
img_mask if img_mask is not None else None)
|
1025 |
-
edit_task.append(task)
|
1026 |
-
if ref1 is not None:
|
1027 |
-
ref1 = ref1 if np.sum(np.array(ref1)) > 0 else None
|
1028 |
-
if ref1 is not None:
|
1029 |
-
edit_image.append(ref1)
|
1030 |
-
edit_image_mask.append(None)
|
1031 |
-
edit_task.append('')
|
1032 |
-
|
1033 |
-
buffered = io.BytesIO()
|
1034 |
-
img.save(buffered, format='PNG')
|
1035 |
-
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1036 |
-
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1037 |
-
pre_info = f'Received one or more images, so image editing is conducted.\n The first input image is:\n {img_str}'
|
1038 |
-
else:
|
1039 |
-
pre_info = 'No image ids were found in the provided text prompt, so text-guided image generation is conducted. \n'
|
1040 |
-
edit_image = None
|
1041 |
-
edit_image_mask = None
|
1042 |
-
edit_task = ''
|
1043 |
-
|
1044 |
-
img_num = len(edit_image) if edit_image is not None else 1
|
1045 |
-
imgs = self.pipe(
|
1046 |
-
image=edit_image,
|
1047 |
-
mask=edit_image_mask,
|
1048 |
-
task=edit_task,
|
1049 |
-
prompt=[prompt] * img_num,
|
1050 |
-
negative_prompt=[''] * img_num,
|
1051 |
-
seed=seed,
|
1052 |
-
refiner_prompt=self.pipe.input.get("refiner_prompt", ""),
|
1053 |
-
refiner_scale=self.pipe.input.get("refiner_scale", 0.0),
|
1054 |
-
)
|
1055 |
-
|
1056 |
-
img = imgs[0]
|
1057 |
-
buffered = io.BytesIO()
|
1058 |
-
img.convert('RGB').save(buffered, format='JPEG')
|
1059 |
-
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1060 |
-
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1061 |
-
history = [(prompt,
|
1062 |
-
f'{pre_info} The generated image is:\n {img_str}')]
|
1063 |
-
|
1064 |
-
img_id = get_md5(img_b64)[:12]
|
1065 |
-
save_path = os.path.join(self.cache_dir, f'{img_id}.jpg')
|
1066 |
-
img.convert('RGB').save(save_path)
|
1067 |
-
|
1068 |
-
return self.get_history(history), gr.update(value=prompt), gr.update(
|
1069 |
-
visible=False), gr.update(value=save_path), gr.update(value=-1)
|
1070 |
-
|
1071 |
-
with self.eg:
|
1072 |
-
self.example_task = gr.Text(label='Task Name',
|
1073 |
-
value='',
|
1074 |
-
visible=False)
|
1075 |
-
self.example_image = gr.Image(label='Edit Image',
|
1076 |
-
type='pil',
|
1077 |
-
image_mode='RGB',
|
1078 |
-
visible=False)
|
1079 |
-
self.example_mask = gr.Image(label='Edit Image Mask',
|
1080 |
-
type='pil',
|
1081 |
-
image_mode='L',
|
1082 |
-
visible=False)
|
1083 |
-
self.example_ref_im1 = gr.Image(label='Ref Image',
|
1084 |
-
type='pil',
|
1085 |
-
image_mode='RGB',
|
1086 |
-
visible=False)
|
1087 |
-
|
1088 |
-
self.examples = gr.Examples(
|
1089 |
-
fn=run_example,
|
1090 |
-
examples=self.chatbot_examples,
|
1091 |
-
inputs=[
|
1092 |
-
self.example_task, self.example_image, self.example_mask,
|
1093 |
-
self.example_ref_im1, self.text, self.seed
|
1094 |
-
],
|
1095 |
-
outputs=[self.chatbot, self.text, self.gallery, self.legacy_image_viewer, self.seed],
|
1096 |
-
examples_per_page=4,
|
1097 |
-
cache_examples=False,
|
1098 |
-
run_on_click=True)
|
1099 |
-
|
1100 |
-
########################################
|
1101 |
-
def upload_image():
|
1102 |
-
return (gr.update(visible=True,
|
1103 |
-
scale=1), gr.update(visible=True, scale=1),
|
1104 |
-
gr.update(visible=True), gr.update(visible=False),
|
1105 |
-
gr.update(visible=False), gr.update(visible=False),
|
1106 |
-
gr.update(visible=True))
|
1107 |
-
|
1108 |
-
self.upload_btn.click(upload_image,
|
1109 |
-
inputs=[],
|
1110 |
-
outputs=[
|
1111 |
-
self.chat_page, self.editor_page,
|
1112 |
-
self.upload_tab, self.edit_tab,
|
1113 |
-
self.image_view_tab, self.video_view_tab,
|
1114 |
-
self.upload_tabs
|
1115 |
-
])
|
1116 |
-
|
1117 |
-
########################################
|
1118 |
-
def edit_image(evt: gr.SelectData):
|
1119 |
-
if isinstance(evt.value, str):
|
1120 |
-
img_b64s = re.findall(
|
1121 |
-
'<img src="data:image/png;base64,(.*?)" style="pointer-events: none;">',
|
1122 |
-
evt.value)
|
1123 |
-
imgs = [
|
1124 |
-
Image.open(io.BytesIO(base64.b64decode(copy.deepcopy(i))))
|
1125 |
-
for i in img_b64s
|
1126 |
-
]
|
1127 |
-
if len(imgs) > 0:
|
1128 |
-
if len(imgs) == 2:
|
1129 |
-
if self.gradio_version >= '5.0.0':
|
1130 |
-
view_img = copy.deepcopy(imgs[-1])
|
1131 |
-
else:
|
1132 |
-
view_img = copy.deepcopy(imgs)
|
1133 |
-
edit_img = copy.deepcopy(imgs[-1])
|
1134 |
-
else:
|
1135 |
-
if self.gradio_version >= '5.0.0':
|
1136 |
-
view_img = copy.deepcopy(imgs[-1])
|
1137 |
-
else:
|
1138 |
-
view_img = [
|
1139 |
-
copy.deepcopy(imgs[-1]),
|
1140 |
-
copy.deepcopy(imgs[-1])
|
1141 |
-
]
|
1142 |
-
edit_img = copy.deepcopy(imgs[-1])
|
1143 |
-
|
1144 |
-
return (gr.update(visible=True,
|
1145 |
-
scale=1), gr.update(visible=True,
|
1146 |
-
scale=1),
|
1147 |
-
gr.update(visible=False), gr.update(visible=True),
|
1148 |
-
gr.update(visible=True), gr.update(visible=False),
|
1149 |
-
gr.update(value=edit_img),
|
1150 |
-
gr.update(value=view_img), gr.update(value=None),
|
1151 |
-
gr.update(visible=True))
|
1152 |
-
else:
|
1153 |
-
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
1154 |
-
gr.update(), gr.update(), gr.update(), gr.update(),
|
1155 |
-
gr.update(), gr.update())
|
1156 |
-
elif isinstance(evt.value, dict) and evt.value.get(
|
1157 |
-
'component', '') == 'video':
|
1158 |
-
value = evt.value['value']['video']['path']
|
1159 |
-
return (gr.update(visible=True,
|
1160 |
-
scale=1), gr.update(visible=True, scale=1),
|
1161 |
-
gr.update(visible=False), gr.update(visible=False),
|
1162 |
-
gr.update(visible=False), gr.update(visible=True),
|
1163 |
-
gr.update(), gr.update(), gr.update(value=value),
|
1164 |
-
gr.update())
|
1165 |
-
else:
|
1166 |
-
return (gr.update(), gr.update(), gr.update(), gr.update(),
|
1167 |
-
gr.update(), gr.update(), gr.update(), gr.update(),
|
1168 |
-
gr.update(), gr.update())
|
1169 |
-
|
1170 |
-
self.chatbot.select(edit_image,
|
1171 |
-
outputs=[
|
1172 |
-
self.chat_page, self.editor_page,
|
1173 |
-
self.upload_tab, self.edit_tab,
|
1174 |
-
self.image_view_tab, self.video_view_tab,
|
1175 |
-
self.image_editor, self.image_viewer,
|
1176 |
-
self.video_viewer, self.edit_tabs
|
1177 |
-
])
|
1178 |
-
|
1179 |
-
if self.gradio_version < '5.0.0':
|
1180 |
-
self.image_viewer.change(lambda x: x,
|
1181 |
-
inputs=self.image_viewer,
|
1182 |
-
outputs=self.image_viewer)
|
1183 |
-
|
1184 |
-
########################################
|
1185 |
-
def submit_upload_image(image, history, images):
|
1186 |
-
history, images, _ = self.add_uploaded_image_to_history(
|
1187 |
-
image, history, images)
|
1188 |
-
return gr.update(visible=False), gr.update(
|
1189 |
-
visible=True), gr.update(
|
1190 |
-
value=self.get_history(history)), history, images
|
1191 |
-
|
1192 |
-
self.sub_btn_1.click(
|
1193 |
-
submit_upload_image,
|
1194 |
-
inputs=[self.image_uploader, self.history, self.images],
|
1195 |
-
outputs=[
|
1196 |
-
self.editor_page, self.chat_page, self.chatbot, self.history,
|
1197 |
-
self.images
|
1198 |
-
])
|
1199 |
-
|
1200 |
-
########################################
|
1201 |
-
def submit_edit_image(imagemask, mask_type, history, images):
|
1202 |
-
history, images = self.add_edited_image_to_history(
|
1203 |
-
imagemask, mask_type, history, images)
|
1204 |
-
return gr.update(visible=False), gr.update(
|
1205 |
-
visible=True), gr.update(
|
1206 |
-
value=self.get_history(history)), history, images
|
1207 |
-
|
1208 |
-
self.sub_btn_2.click(submit_edit_image,
|
1209 |
-
inputs=[
|
1210 |
-
self.image_editor, self.mask_type,
|
1211 |
-
self.history, self.images
|
1212 |
-
],
|
1213 |
-
outputs=[
|
1214 |
-
self.editor_page, self.chat_page,
|
1215 |
-
self.chatbot, self.history, self.images
|
1216 |
-
])
|
1217 |
-
|
1218 |
-
########################################
|
1219 |
-
def exit_edit():
|
1220 |
-
return gr.update(visible=False), gr.update(visible=True, scale=3)
|
1221 |
-
|
1222 |
-
self.ext_btn_1.click(exit_edit,
|
1223 |
-
outputs=[self.editor_page, self.chat_page])
|
1224 |
-
self.ext_btn_2.click(exit_edit,
|
1225 |
-
outputs=[self.editor_page, self.chat_page])
|
1226 |
-
self.ext_btn_3.click(exit_edit,
|
1227 |
-
outputs=[self.editor_page, self.chat_page])
|
1228 |
-
self.ext_btn_4.click(exit_edit,
|
1229 |
-
outputs=[self.editor_page, self.chat_page])
|
1230 |
-
|
1231 |
-
########################################
|
1232 |
-
def update_mask_type_info(mask_type):
|
1233 |
-
if mask_type == 'Background':
|
1234 |
-
info = 'Background mode will not erase the visual content in the mask area'
|
1235 |
-
visible = False
|
1236 |
-
elif mask_type == 'Composite':
|
1237 |
-
info = 'Composite mode will erase the visual content in the mask area'
|
1238 |
-
visible = False
|
1239 |
-
elif mask_type == 'Outpainting':
|
1240 |
-
info = 'Outpaint mode is used for preparing input image for outpainting task'
|
1241 |
-
visible = True
|
1242 |
-
return (gr.update(
|
1243 |
-
visible=True,
|
1244 |
-
value=
|
1245 |
-
f"<div style='background-color: white; padding-left: 15px; color: grey;'>{info}</div>"
|
1246 |
-
), gr.update(visible=visible))
|
1247 |
-
|
1248 |
-
self.mask_type.change(update_mask_type_info,
|
1249 |
-
inputs=self.mask_type,
|
1250 |
-
outputs=[self.mask_type_info, self.outpaint_tab])
|
1251 |
-
|
1252 |
-
########################################
|
1253 |
-
def extend_image(top_ratio, bottom_ratio, left_ratio, right_ratio,
|
1254 |
-
image):
|
1255 |
-
img = cv2.cvtColor(image['background'], cv2.COLOR_RGBA2RGB)
|
1256 |
-
h, w = img.shape[:2]
|
1257 |
-
new_h = int(h * (top_ratio + bottom_ratio + 1))
|
1258 |
-
new_w = int(w * (left_ratio + right_ratio + 1))
|
1259 |
-
start_h = int(h * top_ratio)
|
1260 |
-
start_w = int(w * left_ratio)
|
1261 |
-
new_img = np.zeros((new_h, new_w, 3), dtype=np.uint8)
|
1262 |
-
new_mask = np.ones((new_h, new_w, 1), dtype=np.uint8) * 255
|
1263 |
-
new_img[start_h:start_h + h, start_w:start_w + w, :] = img
|
1264 |
-
new_mask[start_h:start_h + h, start_w:start_w + w] = 0
|
1265 |
-
layer = np.concatenate([new_img, new_mask], axis=2)
|
1266 |
-
value = {
|
1267 |
-
'background': new_img,
|
1268 |
-
'composite': new_img,
|
1269 |
-
'layers': [layer]
|
1270 |
-
}
|
1271 |
-
return gr.update(value=value)
|
1272 |
-
|
1273 |
-
self.img_pad_btn.click(extend_image,
|
1274 |
-
inputs=[
|
1275 |
-
self.top_ext, self.bottom_ext,
|
1276 |
-
self.left_ext, self.right_ext,
|
1277 |
-
self.image_editor
|
1278 |
-
],
|
1279 |
-
outputs=self.image_editor)
|
1280 |
-
|
1281 |
-
########################################
|
1282 |
-
def clear_chat(history, images, history_result):
|
1283 |
-
history.clear()
|
1284 |
-
images.clear()
|
1285 |
-
history_result.clear()
|
1286 |
-
return history, images, history_result, self.get_history(history)
|
1287 |
-
|
1288 |
-
self.clear_btn.click(
|
1289 |
-
clear_chat,
|
1290 |
-
inputs=[self.history, self.images, self.history_result],
|
1291 |
-
outputs=[
|
1292 |
-
self.history, self.images, self.history_result, self.chatbot
|
1293 |
-
])
|
1294 |
-
|
1295 |
-
def get_history(self, history):
|
1296 |
-
info = []
|
1297 |
-
for item in history:
|
1298 |
-
new_item = [None, None]
|
1299 |
-
if isinstance(item[0], str) and item[0].endswith('.mp4'):
|
1300 |
-
new_item[0] = gr.Video(item[0], format='mp4')
|
1301 |
-
else:
|
1302 |
-
new_item[0] = item[0]
|
1303 |
-
if isinstance(item[1], str) and item[1].endswith('.mp4'):
|
1304 |
-
new_item[1] = gr.Video(item[1], format='mp4')
|
1305 |
-
else:
|
1306 |
-
new_item[1] = item[1]
|
1307 |
-
info.append(new_item)
|
1308 |
-
return info
|
1309 |
-
|
1310 |
-
def generate_random_string(self, length=20):
|
1311 |
-
letters_and_digits = string.ascii_letters + string.digits
|
1312 |
-
random_string = ''.join(
|
1313 |
-
random.choice(letters_and_digits) for i in range(length))
|
1314 |
-
return random_string
|
1315 |
-
|
1316 |
-
def add_edited_image_to_history(self, image, mask_type, history, images):
|
1317 |
-
if mask_type == 'Composite':
|
1318 |
-
img = Image.fromarray(image['composite'])
|
1319 |
-
else:
|
1320 |
-
img = Image.fromarray(image['background'])
|
1321 |
-
|
1322 |
-
img_id = get_md5(self.generate_random_string())[:12]
|
1323 |
-
save_path = os.path.join(self.cache_dir, f'{img_id}.png')
|
1324 |
-
img.convert('RGB').save(save_path)
|
1325 |
-
|
1326 |
-
mask = image['layers'][0][:, :, 3]
|
1327 |
-
mask = Image.fromarray(mask).convert('RGB')
|
1328 |
-
mask_path = os.path.join(self.cache_dir, f'{img_id}_mask.png')
|
1329 |
-
mask.save(mask_path)
|
1330 |
-
|
1331 |
-
w, h = img.size
|
1332 |
-
if w > h:
|
1333 |
-
tb_w = 128
|
1334 |
-
tb_h = int(h * tb_w / w)
|
1335 |
-
else:
|
1336 |
-
tb_h = 128
|
1337 |
-
tb_w = int(w * tb_h / h)
|
1338 |
-
|
1339 |
-
if mask_type == 'Background':
|
1340 |
-
comp_mask = np.array(mask, dtype=np.uint8)
|
1341 |
-
mask_alpha = (comp_mask[:, :, 0:1].astype(np.float32) *
|
1342 |
-
0.6).astype(np.uint8)
|
1343 |
-
comp_mask = np.concatenate([comp_mask, mask_alpha], axis=2)
|
1344 |
-
thumbnail = Image.alpha_composite(
|
1345 |
-
img.convert('RGBA'),
|
1346 |
-
Image.fromarray(comp_mask).convert('RGBA')).convert('RGB')
|
1347 |
-
else:
|
1348 |
-
thumbnail = img.convert('RGB')
|
1349 |
-
|
1350 |
-
thumbnail_path = os.path.join(self.cache_dir,
|
1351 |
-
f'{img_id}_thumbnail.jpg')
|
1352 |
-
thumbnail = thumbnail.resize((tb_w, tb_h))
|
1353 |
-
thumbnail.save(thumbnail_path, format='JPEG')
|
1354 |
-
|
1355 |
-
buffered = io.BytesIO()
|
1356 |
-
img.convert('RGB').save(buffered, format='PNG')
|
1357 |
-
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1358 |
-
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1359 |
-
|
1360 |
-
buffered = io.BytesIO()
|
1361 |
-
mask.convert('RGB').save(buffered, format='PNG')
|
1362 |
-
mask_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1363 |
-
mask_str = f'<img src="data:image/png;base64,{mask_b64}" style="pointer-events: none;">'
|
1364 |
-
|
1365 |
-
images[img_id] = {
|
1366 |
-
'image': save_path,
|
1367 |
-
'mask': mask_path,
|
1368 |
-
'mask_type': mask_type,
|
1369 |
-
'thumbnail': thumbnail_path
|
1370 |
-
}
|
1371 |
-
history.append((
|
1372 |
-
None,
|
1373 |
-
f'This is edited image and mask:\n {img_str} {mask_str} image ID is: {img_id}'
|
1374 |
-
))
|
1375 |
-
return history, images
|
1376 |
-
|
1377 |
-
def add_uploaded_image_to_history(self, img, history, images):
|
1378 |
-
img_id = get_md5(self.generate_random_string())[:12]
|
1379 |
-
save_path = os.path.join(self.cache_dir, f'{img_id}.png')
|
1380 |
-
w, h = img.size
|
1381 |
-
if w > 2048:
|
1382 |
-
ratio = w / 2048.
|
1383 |
-
w = 2048
|
1384 |
-
h = int(h / ratio)
|
1385 |
-
if h > 2048:
|
1386 |
-
ratio = h / 2048.
|
1387 |
-
h = 2048
|
1388 |
-
w = int(w / ratio)
|
1389 |
-
img = img.resize((w, h))
|
1390 |
-
img.save(save_path)
|
1391 |
-
|
1392 |
-
w, h = img.size
|
1393 |
-
if w > h:
|
1394 |
-
tb_w = 128
|
1395 |
-
tb_h = int(h * tb_w / w)
|
1396 |
-
else:
|
1397 |
-
tb_h = 128
|
1398 |
-
tb_w = int(w * tb_h / h)
|
1399 |
-
thumbnail_path = os.path.join(self.cache_dir,
|
1400 |
-
f'{img_id}_thumbnail.jpg')
|
1401 |
-
thumbnail = img.resize((tb_w, tb_h))
|
1402 |
-
thumbnail.save(thumbnail_path, format='JPEG')
|
1403 |
-
|
1404 |
-
images[img_id] = {
|
1405 |
-
'image': save_path,
|
1406 |
-
'mask': None,
|
1407 |
-
'mask_type': None,
|
1408 |
-
'thumbnail': thumbnail_path
|
1409 |
-
}
|
1410 |
-
|
1411 |
-
buffered = io.BytesIO()
|
1412 |
-
img.convert('RGB').save(buffered, format='PNG')
|
1413 |
-
img_b64 = base64.b64encode(buffered.getvalue()).decode('utf-8')
|
1414 |
-
img_str = f'<img src="data:image/png;base64,{img_b64}" style="pointer-events: none;">'
|
1415 |
-
|
1416 |
-
history.append(
|
1417 |
-
(None,
|
1418 |
-
f'This is uploaded image:\n {img_str} image ID is: {img_id}'))
|
1419 |
-
return history, images, img_id
|
1420 |
-
|
1421 |
-
|
1422 |
-
if __name__ == '__main__':
|
1423 |
-
cfg = "config/chatbot_ui.yaml"
|
1424 |
-
with gr.Blocks() as demo:
|
1425 |
-
chatbot = ChatBotUI(cfg)
|
1426 |
-
chatbot.create_ui()
|
1427 |
-
chatbot.set_callbacks()
|
1428 |
-
demo.launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/chatbot_ui.yaml
DELETED
@@ -1,25 +0,0 @@
|
|
1 |
-
WORK_DIR: chatbot
|
2 |
-
FILE_SYSTEM:
|
3 |
-
- NAME: LocalFs
|
4 |
-
TEMP_DIR: ./cache
|
5 |
-
- NAME: ModelscopeFs
|
6 |
-
TEMP_DIR: ./cache
|
7 |
-
- NAME: HuggingfaceFs
|
8 |
-
TEMP_DIR: ./cache
|
9 |
-
#
|
10 |
-
ENABLE_I2V: False
|
11 |
-
SKIP_EXAMPLES: True
|
12 |
-
#
|
13 |
-
MODEL:
|
14 |
-
EDIT_MODEL:
|
15 |
-
MODEL_CFG_DIR: config/models/
|
16 |
-
I2V:
|
17 |
-
MODEL_NAME: CogVideoX-5b-I2V
|
18 |
-
MODEL_DIR: ms://ZhipuAI/CogVideoX-5b-I2V/
|
19 |
-
CAPTIONER:
|
20 |
-
MODEL_NAME: InternVL2-2B
|
21 |
-
MODEL_DIR: ms://OpenGVLab/InternVL2-2B/
|
22 |
-
PROMPT: '<image>\nThis image is the first frame of a video. Based on this image, please imagine what changes may occur in the next few seconds of the video. Please output brief description, such as "a dog running" or "a person turns to left". No more than 30 words.'
|
23 |
-
ENHANCER:
|
24 |
-
MODEL_NAME: Meta-Llama-3.1-8B-Instruct
|
25 |
-
MODEL_DIR: ms://LLM-Research/Meta-Llama-3.1-8B-Instruct/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
config/models/ace_flux_dev.yaml
DELETED
@@ -1,187 +0,0 @@
|
|
1 |
-
NAME: ACE_FLUX.1_dev
|
2 |
-
IS_DEFAULT: True
|
3 |
-
USE_DYNAMIC_MODEL: False
|
4 |
-
INFERENCE_TYPE: ACE_FLUX
|
5 |
-
MAX_SEQ_LENGTH: 3072
|
6 |
-
SRC_MAX_SEQ_LENGTH: 2048
|
7 |
-
DEFAULT_PARAS:
|
8 |
-
PARAS:
|
9 |
-
#
|
10 |
-
INPUT:
|
11 |
-
INPUT_IMAGE:
|
12 |
-
INPUT_MASK:
|
13 |
-
TASK:
|
14 |
-
PROMPT: ""
|
15 |
-
OUTPUT_HEIGHT: 1024
|
16 |
-
OUTPUT_WIDTH: 1024
|
17 |
-
SAMPLER: flow_euler
|
18 |
-
SAMPLE_STEPS: 20
|
19 |
-
GUIDE_SCALE: 3.5
|
20 |
-
SEED: -1
|
21 |
-
TAR_INDEX: 0
|
22 |
-
ALIGN: False
|
23 |
-
OUTPUT:
|
24 |
-
LATENT:
|
25 |
-
IMAGES:
|
26 |
-
SEED:
|
27 |
-
MODULES_PARAS:
|
28 |
-
FIRST_STAGE_MODEL:
|
29 |
-
FUNCTION:
|
30 |
-
- NAME: encode
|
31 |
-
DTYPE: bfloat16
|
32 |
-
INPUT: [ "IMAGE" ]
|
33 |
-
- NAME: decode
|
34 |
-
DTYPE: bfloat16
|
35 |
-
INPUT: [ "LATENT" ]
|
36 |
-
PARAS:
|
37 |
-
SCALE_FACTOR: 1.5305
|
38 |
-
SHIFT_FACTOR: 0.0609
|
39 |
-
SIZE_FACTOR: 8
|
40 |
-
DIFFUSION_MODEL:
|
41 |
-
FUNCTION:
|
42 |
-
- NAME: forward
|
43 |
-
DTYPE: bfloat16
|
44 |
-
INPUT: [ "SAMPLE_STEPS", "SAMPLE", "GUIDE_SCALE" ]
|
45 |
-
COND_STAGE_MODEL:
|
46 |
-
FUNCTION:
|
47 |
-
- NAME: encode_list
|
48 |
-
DTYPE: bfloat16
|
49 |
-
INPUT: [ "PROMPT" ]
|
50 |
-
#
|
51 |
-
MODEL:
|
52 |
-
NAME: LatentDiffusionACEFlux
|
53 |
-
PARAMETERIZATION: rf
|
54 |
-
PRETRAINED_MODEL:
|
55 |
-
IGNORE_KEYS: [ ]
|
56 |
-
SIZE_FACTOR: 8
|
57 |
-
TEXT_IDENTIFIER: [ '{image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
|
58 |
-
USE_TEXT_POS_EMBEDDINGS: True
|
59 |
-
DIFFUSION:
|
60 |
-
# NAME DESCRIPTION: TYPE: default: 'DiffusionFluxRF'
|
61 |
-
NAME: DiffusionFluxRF
|
62 |
-
PREDICTION_TYPE: raw
|
63 |
-
# NOISE_SCHEDULER DESCRIPTION: TYPE: default: ''
|
64 |
-
NOISE_SCHEDULER:
|
65 |
-
NAME: FlowMatchFluxShiftScheduler
|
66 |
-
SHIFT: True
|
67 |
-
SIGMOID_SCALE: 1
|
68 |
-
BASE_SHIFT: 0.5
|
69 |
-
MAX_SHIFT: 1.15
|
70 |
-
#
|
71 |
-
DIFFUSION_MODEL:
|
72 |
-
# NAME DESCRIPTION: TYPE: default: 'Flux'
|
73 |
-
NAME: ACEFlux
|
74 |
-
PRETRAINED_MODEL: hf://black-forest-labs/[email protected]
|
75 |
-
SWIFT_LORA_MODEL: ["hf://scepter-studio/ACE-FLUX.1-dev@ace_flux.1_dev_lora.bin"]
|
76 |
-
# IN_CHANNELS DESCRIPTION: model's input channels. TYPE: int default: 64
|
77 |
-
IN_CHANNELS: 64
|
78 |
-
# HIDDEN_SIZE DESCRIPTION: model's hidden size. TYPE: int default: 1024
|
79 |
-
HIDDEN_SIZE: 3072
|
80 |
-
# NUM_HEADS DESCRIPTION: number of heads in the transformer. TYPE: int default: 16
|
81 |
-
NUM_HEADS: 24
|
82 |
-
# AXES_DIM DESCRIPTION: dimensions of the axes of the positional encoding. TYPE: list default: [16, 56, 56]
|
83 |
-
AXES_DIM: [ 16, 56, 56 ]
|
84 |
-
# THETA DESCRIPTION: theta for positional encoding. TYPE: int default: 10000
|
85 |
-
THETA: 10000
|
86 |
-
# VEC_IN_DIM DESCRIPTION: dimension of the vector input. TYPE: int default: 768
|
87 |
-
VEC_IN_DIM: 768
|
88 |
-
# GUIDANCE_EMBED DESCRIPTION: whether to use guidance embedding. TYPE: bool default: False
|
89 |
-
GUIDANCE_EMBED: True
|
90 |
-
# CONTEXT_IN_DIM DESCRIPTION: dimension of the context input. TYPE: int default: 4096
|
91 |
-
CONTEXT_IN_DIM: 4096
|
92 |
-
# MLP_RATIO DESCRIPTION: ratio of mlp hidden size to hidden size. TYPE: float default: 4.0
|
93 |
-
MLP_RATIO: 4.0
|
94 |
-
# QKV_BIAS DESCRIPTION: whether to use bias in qkv projection. TYPE: bool default: True
|
95 |
-
QKV_BIAS: True
|
96 |
-
# DEPTH DESCRIPTION: number of transformer blocks. TYPE: int default: 19
|
97 |
-
DEPTH: 19
|
98 |
-
# DEPTH_SINGLE_BLOCKS DESCRIPTION: number of transformer blocks in the single stream block. TYPE: int default: 38
|
99 |
-
DEPTH_SINGLE_BLOCKS: 38
|
100 |
-
ATTN_BACKEND: pytorch
|
101 |
-
|
102 |
-
#
|
103 |
-
FIRST_STAGE_MODEL:
|
104 |
-
NAME: AutoencoderKLFlux
|
105 |
-
EMBED_DIM: 16
|
106 |
-
PRETRAINED_MODEL: hf://black-forest-labs/[email protected]
|
107 |
-
IGNORE_KEYS: [ ]
|
108 |
-
BATCH_SIZE: 8
|
109 |
-
USE_CONV: False
|
110 |
-
SCALE_FACTOR: 0.3611
|
111 |
-
SHIFT_FACTOR: 0.1159
|
112 |
-
#
|
113 |
-
ENCODER:
|
114 |
-
NAME: Encoder
|
115 |
-
USE_CHECKPOINT: True
|
116 |
-
CH: 128
|
117 |
-
OUT_CH: 3
|
118 |
-
NUM_RES_BLOCKS: 2
|
119 |
-
IN_CHANNELS: 3
|
120 |
-
ATTN_RESOLUTIONS: [ ]
|
121 |
-
CH_MULT: [ 1, 2, 4, 4 ]
|
122 |
-
Z_CHANNELS: 16
|
123 |
-
DOUBLE_Z: True
|
124 |
-
DROPOUT: 0.0
|
125 |
-
RESAMP_WITH_CONV: True
|
126 |
-
#
|
127 |
-
DECODER:
|
128 |
-
NAME: Decoder
|
129 |
-
USE_CHECKPOINT: True
|
130 |
-
CH: 128
|
131 |
-
OUT_CH: 3
|
132 |
-
NUM_RES_BLOCKS: 2
|
133 |
-
IN_CHANNELS: 3
|
134 |
-
ATTN_RESOLUTIONS: [ ]
|
135 |
-
CH_MULT: [ 1, 2, 4, 4 ]
|
136 |
-
Z_CHANNELS: 16
|
137 |
-
DROPOUT: 0.0
|
138 |
-
RESAMP_WITH_CONV: True
|
139 |
-
GIVE_PRE_END: False
|
140 |
-
TANH_OUT: False
|
141 |
-
#
|
142 |
-
COND_STAGE_MODEL:
|
143 |
-
# NAME DESCRIPTION: TYPE: default: 'T5PlusClipFluxEmbedder'
|
144 |
-
NAME: T5ACEPlusClipFluxEmbedder
|
145 |
-
# T5_MODEL DESCRIPTION: TYPE: default: ''
|
146 |
-
T5_MODEL:
|
147 |
-
# NAME DESCRIPTION: TYPE: default: 'HFEmbedder'
|
148 |
-
NAME: ACEHFEmbedder
|
149 |
-
# HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
|
150 |
-
HF_MODEL_CLS: T5EncoderModel
|
151 |
-
# MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
|
152 |
-
MODEL_PATH: hf://black-forest-labs/FLUX.1-dev@text_encoder_2/
|
153 |
-
# HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
|
154 |
-
HF_TOKENIZER_CLS: T5Tokenizer
|
155 |
-
# TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
|
156 |
-
TOKENIZER_PATH: hf://black-forest-labs/FLUX.1-dev@tokenizer_2/
|
157 |
-
ADDED_IDENTIFIER: [ '<img>','{image}', '{caption}', '{mask}', '{ref_image}', '{image1}', '{image2}', '{image3}', '{image4}', '{image5}', '{image6}', '{image7}', '{image8}', '{image9}' ]
|
158 |
-
# MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
|
159 |
-
MAX_LENGTH: 512
|
160 |
-
# OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
|
161 |
-
OUTPUT_KEY: last_hidden_state
|
162 |
-
# D_TYPE DESCRIPTION: dtype TYPE: str default: 'bfloat16'
|
163 |
-
D_TYPE: bfloat16
|
164 |
-
# BATCH_INFER DESCRIPTION: batch infer TYPE: bool default: False
|
165 |
-
BATCH_INFER: False
|
166 |
-
CLEAN: whitespace
|
167 |
-
# CLIP_MODEL DESCRIPTION: TYPE: default: ''
|
168 |
-
CLIP_MODEL:
|
169 |
-
# NAME DESCRIPTION: TYPE: default: 'HFEmbedder'
|
170 |
-
NAME: ACEHFEmbedder
|
171 |
-
# HF_MODEL_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
|
172 |
-
HF_MODEL_CLS: CLIPTextModel
|
173 |
-
# MODEL_PATH DESCRIPTION: model folder path TYPE: NoneType default: None
|
174 |
-
MODEL_PATH: hf://black-forest-labs/FLUX.1-dev@text_encoder/
|
175 |
-
# HF_TOKENIZER_CLS DESCRIPTION: huggingface cls in transfomer TYPE: NoneType default: None
|
176 |
-
HF_TOKENIZER_CLS: CLIPTokenizer
|
177 |
-
# TOKENIZER_PATH DESCRIPTION: tokenizer folder path TYPE: NoneType default: None
|
178 |
-
TOKENIZER_PATH: hf://black-forest-labs/FLUX.1-dev@tokenizer/
|
179 |
-
# MAX_LENGTH DESCRIPTION: max length of input TYPE: int default: 77
|
180 |
-
MAX_LENGTH: 77
|
181 |
-
# OUTPUT_KEY DESCRIPTION: output key TYPE: str default: 'last_hidden_state'
|
182 |
-
OUTPUT_KEY: pooler_output
|
183 |
-
# D_TYPE DESCRIPTION: dtype TYPE: str default: 'bfloat16'
|
184 |
-
D_TYPE: bfloat16
|
185 |
-
# BATCH_INFER DESCRIPTION: batch infer TYPE: bool default: False
|
186 |
-
BATCH_INFER: True
|
187 |
-
CLEAN: whitespace
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
example.py
DELETED
@@ -1,370 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import os
|
4 |
-
from PIL import Image
|
5 |
-
from scepter.modules.utils.file_system import FS
|
6 |
-
|
7 |
-
|
8 |
-
def download_image(image, local_path=None):
|
9 |
-
if not FS.exists(local_path):
|
10 |
-
local_path = FS.get_from(image, local_path=local_path)
|
11 |
-
if local_path.split(".")[-1] in ['jpg', 'jpeg']:
|
12 |
-
im = Image.open(local_path).convert("RGB")
|
13 |
-
im.save(local_path, format='JPEG')
|
14 |
-
return local_path
|
15 |
-
|
16 |
-
|
17 |
-
def get_examples(cache_dir):
|
18 |
-
print('Downloading Examples ...')
|
19 |
-
examples = [
|
20 |
-
[
|
21 |
-
'Facial Editing',
|
22 |
-
download_image(
|
23 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e33edc106953.png?raw=true',
|
24 |
-
os.path.join(cache_dir, 'examples/e33edc106953.jpg')), None,
|
25 |
-
None, '{image} let the man smile', 6666
|
26 |
-
],
|
27 |
-
[
|
28 |
-
'Facial Editing',
|
29 |
-
download_image(
|
30 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5d2bcc91a3e9.png?raw=true',
|
31 |
-
os.path.join(cache_dir, 'examples/5d2bcc91a3e9.jpg')), None,
|
32 |
-
None, 'let the man in {image} wear sunglasses', 9999
|
33 |
-
],
|
34 |
-
[
|
35 |
-
'Facial Editing',
|
36 |
-
download_image(
|
37 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a52eac708bd.png?raw=true',
|
38 |
-
os.path.join(cache_dir, 'examples/3a52eac708bd.jpg')), None,
|
39 |
-
None, '{image} red hair', 9999
|
40 |
-
],
|
41 |
-
[
|
42 |
-
'Facial Editing',
|
43 |
-
download_image(
|
44 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3f4dc464a0ea.png?raw=true',
|
45 |
-
os.path.join(cache_dir, 'examples/3f4dc464a0ea.jpg')), None,
|
46 |
-
None, '{image} let the man serious', 99999
|
47 |
-
],
|
48 |
-
[
|
49 |
-
'Controllable Generation',
|
50 |
-
download_image(
|
51 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/131ca90fd2a9.png?raw=true',
|
52 |
-
os.path.join(cache_dir,
|
53 |
-
'examples/131ca90fd2a9.jpg')), None, None,
|
54 |
-
'"A person sits contemplatively on the ground, surrounded by falling autumn leaves. Dressed in a green sweater and dark blue pants, they rest their chin on their hand, exuding a relaxed demeanor. Their stylish checkered slip-on shoes add a touch of flair, while a black purse lies in their lap. The backdrop of muted brown enhances the warm, cozy atmosphere of the scene." , generate the image that corresponds to the given scribble {image}.',
|
55 |
-
613725
|
56 |
-
],
|
57 |
-
[
|
58 |
-
'Render Text',
|
59 |
-
download_image(
|
60 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48.png?raw=true',
|
61 |
-
os.path.join(cache_dir, 'examples/33e9f27c2c48.jpg')),
|
62 |
-
download_image(
|
63 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/33e9f27c2c48_mask.png?raw=true',
|
64 |
-
os.path.join(cache_dir,
|
65 |
-
'examples/33e9f27c2c48_mask.jpg')), None,
|
66 |
-
'Put the text "C A T" at the position marked by mask in the {image}',
|
67 |
-
6666
|
68 |
-
],
|
69 |
-
[
|
70 |
-
'Style Transfer',
|
71 |
-
download_image(
|
72 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/9e73e7eeef55.png?raw=true',
|
73 |
-
os.path.join(cache_dir, 'examples/9e73e7eeef55.jpg')), None,
|
74 |
-
download_image(
|
75 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/2e02975293d6.png?raw=true',
|
76 |
-
os.path.join(cache_dir, 'examples/2e02975293d6.jpg')),
|
77 |
-
'edit {image} based on the style of {image1} ', 99999
|
78 |
-
],
|
79 |
-
[
|
80 |
-
'Outpainting',
|
81 |
-
download_image(
|
82 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f.png?raw=true',
|
83 |
-
os.path.join(cache_dir, 'examples/f2b22c08be3f.jpg')),
|
84 |
-
download_image(
|
85 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f2b22c08be3f_mask.png?raw=true',
|
86 |
-
os.path.join(cache_dir,
|
87 |
-
'examples/f2b22c08be3f_mask.jpg')), None,
|
88 |
-
'Could the {image} be widened within the space designated by mask, while retaining the original?',
|
89 |
-
6666
|
90 |
-
],
|
91 |
-
[
|
92 |
-
'Image Segmentation',
|
93 |
-
download_image(
|
94 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/db3ebaa81899.png?raw=true',
|
95 |
-
os.path.join(cache_dir, 'examples/db3ebaa81899.jpg')), None,
|
96 |
-
None, '{image} Segmentation', 6666
|
97 |
-
],
|
98 |
-
[
|
99 |
-
'Depth Estimation',
|
100 |
-
download_image(
|
101 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/f1927c4692ba.png?raw=true',
|
102 |
-
os.path.join(cache_dir, 'examples/f1927c4692ba.jpg')), None,
|
103 |
-
None, '{image} Depth Estimation', 6666
|
104 |
-
],
|
105 |
-
[
|
106 |
-
'Pose Estimation',
|
107 |
-
download_image(
|
108 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/014e5bf3b4d1.png?raw=true',
|
109 |
-
os.path.join(cache_dir, 'examples/014e5bf3b4d1.jpg')), None,
|
110 |
-
None, '{image} distinguish the poses of the figures', 999999
|
111 |
-
],
|
112 |
-
[
|
113 |
-
'Scribble Extraction',
|
114 |
-
download_image(
|
115 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/5f59a202f8ac.png?raw=true',
|
116 |
-
os.path.join(cache_dir, 'examples/5f59a202f8ac.jpg')), None,
|
117 |
-
None, 'Generate a scribble of {image}, please.', 6666
|
118 |
-
],
|
119 |
-
[
|
120 |
-
'Mosaic',
|
121 |
-
download_image(
|
122 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3a2f52361eea.png?raw=true',
|
123 |
-
os.path.join(cache_dir, 'examples/3a2f52361eea.jpg')), None,
|
124 |
-
None, 'Adapt {image} into a mosaic representation.', 6666
|
125 |
-
],
|
126 |
-
[
|
127 |
-
'Edge map Extraction',
|
128 |
-
download_image(
|
129 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/b9d1e519d6e5.png?raw=true',
|
130 |
-
os.path.join(cache_dir, 'examples/b9d1e519d6e5.jpg')), None,
|
131 |
-
None, 'Get the edge-enhanced result for {image}.', 6666
|
132 |
-
],
|
133 |
-
[
|
134 |
-
'Grayscale',
|
135 |
-
download_image(
|
136 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4ebbe2ba29b.png?raw=true',
|
137 |
-
os.path.join(cache_dir, 'examples/c4ebbe2ba29b.jpg')), None,
|
138 |
-
None, 'transform {image} into a black and white one', 6666
|
139 |
-
],
|
140 |
-
[
|
141 |
-
'Contour Extraction',
|
142 |
-
download_image(
|
143 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/19652d0f6c4b.png?raw=true',
|
144 |
-
os.path.join(cache_dir,
|
145 |
-
'examples/19652d0f6c4b.jpg')), None, None,
|
146 |
-
'Would you be able to make a contour picture from {image} for me?',
|
147 |
-
6666
|
148 |
-
],
|
149 |
-
[
|
150 |
-
'Controllable Generation',
|
151 |
-
download_image(
|
152 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/249cda2844b7.png?raw=true',
|
153 |
-
os.path.join(cache_dir,
|
154 |
-
'examples/249cda2844b7.jpg')), None, None,
|
155 |
-
'Following the segmentation outcome in mask of {image}, develop a real-life image using the explanatory note in "a mighty cat lying on the bed”.',
|
156 |
-
6666
|
157 |
-
],
|
158 |
-
[
|
159 |
-
'Controllable Generation',
|
160 |
-
download_image(
|
161 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/411f6c4b8e6c.png?raw=true',
|
162 |
-
os.path.join(cache_dir,
|
163 |
-
'examples/411f6c4b8e6c.jpg')), None, None,
|
164 |
-
'use the depth map {image} and the text caption "a cut white cat" to create a corresponding graphic image',
|
165 |
-
999999
|
166 |
-
],
|
167 |
-
[
|
168 |
-
'Controllable Generation',
|
169 |
-
download_image(
|
170 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a35c96ed137a.png?raw=true',
|
171 |
-
os.path.join(cache_dir,
|
172 |
-
'examples/a35c96ed137a.jpg')), None, None,
|
173 |
-
'help translate this posture schema {image} into a colored image based on the context I provided "A beautiful woman Climbing the climbing wall, wearing a harness and climbing gear, skillfully maneuvering up the wall with her back to the camera, with a safety rope."',
|
174 |
-
3599999
|
175 |
-
],
|
176 |
-
[
|
177 |
-
'Controllable Generation',
|
178 |
-
download_image(
|
179 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/dcb2fc86f1ce.png?raw=true',
|
180 |
-
os.path.join(cache_dir,
|
181 |
-
'examples/dcb2fc86f1ce.jpg')), None, None,
|
182 |
-
'Transform and generate an image using mosaic {image} and "Monarch butterflies gracefully perch on vibrant purple flowers, showcasing their striking orange and black wings in a lush garden setting." description',
|
183 |
-
6666
|
184 |
-
],
|
185 |
-
[
|
186 |
-
'Controllable Generation',
|
187 |
-
download_image(
|
188 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/4cd4ee494962.png?raw=true',
|
189 |
-
os.path.join(cache_dir,
|
190 |
-
'examples/4cd4ee494962.jpg')), None, None,
|
191 |
-
'make this {image} colorful as per the "beautiful sunflowers"',
|
192 |
-
6666
|
193 |
-
],
|
194 |
-
[
|
195 |
-
'Controllable Generation',
|
196 |
-
download_image(
|
197 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/a47e3a9cd166.png?raw=true',
|
198 |
-
os.path.join(cache_dir,
|
199 |
-
'examples/a47e3a9cd166.jpg')), None, None,
|
200 |
-
'Take the edge conscious {image} and the written guideline "A whimsical animated character is depicted holding a delectable cake adorned with blue and white frosting and a drizzle of chocolate. The character wears a yellow headband with a bow, matching a cozy yellow sweater. Her dark hair is styled in a braid, tied with a yellow ribbon. With a golden fork in hand, she stands ready to enjoy a slice, exuding an air of joyful anticipation. The scene is creatively rendered with a charming and playful aesthetic." and produce a realistic image.',
|
201 |
-
613725
|
202 |
-
],
|
203 |
-
[
|
204 |
-
'Controllable Generation',
|
205 |
-
download_image(
|
206 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d890ed8a3ac2.png?raw=true',
|
207 |
-
os.path.join(cache_dir,
|
208 |
-
'examples/d890ed8a3ac2.jpg')), None, None,
|
209 |
-
'creating a vivid image based on {image} and description "This image features a delicious rectangular tart with a flaky, golden-brown crust. The tart is topped with evenly sliced tomatoes, layered over a creamy cheese filling. Aromatic herbs are sprinkled on top, adding a touch of green and enhancing the visual appeal. The background includes a soft, textured fabric and scattered white flowers, creating an elegant and inviting presentation. Bright red tomatoes in the upper right corner hint at the fresh ingredients used in the dish."',
|
210 |
-
6666
|
211 |
-
],
|
212 |
-
[
|
213 |
-
'Image Denoising',
|
214 |
-
download_image(
|
215 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/0844a686a179.png?raw=true',
|
216 |
-
os.path.join(cache_dir,
|
217 |
-
'examples/0844a686a179.jpg')), None, None,
|
218 |
-
'Eliminate noise interference in {image} and maximize the crispness to obtain superior high-definition quality',
|
219 |
-
6666
|
220 |
-
],
|
221 |
-
[
|
222 |
-
'Inpainting',
|
223 |
-
download_image(
|
224 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b.png?raw=true',
|
225 |
-
os.path.join(cache_dir, 'examples/fa91b6b7e59b.jpg')),
|
226 |
-
download_image(
|
227 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/fa91b6b7e59b_mask.png?raw=true',
|
228 |
-
os.path.join(cache_dir,
|
229 |
-
'examples/fa91b6b7e59b_mask.jpg')), None,
|
230 |
-
'Ensure to overhaul the parts of the {image} indicated by the mask.',
|
231 |
-
6666
|
232 |
-
],
|
233 |
-
[
|
234 |
-
'Inpainting',
|
235 |
-
download_image(
|
236 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26.png?raw=true',
|
237 |
-
os.path.join(cache_dir, 'examples/632899695b26.jpg')),
|
238 |
-
download_image(
|
239 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/632899695b26_mask.png?raw=true',
|
240 |
-
os.path.join(cache_dir,
|
241 |
-
'examples/632899695b26_mask.jpg')), None,
|
242 |
-
'Refashion the mask portion of {image} in accordance with "A yellow egg with a smiling face painted on it"',
|
243 |
-
6666
|
244 |
-
],
|
245 |
-
[
|
246 |
-
'General Editing',
|
247 |
-
download_image(
|
248 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/354d17594afe.png?raw=true',
|
249 |
-
os.path.join(cache_dir,
|
250 |
-
'examples/354d17594afe.jpg')), None, None,
|
251 |
-
'{image} change the dog\'s posture to walking in the water, and change the background to green plants and a pond.',
|
252 |
-
6666
|
253 |
-
],
|
254 |
-
[
|
255 |
-
'General Editing',
|
256 |
-
download_image(
|
257 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/38946455752b.png?raw=true',
|
258 |
-
os.path.join(cache_dir,
|
259 |
-
'examples/38946455752b.jpg')), None, None,
|
260 |
-
'{image} change the color of the dress from white to red and the model\'s hair color red brown to blonde.Other parts remain unchanged',
|
261 |
-
6669
|
262 |
-
],
|
263 |
-
[
|
264 |
-
'Facial Editing',
|
265 |
-
download_image(
|
266 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/3ba5202f0cd8.png?raw=true',
|
267 |
-
os.path.join(cache_dir,
|
268 |
-
'examples/3ba5202f0cd8.jpg')), None, None,
|
269 |
-
'Keep the same facial feature in @3ba5202f0cd8, change the woman\'s clothing from a Blue denim jacket to a white turtleneck sweater and adjust her posture so that she is supporting her chin with both hands. Other aspects, such as background, hairstyle, facial expression, etc, remain unchanged.',
|
270 |
-
99999
|
271 |
-
],
|
272 |
-
[
|
273 |
-
'Facial Editing',
|
274 |
-
download_image(
|
275 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/369365b94725.png?raw=true',
|
276 |
-
os.path.join(cache_dir, 'examples/369365b94725.jpg')), None,
|
277 |
-
None, '{image} Make her looking at the camera', 6666
|
278 |
-
],
|
279 |
-
[
|
280 |
-
'Facial Editing',
|
281 |
-
download_image(
|
282 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/92751f2e4a0e.png?raw=true',
|
283 |
-
os.path.join(cache_dir, 'examples/92751f2e4a0e.jpg')), None,
|
284 |
-
None, '{image} Remove the smile from his face', 9899999
|
285 |
-
],
|
286 |
-
[
|
287 |
-
'Remove Text',
|
288 |
-
download_image(
|
289 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/8530a6711b2e.png?raw=true',
|
290 |
-
os.path.join(cache_dir, 'examples/8530a6711b2e.jpg')), None,
|
291 |
-
None, 'Aim to remove any textual element in {image}', 6666
|
292 |
-
],
|
293 |
-
[
|
294 |
-
'Remove Text',
|
295 |
-
download_image(
|
296 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6.png?raw=true',
|
297 |
-
os.path.join(cache_dir, 'examples/c4d7fb28f8f6.jpg')),
|
298 |
-
download_image(
|
299 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/c4d7fb28f8f6_mask.png?raw=true',
|
300 |
-
os.path.join(cache_dir,
|
301 |
-
'examples/c4d7fb28f8f6_mask.jpg')), None,
|
302 |
-
'Rub out any text found in the mask sector of the {image}.', 6666
|
303 |
-
],
|
304 |
-
[
|
305 |
-
'Remove Object',
|
306 |
-
download_image(
|
307 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e2f318fa5e5b.png?raw=true',
|
308 |
-
os.path.join(cache_dir,
|
309 |
-
'examples/e2f318fa5e5b.jpg')), None, None,
|
310 |
-
'Remove the unicorn in this {image}, ensuring a smooth edit.',
|
311 |
-
99999
|
312 |
-
],
|
313 |
-
[
|
314 |
-
'Remove Object',
|
315 |
-
download_image(
|
316 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00.png?raw=true',
|
317 |
-
os.path.join(cache_dir, 'examples/1ae96d8aca00.jpg')),
|
318 |
-
download_image(
|
319 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/1ae96d8aca00_mask.png?raw=true',
|
320 |
-
os.path.join(cache_dir, 'examples/1ae96d8aca00_mask.jpg')),
|
321 |
-
None, 'Discard the contents of the mask area from {image}.', 99999
|
322 |
-
],
|
323 |
-
[
|
324 |
-
'Add Object',
|
325 |
-
download_image(
|
326 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511.png?raw=true',
|
327 |
-
os.path.join(cache_dir, 'examples/80289f48e511.jpg')),
|
328 |
-
download_image(
|
329 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/80289f48e511_mask.png?raw=true',
|
330 |
-
os.path.join(cache_dir,
|
331 |
-
'examples/80289f48e511_mask.jpg')), None,
|
332 |
-
'add a Hot Air Balloon into the {image}, per the mask', 613725
|
333 |
-
],
|
334 |
-
[
|
335 |
-
'Style Transfer',
|
336 |
-
download_image(
|
337 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/d725cb2009e8.png?raw=true',
|
338 |
-
os.path.join(cache_dir, 'examples/d725cb2009e8.jpg')), None,
|
339 |
-
None, 'Change the style of {image} to colored pencil style', 99999
|
340 |
-
],
|
341 |
-
[
|
342 |
-
'Style Transfer',
|
343 |
-
download_image(
|
344 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/e0f48b3fd010.png?raw=true',
|
345 |
-
os.path.join(cache_dir, 'examples/e0f48b3fd010.jpg')), None,
|
346 |
-
None, 'make {image} to Walt Disney Animation style', 99999
|
347 |
-
],
|
348 |
-
[
|
349 |
-
'Try On',
|
350 |
-
download_image(
|
351 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96.png?raw=true',
|
352 |
-
os.path.join(cache_dir, 'examples/ee4ca60b8c96.jpg')),
|
353 |
-
download_image(
|
354 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ee4ca60b8c96_mask.png?raw=true',
|
355 |
-
os.path.join(cache_dir, 'examples/ee4ca60b8c96_mask.jpg')),
|
356 |
-
download_image(
|
357 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/ebe825bbfe3c.png?raw=true',
|
358 |
-
os.path.join(cache_dir, 'examples/ebe825bbfe3c.jpg')),
|
359 |
-
'Change the cloth in {image} to the one in {image1}', 99999
|
360 |
-
],
|
361 |
-
[
|
362 |
-
'Workflow',
|
363 |
-
download_image(
|
364 |
-
'https://github.com/ali-vilab/ace-page/blob/main/assets/examples/cb85353c004b.png?raw=true',
|
365 |
-
os.path.join(cache_dir, 'examples/cb85353c004b.jpg')), None,
|
366 |
-
None, '<workflow> ice cream {image}', 99999
|
367 |
-
],
|
368 |
-
]
|
369 |
-
print('Finish. Start building UI ...')
|
370 |
-
return examples
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/__init__.py
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
from .flux import Flux, ACEFlux
|
2 |
-
from .embedder import ACETextEmbedder, T5ACEPlusClipFluxEmbedder, ACEHFEmbedder
|
|
|
|
|
|
models/embedder.py
DELETED
@@ -1,383 +0,0 @@
|
|
1 |
-
# -*- coding: utf-8 -*-
|
2 |
-
# Copyright (c) Alibaba, Inc. and its affiliates.
|
3 |
-
import warnings
|
4 |
-
from contextlib import nullcontext
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch.nn.functional as F
|
8 |
-
import torch.utils.dlpack
|
9 |
-
import transformers
|
10 |
-
from scepter.modules.model.embedder.base_embedder import BaseEmbedder
|
11 |
-
from scepter.modules.model.registry import EMBEDDERS
|
12 |
-
from scepter.modules.model.tokenizer.tokenizer_component import (
|
13 |
-
basic_clean, canonicalize, heavy_clean, whitespace_clean)
|
14 |
-
from scepter.modules.utils.config import dict_to_yaml
|
15 |
-
from scepter.modules.utils.distribute import we
|
16 |
-
from scepter.modules.utils.file_system import FS
|
17 |
-
|
18 |
-
try:
|
19 |
-
from transformers import AutoTokenizer, T5EncoderModel
|
20 |
-
except Exception as e:
|
21 |
-
warnings.warn(
|
22 |
-
f'Import transformers error, please deal with this problem: {e}')
|
23 |
-
|
24 |
-
|
25 |
-
@EMBEDDERS.register_class()
|
26 |
-
class ACETextEmbedder(BaseEmbedder):
|
27 |
-
"""
|
28 |
-
Uses the OpenCLIP transformer encoder for text
|
29 |
-
"""
|
30 |
-
"""
|
31 |
-
Uses the OpenCLIP transformer encoder for text
|
32 |
-
"""
|
33 |
-
para_dict = {
|
34 |
-
'PRETRAINED_MODEL': {
|
35 |
-
'value':
|
36 |
-
'google/umt5-small',
|
37 |
-
'description':
|
38 |
-
'Pretrained Model for umt5, modelcard path or local path.'
|
39 |
-
},
|
40 |
-
'TOKENIZER_PATH': {
|
41 |
-
'value': 'google/umt5-small',
|
42 |
-
'description':
|
43 |
-
'Tokenizer Path for umt5, modelcard path or local path.'
|
44 |
-
},
|
45 |
-
'FREEZE': {
|
46 |
-
'value': True,
|
47 |
-
'description': ''
|
48 |
-
},
|
49 |
-
'USE_GRAD': {
|
50 |
-
'value': False,
|
51 |
-
'description': 'Compute grad or not.'
|
52 |
-
},
|
53 |
-
'CLEAN': {
|
54 |
-
'value':
|
55 |
-
'whitespace',
|
56 |
-
'description':
|
57 |
-
'Set the clean strtegy for tokenizer, used when TOKENIZER_PATH is not None.'
|
58 |
-
},
|
59 |
-
'LAYER': {
|
60 |
-
'value': 'last',
|
61 |
-
'description': ''
|
62 |
-
},
|
63 |
-
'LEGACY': {
|
64 |
-
'value':
|
65 |
-
True,
|
66 |
-
'description':
|
67 |
-
'Whether use legacy returnd feature or not ,default True.'
|
68 |
-
}
|
69 |
-
}
|
70 |
-
|
71 |
-
def __init__(self, cfg, logger=None):
|
72 |
-
super().__init__(cfg, logger=logger)
|
73 |
-
pretrained_path = cfg.get('PRETRAINED_MODEL', None)
|
74 |
-
self.t5_dtype = cfg.get('T5_DTYPE', 'float32')
|
75 |
-
assert pretrained_path
|
76 |
-
with FS.get_dir_to_local_dir(pretrained_path,
|
77 |
-
wait_finish=True) as local_path:
|
78 |
-
self.model = T5EncoderModel.from_pretrained(
|
79 |
-
local_path,
|
80 |
-
torch_dtype=getattr(
|
81 |
-
torch,
|
82 |
-
'float' if self.t5_dtype == 'float32' else self.t5_dtype))
|
83 |
-
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
|
84 |
-
self.length = cfg.get('LENGTH', 77)
|
85 |
-
|
86 |
-
self.use_grad = cfg.get('USE_GRAD', False)
|
87 |
-
self.clean = cfg.get('CLEAN', 'whitespace')
|
88 |
-
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
|
89 |
-
if tokenizer_path:
|
90 |
-
self.tokenize_kargs = {'return_tensors': 'pt'}
|
91 |
-
with FS.get_dir_to_local_dir(tokenizer_path,
|
92 |
-
wait_finish=True) as local_path:
|
93 |
-
if self.added_identifier is not None and isinstance(
|
94 |
-
self.added_identifier, list):
|
95 |
-
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
|
96 |
-
else:
|
97 |
-
self.tokenizer = AutoTokenizer.from_pretrained(local_path)
|
98 |
-
if self.length is not None:
|
99 |
-
self.tokenize_kargs.update({
|
100 |
-
'padding': 'max_length',
|
101 |
-
'truncation': True,
|
102 |
-
'max_length': self.length
|
103 |
-
})
|
104 |
-
self.eos_token = self.tokenizer(
|
105 |
-
self.tokenizer.eos_token)['input_ids'][0]
|
106 |
-
else:
|
107 |
-
self.tokenizer = None
|
108 |
-
self.tokenize_kargs = {}
|
109 |
-
|
110 |
-
self.use_grad = cfg.get('USE_GRAD', False)
|
111 |
-
self.clean = cfg.get('CLEAN', 'whitespace')
|
112 |
-
|
113 |
-
def freeze(self):
|
114 |
-
self.model = self.model.eval()
|
115 |
-
for param in self.parameters():
|
116 |
-
param.requires_grad = False
|
117 |
-
|
118 |
-
# encode && encode_text
|
119 |
-
def forward(self, tokens, return_mask=False, use_mask=True):
|
120 |
-
# tokenization
|
121 |
-
embedding_context = nullcontext if self.use_grad else torch.no_grad
|
122 |
-
with embedding_context():
|
123 |
-
if use_mask:
|
124 |
-
x = self.model(tokens.input_ids.to(we.device_id),
|
125 |
-
tokens.attention_mask.to(we.device_id))
|
126 |
-
else:
|
127 |
-
x = self.model(tokens.input_ids.to(we.device_id))
|
128 |
-
x = x.last_hidden_state
|
129 |
-
|
130 |
-
if return_mask:
|
131 |
-
return x.detach() + 0.0, tokens.attention_mask.to(we.device_id)
|
132 |
-
else:
|
133 |
-
return x.detach() + 0.0, None
|
134 |
-
|
135 |
-
def _clean(self, text):
|
136 |
-
if self.clean == 'whitespace':
|
137 |
-
text = whitespace_clean(basic_clean(text))
|
138 |
-
elif self.clean == 'lower':
|
139 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
140 |
-
elif self.clean == 'canonicalize':
|
141 |
-
text = canonicalize(basic_clean(text))
|
142 |
-
elif self.clean == 'heavy':
|
143 |
-
text = heavy_clean(basic_clean(text))
|
144 |
-
return text
|
145 |
-
|
146 |
-
def encode(self, text, return_mask=False, use_mask=True):
|
147 |
-
if isinstance(text, str):
|
148 |
-
text = [text]
|
149 |
-
if self.clean:
|
150 |
-
text = [self._clean(u) for u in text]
|
151 |
-
assert self.tokenizer is not None
|
152 |
-
cont, mask = [], []
|
153 |
-
with torch.autocast(device_type='cuda',
|
154 |
-
enabled=self.t5_dtype in ('float16', 'bfloat16'),
|
155 |
-
dtype=getattr(torch, self.t5_dtype)):
|
156 |
-
for tt in text:
|
157 |
-
tokens = self.tokenizer([tt], **self.tokenize_kargs)
|
158 |
-
one_cont, one_mask = self(tokens,
|
159 |
-
return_mask=return_mask,
|
160 |
-
use_mask=use_mask)
|
161 |
-
cont.append(one_cont)
|
162 |
-
mask.append(one_mask)
|
163 |
-
if return_mask:
|
164 |
-
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
|
165 |
-
else:
|
166 |
-
return torch.cat(cont, dim=0)
|
167 |
-
|
168 |
-
def encode_list(self, text_list, return_mask=True):
|
169 |
-
cont_list = []
|
170 |
-
mask_list = []
|
171 |
-
for pp in text_list:
|
172 |
-
cont, cont_mask = self.encode(pp, return_mask=return_mask)
|
173 |
-
cont_list.append(cont)
|
174 |
-
mask_list.append(cont_mask)
|
175 |
-
if return_mask:
|
176 |
-
return cont_list, mask_list
|
177 |
-
else:
|
178 |
-
return cont_list
|
179 |
-
|
180 |
-
@staticmethod
|
181 |
-
def get_config_template():
|
182 |
-
return dict_to_yaml('MODELS',
|
183 |
-
__class__.__name__,
|
184 |
-
ACETextEmbedder.para_dict,
|
185 |
-
set_name=True)
|
186 |
-
|
187 |
-
@EMBEDDERS.register_class()
|
188 |
-
class ACEHFEmbedder(BaseEmbedder):
|
189 |
-
para_dict = {
|
190 |
-
"HF_MODEL_CLS": {
|
191 |
-
"value": None,
|
192 |
-
"description": "huggingface cls in transfomer"
|
193 |
-
},
|
194 |
-
"MODEL_PATH": {
|
195 |
-
"value": None,
|
196 |
-
"description": "model folder path"
|
197 |
-
},
|
198 |
-
"HF_TOKENIZER_CLS": {
|
199 |
-
"value": None,
|
200 |
-
"description": "huggingface cls in transfomer"
|
201 |
-
},
|
202 |
-
|
203 |
-
"TOKENIZER_PATH": {
|
204 |
-
"value": None,
|
205 |
-
"description": "tokenizer folder path"
|
206 |
-
},
|
207 |
-
"MAX_LENGTH": {
|
208 |
-
"value": 77,
|
209 |
-
"description": "max length of input"
|
210 |
-
},
|
211 |
-
"OUTPUT_KEY": {
|
212 |
-
"value": "last_hidden_state",
|
213 |
-
"description": "output key"
|
214 |
-
},
|
215 |
-
"D_TYPE": {
|
216 |
-
"value": "float",
|
217 |
-
"description": "dtype"
|
218 |
-
},
|
219 |
-
"BATCH_INFER": {
|
220 |
-
"value": False,
|
221 |
-
"description": "batch infer"
|
222 |
-
}
|
223 |
-
}
|
224 |
-
para_dict.update(BaseEmbedder.para_dict)
|
225 |
-
def __init__(self, cfg, logger=None):
|
226 |
-
super().__init__(cfg, logger=logger)
|
227 |
-
hf_model_cls = cfg.get('HF_MODEL_CLS', None)
|
228 |
-
model_path = cfg.get("MODEL_PATH", None)
|
229 |
-
hf_tokenizer_cls = cfg.get('HF_TOKENIZER_CLS', None)
|
230 |
-
tokenizer_path = cfg.get('TOKENIZER_PATH', None)
|
231 |
-
self.max_length = cfg.get('MAX_LENGTH', 77)
|
232 |
-
self.output_key = cfg.get("OUTPUT_KEY", "last_hidden_state")
|
233 |
-
self.d_type = cfg.get("D_TYPE", "float")
|
234 |
-
self.clean = cfg.get("CLEAN", "whitespace")
|
235 |
-
self.batch_infer = cfg.get("BATCH_INFER", False)
|
236 |
-
self.added_identifier = cfg.get('ADDED_IDENTIFIER', None)
|
237 |
-
torch_dtype = getattr(torch, self.d_type)
|
238 |
-
|
239 |
-
assert hf_model_cls is not None and hf_tokenizer_cls is not None
|
240 |
-
assert model_path is not None and tokenizer_path is not None
|
241 |
-
with FS.get_dir_to_local_dir(tokenizer_path, wait_finish=True) as local_path:
|
242 |
-
self.tokenizer = getattr(transformers, hf_tokenizer_cls).from_pretrained(local_path,
|
243 |
-
max_length = self.max_length,
|
244 |
-
torch_dtype = torch_dtype,
|
245 |
-
additional_special_tokens=self.added_identifier)
|
246 |
-
|
247 |
-
with FS.get_dir_to_local_dir(model_path, wait_finish=True) as local_path:
|
248 |
-
self.hf_module = getattr(transformers, hf_model_cls).from_pretrained(local_path, torch_dtype = torch_dtype)
|
249 |
-
|
250 |
-
|
251 |
-
self.hf_module = self.hf_module.eval().requires_grad_(False)
|
252 |
-
|
253 |
-
def forward(self, text: list[str], return_mask = False):
|
254 |
-
batch_encoding = self.tokenizer(
|
255 |
-
text,
|
256 |
-
truncation=True,
|
257 |
-
max_length=self.max_length,
|
258 |
-
return_length=False,
|
259 |
-
return_overflowing_tokens=False,
|
260 |
-
padding="max_length",
|
261 |
-
return_tensors="pt",
|
262 |
-
)
|
263 |
-
|
264 |
-
outputs = self.hf_module(
|
265 |
-
input_ids=batch_encoding["input_ids"].to(self.hf_module.device),
|
266 |
-
attention_mask=None,
|
267 |
-
output_hidden_states=False,
|
268 |
-
)
|
269 |
-
if return_mask:
|
270 |
-
return outputs[self.output_key], batch_encoding['attention_mask'].to(self.hf_module.device)
|
271 |
-
else:
|
272 |
-
return outputs[self.output_key], None
|
273 |
-
|
274 |
-
def encode(self, text, return_mask = False):
|
275 |
-
if isinstance(text, str):
|
276 |
-
text = [text]
|
277 |
-
if self.clean:
|
278 |
-
text = [self._clean(u) for u in text]
|
279 |
-
if not self.batch_infer:
|
280 |
-
cont, mask = [], []
|
281 |
-
for tt in text:
|
282 |
-
one_cont, one_mask = self([tt], return_mask=return_mask)
|
283 |
-
cont.append(one_cont)
|
284 |
-
mask.append(one_mask)
|
285 |
-
if return_mask:
|
286 |
-
return torch.cat(cont, dim=0), torch.cat(mask, dim=0)
|
287 |
-
else:
|
288 |
-
return torch.cat(cont, dim=0)
|
289 |
-
else:
|
290 |
-
ret_data = self(text, return_mask = return_mask)
|
291 |
-
if return_mask:
|
292 |
-
return ret_data
|
293 |
-
else:
|
294 |
-
return ret_data[0]
|
295 |
-
|
296 |
-
def encode_list(self, text_list, return_mask=True):
|
297 |
-
cont_list = []
|
298 |
-
mask_list = []
|
299 |
-
for pp in text_list:
|
300 |
-
cont = self.encode(pp, return_mask=return_mask)
|
301 |
-
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
|
302 |
-
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
|
303 |
-
if return_mask:
|
304 |
-
return cont_list, mask_list
|
305 |
-
else:
|
306 |
-
return cont_list
|
307 |
-
|
308 |
-
def encode_list_of_list(self, text_list, return_mask=True):
|
309 |
-
cont_list = []
|
310 |
-
mask_list = []
|
311 |
-
for pp in text_list:
|
312 |
-
cont = self.encode_list(pp, return_mask=return_mask)
|
313 |
-
cont_list.append(cont[0]) if return_mask else cont_list.append(cont)
|
314 |
-
mask_list.append(cont[1]) if return_mask else mask_list.append(None)
|
315 |
-
if return_mask:
|
316 |
-
return cont_list, mask_list
|
317 |
-
else:
|
318 |
-
return cont_list
|
319 |
-
|
320 |
-
def _clean(self, text):
|
321 |
-
if self.clean == 'whitespace':
|
322 |
-
text = whitespace_clean(basic_clean(text))
|
323 |
-
elif self.clean == 'lower':
|
324 |
-
text = whitespace_clean(basic_clean(text)).lower()
|
325 |
-
elif self.clean == 'canonicalize':
|
326 |
-
text = canonicalize(basic_clean(text))
|
327 |
-
return text
|
328 |
-
@staticmethod
|
329 |
-
def get_config_template():
|
330 |
-
return dict_to_yaml('EMBEDDER',
|
331 |
-
__class__.__name__,
|
332 |
-
ACEHFEmbedder.para_dict,
|
333 |
-
set_name=True)
|
334 |
-
|
335 |
-
@EMBEDDERS.register_class()
|
336 |
-
class T5ACEPlusClipFluxEmbedder(BaseEmbedder):
|
337 |
-
"""
|
338 |
-
Uses the OpenCLIP transformer encoder for text
|
339 |
-
"""
|
340 |
-
para_dict = {
|
341 |
-
'T5_MODEL': {},
|
342 |
-
'CLIP_MODEL': {}
|
343 |
-
}
|
344 |
-
|
345 |
-
def __init__(self, cfg, logger=None):
|
346 |
-
super().__init__(cfg, logger=logger)
|
347 |
-
self.t5_model = EMBEDDERS.build(cfg.T5_MODEL, logger=logger)
|
348 |
-
self.clip_model = EMBEDDERS.build(cfg.CLIP_MODEL, logger=logger)
|
349 |
-
|
350 |
-
def encode(self, text, return_mask = False):
|
351 |
-
t5_embeds = self.t5_model.encode(text, return_mask = return_mask)
|
352 |
-
clip_embeds = self.clip_model.encode(text, return_mask = return_mask)
|
353 |
-
# change embedding strategy here
|
354 |
-
return {
|
355 |
-
'context': t5_embeds,
|
356 |
-
'y': clip_embeds,
|
357 |
-
}
|
358 |
-
|
359 |
-
def encode_list(self, text, return_mask = False):
|
360 |
-
t5_embeds = self.t5_model.encode_list(text, return_mask = return_mask)
|
361 |
-
clip_embeds = self.clip_model.encode_list(text, return_mask = return_mask)
|
362 |
-
# change embedding strategy here
|
363 |
-
return {
|
364 |
-
'context': t5_embeds,
|
365 |
-
'y': clip_embeds,
|
366 |
-
}
|
367 |
-
|
368 |
-
def encode_list_of_list(self, text, return_mask = False):
|
369 |
-
t5_embeds = self.t5_model.encode_list_of_list(text, return_mask = return_mask)
|
370 |
-
clip_embeds = self.clip_model.encode_list_of_list(text, return_mask = return_mask)
|
371 |
-
# change embedding strategy here
|
372 |
-
return {
|
373 |
-
'context': t5_embeds,
|
374 |
-
'y': clip_embeds,
|
375 |
-
}
|
376 |
-
|
377 |
-
|
378 |
-
@staticmethod
|
379 |
-
def get_config_template():
|
380 |
-
return dict_to_yaml('EMBEDDER',
|
381 |
-
__class__.__name__,
|
382 |
-
T5ACEPlusClipFluxEmbedder.para_dict,
|
383 |
-
set_name=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/flux.py
DELETED
@@ -1,798 +0,0 @@
|
|
1 |
-
import math, torch
|
2 |
-
from collections import OrderedDict
|
3 |
-
from functools import partial
|
4 |
-
from einops import rearrange, repeat
|
5 |
-
from scepter.modules.model.base_model import BaseModel
|
6 |
-
from scepter.modules.model.registry import BACKBONES
|
7 |
-
from scepter.modules.utils.config import dict_to_yaml
|
8 |
-
from scepter.modules.utils.distribute import we
|
9 |
-
from scepter.modules.utils.file_system import FS
|
10 |
-
from torch import Tensor, nn
|
11 |
-
from torch.nn.utils.rnn import pad_sequence
|
12 |
-
from torch.utils.checkpoint import checkpoint_sequential
|
13 |
-
|
14 |
-
from .layers import (DoubleStreamBlock, EmbedND, LastLayer,
|
15 |
-
MLPEmbedder, SingleStreamBlock,
|
16 |
-
timestep_embedding, DoubleStreamBlockACE, SingleStreamBlockACE)
|
17 |
-
|
18 |
-
@BACKBONES.register_class()
|
19 |
-
class Flux(BaseModel):
|
20 |
-
"""
|
21 |
-
Transformer backbone Diffusion model with RoPE.
|
22 |
-
"""
|
23 |
-
para_dict = {
|
24 |
-
"IN_CHANNELS": {
|
25 |
-
"value": 64,
|
26 |
-
"description": "model's input channels."
|
27 |
-
},
|
28 |
-
"OUT_CHANNELS": {
|
29 |
-
"value": 64,
|
30 |
-
"description": "model's output channels."
|
31 |
-
},
|
32 |
-
"HIDDEN_SIZE": {
|
33 |
-
"value": 1024,
|
34 |
-
"description": "model's hidden size."
|
35 |
-
},
|
36 |
-
"NUM_HEADS": {
|
37 |
-
"value": 16,
|
38 |
-
"description": "number of heads in the transformer."
|
39 |
-
},
|
40 |
-
"AXES_DIM": {
|
41 |
-
"value": [16, 56, 56],
|
42 |
-
"description": "dimensions of the axes of the positional encoding."
|
43 |
-
},
|
44 |
-
"THETA": {
|
45 |
-
"value": 10_000,
|
46 |
-
"description": "theta for positional encoding."
|
47 |
-
},
|
48 |
-
"VEC_IN_DIM": {
|
49 |
-
"value": 768,
|
50 |
-
"description": "dimension of the vector input."
|
51 |
-
},
|
52 |
-
"GUIDANCE_EMBED": {
|
53 |
-
"value": False,
|
54 |
-
"description": "whether to use guidance embedding."
|
55 |
-
},
|
56 |
-
"CONTEXT_IN_DIM": {
|
57 |
-
"value": 4096,
|
58 |
-
"description": "dimension of the context input."
|
59 |
-
},
|
60 |
-
"MLP_RATIO": {
|
61 |
-
"value": 4.0,
|
62 |
-
"description": "ratio of mlp hidden size to hidden size."
|
63 |
-
},
|
64 |
-
"QKV_BIAS": {
|
65 |
-
"value": True,
|
66 |
-
"description": "whether to use bias in qkv projection."
|
67 |
-
},
|
68 |
-
"DEPTH": {
|
69 |
-
"value": 19,
|
70 |
-
"description": "number of transformer blocks."
|
71 |
-
},
|
72 |
-
"DEPTH_SINGLE_BLOCKS": {
|
73 |
-
"value": 38,
|
74 |
-
"description": "number of transformer blocks in the single stream block."
|
75 |
-
},
|
76 |
-
"USE_GRAD_CHECKPOINT": {
|
77 |
-
"value": False,
|
78 |
-
"description": "whether to use gradient checkpointing."
|
79 |
-
},
|
80 |
-
"ATTN_BACKEND": {
|
81 |
-
"value": "pytorch",
|
82 |
-
"description": "backend for the transformer blocks, 'pytorch' or 'flash_attn'."
|
83 |
-
}
|
84 |
-
}
|
85 |
-
def __init__(
|
86 |
-
self,
|
87 |
-
cfg,
|
88 |
-
logger = None
|
89 |
-
):
|
90 |
-
super().__init__(cfg, logger=logger)
|
91 |
-
self.in_channels = cfg.IN_CHANNELS
|
92 |
-
self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
|
93 |
-
hidden_size = cfg.get("HIDDEN_SIZE", 1024)
|
94 |
-
num_heads = cfg.get("NUM_HEADS", 16)
|
95 |
-
axes_dim = cfg.AXES_DIM
|
96 |
-
theta = cfg.THETA
|
97 |
-
vec_in_dim = cfg.VEC_IN_DIM
|
98 |
-
self.guidance_embed = cfg.GUIDANCE_EMBED
|
99 |
-
context_in_dim = cfg.CONTEXT_IN_DIM
|
100 |
-
mlp_ratio = cfg.MLP_RATIO
|
101 |
-
qkv_bias = cfg.QKV_BIAS
|
102 |
-
depth = cfg.DEPTH
|
103 |
-
depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
|
104 |
-
self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
|
105 |
-
self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
|
106 |
-
self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
|
107 |
-
self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
|
108 |
-
self.blackforest_lora_model = cfg.get("BLACKFOREST_LORA_MODEL", None)
|
109 |
-
self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
|
110 |
-
|
111 |
-
if hidden_size % num_heads != 0:
|
112 |
-
raise ValueError(
|
113 |
-
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
114 |
-
)
|
115 |
-
pe_dim = hidden_size // num_heads
|
116 |
-
if sum(axes_dim) != pe_dim:
|
117 |
-
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
118 |
-
self.hidden_size = hidden_size
|
119 |
-
self.num_heads = num_heads
|
120 |
-
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim= axes_dim)
|
121 |
-
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
122 |
-
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
123 |
-
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
|
124 |
-
self.guidance_in = (
|
125 |
-
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
|
126 |
-
)
|
127 |
-
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
128 |
-
|
129 |
-
self.double_blocks = nn.ModuleList(
|
130 |
-
[
|
131 |
-
DoubleStreamBlock(
|
132 |
-
self.hidden_size,
|
133 |
-
self.num_heads,
|
134 |
-
mlp_ratio=mlp_ratio,
|
135 |
-
qkv_bias=qkv_bias,
|
136 |
-
backend=self.attn_backend
|
137 |
-
)
|
138 |
-
for _ in range(depth)
|
139 |
-
]
|
140 |
-
)
|
141 |
-
|
142 |
-
self.single_blocks = nn.ModuleList(
|
143 |
-
[
|
144 |
-
SingleStreamBlock(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
|
145 |
-
for _ in range(depth_single_blocks)
|
146 |
-
]
|
147 |
-
)
|
148 |
-
|
149 |
-
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
150 |
-
|
151 |
-
def prepare_input(self, x, context, y, x_shape=None):
|
152 |
-
# x.shape [6, 16, 16, 16] target is [6, 16, 768, 1360]
|
153 |
-
bs, c, h, w = x.shape
|
154 |
-
x = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)
|
155 |
-
x_id = torch.zeros(h // 2, w // 2, 3)
|
156 |
-
x_id[..., 1] = x_id[..., 1] + torch.arange(h // 2)[:, None]
|
157 |
-
x_id[..., 2] = x_id[..., 2] + torch.arange(w // 2)[None, :]
|
158 |
-
x_ids = repeat(x_id, "h w c -> b (h w) c", b=bs)
|
159 |
-
txt_ids = torch.zeros(bs, context.shape[1], 3)
|
160 |
-
return x, x_ids.to(x), context.to(x), txt_ids.to(x), y.to(x), h, w
|
161 |
-
|
162 |
-
def unpack(self, x: Tensor, height: int, width: int) -> Tensor:
|
163 |
-
return rearrange(
|
164 |
-
x,
|
165 |
-
"b (h w) (c ph pw) -> b c (h ph) (w pw)",
|
166 |
-
h=math.ceil(height/2),
|
167 |
-
w=math.ceil(width/2),
|
168 |
-
ph=2,
|
169 |
-
pw=2,
|
170 |
-
)
|
171 |
-
|
172 |
-
# def merge_diffuser_lora(self, ori_sd, lora_sd, scale = 1.0):
|
173 |
-
# key_map = {
|
174 |
-
# "single_blocks.{}.linear1.weight": {"key_list": [
|
175 |
-
# ["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
|
176 |
-
# "transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight"],
|
177 |
-
# ["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
|
178 |
-
# "transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight"],
|
179 |
-
# ["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
|
180 |
-
# "transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight"],
|
181 |
-
# ["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
|
182 |
-
# "transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight"]
|
183 |
-
# ], "num": 38},
|
184 |
-
# "single_blocks.{}.modulation.lin.weight": {"key_list": [
|
185 |
-
# ["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
|
186 |
-
# "transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight"],
|
187 |
-
# ], "num": 38},
|
188 |
-
# "single_blocks.{}.linear2.weight": {"key_list": [
|
189 |
-
# ["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
|
190 |
-
# "transformer.single_transformer_blocks.{}.proj_out.lora_B.weight"],
|
191 |
-
# ], "num": 38},
|
192 |
-
# "double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
|
193 |
-
# ["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
|
194 |
-
# "transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight"],
|
195 |
-
# ["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
|
196 |
-
# "transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight"],
|
197 |
-
# ["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
|
198 |
-
# "transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight"],
|
199 |
-
# ], "num": 19},
|
200 |
-
# "double_blocks.{}.img_attn.qkv.weight": {"key_list": [
|
201 |
-
# ["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
|
202 |
-
# "transformer.transformer_blocks.{}.attn.to_q.lora_B.weight"],
|
203 |
-
# ["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
|
204 |
-
# "transformer.transformer_blocks.{}.attn.to_k.lora_B.weight"],
|
205 |
-
# ["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
|
206 |
-
# "transformer.transformer_blocks.{}.attn.to_v.lora_B.weight"],
|
207 |
-
# ], "num": 19},
|
208 |
-
# "double_blocks.{}.img_attn.proj.weight": {"key_list": [
|
209 |
-
# ["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
|
210 |
-
# "transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight"]
|
211 |
-
# ], "num": 19},
|
212 |
-
# "double_blocks.{}.txt_attn.proj.weight": {"key_list": [
|
213 |
-
# ["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
|
214 |
-
# "transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight"]
|
215 |
-
# ], "num": 19},
|
216 |
-
# "double_blocks.{}.img_mlp.0.weight": {"key_list": [
|
217 |
-
# ["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
|
218 |
-
# "transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight"]
|
219 |
-
# ], "num": 19},
|
220 |
-
# "double_blocks.{}.img_mlp.2.weight": {"key_list": [
|
221 |
-
# ["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
|
222 |
-
# "transformer.transformer_blocks.{}.ff.net.2.lora_B.weight"]
|
223 |
-
# ], "num": 19},
|
224 |
-
# "double_blocks.{}.txt_mlp.0.weight": {"key_list": [
|
225 |
-
# ["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
|
226 |
-
# "transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight"]
|
227 |
-
# ], "num": 19},
|
228 |
-
# "double_blocks.{}.txt_mlp.2.weight": {"key_list": [
|
229 |
-
# ["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
|
230 |
-
# "transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight"]
|
231 |
-
# ], "num": 19},
|
232 |
-
# "double_blocks.{}.img_mod.lin.weight": {"key_list": [
|
233 |
-
# ["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
|
234 |
-
# "transformer.transformer_blocks.{}.norm1.linear.lora_B.weight"]
|
235 |
-
# ], "num": 19},
|
236 |
-
# "double_blocks.{}.txt_mod.lin.weight": {"key_list": [
|
237 |
-
# ["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
|
238 |
-
# "transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight"]
|
239 |
-
# ], "num": 19}
|
240 |
-
# }
|
241 |
-
# have_lora_keys = 0
|
242 |
-
# for k, v in key_map.items():
|
243 |
-
# key_list = v["key_list"]
|
244 |
-
# block_num = v["num"]
|
245 |
-
# for block_id in range(block_num):
|
246 |
-
# current_weight_list = []
|
247 |
-
# for k_list in key_list:
|
248 |
-
# current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
|
249 |
-
# lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
|
250 |
-
# current_weight_list.append(current_weight)
|
251 |
-
# current_weight = torch.cat(current_weight_list, dim=0)
|
252 |
-
# ori_sd[k.format(block_id)] += scale*current_weight
|
253 |
-
# have_lora_keys += 1
|
254 |
-
# self.logger.info(f"merge_swift_lora loads lora'parameters {have_lora_keys}")
|
255 |
-
# return ori_sd
|
256 |
-
|
257 |
-
def merge_diffuser_lora(self, ori_sd, lora_sd, scale=1.0):
|
258 |
-
key_map = {
|
259 |
-
"single_blocks.{}.linear1.weight": {"key_list": [
|
260 |
-
["transformer.single_transformer_blocks.{}.attn.to_q.lora_A.weight",
|
261 |
-
"transformer.single_transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
|
262 |
-
["transformer.single_transformer_blocks.{}.attn.to_k.lora_A.weight",
|
263 |
-
"transformer.single_transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
|
264 |
-
["transformer.single_transformer_blocks.{}.attn.to_v.lora_A.weight",
|
265 |
-
"transformer.single_transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
|
266 |
-
["transformer.single_transformer_blocks.{}.proj_mlp.lora_A.weight",
|
267 |
-
"transformer.single_transformer_blocks.{}.proj_mlp.lora_B.weight", [9216, 21504]]
|
268 |
-
], "num": 38},
|
269 |
-
"single_blocks.{}.modulation.lin.weight": {"key_list": [
|
270 |
-
["transformer.single_transformer_blocks.{}.norm.linear.lora_A.weight",
|
271 |
-
"transformer.single_transformer_blocks.{}.norm.linear.lora_B.weight", [0, 9216]],
|
272 |
-
], "num": 38},
|
273 |
-
"single_blocks.{}.linear2.weight": {"key_list": [
|
274 |
-
["transformer.single_transformer_blocks.{}.proj_out.lora_A.weight",
|
275 |
-
"transformer.single_transformer_blocks.{}.proj_out.lora_B.weight", [0, 3072]],
|
276 |
-
], "num": 38},
|
277 |
-
"double_blocks.{}.txt_attn.qkv.weight": {"key_list": [
|
278 |
-
["transformer.transformer_blocks.{}.attn.add_q_proj.lora_A.weight",
|
279 |
-
"transformer.transformer_blocks.{}.attn.add_q_proj.lora_B.weight", [0, 3072]],
|
280 |
-
["transformer.transformer_blocks.{}.attn.add_k_proj.lora_A.weight",
|
281 |
-
"transformer.transformer_blocks.{}.attn.add_k_proj.lora_B.weight", [3072, 6144]],
|
282 |
-
["transformer.transformer_blocks.{}.attn.add_v_proj.lora_A.weight",
|
283 |
-
"transformer.transformer_blocks.{}.attn.add_v_proj.lora_B.weight", [6144, 9216]],
|
284 |
-
], "num": 19},
|
285 |
-
"double_blocks.{}.img_attn.qkv.weight": {"key_list": [
|
286 |
-
["transformer.transformer_blocks.{}.attn.to_q.lora_A.weight",
|
287 |
-
"transformer.transformer_blocks.{}.attn.to_q.lora_B.weight", [0, 3072]],
|
288 |
-
["transformer.transformer_blocks.{}.attn.to_k.lora_A.weight",
|
289 |
-
"transformer.transformer_blocks.{}.attn.to_k.lora_B.weight", [3072, 6144]],
|
290 |
-
["transformer.transformer_blocks.{}.attn.to_v.lora_A.weight",
|
291 |
-
"transformer.transformer_blocks.{}.attn.to_v.lora_B.weight", [6144, 9216]],
|
292 |
-
], "num": 19},
|
293 |
-
"double_blocks.{}.img_attn.proj.weight": {"key_list": [
|
294 |
-
["transformer.transformer_blocks.{}.attn.to_out.0.lora_A.weight",
|
295 |
-
"transformer.transformer_blocks.{}.attn.to_out.0.lora_B.weight", [0, 3072]]
|
296 |
-
], "num": 19},
|
297 |
-
"double_blocks.{}.txt_attn.proj.weight": {"key_list": [
|
298 |
-
["transformer.transformer_blocks.{}.attn.to_add_out.lora_A.weight",
|
299 |
-
"transformer.transformer_blocks.{}.attn.to_add_out.lora_B.weight", [0, 3072]]
|
300 |
-
], "num": 19},
|
301 |
-
"double_blocks.{}.img_mlp.0.weight": {"key_list": [
|
302 |
-
["transformer.transformer_blocks.{}.ff.net.0.proj.lora_A.weight",
|
303 |
-
"transformer.transformer_blocks.{}.ff.net.0.proj.lora_B.weight", [0, 12288]]
|
304 |
-
], "num": 19},
|
305 |
-
"double_blocks.{}.img_mlp.2.weight": {"key_list": [
|
306 |
-
["transformer.transformer_blocks.{}.ff.net.2.lora_A.weight",
|
307 |
-
"transformer.transformer_blocks.{}.ff.net.2.lora_B.weight", [0, 3072]]
|
308 |
-
], "num": 19},
|
309 |
-
"double_blocks.{}.txt_mlp.0.weight": {"key_list": [
|
310 |
-
["transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_A.weight",
|
311 |
-
"transformer.transformer_blocks.{}.ff_context.net.0.proj.lora_B.weight", [0, 12288]]
|
312 |
-
], "num": 19},
|
313 |
-
"double_blocks.{}.txt_mlp.2.weight": {"key_list": [
|
314 |
-
["transformer.transformer_blocks.{}.ff_context.net.2.lora_A.weight",
|
315 |
-
"transformer.transformer_blocks.{}.ff_context.net.2.lora_B.weight", [0, 3072]]
|
316 |
-
], "num": 19},
|
317 |
-
"double_blocks.{}.img_mod.lin.weight": {"key_list": [
|
318 |
-
["transformer.transformer_blocks.{}.norm1.linear.lora_A.weight",
|
319 |
-
"transformer.transformer_blocks.{}.norm1.linear.lora_B.weight", [0, 18432]]
|
320 |
-
], "num": 19},
|
321 |
-
"double_blocks.{}.txt_mod.lin.weight": {"key_list": [
|
322 |
-
["transformer.transformer_blocks.{}.norm1_context.linear.lora_A.weight",
|
323 |
-
"transformer.transformer_blocks.{}.norm1_context.linear.lora_B.weight", [0, 18432]]
|
324 |
-
], "num": 19}
|
325 |
-
}
|
326 |
-
cover_lora_keys = set()
|
327 |
-
cover_ori_keys = set()
|
328 |
-
for k, v in key_map.items():
|
329 |
-
key_list = v["key_list"]
|
330 |
-
block_num = v["num"]
|
331 |
-
for block_id in range(block_num):
|
332 |
-
for k_list in key_list:
|
333 |
-
if k_list[0].format(block_id) in lora_sd and k_list[1].format(block_id) in lora_sd:
|
334 |
-
cover_lora_keys.add(k_list[0].format(block_id))
|
335 |
-
cover_lora_keys.add(k_list[1].format(block_id))
|
336 |
-
current_weight = torch.matmul(lora_sd[k_list[0].format(block_id)].permute(1, 0),
|
337 |
-
lora_sd[k_list[1].format(block_id)].permute(1, 0)).permute(1, 0)
|
338 |
-
ori_sd[k.format(block_id)][k_list[2][0]:k_list[2][1], ...] += scale * current_weight
|
339 |
-
cover_ori_keys.add(k.format(block_id))
|
340 |
-
# lora_sd.pop(k_list[0].format(block_id))
|
341 |
-
# lora_sd.pop(k_list[1].format(block_id))
|
342 |
-
self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
|
343 |
-
f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
|
344 |
-
f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
|
345 |
-
return ori_sd
|
346 |
-
|
347 |
-
def merge_swift_lora(self, ori_sd, lora_sd, scale = 1.0):
|
348 |
-
have_lora_keys = {}
|
349 |
-
for k, v in lora_sd.items():
|
350 |
-
k = k[len("model."):] if k.startswith("model.") else k
|
351 |
-
ori_key = k.split("lora")[0] + "weight"
|
352 |
-
if ori_key not in ori_sd:
|
353 |
-
raise f"{ori_key} should in the original statedict"
|
354 |
-
if ori_key not in have_lora_keys:
|
355 |
-
have_lora_keys[ori_key] = {}
|
356 |
-
if "lora_A" in k:
|
357 |
-
have_lora_keys[ori_key]["lora_A"] = v
|
358 |
-
elif "lora_B" in k:
|
359 |
-
have_lora_keys[ori_key]["lora_B"] = v
|
360 |
-
else:
|
361 |
-
raise NotImplementedError
|
362 |
-
self.logger.info(f"merge_swift_lora loads lora'parameters {len(have_lora_keys)}")
|
363 |
-
for key, v in have_lora_keys.items():
|
364 |
-
current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
|
365 |
-
ori_sd[key] += scale * current_weight
|
366 |
-
return ori_sd
|
367 |
-
|
368 |
-
|
369 |
-
def merge_blackforest_lora(self, ori_sd, lora_sd, scale = 1.0):
|
370 |
-
have_lora_keys = {}
|
371 |
-
cover_lora_keys = set()
|
372 |
-
cover_ori_keys = set()
|
373 |
-
for k, v in lora_sd.items():
|
374 |
-
if "lora" in k:
|
375 |
-
ori_key = k.split("lora")[0] + "weight"
|
376 |
-
if ori_key not in ori_sd:
|
377 |
-
raise f"{ori_key} should in the original statedict"
|
378 |
-
if ori_key not in have_lora_keys:
|
379 |
-
have_lora_keys[ori_key] = {}
|
380 |
-
if "lora_A" in k:
|
381 |
-
have_lora_keys[ori_key]["lora_A"] = v
|
382 |
-
cover_lora_keys.add(k)
|
383 |
-
cover_ori_keys.add(ori_key)
|
384 |
-
elif "lora_B" in k:
|
385 |
-
have_lora_keys[ori_key]["lora_B"] = v
|
386 |
-
cover_lora_keys.add(k)
|
387 |
-
cover_ori_keys.add(ori_key)
|
388 |
-
else:
|
389 |
-
if k in ori_sd:
|
390 |
-
ori_sd[k] = v
|
391 |
-
cover_lora_keys.add(k)
|
392 |
-
cover_ori_keys.add(k)
|
393 |
-
else:
|
394 |
-
print("unsurpport keys: ", k)
|
395 |
-
self.logger.info(f"merge_blackforest_lora loads lora'parameters lora-paras: \n"
|
396 |
-
f"cover-{len(cover_lora_keys)} vs total {len(lora_sd)} \n"
|
397 |
-
f"cover ori-{len(cover_ori_keys)} vs total {len(ori_sd)}")
|
398 |
-
|
399 |
-
for key, v in have_lora_keys.items():
|
400 |
-
current_weight = torch.matmul(v["lora_A"].permute(1, 0), v["lora_B"].permute(1, 0)).permute(1, 0)
|
401 |
-
# print(key, ori_sd[key].shape, current_weight.shape)
|
402 |
-
ori_sd[key] += scale * current_weight
|
403 |
-
return ori_sd
|
404 |
-
|
405 |
-
def load_pretrained_model(self, pretrained_model):
|
406 |
-
if next(self.parameters()).device.type == 'meta':
|
407 |
-
map_location = torch.device(we.device_id)
|
408 |
-
safe_device = we.device_id
|
409 |
-
else:
|
410 |
-
map_location = "cpu"
|
411 |
-
safe_device = "cpu"
|
412 |
-
|
413 |
-
if pretrained_model is not None:
|
414 |
-
with FS.get_from(pretrained_model, wait_finish=True) as local_model:
|
415 |
-
if local_model.endswith('safetensors'):
|
416 |
-
from safetensors.torch import load_file as load_safetensors
|
417 |
-
sd = load_safetensors(local_model, device=safe_device)
|
418 |
-
else:
|
419 |
-
sd = torch.load(local_model, map_location=map_location, weights_only=True)
|
420 |
-
if "state_dict" in sd:
|
421 |
-
sd = sd["state_dict"]
|
422 |
-
if "model" in sd:
|
423 |
-
sd = sd["model"]["model"]
|
424 |
-
|
425 |
-
|
426 |
-
new_ckpt = OrderedDict()
|
427 |
-
for k, v in sd.items():
|
428 |
-
if k in ("img_in.weight"):
|
429 |
-
model_p = self.state_dict()[k]
|
430 |
-
if v.shape != model_p.shape:
|
431 |
-
expanded_state_dict_weight = torch.zeros_like(model_p, device=v.device)
|
432 |
-
slices = tuple(slice(0, dim) for dim in v.shape)
|
433 |
-
expanded_state_dict_weight[slices] = v
|
434 |
-
new_ckpt[k] = expanded_state_dict_weight
|
435 |
-
else:
|
436 |
-
new_ckpt[k] = v
|
437 |
-
else:
|
438 |
-
new_ckpt[k] = v
|
439 |
-
|
440 |
-
|
441 |
-
if self.lora_model is not None:
|
442 |
-
with FS.get_from(self.lora_model, wait_finish=True) as local_model:
|
443 |
-
if local_model.endswith('safetensors'):
|
444 |
-
from safetensors.torch import load_file as load_safetensors
|
445 |
-
lora_sd = load_safetensors(local_model, device=safe_device)
|
446 |
-
else:
|
447 |
-
lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
|
448 |
-
new_ckpt = self.merge_diffuser_lora(new_ckpt, lora_sd)
|
449 |
-
if self.swift_lora_model is not None:
|
450 |
-
if not isinstance(self.swift_lora_model, list):
|
451 |
-
self.swift_lora_model = [self.swift_lora_model]
|
452 |
-
for lora_model in self.swift_lora_model:
|
453 |
-
self.logger.info(f"load swift lora model: {lora_model}")
|
454 |
-
with FS.get_from(lora_model, wait_finish=True) as local_model:
|
455 |
-
if local_model.endswith('safetensors'):
|
456 |
-
from safetensors.torch import load_file as load_safetensors
|
457 |
-
lora_sd = load_safetensors(local_model, device=safe_device)
|
458 |
-
else:
|
459 |
-
lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
|
460 |
-
new_ckpt = self.merge_swift_lora(new_ckpt, lora_sd)
|
461 |
-
if self.blackforest_lora_model is not None:
|
462 |
-
|
463 |
-
with FS.get_from(self.blackforest_lora_model, wait_finish=True) as local_model:
|
464 |
-
if local_model.endswith('safetensors'):
|
465 |
-
from safetensors.torch import load_file as load_safetensors
|
466 |
-
lora_sd = load_safetensors(local_model, device=safe_device)
|
467 |
-
else:
|
468 |
-
lora_sd = torch.load(local_model, map_location=map_location, weights_only=True)
|
469 |
-
new_ckpt = self.merge_blackforest_lora(new_ckpt, lora_sd)
|
470 |
-
|
471 |
-
|
472 |
-
adapter_ckpt = {}
|
473 |
-
if self.pretrain_adapter is not None:
|
474 |
-
with FS.get_from(self.pretrain_adapter, wait_finish=True) as local_adapter:
|
475 |
-
if local_adapter.endswith('safetensors'):
|
476 |
-
from safetensors.torch import load_file as load_safetensors
|
477 |
-
adapter_ckpt = load_safetensors(local_adapter, device=safe_device)
|
478 |
-
else:
|
479 |
-
adapter_ckpt = torch.load(local_adapter, map_location=map_location, weights_only=True)
|
480 |
-
new_ckpt.update(adapter_ckpt)
|
481 |
-
|
482 |
-
missing, unexpected = self.load_state_dict(new_ckpt, strict=False, assign=True)
|
483 |
-
self.logger.info(
|
484 |
-
f'Restored from {pretrained_model} with {len(missing)} missing and {len(unexpected)} unexpected keys'
|
485 |
-
)
|
486 |
-
if len(missing) > 0:
|
487 |
-
self.logger.info(f'Missing Keys:\n {missing}')
|
488 |
-
if len(unexpected) > 0:
|
489 |
-
self.logger.info(f'\nUnexpected Keys:\n {unexpected}')
|
490 |
-
|
491 |
-
def forward(
|
492 |
-
self,
|
493 |
-
x: Tensor,
|
494 |
-
t: Tensor,
|
495 |
-
cond: dict = {},
|
496 |
-
guidance: Tensor | None = None,
|
497 |
-
gc_seg: int = 0
|
498 |
-
) -> Tensor:
|
499 |
-
x, x_ids, txt, txt_ids, y, h, w = self.prepare_input(x, cond["context"], cond["y"])
|
500 |
-
# running on sequences img
|
501 |
-
x = self.img_in(x)
|
502 |
-
vec = self.time_in(timestep_embedding(t, 256))
|
503 |
-
if self.guidance_embed:
|
504 |
-
if guidance is None:
|
505 |
-
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
506 |
-
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
507 |
-
vec = vec + self.vector_in(y)
|
508 |
-
txt = self.txt_in(txt)
|
509 |
-
ids = torch.cat((txt_ids, x_ids), dim=1)
|
510 |
-
pe = self.pe_embedder(ids)
|
511 |
-
kwargs = dict(
|
512 |
-
vec=vec,
|
513 |
-
pe=pe,
|
514 |
-
txt_length=txt.shape[1],
|
515 |
-
)
|
516 |
-
x = torch.cat((txt, x), 1)
|
517 |
-
if self.use_grad_checkpoint and gc_seg >= 0:
|
518 |
-
x = checkpoint_sequential(
|
519 |
-
functions=[partial(block, **kwargs) for block in self.double_blocks],
|
520 |
-
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
|
521 |
-
input=x,
|
522 |
-
use_reentrant=False
|
523 |
-
)
|
524 |
-
else:
|
525 |
-
for block in self.double_blocks:
|
526 |
-
x = block(x, **kwargs)
|
527 |
-
|
528 |
-
kwargs = dict(
|
529 |
-
vec=vec,
|
530 |
-
pe=pe,
|
531 |
-
)
|
532 |
-
|
533 |
-
if self.use_grad_checkpoint and gc_seg >= 0:
|
534 |
-
x = checkpoint_sequential(
|
535 |
-
functions=[partial(block, **kwargs) for block in self.single_blocks],
|
536 |
-
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
|
537 |
-
input=x,
|
538 |
-
use_reentrant=False
|
539 |
-
)
|
540 |
-
else:
|
541 |
-
for block in self.single_blocks:
|
542 |
-
x = block(x, **kwargs)
|
543 |
-
x = x[:, txt.shape[1] :, ...]
|
544 |
-
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
|
545 |
-
x = self.unpack(x, h, w)
|
546 |
-
return x
|
547 |
-
|
548 |
-
@staticmethod
|
549 |
-
def get_config_template():
|
550 |
-
return dict_to_yaml('MODEL',
|
551 |
-
__class__.__name__,
|
552 |
-
Flux.para_dict,
|
553 |
-
set_name=True)
|
554 |
-
@BACKBONES.register_class()
|
555 |
-
class ACEFlux(Flux):
|
556 |
-
'''
|
557 |
-
cat[x_seq, edit_seq]
|
558 |
-
pe[x_seq] pe[edit_seq]
|
559 |
-
'''
|
560 |
-
|
561 |
-
def __init__(
|
562 |
-
self,
|
563 |
-
cfg,
|
564 |
-
logger=None
|
565 |
-
):
|
566 |
-
super().__init__(cfg, logger=logger)
|
567 |
-
self.in_channels = cfg.IN_CHANNELS
|
568 |
-
self.out_channels = cfg.get("OUT_CHANNELS", self.in_channels)
|
569 |
-
hidden_size = cfg.get("HIDDEN_SIZE", 1024)
|
570 |
-
num_heads = cfg.get("NUM_HEADS", 16)
|
571 |
-
axes_dim = cfg.AXES_DIM
|
572 |
-
theta = cfg.THETA
|
573 |
-
vec_in_dim = cfg.VEC_IN_DIM
|
574 |
-
self.guidance_embed = cfg.GUIDANCE_EMBED
|
575 |
-
context_in_dim = cfg.CONTEXT_IN_DIM
|
576 |
-
mlp_ratio = cfg.MLP_RATIO
|
577 |
-
qkv_bias = cfg.QKV_BIAS
|
578 |
-
depth = cfg.DEPTH
|
579 |
-
depth_single_blocks = cfg.DEPTH_SINGLE_BLOCKS
|
580 |
-
self.use_grad_checkpoint = cfg.get("USE_GRAD_CHECKPOINT", False)
|
581 |
-
self.attn_backend = cfg.get("ATTN_BACKEND", "pytorch")
|
582 |
-
self.lora_model = cfg.get("DIFFUSERS_LORA_MODEL", None)
|
583 |
-
self.swift_lora_model = cfg.get("SWIFT_LORA_MODEL", None)
|
584 |
-
self.blackforest_lora_model = cfg.get("BLACKFOREST_LORA_MODEL", None)
|
585 |
-
self.pretrain_adapter = cfg.get("PRETRAIN_ADAPTER", None)
|
586 |
-
|
587 |
-
if hidden_size % num_heads != 0:
|
588 |
-
raise ValueError(
|
589 |
-
f"Hidden size {hidden_size} must be divisible by num_heads {num_heads}"
|
590 |
-
)
|
591 |
-
pe_dim = hidden_size // num_heads
|
592 |
-
if sum(axes_dim) != pe_dim:
|
593 |
-
raise ValueError(f"Got {axes_dim} but expected positional dim {pe_dim}")
|
594 |
-
self.hidden_size = hidden_size
|
595 |
-
self.num_heads = num_heads
|
596 |
-
self.pe_embedder = EmbedND(dim=pe_dim, theta=theta, axes_dim=axes_dim)
|
597 |
-
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
|
598 |
-
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
|
599 |
-
self.vector_in = MLPEmbedder(vec_in_dim, self.hidden_size)
|
600 |
-
self.guidance_in = (
|
601 |
-
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) if self.guidance_embed else nn.Identity()
|
602 |
-
)
|
603 |
-
self.txt_in = nn.Linear(context_in_dim, self.hidden_size)
|
604 |
-
|
605 |
-
self.double_blocks = nn.ModuleList(
|
606 |
-
[
|
607 |
-
DoubleStreamBlockACE(
|
608 |
-
self.hidden_size,
|
609 |
-
self.num_heads,
|
610 |
-
mlp_ratio=mlp_ratio,
|
611 |
-
qkv_bias=qkv_bias,
|
612 |
-
backend=self.attn_backend
|
613 |
-
)
|
614 |
-
for _ in range(depth)
|
615 |
-
]
|
616 |
-
)
|
617 |
-
|
618 |
-
self.single_blocks = nn.ModuleList(
|
619 |
-
[
|
620 |
-
SingleStreamBlockACE(self.hidden_size, self.num_heads, mlp_ratio=mlp_ratio, backend=self.attn_backend)
|
621 |
-
for _ in range(depth_single_blocks)
|
622 |
-
]
|
623 |
-
)
|
624 |
-
|
625 |
-
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
|
626 |
-
|
627 |
-
def prepare_input(self, x, cond, *args, **kwargs):
|
628 |
-
context, y = cond["context"], cond["y"]
|
629 |
-
# import pdb;pdb.set_trace()
|
630 |
-
batch_shift = []
|
631 |
-
x_list, x_id_list, mask_x_list, x_seq_length = [], [], [], []
|
632 |
-
for ix, shape, is_align in zip(x, cond["x_shapes"], cond['align']):
|
633 |
-
# unpack image from sequence
|
634 |
-
ix = ix[:, :shape[0] * shape[1]].view(-1, shape[0], shape[1])
|
635 |
-
c, h, w = ix.shape
|
636 |
-
ix = rearrange(ix, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
637 |
-
ix_id = torch.zeros(h // 2, w // 2, 3)
|
638 |
-
ix_id[..., 1] = ix_id[..., 1] + torch.arange(h // 2)[:, None]
|
639 |
-
ix_id[..., 2] = ix_id[..., 2] + torch.arange(w // 2)[None, :]
|
640 |
-
batch_shift.append(w // 2) if is_align < 1 else batch_shift.append(0)
|
641 |
-
ix_id = rearrange(ix_id, "h w c -> (h w) c")
|
642 |
-
ix = self.img_in(ix)
|
643 |
-
x_list.append(ix)
|
644 |
-
x_id_list.append(ix_id)
|
645 |
-
mask_x_list.append(torch.ones(ix.shape[0]).to(ix.device, non_blocking=True).bool())
|
646 |
-
x_seq_length.append(ix.shape[0])
|
647 |
-
|
648 |
-
x = pad_sequence(tuple(x_list), batch_first=True)
|
649 |
-
x_ids = pad_sequence(tuple(x_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
|
650 |
-
mask_x = pad_sequence(tuple(mask_x_list), batch_first=True)
|
651 |
-
|
652 |
-
if 'edit' in cond and sum(len(e) for e in cond['edit']) > 0:
|
653 |
-
batch_frames, batch_frames_ids = [], []
|
654 |
-
for i, edit in enumerate(cond['edit']):
|
655 |
-
batch_frames.append([])
|
656 |
-
batch_frames_ids.append([])
|
657 |
-
for ie in edit:
|
658 |
-
ie = ie.squeeze(0)
|
659 |
-
c, h, w = ie.shape
|
660 |
-
ie = rearrange(ie, "c (h ph) (w pw) -> (h w) (c ph pw)", ph=2, pw=2)
|
661 |
-
ie_id = torch.zeros(h // 2, w // 2, 3)
|
662 |
-
ie_id[..., 1] = ie_id[..., 1] + torch.arange(h // 2)[:, None]
|
663 |
-
ie_id[..., 2] = ie_id[..., 2] + torch.arange(batch_shift[i], batch_shift[i] + w // 2)[None, :]
|
664 |
-
ie_id = rearrange(ie_id, "h w c -> (h w) c")
|
665 |
-
batch_frames[i].append(ie)
|
666 |
-
batch_frames_ids[i].append(ie_id)
|
667 |
-
edit_list, edit_id_list, edit_mask_x_list = [], [], []
|
668 |
-
for frames, frame_ids in zip(batch_frames, batch_frames_ids):
|
669 |
-
proj_frames = []
|
670 |
-
for idx, one_frame in enumerate(frames):
|
671 |
-
one_frame = self.img_in(one_frame)
|
672 |
-
proj_frames.append(one_frame)
|
673 |
-
ie = torch.cat(proj_frames, dim=0)
|
674 |
-
ie_id = torch.cat(frame_ids, dim=0)
|
675 |
-
edit_list.append(ie)
|
676 |
-
edit_id_list.append(ie_id)
|
677 |
-
edit_mask_x_list.append(torch.ones(ie.shape[0]).to(ie.device, non_blocking=True).bool())
|
678 |
-
edit = pad_sequence(tuple(edit_list), batch_first=True)
|
679 |
-
edit_ids = pad_sequence(tuple(edit_id_list), batch_first=True).to(x) # [b,pad_seq,2] pad (0.,0.) at dim2
|
680 |
-
edit_mask_x = pad_sequence(tuple(edit_mask_x_list), batch_first=True)
|
681 |
-
else:
|
682 |
-
edit, edit_ids, edit_mask_x = None, None, None
|
683 |
-
|
684 |
-
txt_list, mask_txt_list, y_list = [], [], []
|
685 |
-
for sample_id, (ctx, yy) in enumerate(zip(context, y)):
|
686 |
-
txt_list.append(self.txt_in(ctx.to(x)))
|
687 |
-
mask_txt_list.append(torch.ones(txt_list[-1].shape[0]).to(ctx.device, non_blocking=True).bool())
|
688 |
-
y_list.append(yy.to(x))
|
689 |
-
txt = pad_sequence(tuple(txt_list), batch_first=True)
|
690 |
-
txt_ids = torch.zeros(txt.shape[0], txt.shape[1], 3).to(x)
|
691 |
-
mask_txt = pad_sequence(tuple(mask_txt_list), batch_first=True)
|
692 |
-
y = torch.cat(y_list, dim=0)
|
693 |
-
return x, x_ids, edit, edit_ids, txt, txt_ids, y, mask_x, edit_mask_x, mask_txt, x_seq_length
|
694 |
-
|
695 |
-
def unpack(self, x: Tensor, cond: dict = None, x_seq_length: list = None) -> Tensor:
|
696 |
-
x_list = []
|
697 |
-
image_shapes = cond["x_shapes"]
|
698 |
-
for u, shape, seq_length in zip(x, image_shapes, x_seq_length):
|
699 |
-
height, width = shape
|
700 |
-
h, w = math.ceil(height / 2), math.ceil(width / 2)
|
701 |
-
u = rearrange(
|
702 |
-
u[:h * w, ...],
|
703 |
-
"(h w) (c ph pw) -> (h ph w pw) c",
|
704 |
-
h=h,
|
705 |
-
w=w,
|
706 |
-
ph=2,
|
707 |
-
pw=2,
|
708 |
-
)
|
709 |
-
x_list.append(u)
|
710 |
-
x = pad_sequence(tuple(x_list), batch_first=True).permute(0, 2, 1)
|
711 |
-
return x
|
712 |
-
|
713 |
-
def forward(
|
714 |
-
self,
|
715 |
-
x: Tensor,
|
716 |
-
t: Tensor,
|
717 |
-
cond: dict = {},
|
718 |
-
guidance: Tensor | None = None,
|
719 |
-
gc_seg: int = 0,
|
720 |
-
**kwargs
|
721 |
-
) -> Tensor:
|
722 |
-
x, x_ids, edit, edit_ids, txt, txt_ids, y, mask_x, edit_mask_x, mask_txt, seq_length_list = self.prepare_input(x, cond)
|
723 |
-
# running on sequences img
|
724 |
-
# condition use zero t
|
725 |
-
x_length = x.shape[1]
|
726 |
-
vec = self.time_in(timestep_embedding(t, 256))
|
727 |
-
|
728 |
-
if edit is not None:
|
729 |
-
edit_vec = self.time_in(timestep_embedding(t * 0, 256))
|
730 |
-
# print("edit_vec", torch.sum(edit_vec))
|
731 |
-
else:
|
732 |
-
edit_vec = None
|
733 |
-
|
734 |
-
if self.guidance_embed:
|
735 |
-
if guidance is None:
|
736 |
-
raise ValueError("Didn't get guidance strength for guidance distilled model.")
|
737 |
-
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
|
738 |
-
if edit is not None:
|
739 |
-
edit_vec = edit_vec + self.guidance_in(timestep_embedding(guidance, 256))
|
740 |
-
|
741 |
-
vec = vec + self.vector_in(y)
|
742 |
-
if edit is not None:
|
743 |
-
edit_vec = edit_vec + self.vector_in(y)
|
744 |
-
ids = torch.cat((txt_ids, x_ids, edit_ids), dim=1)
|
745 |
-
mask_aside = torch.cat((mask_txt, mask_x, edit_mask_x), dim=1)
|
746 |
-
x = torch.cat((txt, x, edit), 1)
|
747 |
-
else:
|
748 |
-
ids = torch.cat((txt_ids, x_ids), dim=1)
|
749 |
-
mask_aside = torch.cat((mask_txt, mask_x), dim=1)
|
750 |
-
x = torch.cat((txt, x), 1)
|
751 |
-
|
752 |
-
pe = self.pe_embedder(ids)
|
753 |
-
mask = mask_aside[:, None, :] * mask_aside[:, :, None]
|
754 |
-
|
755 |
-
kwargs = dict(
|
756 |
-
vec=vec,
|
757 |
-
pe=pe,
|
758 |
-
mask=mask,
|
759 |
-
txt_length=txt.shape[1],
|
760 |
-
x_length=x_length,
|
761 |
-
edit_vec=edit_vec,
|
762 |
-
|
763 |
-
)
|
764 |
-
|
765 |
-
if self.use_grad_checkpoint and gc_seg >= 0:
|
766 |
-
x = checkpoint_sequential(
|
767 |
-
functions=[partial(block, **kwargs) for block in self.double_blocks],
|
768 |
-
segments=gc_seg if gc_seg > 0 else len(self.double_blocks),
|
769 |
-
input=x,
|
770 |
-
use_reentrant=False
|
771 |
-
)
|
772 |
-
else:
|
773 |
-
for idx, block in enumerate(self.double_blocks):
|
774 |
-
# print("double block", idx)
|
775 |
-
x = block(x, **kwargs)
|
776 |
-
|
777 |
-
if self.use_grad_checkpoint and gc_seg >= 0:
|
778 |
-
x = checkpoint_sequential(
|
779 |
-
functions=[partial(block, **kwargs) for block in self.single_blocks],
|
780 |
-
segments=gc_seg if gc_seg > 0 else len(self.single_blocks),
|
781 |
-
input=x,
|
782 |
-
use_reentrant=False
|
783 |
-
)
|
784 |
-
else:
|
785 |
-
for idx, block in enumerate(self.single_blocks):
|
786 |
-
# print("single block", idx)
|
787 |
-
x = block(x, **kwargs)
|
788 |
-
x = x[:, txt.shape[1]:txt.shape[1] + x_length, ...]
|
789 |
-
x = self.final_layer(x, vec) # (N, T, patch_size ** 2 * out_channels) 6 64 64
|
790 |
-
x = self.unpack(x, cond, seq_length_list)
|
791 |
-
return x
|
792 |
-
|
793 |
-
@staticmethod
|
794 |
-
def get_config_template():
|
795 |
-
return dict_to_yaml('MODEL',
|
796 |
-
__class__.__name__,
|
797 |
-
ACEFlux.para_dict,
|
798 |
-
set_name=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models/layers.py
DELETED
@@ -1,497 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
|
3 |
-
import math
|
4 |
-
from dataclasses import dataclass
|
5 |
-
from torch import Tensor, nn
|
6 |
-
import torch
|
7 |
-
from einops import rearrange, repeat
|
8 |
-
from torch import Tensor
|
9 |
-
from torch.nn.utils.rnn import pad_sequence
|
10 |
-
|
11 |
-
try:
|
12 |
-
from flash_attn import (
|
13 |
-
flash_attn_varlen_func
|
14 |
-
)
|
15 |
-
FLASHATTN_IS_AVAILABLE = True
|
16 |
-
except ImportError:
|
17 |
-
FLASHATTN_IS_AVAILABLE = False
|
18 |
-
flash_attn_varlen_func = None
|
19 |
-
|
20 |
-
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask: Tensor | None = None, backend = 'pytorch') -> Tensor:
|
21 |
-
q, k = apply_rope(q, k, pe)
|
22 |
-
if backend == 'pytorch':
|
23 |
-
if mask is not None and mask.dtype == torch.bool:
|
24 |
-
mask = torch.zeros_like(mask).to(q).masked_fill_(mask.logical_not(), -1e20)
|
25 |
-
x = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask)
|
26 |
-
# x = torch.nan_to_num(x, nan=0.0, posinf=1e10, neginf=-1e10)
|
27 |
-
x = rearrange(x, "B H L D -> B L (H D)")
|
28 |
-
elif backend == 'flash_attn':
|
29 |
-
# q: (B, H, L, D)
|
30 |
-
# k: (B, H, S, D) now L = S
|
31 |
-
# v: (B, H, S, D)
|
32 |
-
b, h, lq, d = q.shape
|
33 |
-
_, _, lk, _ = k.shape
|
34 |
-
q = rearrange(q, "B H L D -> B L H D")
|
35 |
-
k = rearrange(k, "B H S D -> B S H D")
|
36 |
-
v = rearrange(v, "B H S D -> B S H D")
|
37 |
-
if mask is None:
|
38 |
-
q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(q.device, non_blocking=True)
|
39 |
-
k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(k.device, non_blocking=True)
|
40 |
-
else:
|
41 |
-
q_lens = torch.sum(mask[:, 0, :, 0], dim=1).int()
|
42 |
-
k_lens = torch.sum(mask[:, 0, 0, :], dim=1).int()
|
43 |
-
q = torch.cat([q_v[:q_l] for q_v, q_l in zip(q, q_lens)])
|
44 |
-
k = torch.cat([k_v[:k_l] for k_v, k_l in zip(k, k_lens)])
|
45 |
-
v = torch.cat([v_v[:v_l] for v_v, v_l in zip(v, k_lens)])
|
46 |
-
cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32)
|
47 |
-
cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32)
|
48 |
-
max_seqlen_q = q_lens.max()
|
49 |
-
max_seqlen_k = k_lens.max()
|
50 |
-
|
51 |
-
x = flash_attn_varlen_func(
|
52 |
-
q,
|
53 |
-
k,
|
54 |
-
v,
|
55 |
-
cu_seqlens_q=cu_seqlens_q,
|
56 |
-
cu_seqlens_k=cu_seqlens_k,
|
57 |
-
max_seqlen_q=max_seqlen_q,
|
58 |
-
max_seqlen_k=max_seqlen_k
|
59 |
-
)
|
60 |
-
x_list = [x[cu_seqlens_q[i]:cu_seqlens_q[i+1]] for i in range(b)]
|
61 |
-
x = pad_sequence(tuple(x_list), batch_first=True)
|
62 |
-
x = rearrange(x, "B L H D -> B L (H D)")
|
63 |
-
else:
|
64 |
-
raise NotImplementedError
|
65 |
-
return x
|
66 |
-
|
67 |
-
|
68 |
-
def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
69 |
-
assert dim % 2 == 0
|
70 |
-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
71 |
-
omega = 1.0 / (theta**scale)
|
72 |
-
out = torch.einsum("...n,d->...nd", pos, omega)
|
73 |
-
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
|
74 |
-
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
|
75 |
-
return out.float()
|
76 |
-
|
77 |
-
|
78 |
-
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]:
|
79 |
-
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
|
80 |
-
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
|
81 |
-
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
|
82 |
-
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
|
83 |
-
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)
|
84 |
-
|
85 |
-
class EmbedND(nn.Module):
|
86 |
-
def __init__(self, dim: int, theta: int, axes_dim: list[int]):
|
87 |
-
super().__init__()
|
88 |
-
self.dim = dim
|
89 |
-
self.theta = theta
|
90 |
-
self.axes_dim = axes_dim
|
91 |
-
|
92 |
-
def forward(self, ids: Tensor) -> Tensor:
|
93 |
-
n_axes = ids.shape[-1]
|
94 |
-
emb = torch.cat(
|
95 |
-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
96 |
-
dim=-3,
|
97 |
-
)
|
98 |
-
|
99 |
-
return emb.unsqueeze(1)
|
100 |
-
|
101 |
-
|
102 |
-
def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 1000.0):
|
103 |
-
"""
|
104 |
-
Create sinusoidal timestep embeddings.
|
105 |
-
:param t: a 1-D Tensor of N indices, one per batch element.
|
106 |
-
These may be fractional.
|
107 |
-
:param dim: the dimension of the output.
|
108 |
-
:param max_period: controls the minimum frequency of the embeddings.
|
109 |
-
:return: an (N, D) Tensor of positional embeddings.
|
110 |
-
"""
|
111 |
-
t = time_factor * t
|
112 |
-
half = dim // 2
|
113 |
-
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
|
114 |
-
t.device
|
115 |
-
)
|
116 |
-
|
117 |
-
args = t[:, None].float() * freqs[None]
|
118 |
-
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
119 |
-
if dim % 2:
|
120 |
-
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
121 |
-
if torch.is_floating_point(t):
|
122 |
-
embedding = embedding.to(t)
|
123 |
-
return embedding
|
124 |
-
|
125 |
-
|
126 |
-
class MLPEmbedder(nn.Module):
|
127 |
-
def __init__(self, in_dim: int, hidden_dim: int):
|
128 |
-
super().__init__()
|
129 |
-
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
|
130 |
-
self.silu = nn.SiLU()
|
131 |
-
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
|
132 |
-
|
133 |
-
def forward(self, x: Tensor) -> Tensor:
|
134 |
-
return self.out_layer(self.silu(self.in_layer(x)))
|
135 |
-
|
136 |
-
|
137 |
-
class RMSNorm(torch.nn.Module):
|
138 |
-
def __init__(self, dim: int):
|
139 |
-
super().__init__()
|
140 |
-
self.scale = nn.Parameter(torch.ones(dim))
|
141 |
-
|
142 |
-
def forward(self, x: Tensor):
|
143 |
-
x_dtype = x.dtype
|
144 |
-
x = x.float()
|
145 |
-
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6)
|
146 |
-
return (x * rrms).to(dtype=x_dtype) * self.scale
|
147 |
-
|
148 |
-
|
149 |
-
class QKNorm(torch.nn.Module):
|
150 |
-
def __init__(self, dim: int):
|
151 |
-
super().__init__()
|
152 |
-
self.query_norm = RMSNorm(dim)
|
153 |
-
self.key_norm = RMSNorm(dim)
|
154 |
-
|
155 |
-
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]:
|
156 |
-
q = self.query_norm(q)
|
157 |
-
k = self.key_norm(k)
|
158 |
-
return q.to(v), k.to(v)
|
159 |
-
|
160 |
-
|
161 |
-
class SelfAttention(nn.Module):
|
162 |
-
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
163 |
-
super().__init__()
|
164 |
-
self.num_heads = num_heads
|
165 |
-
head_dim = dim // num_heads
|
166 |
-
|
167 |
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
168 |
-
self.norm = QKNorm(head_dim)
|
169 |
-
self.proj = nn.Linear(dim, dim)
|
170 |
-
|
171 |
-
def forward(self, x: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
|
172 |
-
qkv = self.qkv(x)
|
173 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
174 |
-
q, k = self.norm(q, k, v)
|
175 |
-
x = attention(q, k, v, pe=pe, mask=mask)
|
176 |
-
x = self.proj(x)
|
177 |
-
return x
|
178 |
-
|
179 |
-
class CrossAttention(nn.Module):
|
180 |
-
def __init__(self, dim: int, context_dim: int, num_heads: int = 8, qkv_bias: bool = False):
|
181 |
-
super().__init__()
|
182 |
-
self.num_heads = num_heads
|
183 |
-
head_dim = dim // num_heads
|
184 |
-
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
185 |
-
self.kv = nn.Linear(dim, context_dim * 2, bias=qkv_bias)
|
186 |
-
self.norm = QKNorm(head_dim)
|
187 |
-
self.proj = nn.Linear(dim, dim)
|
188 |
-
|
189 |
-
def forward(self, x: Tensor, context: Tensor, pe: Tensor, mask: Tensor | None = None) -> Tensor:
|
190 |
-
qkv = self.qkv(x)
|
191 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
192 |
-
q, k = self.norm(q, k, v)
|
193 |
-
x = attention(q, k, v, pe=pe, mask=mask)
|
194 |
-
x = self.proj(x)
|
195 |
-
return x
|
196 |
-
|
197 |
-
|
198 |
-
@dataclass
|
199 |
-
class ModulationOut:
|
200 |
-
shift: Tensor
|
201 |
-
scale: Tensor
|
202 |
-
gate: Tensor
|
203 |
-
|
204 |
-
|
205 |
-
class Modulation(nn.Module):
|
206 |
-
def __init__(self, dim: int, double: bool):
|
207 |
-
super().__init__()
|
208 |
-
self.is_double = double
|
209 |
-
self.multiplier = 6 if double else 3
|
210 |
-
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
|
211 |
-
|
212 |
-
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]:
|
213 |
-
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(self.multiplier, dim=-1)
|
214 |
-
|
215 |
-
return (
|
216 |
-
ModulationOut(*out[:3]),
|
217 |
-
ModulationOut(*out[3:]) if self.is_double else None,
|
218 |
-
)
|
219 |
-
|
220 |
-
|
221 |
-
class DoubleStreamBlock(nn.Module):
|
222 |
-
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, backend = 'pytorch'):
|
223 |
-
super().__init__()
|
224 |
-
|
225 |
-
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
226 |
-
self.num_heads = num_heads
|
227 |
-
self.hidden_size = hidden_size
|
228 |
-
self.img_mod = Modulation(hidden_size, double=True)
|
229 |
-
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
230 |
-
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
231 |
-
|
232 |
-
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
233 |
-
self.img_mlp = nn.Sequential(
|
234 |
-
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
235 |
-
nn.GELU(approximate="tanh"),
|
236 |
-
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
237 |
-
)
|
238 |
-
|
239 |
-
self.backend = backend
|
240 |
-
|
241 |
-
self.txt_mod = Modulation(hidden_size, double=True)
|
242 |
-
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
243 |
-
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias)
|
244 |
-
|
245 |
-
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
246 |
-
self.txt_mlp = nn.Sequential(
|
247 |
-
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
248 |
-
nn.GELU(approximate="tanh"),
|
249 |
-
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
250 |
-
)
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None, txt_length = None):
|
256 |
-
img_mod1, img_mod2 = self.img_mod(vec)
|
257 |
-
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
258 |
-
|
259 |
-
txt, img = x[:, :txt_length], x[:, txt_length:]
|
260 |
-
|
261 |
-
# prepare image for attention
|
262 |
-
img_modulated = self.img_norm1(img)
|
263 |
-
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
264 |
-
img_qkv = self.img_attn.qkv(img_modulated)
|
265 |
-
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
266 |
-
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
267 |
-
# prepare txt for attention
|
268 |
-
txt_modulated = self.txt_norm1(txt)
|
269 |
-
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
270 |
-
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
271 |
-
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
272 |
-
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
273 |
-
|
274 |
-
# run actual attention
|
275 |
-
q = torch.cat((txt_q, img_q), dim=2)
|
276 |
-
k = torch.cat((txt_k, img_k), dim=2)
|
277 |
-
v = torch.cat((txt_v, img_v), dim=2)
|
278 |
-
if mask is not None:
|
279 |
-
mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
|
280 |
-
attn = attention(q, k, v, pe=pe, mask = mask, backend = self.backend)
|
281 |
-
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
282 |
-
|
283 |
-
# calculate the img bloks
|
284 |
-
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
285 |
-
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
286 |
-
|
287 |
-
# calculate the txt bloks
|
288 |
-
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
289 |
-
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
290 |
-
x = torch.cat((txt, img), 1)
|
291 |
-
return x
|
292 |
-
|
293 |
-
|
294 |
-
class SingleStreamBlock(nn.Module):
|
295 |
-
"""
|
296 |
-
A DiT block with parallel linear layers as described in
|
297 |
-
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
298 |
-
"""
|
299 |
-
|
300 |
-
def __init__(
|
301 |
-
self,
|
302 |
-
hidden_size: int,
|
303 |
-
num_heads: int,
|
304 |
-
mlp_ratio: float = 4.0,
|
305 |
-
qk_scale: float | None = None,
|
306 |
-
backend='pytorch'
|
307 |
-
):
|
308 |
-
super().__init__()
|
309 |
-
self.hidden_dim = hidden_size
|
310 |
-
self.num_heads = num_heads
|
311 |
-
head_dim = hidden_size // num_heads
|
312 |
-
self.scale = qk_scale or head_dim**-0.5
|
313 |
-
|
314 |
-
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
315 |
-
# qkv and mlp_in
|
316 |
-
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
|
317 |
-
# proj and mlp_out
|
318 |
-
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
|
319 |
-
|
320 |
-
self.norm = QKNorm(head_dim)
|
321 |
-
|
322 |
-
self.hidden_size = hidden_size
|
323 |
-
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
324 |
-
|
325 |
-
self.mlp_act = nn.GELU(approximate="tanh")
|
326 |
-
self.modulation = Modulation(hidden_size, double=False)
|
327 |
-
self.backend = backend
|
328 |
-
|
329 |
-
def forward(self, x: Tensor, vec: Tensor, pe: Tensor, mask: Tensor = None) -> Tensor:
|
330 |
-
mod, _ = self.modulation(vec)
|
331 |
-
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
332 |
-
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
333 |
-
|
334 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
335 |
-
q, k = self.norm(q, k, v)
|
336 |
-
if mask is not None:
|
337 |
-
mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
|
338 |
-
# compute attention
|
339 |
-
attn = attention(q, k, v, pe=pe, mask = mask, backend=self.backend)
|
340 |
-
# compute activation in mlp stream, cat again and run second linear layer
|
341 |
-
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
342 |
-
return x + mod.gate * output
|
343 |
-
|
344 |
-
|
345 |
-
class DoubleStreamBlockACE(DoubleStreamBlock):
|
346 |
-
def forward(self,
|
347 |
-
x: Tensor,
|
348 |
-
vec: Tensor,
|
349 |
-
pe: Tensor,
|
350 |
-
edit_vec: Tensor | None = None,
|
351 |
-
mask: Tensor = None,
|
352 |
-
txt_length = None,
|
353 |
-
x_length = None):
|
354 |
-
img_mod1, img_mod2 = self.img_mod(vec)
|
355 |
-
txt_mod1, txt_mod2 = self.txt_mod(vec)
|
356 |
-
if edit_vec is not None:
|
357 |
-
edit_mod1, edit_mod2 = self.img_mod(edit_vec)
|
358 |
-
txt, img, edit = x[:, :txt_length], x[:, txt_length:txt_length+x_length], x[:, txt_length+x_length:]
|
359 |
-
else:
|
360 |
-
edit_mod1, edit_mod2 = None, None
|
361 |
-
txt, img = x[:, :txt_length], x[:, txt_length:]
|
362 |
-
edit = None
|
363 |
-
|
364 |
-
|
365 |
-
# prepare image for attention
|
366 |
-
img_modulated = self.img_norm1(img)
|
367 |
-
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
|
368 |
-
img_qkv = self.img_attn.qkv(img_modulated)
|
369 |
-
img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
370 |
-
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)
|
371 |
-
# prepare txt for attention
|
372 |
-
txt_modulated = self.txt_norm1(txt)
|
373 |
-
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
|
374 |
-
txt_qkv = self.txt_attn.qkv(txt_modulated)
|
375 |
-
txt_q, txt_k, txt_v = rearrange(txt_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
376 |
-
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)
|
377 |
-
# prepare edit for attention
|
378 |
-
if edit_vec is not None:
|
379 |
-
edit_modulated = self.img_norm1(edit)
|
380 |
-
edit_modulated = (1 + edit_mod1.scale) * edit_modulated + edit_mod1.shift
|
381 |
-
edit_qkv = self.img_attn.qkv(edit_modulated)
|
382 |
-
edit_q, edit_k, edit_v = rearrange(edit_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
383 |
-
edit_q, edit_k = self.img_attn.norm(edit_q, edit_k, edit_v)
|
384 |
-
q = torch.cat((txt_q, img_q, edit_q), dim=2)
|
385 |
-
k = torch.cat((txt_k, img_k, edit_k), dim=2)
|
386 |
-
v = torch.cat((txt_v, img_v, edit_v), dim=2)
|
387 |
-
else:
|
388 |
-
q = torch.cat((txt_q, img_q), dim=2)
|
389 |
-
k = torch.cat((txt_k, img_k), dim=2)
|
390 |
-
v = torch.cat((txt_v, img_v), dim=2)
|
391 |
-
|
392 |
-
# run actual attention
|
393 |
-
if mask is not None:
|
394 |
-
mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
|
395 |
-
attn = attention(q, k, v, pe=pe, mask = mask, backend = "pytorch")
|
396 |
-
if edit_vec is not None:
|
397 |
-
txt_attn, img_attn, edit_attn = (attn[:, : txt.shape[1]],
|
398 |
-
attn[:, txt.shape[1] : txt.shape[1]+img.shape[1]],
|
399 |
-
attn[:, txt.shape[1]+img.shape[1]:])
|
400 |
-
# calculate the img bloks
|
401 |
-
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
402 |
-
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
403 |
-
|
404 |
-
# calculate the img bloks
|
405 |
-
edit = edit + edit_mod1.gate * self.img_attn.proj(edit_attn)
|
406 |
-
edit = edit + edit_mod2.gate * self.img_mlp((1 + edit_mod2.scale) * self.img_norm2(edit) + edit_mod2.shift)
|
407 |
-
|
408 |
-
# calculate the txt bloks
|
409 |
-
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
410 |
-
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
411 |
-
|
412 |
-
x = torch.cat((txt, img, edit), 1)
|
413 |
-
else:
|
414 |
-
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]
|
415 |
-
# calculate the img bloks
|
416 |
-
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
417 |
-
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)
|
418 |
-
|
419 |
-
# calculate the txt bloks
|
420 |
-
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
421 |
-
txt = txt + txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)
|
422 |
-
x = torch.cat((txt, img), 1)
|
423 |
-
return x
|
424 |
-
|
425 |
-
|
426 |
-
class SingleStreamBlockACE(SingleStreamBlock):
|
427 |
-
"""
|
428 |
-
A DiT block with parallel linear layers as described in
|
429 |
-
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
|
430 |
-
"""
|
431 |
-
|
432 |
-
def forward(self, x: Tensor, vec: Tensor,
|
433 |
-
pe: Tensor, mask: Tensor = None,
|
434 |
-
edit_vec: Tensor | None = None,
|
435 |
-
txt_length=None,
|
436 |
-
x_length=None
|
437 |
-
) -> Tensor:
|
438 |
-
mod, _ = self.modulation(vec)
|
439 |
-
if edit_vec is not None:
|
440 |
-
x, edit = x[:, :txt_length + x_length], x[:, txt_length + x_length:]
|
441 |
-
e_mod, _ = self.modulation(edit_vec)
|
442 |
-
edit_mod = (1 + e_mod.scale) * self.pre_norm(edit) + e_mod.shift
|
443 |
-
edit_qkv, edit_mlp = torch.split(self.linear1(edit_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
444 |
-
|
445 |
-
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
446 |
-
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
447 |
-
qkv, mlp = torch.cat([qkv, edit_qkv], 1), torch.cat([mlp, edit_mlp], 1)
|
448 |
-
else:
|
449 |
-
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
|
450 |
-
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)
|
451 |
-
|
452 |
-
q, k, v = rearrange(qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads)
|
453 |
-
q, k = self.norm(q, k, v)
|
454 |
-
if mask is not None:
|
455 |
-
mask = repeat(mask, 'B L S-> B H L S', H=self.num_heads)
|
456 |
-
# compute attention
|
457 |
-
attn = attention(q, k, v, pe=pe, mask = mask, backend="pytorch")
|
458 |
-
# compute activation in mlp stream, cat again and run second linear layer
|
459 |
-
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
|
460 |
-
|
461 |
-
if edit_vec is not None:
|
462 |
-
x_output, edit_output = output.split([x.shape[1], edit.shape[1]], dim = 1)
|
463 |
-
x = x + mod.gate * x_output
|
464 |
-
edit = edit + e_mod.gate * edit_output
|
465 |
-
x = torch.cat((x, edit), 1)
|
466 |
-
return x
|
467 |
-
else:
|
468 |
-
return x + mod.gate * output
|
469 |
-
|
470 |
-
|
471 |
-
class LastLayer(nn.Module):
|
472 |
-
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
|
473 |
-
super().__init__()
|
474 |
-
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
475 |
-
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
|
476 |
-
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True))
|
477 |
-
|
478 |
-
def forward(self, x: Tensor, vec: Tensor) -> Tensor:
|
479 |
-
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1)
|
480 |
-
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
|
481 |
-
x = self.linear(x)
|
482 |
-
return x
|
483 |
-
|
484 |
-
|
485 |
-
if __name__ == '__main__':
|
486 |
-
pe = EmbedND(dim=64, theta=10000, axes_dim=[16, 56, 56])
|
487 |
-
|
488 |
-
ix_id = torch.zeros(64 // 2, 64 // 2, 3)
|
489 |
-
ix_id[..., 1] = ix_id[..., 1] + torch.arange(64 // 2)[:, None]
|
490 |
-
ix_id[..., 2] = ix_id[..., 2] + torch.arange(64 // 2)[None, :]
|
491 |
-
ix_id = rearrange(ix_id, "h w c -> 1 (h w) c")
|
492 |
-
pos = torch.cat([ix_id, ix_id], dim = 1)
|
493 |
-
a = pe(pos)
|
494 |
-
|
495 |
-
b = torch.cat([pe(ix_id), pe(ix_id)], dim = 2)
|
496 |
-
|
497 |
-
print(a - b)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
DELETED
@@ -1,8 +0,0 @@
|
|
1 |
-
huggingface_hub
|
2 |
-
diffusers
|
3 |
-
transformers
|
4 |
-
gradio_imageslider
|
5 |
-
torch==2.4.0
|
6 |
-
xformers==0.0.27.post2
|
7 |
-
torchvision
|
8 |
-
gradio==4.44.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils.py
DELETED
@@ -1,95 +0,0 @@
|
|
1 |
-
#copyright (c) Alibaba, Inc. and its affiliates.
|
2 |
-
import torch
|
3 |
-
import torchvision.transforms as T
|
4 |
-
from PIL import Image
|
5 |
-
from torchvision.transforms.functional import InterpolationMode
|
6 |
-
|
7 |
-
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
8 |
-
IMAGENET_STD = (0.229, 0.224, 0.225)
|
9 |
-
|
10 |
-
|
11 |
-
def build_transform(input_size):
|
12 |
-
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
13 |
-
transform = T.Compose([
|
14 |
-
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
15 |
-
T.Resize((input_size, input_size),
|
16 |
-
interpolation=InterpolationMode.BICUBIC),
|
17 |
-
T.ToTensor(),
|
18 |
-
T.Normalize(mean=MEAN, std=STD)
|
19 |
-
])
|
20 |
-
return transform
|
21 |
-
|
22 |
-
|
23 |
-
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
24 |
-
image_size):
|
25 |
-
best_ratio_diff = float('inf')
|
26 |
-
best_ratio = (1, 1)
|
27 |
-
area = width * height
|
28 |
-
for ratio in target_ratios:
|
29 |
-
target_aspect_ratio = ratio[0] / ratio[1]
|
30 |
-
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
31 |
-
if ratio_diff < best_ratio_diff:
|
32 |
-
best_ratio_diff = ratio_diff
|
33 |
-
best_ratio = ratio
|
34 |
-
elif ratio_diff == best_ratio_diff:
|
35 |
-
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
36 |
-
best_ratio = ratio
|
37 |
-
return best_ratio
|
38 |
-
|
39 |
-
|
40 |
-
def dynamic_preprocess(image,
|
41 |
-
min_num=1,
|
42 |
-
max_num=12,
|
43 |
-
image_size=448,
|
44 |
-
use_thumbnail=False):
|
45 |
-
orig_width, orig_height = image.size
|
46 |
-
aspect_ratio = orig_width / orig_height
|
47 |
-
|
48 |
-
# calculate the existing image aspect ratio
|
49 |
-
target_ratios = set((i, j) for n in range(min_num, max_num + 1)
|
50 |
-
for i in range(1, n + 1) for j in range(1, n + 1)
|
51 |
-
if i * j <= max_num and i * j >= min_num)
|
52 |
-
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
53 |
-
|
54 |
-
# find the closest aspect ratio to the target
|
55 |
-
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
56 |
-
target_ratios, orig_width,
|
57 |
-
orig_height, image_size)
|
58 |
-
|
59 |
-
# calculate the target width and height
|
60 |
-
target_width = image_size * target_aspect_ratio[0]
|
61 |
-
target_height = image_size * target_aspect_ratio[1]
|
62 |
-
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
63 |
-
|
64 |
-
# resize the image
|
65 |
-
resized_img = image.resize((target_width, target_height))
|
66 |
-
processed_images = []
|
67 |
-
for i in range(blocks):
|
68 |
-
box = ((i % (target_width // image_size)) * image_size,
|
69 |
-
(i // (target_width // image_size)) * image_size,
|
70 |
-
((i % (target_width // image_size)) + 1) * image_size,
|
71 |
-
((i // (target_width // image_size)) + 1) * image_size)
|
72 |
-
# split the image
|
73 |
-
split_img = resized_img.crop(box)
|
74 |
-
processed_images.append(split_img)
|
75 |
-
assert len(processed_images) == blocks
|
76 |
-
if use_thumbnail and len(processed_images) != 1:
|
77 |
-
thumbnail_img = image.resize((image_size, image_size))
|
78 |
-
processed_images.append(thumbnail_img)
|
79 |
-
return processed_images
|
80 |
-
|
81 |
-
|
82 |
-
def load_image(image_file, input_size=448, max_num=12):
|
83 |
-
if isinstance(image_file, str):
|
84 |
-
image = Image.open(image_file).convert('RGB')
|
85 |
-
else:
|
86 |
-
image = image_file
|
87 |
-
transform = build_transform(input_size=input_size)
|
88 |
-
images = dynamic_preprocess(image,
|
89 |
-
image_size=input_size,
|
90 |
-
use_thumbnail=True,
|
91 |
-
max_num=max_num)
|
92 |
-
pixel_values = [transform(image) for image in images]
|
93 |
-
pixel_values = torch.stack(pixel_values)
|
94 |
-
return pixel_values
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|