upload mixofshow and orthogonal_mats folder
Browse files- mixofshow/.DS_Store +0 -0
- mixofshow/data/__init__.py +0 -0
- mixofshow/data/__pycache__/__init__.cpython-38.pyc +0 -0
- mixofshow/data/__pycache__/__init__.cpython-39.pyc +0 -0
- mixofshow/data/__pycache__/lora_dataset.cpython-38.pyc +0 -0
- mixofshow/data/__pycache__/lora_dataset.cpython-39.pyc +0 -0
- mixofshow/data/__pycache__/pil_transform.cpython-38.pyc +0 -0
- mixofshow/data/__pycache__/pil_transform.cpython-39.pyc +0 -0
- mixofshow/data/__pycache__/prompt_dataset.cpython-38.pyc +0 -0
- mixofshow/data/__pycache__/prompt_dataset.cpython-39.pyc +0 -0
- mixofshow/data/lora_dataset.py +102 -0
- mixofshow/data/pil_transform.py +366 -0
- mixofshow/data/prompt_dataset.py +67 -0
- mixofshow/models/__pycache__/edlora.cpython-310.pyc +0 -0
- mixofshow/models/__pycache__/edlora.cpython-38.pyc +0 -0
- mixofshow/models/__pycache__/edlora.cpython-39.pyc +0 -0
- mixofshow/models/edlora.py +259 -0
- mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-310.pyc +0 -0
- mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-38.pyc +0 -0
- mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-39.pyc +0 -0
- mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-310.pyc +0 -0
- mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-38.pyc +0 -0
- mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-39.pyc +0 -0
- mixofshow/pipelines/__pycache__/trainer_edlora.cpython-38.pyc +0 -0
- mixofshow/pipelines/__pycache__/trainer_edlora.cpython-39.pyc +0 -0
- mixofshow/pipelines/pipeline_edlora.py +322 -0
- mixofshow/pipelines/pipeline_regionally_t2iadapter.py +608 -0
- mixofshow/pipelines/trainer_edlora.py +380 -0
- mixofshow/utils/__init__.py +0 -0
- mixofshow/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- mixofshow/utils/__pycache__/__init__.cpython-38.pyc +0 -0
- mixofshow/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-38.pyc +0 -0
- mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-39.pyc +0 -0
- mixofshow/utils/__pycache__/ptp_util.cpython-38.pyc +0 -0
- mixofshow/utils/__pycache__/ptp_util.cpython-39.pyc +0 -0
- mixofshow/utils/__pycache__/registry.cpython-38.pyc +0 -0
- mixofshow/utils/__pycache__/registry.cpython-39.pyc +0 -0
- mixofshow/utils/__pycache__/util.cpython-310.pyc +0 -0
- mixofshow/utils/__pycache__/util.cpython-38.pyc +0 -0
- mixofshow/utils/__pycache__/util.cpython-39.pyc +0 -0
- mixofshow/utils/arial.ttf +0 -0
- mixofshow/utils/convert_edlora_to_diffusers.py +99 -0
- mixofshow/utils/ptp_util.py +200 -0
- mixofshow/utils/registry.py +79 -0
- mixofshow/utils/util.py +313 -0
- orthogonal_mats/1280.npy +3 -0
- orthogonal_mats/320.npy +3 -0
- orthogonal_mats/640.npy +3 -0
- orthogonal_mats/768.npy +3 -0
mixofshow/.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
mixofshow/data/__init__.py
ADDED
File without changes
|
mixofshow/data/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (148 Bytes). View file
|
|
mixofshow/data/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (148 Bytes). View file
|
|
mixofshow/data/__pycache__/lora_dataset.cpython-38.pyc
ADDED
Binary file (3.02 kB). View file
|
|
mixofshow/data/__pycache__/lora_dataset.cpython-39.pyc
ADDED
Binary file (3.07 kB). View file
|
|
mixofshow/data/__pycache__/pil_transform.cpython-38.pyc
ADDED
Binary file (10.9 kB). View file
|
|
mixofshow/data/__pycache__/pil_transform.cpython-39.pyc
ADDED
Binary file (10.8 kB). View file
|
|
mixofshow/data/__pycache__/prompt_dataset.cpython-38.pyc
ADDED
Binary file (2.35 kB). View file
|
|
mixofshow/data/__pycache__/prompt_dataset.cpython-39.pyc
ADDED
Binary file (2.36 kB). View file
|
|
mixofshow/data/lora_dataset.py
ADDED
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from mixofshow.data.pil_transform import PairCompose, build_transform
|
11 |
+
|
12 |
+
|
13 |
+
class LoraDataset(Dataset):
|
14 |
+
"""
|
15 |
+
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
16 |
+
It pre-processes the images and the tokenizes prompts.
|
17 |
+
"""
|
18 |
+
def __init__(self, opt):
|
19 |
+
self.opt = opt
|
20 |
+
self.instance_images_path = []
|
21 |
+
|
22 |
+
with open(opt['concept_list'], 'r') as f:
|
23 |
+
concept_list = json.load(f)
|
24 |
+
|
25 |
+
replace_mapping = opt.get('replace_mapping', {})
|
26 |
+
use_caption = opt.get('use_caption', False)
|
27 |
+
use_mask = opt.get('use_mask', False)
|
28 |
+
|
29 |
+
for concept in concept_list:
|
30 |
+
instance_prompt = concept['instance_prompt']
|
31 |
+
caption_dir = concept.get('caption_dir')
|
32 |
+
mask_dir = concept.get('mask_dir')
|
33 |
+
|
34 |
+
instance_prompt = self.process_text(instance_prompt, replace_mapping)
|
35 |
+
|
36 |
+
inst_img_path = []
|
37 |
+
for x in Path(concept['instance_data_dir']).iterdir():
|
38 |
+
if x.is_file() and x.name != '.DS_Store':
|
39 |
+
basename = os.path.splitext(os.path.basename(x))[0]
|
40 |
+
caption_path = os.path.join(caption_dir, f'{basename}.txt') if caption_dir is not None else None
|
41 |
+
|
42 |
+
if use_caption and caption_path is not None and os.path.exists(caption_path):
|
43 |
+
with open(caption_path, 'r') as fr:
|
44 |
+
line = fr.readlines()[0]
|
45 |
+
instance_prompt_image = self.process_text(line, replace_mapping)
|
46 |
+
else:
|
47 |
+
instance_prompt_image = instance_prompt
|
48 |
+
|
49 |
+
if use_mask and mask_dir is not None:
|
50 |
+
mask_path = os.path.join(mask_dir, f'{basename}.png')
|
51 |
+
else:
|
52 |
+
mask_path = None
|
53 |
+
|
54 |
+
inst_img_path.append((x, instance_prompt_image, mask_path))
|
55 |
+
|
56 |
+
self.instance_images_path.extend(inst_img_path)
|
57 |
+
|
58 |
+
random.shuffle(self.instance_images_path)
|
59 |
+
self.num_instance_images = len(self.instance_images_path)
|
60 |
+
|
61 |
+
self.instance_transform = PairCompose([
|
62 |
+
build_transform(transform_opt)
|
63 |
+
for transform_opt in opt['instance_transform']
|
64 |
+
])
|
65 |
+
|
66 |
+
def process_text(self, instance_prompt, replace_mapping):
|
67 |
+
for k, v in replace_mapping.items():
|
68 |
+
instance_prompt = instance_prompt.replace(k, v)
|
69 |
+
instance_prompt = instance_prompt.strip()
|
70 |
+
instance_prompt = re.sub(' +', ' ', instance_prompt)
|
71 |
+
return instance_prompt
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
return self.num_instance_images * self.opt['dataset_enlarge_ratio']
|
75 |
+
|
76 |
+
def __getitem__(self, index):
|
77 |
+
example = {}
|
78 |
+
instance_image, instance_prompt, instance_mask = self.instance_images_path[index % self.num_instance_images]
|
79 |
+
instance_image = Image.open(instance_image).convert('RGB')
|
80 |
+
|
81 |
+
extra_args = {'prompts': instance_prompt}
|
82 |
+
if instance_mask is not None:
|
83 |
+
instance_mask = Image.open(instance_mask).convert('L')
|
84 |
+
extra_args.update({'mask': instance_mask})
|
85 |
+
|
86 |
+
instance_image, extra_args = self.instance_transform(instance_image, **extra_args)
|
87 |
+
example['images'] = instance_image
|
88 |
+
|
89 |
+
if 'mask' in extra_args:
|
90 |
+
example['masks'] = extra_args['mask']
|
91 |
+
example['masks'] = example['masks'].unsqueeze(0)
|
92 |
+
else:
|
93 |
+
pass
|
94 |
+
|
95 |
+
if 'img_mask' in extra_args:
|
96 |
+
example['img_masks'] = extra_args['img_mask']
|
97 |
+
example['img_masks'] = example['img_masks'].unsqueeze(0)
|
98 |
+
else:
|
99 |
+
raise NotImplementedError
|
100 |
+
|
101 |
+
example['prompts'] = extra_args['prompts']
|
102 |
+
return example
|
mixofshow/data/pil_transform.py
ADDED
@@ -0,0 +1,366 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import random
|
3 |
+
from copy import deepcopy
|
4 |
+
|
5 |
+
import cv2
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
import torchvision.transforms.functional as F
|
10 |
+
from PIL import Image
|
11 |
+
from torchvision.transforms import CenterCrop, Normalize, RandomCrop, RandomHorizontalFlip, Resize
|
12 |
+
from torchvision.transforms.functional import InterpolationMode
|
13 |
+
|
14 |
+
from mixofshow.utils.registry import TRANSFORM_REGISTRY
|
15 |
+
|
16 |
+
|
17 |
+
def build_transform(opt):
|
18 |
+
"""Build performance evaluator from options.
|
19 |
+
Args:
|
20 |
+
opt (dict): Configuration.
|
21 |
+
"""
|
22 |
+
opt = deepcopy(opt)
|
23 |
+
transform_type = opt.pop('type')
|
24 |
+
transform = TRANSFORM_REGISTRY.get(transform_type)(**opt)
|
25 |
+
return transform
|
26 |
+
|
27 |
+
|
28 |
+
TRANSFORM_REGISTRY.register(Normalize)
|
29 |
+
TRANSFORM_REGISTRY.register(Resize)
|
30 |
+
TRANSFORM_REGISTRY.register(RandomHorizontalFlip)
|
31 |
+
TRANSFORM_REGISTRY.register(CenterCrop)
|
32 |
+
TRANSFORM_REGISTRY.register(RandomCrop)
|
33 |
+
|
34 |
+
|
35 |
+
@TRANSFORM_REGISTRY.register()
|
36 |
+
class BILINEARResize(Resize):
|
37 |
+
def __init__(self, size):
|
38 |
+
super(BILINEARResize,
|
39 |
+
self).__init__(size, interpolation=InterpolationMode.BILINEAR)
|
40 |
+
|
41 |
+
|
42 |
+
@TRANSFORM_REGISTRY.register()
|
43 |
+
class PairRandomCrop(nn.Module):
|
44 |
+
def __init__(self, size):
|
45 |
+
super().__init__()
|
46 |
+
if isinstance(size, int):
|
47 |
+
self.height, self.width = size, size
|
48 |
+
else:
|
49 |
+
self.height, self.width = size
|
50 |
+
|
51 |
+
def forward(self, img, **kwargs):
|
52 |
+
img_width, img_height = img.size
|
53 |
+
mask_width, mask_height = kwargs['mask'].size
|
54 |
+
|
55 |
+
assert img_height >= self.height and img_height == mask_height
|
56 |
+
assert img_width >= self.width and img_width == mask_width
|
57 |
+
|
58 |
+
x = random.randint(0, img_width - self.width)
|
59 |
+
y = random.randint(0, img_height - self.height)
|
60 |
+
img = F.crop(img, y, x, self.height, self.width)
|
61 |
+
kwargs['mask'] = F.crop(kwargs['mask'], y, x, self.height, self.width)
|
62 |
+
return img, kwargs
|
63 |
+
|
64 |
+
|
65 |
+
@TRANSFORM_REGISTRY.register()
|
66 |
+
class ToTensor(nn.Module):
|
67 |
+
def __init__(self) -> None:
|
68 |
+
super().__init__()
|
69 |
+
|
70 |
+
def forward(self, pic):
|
71 |
+
return F.to_tensor(pic)
|
72 |
+
|
73 |
+
def __repr__(self) -> str:
|
74 |
+
return f'{self.__class__.__name__}()'
|
75 |
+
|
76 |
+
|
77 |
+
@TRANSFORM_REGISTRY.register()
|
78 |
+
class PairRandomHorizontalFlip(torch.nn.Module):
|
79 |
+
def __init__(self, p=0.5):
|
80 |
+
super().__init__()
|
81 |
+
self.p = p
|
82 |
+
|
83 |
+
def forward(self, img, **kwargs):
|
84 |
+
if torch.rand(1) < self.p:
|
85 |
+
kwargs['mask'] = F.hflip(kwargs['mask'])
|
86 |
+
return F.hflip(img), kwargs
|
87 |
+
return img, kwargs
|
88 |
+
|
89 |
+
|
90 |
+
@TRANSFORM_REGISTRY.register()
|
91 |
+
class PairResize(nn.Module):
|
92 |
+
def __init__(self, size):
|
93 |
+
super().__init__()
|
94 |
+
self.resize = Resize(size=size)
|
95 |
+
|
96 |
+
def forward(self, img, **kwargs):
|
97 |
+
kwargs['mask'] = self.resize(kwargs['mask'])
|
98 |
+
img = self.resize(img)
|
99 |
+
return img, kwargs
|
100 |
+
|
101 |
+
|
102 |
+
class PairCompose(nn.Module):
|
103 |
+
def __init__(self, transforms):
|
104 |
+
super().__init__()
|
105 |
+
self.transforms = transforms
|
106 |
+
|
107 |
+
def __call__(self, img, **kwargs):
|
108 |
+
for t in self.transforms:
|
109 |
+
if len(inspect.signature(t.forward).parameters
|
110 |
+
) == 1: # count how many args, not count self
|
111 |
+
img = t(img)
|
112 |
+
else:
|
113 |
+
img, kwargs = t(img, **kwargs)
|
114 |
+
return img, kwargs
|
115 |
+
|
116 |
+
def __repr__(self) -> str:
|
117 |
+
format_string = self.__class__.__name__ + '('
|
118 |
+
for t in self.transforms:
|
119 |
+
format_string += '\n'
|
120 |
+
format_string += f' {t}'
|
121 |
+
format_string += '\n)'
|
122 |
+
return format_string
|
123 |
+
|
124 |
+
|
125 |
+
@TRANSFORM_REGISTRY.register()
|
126 |
+
class HumanResizeCropFinalV3(nn.Module):
|
127 |
+
def __init__(self, size, crop_p=0.5):
|
128 |
+
super().__init__()
|
129 |
+
self.size = size
|
130 |
+
self.crop_p = crop_p
|
131 |
+
self.random_crop = RandomCrop(size=size)
|
132 |
+
self.paired_random_crop = PairRandomCrop(size=size)
|
133 |
+
|
134 |
+
def forward(self, img, **kwargs):
|
135 |
+
# step 1: short edge resize to 512
|
136 |
+
img = F.resize(img, size=self.size)
|
137 |
+
if 'mask' in kwargs:
|
138 |
+
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)
|
139 |
+
|
140 |
+
# step 2: random crop
|
141 |
+
width, height = img.size
|
142 |
+
if random.random() < self.crop_p:
|
143 |
+
if height > width:
|
144 |
+
crop_pos = random.randint(0, height - width)
|
145 |
+
img = F.crop(img, 0, 0, width + crop_pos, width)
|
146 |
+
if 'mask' in kwargs:
|
147 |
+
kwargs['mask'] = F.crop(kwargs['mask'], 0, 0, width + crop_pos, width)
|
148 |
+
else:
|
149 |
+
if 'mask' in kwargs:
|
150 |
+
img, kwargs = self.paired_random_crop(img, **kwargs)
|
151 |
+
else:
|
152 |
+
img = self.random_crop(img)
|
153 |
+
else:
|
154 |
+
img = img
|
155 |
+
|
156 |
+
# step 3: long edge resize
|
157 |
+
img = F.resize(img, size=self.size - 1, max_size=self.size)
|
158 |
+
if 'mask' in kwargs:
|
159 |
+
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)
|
160 |
+
|
161 |
+
new_width, new_height = img.size
|
162 |
+
|
163 |
+
img = np.array(img)
|
164 |
+
if 'mask' in kwargs:
|
165 |
+
kwargs['mask'] = np.array(kwargs['mask']) / 255
|
166 |
+
new_width = min(new_width, kwargs['mask'].shape[1])
|
167 |
+
new_height = min(new_height, kwargs['mask'].shape[0])
|
168 |
+
|
169 |
+
start_y = random.randint(0, 512 - new_height)
|
170 |
+
start_x = random.randint(0, 512 - new_width)
|
171 |
+
|
172 |
+
res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
|
173 |
+
res_mask = np.zeros((self.size, self.size))
|
174 |
+
res_img_mask = np.zeros((self.size, self.size))
|
175 |
+
|
176 |
+
res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img[:new_height, :new_width]
|
177 |
+
if 'mask' in kwargs:
|
178 |
+
res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask'][:new_height, :new_width]
|
179 |
+
kwargs['mask'] = res_mask
|
180 |
+
|
181 |
+
res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
|
182 |
+
kwargs['img_mask'] = res_img_mask
|
183 |
+
|
184 |
+
img = Image.fromarray(res_img)
|
185 |
+
|
186 |
+
if 'mask' in kwargs:
|
187 |
+
kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
|
188 |
+
kwargs['mask'] = torch.from_numpy(kwargs['mask'])
|
189 |
+
kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
|
190 |
+
kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])
|
191 |
+
return img, kwargs
|
192 |
+
|
193 |
+
|
194 |
+
@TRANSFORM_REGISTRY.register()
|
195 |
+
class ResizeFillMaskNew(nn.Module):
|
196 |
+
def __init__(self, size, crop_p, scale_ratio):
|
197 |
+
super().__init__()
|
198 |
+
self.size = size
|
199 |
+
self.crop_p = crop_p
|
200 |
+
self.scale_ratio = scale_ratio
|
201 |
+
self.random_crop = RandomCrop(size=size)
|
202 |
+
self.paired_random_crop = PairRandomCrop(size=size)
|
203 |
+
|
204 |
+
def forward(self, img, **kwargs):
|
205 |
+
# width, height = img.size
|
206 |
+
|
207 |
+
# step 1: short edge resize to 512
|
208 |
+
img = F.resize(img, size=self.size)
|
209 |
+
if 'mask' in kwargs:
|
210 |
+
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size)
|
211 |
+
|
212 |
+
# step 2: random crop
|
213 |
+
if random.random() < self.crop_p:
|
214 |
+
if 'mask' in kwargs:
|
215 |
+
img, kwargs = self.paired_random_crop(img, **kwargs) # 51
|
216 |
+
else:
|
217 |
+
img = self.random_crop(img) # 512
|
218 |
+
else:
|
219 |
+
# long edge resize
|
220 |
+
img = F.resize(img, size=self.size - 1, max_size=self.size)
|
221 |
+
if 'mask' in kwargs:
|
222 |
+
kwargs['mask'] = F.resize(kwargs['mask'], size=self.size - 1, max_size=self.size)
|
223 |
+
|
224 |
+
# step 3: random aspect ratio
|
225 |
+
width, height = img.size
|
226 |
+
ratio = random.uniform(*self.scale_ratio)
|
227 |
+
|
228 |
+
img = F.resize(img, size=(int(height * ratio), int(width * ratio)))
|
229 |
+
if 'mask' in kwargs:
|
230 |
+
kwargs['mask'] = F.resize(kwargs['mask'], size=(int(height * ratio), int(width * ratio)), interpolation=0)
|
231 |
+
|
232 |
+
# step 4: random place
|
233 |
+
new_width, new_height = img.size
|
234 |
+
|
235 |
+
img = np.array(img)
|
236 |
+
if 'mask' in kwargs:
|
237 |
+
kwargs['mask'] = np.array(kwargs['mask']) / 255
|
238 |
+
|
239 |
+
start_y = random.randint(0, 512 - new_height)
|
240 |
+
start_x = random.randint(0, 512 - new_width)
|
241 |
+
|
242 |
+
res_img = np.zeros((self.size, self.size, 3), dtype=np.uint8)
|
243 |
+
res_mask = np.zeros((self.size, self.size))
|
244 |
+
res_img_mask = np.zeros((self.size, self.size))
|
245 |
+
|
246 |
+
res_img[start_y:start_y + new_height, start_x:start_x + new_width, :] = img
|
247 |
+
if 'mask' in kwargs:
|
248 |
+
res_mask[start_y:start_y + new_height, start_x:start_x + new_width] = kwargs['mask']
|
249 |
+
kwargs['mask'] = res_mask
|
250 |
+
|
251 |
+
res_img_mask[start_y:start_y + new_height, start_x:start_x + new_width] = 1
|
252 |
+
kwargs['img_mask'] = res_img_mask
|
253 |
+
|
254 |
+
img = Image.fromarray(res_img)
|
255 |
+
|
256 |
+
if 'mask' in kwargs:
|
257 |
+
kwargs['mask'] = cv2.resize(kwargs['mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
|
258 |
+
kwargs['mask'] = torch.from_numpy(kwargs['mask'])
|
259 |
+
kwargs['img_mask'] = cv2.resize(kwargs['img_mask'], (self.size // 8, self.size // 8), cv2.INTER_NEAREST)
|
260 |
+
kwargs['img_mask'] = torch.from_numpy(kwargs['img_mask'])
|
261 |
+
|
262 |
+
return img, kwargs
|
263 |
+
|
264 |
+
|
265 |
+
@TRANSFORM_REGISTRY.register()
|
266 |
+
class ShuffleCaption(nn.Module):
|
267 |
+
def __init__(self, keep_token_num):
|
268 |
+
super().__init__()
|
269 |
+
self.keep_token_num = keep_token_num
|
270 |
+
|
271 |
+
def forward(self, img, **kwargs):
|
272 |
+
prompts = kwargs['prompts'].strip()
|
273 |
+
|
274 |
+
fixed_tokens = []
|
275 |
+
flex_tokens = [t.strip() for t in prompts.strip().split(',')]
|
276 |
+
if self.keep_token_num > 0:
|
277 |
+
fixed_tokens = flex_tokens[:self.keep_token_num]
|
278 |
+
flex_tokens = flex_tokens[self.keep_token_num:]
|
279 |
+
|
280 |
+
random.shuffle(flex_tokens)
|
281 |
+
prompts = ', '.join(fixed_tokens + flex_tokens)
|
282 |
+
kwargs['prompts'] = prompts
|
283 |
+
return img, kwargs
|
284 |
+
|
285 |
+
|
286 |
+
@TRANSFORM_REGISTRY.register()
|
287 |
+
class EnhanceText(nn.Module):
|
288 |
+
def __init__(self, enhance_type='object'):
|
289 |
+
super().__init__()
|
290 |
+
STYLE_TEMPLATE = [
|
291 |
+
'a painting in the style of {}',
|
292 |
+
'a rendering in the style of {}',
|
293 |
+
'a cropped painting in the style of {}',
|
294 |
+
'the painting in the style of {}',
|
295 |
+
'a clean painting in the style of {}',
|
296 |
+
'a dirty painting in the style of {}',
|
297 |
+
'a dark painting in the style of {}',
|
298 |
+
'a picture in the style of {}',
|
299 |
+
'a cool painting in the style of {}',
|
300 |
+
'a close-up painting in the style of {}',
|
301 |
+
'a bright painting in the style of {}',
|
302 |
+
'a cropped painting in the style of {}',
|
303 |
+
'a good painting in the style of {}',
|
304 |
+
'a close-up painting in the style of {}',
|
305 |
+
'a rendition in the style of {}',
|
306 |
+
'a nice painting in the style of {}',
|
307 |
+
'a small painting in the style of {}',
|
308 |
+
'a weird painting in the style of {}',
|
309 |
+
'a large painting in the style of {}',
|
310 |
+
]
|
311 |
+
|
312 |
+
OBJECT_TEMPLATE = [
|
313 |
+
'a photo of a {}',
|
314 |
+
'a rendering of a {}',
|
315 |
+
'a cropped photo of the {}',
|
316 |
+
'the photo of a {}',
|
317 |
+
'a photo of a clean {}',
|
318 |
+
'a photo of a dirty {}',
|
319 |
+
'a dark photo of the {}',
|
320 |
+
'a photo of my {}',
|
321 |
+
'a photo of the cool {}',
|
322 |
+
'a close-up photo of a {}',
|
323 |
+
'a bright photo of the {}',
|
324 |
+
'a cropped photo of a {}',
|
325 |
+
'a photo of the {}',
|
326 |
+
'a good photo of the {}',
|
327 |
+
'a photo of one {}',
|
328 |
+
'a close-up photo of the {}',
|
329 |
+
'a rendition of the {}',
|
330 |
+
'a photo of the clean {}',
|
331 |
+
'a rendition of a {}',
|
332 |
+
'a photo of a nice {}',
|
333 |
+
'a good photo of a {}',
|
334 |
+
'a photo of the nice {}',
|
335 |
+
'a photo of the small {}',
|
336 |
+
'a photo of the weird {}',
|
337 |
+
'a photo of the large {}',
|
338 |
+
'a photo of a cool {}',
|
339 |
+
'a photo of a small {}',
|
340 |
+
]
|
341 |
+
|
342 |
+
HUMAN_TEMPLATE = [
|
343 |
+
'a photo of a {}', 'a photo of one {}', 'a photo of the {}',
|
344 |
+
'the photo of a {}', 'a rendering of a {}',
|
345 |
+
'a rendition of the {}', 'a rendition of a {}',
|
346 |
+
'a cropped photo of the {}', 'a cropped photo of a {}',
|
347 |
+
'a bad photo of the {}', 'a bad photo of a {}',
|
348 |
+
'a photo of a weird {}', 'a weird photo of a {}',
|
349 |
+
'a bright photo of the {}', 'a good photo of the {}',
|
350 |
+
'a photo of a nice {}', 'a good photo of a {}',
|
351 |
+
'a photo of a cool {}', 'a bright photo of the {}'
|
352 |
+
]
|
353 |
+
|
354 |
+
if enhance_type == 'object':
|
355 |
+
self.templates = OBJECT_TEMPLATE
|
356 |
+
elif enhance_type == 'style':
|
357 |
+
self.templates = STYLE_TEMPLATE
|
358 |
+
elif enhance_type == 'human':
|
359 |
+
self.templates = HUMAN_TEMPLATE
|
360 |
+
else:
|
361 |
+
raise NotImplementedError
|
362 |
+
|
363 |
+
def forward(self, img, **kwargs):
|
364 |
+
concept_token = kwargs['prompts'].strip()
|
365 |
+
kwargs['prompts'] = random.choice(self.templates).format(concept_token)
|
366 |
+
return img, kwargs
|
mixofshow/data/prompt_dataset.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import random
|
3 |
+
import re
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
|
8 |
+
|
9 |
+
class PromptDataset(Dataset):
|
10 |
+
'A simple dataset to prepare the prompts to generate class images on multiple GPUs.'
|
11 |
+
|
12 |
+
def __init__(self, opt):
|
13 |
+
self.opt = opt
|
14 |
+
|
15 |
+
self.prompts = opt['prompts']
|
16 |
+
|
17 |
+
if isinstance(self.prompts, list):
|
18 |
+
self.prompts = self.prompts
|
19 |
+
elif os.path.exists(self.prompts):
|
20 |
+
# is file
|
21 |
+
with open(self.prompts, 'r') as fr:
|
22 |
+
lines = fr.readlines()
|
23 |
+
lines = [item.strip() for item in lines]
|
24 |
+
self.prompts = lines
|
25 |
+
else:
|
26 |
+
raise ValueError(
|
27 |
+
'prompts should be a prompt file path or prompt list, please check!'
|
28 |
+
)
|
29 |
+
|
30 |
+
self.prompts = self.replace_placeholder(self.prompts)
|
31 |
+
|
32 |
+
self.num_samples_per_prompt = opt['num_samples_per_prompt']
|
33 |
+
self.prompts_to_generate = [
|
34 |
+
(p, i) for i in range(1, self.num_samples_per_prompt + 1)
|
35 |
+
for p in self.prompts
|
36 |
+
]
|
37 |
+
self.latent_size = opt['latent_size'] # (4,64,64)
|
38 |
+
self.share_latent_across_prompt = opt.get('share_latent_across_prompt', True) # (true, false)
|
39 |
+
|
40 |
+
def replace_placeholder(self, prompts):
|
41 |
+
# replace placehold token
|
42 |
+
replace_mapping = self.opt.get('replace_mapping', {})
|
43 |
+
new_lines = []
|
44 |
+
for line in self.prompts:
|
45 |
+
if len(line.strip()) == 0:
|
46 |
+
continue
|
47 |
+
for k, v in replace_mapping.items():
|
48 |
+
line = line.replace(k, v)
|
49 |
+
line = line.strip()
|
50 |
+
line = re.sub(' +', ' ', line)
|
51 |
+
new_lines.append(line)
|
52 |
+
return new_lines
|
53 |
+
|
54 |
+
def __len__(self):
|
55 |
+
return len(self.prompts_to_generate)
|
56 |
+
|
57 |
+
def __getitem__(self, index):
|
58 |
+
prompt, indice = self.prompts_to_generate[index]
|
59 |
+
example = {}
|
60 |
+
example['prompts'] = prompt
|
61 |
+
example['indices'] = indice
|
62 |
+
if self.share_latent_across_prompt:
|
63 |
+
seed = indice
|
64 |
+
else:
|
65 |
+
seed = random.randint(0, 1000)
|
66 |
+
example['latents'] = torch.randn(self.latent_size, generator=torch.manual_seed(seed))
|
67 |
+
return example
|
mixofshow/models/__pycache__/edlora.cpython-310.pyc
ADDED
Binary file (6.96 kB). View file
|
|
mixofshow/models/__pycache__/edlora.cpython-38.pyc
ADDED
Binary file (6.96 kB). View file
|
|
mixofshow/models/__pycache__/edlora.cpython-39.pyc
ADDED
Binary file (6.95 kB). View file
|
|
mixofshow/models/edlora.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from diffusers.models.attention_processor import AttnProcessor
|
7 |
+
from diffusers.utils.import_utils import is_xformers_available
|
8 |
+
|
9 |
+
if is_xformers_available():
|
10 |
+
import xformers
|
11 |
+
|
12 |
+
|
13 |
+
def remove_edlora_unet_attention_forward(unet):
|
14 |
+
def change_forward(unet): # omit proceesor in new diffusers
|
15 |
+
for name, layer in unet.named_children():
|
16 |
+
if layer.__class__.__name__ == 'Attention' and name == 'attn2':
|
17 |
+
layer.set_processor(AttnProcessor())
|
18 |
+
else:
|
19 |
+
change_forward(layer)
|
20 |
+
change_forward(unet)
|
21 |
+
|
22 |
+
|
23 |
+
class EDLoRA_Control_AttnProcessor:
|
24 |
+
r"""
|
25 |
+
Default processor for performing attention-related computations.
|
26 |
+
"""
|
27 |
+
def __init__(self, cross_attention_idx, place_in_unet, controller, attention_op=None):
|
28 |
+
self.cross_attention_idx = cross_attention_idx
|
29 |
+
self.place_in_unet = place_in_unet
|
30 |
+
self.controller = controller
|
31 |
+
self.attention_op = attention_op
|
32 |
+
|
33 |
+
def __call__(
|
34 |
+
self,
|
35 |
+
attn,
|
36 |
+
hidden_states,
|
37 |
+
encoder_hidden_states=None,
|
38 |
+
attention_mask=None,
|
39 |
+
temb=None,
|
40 |
+
):
|
41 |
+
residual = hidden_states
|
42 |
+
|
43 |
+
if attn.spatial_norm is not None:
|
44 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
45 |
+
|
46 |
+
input_ndim = hidden_states.ndim
|
47 |
+
|
48 |
+
if input_ndim == 4:
|
49 |
+
batch_size, channel, height, width = hidden_states.shape
|
50 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
51 |
+
|
52 |
+
if encoder_hidden_states is None:
|
53 |
+
is_cross = False
|
54 |
+
encoder_hidden_states = hidden_states
|
55 |
+
else:
|
56 |
+
is_cross = True
|
57 |
+
if len(encoder_hidden_states.shape) == 4: # multi-layer embedding
|
58 |
+
encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...]
|
59 |
+
else: # single layer embedding
|
60 |
+
encoder_hidden_states = encoder_hidden_states
|
61 |
+
|
62 |
+
assert not attn.norm_cross
|
63 |
+
|
64 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
65 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
66 |
+
|
67 |
+
if attn.group_norm is not None:
|
68 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
69 |
+
|
70 |
+
query = attn.to_q(hidden_states)
|
71 |
+
key = attn.to_k(encoder_hidden_states)
|
72 |
+
value = attn.to_v(encoder_hidden_states)
|
73 |
+
|
74 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
75 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
76 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
77 |
+
|
78 |
+
if is_xformers_available() and not is_cross:
|
79 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
80 |
+
hidden_states = hidden_states.to(query.dtype)
|
81 |
+
else:
|
82 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
83 |
+
attention_probs = self.controller(attention_probs, is_cross, self.place_in_unet)
|
84 |
+
hidden_states = torch.bmm(attention_probs, value)
|
85 |
+
|
86 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
87 |
+
|
88 |
+
# linear proj
|
89 |
+
hidden_states = attn.to_out[0](hidden_states)
|
90 |
+
# dropout
|
91 |
+
hidden_states = attn.to_out[1](hidden_states)
|
92 |
+
|
93 |
+
if input_ndim == 4:
|
94 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
95 |
+
|
96 |
+
if attn.residual_connection:
|
97 |
+
hidden_states = hidden_states + residual
|
98 |
+
|
99 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
100 |
+
|
101 |
+
return hidden_states
|
102 |
+
|
103 |
+
|
104 |
+
class EDLoRA_AttnProcessor:
|
105 |
+
def __init__(self, cross_attention_idx, attention_op=None):
|
106 |
+
self.attention_op = attention_op
|
107 |
+
self.cross_attention_idx = cross_attention_idx
|
108 |
+
|
109 |
+
def __call__(
|
110 |
+
self,
|
111 |
+
attn,
|
112 |
+
hidden_states,
|
113 |
+
encoder_hidden_states=None,
|
114 |
+
attention_mask=None,
|
115 |
+
temb=None,
|
116 |
+
):
|
117 |
+
residual = hidden_states
|
118 |
+
|
119 |
+
if attn.spatial_norm is not None:
|
120 |
+
hidden_states = attn.spatial_norm(hidden_states, temb)
|
121 |
+
|
122 |
+
input_ndim = hidden_states.ndim
|
123 |
+
|
124 |
+
if input_ndim == 4:
|
125 |
+
batch_size, channel, height, width = hidden_states.shape
|
126 |
+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
127 |
+
|
128 |
+
if encoder_hidden_states is None:
|
129 |
+
encoder_hidden_states = hidden_states
|
130 |
+
else:
|
131 |
+
if len(encoder_hidden_states.shape) == 4: # multi-layer embedding
|
132 |
+
encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...]
|
133 |
+
else: # single layer embedding
|
134 |
+
encoder_hidden_states = encoder_hidden_states
|
135 |
+
|
136 |
+
assert not attn.norm_cross
|
137 |
+
|
138 |
+
batch_size, sequence_length, _ = encoder_hidden_states.shape
|
139 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
140 |
+
|
141 |
+
if attn.group_norm is not None:
|
142 |
+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
143 |
+
|
144 |
+
query = attn.to_q(hidden_states)
|
145 |
+
key = attn.to_k(encoder_hidden_states)
|
146 |
+
value = attn.to_v(encoder_hidden_states)
|
147 |
+
|
148 |
+
query = attn.head_to_batch_dim(query).contiguous()
|
149 |
+
key = attn.head_to_batch_dim(key).contiguous()
|
150 |
+
value = attn.head_to_batch_dim(value).contiguous()
|
151 |
+
|
152 |
+
if is_xformers_available():
|
153 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
154 |
+
hidden_states = hidden_states.to(query.dtype)
|
155 |
+
else:
|
156 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
157 |
+
hidden_states = torch.bmm(attention_probs, value)
|
158 |
+
|
159 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
160 |
+
|
161 |
+
# linear proj
|
162 |
+
hidden_states = attn.to_out[0](hidden_states)
|
163 |
+
# dropout
|
164 |
+
hidden_states = attn.to_out[1](hidden_states)
|
165 |
+
|
166 |
+
if input_ndim == 4:
|
167 |
+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
168 |
+
|
169 |
+
if attn.residual_connection:
|
170 |
+
hidden_states = hidden_states + residual
|
171 |
+
|
172 |
+
hidden_states = hidden_states / attn.rescale_output_factor
|
173 |
+
|
174 |
+
return hidden_states
|
175 |
+
|
176 |
+
|
177 |
+
def revise_edlora_unet_attention_forward(unet):
|
178 |
+
def change_forward(unet, count):
|
179 |
+
for name, layer in unet.named_children():
|
180 |
+
if layer.__class__.__name__ == 'Attention' and 'attn2' in name:
|
181 |
+
layer.set_processor(EDLoRA_AttnProcessor(count))
|
182 |
+
count += 1
|
183 |
+
else:
|
184 |
+
count = change_forward(layer, count)
|
185 |
+
return count
|
186 |
+
|
187 |
+
# use this to ensure the order
|
188 |
+
cross_attention_idx = change_forward(unet.down_blocks, 0)
|
189 |
+
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx)
|
190 |
+
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx)
|
191 |
+
print(f'Number of attention layer registered {cross_attention_idx}')
|
192 |
+
|
193 |
+
|
194 |
+
def revise_edlora_unet_attention_controller_forward(unet, controller):
|
195 |
+
class DummyController:
|
196 |
+
def __call__(self, *args):
|
197 |
+
return args[0]
|
198 |
+
|
199 |
+
def __init__(self):
|
200 |
+
self.num_att_layers = 0
|
201 |
+
|
202 |
+
if controller is None:
|
203 |
+
controller = DummyController()
|
204 |
+
|
205 |
+
def change_forward(unet, count, place_in_unet):
|
206 |
+
for name, layer in unet.named_children():
|
207 |
+
if layer.__class__.__name__ == 'Attention' and 'attn2' in name: # only register controller for cross-attention
|
208 |
+
layer.set_processor(EDLoRA_Control_AttnProcessor(count, place_in_unet, controller))
|
209 |
+
count += 1
|
210 |
+
else:
|
211 |
+
count = change_forward(layer, count, place_in_unet)
|
212 |
+
return count
|
213 |
+
|
214 |
+
# use this to ensure the order
|
215 |
+
cross_attention_idx = change_forward(unet.down_blocks, 0, 'down')
|
216 |
+
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx, 'mid')
|
217 |
+
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx, 'up')
|
218 |
+
print(f'Number of attention layer registered {cross_attention_idx}')
|
219 |
+
controller.num_att_layers = cross_attention_idx
|
220 |
+
|
221 |
+
|
222 |
+
class LoRALinearLayer(nn.Module):
|
223 |
+
def __init__(self, name, original_module, rank=4, alpha=1):
|
224 |
+
super().__init__()
|
225 |
+
|
226 |
+
self.name = name
|
227 |
+
|
228 |
+
### Hard coded LoRA rank
|
229 |
+
rank = 32
|
230 |
+
|
231 |
+
if original_module.__class__.__name__ == 'Conv2d':
|
232 |
+
in_channels, out_channels = original_module.in_channels, original_module.out_channels
|
233 |
+
self.lora_down = torch.nn.Conv2d(in_channels, rank, (1, 1), bias=False)
|
234 |
+
self.lora_up = torch.nn.Conv2d(rank, out_channels, (1, 1), bias=False)
|
235 |
+
else:
|
236 |
+
in_features, out_features = original_module.in_features, original_module.out_features
|
237 |
+
self.lora_down = nn.Linear(in_features, rank, bias=False)
|
238 |
+
self.lora_up = nn.Linear(rank, out_features, bias=False)
|
239 |
+
|
240 |
+
self.register_buffer('alpha', torch.tensor(alpha))
|
241 |
+
|
242 |
+
### Load and initialize orthogonal B
|
243 |
+
m = np.load(f"orthogonal_mats/{in_features}.npy")
|
244 |
+
idxs = np.random.choice(in_features, size = rank, replace = False)
|
245 |
+
m = m[idxs]/2
|
246 |
+
with torch.no_grad():
|
247 |
+
self.lora_down.weight = torch.nn.Parameter(torch.tensor(m, dtype = self.lora_down.weight.dtype))
|
248 |
+
|
249 |
+
torch.nn.init.zeros_(self.lora_up.weight)
|
250 |
+
|
251 |
+
for param in self.lora_down.parameters():
|
252 |
+
param.requires_grad = False
|
253 |
+
|
254 |
+
self.original_forward = original_module.forward
|
255 |
+
original_module.forward = self.forward
|
256 |
+
|
257 |
+
def forward(self, hidden_states):
|
258 |
+
hidden_states = self.original_forward(hidden_states) + self.alpha * self.lora_up(self.lora_down(hidden_states))
|
259 |
+
return hidden_states
|
mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-310.pyc
ADDED
Binary file (8.81 kB). View file
|
|
mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-38.pyc
ADDED
Binary file (8.69 kB). View file
|
|
mixofshow/pipelines/__pycache__/pipeline_edlora.cpython-39.pyc
ADDED
Binary file (8.7 kB). View file
|
|
mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-310.pyc
ADDED
Binary file (19.1 kB). View file
|
|
mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-38.pyc
ADDED
Binary file (19 kB). View file
|
|
mixofshow/pipelines/__pycache__/pipeline_regionally_t2iadapter.cpython-39.pyc
ADDED
Binary file (19 kB). View file
|
|
mixofshow/pipelines/__pycache__/trainer_edlora.cpython-38.pyc
ADDED
Binary file (10.9 kB). View file
|
|
mixofshow/pipelines/__pycache__/trainer_edlora.cpython-39.pyc
ADDED
Binary file (10.9 kB). View file
|
|
mixofshow/pipelines/pipeline_edlora.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from diffusers import StableDiffusionPipeline
|
5 |
+
from diffusers.configuration_utils import FrozenDict
|
6 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
7 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
8 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
9 |
+
from diffusers.utils import deprecate
|
10 |
+
from einops import rearrange
|
11 |
+
from packaging import version
|
12 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
13 |
+
|
14 |
+
from mixofshow.models.edlora import (revise_edlora_unet_attention_controller_forward,
|
15 |
+
revise_edlora_unet_attention_forward)
|
16 |
+
|
17 |
+
|
18 |
+
def bind_concept_prompt(prompts, new_concept_cfg):
|
19 |
+
if isinstance(prompts, str):
|
20 |
+
prompts = [prompts]
|
21 |
+
new_prompts = []
|
22 |
+
for prompt in prompts:
|
23 |
+
prompt = [prompt] * 16
|
24 |
+
for concept_name, new_token_cfg in new_concept_cfg.items():
|
25 |
+
prompt = [
|
26 |
+
p.replace(concept_name, new_name) for p, new_name in zip(prompt, new_token_cfg['concept_token_names'])
|
27 |
+
]
|
28 |
+
new_prompts.extend(prompt)
|
29 |
+
return new_prompts
|
30 |
+
|
31 |
+
|
32 |
+
class EDLoRAPipeline(StableDiffusionPipeline):
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
vae: AutoencoderKL,
|
37 |
+
text_encoder: CLIPTextModel,
|
38 |
+
tokenizer: CLIPTokenizer,
|
39 |
+
unet: UNet2DConditionModel,
|
40 |
+
scheduler: KarrasDiffusionSchedulers,
|
41 |
+
safety_checker=None,
|
42 |
+
feature_extractor=None,
|
43 |
+
requires_safety_checker: bool = False,
|
44 |
+
):
|
45 |
+
if hasattr(scheduler.config, 'steps_offset') and scheduler.config.steps_offset != 1:
|
46 |
+
deprecation_message = (
|
47 |
+
f'The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`'
|
48 |
+
f' should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure '
|
49 |
+
'to update the config accordingly as leaving `steps_offset` might led to incorrect results'
|
50 |
+
' in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,'
|
51 |
+
' it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`'
|
52 |
+
' file'
|
53 |
+
)
|
54 |
+
deprecate('steps_offset!=1', '1.0.0', deprecation_message, standard_warn=False)
|
55 |
+
new_config = dict(scheduler.config)
|
56 |
+
new_config['steps_offset'] = 1
|
57 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
58 |
+
|
59 |
+
if hasattr(scheduler.config, 'clip_sample') and scheduler.config.clip_sample is True:
|
60 |
+
deprecation_message = (
|
61 |
+
f'The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`.'
|
62 |
+
' `clip_sample` should be set to False in the configuration file. Please make sure to update the'
|
63 |
+
' config accordingly as not setting `clip_sample` in the config might lead to incorrect results in'
|
64 |
+
' future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very'
|
65 |
+
' nice if you could open a Pull request for the `scheduler/scheduler_config.json` file'
|
66 |
+
)
|
67 |
+
deprecate('clip_sample not set', '1.0.0', deprecation_message, standard_warn=False)
|
68 |
+
new_config = dict(scheduler.config)
|
69 |
+
new_config['clip_sample'] = False
|
70 |
+
scheduler._internal_dict = FrozenDict(new_config)
|
71 |
+
|
72 |
+
is_unet_version_less_0_9_0 = hasattr(unet.config, '_diffusers_version') and version.parse(
|
73 |
+
version.parse(unet.config._diffusers_version).base_version
|
74 |
+
) < version.parse('0.9.0.dev0')
|
75 |
+
is_unet_sample_size_less_64 = hasattr(unet.config, 'sample_size') and unet.config.sample_size < 64
|
76 |
+
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
|
77 |
+
deprecation_message = (
|
78 |
+
'The configuration file of the unet has set the default `sample_size` to smaller than'
|
79 |
+
' 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the'
|
80 |
+
' following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-'
|
81 |
+
' CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5'
|
82 |
+
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
|
83 |
+
' configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`'
|
84 |
+
' in the config might lead to incorrect results in future versions. If you have downloaded this'
|
85 |
+
' checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for'
|
86 |
+
' the `unet/config.json` file'
|
87 |
+
)
|
88 |
+
deprecate('sample_size<64', '1.0.0', deprecation_message, standard_warn=False)
|
89 |
+
new_config = dict(unet.config)
|
90 |
+
new_config['sample_size'] = 64
|
91 |
+
unet._internal_dict = FrozenDict(new_config)
|
92 |
+
|
93 |
+
revise_edlora_unet_attention_forward(unet)
|
94 |
+
self.register_modules(
|
95 |
+
vae=vae,
|
96 |
+
text_encoder=text_encoder,
|
97 |
+
tokenizer=tokenizer,
|
98 |
+
unet=unet,
|
99 |
+
scheduler=scheduler
|
100 |
+
)
|
101 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
102 |
+
self.new_concept_cfg = None
|
103 |
+
|
104 |
+
def set_new_concept_cfg(self, new_concept_cfg=None):
|
105 |
+
self.new_concept_cfg = new_concept_cfg
|
106 |
+
|
107 |
+
def set_controller(self, controller):
|
108 |
+
self.controller = controller
|
109 |
+
revise_edlora_unet_attention_controller_forward(self.unet, controller)
|
110 |
+
|
111 |
+
def _encode_prompt(self,
|
112 |
+
prompt,
|
113 |
+
new_concept_cfg,
|
114 |
+
device,
|
115 |
+
num_images_per_prompt,
|
116 |
+
do_classifier_free_guidance,
|
117 |
+
negative_prompt=None,
|
118 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
119 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None
|
120 |
+
):
|
121 |
+
|
122 |
+
assert num_images_per_prompt == 1, 'only support num_images_per_prompt=1 now'
|
123 |
+
|
124 |
+
if prompt is not None and isinstance(prompt, str):
|
125 |
+
batch_size = 1
|
126 |
+
elif prompt is not None and isinstance(prompt, list):
|
127 |
+
batch_size = len(prompt)
|
128 |
+
else:
|
129 |
+
batch_size = prompt_embeds.shape[0]
|
130 |
+
|
131 |
+
if prompt_embeds is None:
|
132 |
+
|
133 |
+
prompt_extend = bind_concept_prompt(prompt, new_concept_cfg)
|
134 |
+
|
135 |
+
text_inputs = self.tokenizer(
|
136 |
+
prompt_extend,
|
137 |
+
padding='max_length',
|
138 |
+
max_length=self.tokenizer.model_max_length,
|
139 |
+
truncation=True,
|
140 |
+
return_tensors='pt',
|
141 |
+
)
|
142 |
+
text_input_ids = text_inputs.input_ids
|
143 |
+
|
144 |
+
prompt_embeds = self.text_encoder(text_input_ids.to(device))[0]
|
145 |
+
prompt_embeds = rearrange(prompt_embeds, '(b n) m c -> b n m c', b=batch_size)
|
146 |
+
|
147 |
+
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
148 |
+
|
149 |
+
bs_embed, layer_num, seq_len, _ = prompt_embeds.shape
|
150 |
+
|
151 |
+
# get unconditional embeddings for classifier free guidance
|
152 |
+
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
153 |
+
uncond_tokens: List[str]
|
154 |
+
if negative_prompt is None:
|
155 |
+
uncond_tokens = [''] * batch_size
|
156 |
+
elif type(prompt) is not type(negative_prompt):
|
157 |
+
raise TypeError(f'`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !='
|
158 |
+
f' {type(prompt)}.')
|
159 |
+
elif isinstance(negative_prompt, str):
|
160 |
+
uncond_tokens = [negative_prompt]
|
161 |
+
elif batch_size != len(negative_prompt):
|
162 |
+
raise ValueError(
|
163 |
+
f'`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:'
|
164 |
+
f' {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches'
|
165 |
+
' the batch size of `prompt`.')
|
166 |
+
else:
|
167 |
+
uncond_tokens = negative_prompt
|
168 |
+
|
169 |
+
uncond_input = self.tokenizer(
|
170 |
+
uncond_tokens,
|
171 |
+
padding='max_length',
|
172 |
+
max_length=seq_len,
|
173 |
+
truncation=True,
|
174 |
+
return_tensors='pt',
|
175 |
+
)
|
176 |
+
|
177 |
+
negative_prompt_embeds = self.text_encoder(uncond_input.input_ids.to(device))[0]
|
178 |
+
|
179 |
+
if do_classifier_free_guidance:
|
180 |
+
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
181 |
+
seq_len = negative_prompt_embeds.shape[1]
|
182 |
+
|
183 |
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
184 |
+
negative_prompt_embeds = (negative_prompt_embeds).view(batch_size, 1, seq_len, -1).repeat(1, layer_num, 1, 1)
|
185 |
+
|
186 |
+
# For classifier free guidance, we need to do two forward passes.
|
187 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
188 |
+
# to avoid doing two forward passes
|
189 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
190 |
+
return prompt_embeds
|
191 |
+
|
192 |
+
@torch.no_grad()
|
193 |
+
def __call__(
|
194 |
+
self,
|
195 |
+
prompt: Union[str, List[str]] = None,
|
196 |
+
height: Optional[int] = None,
|
197 |
+
width: Optional[int] = None,
|
198 |
+
num_inference_steps: int = 50,
|
199 |
+
guidance_scale: float = 7.5,
|
200 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
201 |
+
num_images_per_prompt: Optional[int] = 1,
|
202 |
+
eta: float = 0.0,
|
203 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
204 |
+
latents: Optional[torch.FloatTensor] = None,
|
205 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
206 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
207 |
+
output_type: Optional[str] = 'pil',
|
208 |
+
return_dict: bool = True,
|
209 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
210 |
+
callback_steps: int = 1,
|
211 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
212 |
+
):
|
213 |
+
|
214 |
+
# 0. Default height and width to unet
|
215 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
216 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
217 |
+
|
218 |
+
# 1. Check inputs. Raise error if not correct
|
219 |
+
self.check_inputs(prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds)
|
220 |
+
|
221 |
+
# 2. Define call parameters
|
222 |
+
if prompt is not None and isinstance(prompt, str):
|
223 |
+
batch_size = 1
|
224 |
+
elif prompt is not None and isinstance(prompt, list):
|
225 |
+
batch_size = len(prompt)
|
226 |
+
else:
|
227 |
+
batch_size = prompt_embeds.shape[0]
|
228 |
+
|
229 |
+
device = self._execution_device
|
230 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
231 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
232 |
+
# corresponds to doing no classifier free guidance.
|
233 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
234 |
+
|
235 |
+
# 3. Encode input prompt, this support pplus and edlora (layer-wise embedding)
|
236 |
+
assert self.new_concept_cfg is not None
|
237 |
+
prompt_embeds = self._encode_prompt(
|
238 |
+
prompt,
|
239 |
+
self.new_concept_cfg,
|
240 |
+
device,
|
241 |
+
num_images_per_prompt,
|
242 |
+
do_classifier_free_guidance,
|
243 |
+
negative_prompt,
|
244 |
+
prompt_embeds=prompt_embeds,
|
245 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
246 |
+
)
|
247 |
+
|
248 |
+
# 4. Prepare timesteps
|
249 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
250 |
+
timesteps = self.scheduler.timesteps
|
251 |
+
|
252 |
+
# 5. Prepare latent variables
|
253 |
+
num_channels_latents = self.unet.in_channels
|
254 |
+
latents = self.prepare_latents(
|
255 |
+
batch_size * num_images_per_prompt,
|
256 |
+
num_channels_latents,
|
257 |
+
height,
|
258 |
+
width,
|
259 |
+
prompt_embeds.dtype,
|
260 |
+
device,
|
261 |
+
generator,
|
262 |
+
latents,
|
263 |
+
)
|
264 |
+
|
265 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
266 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
267 |
+
|
268 |
+
# 7. Denoising loop
|
269 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
270 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
271 |
+
for i, t in enumerate(timesteps):
|
272 |
+
# expand the latents if we are doing classifier free guidance
|
273 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
274 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
275 |
+
|
276 |
+
# predict the noise residual
|
277 |
+
noise_pred = self.unet(
|
278 |
+
latent_model_input,
|
279 |
+
t,
|
280 |
+
encoder_hidden_states=prompt_embeds,
|
281 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
282 |
+
).sample
|
283 |
+
|
284 |
+
# perform guidance
|
285 |
+
if do_classifier_free_guidance:
|
286 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
287 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
288 |
+
|
289 |
+
# compute the previous noisy sample x_t -> x_t-1
|
290 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
291 |
+
|
292 |
+
if hasattr(self, 'controller'):
|
293 |
+
dtype = latents.dtype
|
294 |
+
latents = self.controller.step_callback(latents)
|
295 |
+
latents = latents.to(dtype)
|
296 |
+
|
297 |
+
# call the callback, if provided
|
298 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
299 |
+
progress_bar.update()
|
300 |
+
if callback is not None and i % callback_steps == 0:
|
301 |
+
callback(i, t, latents)
|
302 |
+
|
303 |
+
if output_type == 'latent':
|
304 |
+
image = latents
|
305 |
+
elif output_type == 'pil':
|
306 |
+
# 8. Post-processing
|
307 |
+
image = self.decode_latents(latents)
|
308 |
+
|
309 |
+
# 10. Convert to PIL
|
310 |
+
image = self.numpy_to_pil(image)
|
311 |
+
else:
|
312 |
+
# 8. Post-processing
|
313 |
+
image = self.decode_latents(latents)
|
314 |
+
|
315 |
+
# Offload last model to CPU
|
316 |
+
if hasattr(self, 'final_offload_hook') and self.final_offload_hook is not None:
|
317 |
+
self.final_offload_hook.offload()
|
318 |
+
|
319 |
+
if not return_dict:
|
320 |
+
return (image)
|
321 |
+
|
322 |
+
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
|
mixofshow/pipelines/pipeline_regionally_t2iadapter.py
ADDED
@@ -0,0 +1,608 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
3 |
+
|
4 |
+
import PIL
|
5 |
+
import torch
|
6 |
+
from diffusers.image_processor import VaeImageProcessor
|
7 |
+
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
8 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
9 |
+
from diffusers.pipelines.t2i_adapter.pipeline_stable_diffusion_adapter import (StableDiffusionAdapterPipeline,
|
10 |
+
StableDiffusionAdapterPipelineOutput,
|
11 |
+
_preprocess_adapter_image)
|
12 |
+
from diffusers.schedulers import KarrasDiffusionSchedulers
|
13 |
+
from diffusers.utils import logging
|
14 |
+
from diffusers.utils.import_utils import is_xformers_available
|
15 |
+
from einops import rearrange
|
16 |
+
from torch import einsum
|
17 |
+
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
18 |
+
|
19 |
+
if is_xformers_available():
|
20 |
+
import xformers
|
21 |
+
|
22 |
+
from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt
|
23 |
+
|
24 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
25 |
+
|
26 |
+
|
27 |
+
class RegionT2I_AttnProcessor:
|
28 |
+
def __init__(self, cross_attention_idx, attention_op=None):
|
29 |
+
self.attention_op = attention_op
|
30 |
+
self.cross_attention_idx = cross_attention_idx
|
31 |
+
|
32 |
+
def region_rewrite(self, attn, hidden_states, query, region_list, height, width):
|
33 |
+
|
34 |
+
def get_region_mask(region_list, feat_height, feat_width):
|
35 |
+
exclusive_mask = torch.zeros((feat_height, feat_width))
|
36 |
+
for region in region_list:
|
37 |
+
start_h, start_w, end_h, end_w = region[-1]
|
38 |
+
start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
|
39 |
+
start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
|
40 |
+
exclusive_mask[start_h:end_h, start_w:end_w] += 1
|
41 |
+
return exclusive_mask
|
42 |
+
|
43 |
+
dtype = query.dtype
|
44 |
+
seq_lens = query.shape[1]
|
45 |
+
downscale = math.sqrt(height * width / seq_lens)
|
46 |
+
|
47 |
+
# 0: context >=1: may be overlap
|
48 |
+
feat_height, feat_width = int(height // downscale), int(width // downscale)
|
49 |
+
region_mask = get_region_mask(region_list, feat_height, feat_width)
|
50 |
+
|
51 |
+
query = rearrange(query, 'b (h w) c -> b h w c', h=feat_height, w=feat_width)
|
52 |
+
hidden_states = rearrange(hidden_states, 'b (h w) c -> b h w c', h=feat_height, w=feat_width)
|
53 |
+
|
54 |
+
new_hidden_state = torch.zeros_like(hidden_states)
|
55 |
+
new_hidden_state[:, region_mask == 0, :] = hidden_states[:, region_mask == 0, :]
|
56 |
+
|
57 |
+
replace_ratio = 1.0
|
58 |
+
new_hidden_state[:, region_mask != 0, :] = (1 - replace_ratio) * hidden_states[:, region_mask != 0, :]
|
59 |
+
|
60 |
+
for region in region_list:
|
61 |
+
region_key, region_value, region_box = region
|
62 |
+
|
63 |
+
if attn.upcast_attention:
|
64 |
+
query = query.float()
|
65 |
+
region_key = region_key.float()
|
66 |
+
|
67 |
+
start_h, start_w, end_h, end_w = region_box
|
68 |
+
start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
|
69 |
+
start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
|
70 |
+
|
71 |
+
attention_region = einsum('b h w c, b n c -> b h w n', query[:, start_h:end_h, start_w:end_w, :], region_key) * attn.scale
|
72 |
+
if attn.upcast_softmax:
|
73 |
+
attention_region = attention_region.float()
|
74 |
+
|
75 |
+
attention_region = attention_region.softmax(dim=-1)
|
76 |
+
attention_region = attention_region.to(dtype)
|
77 |
+
|
78 |
+
hidden_state_region = einsum('b h w n, b n c -> b h w c', attention_region, region_value)
|
79 |
+
new_hidden_state[:, start_h:end_h, start_w:end_w, :] += \
|
80 |
+
replace_ratio * (hidden_state_region / (
|
81 |
+
region_mask.reshape(
|
82 |
+
1, *region_mask.shape, 1)[:, start_h:end_h, start_w:end_w, :]
|
83 |
+
).to(query.device))
|
84 |
+
|
85 |
+
new_hidden_state = rearrange(new_hidden_state, 'b h w c -> b (h w) c')
|
86 |
+
return new_hidden_state
|
87 |
+
|
88 |
+
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, temb=None, **cross_attention_kwargs):
|
89 |
+
batch_size, sequence_length, _ = hidden_states.shape
|
90 |
+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
91 |
+
query = attn.to_q(hidden_states)
|
92 |
+
|
93 |
+
if encoder_hidden_states is None:
|
94 |
+
is_cross = False
|
95 |
+
encoder_hidden_states = hidden_states
|
96 |
+
else:
|
97 |
+
is_cross = True
|
98 |
+
|
99 |
+
if len(encoder_hidden_states.shape) == 4: # multi-layer embedding
|
100 |
+
encoder_hidden_states = encoder_hidden_states[:, self.cross_attention_idx, ...]
|
101 |
+
else:
|
102 |
+
encoder_hidden_states = encoder_hidden_states
|
103 |
+
|
104 |
+
key = attn.to_k(encoder_hidden_states)
|
105 |
+
value = attn.to_v(encoder_hidden_states)
|
106 |
+
|
107 |
+
query = attn.head_to_batch_dim(query)
|
108 |
+
key = attn.head_to_batch_dim(key)
|
109 |
+
value = attn.head_to_batch_dim(value)
|
110 |
+
|
111 |
+
if is_xformers_available() and not is_cross:
|
112 |
+
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
113 |
+
hidden_states = hidden_states.to(query.dtype)
|
114 |
+
else:
|
115 |
+
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
116 |
+
hidden_states = torch.bmm(attention_probs, value)
|
117 |
+
|
118 |
+
if is_cross:
|
119 |
+
region_list = []
|
120 |
+
for region in cross_attention_kwargs['region_list']:
|
121 |
+
if len(region[0].shape) == 4:
|
122 |
+
region_key = attn.to_k(region[0][:, self.cross_attention_idx, ...])
|
123 |
+
region_value = attn.to_v(region[0][:, self.cross_attention_idx, ...])
|
124 |
+
else:
|
125 |
+
region_key = attn.to_k(region[0])
|
126 |
+
region_value = attn.to_v(region[0])
|
127 |
+
region_key = attn.head_to_batch_dim(region_key)
|
128 |
+
region_value = attn.head_to_batch_dim(region_value)
|
129 |
+
region_list.append((region_key, region_value, region[1]))
|
130 |
+
|
131 |
+
hidden_states = self.region_rewrite(
|
132 |
+
attn=attn,
|
133 |
+
hidden_states=hidden_states,
|
134 |
+
query=query,
|
135 |
+
region_list=region_list,
|
136 |
+
height=cross_attention_kwargs['height'],
|
137 |
+
width=cross_attention_kwargs['width'])
|
138 |
+
|
139 |
+
hidden_states = attn.batch_to_head_dim(hidden_states)
|
140 |
+
|
141 |
+
# linear proj
|
142 |
+
hidden_states = attn.to_out[0](hidden_states)
|
143 |
+
# dropout
|
144 |
+
hidden_states = attn.to_out[1](hidden_states)
|
145 |
+
return hidden_states
|
146 |
+
|
147 |
+
|
148 |
+
def revise_regionally_t2iadapter_attention_forward(unet):
|
149 |
+
def change_forward(unet, count):
|
150 |
+
for name, layer in unet.named_children():
|
151 |
+
if layer.__class__.__name__ == 'Attention':
|
152 |
+
layer.set_processor(RegionT2I_AttnProcessor(count))
|
153 |
+
if 'attn2' in name:
|
154 |
+
count += 1
|
155 |
+
else:
|
156 |
+
count = change_forward(layer, count)
|
157 |
+
return count
|
158 |
+
|
159 |
+
# use this to ensure the order
|
160 |
+
cross_attention_idx = change_forward(unet.down_blocks, 0)
|
161 |
+
cross_attention_idx = change_forward(unet.mid_block, cross_attention_idx)
|
162 |
+
cross_attention_idx = change_forward(unet.up_blocks, cross_attention_idx)
|
163 |
+
print(f'Number of attention layer registered {cross_attention_idx}')
|
164 |
+
|
165 |
+
|
166 |
+
class RegionallyT2IAdapterPipeline(StableDiffusionAdapterPipeline):
|
167 |
+
_optional_components = ['safety_checker', 'feature_extractor']
|
168 |
+
|
169 |
+
def __init__(
|
170 |
+
self,
|
171 |
+
vae: AutoencoderKL,
|
172 |
+
text_encoder: CLIPTextModel,
|
173 |
+
tokenizer: CLIPTokenizer,
|
174 |
+
unet: UNet2DConditionModel,
|
175 |
+
scheduler: KarrasDiffusionSchedulers,
|
176 |
+
safety_checker: StableDiffusionSafetyChecker,
|
177 |
+
feature_extractor: CLIPFeatureExtractor,
|
178 |
+
requires_safety_checker: bool = False,
|
179 |
+
):
|
180 |
+
|
181 |
+
if safety_checker is None and requires_safety_checker:
|
182 |
+
logger.warning(
|
183 |
+
f'You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure'
|
184 |
+
' that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered'
|
185 |
+
' results in services or applications open to the public. Both the diffusers team and Hugging Face'
|
186 |
+
' strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling'
|
187 |
+
' it only for use-cases that involve analyzing network behavior or auditing its results. For more'
|
188 |
+
' information, please have a look at https://github.com/huggingface/diffusers/pull/254 .'
|
189 |
+
)
|
190 |
+
|
191 |
+
if safety_checker is not None and feature_extractor is None:
|
192 |
+
raise ValueError(
|
193 |
+
'Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety'
|
194 |
+
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
195 |
+
)
|
196 |
+
|
197 |
+
self.register_modules(
|
198 |
+
vae=vae,
|
199 |
+
text_encoder=text_encoder,
|
200 |
+
tokenizer=tokenizer,
|
201 |
+
unet=unet,
|
202 |
+
scheduler=scheduler,
|
203 |
+
safety_checker=safety_checker,
|
204 |
+
feature_extractor=feature_extractor,
|
205 |
+
)
|
206 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
207 |
+
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
208 |
+
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
209 |
+
self.new_concept_cfg = None
|
210 |
+
revise_regionally_t2iadapter_attention_forward(self.unet)
|
211 |
+
|
212 |
+
def set_new_concept_cfg(self, new_concept_cfg=None):
|
213 |
+
self.new_concept_cfg = new_concept_cfg
|
214 |
+
|
215 |
+
def _encode_region_prompt(self,
|
216 |
+
prompt,
|
217 |
+
new_concept_cfg,
|
218 |
+
device,
|
219 |
+
num_images_per_prompt,
|
220 |
+
do_classifier_free_guidance,
|
221 |
+
negative_prompt=None,
|
222 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
223 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
224 |
+
height=512,
|
225 |
+
width=512
|
226 |
+
):
|
227 |
+
if prompt is not None and isinstance(prompt, str):
|
228 |
+
batch_size = 1
|
229 |
+
elif prompt is not None and isinstance(prompt, list):
|
230 |
+
batch_size = len(prompt)
|
231 |
+
else:
|
232 |
+
batch_size = prompt_embeds.shape[0]
|
233 |
+
|
234 |
+
assert batch_size == 1, 'only sample one prompt once in this version'
|
235 |
+
|
236 |
+
if prompt_embeds is None:
|
237 |
+
context_prompt, region_list = prompt[0][0], prompt[0][1]
|
238 |
+
context_prompt = bind_concept_prompt([context_prompt], new_concept_cfg)
|
239 |
+
context_prompt_input_ids = self.tokenizer(
|
240 |
+
context_prompt,
|
241 |
+
padding='max_length',
|
242 |
+
max_length=self.tokenizer.model_max_length,
|
243 |
+
truncation=True,
|
244 |
+
return_tensors='pt',
|
245 |
+
).input_ids
|
246 |
+
|
247 |
+
prompt_embeds = self.text_encoder(context_prompt_input_ids.to(device), attention_mask=None)[0]
|
248 |
+
prompt_embeds = rearrange(prompt_embeds, '(b n) m c -> b n m c', b=batch_size)
|
249 |
+
prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
250 |
+
|
251 |
+
bs_embed, layer_num, seq_len, _ = prompt_embeds.shape
|
252 |
+
|
253 |
+
if negative_prompt is None:
|
254 |
+
negative_prompt = [''] * batch_size
|
255 |
+
|
256 |
+
negative_prompt_input_ids = self.tokenizer(
|
257 |
+
negative_prompt,
|
258 |
+
padding='max_length',
|
259 |
+
max_length=self.tokenizer.model_max_length,
|
260 |
+
truncation=True,
|
261 |
+
return_tensors='pt').input_ids
|
262 |
+
|
263 |
+
negative_prompt_embeds = self.text_encoder(
|
264 |
+
negative_prompt_input_ids.to(device),
|
265 |
+
attention_mask=None,
|
266 |
+
)[0]
|
267 |
+
|
268 |
+
negative_prompt_embeds = (negative_prompt_embeds).view(batch_size, 1, seq_len, -1).repeat(1, layer_num, 1, 1)
|
269 |
+
negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
270 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
271 |
+
|
272 |
+
for idx, region in enumerate(region_list):
|
273 |
+
region_prompt, region_neg_prompt, pos = region
|
274 |
+
region_prompt = bind_concept_prompt([region_prompt], new_concept_cfg)
|
275 |
+
region_prompt_input_ids = self.tokenizer(
|
276 |
+
region_prompt,
|
277 |
+
padding='max_length',
|
278 |
+
max_length=self.tokenizer.model_max_length,
|
279 |
+
truncation=True,
|
280 |
+
return_tensors='pt').input_ids
|
281 |
+
region_embeds = self.text_encoder(region_prompt_input_ids.to(device), attention_mask=None)[0]
|
282 |
+
region_embeds = rearrange(region_embeds, '(b n) m c -> b n m c', b=batch_size)
|
283 |
+
region_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
284 |
+
bs_embed, layer_num, seq_len, _ = region_embeds.shape
|
285 |
+
|
286 |
+
if region_neg_prompt is None:
|
287 |
+
region_neg_prompt = [''] * batch_size
|
288 |
+
region_negprompt_input_ids = self.tokenizer(
|
289 |
+
region_neg_prompt,
|
290 |
+
padding='max_length',
|
291 |
+
max_length=self.tokenizer.model_max_length,
|
292 |
+
truncation=True,
|
293 |
+
return_tensors='pt').input_ids
|
294 |
+
region_neg_embeds = self.text_encoder(region_negprompt_input_ids.to(device), attention_mask=None)[0]
|
295 |
+
region_neg_embeds = (region_neg_embeds).view(batch_size, 1, seq_len, -1).repeat(1, layer_num, 1, 1)
|
296 |
+
region_neg_embeds.to(dtype=self.text_encoder.dtype, device=device)
|
297 |
+
region_list[idx] = (torch.cat([region_neg_embeds, region_embeds]), pos)
|
298 |
+
|
299 |
+
return prompt_embeds, region_list
|
300 |
+
|
301 |
+
@torch.no_grad()
|
302 |
+
def __call__(
|
303 |
+
self,
|
304 |
+
prompt: Union[str, List[str]] = None,
|
305 |
+
keypose_adapter_input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
|
306 |
+
keypose_adaptor_weight=1.0,
|
307 |
+
region_keypose_adaptor_weight='',
|
308 |
+
sketch_adapter_input: Union[torch.Tensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
|
309 |
+
sketch_adaptor_weight=1.0,
|
310 |
+
region_sketch_adaptor_weight='',
|
311 |
+
height: Optional[int] = None,
|
312 |
+
width: Optional[int] = None,
|
313 |
+
num_inference_steps: int = 50,
|
314 |
+
guidance_scale: float = 7.5,
|
315 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
316 |
+
num_images_per_prompt: Optional[int] = 1,
|
317 |
+
eta: float = 0.0,
|
318 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
319 |
+
latents: Optional[torch.FloatTensor] = None,
|
320 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
321 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
322 |
+
output_type: Optional[str] = 'pil',
|
323 |
+
return_dict: bool = True,
|
324 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
325 |
+
callback_steps: int = 1,
|
326 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
327 |
+
):
|
328 |
+
r"""
|
329 |
+
Function invoked when calling the pipeline for generation.
|
330 |
+
|
331 |
+
Args:
|
332 |
+
prompt (`str` or `List[str]`, *optional*):
|
333 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
334 |
+
instead.
|
335 |
+
image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[List[PIL.Image.Image]]`):
|
336 |
+
The Adapter input condition. Adapter uses this input condition to generate guidance to Unet. If the
|
337 |
+
type is specified as `Torch.FloatTensor`, it is passed to Adapter as is. PIL.Image.Image` can also be
|
338 |
+
accepted as an image. The control image is automatically resized to fit the output image.
|
339 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
340 |
+
The height in pixels of the generated image.
|
341 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
342 |
+
The width in pixels of the generated image.
|
343 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
344 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
345 |
+
expense of slower inference.
|
346 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
347 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
348 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
349 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
350 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
351 |
+
usually at the expense of lower image quality.
|
352 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
353 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
354 |
+
`negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
|
355 |
+
Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
|
356 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
357 |
+
The number of images to generate per prompt.
|
358 |
+
eta (`float`, *optional*, defaults to 0.0):
|
359 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
360 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
361 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
362 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
363 |
+
to make generation deterministic.
|
364 |
+
latents (`torch.FloatTensor`, *optional*):
|
365 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
366 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
367 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
368 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
369 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
370 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
371 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
372 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
373 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
374 |
+
argument.
|
375 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
376 |
+
The output format of the generate image. Choose between
|
377 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
378 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
379 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] instead
|
380 |
+
of a plain tuple.
|
381 |
+
callback (`Callable`, *optional*):
|
382 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
383 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
384 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
385 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
386 |
+
called at every step.
|
387 |
+
cross_attention_kwargs (`dict`, *optional*):
|
388 |
+
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
389 |
+
`self.processor` in
|
390 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
391 |
+
adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
392 |
+
The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
|
393 |
+
residual in the original unet. If multiple adapters are specified in init, you can set the
|
394 |
+
corresponding scale as a list.
|
395 |
+
|
396 |
+
Examples:
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
[`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] or `tuple`:
|
400 |
+
[`~pipelines.stable_diffusion.StableDiffusionAdapterPipelineOutput`] if `return_dict` is True, otherwise a
|
401 |
+
`tuple. When returning a tuple, the first element is a list with the generated images, and the second
|
402 |
+
element is a list of `bool`s denoting whether the corresponding generated image likely represents
|
403 |
+
"not-safe-for-work" (nsfw) content, according to the `safety_checker`.
|
404 |
+
"""
|
405 |
+
# 0. Default height and width to unet
|
406 |
+
device = self._execution_device
|
407 |
+
|
408 |
+
# 1. Check inputs. Raise error if not correct
|
409 |
+
self.check_inputs(
|
410 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
411 |
+
)
|
412 |
+
|
413 |
+
if keypose_adapter_input is not None:
|
414 |
+
keypose_input = _preprocess_adapter_image(keypose_adapter_input, height, width).to(self.device)
|
415 |
+
keypose_input = keypose_input.to(self.keypose_adapter.dtype)
|
416 |
+
else:
|
417 |
+
keypose_input = None
|
418 |
+
|
419 |
+
if sketch_adapter_input is not None:
|
420 |
+
sketch_input = _preprocess_adapter_image(sketch_adapter_input, height, width).to(self.device)
|
421 |
+
sketch_input = sketch_input.to(self.sketch_adapter.dtype)
|
422 |
+
else:
|
423 |
+
sketch_input = None
|
424 |
+
|
425 |
+
# 2. Define call parameters
|
426 |
+
if prompt is not None and isinstance(prompt, str):
|
427 |
+
batch_size = 1
|
428 |
+
elif prompt is not None and isinstance(prompt, list):
|
429 |
+
batch_size = len(prompt)
|
430 |
+
else:
|
431 |
+
batch_size = prompt_embeds.shape[0]
|
432 |
+
|
433 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
434 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
435 |
+
# corresponds to doing no classifier free guidance.
|
436 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
437 |
+
|
438 |
+
# 3. Encode input prompt
|
439 |
+
assert self.new_concept_cfg is not None
|
440 |
+
prompt_embeds, region_list = self._encode_region_prompt(
|
441 |
+
prompt,
|
442 |
+
self.new_concept_cfg,
|
443 |
+
device,
|
444 |
+
num_images_per_prompt,
|
445 |
+
do_classifier_free_guidance,
|
446 |
+
negative_prompt,
|
447 |
+
prompt_embeds=prompt_embeds,
|
448 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
449 |
+
height=height,
|
450 |
+
width=width
|
451 |
+
)
|
452 |
+
|
453 |
+
# 4. Prepare timesteps
|
454 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
455 |
+
timesteps = self.scheduler.timesteps
|
456 |
+
|
457 |
+
# 5. Prepare latent variables
|
458 |
+
num_channels_latents = self.unet.config.in_channels
|
459 |
+
latents = self.prepare_latents(
|
460 |
+
batch_size * num_images_per_prompt,
|
461 |
+
num_channels_latents,
|
462 |
+
height,
|
463 |
+
width,
|
464 |
+
prompt_embeds.dtype,
|
465 |
+
device,
|
466 |
+
generator,
|
467 |
+
latents,
|
468 |
+
)
|
469 |
+
|
470 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
471 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
472 |
+
|
473 |
+
# 7. Denoising loop
|
474 |
+
if keypose_input is not None:
|
475 |
+
keypose_adapter_state = self.keypose_adapter(keypose_input)
|
476 |
+
else:
|
477 |
+
keypose_adapter_state = None
|
478 |
+
|
479 |
+
if sketch_input is not None:
|
480 |
+
sketch_adapter_state = self.sketch_adapter(sketch_input)
|
481 |
+
else:
|
482 |
+
sketch_adapter_state = None
|
483 |
+
|
484 |
+
num_states = len(keypose_adapter_state) if keypose_adapter_state is not None else len(sketch_adapter_state)
|
485 |
+
|
486 |
+
adapter_state = []
|
487 |
+
|
488 |
+
for idx in range(num_states):
|
489 |
+
if keypose_adapter_state is not None:
|
490 |
+
feat_keypose = keypose_adapter_state[idx]
|
491 |
+
|
492 |
+
spatial_adaptor_weight = keypose_adaptor_weight * torch.ones(*feat_keypose.shape[2:]).to(
|
493 |
+
feat_keypose.dtype).to(feat_keypose.device)
|
494 |
+
|
495 |
+
if region_keypose_adaptor_weight != '':
|
496 |
+
region_list = region_keypose_adaptor_weight.split('|')
|
497 |
+
|
498 |
+
for region_weight in region_list:
|
499 |
+
region, weight = region_weight.split('-')
|
500 |
+
region = eval(region)
|
501 |
+
weight = eval(weight)
|
502 |
+
feat_height, feat_width = feat_keypose.shape[2:]
|
503 |
+
start_h, start_w, end_h, end_w = region
|
504 |
+
start_h, end_h = start_h / height, end_h / height
|
505 |
+
start_w, end_w = start_w / width, end_w / width
|
506 |
+
|
507 |
+
start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
|
508 |
+
start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
|
509 |
+
|
510 |
+
spatial_adaptor_weight[start_h:end_h, start_w:end_w] = weight
|
511 |
+
feat_keypose = spatial_adaptor_weight * feat_keypose
|
512 |
+
|
513 |
+
else:
|
514 |
+
feat_keypose = 0
|
515 |
+
|
516 |
+
if sketch_adapter_state is not None:
|
517 |
+
feat_sketch = sketch_adapter_state[idx]
|
518 |
+
# print(feat_keypose.shape) # torch.Size([1, 320, 64, 128])
|
519 |
+
spatial_adaptor_weight = sketch_adaptor_weight * torch.ones(*feat_sketch.shape[2:]).to(
|
520 |
+
feat_sketch.dtype).to(feat_sketch.device)
|
521 |
+
|
522 |
+
if region_sketch_adaptor_weight != '':
|
523 |
+
region_list = region_sketch_adaptor_weight.split('|')
|
524 |
+
|
525 |
+
for region_weight in region_list:
|
526 |
+
region, weight = region_weight.split('-')
|
527 |
+
region = eval(region)
|
528 |
+
weight = eval(weight)
|
529 |
+
feat_height, feat_width = feat_sketch.shape[2:]
|
530 |
+
start_h, start_w, end_h, end_w = region
|
531 |
+
start_h, end_h = start_h / height, end_h / height
|
532 |
+
start_w, end_w = start_w / width, end_w / width
|
533 |
+
|
534 |
+
start_h, start_w, end_h, end_w = math.ceil(start_h * feat_height), math.ceil(
|
535 |
+
start_w * feat_width), math.floor(end_h * feat_height), math.floor(end_w * feat_width)
|
536 |
+
|
537 |
+
spatial_adaptor_weight[start_h:end_h, start_w:end_w] = weight
|
538 |
+
feat_sketch = spatial_adaptor_weight * feat_sketch
|
539 |
+
else:
|
540 |
+
feat_sketch = 0
|
541 |
+
|
542 |
+
adapter_state.append(feat_keypose + feat_sketch)
|
543 |
+
|
544 |
+
if do_classifier_free_guidance:
|
545 |
+
for k, v in enumerate(adapter_state):
|
546 |
+
adapter_state[k] = torch.cat([v] * 2, dim=0)
|
547 |
+
|
548 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
549 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
550 |
+
for i, t in enumerate(timesteps):
|
551 |
+
# expand the latents if we are doing classifier free guidance
|
552 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
553 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
554 |
+
|
555 |
+
# predict the noise residual
|
556 |
+
noise_pred = self.unet(
|
557 |
+
latent_model_input,
|
558 |
+
t,
|
559 |
+
encoder_hidden_states=prompt_embeds,
|
560 |
+
cross_attention_kwargs={
|
561 |
+
'region_list': region_list,
|
562 |
+
'height': height,
|
563 |
+
'width': width,
|
564 |
+
},
|
565 |
+
down_block_additional_residuals=[state.clone() for state in adapter_state],
|
566 |
+
).sample
|
567 |
+
|
568 |
+
# perform guidance
|
569 |
+
if do_classifier_free_guidance:
|
570 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
571 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
572 |
+
|
573 |
+
# compute the previous noisy sample x_t -> x_t-1
|
574 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
575 |
+
|
576 |
+
# call the callback, if provided
|
577 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
578 |
+
progress_bar.update()
|
579 |
+
if callback is not None and i % callback_steps == 0:
|
580 |
+
callback(i, t, latents)
|
581 |
+
|
582 |
+
if output_type == 'latent':
|
583 |
+
image = latents
|
584 |
+
has_nsfw_concept = None
|
585 |
+
elif output_type == 'pil':
|
586 |
+
# 8. Post-processing
|
587 |
+
image = self.decode_latents(latents)
|
588 |
+
|
589 |
+
# 9. Run safety checker
|
590 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
591 |
+
|
592 |
+
# 10. Convert to PIL
|
593 |
+
image = self.numpy_to_pil(image)
|
594 |
+
else:
|
595 |
+
# 8. Post-processing
|
596 |
+
image = self.decode_latents(latents)
|
597 |
+
|
598 |
+
# 9. Run safety checker
|
599 |
+
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
600 |
+
|
601 |
+
# Offload last model to CPU
|
602 |
+
if hasattr(self, 'final_offload_hook') and self.final_offload_hook is not None:
|
603 |
+
self.final_offload_hook.offload()
|
604 |
+
|
605 |
+
if not return_dict:
|
606 |
+
return (image, has_nsfw_concept)
|
607 |
+
|
608 |
+
return StableDiffusionAdapterPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
mixofshow/pipelines/trainer_edlora.py
ADDED
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import math
|
3 |
+
import re
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
from accelerate.logging import get_logger
|
9 |
+
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
10 |
+
from diffusers.utils.import_utils import is_xformers_available
|
11 |
+
from einops import rearrange
|
12 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
13 |
+
|
14 |
+
from mixofshow.models.edlora import (LoRALinearLayer, revise_edlora_unet_attention_controller_forward,
|
15 |
+
revise_edlora_unet_attention_forward)
|
16 |
+
from mixofshow.pipelines.pipeline_edlora import bind_concept_prompt
|
17 |
+
from mixofshow.utils.ptp_util import AttentionStore
|
18 |
+
|
19 |
+
|
20 |
+
class EDLoRATrainer(nn.Module):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
pretrained_path,
|
24 |
+
new_concept_token,
|
25 |
+
initializer_token,
|
26 |
+
enable_edlora, # true for ED-LoRA, false for LoRA
|
27 |
+
finetune_cfg=None,
|
28 |
+
noise_offset=None,
|
29 |
+
attn_reg_weight=None,
|
30 |
+
reg_full_identity=True, # True for thanos, False for real person (don't need to encode clothes)
|
31 |
+
use_mask_loss=True,
|
32 |
+
enable_xformers=False,
|
33 |
+
gradient_checkpoint=False
|
34 |
+
):
|
35 |
+
super().__init__()
|
36 |
+
|
37 |
+
# 1. Load the model.
|
38 |
+
self.vae = AutoencoderKL.from_pretrained(pretrained_path, subfolder='vae')
|
39 |
+
self.tokenizer = CLIPTokenizer.from_pretrained(pretrained_path, subfolder='tokenizer')
|
40 |
+
self.text_encoder = CLIPTextModel.from_pretrained(pretrained_path, subfolder='text_encoder')
|
41 |
+
self.unet = UNet2DConditionModel.from_pretrained(pretrained_path, subfolder='unet')
|
42 |
+
|
43 |
+
if gradient_checkpoint:
|
44 |
+
self.unet.enable_gradient_checkpointing()
|
45 |
+
|
46 |
+
if enable_xformers:
|
47 |
+
assert is_xformers_available(), 'need to install xformer first'
|
48 |
+
|
49 |
+
# 2. Define train scheduler
|
50 |
+
self.scheduler = DDPMScheduler.from_pretrained(pretrained_path, subfolder='scheduler')
|
51 |
+
|
52 |
+
# 3. define training cfg
|
53 |
+
self.enable_edlora = enable_edlora
|
54 |
+
self.new_concept_cfg = self.init_new_concept(new_concept_token, initializer_token, enable_edlora=enable_edlora)
|
55 |
+
|
56 |
+
self.attn_reg_weight = attn_reg_weight
|
57 |
+
self.reg_full_identity = reg_full_identity
|
58 |
+
if self.attn_reg_weight is not None:
|
59 |
+
self.controller = AttentionStore(training=True)
|
60 |
+
revise_edlora_unet_attention_controller_forward(self.unet, self.controller) # support both lora and edlora forward
|
61 |
+
else:
|
62 |
+
revise_edlora_unet_attention_forward(self.unet) # support both lora and edlora forward
|
63 |
+
|
64 |
+
if finetune_cfg:
|
65 |
+
self.set_finetune_cfg(finetune_cfg)
|
66 |
+
|
67 |
+
self.noise_offset = noise_offset
|
68 |
+
self.use_mask_loss = use_mask_loss
|
69 |
+
|
70 |
+
def set_finetune_cfg(self, finetune_cfg):
|
71 |
+
logger = get_logger('mixofshow', log_level='INFO')
|
72 |
+
params_to_freeze = [self.vae.parameters(), self.text_encoder.parameters(), self.unet.parameters()]
|
73 |
+
|
74 |
+
# step 1: close all parameters, required_grad to False
|
75 |
+
for params in itertools.chain(*params_to_freeze):
|
76 |
+
params.requires_grad = False
|
77 |
+
|
78 |
+
# step 2: begin to add trainable paramters
|
79 |
+
params_group_list = []
|
80 |
+
|
81 |
+
# 1. text embedding
|
82 |
+
if finetune_cfg['text_embedding']['enable_tuning']:
|
83 |
+
text_embedding_cfg = finetune_cfg['text_embedding']
|
84 |
+
|
85 |
+
params_list = []
|
86 |
+
for params in self.text_encoder.get_input_embeddings().parameters():
|
87 |
+
params.requires_grad = True
|
88 |
+
params_list.append(params)
|
89 |
+
|
90 |
+
params_group = {'params': params_list, 'lr': text_embedding_cfg['lr']}
|
91 |
+
if 'weight_decay' in text_embedding_cfg:
|
92 |
+
params_group.update({'weight_decay': text_embedding_cfg['weight_decay']})
|
93 |
+
params_group_list.append(params_group)
|
94 |
+
logger.info(f"optimizing embedding using lr: {text_embedding_cfg['lr']}")
|
95 |
+
|
96 |
+
# 2. text encoder
|
97 |
+
if finetune_cfg['text_encoder']['enable_tuning'] and finetune_cfg['text_encoder'].get('lora_cfg'):
|
98 |
+
text_encoder_cfg = finetune_cfg['text_encoder']
|
99 |
+
|
100 |
+
where = text_encoder_cfg['lora_cfg'].pop('where')
|
101 |
+
assert where in ['CLIPEncoderLayer', 'CLIPAttention']
|
102 |
+
|
103 |
+
self.text_encoder_lora = nn.ModuleList()
|
104 |
+
params_list = []
|
105 |
+
|
106 |
+
for name, module in self.text_encoder.named_modules():
|
107 |
+
if module.__class__.__name__ == where:
|
108 |
+
for child_name, child_module in module.named_modules():
|
109 |
+
if child_module.__class__.__name__ == 'Linear':
|
110 |
+
lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **text_encoder_cfg['lora_cfg'])
|
111 |
+
self.text_encoder_lora.append(lora_module)
|
112 |
+
params_list.extend(list(lora_module.parameters()))
|
113 |
+
|
114 |
+
params_group_list.append({'params': params_list, 'lr': text_encoder_cfg['lr']})
|
115 |
+
logger.info(f"optimizing text_encoder ({len(self.text_encoder_lora)} LoRAs), using lr: {text_encoder_cfg['lr']}")
|
116 |
+
|
117 |
+
# 3. unet
|
118 |
+
if finetune_cfg['unet']['enable_tuning'] and finetune_cfg['unet'].get('lora_cfg'):
|
119 |
+
unet_cfg = finetune_cfg['unet']
|
120 |
+
|
121 |
+
where = unet_cfg['lora_cfg'].pop('where')
|
122 |
+
assert where in ['Transformer2DModel', 'Attention']
|
123 |
+
|
124 |
+
self.unet_lora = nn.ModuleList()
|
125 |
+
params_list = []
|
126 |
+
|
127 |
+
for name, module in self.unet.named_modules():
|
128 |
+
if module.__class__.__name__ == where:
|
129 |
+
for child_name, child_module in module.named_modules():
|
130 |
+
if child_module.__class__.__name__ == 'Linear' or (child_module.__class__.__name__ == 'Conv2d' and child_module.kernel_size == (1, 1)):
|
131 |
+
lora_module = LoRALinearLayer(name + '.' + child_name, child_module, **unet_cfg['lora_cfg'])
|
132 |
+
self.unet_lora.append(lora_module)
|
133 |
+
params_list.extend(list(lora_module.parameters()))
|
134 |
+
|
135 |
+
params_group_list.append({'params': params_list, 'lr': unet_cfg['lr']})
|
136 |
+
logger.info(f"optimizing unet ({len(self.unet_lora)} LoRAs), using lr: {unet_cfg['lr']}")
|
137 |
+
|
138 |
+
# 4. optimize params
|
139 |
+
self.params_to_optimize_iterator = params_group_list
|
140 |
+
|
141 |
+
def get_params_to_optimize(self):
|
142 |
+
return self.params_to_optimize_iterator
|
143 |
+
|
144 |
+
def init_new_concept(self, new_concept_tokens, initializer_tokens, enable_edlora=True):
|
145 |
+
logger = get_logger('mixofshow', log_level='INFO')
|
146 |
+
new_concept_cfg = {}
|
147 |
+
new_concept_tokens = new_concept_tokens.split('+')
|
148 |
+
|
149 |
+
if initializer_tokens is None:
|
150 |
+
initializer_tokens = ['<rand-0.017>'] * len(new_concept_tokens)
|
151 |
+
else:
|
152 |
+
initializer_tokens = initializer_tokens.split('+')
|
153 |
+
assert len(new_concept_tokens) == len(initializer_tokens), 'concept token should match init token.'
|
154 |
+
|
155 |
+
for idx, (concept_name, init_token) in enumerate(zip(new_concept_tokens, initializer_tokens)):
|
156 |
+
if enable_edlora:
|
157 |
+
num_new_embedding = 16
|
158 |
+
else:
|
159 |
+
num_new_embedding = 1
|
160 |
+
new_token_names = [f'<new{idx * num_new_embedding + layer_id}>' for layer_id in range(num_new_embedding)]
|
161 |
+
|
162 |
+
num_added_tokens = self.tokenizer.add_tokens(new_token_names)
|
163 |
+
assert num_added_tokens == len(new_token_names), 'some token is already in tokenizer'
|
164 |
+
new_token_ids = [self.tokenizer.convert_tokens_to_ids(token_name) for token_name in new_token_names]
|
165 |
+
|
166 |
+
# init embedding
|
167 |
+
self.text_encoder.resize_token_embeddings(len(self.tokenizer))
|
168 |
+
token_embeds = self.text_encoder.get_input_embeddings().weight.data
|
169 |
+
|
170 |
+
if init_token.startswith('<rand'):
|
171 |
+
sigma_val = float(re.findall(r'<rand-(.*)>', init_token)[0])
|
172 |
+
init_feature = torch.randn_like(token_embeds[0]) * sigma_val
|
173 |
+
logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by: {init_token}')
|
174 |
+
else:
|
175 |
+
# Convert the initializer_token, placeholder_token to ids
|
176 |
+
init_token_ids = self.tokenizer.encode(init_token, add_special_tokens=False)
|
177 |
+
# print(token_ids)
|
178 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
179 |
+
if len(init_token_ids) > 1 or init_token_ids[0] == 40497:
|
180 |
+
raise ValueError('The initializer token must be a single existing token.')
|
181 |
+
init_feature = token_embeds[init_token_ids]
|
182 |
+
logger.info(f'{concept_name} ({min(new_token_ids)}-{max(new_token_ids)}) is random initialized by existing token ({init_token}): {init_token_ids[0]}')
|
183 |
+
|
184 |
+
for token_id in new_token_ids:
|
185 |
+
token_embeds[token_id] = init_feature.clone()
|
186 |
+
|
187 |
+
new_concept_cfg.update({
|
188 |
+
concept_name: {
|
189 |
+
'concept_token_ids': new_token_ids,
|
190 |
+
'concept_token_names': new_token_names
|
191 |
+
}
|
192 |
+
})
|
193 |
+
|
194 |
+
return new_concept_cfg
|
195 |
+
|
196 |
+
def get_all_concept_token_ids(self):
|
197 |
+
new_concept_token_ids = []
|
198 |
+
for _, new_token_cfg in self.new_concept_cfg.items():
|
199 |
+
new_concept_token_ids.extend(new_token_cfg['concept_token_ids'])
|
200 |
+
return new_concept_token_ids
|
201 |
+
|
202 |
+
def forward(self, images, prompts, masks, img_masks):
|
203 |
+
latents = self.vae.encode(images).latent_dist.sample()
|
204 |
+
latents = latents * 0.18215
|
205 |
+
|
206 |
+
# Sample noise that we'll add to the latents
|
207 |
+
noise = torch.randn_like(latents)
|
208 |
+
if self.noise_offset is not None:
|
209 |
+
noise += self.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1), device=latents.device)
|
210 |
+
|
211 |
+
bsz = latents.shape[0]
|
212 |
+
# Sample a random timestep for each image
|
213 |
+
timesteps = torch.randint(0, self.scheduler.config.num_train_timesteps, (bsz, ), device=latents.device)
|
214 |
+
timesteps = timesteps.long()
|
215 |
+
|
216 |
+
# Add noise to the latents according to the noise magnitude at each timestep
|
217 |
+
# (this is the forward diffusion process)
|
218 |
+
noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
|
219 |
+
|
220 |
+
if self.enable_edlora:
|
221 |
+
prompts = bind_concept_prompt(prompts, new_concept_cfg=self.new_concept_cfg) # edlora
|
222 |
+
|
223 |
+
# get text ids
|
224 |
+
text_input_ids = self.tokenizer(
|
225 |
+
prompts,
|
226 |
+
padding='max_length',
|
227 |
+
max_length=self.tokenizer.model_max_length,
|
228 |
+
truncation=True,
|
229 |
+
return_tensors='pt').input_ids.to(latents.device)
|
230 |
+
|
231 |
+
# Get the text embedding for conditioning
|
232 |
+
encoder_hidden_states = self.text_encoder(text_input_ids)[0]
|
233 |
+
if self.enable_edlora:
|
234 |
+
encoder_hidden_states = rearrange(encoder_hidden_states, '(b n) m c -> b n m c', b=latents.shape[0]) # edlora
|
235 |
+
|
236 |
+
# Predict the noise residual
|
237 |
+
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
238 |
+
|
239 |
+
# Get the target for loss depending on the prediction type
|
240 |
+
if self.scheduler.config.prediction_type == 'epsilon':
|
241 |
+
target = noise
|
242 |
+
elif self.scheduler.config.prediction_type == 'v_prediction':
|
243 |
+
target = self.scheduler.get_velocity(latents, noise, timesteps)
|
244 |
+
else:
|
245 |
+
raise ValueError(f'Unknown prediction type {self.scheduler.config.prediction_type}')
|
246 |
+
|
247 |
+
if self.use_mask_loss:
|
248 |
+
loss_mask = masks
|
249 |
+
else:
|
250 |
+
loss_mask = img_masks
|
251 |
+
loss = F.mse_loss(model_pred.float(), target.float(), reduction='none')
|
252 |
+
loss = ((loss * loss_mask).sum([1, 2, 3]) / loss_mask.sum([1, 2, 3])).mean()
|
253 |
+
|
254 |
+
if self.attn_reg_weight is not None:
|
255 |
+
attention_maps = self.controller.get_average_attention()
|
256 |
+
attention_loss = self.cal_attn_reg(attention_maps, masks, text_input_ids)
|
257 |
+
if not torch.isnan(attention_loss): # full mask
|
258 |
+
loss = loss + attention_loss
|
259 |
+
self.controller.reset()
|
260 |
+
|
261 |
+
return loss
|
262 |
+
|
263 |
+
def cal_attn_reg(self, attention_maps, masks, text_input_ids):
|
264 |
+
'''
|
265 |
+
attention_maps: {down_cross:[], mid_cross:[], up_cross:[]}
|
266 |
+
masks: torch.Size([1, 1, 64, 64])
|
267 |
+
text_input_ids: torch.Size([16, 77])
|
268 |
+
'''
|
269 |
+
# step 1: find token position
|
270 |
+
batch_size = masks.shape[0]
|
271 |
+
text_input_ids = rearrange(text_input_ids, '(b l) n -> b l n', b=batch_size)
|
272 |
+
# print(masks.shape) # torch.Size([2, 1, 64, 64])
|
273 |
+
# print(text_input_ids.shape) # torch.Size([2, 16, 77])
|
274 |
+
|
275 |
+
new_token_pos = []
|
276 |
+
all_concept_token_ids = self.get_all_concept_token_ids()
|
277 |
+
for text in text_input_ids:
|
278 |
+
text = text[0] # even multi-layer embedding, we extract the first one
|
279 |
+
new_token_pos.append([idx for idx in range(len(text)) if text[idx] in all_concept_token_ids])
|
280 |
+
|
281 |
+
# step2: aggregate attention maps with resolution and concat heads
|
282 |
+
attention_groups = {'64': [], '32': [], '16': [], '8': []}
|
283 |
+
for _, attention_list in attention_maps.items():
|
284 |
+
for attn in attention_list:
|
285 |
+
res = int(math.sqrt(attn.shape[1]))
|
286 |
+
cross_map = attn.reshape(batch_size, -1, res, res, attn.shape[-1])
|
287 |
+
attention_groups[str(res)].append(cross_map)
|
288 |
+
|
289 |
+
for k, cross_map in attention_groups.items():
|
290 |
+
cross_map = torch.cat(cross_map, dim=-4) # concat heads
|
291 |
+
cross_map = cross_map.sum(-4) / cross_map.shape[-4] # e.g., 64 torch.Size([2, 64, 64, 77])
|
292 |
+
cross_map = torch.stack([batch_map[..., batch_pos] for batch_pos, batch_map in zip(new_token_pos, cross_map)]) # torch.Size([2, 64, 64, 2])
|
293 |
+
attention_groups[k] = cross_map
|
294 |
+
|
295 |
+
attn_reg_total = 0
|
296 |
+
# step3: calculate loss for each resolution: <new1> <new2> -> <new1> is to penalize outside mask, <new2> to align with mask
|
297 |
+
for k, cross_map in attention_groups.items():
|
298 |
+
map_adjective, map_subject = cross_map[..., 0], cross_map[..., 1]
|
299 |
+
|
300 |
+
map_subject = map_subject / map_subject.max()
|
301 |
+
map_adjective = map_adjective / map_adjective.max()
|
302 |
+
|
303 |
+
gt_mask = F.interpolate(masks, size=map_subject.shape[1:], mode='nearest').squeeze(1)
|
304 |
+
|
305 |
+
if self.reg_full_identity:
|
306 |
+
loss_subject = F.mse_loss(map_subject.float(), gt_mask.float(), reduction='mean')
|
307 |
+
else:
|
308 |
+
loss_subject = map_subject[gt_mask == 0].mean()
|
309 |
+
|
310 |
+
loss_adjective = map_adjective[gt_mask == 0].mean()
|
311 |
+
|
312 |
+
attn_reg_total += self.attn_reg_weight * (loss_subject + loss_adjective)
|
313 |
+
return attn_reg_total
|
314 |
+
|
315 |
+
def load_delta_state_dict(self, delta_state_dict):
|
316 |
+
# load embedding
|
317 |
+
logger = get_logger('mixofshow', log_level='INFO')
|
318 |
+
|
319 |
+
if 'new_concept_embedding' in delta_state_dict and len(delta_state_dict['new_concept_embedding']) != 0:
|
320 |
+
new_concept_tokens = list(delta_state_dict['new_concept_embedding'].keys())
|
321 |
+
|
322 |
+
# check whether new concept is initialized
|
323 |
+
token_embeds = self.text_encoder.get_input_embeddings().weight.data
|
324 |
+
if set(new_concept_tokens) != set(self.new_concept_cfg.keys()):
|
325 |
+
logger.warning('Your checkpoint have different concept with your model, loading existing concepts')
|
326 |
+
|
327 |
+
for concept_name, concept_cfg in self.new_concept_cfg.items():
|
328 |
+
logger.info(f'load: concept_{concept_name}')
|
329 |
+
token_embeds[concept_cfg['concept_token_ids']] = token_embeds[
|
330 |
+
concept_cfg['concept_token_ids']].copy_(delta_state_dict['new_concept_embedding'][concept_name])
|
331 |
+
|
332 |
+
# load text_encoder
|
333 |
+
if 'text_encoder' in delta_state_dict and len(delta_state_dict['text_encoder']) != 0:
|
334 |
+
load_keys = delta_state_dict['text_encoder'].keys()
|
335 |
+
if hasattr(self, 'text_encoder_lora') and len(load_keys) == 2 * len(self.text_encoder_lora):
|
336 |
+
logger.info('loading LoRA for text encoder:')
|
337 |
+
for lora_module in self.text_encoder_lora:
|
338 |
+
for name, param, in lora_module.named_parameters():
|
339 |
+
logger.info(f'load: {lora_module.name}.{name}')
|
340 |
+
param.data.copy_(delta_state_dict['text_encoder'][f'{lora_module.name}.{name}'])
|
341 |
+
else:
|
342 |
+
for name, param, in self.text_encoder.named_parameters():
|
343 |
+
if name in load_keys and 'token_embedding' not in name:
|
344 |
+
logger.info(f'load: {name}')
|
345 |
+
param.data.copy_(delta_state_dict['text_encoder'][f'{name}'])
|
346 |
+
|
347 |
+
# load unet
|
348 |
+
if 'unet' in delta_state_dict and len(delta_state_dict['unet']) != 0:
|
349 |
+
load_keys = delta_state_dict['unet'].keys()
|
350 |
+
if hasattr(self, 'unet_lora') and len(load_keys) == 2 * len(self.unet_lora):
|
351 |
+
logger.info('loading LoRA for unet:')
|
352 |
+
for lora_module in self.unet_lora:
|
353 |
+
for name, param, in lora_module.named_parameters():
|
354 |
+
logger.info(f'load: {lora_module.name}.{name}')
|
355 |
+
param.data.copy_(delta_state_dict['unet'][f'{lora_module.name}.{name}'])
|
356 |
+
else:
|
357 |
+
for name, param, in self.unet.named_parameters():
|
358 |
+
if name in load_keys:
|
359 |
+
logger.info(f'load: {name}')
|
360 |
+
param.data.copy_(delta_state_dict['unet'][f'{name}'])
|
361 |
+
|
362 |
+
def delta_state_dict(self):
|
363 |
+
delta_dict = {'new_concept_embedding': {}, 'text_encoder': {}, 'unet': {}}
|
364 |
+
|
365 |
+
# save_embedding
|
366 |
+
for concept_name, concept_cfg in self.new_concept_cfg.items():
|
367 |
+
learned_embeds = self.text_encoder.get_input_embeddings().weight[concept_cfg['concept_token_ids']]
|
368 |
+
delta_dict['new_concept_embedding'][concept_name] = learned_embeds.detach().cpu()
|
369 |
+
|
370 |
+
# save text model
|
371 |
+
for lora_module in self.text_encoder_lora:
|
372 |
+
for name, param, in lora_module.named_parameters():
|
373 |
+
delta_dict['text_encoder'][f'{lora_module.name}.{name}'] = param.cpu().clone()
|
374 |
+
|
375 |
+
# save unet model
|
376 |
+
for lora_module in self.unet_lora:
|
377 |
+
for name, param, in lora_module.named_parameters():
|
378 |
+
delta_dict['unet'][f'{lora_module.name}.{name}'] = param.cpu().clone()
|
379 |
+
|
380 |
+
return delta_dict
|
mixofshow/utils/__init__.py
ADDED
File without changes
|
mixofshow/utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (151 Bytes). View file
|
|
mixofshow/utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (149 Bytes). View file
|
|
mixofshow/utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (149 Bytes). View file
|
|
mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-38.pyc
ADDED
Binary file (3.64 kB). View file
|
|
mixofshow/utils/__pycache__/convert_edlora_to_diffusers.cpython-39.pyc
ADDED
Binary file (3.58 kB). View file
|
|
mixofshow/utils/__pycache__/ptp_util.cpython-38.pyc
ADDED
Binary file (7.21 kB). View file
|
|
mixofshow/utils/__pycache__/ptp_util.cpython-39.pyc
ADDED
Binary file (7.2 kB). View file
|
|
mixofshow/utils/__pycache__/registry.cpython-38.pyc
ADDED
Binary file (2.49 kB). View file
|
|
mixofshow/utils/__pycache__/registry.cpython-39.pyc
ADDED
Binary file (2.48 kB). View file
|
|
mixofshow/utils/__pycache__/util.cpython-310.pyc
ADDED
Binary file (9.59 kB). View file
|
|
mixofshow/utils/__pycache__/util.cpython-38.pyc
ADDED
Binary file (9.51 kB). View file
|
|
mixofshow/utils/__pycache__/util.cpython-39.pyc
ADDED
Binary file (9.51 kB). View file
|
|
mixofshow/utils/arial.ttf
ADDED
Binary file (367 kB). View file
|
|
mixofshow/utils/convert_edlora_to_diffusers.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
|
3 |
+
|
4 |
+
def load_new_concept(pipe, new_concept_embedding, enable_edlora=True):
|
5 |
+
new_concept_cfg = {}
|
6 |
+
|
7 |
+
for idx, (concept_name, concept_embedding) in enumerate(new_concept_embedding.items()):
|
8 |
+
if enable_edlora:
|
9 |
+
num_new_embedding = 16
|
10 |
+
else:
|
11 |
+
num_new_embedding = 1
|
12 |
+
new_token_names = [f'<new{idx * num_new_embedding + layer_id}>' for layer_id in range(num_new_embedding)]
|
13 |
+
num_added_tokens = pipe.tokenizer.add_tokens(new_token_names)
|
14 |
+
assert num_added_tokens == len(new_token_names), 'some token is already in tokenizer'
|
15 |
+
new_token_ids = [pipe.tokenizer.convert_tokens_to_ids(token_name) for token_name in new_token_names]
|
16 |
+
|
17 |
+
# init embedding
|
18 |
+
pipe.text_encoder.resize_token_embeddings(len(pipe.tokenizer))
|
19 |
+
token_embeds = pipe.text_encoder.get_input_embeddings().weight.data
|
20 |
+
token_embeds[new_token_ids] = concept_embedding.clone().to(token_embeds.device, dtype=token_embeds.dtype)
|
21 |
+
print(f'load embedding: {concept_name}')
|
22 |
+
|
23 |
+
new_concept_cfg.update({
|
24 |
+
concept_name: {
|
25 |
+
'concept_token_ids': new_token_ids,
|
26 |
+
'concept_token_names': new_token_names
|
27 |
+
}
|
28 |
+
})
|
29 |
+
|
30 |
+
return pipe, new_concept_cfg
|
31 |
+
|
32 |
+
|
33 |
+
def merge_lora_into_weight(original_state_dict, lora_state_dict, model_type, alpha):
|
34 |
+
def get_lora_down_name(original_layer_name):
|
35 |
+
if model_type == 'text_encoder':
|
36 |
+
lora_down_name = original_layer_name.replace('q_proj.weight', 'q_proj.lora_down.weight') \
|
37 |
+
.replace('k_proj.weight', 'k_proj.lora_down.weight') \
|
38 |
+
.replace('v_proj.weight', 'v_proj.lora_down.weight') \
|
39 |
+
.replace('out_proj.weight', 'out_proj.lora_down.weight') \
|
40 |
+
.replace('fc1.weight', 'fc1.lora_down.weight') \
|
41 |
+
.replace('fc2.weight', 'fc2.lora_down.weight')
|
42 |
+
else:
|
43 |
+
lora_down_name = k.replace('to_q.weight', 'to_q.lora_down.weight') \
|
44 |
+
.replace('to_k.weight', 'to_k.lora_down.weight') \
|
45 |
+
.replace('to_v.weight', 'to_v.lora_down.weight') \
|
46 |
+
.replace('to_out.0.weight', 'to_out.0.lora_down.weight') \
|
47 |
+
.replace('ff.net.0.proj.weight', 'ff.net.0.proj.lora_down.weight') \
|
48 |
+
.replace('ff.net.2.weight', 'ff.net.2.lora_down.weight') \
|
49 |
+
.replace('proj_out.weight', 'proj_out.lora_down.weight') \
|
50 |
+
.replace('proj_in.weight', 'proj_in.lora_down.weight')
|
51 |
+
|
52 |
+
return lora_down_name
|
53 |
+
|
54 |
+
assert model_type in ['unet', 'text_encoder']
|
55 |
+
new_state_dict = copy.deepcopy(original_state_dict)
|
56 |
+
|
57 |
+
load_cnt = 0
|
58 |
+
for k in new_state_dict.keys():
|
59 |
+
lora_down_name = get_lora_down_name(k)
|
60 |
+
lora_up_name = lora_down_name.replace('lora_down', 'lora_up')
|
61 |
+
|
62 |
+
if lora_up_name in lora_state_dict:
|
63 |
+
load_cnt += 1
|
64 |
+
original_params = new_state_dict[k]
|
65 |
+
lora_down_params = lora_state_dict[lora_down_name].to(original_params.device)
|
66 |
+
lora_up_params = lora_state_dict[lora_up_name].to(original_params.device)
|
67 |
+
if len(original_params.shape) == 4:
|
68 |
+
lora_param = lora_up_params.squeeze() @ lora_down_params.squeeze()
|
69 |
+
lora_param = lora_param.unsqueeze(-1).unsqueeze(-1)
|
70 |
+
else:
|
71 |
+
lora_param = lora_up_params @ lora_down_params
|
72 |
+
merge_params = original_params + alpha * lora_param
|
73 |
+
new_state_dict[k] = merge_params
|
74 |
+
|
75 |
+
print(f'load {load_cnt} LoRAs of {model_type}')
|
76 |
+
return new_state_dict
|
77 |
+
|
78 |
+
|
79 |
+
def convert_edlora(pipe, state_dict, enable_edlora, alpha=0.6):
|
80 |
+
|
81 |
+
state_dict = state_dict['params'] if 'params' in state_dict.keys() else state_dict
|
82 |
+
|
83 |
+
# step 1: load embedding
|
84 |
+
if 'new_concept_embedding' in state_dict and len(state_dict['new_concept_embedding']) != 0:
|
85 |
+
pipe, new_concept_cfg = load_new_concept(pipe, state_dict['new_concept_embedding'], enable_edlora)
|
86 |
+
|
87 |
+
# step 2: merge lora weight to unet
|
88 |
+
unet_lora_state_dict = state_dict['unet']
|
89 |
+
pretrained_unet_state_dict = pipe.unet.state_dict()
|
90 |
+
updated_unet_state_dict = merge_lora_into_weight(pretrained_unet_state_dict, unet_lora_state_dict, model_type='unet', alpha=alpha)
|
91 |
+
pipe.unet.load_state_dict(updated_unet_state_dict)
|
92 |
+
|
93 |
+
# step 3: merge lora weight to text_encoder
|
94 |
+
text_encoder_lora_state_dict = state_dict['text_encoder']
|
95 |
+
pretrained_text_encoder_state_dict = pipe.text_encoder.state_dict()
|
96 |
+
updated_text_encoder_state_dict = merge_lora_into_weight(pretrained_text_encoder_state_dict, text_encoder_lora_state_dict, model_type='text_encoder', alpha=alpha)
|
97 |
+
pipe.text_encoder.load_state_dict(updated_text_encoder_state_dict)
|
98 |
+
|
99 |
+
return pipe, new_concept_cfg
|
mixofshow/utils/ptp_util.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from IPython.display import display
|
8 |
+
from PIL import Image
|
9 |
+
|
10 |
+
|
11 |
+
class EmptyControl:
|
12 |
+
def step_callback(self, x_t):
|
13 |
+
return x_t
|
14 |
+
|
15 |
+
def between_steps(self):
|
16 |
+
return
|
17 |
+
|
18 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
19 |
+
return attn
|
20 |
+
|
21 |
+
|
22 |
+
class AttentionControl(abc.ABC):
|
23 |
+
def step_callback(self, x_t):
|
24 |
+
return x_t
|
25 |
+
|
26 |
+
def between_steps(self):
|
27 |
+
return
|
28 |
+
|
29 |
+
@property
|
30 |
+
def num_uncond_att_layers(self):
|
31 |
+
return self.num_att_layers if self.low_resource else 0
|
32 |
+
|
33 |
+
@abc.abstractmethod
|
34 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
35 |
+
raise NotImplementedError
|
36 |
+
|
37 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
38 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
39 |
+
if self.low_resource:
|
40 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
41 |
+
else:
|
42 |
+
if self.training:
|
43 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
44 |
+
else:
|
45 |
+
h = attn.shape[0]
|
46 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
47 |
+
|
48 |
+
self.cur_att_layer += 1
|
49 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
50 |
+
self.cur_att_layer = 0
|
51 |
+
self.cur_step += 1
|
52 |
+
self.between_steps()
|
53 |
+
return attn
|
54 |
+
|
55 |
+
def reset(self):
|
56 |
+
self.cur_step = 0
|
57 |
+
self.cur_att_layer = 0
|
58 |
+
|
59 |
+
def __init__(self, low_resource, training):
|
60 |
+
self.cur_step = 0
|
61 |
+
self.num_att_layers = -1
|
62 |
+
self.cur_att_layer = 0
|
63 |
+
self.low_resource = low_resource
|
64 |
+
self.training = training
|
65 |
+
|
66 |
+
|
67 |
+
class AttentionStore(AttentionControl):
|
68 |
+
@staticmethod
|
69 |
+
def get_empty_store():
|
70 |
+
return {
|
71 |
+
'down_cross': [],
|
72 |
+
'mid_cross': [],
|
73 |
+
'up_cross': [],
|
74 |
+
'down_self': [],
|
75 |
+
'mid_self': [],
|
76 |
+
'up_self': []
|
77 |
+
}
|
78 |
+
|
79 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
80 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
81 |
+
self.step_store[key].append(attn)
|
82 |
+
return attn
|
83 |
+
|
84 |
+
def between_steps(self):
|
85 |
+
if len(self.attention_store) == 0:
|
86 |
+
self.attention_store = self.step_store
|
87 |
+
else:
|
88 |
+
for key in self.attention_store:
|
89 |
+
for i in range(len(self.attention_store[key])):
|
90 |
+
self.attention_store[key][i] = self.attention_store[key][i] + self.step_store[key][i]
|
91 |
+
self.step_store = self.get_empty_store()
|
92 |
+
|
93 |
+
def get_average_attention(self):
|
94 |
+
average_attention = {
|
95 |
+
key: [item / self.cur_step for item in self.attention_store[key]]
|
96 |
+
for key in self.attention_store
|
97 |
+
}
|
98 |
+
return average_attention
|
99 |
+
|
100 |
+
def reset(self):
|
101 |
+
super(AttentionStore, self).reset()
|
102 |
+
self.step_store = self.get_empty_store()
|
103 |
+
self.attention_store = {}
|
104 |
+
|
105 |
+
def __init__(self, low_resource=False, training=False):
|
106 |
+
super(AttentionStore, self).__init__(low_resource, training)
|
107 |
+
self.step_store = self.get_empty_store()
|
108 |
+
self.attention_store = {}
|
109 |
+
|
110 |
+
|
111 |
+
def text_under_image(image: np.ndarray,
|
112 |
+
text: str,
|
113 |
+
text_color: Tuple[int, int, int] = (0, 0, 0)):
|
114 |
+
h, w, c = image.shape
|
115 |
+
offset = int(h * .2)
|
116 |
+
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
117 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
118 |
+
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
|
119 |
+
img[:h] = image
|
120 |
+
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
121 |
+
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
122 |
+
cv2.putText(img, text, (text_x, text_y), font, 1, text_color, 2)
|
123 |
+
return img
|
124 |
+
|
125 |
+
|
126 |
+
def view_images(images, num_rows=1, offset_ratio=0.02, notebook=True):
|
127 |
+
if type(images) is list:
|
128 |
+
num_empty = len(images) % num_rows
|
129 |
+
elif images.ndim == 4:
|
130 |
+
num_empty = images.shape[0] % num_rows
|
131 |
+
else:
|
132 |
+
images = [images]
|
133 |
+
num_empty = 0
|
134 |
+
|
135 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
136 |
+
images = [image.astype(np.uint8)
|
137 |
+
for image in images] + [empty_images] * num_empty
|
138 |
+
num_items = len(images)
|
139 |
+
|
140 |
+
h, w, c = images[0].shape
|
141 |
+
offset = int(h * offset_ratio)
|
142 |
+
num_cols = num_items // num_rows
|
143 |
+
image_ = np.ones(
|
144 |
+
(h * num_rows + offset * (num_rows - 1), w * num_cols + offset *
|
145 |
+
(num_cols - 1), 3),
|
146 |
+
dtype=np.uint8) * 255
|
147 |
+
for i in range(num_rows):
|
148 |
+
for j in range(num_cols):
|
149 |
+
image_[i * (h + offset):i * (h + offset) + h:, j * (w + offset):j *
|
150 |
+
(w + offset) + w] = images[i * num_cols + j]
|
151 |
+
|
152 |
+
pil_img = Image.fromarray(image_)
|
153 |
+
if notebook is True:
|
154 |
+
display(pil_img)
|
155 |
+
else:
|
156 |
+
return pil_img
|
157 |
+
|
158 |
+
|
159 |
+
def aggregate_attention(attention_store: AttentionStore, res: int,
|
160 |
+
from_where: List[str], prompts: List[str],
|
161 |
+
is_cross: bool, select: int):
|
162 |
+
out = []
|
163 |
+
attention_maps = attention_store.get_average_attention()
|
164 |
+
num_pixels = res**2
|
165 |
+
for location in from_where:
|
166 |
+
for item in attention_maps[
|
167 |
+
f"{location}_{'cross' if is_cross else 'self'}"]:
|
168 |
+
if item.shape[1] == num_pixels:
|
169 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
170 |
+
out.append(cross_maps)
|
171 |
+
out = torch.cat(out, dim=0)
|
172 |
+
out = out.sum(0) / out.shape[0]
|
173 |
+
return out.cpu()
|
174 |
+
|
175 |
+
|
176 |
+
def show_cross_attention(attention_store: AttentionStore,
|
177 |
+
res: int,
|
178 |
+
from_where: List[str],
|
179 |
+
prompts: List[str],
|
180 |
+
tokenizer,
|
181 |
+
select: int = 0,
|
182 |
+
notebook=True):
|
183 |
+
tokens = tokenizer.encode(prompts[select])
|
184 |
+
decoder = tokenizer.decode
|
185 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, prompts, True, select)
|
186 |
+
|
187 |
+
images = []
|
188 |
+
for i in range(len(tokens)):
|
189 |
+
image = attention_maps[:, :, i]
|
190 |
+
image = 255 * image / image.max()
|
191 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
192 |
+
image = image.numpy().astype(np.uint8)
|
193 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
194 |
+
image = text_under_image(image, decoder(int(tokens[i])))
|
195 |
+
images.append(image)
|
196 |
+
|
197 |
+
if notebook is True:
|
198 |
+
view_images(np.stack(images, axis=0))
|
199 |
+
else:
|
200 |
+
return view_images(np.stack(images, axis=0), notebook=False)
|
mixofshow/utils/registry.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
|
2 |
+
|
3 |
+
|
4 |
+
class Registry():
|
5 |
+
"""
|
6 |
+
The registry that provides name -> object mapping, to support third-party
|
7 |
+
users' custom modules.
|
8 |
+
|
9 |
+
To create a registry (e.g. a backbone registry):
|
10 |
+
|
11 |
+
.. code-block:: python
|
12 |
+
|
13 |
+
BACKBONE_REGISTRY = Registry('BACKBONE')
|
14 |
+
|
15 |
+
To register an object:
|
16 |
+
|
17 |
+
.. code-block:: python
|
18 |
+
|
19 |
+
@BACKBONE_REGISTRY.register()
|
20 |
+
class MyBackbone():
|
21 |
+
...
|
22 |
+
|
23 |
+
Or:
|
24 |
+
|
25 |
+
.. code-block:: python
|
26 |
+
|
27 |
+
BACKBONE_REGISTRY.register(MyBackbone)
|
28 |
+
"""
|
29 |
+
def __init__(self, name):
|
30 |
+
"""
|
31 |
+
Args:
|
32 |
+
name (str): the name of this registry
|
33 |
+
"""
|
34 |
+
self._name = name
|
35 |
+
self._obj_map = {}
|
36 |
+
|
37 |
+
def _do_register(self, name, obj):
|
38 |
+
assert (name not in self._obj_map), (
|
39 |
+
f"An object named '{name}' was already registered "
|
40 |
+
f"in '{self._name}' registry!")
|
41 |
+
self._obj_map[name] = obj
|
42 |
+
|
43 |
+
def register(self, obj=None):
|
44 |
+
"""
|
45 |
+
Register the given object under the the name `obj.__name__`.
|
46 |
+
Can be used as either a decorator or not.
|
47 |
+
See docstring of this class for usage.
|
48 |
+
"""
|
49 |
+
if obj is None:
|
50 |
+
# used as a decorator
|
51 |
+
def deco(func_or_class):
|
52 |
+
name = func_or_class.__name__
|
53 |
+
self._do_register(name, func_or_class)
|
54 |
+
return func_or_class
|
55 |
+
|
56 |
+
return deco
|
57 |
+
|
58 |
+
# used as a function call
|
59 |
+
name = obj.__name__
|
60 |
+
self._do_register(name, obj)
|
61 |
+
|
62 |
+
def get(self, name):
|
63 |
+
ret = self._obj_map.get(name)
|
64 |
+
if ret is None:
|
65 |
+
raise KeyError(
|
66 |
+
f"No object named '{name}' found in '{self._name}' registry!")
|
67 |
+
return ret
|
68 |
+
|
69 |
+
def __contains__(self, name):
|
70 |
+
return name in self._obj_map
|
71 |
+
|
72 |
+
def __iter__(self):
|
73 |
+
return iter(self._obj_map.items())
|
74 |
+
|
75 |
+
def keys(self):
|
76 |
+
return self._obj_map.keys()
|
77 |
+
|
78 |
+
|
79 |
+
TRANSFORM_REGISTRY = Registry('transform')
|
mixofshow/utils/util.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import logging
|
3 |
+
import os
|
4 |
+
import os.path
|
5 |
+
import os.path as osp
|
6 |
+
import time
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
import PIL
|
10 |
+
import torch
|
11 |
+
from accelerate.logging import get_logger
|
12 |
+
from accelerate.state import PartialState
|
13 |
+
from PIL import Image, ImageDraw, ImageFont
|
14 |
+
from torchvision.transforms.transforms import ToTensor
|
15 |
+
from torchvision.utils import make_grid
|
16 |
+
|
17 |
+
NEGATIVE_PROMPT = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
|
18 |
+
|
19 |
+
|
20 |
+
# ----------- file/logger util ----------
|
21 |
+
def get_time_str():
|
22 |
+
return time.strftime('%Y%m%d_%H%M%S', time.localtime())
|
23 |
+
|
24 |
+
|
25 |
+
def mkdir_and_rename(path):
|
26 |
+
"""mkdirs. If path exists, rename it with timestamp and create a new one.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
path (str): Folder path.
|
30 |
+
"""
|
31 |
+
if osp.exists(path):
|
32 |
+
new_name = path + '_archived_' + get_time_str()
|
33 |
+
print(f'Path already exists. Rename it to {new_name}', flush=True)
|
34 |
+
os.rename(path, new_name)
|
35 |
+
os.makedirs(path, exist_ok=True)
|
36 |
+
|
37 |
+
|
38 |
+
def make_exp_dirs(opt):
|
39 |
+
"""Make dirs for experiments."""
|
40 |
+
path_opt = opt['path'].copy()
|
41 |
+
if opt['is_train']:
|
42 |
+
mkdir_and_rename(path_opt.pop('experiments_root'))
|
43 |
+
else:
|
44 |
+
mkdir_and_rename(path_opt.pop('results_root'))
|
45 |
+
for key, path in path_opt.items():
|
46 |
+
if ('strict_load' in key) or ('pretrain_network' in key) or (
|
47 |
+
'resume' in key) or ('param_key' in key) or ('lora_path' in key):
|
48 |
+
continue
|
49 |
+
else:
|
50 |
+
os.makedirs(path, exist_ok=True)
|
51 |
+
|
52 |
+
|
53 |
+
def copy_opt_file(opt_file, experiments_root):
|
54 |
+
# copy the yml file to the experiment root
|
55 |
+
import sys
|
56 |
+
import time
|
57 |
+
from shutil import copyfile
|
58 |
+
cmd = ' '.join(sys.argv)
|
59 |
+
filename = osp.join(experiments_root, osp.basename(opt_file))
|
60 |
+
copyfile(opt_file, filename)
|
61 |
+
|
62 |
+
with open(filename, 'r+') as f:
|
63 |
+
lines = f.readlines()
|
64 |
+
lines.insert(
|
65 |
+
0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
|
66 |
+
f.seek(0)
|
67 |
+
f.writelines(lines)
|
68 |
+
|
69 |
+
|
70 |
+
def set_path_logger(accelerator, root_path, config_path, opt, is_train=True):
|
71 |
+
opt['is_train'] = is_train
|
72 |
+
|
73 |
+
if is_train:
|
74 |
+
experiments_root = osp.join(root_path, 'experiments', opt['name'])
|
75 |
+
opt['path']['experiments_root'] = experiments_root
|
76 |
+
opt['path']['models'] = osp.join(experiments_root, 'models')
|
77 |
+
opt['path']['log'] = experiments_root
|
78 |
+
opt['path']['visualization'] = osp.join(experiments_root,
|
79 |
+
'visualization')
|
80 |
+
else:
|
81 |
+
results_root = osp.join(root_path, 'results', opt['name'])
|
82 |
+
opt['path']['results_root'] = results_root
|
83 |
+
opt['path']['log'] = results_root
|
84 |
+
opt['path']['visualization'] = osp.join(results_root, 'visualization')
|
85 |
+
|
86 |
+
# Handle the output folder creation
|
87 |
+
if accelerator.is_main_process:
|
88 |
+
make_exp_dirs(opt)
|
89 |
+
|
90 |
+
accelerator.wait_for_everyone()
|
91 |
+
|
92 |
+
if is_train:
|
93 |
+
copy_opt_file(config_path, opt['path']['experiments_root'])
|
94 |
+
log_file = osp.join(opt['path']['log'],
|
95 |
+
f"train_{opt['name']}_{get_time_str()}.log")
|
96 |
+
set_logger(log_file)
|
97 |
+
else:
|
98 |
+
copy_opt_file(config_path, opt['path']['results_root'])
|
99 |
+
log_file = osp.join(opt['path']['log'],
|
100 |
+
f"test_{opt['name']}_{get_time_str()}.log")
|
101 |
+
set_logger(log_file)
|
102 |
+
|
103 |
+
|
104 |
+
def set_logger(log_file=None):
|
105 |
+
# Make one log on every process with the configuration for debugging.
|
106 |
+
format_str = '%(asctime)s %(levelname)s: %(message)s'
|
107 |
+
log_level = logging.INFO
|
108 |
+
handlers = []
|
109 |
+
|
110 |
+
file_handler = logging.FileHandler(log_file, 'w')
|
111 |
+
file_handler.setFormatter(logging.Formatter(format_str))
|
112 |
+
file_handler.setLevel(log_level)
|
113 |
+
handlers.append(file_handler)
|
114 |
+
|
115 |
+
stream_handler = logging.StreamHandler()
|
116 |
+
stream_handler.setFormatter(logging.Formatter(format_str))
|
117 |
+
handlers.append(stream_handler)
|
118 |
+
|
119 |
+
logging.basicConfig(handlers=handlers, level=log_level)
|
120 |
+
|
121 |
+
|
122 |
+
def dict2str(opt, indent_level=1):
|
123 |
+
"""dict to string for printing options.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
opt (dict): Option dict.
|
127 |
+
indent_level (int): Indent level. Default: 1.
|
128 |
+
|
129 |
+
Return:
|
130 |
+
(str): Option string for printing.
|
131 |
+
"""
|
132 |
+
msg = '\n'
|
133 |
+
for k, v in opt.items():
|
134 |
+
if isinstance(v, dict):
|
135 |
+
msg += ' ' * (indent_level * 2) + k + ':['
|
136 |
+
msg += dict2str(v, indent_level + 1)
|
137 |
+
msg += ' ' * (indent_level * 2) + ']\n'
|
138 |
+
else:
|
139 |
+
msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
|
140 |
+
return msg
|
141 |
+
|
142 |
+
|
143 |
+
class MessageLogger():
|
144 |
+
"""Message logger for printing.
|
145 |
+
|
146 |
+
Args:
|
147 |
+
opt (dict): Config. It contains the following keys:
|
148 |
+
name (str): Exp name.
|
149 |
+
logger (dict): Contains 'print_freq' (str) for logger interval.
|
150 |
+
train (dict): Contains 'total_iter' (int) for total iters.
|
151 |
+
use_tb_logger (bool): Use tensorboard logger.
|
152 |
+
start_iter (int): Start iter. Default: 1.
|
153 |
+
tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
|
154 |
+
"""
|
155 |
+
def __init__(self, opt, start_iter=1):
|
156 |
+
self.exp_name = opt['name']
|
157 |
+
self.interval = opt['logger']['print_freq']
|
158 |
+
self.start_iter = start_iter
|
159 |
+
self.max_iters = opt['train']['total_iter']
|
160 |
+
self.start_time = time.time()
|
161 |
+
self.logger = get_logger('mixofshow', log_level='INFO')
|
162 |
+
|
163 |
+
def reset_start_time(self):
|
164 |
+
self.start_time = time.time()
|
165 |
+
|
166 |
+
def __call__(self, log_vars):
|
167 |
+
"""Format logging message.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
log_vars (dict): It contains the following keys:
|
171 |
+
epoch (int): Epoch number.
|
172 |
+
iter (int): Current iter.
|
173 |
+
lrs (list): List for learning rates.
|
174 |
+
|
175 |
+
time (float): Iter time.
|
176 |
+
data_time (float): Data time for each iter.
|
177 |
+
"""
|
178 |
+
# epoch, iter, learning rates
|
179 |
+
current_iter = log_vars.pop('iter')
|
180 |
+
lrs = log_vars.pop('lrs')
|
181 |
+
|
182 |
+
message = (
|
183 |
+
f'[{self.exp_name[:5]}..][Iter:{current_iter:8,d}, lr:('
|
184 |
+
)
|
185 |
+
for v in lrs:
|
186 |
+
message += f'{v:.3e},'
|
187 |
+
message += ')] '
|
188 |
+
|
189 |
+
# time and estimated time
|
190 |
+
total_time = time.time() - self.start_time
|
191 |
+
time_sec_avg = total_time / (current_iter - self.start_iter + 1)
|
192 |
+
eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
|
193 |
+
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
194 |
+
message += f'[eta: {eta_str}] '
|
195 |
+
|
196 |
+
# other items, especially losses
|
197 |
+
for k, v in log_vars.items():
|
198 |
+
message += f'{k}: {v:.4e} '
|
199 |
+
|
200 |
+
self.logger.info(message)
|
201 |
+
|
202 |
+
|
203 |
+
def reduce_loss_dict(accelerator, loss_dict):
|
204 |
+
"""reduce loss dict.
|
205 |
+
|
206 |
+
In distributed training, it averages the losses among different GPUs .
|
207 |
+
|
208 |
+
Args:
|
209 |
+
loss_dict (OrderedDict): Loss dict.
|
210 |
+
"""
|
211 |
+
with torch.no_grad():
|
212 |
+
keys = []
|
213 |
+
losses = []
|
214 |
+
for name, value in loss_dict.items():
|
215 |
+
keys.append(name)
|
216 |
+
losses.append(value)
|
217 |
+
losses = torch.stack(losses, 0)
|
218 |
+
losses = accelerator.reduce(losses)
|
219 |
+
|
220 |
+
world_size = PartialState().num_processes
|
221 |
+
losses /= world_size
|
222 |
+
|
223 |
+
loss_dict = {key: loss for key, loss in zip(keys, losses)}
|
224 |
+
|
225 |
+
log_dict = OrderedDict()
|
226 |
+
for name, value in loss_dict.items():
|
227 |
+
log_dict[name] = value.mean().item()
|
228 |
+
|
229 |
+
return log_dict
|
230 |
+
|
231 |
+
|
232 |
+
def pil_imwrite(img, file_path, auto_mkdir=True):
|
233 |
+
"""Write image to file.
|
234 |
+
Args:
|
235 |
+
img (ndarray): Image array to be written.
|
236 |
+
file_path (str): Image file path.
|
237 |
+
params (None or list): Same as opencv's :func:`imwrite` interface.
|
238 |
+
auto_mkdir (bool): If the parent folder of `file_path` does not exist,
|
239 |
+
whether to create it automatically.
|
240 |
+
Returns:
|
241 |
+
bool: Successful or not.
|
242 |
+
"""
|
243 |
+
assert isinstance(
|
244 |
+
img, PIL.Image.Image), 'model should return a list of PIL images'
|
245 |
+
if auto_mkdir:
|
246 |
+
dir_name = os.path.abspath(os.path.dirname(file_path))
|
247 |
+
os.makedirs(dir_name, exist_ok=True)
|
248 |
+
img.save(file_path)
|
249 |
+
|
250 |
+
|
251 |
+
def draw_prompt(text, height, width, font_size=45):
|
252 |
+
img = Image.new('RGB', (width, height), (255, 255, 255))
|
253 |
+
draw = ImageDraw.Draw(img)
|
254 |
+
font = ImageFont.truetype(
|
255 |
+
osp.join(osp.dirname(osp.abspath(__file__)), 'arial.ttf'), font_size)
|
256 |
+
|
257 |
+
guess_count = 0
|
258 |
+
|
259 |
+
while font.font.getsize(text[:guess_count])[0][
|
260 |
+
0] + 0.1 * width < width - 0.1 * width and guess_count < len(
|
261 |
+
text): # centerize
|
262 |
+
guess_count += 1
|
263 |
+
|
264 |
+
text_new = ''
|
265 |
+
for idx, s in enumerate(text):
|
266 |
+
if idx % guess_count == 0:
|
267 |
+
text_new += '\n'
|
268 |
+
if s == ' ':
|
269 |
+
s = '' # new line trip the first space
|
270 |
+
text_new += s
|
271 |
+
|
272 |
+
draw.text([int(0.1 * width), int(0.3 * height)],
|
273 |
+
text_new,
|
274 |
+
font=font,
|
275 |
+
fill='black')
|
276 |
+
return img
|
277 |
+
|
278 |
+
|
279 |
+
def compose_visualize(dir_path):
|
280 |
+
file_list = sorted(os.listdir(dir_path))
|
281 |
+
img_list = []
|
282 |
+
info_dict = {'prompts': set(), 'sample_args': set(), 'suffix': set()}
|
283 |
+
for filename in file_list:
|
284 |
+
prompt, sample_args, index, suffix = osp.splitext(
|
285 |
+
osp.basename(filename))[0].split('---')
|
286 |
+
|
287 |
+
filepath = osp.join(dir_path, filename)
|
288 |
+
img = ToTensor()(Image.open(filepath))
|
289 |
+
height, width = img.shape[1:]
|
290 |
+
|
291 |
+
if prompt not in info_dict['prompts']:
|
292 |
+
img_list.append(ToTensor()(draw_prompt(prompt,
|
293 |
+
height=height,
|
294 |
+
width=width,
|
295 |
+
font_size=45)))
|
296 |
+
info_dict['prompts'].add(prompt)
|
297 |
+
info_dict['sample_args'].add(sample_args)
|
298 |
+
info_dict['suffix'].add(suffix)
|
299 |
+
|
300 |
+
img_list.append(img)
|
301 |
+
assert len(
|
302 |
+
info_dict['sample_args']
|
303 |
+
) == 1, 'compose dir should contain images form same sample args.'
|
304 |
+
assert len(info_dict['suffix']
|
305 |
+
) == 1, 'compose dir should contain images form same suffix.'
|
306 |
+
|
307 |
+
grid = make_grid(img_list, nrow=len(img_list) // len(info_dict['prompts']))
|
308 |
+
# Add 0.5 after unnormalizing to [0, 255] to round to nearest integer
|
309 |
+
ndarr = grid.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to(
|
310 |
+
'cpu', torch.uint8).numpy()
|
311 |
+
im = Image.fromarray(ndarr)
|
312 |
+
save_name = f"{info_dict['sample_args'].pop()}---{info_dict['suffix'].pop()}.jpg"
|
313 |
+
im.save(osp.join(osp.dirname(dir_path), save_name))
|
orthogonal_mats/1280.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8cacbfe6dda3140404a86019e487349f7b693667cd7efe5848fe9fc04b1a3618
|
3 |
+
size 13107328
|
orthogonal_mats/320.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:99b8a3b3cf101ac83f43eda74392447207bdc4745e8e77626219d994ea8f2ae9
|
3 |
+
size 819328
|
orthogonal_mats/640.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0c3c38d378cf4c22f7ab4cd3fb734c846afd4500ca58f253cb63c44304c360aa
|
3 |
+
size 3276928
|
orthogonal_mats/768.npy
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:64bd95b552140a47665e46cb5736cd8562e40b5f27f0f262a4b7563d13061daf
|
3 |
+
size 4718720
|