kamwoh commited on
Commit
617065a
1 Parent(s): d5e1f83

copied from dreamcreature main repo

Browse files
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .idea/
2
+ src/data/cub200_2011/train
3
+ src/data/dogs/Images
4
+ __pycache__
5
+ */.ipynb_checkpoints
6
+ /.ipynb_checkpoints/requirements-checkpoint.txt
7
+
8
+ *.bin
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import re
4
+ import shutil
5
+
6
+ import gradio as gr
7
+ import requests
8
+ import torch
9
+
10
+ from dreamcreature.pipeline import create_args, load_pipeline
11
+
12
+ CUB_DESCRIPTION = """
13
+ # DreamCreature (CUB-200-2011)
14
+ To create your own creature, you can type:
15
+
16
+ `"a photo of a <head:id> <wing:id> bird"` where `id` ranges from 1~200 (200 classes corresponding to CUB-200-2011)
17
+
18
+ For instance `"a photo of a <head:17> <wing:18> bird"` using head of `cardinal (17)` and wing of `spotted catbird (18)`
19
+
20
+ Please see `id` in https://github.com/kamwoh/dreamcreature/blob/master/src/data/cub200_2011/class_names.txt
21
+
22
+ You can also try any prompt you like such as:
23
+
24
+ Sub-concept transfer: `"a photo of a <wing:17> cat"`
25
+
26
+ Inspiring design: `"a photo of a <head:101> <wing:191> teddy bear"`
27
+
28
+ (Experimental) You can also use two parts together such as:
29
+
30
+ `"a photo of a <head:17> <head:18> bird"` mixing head of `cardinal (17)` and `spotted catbird (18)`
31
+
32
+ The current available parts are: `head`, `body`, `wing`, `tail`, and `leg`
33
+
34
+ """
35
+
36
+ DOG_DESCRIPTION = """
37
+ # DreamCreature (Stanford Dogs)
38
+ To create your own creature, you can type:
39
+
40
+ `"a photo of a <nose:id> <ear:id> dog"` where `id` ranges from 0~119 (120 classes corresponding to Stanford Dogs)
41
+
42
+ For instance `"a photo of a <nose:2> <ear:112> dog"` using head of `maltese dog (2)` and wing of `cardigan (112)`
43
+
44
+ Please see `id` in https://github.com/kamwoh/dreamcreature/blob/master/src/data/dogs/class_names.txt
45
+
46
+ Sub-concept transfer: `"a photo of a <ear:112> cat"`
47
+
48
+ Inspiring design: `"a photo of a <eye:38> <body:38> teddy bear"`
49
+
50
+ (Experimental) You can also use two parts together such as:
51
+
52
+ `"a photo of a <nose:1> <nose:112> dog"` mixing head of `maltese dog (2)` and `spotted cardigan (112)`
53
+
54
+ The current available parts are: `eye`, `neck`, `ear`, `body`, `leg`, `nose` and `forehead`
55
+
56
+ """
57
+
58
+
59
+ def prepare_pipeline(model_name):
60
+ is_cub = 'cub' in model_name
61
+
62
+ checkpoint_name = {
63
+ 'dreamcreature-sd1.5-cub200': 'checkpoint-74900',
64
+ 'dreamcreature-sd1.5-dog': 'checkpoint-150000'
65
+ }
66
+
67
+ repo_url = f"https://huggingface.co/kamwoh/{model_name}/resolve/main"
68
+ file_url = repo_url + f"/{checkpoint_name}/pytorch_model.bin"
69
+ local_path = f"{model_name}/{checkpoint_name}/pytorch_model.bin"
70
+ os.makedirs(f"{model_name}/{checkpoint_name}", exist_ok=True)
71
+ download_file(file_url, local_path)
72
+
73
+ file_url = repo_url + f"/{checkpoint_name}/pytorch_model_1.bin"
74
+ local_path = f"{model_name}/{checkpoint_name}/pytorch_model_1.bin"
75
+ download_file(file_url, local_path)
76
+
77
+ OUTPUT_DIR = model_name
78
+
79
+ args = create_args(OUTPUT_DIR)
80
+ if 'dpo' in OUTPUT_DIR:
81
+ args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
82
+
83
+ pipe = load_pipeline(args, torch.float16, 'cuda')
84
+ pipe = pipe.to(torch.float16)
85
+
86
+ pipe.verbose = True
87
+ pipe.v = 're'
88
+
89
+ if is_cub:
90
+ pipe.num_k_per_part = 200
91
+
92
+ MAPPING = {
93
+ 'body': 0,
94
+ 'tail': 1,
95
+ 'head': 2,
96
+ 'wing': 4,
97
+ 'leg': 6
98
+ }
99
+
100
+ ID2NAME = open('data/cub200_2011/class_names.txt').readlines()
101
+ ID2NAME = [line.strip() for line in ID2NAME]
102
+
103
+ else:
104
+ pipe.num_k_per_part = 120
105
+
106
+ MAPPING = {
107
+ 'eye': 0,
108
+ 'neck': 2,
109
+ 'ear': 3,
110
+ 'body': 4,
111
+ 'leg': 5,
112
+ 'nose': 6,
113
+ 'forehead': 7
114
+ }
115
+
116
+ ID2NAME = open('data/dogs/class_names.txt').readlines()
117
+ ID2NAME = [line.strip() for line in ID2NAME]
118
+
119
+ return pipe, MAPPING, ID2NAME
120
+
121
+
122
+ def download_file(url, local_path):
123
+ if os.path.exists(local_path):
124
+ return
125
+
126
+ with requests.get(url, stream=True) as r:
127
+ with open(local_path, 'wb') as f:
128
+ shutil.copyfileobj(r.raw, f)
129
+
130
+
131
+ def process_text(text, MAPPING, ID2NAME):
132
+ pattern = r"<([^:>]+):(\d+)>"
133
+ result = text
134
+ offset = 0
135
+
136
+ part2id = []
137
+
138
+ for match in re.finditer(pattern, text):
139
+ key = match.group(1)
140
+ clsid = int(match.group(2))
141
+ clsid = min(max(clsid, 1), 200) # must be 1~200
142
+
143
+ replacement = f"<{MAPPING[key]}:{clsid - 1}>"
144
+ start, end = match.span()
145
+
146
+ # Adjust the start and end positions based on the offset from previous replacements
147
+ start += offset
148
+ end += offset
149
+
150
+ # Replace the matched text with the replacement
151
+ result = result[:start] + replacement + result[end:]
152
+
153
+ # Update the offset for the next replacement
154
+ offset += len(replacement) - (end - start)
155
+
156
+ part2id.append(f'{key}: {ID2NAME[clsid - 1]}')
157
+
158
+ return result, part2id
159
+
160
+
161
+ def generate_images(model_name, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
162
+ generator = torch.Generator(device='cuda')
163
+ generator = generator.manual_seed(int(seed))
164
+
165
+ try:
166
+ pipe, MAPPING, ID2NAME = prepare_pipeline(model_name)
167
+
168
+ prompt, part2id = process_text(prompt, MAPPING, ID2NAME)
169
+ negative_prompt, _ = process_text(negative_prompt, MAPPING, ID2NAME)
170
+
171
+ images = pipe(prompt,
172
+ negative_prompt=negative_prompt, generator=generator,
173
+ num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
174
+ num_images_per_prompt=num_images).images
175
+
176
+ del pipe
177
+ except Exception as e:
178
+ raise gr.Error(f"Probably due to the prompt have invalid input, please follow the instruction. "
179
+ f"The error message: {e}")
180
+ finally:
181
+ gc.collect()
182
+ torch.cuda.empty_cache()
183
+
184
+ return images, '; '.join(part2id)
185
+
186
+
187
+ with gr.Blocks(title="DreamCreature") as demo:
188
+ with gr.Row():
189
+ main_desc = gr.Markdown(CUB_DESCRIPTION)
190
+ with gr.Column():
191
+ with gr.Row():
192
+ with gr.Group():
193
+ dropdown = gr.Dropdown(choices=["dreamcreature-sd1.5-cub200",
194
+ "dreamcreature-sd1.5-dog"],
195
+ value="dreamcreature-sd1.5-cub200")
196
+ prompt = gr.Textbox(label="Prompt", value="a photo of a <head:101> <wing:191> teddy bear")
197
+ negative_prompt = gr.Textbox(label="Negative Prompt",
198
+ value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic")
199
+ num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Num Inference Steps")
200
+ guidance_scale = gr.Slider(minimum=2, maximum=20, step=0.1, value=7.5, label="Guidance Scale")
201
+ num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images")
202
+ seed = gr.Number(label="Seed", value=777881414)
203
+ button = gr.Button()
204
+
205
+ with gr.Column():
206
+ output_images = gr.Gallery(columns=4, label='Output')
207
+ markdown_labels = gr.Markdown("")
208
+
209
+ dropdown.change(fn=lambda x: {'dreamcreature-sd1.5-cub200': CUB_DESCRIPTION,
210
+ 'dreamcreature-sd1.5-dog': DOG_DESCRIPTION}[x], inputs=dropdown, outputs=main_desc)
211
+ button.click(fn=generate_images,
212
+ inputs=[dropdown, prompt, negative_prompt, num_inference_steps, guidance_scale, num_images,
213
+ seed], outputs=[output_images, markdown_labels], show_progress=True)
214
+
215
+ demo.queue().launch(inline=False, share=True, debug=True, server_name='0.0.0.0')
data/cub200_2011/class_names.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ black footed albatross
2
+ laysan albatross
3
+ sooty albatross
4
+ groove billed ani
5
+ crested auklet
6
+ least auklet
7
+ parakeet auklet
8
+ rhinoceros auklet
9
+ brewer blackbird
10
+ red winged blackbird
11
+ rusty blackbird
12
+ yellow headed blackbird
13
+ bobolink
14
+ indigo bunting
15
+ lazuli bunting
16
+ painted bunting
17
+ cardinal
18
+ spotted catbird
19
+ gray catbird
20
+ yellow breasted chat
21
+ eastern towhee
22
+ chuck will widow
23
+ brandt cormorant
24
+ red faced cormorant
25
+ pelagic cormorant
26
+ bronzed cowbird
27
+ shiny cowbird
28
+ brown creeper
29
+ american crow
30
+ fish crow
31
+ black billed cuckoo
32
+ mangrove cuckoo
33
+ yellow billed cuckoo
34
+ gray crowned rosy finch
35
+ purple finch
36
+ northern flicker
37
+ acadian flycatcher
38
+ great crested flycatcher
39
+ least flycatcher
40
+ olive sided flycatcher
41
+ scissor tailed flycatcher
42
+ vermilion flycatcher
43
+ yellow bellied flycatcher
44
+ frigatebird
45
+ northern fulmar
46
+ gadwall
47
+ american goldfinch
48
+ european goldfinch
49
+ boat tailed grackle
50
+ eared grebe
51
+ horned grebe
52
+ pied billed grebe
53
+ western grebe
54
+ blue grosbeak
55
+ evening grosbeak
56
+ pine grosbeak
57
+ rose breasted grosbeak
58
+ pigeon guillemot
59
+ california gull
60
+ glaucous winged gull
61
+ heermann gull
62
+ herring gull
63
+ ivory gull
64
+ ring billed gull
65
+ slaty backed gull
66
+ western gull
67
+ anna hummingbird
68
+ ruby throated hummingbird
69
+ rufous hummingbird
70
+ green violetear
71
+ long tailed jaeger
72
+ pomarine jaeger
73
+ blue jay
74
+ florida jay
75
+ green jay
76
+ dark eyed junco
77
+ tropical kingbird
78
+ gray kingbird
79
+ belted kingfisher
80
+ green kingfisher
81
+ pied kingfisher
82
+ ringed kingfisher
83
+ white breasted kingfisher
84
+ red legged kittiwake
85
+ horned lark
86
+ pacific loon
87
+ mallard
88
+ western meadowlark
89
+ hooded merganser
90
+ red breasted merganser
91
+ mockingbird
92
+ nighthawk
93
+ clark nutcracker
94
+ white breasted nuthatch
95
+ baltimore oriole
96
+ hooded oriole
97
+ orchard oriole
98
+ scott oriole
99
+ ovenbird
100
+ brown pelican
101
+ white pelican
102
+ western wood pewee
103
+ sayornis
104
+ american pipit
105
+ whip poor will
106
+ horned puffin
107
+ common raven
108
+ white necked raven
109
+ american redstart
110
+ geococcyx
111
+ loggerhead shrike
112
+ great grey shrike
113
+ baird sparrow
114
+ black throated sparrow
115
+ brewer sparrow
116
+ chipping sparrow
117
+ clay colored sparrow
118
+ house sparrow
119
+ field sparrow
120
+ fox sparrow
121
+ grasshopper sparrow
122
+ harris sparrow
123
+ henslow sparrow
124
+ le conte sparrow
125
+ lincoln sparrow
126
+ nelson sharp tailed sparrow
127
+ savannah sparrow
128
+ seaside sparrow
129
+ song sparrow
130
+ tree sparrow
131
+ vesper sparrow
132
+ white crowned sparrow
133
+ white throated sparrow
134
+ cape glossy starling
135
+ bank swallow
136
+ barn swallow
137
+ cliff swallow
138
+ tree swallow
139
+ scarlet tanager
140
+ summer tanager
141
+ artic tern
142
+ black tern
143
+ caspian tern
144
+ common tern
145
+ elegant tern
146
+ forsters tern
147
+ least tern
148
+ green tailed towhee
149
+ brown thrasher
150
+ sage thrasher
151
+ black capped vireo
152
+ blue headed vireo
153
+ philadelphia vireo
154
+ red eyed vireo
155
+ warbling vireo
156
+ white eyed vireo
157
+ yellow throated vireo
158
+ bay breasted warbler
159
+ black and white warbler
160
+ black throated blue warbler
161
+ blue winged warbler
162
+ canada warbler
163
+ cape may warbler
164
+ cerulean warbler
165
+ chestnut sided warbler
166
+ golden winged warbler
167
+ hooded warbler
168
+ kentucky warbler
169
+ magnolia warbler
170
+ mourning warbler
171
+ myrtle warbler
172
+ nashville warbler
173
+ orange crowned warbler
174
+ palm warbler
175
+ pine warbler
176
+ prairie warbler
177
+ prothonotary warbler
178
+ swainson warbler
179
+ tennessee warbler
180
+ wilson warbler
181
+ worm eating warbler
182
+ yellow warbler
183
+ northern waterthrush
184
+ louisiana waterthrush
185
+ bohemian waxwing
186
+ cedar waxwing
187
+ american three toed woodpecker
188
+ pileated woodpecker
189
+ red bellied woodpecker
190
+ red cockaded woodpecker
191
+ red headed woodpecker
192
+ downy woodpecker
193
+ bewick wren
194
+ cactus wren
195
+ carolina wren
196
+ house wren
197
+ marsh wren
198
+ rock wren
199
+ winter wren
200
+ common yellowthroat
data/cub200_2011/pretrained_kmeans.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00b2ff84f80daa3cdbd4b18e4088fd900b70a8a192a70957c34e4369d6065e65
3
+ size 6874495
data/cub200_2011/train.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/cub200_2011/train_caps_better_m8_k256.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/dogs/class_names.txt ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Chihuahua
2
+ Japanese spaniel
3
+ Maltese dog
4
+ Pekinese
5
+ Shih Tzu
6
+ Blenheim spaniel
7
+ papillon
8
+ toy terrier
9
+ Rhodesian ridgeback
10
+ Afghan hound
11
+ basset
12
+ beagle
13
+ bloodhound
14
+ bluetick
15
+ black and tan coonhound
16
+ Walker hound
17
+ English foxhound
18
+ redbone
19
+ borzoi
20
+ Irish wolfhound
21
+ Italian greyhound
22
+ whippet
23
+ Ibizan hound
24
+ Norwegian elkhound
25
+ otterhound
26
+ Saluki
27
+ Scottish deerhound
28
+ Weimaraner
29
+ Staffordshire bullterrier
30
+ American Staffordshire terrier
31
+ Bedlington terrier
32
+ Border terrier
33
+ Kerry blue terrier
34
+ Irish terrier
35
+ Norfolk terrier
36
+ Norwich terrier
37
+ Yorkshire terrier
38
+ wire haired fox terrier
39
+ Lakeland terrier
40
+ Sealyham terrier
41
+ Airedale
42
+ cairn
43
+ Australian terrier
44
+ Dandie Dinmont
45
+ Boston bull
46
+ miniature schnauzer
47
+ giant schnauzer
48
+ standard schnauzer
49
+ Scotch terrier
50
+ Tibetan terrier
51
+ silky terrier
52
+ soft coated wheaten terrier
53
+ West Highland white terrier
54
+ Lhasa
55
+ flat coated retriever
56
+ curly coated retriever
57
+ golden retriever
58
+ Labrador retriever
59
+ Chesapeake Bay retriever
60
+ German short haired pointer
61
+ vizsla
62
+ English setter
63
+ Irish setter
64
+ Gordon setter
65
+ Brittany spaniel
66
+ clumber
67
+ English springer
68
+ Welsh springer spaniel
69
+ cocker spaniel
70
+ Sussex spaniel
71
+ Irish water spaniel
72
+ kuvasz
73
+ schipperke
74
+ groenendael
75
+ malinois
76
+ briard
77
+ kelpie
78
+ komondor
79
+ Old English sheepdog
80
+ Shetland sheepdog
81
+ collie
82
+ Border collie
83
+ Bouvier des Flandres
84
+ Rottweiler
85
+ German shepherd
86
+ Doberman
87
+ miniature pinscher
88
+ Greater Swiss Mountain dog
89
+ Bernese mountain dog
90
+ Appenzeller
91
+ EntleBucher
92
+ boxer
93
+ bull mastiff
94
+ Tibetan mastiff
95
+ French bulldog
96
+ Great Dane
97
+ Saint Bernard
98
+ Eskimo dog
99
+ malamute
100
+ Siberian husky
101
+ affenpinscher
102
+ basenji
103
+ pug
104
+ Leonberg
105
+ Newfoundland
106
+ Great Pyrenees
107
+ Samoyed
108
+ Pomeranian
109
+ chow
110
+ keeshond
111
+ Brabancon griffon
112
+ Pembroke
113
+ Cardigan
114
+ toy poodle
115
+ miniature poodle
116
+ standard poodle
117
+ Mexican hairless
118
+ dingo
119
+ dhole
120
+ African hunting dog
data/dogs/pretrained_kmeans.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bf723669b2d6dad50d58a6d7b3dad7fafa6b49d8a3fca3fd5b713662ccd4b88
3
+ size 6874495
data/dogs/train.txt ADDED
The diff for this file is too large to render. See raw diff
 
data/dogs/train_caps_better_m8_k256.txt ADDED
The diff for this file is too large to render. See raw diff
 
dreamcreature/__init__.py ADDED
File without changes
dreamcreature/attn_processor.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers.models.attention_processor import *
2
+
3
+
4
+ class LoRAAttnProcessorCustom(nn.Module, AttnProcessor):
5
+ r"""
6
+ Processor for implementing the LoRA attention mechanism.
7
+
8
+ Args:
9
+ hidden_size (`int`, *optional*):
10
+ The hidden size of the attention layer.
11
+ cross_attention_dim (`int`, *optional*):
12
+ The number of channels in the `encoder_hidden_states`.
13
+ rank (`int`, defaults to 4):
14
+ The dimension of the LoRA update matrices.
15
+ network_alpha (`int`, *optional*):
16
+ Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
17
+ """
18
+
19
+ def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None, **kwargs):
20
+ super().__init__()
21
+
22
+ self.hidden_size = hidden_size
23
+ self.cross_attention_dim = cross_attention_dim
24
+ self.rank = rank
25
+
26
+ q_rank = kwargs.pop("q_rank", None)
27
+ q_hidden_size = kwargs.pop("q_hidden_size", None)
28
+ q_rank = q_rank if q_rank is not None else rank
29
+ q_hidden_size = q_hidden_size if q_hidden_size is not None else hidden_size
30
+
31
+ v_rank = kwargs.pop("v_rank", None)
32
+ v_hidden_size = kwargs.pop("v_hidden_size", None)
33
+ v_rank = v_rank if v_rank is not None else rank
34
+ v_hidden_size = v_hidden_size if v_hidden_size is not None else hidden_size
35
+
36
+ out_rank = kwargs.pop("out_rank", None)
37
+ out_hidden_size = kwargs.pop("out_hidden_size", None)
38
+ out_rank = out_rank if out_rank is not None else rank
39
+ out_hidden_size = out_hidden_size if out_hidden_size is not None else hidden_size
40
+
41
+ self.to_q_lora = LoRALinearLayer(q_hidden_size, q_hidden_size, q_rank, network_alpha)
42
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
43
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or v_hidden_size, v_hidden_size, v_rank, network_alpha)
44
+ self.to_out_lora = LoRALinearLayer(out_hidden_size, out_hidden_size, out_rank, network_alpha)
45
+
46
+ def __call__(self, attn: Attention, hidden_states, *args, **kwargs):
47
+ self_cls_name = self.__class__.__name__
48
+ deprecate(
49
+ self_cls_name,
50
+ "0.26.0",
51
+ (
52
+ f"Make sure use {self_cls_name[4:]} instead by setting"
53
+ "LoRA layers to `self.{to_q,to_k,to_v,to_out[0]}.lora_layer` respectively. This will be done automatically when using"
54
+ " `LoraLoaderMixin.load_lora_weights`"
55
+ ),
56
+ )
57
+ attn.to_q.lora_layer = self.to_q_lora.to(hidden_states.device)
58
+ attn.to_k.lora_layer = self.to_k_lora.to(hidden_states.device)
59
+ attn.to_v.lora_layer = self.to_v_lora.to(hidden_states.device)
60
+ attn.to_out[0].lora_layer = self.to_out_lora.to(hidden_states.device)
61
+
62
+ attn._modules.pop("processor")
63
+ attn.processor = AttnProcessorCustom(16)
64
+ return attn.processor(attn, hidden_states, *args, **kwargs)
65
+
66
+
67
+ class AttnProcessorCustom(AttnProcessor):
68
+ r"""
69
+ Default processor for performing attention-related computations.
70
+ """
71
+
72
+ def __init__(self, attn_size):
73
+ self.attn_size = attn_size
74
+
75
+ def __call__(
76
+ self,
77
+ attn: Attention,
78
+ hidden_states,
79
+ encoder_hidden_states=None,
80
+ attention_mask=None,
81
+ temb=None,
82
+ scale=1.0,
83
+ ):
84
+ residual = hidden_states
85
+
86
+ args = () if USE_PEFT_BACKEND else (scale,)
87
+
88
+ if attn.spatial_norm is not None:
89
+ hidden_states = attn.spatial_norm(hidden_states, temb)
90
+
91
+ input_ndim = hidden_states.ndim
92
+
93
+ if input_ndim == 4:
94
+ batch_size, channel, height, width = hidden_states.shape
95
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
96
+
97
+ batch_size, sequence_length, _ = (
98
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
99
+ )
100
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
101
+
102
+ if attn.group_norm is not None:
103
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
104
+
105
+ query = attn.to_q(hidden_states, *args)
106
+
107
+ if encoder_hidden_states is None:
108
+ encoder_hidden_states = hidden_states
109
+ elif attn.norm_cross:
110
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
111
+
112
+ key = attn.to_k(encoder_hidden_states, *args)
113
+ value = attn.to_v(encoder_hidden_states, *args)
114
+
115
+ query = attn.head_to_batch_dim(query)
116
+ key = attn.head_to_batch_dim(key)
117
+ value = attn.head_to_batch_dim(value)
118
+
119
+ attn_size = self.attn_size
120
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
121
+ if attention_probs.size(2) == 77 and attention_probs.size(1) == (attn_size * attn_size): # (B*Head,HW,L)
122
+ attn_probs_cache = attention_probs.reshape(batch_size, -1, attn_size, attn_size, 77)
123
+ attn.attn_probs = attn_probs_cache
124
+ else:
125
+ attn.attn_probs = None
126
+
127
+ hidden_states = torch.bmm(attention_probs, value)
128
+ hidden_states = attn.batch_to_head_dim(hidden_states)
129
+
130
+ # linear proj
131
+ hidden_states = attn.to_out[0](hidden_states, *args)
132
+ # dropout
133
+ hidden_states = attn.to_out[1](hidden_states)
134
+
135
+ if input_ndim == 4:
136
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
137
+
138
+ if attn.residual_connection:
139
+ hidden_states = hidden_states + residual
140
+
141
+ hidden_states = hidden_states / attn.rescale_output_factor
142
+
143
+ return hidden_states
dreamcreature/dataset.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from torch.utils.data import Dataset
8
+
9
+
10
+ class ImageDataset(Dataset):
11
+
12
+ def __init__(self,
13
+ rootdir,
14
+ filename='train.txt',
15
+ path_prefix='',
16
+ transform=None,
17
+ target_transform=None):
18
+ super().__init__()
19
+
20
+ self.rootdir = rootdir
21
+ self.filename = filename
22
+ self.path_prefix = path_prefix
23
+
24
+ self.image_paths = []
25
+ self.image_labels = []
26
+
27
+ filename = os.path.join(self.rootdir, self.filename)
28
+
29
+ with open(filename, 'r') as f:
30
+ while True:
31
+ lines = f.readline()
32
+ if not lines:
33
+ break
34
+
35
+ lines = lines.strip()
36
+ split_lines = lines.split(' ')
37
+ path_tmp = split_lines[0]
38
+ label_tmp = split_lines[1:]
39
+ self.is_onehot = len(label_tmp) != 1
40
+ if not self.is_onehot:
41
+ label_tmp = label_tmp[0]
42
+ self.image_paths.append(path_tmp)
43
+ self.image_labels.append(label_tmp)
44
+
45
+ self.image_paths = np.array(self.image_paths)
46
+ self.image_labels = np.array(self.image_labels, dtype=np.float32)
47
+
48
+ self.transform = transform
49
+ self.target_transform = target_transform
50
+
51
+ def __getitem__(self, index):
52
+ """
53
+ Args:
54
+ index (int): Index
55
+ Returns:
56
+ tuple: (image, target) where target is index of the target class.
57
+ """
58
+ path, target = self.image_paths[index], self.image_labels[index]
59
+ target = torch.tensor(target)
60
+
61
+ img = Image.open(f'{self.path_prefix}{path}').convert('RGB')
62
+
63
+ if self.transform is not None:
64
+ img = self.transform(img)
65
+
66
+ if self.target_transform is not None:
67
+ target = self.target_transform(target)
68
+
69
+ return img, target, index
70
+
71
+ def __len__(self):
72
+ return len(self.image_paths)
73
+
74
+
75
+ class DreamCreatureDataset(ImageDataset):
76
+
77
+ def __init__(self,
78
+ rootdir,
79
+ filename='train.txt',
80
+ path_prefix='',
81
+ code_filename='train_caps.txt',
82
+ num_parts=8, num_k_per_part=256, repeat=1,
83
+ use_gt_label=False,
84
+ bg_code=7,
85
+ transform=None,
86
+ target_transform=None):
87
+ super().__init__(rootdir, filename, path_prefix, transform, target_transform)
88
+
89
+ self.image_codes = np.array(open(rootdir + '/' + code_filename).readlines())
90
+ self.num_parts = num_parts
91
+ self.num_k_per_part = num_k_per_part
92
+ self.repeat = repeat
93
+ self.use_gt_label = use_gt_label
94
+ self.bg_code = bg_code
95
+
96
+ def filter_by_class(self, target):
97
+ target_mask = self.image_labels == target
98
+ self.image_paths = self.image_paths[target_mask]
99
+ self.image_codes = self.image_codes[target_mask]
100
+ self.image_labels = self.image_labels[target_mask]
101
+
102
+ def set_max_samples(self, n, seed):
103
+ np.random.seed(seed)
104
+ rand_idx = np.arange(len(self.image_paths))
105
+ np.random.shuffle(rand_idx)
106
+
107
+ self.image_paths = self.image_paths[rand_idx[:n]]
108
+ self.image_codes = self.image_codes[rand_idx[:n]]
109
+ self.image_labels = self.image_labels[rand_idx[:n]]
110
+
111
+ def __len__(self):
112
+ return len(self.image_paths) * self.repeat
113
+
114
+ def __getitem__(self, index):
115
+ """
116
+ Args:
117
+ index (int): Index
118
+ Returns:
119
+ tuple: (image, target) where target is index of the target class.
120
+ """
121
+ index = index % len(self.image_paths)
122
+ path, target = self.image_paths[index], self.image_labels[index]
123
+ target = torch.tensor(target)
124
+
125
+ img = Image.open(f'{self.path_prefix}{path}').convert('RGB')
126
+
127
+ cap = self.image_codes[index].strip()
128
+
129
+ if self.transform is not None:
130
+ img = self.transform(img)
131
+
132
+ if self.target_transform is not None:
133
+ target = self.target_transform(target)
134
+
135
+ appeared = []
136
+
137
+ code = torch.ones(self.num_parts) * self.num_k_per_part # represents not exists
138
+ splits = cap.strip().replace('.', '').split(' ')
139
+ for c in splits:
140
+ idx, intval = c.split(':')
141
+ appeared.append(int(idx))
142
+ if self.use_gt_label and self.bg_code != int(idx):
143
+ code[int(idx)] = target
144
+ else:
145
+ code[int(idx)] = int(intval)
146
+
147
+ example = {
148
+ 'pixel_values': img,
149
+ 'captions': cap,
150
+ 'codes': code,
151
+ 'labels': target,
152
+ 'appeared': appeared
153
+ }
154
+
155
+ return example
dreamcreature/dino.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ from transformers import AutoModel, AutoImageProcessor
3
+
4
+
5
+ class DINO(nn.Module):
6
+ NAMES = {
7
+ 'dino': 'facebook/dino-vits16',
8
+ 'dinov2': 'facebook/dinov2-base'
9
+ }
10
+
11
+ def __init__(self, name='facebook/dinov2-base', **kwargs):
12
+ super().__init__()
13
+
14
+ self.model = AutoModel.from_pretrained(name)
15
+ self.processor = AutoImageProcessor.from_pretrained(name)
16
+
17
+ def forward(self, image):
18
+ vit_output = self.model(image,
19
+ output_hidden_states=True,
20
+ return_dict=True)
21
+
22
+ outputs = {}
23
+ for i in range(1, len(vit_output.hidden_states)):
24
+ outputs[f'block{i}'] = vit_output.hidden_states[i][:, 0] # get cls only
25
+ outputs['feats'] = outputs[f'block{i}']
26
+ return outputs
27
+
28
+ def preprocess(self, image, size=None):
29
+ inputs = self.processor(images=image, return_tensors="pt", size=size)
30
+ return inputs['pixel_values']
31
+
32
+ def get_feat_maps(self, image, index=-1):
33
+ vit_output = self.model(image,
34
+ output_hidden_states=True,
35
+ return_dict=True)
36
+
37
+ last_hidden_states = vit_output.hidden_states[index]
38
+
39
+ B, T, C = last_hidden_states.size()
40
+ HW = int((T - 1) ** 0.5)
41
+
42
+ return last_hidden_states[:, 1:, :].reshape(B, HW, HW, C).permute(0, 3, 1, 2) # (B, C, H, W)
dreamcreature/kmeans_segmentation.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchpq
7
+
8
+
9
+ class KMeansSegmentation:
10
+ FOREGROUND = 'foreground_background'
11
+ COARSE = 'coarse_kmeans'
12
+ FINE = 'fine_kmeans'
13
+
14
+ def __init__(self, path, foreground_idx=0, background_code=7, M=8, K=256):
15
+ if not os.path.exists(path):
16
+ raise FileNotFoundError(f'please train {path}')
17
+ kmeans = torch.load(path)
18
+
19
+ self.foreground_idx = foreground_idx
20
+ self.kmeans = kmeans
21
+ self.background_code = background_code
22
+ self.M = M
23
+ self.K = K
24
+
25
+ self.fg: torchpq.clustering.KMeans = kmeans[KMeansSegmentation.FOREGROUND]
26
+ self.coarse: torchpq.clustering.KMeans = kmeans[KMeansSegmentation.COARSE]
27
+ self.fine: List[torchpq.clustering.KMeans] = kmeans[KMeansSegmentation.FINE]
28
+
29
+ def obtain_fine_feats(self, prompt, filter_idxs=[]):
30
+ if isinstance(prompt, str):
31
+ code = np.zeros((self.M,), dtype=int)
32
+ splits = prompt.strip().split(' ')
33
+ for s in splits:
34
+ m, k = s.split(':')
35
+ code[int(m)] = int(k)
36
+ else:
37
+ code = prompt
38
+
39
+ fine_feats = []
40
+ for m in range(self.M):
41
+ fine_feats.append(self.fine[m].centroids.cpu().t()[code[m]])
42
+ fine_feats = torch.stack(fine_feats, dim=0)
43
+
44
+ if len(filter_idxs) != 0:
45
+ new_fine_feats = []
46
+
47
+ for m in range(self.M):
48
+ if m not in filter_idxs:
49
+ new_fine_feats.append(fine_feats[m])
50
+
51
+ fine_feats = torch.stack(new_fine_feats, dim=0)
52
+
53
+ return fine_feats
54
+
55
+ def get_segmask(self, feat_map, with_appeared_tokens=False):
56
+ N, C, H, W = feat_map.size()
57
+ query = feat_map.cuda().reshape(N, C, H * W).permute(0, 2, 1) # (N, H*W, C)
58
+
59
+ fg_labels = self.fg.predict(query.reshape(N * H * W, C).t().contiguous()) # (N*H*W)
60
+ fg_labels = fg_labels.reshape(N, H * W)
61
+
62
+ fg_idx = self.foreground_idx
63
+ bg_idx = 1 - self.foreground_idx
64
+
65
+ nobg = []
66
+ bgmean = []
67
+
68
+ for i in range(N):
69
+ bgnorm_mean = query[i][fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)
70
+
71
+ if fg_idx == 0:
72
+ bg_mask = fg_labels[i]
73
+ else:
74
+ bg_mask = 1 - fg_labels[i]
75
+
76
+ bg_mask = bg_mask.unsqueeze(1)
77
+ nobg.append(query[i] * (1 - bg_mask) + (-1 * bg_mask))
78
+ bgmean.append(bgnorm_mean)
79
+
80
+ nobg = torch.stack(nobg, dim=0) # (B, H*W, C)
81
+ coarse_labels = self.coarse.predict(nobg.reshape(N * H * W, 768).t().contiguous())
82
+ coarse_labels = coarse_labels.reshape(N, H, W)
83
+
84
+ segmasks = []
85
+ for m in range(self.M):
86
+ mask = (coarse_labels == m).float() # (N, H, W)
87
+ segmasks.append(mask)
88
+ segmasks = torch.stack(segmasks, dim=1) # (N, M, H, W)
89
+
90
+ if with_appeared_tokens:
91
+ appeared_tokens = []
92
+ for i in range(N):
93
+ appeared_tokens.append(torch.unique(coarse_labels[i].reshape(-1)).tolist())
94
+ return segmasks, appeared_tokens
95
+
96
+ return segmasks
97
+
98
+ def predict(self, feat_map, disable=True, filter_idxs=[]):
99
+ # feat_map: (B, C, H, W)
100
+
101
+ N, C, H, W = feat_map.size()
102
+ query = feat_map.reshape(N, C, H * W).permute(0, 2, 1) # (N, H*W, C)
103
+
104
+ fg_labels = self.fg.predict(query.reshape(N * H * W, C).t().contiguous().cuda()).cpu() # (N*H*W)
105
+ fg_labels = fg_labels.reshape(N, H * W)
106
+
107
+ fg_idx = self.foreground_idx
108
+ bg_idx = 1 - self.foreground_idx
109
+
110
+ nobg = []
111
+ bgmean = []
112
+
113
+ for i in range(N):
114
+ bgnorm_mean = query[i][fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)
115
+
116
+ if fg_idx == 0:
117
+ bg_mask = fg_labels[i]
118
+ else:
119
+ bg_mask = 1 - fg_labels[i]
120
+
121
+ bg_mask = bg_mask.unsqueeze(1)
122
+ nobg.append(query[i] * (1 - bg_mask) + (-1 * bg_mask))
123
+ bgmean.append(bgnorm_mean)
124
+
125
+ nobg = torch.stack(nobg, dim=0) # (B, H*W, C)
126
+ bgmean = torch.cat(bgmean, dim=0)
127
+
128
+ coarse_labels = self.coarse.predict(nobg.reshape(N * H * W, 768).t().contiguous().cuda()).cpu()
129
+ coarse_labels = coarse_labels.reshape(N, H * W)
130
+
131
+ from tqdm.auto import tqdm
132
+
133
+ fgmean = []
134
+ M = self.M
135
+
136
+ locs = np.zeros((N, M, 2))
137
+
138
+ for i in tqdm(range(N), disable=disable):
139
+ mean_feats = []
140
+ for m in range(M):
141
+ coarse_mask = coarse_labels[i] == m
142
+ if coarse_mask.sum().item() == 0:
143
+ m_mean_feats = torch.zeros(1, C)
144
+ else:
145
+ locs[i, m] = (coarse_mask.reshape(H, W).nonzero().float().add(0.5).mean(dim=0) / H).cpu().numpy()
146
+ m_mean_feats = query[i][coarse_mask].mean(dim=0, keepdim=True) # (H*W,C) -> (1,C)
147
+
148
+ mean_feats.append(m_mean_feats)
149
+
150
+ mean_feats = torch.cat(mean_feats, dim=0)
151
+ fgmean.append(mean_feats)
152
+
153
+ fgmean = torch.stack(fgmean, dim=0) # (N, M, C)
154
+ final_labels = torch.ones(N, M) * self.K
155
+
156
+ for m in range(M):
157
+ fine_kmeans = self.fine[m]
158
+
159
+ if m == self.background_code:
160
+ fine_labels = fine_kmeans.predict(bgmean.t().contiguous().cuda()).cpu()
161
+ final_labels[:, m] = fine_labels
162
+ else:
163
+ fine_inp = fgmean[:, m].reshape(N, C)
164
+ is_zero = fine_inp.sum(dim=1) == 0
165
+ fine_labels = fine_kmeans.predict(fine_inp.t().contiguous().cuda()).cpu()
166
+ fine_labels[is_zero] = self.K
167
+
168
+ final_labels[:, m] = fine_labels
169
+
170
+ fgmean[:, self.background_code] = bgmean
171
+ fine_prompts = []
172
+
173
+ for fine_label in final_labels:
174
+ prompt_dict = {k: int(v) for k, v in enumerate(list(fine_label))}
175
+ if len(filter_idxs) != 0:
176
+ for i in filter_idxs:
177
+ del prompt_dict[i]
178
+ prompt = ' '.join([f'{k}:{v}' for k, v in prompt_dict.items() if v != self.K])
179
+ fine_prompts.append(prompt)
180
+
181
+ return {
182
+ 'features': fgmean,
183
+ 'fg_labels': fg_labels,
184
+ 'coarse_labels': coarse_labels,
185
+ 'fine_labels': final_labels,
186
+ 'fine_prompts': fine_prompts,
187
+ 'location': locs
188
+ }
dreamcreature/loss.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from diffusers import UNet2DConditionModel
4
+ from diffusers.models.attention_processor import Attention
5
+
6
+ from dreamcreature.dino import DINO
7
+ from dreamcreature.kmeans_segmentation import KMeansSegmentation
8
+
9
+
10
+ def dreamcreature_loss(batch,
11
+ unet: UNet2DConditionModel,
12
+ dino: DINO,
13
+ seg: KMeansSegmentation,
14
+ placeholder_token_ids,
15
+ accelerator):
16
+ attn_probs = {}
17
+
18
+ for name, module in unet.named_modules():
19
+ if isinstance(module, Attention) and module.attn_probs is not None:
20
+ a = module.attn_probs.mean(dim=1) # (B,Head,H,W,77) -> (B,H,W,77)
21
+ attn_probs[name] = a
22
+
23
+ avg_attn_map = []
24
+ for name in attn_probs:
25
+ avg_attn_map.append(attn_probs[name])
26
+
27
+ avg_attn_map = torch.stack(avg_attn_map, dim=0).mean(dim=0) # (L,B,H,W,77) -> (B,H,W,77)
28
+ B, H, W, seq_length = avg_attn_map.size()
29
+ located_attn_map = []
30
+
31
+ # locate the attn map
32
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
33
+ for bi in range(B):
34
+ if "input_ids" in batch:
35
+ learnable_idx = (batch["input_ids"][bi] == placeholder_token_id).nonzero(as_tuple=True)[0]
36
+ else:
37
+ learnable_idx = (batch["input_ids_one"][bi] == placeholder_token_id).nonzero(as_tuple=True)[0]
38
+
39
+ if len(learnable_idx) != 0: # only assign if found
40
+ if len(learnable_idx) == 1:
41
+ offset_learnable_idx = learnable_idx
42
+ else: # if there is two and above.
43
+ raise NotImplementedError
44
+
45
+ located_map = avg_attn_map[bi, :, :, offset_learnable_idx]
46
+ located_attn_map.append(located_map)
47
+ else:
48
+ located_attn_map.append(torch.zeros(H, W, 1).to(accelerator.device))
49
+
50
+ M = len(placeholder_token_ids)
51
+ located_attn_map = torch.stack(located_attn_map, dim=0).reshape(M, B, H, W).transpose(0, 1) # (B, M, 16, 16)
52
+
53
+ raw_images = batch['raw_images']
54
+ dino_input = dino.preprocess(raw_images, size=224)
55
+ with torch.no_grad():
56
+ dino_ft = dino.get_feat_maps(dino_input)
57
+ segmasks, appeared_tokens = seg.get_segmask(dino_ft, True) # (B, M, H, W)
58
+ segmasks = segmasks.to(located_attn_map.dtype)
59
+ if H != 16: # for res 1024
60
+ segmasks = F.interpolate(segmasks, (H, W), mode='nearest')
61
+
62
+ masks = []
63
+ for i, appeared in enumerate(appeared_tokens):
64
+ mask = (segmasks[i, appeared].sum(dim=0) > 0).float() # (A, H, W) -> (H, W)
65
+ masks.append(mask)
66
+ masks = torch.stack(masks, dim=0) # (B, H, W)
67
+ batch['masks'] = masks
68
+
69
+ norm_map = located_attn_map / located_attn_map.sum(dim=1, keepdim=True).clamp(min=1e-6)
70
+ # if norm_map is assigned manually, means the sub-concept token is not found, hence no gradient will be backprop
71
+ attn_loss = F.binary_cross_entropy(norm_map.clamp(min=0, max=1),
72
+ segmasks.clamp(min=0, max=1))
73
+ return attn_loss, located_attn_map.detach().max()
dreamcreature/mapper.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class TokenMapper(nn.Module):
8
+ def __init__(self,
9
+ num_parts,
10
+ num_k_per_part,
11
+ out_dims,
12
+ projection_nlayers=1,
13
+ projection_activation=nn.ReLU(),
14
+ with_pe=True):
15
+ super().__init__()
16
+
17
+ self.num_parts = num_parts
18
+ self.num_k_per_part = num_k_per_part
19
+ self.with_pe = with_pe
20
+ self.out_dims = out_dims
21
+
22
+ self.embedding = nn.Embedding((self.num_k_per_part + 1) * num_parts, out_dims)
23
+ if with_pe:
24
+ self.pe = nn.Parameter(torch.randn(num_parts, out_dims))
25
+ else:
26
+ self.register_buffer('pe', torch.zeros(num_parts, out_dims))
27
+
28
+ if projection_nlayers == 0:
29
+ self.projection = nn.Identity()
30
+ else:
31
+ projections = []
32
+ for i in range(projection_nlayers - 1):
33
+ projections.append(nn.Linear(out_dims, out_dims))
34
+ projections.append(projection_activation)
35
+
36
+ projections.append(nn.Linear(out_dims, out_dims))
37
+ self.projection = nn.Sequential(*projections)
38
+
39
+ def get_all_embeddings(self, no_projection=False, no_pe=False):
40
+ idx = torch.arange(self.num_parts * (self.num_k_per_part + 1)).long().to(self.embedding.weight.device)
41
+ idx = idx.reshape(self.num_parts, self.num_k_per_part + 1)
42
+ emb = self.embedding(idx) # (K, N, d)
43
+
44
+ if not no_pe:
45
+ emb_pe = emb + self.pe.unsqueeze(1)
46
+ else:
47
+ emb_pe = emb
48
+
49
+ if not no_projection:
50
+ projected = self.projection(emb_pe)
51
+ else:
52
+ projected = emb_pe
53
+
54
+ return projected
55
+
56
+ def forward(self, hashes, index: Optional[torch.Tensor] = None):
57
+ B = hashes.size(0)
58
+
59
+ # 0, 257, 514, ...
60
+ if index is None:
61
+ offset = torch.arange(self.num_parts, device=hashes.device) * (self.num_k_per_part + 1)
62
+ hashes = self.embedding(hashes.long() + offset.reshape(1, -1)) # (B, N, d)
63
+ else:
64
+ offset = index.reshape(-1) * (self.num_k_per_part + 1)
65
+ hashes = self.embedding(hashes.long() + offset.reshape(B, -1).long()) # (B, N, d)
66
+
67
+ if index is not None:
68
+ pe = self.pe[index.reshape(-1)] # index must be equal size
69
+ pe = pe.reshape(B, -1, self.out_dims)
70
+ hashes = hashes + pe
71
+ else:
72
+ hashes = hashes + self.pe.unsqueeze(0).repeat(B, 1, 1)
73
+ projected = self.projection(hashes)
74
+
75
+ return projected
dreamcreature/pipeline.py ADDED
@@ -0,0 +1,771 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+
4
+ from diffusers.loaders import AttnProcsLayers
5
+ from diffusers.models.attention_processor import Attention
6
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import *
7
+ from omegaconf import OmegaConf
8
+
9
+ from dreamcreature.attn_processor import LoRAAttnProcessorCustom
10
+ from dreamcreature.mapper import TokenMapper
11
+ from dreamcreature.text_encoder import CustomCLIPTextModel
12
+ from dreamcreature.tokenizer import MultiTokenCLIPTokenizer
13
+ from utils import add_tokens, get_attn_processors
14
+
15
+
16
+ def setup_attn_processor(unet, **kwargs):
17
+ lora_attn_procs = {}
18
+ for name in unet.attn_processors.keys():
19
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
20
+ if name.startswith("mid_block"):
21
+ hidden_size = unet.config.block_out_channels[-1]
22
+ elif name.startswith("up_blocks"):
23
+ block_id = int(name[len("up_blocks.")])
24
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
25
+ elif name.startswith("down_blocks"):
26
+ block_id = int(name[len("down_blocks.")])
27
+ hidden_size = unet.config.block_out_channels[block_id]
28
+
29
+ lora_attn_procs[name] = LoRAAttnProcessorCustom(
30
+ hidden_size=hidden_size,
31
+ cross_attention_dim=cross_attention_dim,
32
+ rank=kwargs['rank'],
33
+ )
34
+
35
+ unet.set_attn_processor(lora_attn_procs)
36
+
37
+
38
+ def load_attn_processor(unet, filename):
39
+ lora_layers = AttnProcsLayers(get_attn_processors(unet))
40
+ lora_layers.load_state_dict(torch.load(filename))
41
+
42
+
43
+ def convert_prompt_re(prompt: str):
44
+ pattern = r"<(\d+):(\d+)>"
45
+ result = prompt
46
+ offset = 0
47
+
48
+ ints = []
49
+ parts_i = []
50
+
51
+ for match in re.finditer(pattern, prompt):
52
+ i = int(match.group(1))
53
+ b = int(match.group(2))
54
+
55
+ replacement = f"<part>_{i}"
56
+ start, end = match.span()
57
+
58
+ # Adjust the start and end positions based on the offset from previous replacements
59
+ start += offset
60
+ end += offset
61
+
62
+ # Replace the matched text with the replacement
63
+ result = result[:start] + replacement + result[end:]
64
+
65
+ # Update the offset for the next replacement
66
+ offset += len(replacement) - (end - start)
67
+
68
+ parts_i.append(i)
69
+ ints.append(b)
70
+
71
+ result = result.strip()
72
+
73
+ if len(ints) == 0:
74
+ return result, None, None
75
+
76
+ ints = torch.tensor(ints) # (nparts,)
77
+ return result, ints, parts_i
78
+
79
+
80
+ def convert_prompt(prompt: str, replace_token: bool = False, v='v1'):
81
+ r"""
82
+ Parameters:
83
+ prompt (`str`):
84
+ The prompt to guide the image generation.
85
+
86
+ Returns:
87
+ `str`: The converted prompt
88
+ """
89
+ if v == 're':
90
+ return convert_prompt_re(prompt)
91
+
92
+ if ':' not in prompt:
93
+ return prompt, None, None
94
+
95
+ splits = prompt.replace('.', '').strip().split(' ')
96
+ # v1: a photo of a 0:1 1:24 ...
97
+ # v2: a photo of a <0:1> <1:24> ...
98
+ ints = []
99
+ noncode_start = ''
100
+ noncode_end = ''
101
+ parts = ''
102
+ parts_i = []
103
+ split_tokens = []
104
+ for b in splits:
105
+ if ':' not in b:
106
+ split_tokens.append(b)
107
+ continue
108
+
109
+ if v == 'v1':
110
+ i, b = b.strip().split(':')
111
+ has_comma = ',' in b
112
+ if has_comma:
113
+ b = b[:-1]
114
+ intb = int(b)
115
+ parts += f'<part>_{i} '
116
+ split_tokens.append(f'<part>_{i}')
117
+ if has_comma:
118
+ split_tokens.append(',')
119
+ else:
120
+ if b[0] == '<':
121
+ if '>' not in b: # no closing >, ignore
122
+ split_tokens.append(b)
123
+ continue
124
+
125
+ i, b = b[1:].strip().split(':')
126
+ token_to_add = ''
127
+ if b[-1] in [',', '.']:
128
+ token_to_add = b[-1]
129
+ b = b[:-1]
130
+
131
+ if b[-1] == '>':
132
+ b = b[:-1]
133
+ else: # not >, just search for the first >
134
+ for ci, char in enumerate(b):
135
+ if char == '>':
136
+ token_to_add = b[ci + 1:] + token_to_add
137
+ b = b[:ci] # skip >
138
+ break
139
+ else: # has : but not start with <
140
+ split_tokens.append(b)
141
+ continue
142
+
143
+ intb = abs(int(b)) # just force negative one to positive
144
+
145
+ parts += f'<part>_{i} '
146
+ split_tokens.append(f'<part>_{i}')
147
+ if len(token_to_add) != 0:
148
+ split_tokens.append(token_to_add)
149
+
150
+ try:
151
+ int(i)
152
+ except:
153
+ raise ValueError(f'cannot cast `part` properly, please make sure input is correct')
154
+
155
+ parts_i.append(int(i))
156
+ ints.append(intb)
157
+
158
+ ints = torch.tensor(ints) # (nparts,)
159
+
160
+ if replace_token:
161
+ new_caption = f'{noncode_start.strip()} <part> {noncode_end.strip()}'
162
+ else:
163
+ new_caption = ' '.join(split_tokens)
164
+
165
+ new_caption = new_caption.strip()
166
+
167
+ return new_caption, ints, parts_i
168
+
169
+
170
+ class DreamCreatureSDPipeline(StableDiffusionPipeline):
171
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: MultiTokenCLIPTokenizer):
172
+ r"""
173
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
174
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
175
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
176
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
177
+
178
+ Parameters:
179
+ prompt (`str`):
180
+ The prompt to guide the image generation.
181
+ tokenizer (`PreTrainedTokenizer`):
182
+ The tokenizer responsible for encoding the prompt into input tokens.
183
+
184
+ Returns:
185
+ `str`: The converted prompt
186
+ """
187
+ if hasattr(self, 'replace_token'):
188
+ replace_token = self.replace_token
189
+ else:
190
+ replace_token = True
191
+
192
+ if hasattr(self, 'v'):
193
+ v = self.v
194
+ else:
195
+ v = 'v1'
196
+
197
+ new_caption, code, parts_i = convert_prompt(prompt, replace_token, v)
198
+ if hasattr(self, 'num_k_per_part'):
199
+ if code is not None and any(code >= self.num_k_per_part):
200
+ raise ValueError(f'`id` cannot more than {self.num_k_per_part}')
201
+
202
+ if hasattr(self, 'verbose') and self.verbose:
203
+ print(new_caption)
204
+
205
+ return new_caption, code, parts_i
206
+
207
+ def compute_prompt_embeddings(self, prompts, device):
208
+ # textual inversion: procecss multi-vector tokens if necessary
209
+ if not isinstance(prompts, List):
210
+ prompts = [prompts]
211
+
212
+ prompt_embeds_concat = []
213
+ for prompt in prompts:
214
+ prompt, code, parts_i = self.maybe_convert_prompt(prompt, self.tokenizer)
215
+
216
+ if hasattr(self, 'replace_token'):
217
+ replace_token = self.replace_token
218
+ else:
219
+ replace_token = True
220
+
221
+ text_inputs = self.tokenizer(
222
+ prompt,
223
+ replace_token=replace_token,
224
+ padding="max_length",
225
+ max_length=self.tokenizer.model_max_length,
226
+ truncation=True,
227
+ return_tensors="pt",
228
+ )
229
+ text_input_ids = text_inputs.input_ids
230
+ if hasattr(self, 'verbose') and self.verbose:
231
+ print(text_input_ids)
232
+
233
+ untruncated_ids = self.tokenizer(prompt,
234
+ replace_token=replace_token,
235
+ padding="longest",
236
+ return_tensors="pt").input_ids
237
+
238
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
239
+ text_input_ids, untruncated_ids
240
+ ):
241
+ removed_text = self.tokenizer.batch_decode(
242
+ untruncated_ids[:, self.tokenizer.model_max_length - 1: -1]
243
+ )
244
+ logger.warning(
245
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
246
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
247
+ )
248
+
249
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
250
+ attention_mask = text_inputs.attention_mask.to(device)
251
+ else:
252
+ attention_mask = None
253
+
254
+ if code is None:
255
+ modified_hs = None
256
+ else:
257
+ placeholder_token_ids = self.placeholder_token_ids
258
+ placeholder_token_ids = [placeholder_token_ids[i] for i in parts_i] # follow the order of prompt's i
259
+ mapper_outputs = self.simple_mapper(code.unsqueeze(0).to(device), torch.tensor(parts_i).to(device))
260
+ modified_hs = self.text_encoder.text_model.forward_embeddings_with_mapper(text_input_ids.to(device),
261
+ None,
262
+ mapper_outputs,
263
+ placeholder_token_ids)
264
+
265
+ prompt_embeds = self.text_encoder(
266
+ text_input_ids.to(device),
267
+ attention_mask=attention_mask,
268
+ hidden_states=modified_hs
269
+ )
270
+ prompt_embeds = prompt_embeds[0]
271
+ prompt_embeds_concat.append(prompt_embeds)
272
+
273
+ if len(prompt_embeds_concat) == 1:
274
+ return prompt_embeds_concat[0]
275
+ else:
276
+ return torch.cat(prompt_embeds_concat, dim=0)
277
+
278
+ def encode_prompt(
279
+ self,
280
+ prompt,
281
+ device,
282
+ num_images_per_prompt,
283
+ do_classifier_free_guidance,
284
+ negative_prompt=None,
285
+ prompt_embeds: Optional[torch.FloatTensor] = None,
286
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
287
+ lora_scale: Optional[float] = None,
288
+ ):
289
+ r"""
290
+ Encodes the prompt into text encoder hidden states.
291
+
292
+ Args:
293
+ prompt (`str` or `List[str]`, *optional*):
294
+ prompt to be encoded
295
+ device: (`torch.device`):
296
+ torch device
297
+ num_images_per_prompt (`int`):
298
+ number of images that should be generated per prompt
299
+ do_classifier_free_guidance (`bool`):
300
+ whether to use classifier free guidance or not
301
+ negative_prompt (`str` or `List[str]`, *optional*):
302
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
303
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
304
+ less than `1`).
305
+ prompt_embeds (`torch.FloatTensor`, *optional*):
306
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
307
+ provided, text embeddings will be generated from `prompt` input argument.
308
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
309
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
310
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
311
+ argument.
312
+ lora_scale (`float`, *optional*):
313
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
314
+ """
315
+ # set lora scale so that monkey patched LoRA
316
+ # function of text encoder can correctly access it
317
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
318
+ self._lora_scale = lora_scale
319
+
320
+ if prompt is not None and isinstance(prompt, str):
321
+ batch_size = 1
322
+ elif prompt is not None and isinstance(prompt, list):
323
+ batch_size = len(prompt)
324
+ else:
325
+ batch_size = prompt_embeds.shape[0]
326
+
327
+ if prompt_embeds is None:
328
+ prompt_embeds = self.compute_prompt_embeddings(prompt, device)
329
+
330
+ # if self.text_encoder is not None:
331
+ # prompt_embeds_dtype = self.text_encoder.dtype
332
+ # elif self.unet is not None:
333
+ # prompt_embeds_dtype = self.unet.dtype
334
+ # else:
335
+ # prompt_embeds_dtype = prompt_embeds.dtype
336
+
337
+ prompt_embeds_dtype = self.unet.dtype # should be unet only because this is unet's condition input
338
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
339
+
340
+ bs_embed, seq_len, _ = prompt_embeds.shape
341
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
342
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
343
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
344
+
345
+ # get unconditional embeddings for classifier free guidance
346
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
347
+ uncond_tokens: List[str]
348
+ if negative_prompt is None:
349
+ uncond_tokens = [""] * batch_size
350
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
351
+ raise TypeError(
352
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
353
+ f" {type(prompt)}."
354
+ )
355
+ elif isinstance(negative_prompt, str):
356
+ uncond_tokens = [negative_prompt] * batch_size
357
+ elif batch_size != len(negative_prompt):
358
+ raise ValueError(
359
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
360
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
361
+ " the batch size of `prompt`."
362
+ )
363
+ else:
364
+ uncond_tokens = negative_prompt
365
+
366
+ negative_prompt_embeds = []
367
+ for u_tokens in uncond_tokens:
368
+ negative_prompt_embeds.append(self.compute_prompt_embeddings(u_tokens, device))
369
+ negative_prompt_embeds = torch.cat(negative_prompt_embeds, dim=0)
370
+
371
+ # if isinstance(self, TextualInversionLoaderMixin):
372
+ # uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
373
+
374
+ # max_length = prompt_embeds.shape[1]
375
+ # uncond_input = self.tokenizer(
376
+ # uncond_tokens,
377
+ # padding="max_length",
378
+ # max_length=max_length,
379
+ # truncation=True,
380
+ # return_tensors="pt",
381
+ # )
382
+ #
383
+ # if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
384
+ # attention_mask = uncond_input.attention_mask.to(device)
385
+ # else:
386
+ # attention_mask = None
387
+ #
388
+ # negative_prompt_embeds = self.text_encoder(
389
+ # uncond_input.input_ids.to(device),
390
+ # attention_mask=attention_mask,
391
+ # )
392
+ # negative_prompt_embeds = negative_prompt_embeds[0]
393
+
394
+ if do_classifier_free_guidance:
395
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
396
+ seq_len = negative_prompt_embeds.shape[1]
397
+
398
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
399
+
400
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
401
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
402
+
403
+ return prompt_embeds, negative_prompt_embeds
404
+
405
+ @torch.no_grad()
406
+ def obtain_attention_map(self, image, prompt, timesteps):
407
+ prompt, codes, index = self.maybe_convert_prompt(prompt, self.tokenizer)
408
+
409
+ if hasattr(self, 'replace_token'):
410
+ replace_token = self.replace_token
411
+ else:
412
+ replace_token = True
413
+
414
+ text_inputs = self.tokenizer(
415
+ prompt,
416
+ replace_token=replace_token,
417
+ padding="max_length",
418
+ max_length=self.tokenizer.model_max_length,
419
+ truncation=True,
420
+ return_tensors="pt",
421
+ )
422
+ input_ids = text_inputs.input_ids
423
+
424
+ placeholder_token_ids = self.placeholder_token_ids
425
+ placeholder_token_ids = [placeholder_token_ids[i] for i in index]
426
+
427
+ # forward an image, denoise it and obtain the attention map
428
+ device = self._execution_device
429
+
430
+ latents = self.vae.encode(image.to(device, dtype=self.weight_dtype)).latent_dist.sample()
431
+ latents = latents * self.vae.config.scaling_factor
432
+
433
+ # bsz = latents.shape[0]
434
+ # Sample a random timestep for each image
435
+ # timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
436
+ timesteps = timesteps.long().to(latents.device)
437
+
438
+ # Add noise to the latents according to the noise magnitude at each timestep
439
+ # (this is the forward diffusion process)
440
+ noise = torch.randn_like(latents)
441
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
442
+
443
+ mapper_outputs = self.simple_mapper(codes.unsqueeze(0).to(device), torch.tensor(index).to(device))
444
+ # print(mapper_outputs.size(), batch["input_ids"].size())
445
+ modified_hs = self.text_encoder.text_model.forward_embeddings_with_mapper(input_ids.to(device),
446
+ None,
447
+ mapper_outputs,
448
+ placeholder_token_ids)
449
+ encoder_hidden_states = self.text_encoder(input_ids, hidden_states=modified_hs)[0]
450
+ model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states.to(dtype=self.weight_dtype)).sample
451
+
452
+ attn_probs = {}
453
+
454
+ for name, module in self.unet.named_modules():
455
+ if isinstance(module, Attention) and module.attn_probs is not None:
456
+ a = module.attn_probs[0].mean(dim=0) # (2,Head,H,W,77)->(H,W,77)
457
+ attn_probs[name] = a
458
+
459
+ avg_attn_map = []
460
+ for name in attn_probs:
461
+ avg_attn_map.append(attn_probs[name])
462
+ avg_attn_map = torch.stack(avg_attn_map, dim=0).mean(dim=0) # (5,B,H,W,77) -> (B,H,W,77)
463
+
464
+ return attn_probs, avg_attn_map, input_ids
465
+
466
+ @torch.no_grad()
467
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
468
+ def __call__(
469
+ self,
470
+ prompt: Union[str, List[str]] = None,
471
+ height: Optional[int] = None,
472
+ width: Optional[int] = None,
473
+ num_inference_steps: int = 50,
474
+ guidance_scale: float = 7.5,
475
+ negative_prompt: Optional[Union[str, List[str]]] = None,
476
+ num_images_per_prompt: Optional[int] = 1,
477
+ eta: float = 0.0,
478
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
479
+ latents: Optional[torch.FloatTensor] = None,
480
+ prompt_embeds: Optional[torch.FloatTensor] = None,
481
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
482
+ output_type: Optional[str] = "pil",
483
+ return_dict: bool = True,
484
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
485
+ callback_steps: int = 1,
486
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
487
+ guidance_rescale: float = 0.0,
488
+ get_attention_map: bool = False
489
+ ):
490
+ r"""
491
+ The call function to the pipeline for generation.
492
+
493
+ Args:
494
+ prompt (`str` or `List[str]`, *optional*):
495
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
496
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
497
+ The height in pixels of the generated image.
498
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
499
+ The width in pixels of the generated image.
500
+ num_inference_steps (`int`, *optional*, defaults to 50):
501
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
502
+ expense of slower inference.
503
+ guidance_scale (`float`, *optional*, defaults to 7.5):
504
+ A higher guidance scale value encourages the model to generate images closely linked to the text
505
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
506
+ negative_prompt (`str` or `List[str]`, *optional*):
507
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
508
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
509
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
510
+ The number of images to generate per prompt.
511
+ eta (`float`, *optional*, defaults to 0.0):
512
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
513
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
514
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
515
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
516
+ generation deterministic.
517
+ latents (`torch.FloatTensor`, *optional*):
518
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
519
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
520
+ tensor is generated by sampling using the supplied random `generator`.
521
+ prompt_embeds (`torch.FloatTensor`, *optional*):
522
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
523
+ provided, text embeddings are generated from the `prompt` input argument.
524
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
525
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
526
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
527
+ output_type (`str`, *optional*, defaults to `"pil"`):
528
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
529
+ return_dict (`bool`, *optional*, defaults to `True`):
530
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
531
+ plain tuple.
532
+ callback (`Callable`, *optional*):
533
+ A function that calls every `callback_steps` steps during inference. The function is called with the
534
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
535
+ callback_steps (`int`, *optional*, defaults to 1):
536
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
537
+ every step.
538
+ cross_attention_kwargs (`dict`, *optional*):
539
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
540
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
541
+ guidance_rescale (`float`, *optional*, defaults to 0.7):
542
+ Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
543
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
544
+ using zero terminal SNR.
545
+
546
+ Examples:
547
+
548
+ Returns:
549
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
550
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
551
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
552
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
553
+ "not-safe-for-work" (nsfw) content.
554
+ """
555
+ # 0. Default height and width to unet
556
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
557
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
558
+
559
+ # 1. Check inputs. Raise error if not correct
560
+ self.check_inputs(
561
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
562
+ )
563
+
564
+ # 2. Define call parameters
565
+ if prompt is not None and isinstance(prompt, str):
566
+ batch_size = 1
567
+ elif prompt is not None and isinstance(prompt, list):
568
+ batch_size = len(prompt)
569
+ else:
570
+ batch_size = prompt_embeds.shape[0]
571
+
572
+ device = self._execution_device
573
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
574
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
575
+ # corresponds to doing no classifier free guidance.
576
+ do_classifier_free_guidance = guidance_scale > 1.0
577
+
578
+ # 3. Encode input prompt
579
+ text_encoder_lora_scale = (
580
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
581
+ )
582
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
583
+ prompt,
584
+ device,
585
+ num_images_per_prompt,
586
+ do_classifier_free_guidance,
587
+ negative_prompt,
588
+ prompt_embeds=prompt_embeds,
589
+ negative_prompt_embeds=negative_prompt_embeds,
590
+ lora_scale=text_encoder_lora_scale,
591
+ )
592
+ # For classifier free guidance, we need to do two forward passes.
593
+ # Here we concatenate the unconditional and text embeddings into a single batch
594
+ # to avoid doing two forward passes
595
+ if do_classifier_free_guidance:
596
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
597
+
598
+ # 4. Prepare timesteps
599
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
600
+ timesteps = self.scheduler.timesteps
601
+
602
+ # 5. Prepare latent variables
603
+ num_channels_latents = self.unet.config.in_channels
604
+ latents = self.prepare_latents(
605
+ batch_size * num_images_per_prompt,
606
+ num_channels_latents,
607
+ height,
608
+ width,
609
+ prompt_embeds.dtype,
610
+ device,
611
+ generator,
612
+ latents,
613
+ )
614
+
615
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
616
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
617
+
618
+ # 7. Denoising loop
619
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
620
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
621
+ if get_attention_map:
622
+ attn_maps = {} # each t one attn map
623
+
624
+ for i, t in enumerate(timesteps):
625
+ # expand the latents if we are doing classifier free guidance
626
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
627
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
628
+
629
+ # predict the noise residual
630
+ noise_pred = self.unet(
631
+ latent_model_input,
632
+ t,
633
+ encoder_hidden_states=prompt_embeds,
634
+ cross_attention_kwargs=cross_attention_kwargs,
635
+ return_dict=False,
636
+ )[0]
637
+ if get_attention_map:
638
+ attn_probs = {}
639
+
640
+ for name, module in self.unet.named_modules():
641
+ if isinstance(module, Attention) and module.attn_probs is not None:
642
+ a = module.attn_probs[1].mean(dim=0) # (2,Head,H,W,77)->(H,W,77)
643
+ attn_probs[name] = a
644
+
645
+ attn_maps[i] = attn_probs
646
+
647
+ # perform guidance
648
+ if do_classifier_free_guidance:
649
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
650
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
651
+
652
+ if do_classifier_free_guidance and guidance_rescale > 0.0:
653
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
654
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale)
655
+
656
+ # compute the previous noisy sample x_t -> x_t-1
657
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
658
+
659
+ # call the callback, if provided
660
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
661
+ progress_bar.update()
662
+ if callback is not None and i % callback_steps == 0:
663
+ callback(i, t, latents)
664
+
665
+ if get_attention_map:
666
+ output_maps = {}
667
+ for name in attn_probs.keys():
668
+ timeavg_maps = []
669
+ for i in attn_maps.keys():
670
+ timeavg_maps.append(attn_maps[i][name])
671
+ timeavg_maps = torch.stack(timeavg_maps, dim=0).mean(dim=0)
672
+ output_maps[name] = timeavg_maps
673
+
674
+ avg_attn_map = []
675
+ for name in attn_probs:
676
+ avg_attn_map.append(attn_probs[name])
677
+ avg_attn_map = torch.stack(avg_attn_map, dim=0).mean(dim=0) # (5,B,H,W,77) -> (B,H,W,77)
678
+ output_maps['avg'] = avg_attn_map
679
+
680
+ del attn_maps
681
+ del attn_probs
682
+ self.attn_maps = output_maps
683
+
684
+ if not output_type == "latent":
685
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
686
+ image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
687
+ else:
688
+ image = latents
689
+ has_nsfw_concept = None
690
+
691
+ if has_nsfw_concept is None:
692
+ do_denormalize = [True] * image.shape[0]
693
+ else:
694
+ do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
695
+
696
+ image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
697
+
698
+ # Offload last model to CPU
699
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
700
+ self.final_offload_hook.offload()
701
+
702
+ if not return_dict:
703
+ return (image, has_nsfw_concept)
704
+
705
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
706
+
707
+
708
+ def create_args(output_dir, num_parts=8, num_k_per_part=256):
709
+ args = OmegaConf.create({
710
+ 'pretrained_model_name_or_path': 'runwayml/stable-diffusion-v1-5',
711
+ 'num_parts': num_parts,
712
+ 'num_k_per_part': num_k_per_part,
713
+ 'revision': None,
714
+ 'variant': None,
715
+ 'rank': 4,
716
+ 'projection_nlayers': 1,
717
+ 'output_dir': output_dir
718
+ })
719
+ folders = sorted(os.listdir(args.output_dir))
720
+ cps = [int(f.split('-')[1]) for f in folders if 'checkpoint' in f and '.ipynb' not in f]
721
+ maxcp = max(cps)
722
+
723
+ args.maxcp = maxcp
724
+ args.unet_path = None
725
+ return args
726
+
727
+
728
+ def load_pipeline(args, weight_dtype=torch.float16, device=torch.device('cuda')):
729
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(
730
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
731
+ )
732
+ text_encoder = CustomCLIPTextModel.from_pretrained(
733
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
734
+ )
735
+ unet_path = args.unet_path if args.unet_path is not None else args.pretrained_model_name_or_path
736
+ unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
737
+ unet_path, subfolder="unet", revision=args.revision
738
+ )
739
+ pipeline = DreamCreatureSDPipeline.from_pretrained(
740
+ args.pretrained_model_name_or_path,
741
+ unet=unet,
742
+ text_encoder=text_encoder,
743
+ tokenizer=tokenizer,
744
+ revision=args.revision,
745
+ torch_dtype=weight_dtype,
746
+ )
747
+ pipeline.num_k_per_part = args.num_k_per_part
748
+ pipeline.num_parts = args.num_parts
749
+ placeholder_token = "<part>"
750
+ initializer_token = None
751
+ placeholder_token_ids = add_tokens(tokenizer,
752
+ text_encoder,
753
+ placeholder_token,
754
+ args.num_parts,
755
+ initializer_token)
756
+ pipeline.placeholder_token_ids = placeholder_token_ids
757
+ pipeline.simple_mapper = TokenMapper(args.num_parts,
758
+ args.num_k_per_part,
759
+ 768,
760
+ args.projection_nlayers)
761
+ pipeline.simple_mapper.load_state_dict(torch.load(args.output_dir + f'/checkpoint-{args.maxcp}/pytorch_model_1.bin',
762
+ map_location='cpu'))
763
+ pipeline.simple_mapper.to(device)
764
+ pipeline.replace_token = False
765
+ pipeline = pipeline.to(device)
766
+
767
+ # load attention processors
768
+ # pipeline.unet.load_attn_procs(args.output_dir, use_safetensors=not args.custom_diffusion)
769
+ setup_attn_processor(pipeline.unet, rank=args.rank)
770
+ load_attn_processor(pipeline.unet, args.output_dir + f'/checkpoint-{args.maxcp}/pytorch_model.bin')
771
+ return pipeline
dreamcreature/pipeline_xl.py ADDED
@@ -0,0 +1,895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from diffusers.models.attention_processor import Attention
4
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import *
5
+ from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl import (
6
+ StableDiffusionXLPipeline,
7
+ StableDiffusionXLPipelineOutput,
8
+ XLA_AVAILABLE,
9
+ Tuple,
10
+ StableDiffusionXLLoraLoaderMixin,
11
+ PipelineImageInput
12
+ )
13
+ from omegaconf import OmegaConf
14
+
15
+ from dreamcreature.attn_processor import AttnProcessorCustom
16
+ from dreamcreature.mapper import TokenMapper
17
+ from dreamcreature.pipeline import convert_prompt
18
+ from dreamcreature.text_encoder import CustomCLIPTextModel, CustomCLIPTextModelWithProjection
19
+ from dreamcreature.tokenizer import MultiTokenCLIPTokenizer
20
+ from utils import add_tokens
21
+
22
+
23
+ def init_for_pipeline(args):
24
+ tokenizer_one = MultiTokenCLIPTokenizer.from_pretrained(
25
+ args.pretrained_model_name_or_path,
26
+ subfolder="tokenizer",
27
+ revision=args.revision,
28
+ use_fast=False,
29
+ )
30
+ tokenizer_two = MultiTokenCLIPTokenizer.from_pretrained(
31
+ args.pretrained_model_name_or_path,
32
+ subfolder="tokenizer_2",
33
+ revision=args.revision,
34
+ use_fast=False,
35
+ )
36
+ text_encoder_cls_one = CustomCLIPTextModel
37
+ text_encoder_cls_two = CustomCLIPTextModelWithProjection
38
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
39
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
40
+ )
41
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
42
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
43
+ )
44
+
45
+ OUT_DIMS = 768 + 1280 # 2048
46
+ simple_mapper = TokenMapper(args.num_parts,
47
+ args.num_k_per_part,
48
+ OUT_DIMS,
49
+ args.projection_nlayers)
50
+ return text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, simple_mapper
51
+
52
+
53
+ class DreamCreatureSDXLPipeline(StableDiffusionXLPipeline):
54
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: MultiTokenCLIPTokenizer):
55
+ r"""
56
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
57
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
58
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
59
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
60
+
61
+ Parameters:
62
+ prompt (`str`):
63
+ The prompt to guide the image generation.
64
+ tokenizer (`PreTrainedTokenizer`):
65
+ The tokenizer responsible for encoding the prompt into input tokens.
66
+
67
+ Returns:
68
+ `str`: The converted prompt
69
+ """
70
+ if hasattr(self, 'replace_token'):
71
+ replace_token = self.replace_token
72
+ else:
73
+ replace_token = True
74
+
75
+ new_caption, code, parts_i = convert_prompt(prompt, replace_token)
76
+
77
+ if hasattr(self, 'verbose') and self.verbose:
78
+ print(new_caption)
79
+
80
+ return new_caption, code, parts_i
81
+
82
+ def compute_prompt_embeddings(self, prompts, text_encoder, tokenizer, device, index):
83
+ # textual inversion: procecss multi-vector tokens if necessary
84
+ if not isinstance(prompts, List):
85
+ prompts = [prompts]
86
+
87
+ prompt_embeds_concat = []
88
+ pooled_prompt_embeds_concat = []
89
+ for prompt in prompts:
90
+ prompt, code, parts_i = self.maybe_convert_prompt(prompt, tokenizer)
91
+
92
+ if hasattr(self, 'replace_token'):
93
+ replace_token = self.replace_token
94
+ else:
95
+ replace_token = True
96
+
97
+ text_inputs = tokenizer(
98
+ prompt,
99
+ replace_token=replace_token,
100
+ padding="max_length",
101
+ max_length=self.tokenizer.model_max_length,
102
+ truncation=True,
103
+ return_tensors="pt",
104
+ )
105
+ text_input_ids = text_inputs.input_ids
106
+ if hasattr(self, 'verbose') and self.verbose:
107
+ print(text_input_ids)
108
+
109
+ untruncated_ids = tokenizer(prompt,
110
+ replace_token=replace_token,
111
+ padding="longest",
112
+ return_tensors="pt").input_ids
113
+
114
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
115
+ text_input_ids, untruncated_ids
116
+ ):
117
+ removed_text = tokenizer.batch_decode(
118
+ untruncated_ids[:, tokenizer.model_max_length - 1: -1]
119
+ )
120
+ logger.warning(
121
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
122
+ f" {tokenizer.model_max_length} tokens: {removed_text}"
123
+ )
124
+
125
+ if hasattr(text_encoder.config, "use_attention_mask") and text_encoder.config.use_attention_mask:
126
+ attention_mask = text_inputs.attention_mask.to(device)
127
+ else:
128
+ attention_mask = None
129
+
130
+ if code is None:
131
+ modified_hs = None
132
+ else:
133
+ placeholder_token_ids = self.placeholder_token_ids
134
+ placeholder_token_ids = [placeholder_token_ids[i] for i in parts_i] # follow the order of prompt's i
135
+ mapper_outputs = self.simple_mapper(code.unsqueeze(0).to(device), torch.tensor(parts_i).to(device))
136
+ if index == 0: # first encoder
137
+ mapper_outputs = mapper_outputs[..., :768]
138
+ else:
139
+ mapper_outputs = mapper_outputs[..., 768:]
140
+ modified_hs = text_encoder.text_model.forward_embeddings_with_mapper(text_input_ids.to(device),
141
+ None,
142
+ mapper_outputs,
143
+ placeholder_token_ids)
144
+
145
+ prompt_embeds = text_encoder(
146
+ text_input_ids.to(device),
147
+ output_hidden_states=True,
148
+ attention_mask=attention_mask,
149
+ hidden_states=modified_hs
150
+ )
151
+ # We are only ALWAYS interested in the pooled output of the final text encoder
152
+ pooled_prompt_embeds = prompt_embeds[0]
153
+ prompt_embeds = prompt_embeds.hidden_states[-2]
154
+
155
+ pooled_prompt_embeds_concat.append(pooled_prompt_embeds)
156
+ prompt_embeds_concat.append(prompt_embeds)
157
+
158
+ if len(prompt_embeds_concat) == 1:
159
+ return prompt_embeds_concat[0], pooled_prompt_embeds_concat[0]
160
+ else:
161
+ return torch.cat(prompt_embeds_concat, dim=0), torch.cat(pooled_prompt_embeds_concat, dim=0)
162
+
163
+ def encode_prompt(
164
+ self,
165
+ prompt: str,
166
+ prompt_2: Optional[str] = None,
167
+ device: Optional[torch.device] = None,
168
+ num_images_per_prompt: int = 1,
169
+ do_classifier_free_guidance: bool = True,
170
+ negative_prompt: Optional[str] = None,
171
+ negative_prompt_2: Optional[str] = None,
172
+ prompt_embeds: Optional[torch.FloatTensor] = None,
173
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
174
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
175
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
176
+ lora_scale: Optional[float] = None,
177
+ clip_skip: Optional[int] = None,
178
+ ):
179
+ r"""
180
+ Encodes the prompt into text encoder hidden states.
181
+
182
+ Args:
183
+ prompt (`str` or `List[str]`, *optional*):
184
+ prompt to be encoded
185
+ prompt_2 (`str` or `List[str]`, *optional*):
186
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
187
+ used in both text-encoders
188
+ device: (`torch.device`):
189
+ torch device
190
+ num_images_per_prompt (`int`):
191
+ number of images that should be generated per prompt
192
+ do_classifier_free_guidance (`bool`):
193
+ whether to use classifier free guidance or not
194
+ negative_prompt (`str` or `List[str]`, *optional*):
195
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
196
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
197
+ less than `1`).
198
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
199
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
200
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
201
+ prompt_embeds (`torch.FloatTensor`, *optional*):
202
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
203
+ provided, text embeddings will be generated from `prompt` input argument.
204
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
205
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
206
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
207
+ argument.
208
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
209
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
210
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
211
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
212
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
213
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
214
+ input argument.
215
+ lora_scale (`float`, *optional*):
216
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
217
+ clip_skip (`int`, *optional*):
218
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
219
+ the output of the pre-final layer will be used for computing the prompt embeddings.
220
+ """
221
+ device = device or self._execution_device
222
+
223
+ # set lora scale so that monkey patched LoRA
224
+ # function of text encoder can correctly access it
225
+ if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
226
+ self._lora_scale = lora_scale
227
+
228
+ # dynamically adjust the LoRA scale
229
+ if self.text_encoder is not None:
230
+ if not USE_PEFT_BACKEND:
231
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
232
+ else:
233
+ scale_lora_layers(self.text_encoder, lora_scale)
234
+
235
+ if self.text_encoder_2 is not None:
236
+ if not USE_PEFT_BACKEND:
237
+ adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale)
238
+ else:
239
+ scale_lora_layers(self.text_encoder_2, lora_scale)
240
+
241
+ prompt = [prompt] if isinstance(prompt, str) else prompt
242
+
243
+ if prompt is not None:
244
+ batch_size = len(prompt)
245
+ else:
246
+ batch_size = prompt_embeds.shape[0]
247
+
248
+ # Define tokenizers and text encoders
249
+ tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2]
250
+ text_encoders = (
251
+ [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2]
252
+ )
253
+
254
+ if prompt_embeds is None:
255
+ prompt_2 = prompt_2 or prompt
256
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
257
+
258
+ # textual inversion: procecss multi-vector tokens if necessary
259
+ prompt_embeds_list = []
260
+ prompts = [prompt, prompt_2]
261
+ index = 0
262
+ for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders):
263
+ prompt_embeds, pooled_prompt_embeds = self.compute_prompt_embeddings(prompt,
264
+ text_encoder,
265
+ tokenizer,
266
+ device,
267
+ index)
268
+ prompt_embeds_list.append(prompt_embeds)
269
+ index += 1
270
+
271
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
272
+
273
+ # get unconditional embeddings for classifier free guidance
274
+ zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt
275
+ if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt:
276
+ negative_prompt_embeds = torch.zeros_like(prompt_embeds)
277
+ negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds)
278
+ elif do_classifier_free_guidance and negative_prompt_embeds is None:
279
+ negative_prompt = negative_prompt or ""
280
+ negative_prompt_2 = negative_prompt_2 or negative_prompt
281
+
282
+ # normalize str to list
283
+ negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
284
+ negative_prompt_2 = (
285
+ batch_size * [negative_prompt_2] if isinstance(negative_prompt_2, str) else negative_prompt_2
286
+ )
287
+
288
+ uncond_tokens: List[str]
289
+ if prompt is not None and type(prompt) is not type(negative_prompt):
290
+ raise TypeError(
291
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
292
+ f" {type(prompt)}."
293
+ )
294
+ elif batch_size != len(negative_prompt):
295
+ raise ValueError(
296
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
297
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
298
+ " the batch size of `prompt`."
299
+ )
300
+ else:
301
+ uncond_tokens = [negative_prompt, negative_prompt_2]
302
+
303
+ negative_prompt_embeds_list = []
304
+ index = 0
305
+ for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders):
306
+ negative_prompt_embeds, negative_pooled_prompt_embeds = self.compute_prompt_embeddings(negative_prompt,
307
+ text_encoder,
308
+ tokenizer,
309
+ device,
310
+ index)
311
+ negative_prompt_embeds_list.append(negative_prompt_embeds)
312
+ index += 1
313
+
314
+ negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1)
315
+
316
+ if self.text_encoder_2 is not None:
317
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
318
+ else:
319
+ prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device)
320
+
321
+ bs_embed, seq_len, _ = prompt_embeds.shape
322
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
323
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
324
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
325
+
326
+ if do_classifier_free_guidance:
327
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
328
+ seq_len = negative_prompt_embeds.shape[1]
329
+
330
+ if self.text_encoder_2 is not None:
331
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device)
332
+ else:
333
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device)
334
+
335
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
336
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
337
+
338
+ pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
339
+ bs_embed * num_images_per_prompt, -1
340
+ )
341
+ if do_classifier_free_guidance:
342
+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view(
343
+ bs_embed * num_images_per_prompt, -1
344
+ )
345
+
346
+ if self.text_encoder is not None:
347
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
348
+ # Retrieve the original scale by scaling back the LoRA layers
349
+ unscale_lora_layers(self.text_encoder, lora_scale)
350
+
351
+ if self.text_encoder_2 is not None:
352
+ if isinstance(self, StableDiffusionXLLoraLoaderMixin) and USE_PEFT_BACKEND:
353
+ # Retrieve the original scale by scaling back the LoRA layers
354
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
355
+
356
+ return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
357
+
358
+ @torch.no_grad()
359
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
360
+ def __call__(
361
+ self,
362
+ prompt: Union[str, List[str]] = None,
363
+ prompt_2: Optional[Union[str, List[str]]] = None,
364
+ height: Optional[int] = None,
365
+ width: Optional[int] = None,
366
+ num_inference_steps: int = 50,
367
+ timesteps: List[int] = None,
368
+ denoising_end: Optional[float] = None,
369
+ guidance_scale: float = 5.0,
370
+ negative_prompt: Optional[Union[str, List[str]]] = None,
371
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
372
+ num_images_per_prompt: Optional[int] = 1,
373
+ eta: float = 0.0,
374
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
375
+ latents: Optional[torch.FloatTensor] = None,
376
+ prompt_embeds: Optional[torch.FloatTensor] = None,
377
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
378
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
379
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
380
+ ip_adapter_image: Optional[PipelineImageInput] = None,
381
+ output_type: Optional[str] = "pil",
382
+ return_dict: bool = True,
383
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
384
+ guidance_rescale: float = 0.0,
385
+ original_size: Optional[Tuple[int, int]] = None,
386
+ crops_coords_top_left: Tuple[int, int] = (0, 0),
387
+ target_size: Optional[Tuple[int, int]] = None,
388
+ negative_original_size: Optional[Tuple[int, int]] = None,
389
+ negative_crops_coords_top_left: Tuple[int, int] = (0, 0),
390
+ negative_target_size: Optional[Tuple[int, int]] = None,
391
+ clip_skip: Optional[int] = None,
392
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
393
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
394
+ get_attention_map: bool = False,
395
+ **kwargs,
396
+ ):
397
+ r"""
398
+ Function invoked when calling the pipeline for generation.
399
+
400
+ Args:
401
+ prompt (`str` or `List[str]`, *optional*):
402
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
403
+ instead.
404
+ prompt_2 (`str` or `List[str]`, *optional*):
405
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
406
+ used in both text-encoders
407
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
408
+ The height in pixels of the generated image. This is set to 1024 by default for the best results.
409
+ Anything below 512 pixels won't work well for
410
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
411
+ and checkpoints that are not specifically fine-tuned on low resolutions.
412
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
413
+ The width in pixels of the generated image. This is set to 1024 by default for the best results.
414
+ Anything below 512 pixels won't work well for
415
+ [stabilityai/stable-diffusion-xl-base-1.0](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
416
+ and checkpoints that are not specifically fine-tuned on low resolutions.
417
+ num_inference_steps (`int`, *optional*, defaults to 50):
418
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
419
+ expense of slower inference.
420
+ timesteps (`List[int]`, *optional*):
421
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
422
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
423
+ passed will be used. Must be in descending order.
424
+ denoising_end (`float`, *optional*):
425
+ When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be
426
+ completed before it is intentionally prematurely terminated. As a result, the returned sample will
427
+ still retain a substantial amount of noise as determined by the discrete timesteps selected by the
428
+ scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a
429
+ "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image
430
+ Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output)
431
+ guidance_scale (`float`, *optional*, defaults to 5.0):
432
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
433
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
434
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
435
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
436
+ usually at the expense of lower image quality.
437
+ negative_prompt (`str` or `List[str]`, *optional*):
438
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
439
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
440
+ less than `1`).
441
+ negative_prompt_2 (`str` or `List[str]`, *optional*):
442
+ The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and
443
+ `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders
444
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
445
+ The number of images to generate per prompt.
446
+ eta (`float`, *optional*, defaults to 0.0):
447
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
448
+ [`schedulers.DDIMScheduler`], will be ignored for others.
449
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
450
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
451
+ to make generation deterministic.
452
+ latents (`torch.FloatTensor`, *optional*):
453
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
454
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
455
+ tensor will ge generated by sampling using the supplied random `generator`.
456
+ prompt_embeds (`torch.FloatTensor`, *optional*):
457
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
458
+ provided, text embeddings will be generated from `prompt` input argument.
459
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
460
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
461
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
462
+ argument.
463
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
464
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
465
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
466
+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
467
+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
468
+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
469
+ input argument.
470
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
471
+ output_type (`str`, *optional*, defaults to `"pil"`):
472
+ The output format of the generate image. Choose between
473
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
474
+ return_dict (`bool`, *optional*, defaults to `True`):
475
+ Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
476
+ of a plain tuple.
477
+ cross_attention_kwargs (`dict`, *optional*):
478
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
479
+ `self.processor` in
480
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
481
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
482
+ Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
483
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
484
+ [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
485
+ Guidance rescale factor should fix overexposure when using zero terminal SNR.
486
+ original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
487
+ If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
488
+ `original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
489
+ explained in section 2.2 of
490
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
491
+ crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
492
+ `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position
493
+ `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting
494
+ `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of
495
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
496
+ target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
497
+ For most cases, `target_size` should be set to the desired height and width of the generated image. If
498
+ not specified it will default to `(height, width)`. Part of SDXL's micro-conditioning as explained in
499
+ section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
500
+ negative_original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
501
+ To negatively condition the generation process based on a specific image resolution. Part of SDXL's
502
+ micro-conditioning as explained in section 2.2 of
503
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
504
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
505
+ negative_crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)):
506
+ To negatively condition the generation process based on a specific crop coordinates. Part of SDXL's
507
+ micro-conditioning as explained in section 2.2 of
508
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
509
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
510
+ negative_target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
511
+ To negatively condition the generation process based on a target image resolution. It should be as same
512
+ as the `target_size` for most cases. Part of SDXL's micro-conditioning as explained in section 2.2 of
513
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). For more
514
+ information, refer to this issue thread: https://github.com/huggingface/diffusers/issues/4208.
515
+ callback_on_step_end (`Callable`, *optional*):
516
+ A function that calls at the end of each denoising steps during the inference. The function is called
517
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
518
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
519
+ `callback_on_step_end_tensor_inputs`.
520
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
521
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
522
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
523
+ `._callback_tensor_inputs` attribute of your pipeline class.
524
+
525
+ Examples:
526
+
527
+ Returns:
528
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`:
529
+ [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a
530
+ `tuple`. When returning a tuple, the first element is a list with the generated images.
531
+ """
532
+
533
+ callback = kwargs.pop("callback", None)
534
+ callback_steps = kwargs.pop("callback_steps", None)
535
+
536
+ if callback is not None:
537
+ deprecate(
538
+ "callback",
539
+ "1.0.0",
540
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
541
+ )
542
+ if callback_steps is not None:
543
+ deprecate(
544
+ "callback_steps",
545
+ "1.0.0",
546
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider use `callback_on_step_end`",
547
+ )
548
+
549
+ # 0. Default height and width to unet
550
+ height = height or self.default_sample_size * self.vae_scale_factor
551
+ width = width or self.default_sample_size * self.vae_scale_factor
552
+
553
+ original_size = original_size or (height, width)
554
+ target_size = target_size or (height, width)
555
+
556
+ # 1. Check inputs. Raise error if not correct
557
+ self.check_inputs(
558
+ prompt,
559
+ prompt_2,
560
+ height,
561
+ width,
562
+ callback_steps,
563
+ negative_prompt,
564
+ negative_prompt_2,
565
+ prompt_embeds,
566
+ negative_prompt_embeds,
567
+ pooled_prompt_embeds,
568
+ negative_pooled_prompt_embeds,
569
+ callback_on_step_end_tensor_inputs,
570
+ )
571
+
572
+ self._guidance_scale = guidance_scale
573
+ self._guidance_rescale = guidance_rescale
574
+ self._clip_skip = clip_skip
575
+ self._cross_attention_kwargs = cross_attention_kwargs
576
+ self._denoising_end = denoising_end
577
+
578
+ # 2. Define call parameters
579
+ if prompt is not None and isinstance(prompt, str):
580
+ batch_size = 1
581
+ elif prompt is not None and isinstance(prompt, list):
582
+ batch_size = len(prompt)
583
+ else:
584
+ batch_size = prompt_embeds.shape[0]
585
+
586
+ device = self._execution_device
587
+
588
+ # 3. Encode input prompt
589
+ lora_scale = (
590
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
591
+ )
592
+
593
+ (
594
+ prompt_embeds,
595
+ negative_prompt_embeds,
596
+ pooled_prompt_embeds,
597
+ negative_pooled_prompt_embeds,
598
+ ) = self.encode_prompt(
599
+ prompt=prompt,
600
+ prompt_2=prompt_2,
601
+ device=device,
602
+ num_images_per_prompt=num_images_per_prompt,
603
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
604
+ negative_prompt=negative_prompt,
605
+ negative_prompt_2=negative_prompt_2,
606
+ prompt_embeds=prompt_embeds,
607
+ negative_prompt_embeds=negative_prompt_embeds,
608
+ pooled_prompt_embeds=pooled_prompt_embeds,
609
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
610
+ lora_scale=lora_scale,
611
+ clip_skip=self.clip_skip,
612
+ )
613
+
614
+ # 4. Prepare timesteps
615
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
616
+
617
+ # 5. Prepare latent variables
618
+ num_channels_latents = self.unet.config.in_channels
619
+ latents = self.prepare_latents(
620
+ batch_size * num_images_per_prompt,
621
+ num_channels_latents,
622
+ height,
623
+ width,
624
+ prompt_embeds.dtype,
625
+ device,
626
+ generator,
627
+ latents,
628
+ )
629
+
630
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
631
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
632
+
633
+ # 7. Prepare added time ids & embeddings
634
+ add_text_embeds = pooled_prompt_embeds
635
+ if self.text_encoder_2 is None:
636
+ text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
637
+ else:
638
+ text_encoder_projection_dim = self.text_encoder_2.config.projection_dim
639
+
640
+ add_time_ids = self._get_add_time_ids(
641
+ original_size,
642
+ crops_coords_top_left,
643
+ target_size,
644
+ dtype=prompt_embeds.dtype,
645
+ text_encoder_projection_dim=text_encoder_projection_dim,
646
+ )
647
+ if negative_original_size is not None and negative_target_size is not None:
648
+ negative_add_time_ids = self._get_add_time_ids(
649
+ negative_original_size,
650
+ negative_crops_coords_top_left,
651
+ negative_target_size,
652
+ dtype=prompt_embeds.dtype,
653
+ text_encoder_projection_dim=text_encoder_projection_dim,
654
+ )
655
+ else:
656
+ negative_add_time_ids = add_time_ids
657
+
658
+ if self.do_classifier_free_guidance:
659
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
660
+ add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
661
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
662
+
663
+ prompt_embeds = prompt_embeds.to(device)
664
+ add_text_embeds = add_text_embeds.to(device)
665
+ add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
666
+
667
+ if ip_adapter_image is not None:
668
+ output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
669
+ image_embeds, negative_image_embeds = self.encode_image(
670
+ ip_adapter_image, device, num_images_per_prompt, output_hidden_state
671
+ )
672
+ if self.do_classifier_free_guidance:
673
+ image_embeds = torch.cat([negative_image_embeds, image_embeds])
674
+ image_embeds = image_embeds.to(device)
675
+
676
+ # 8. Denoising loop
677
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
678
+
679
+ # 8.1 Apply denoising_end
680
+ if (
681
+ self.denoising_end is not None
682
+ and isinstance(self.denoising_end, float)
683
+ and self.denoising_end > 0
684
+ and self.denoising_end < 1
685
+ ):
686
+ discrete_timestep_cutoff = int(
687
+ round(
688
+ self.scheduler.config.num_train_timesteps
689
+ - (self.denoising_end * self.scheduler.config.num_train_timesteps)
690
+ )
691
+ )
692
+ num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps)))
693
+ timesteps = timesteps[:num_inference_steps]
694
+
695
+ # 9. Optionally get Guidance Scale Embedding
696
+ timestep_cond = None
697
+ if self.unet.config.time_cond_proj_dim is not None:
698
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
699
+ timestep_cond = self.get_guidance_scale_embedding(
700
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
701
+ ).to(device=device, dtype=latents.dtype)
702
+
703
+ self._num_timesteps = len(timesteps)
704
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
705
+ if get_attention_map:
706
+ attn_maps = {} # each t one attn map
707
+
708
+ for i, t in enumerate(timesteps):
709
+ # expand the latents if we are doing classifier free guidance
710
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
711
+
712
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
713
+
714
+ # predict the noise residual
715
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
716
+ if ip_adapter_image is not None:
717
+ added_cond_kwargs["image_embeds"] = image_embeds
718
+ noise_pred = self.unet(
719
+ latent_model_input,
720
+ t,
721
+ encoder_hidden_states=prompt_embeds,
722
+ timestep_cond=timestep_cond,
723
+ cross_attention_kwargs=self.cross_attention_kwargs,
724
+ added_cond_kwargs=added_cond_kwargs,
725
+ return_dict=False,
726
+ )[0]
727
+ if get_attention_map:
728
+ attn_probs = {}
729
+
730
+ for name, module in self.unet.named_modules():
731
+ if isinstance(module, Attention) and module.attn_probs is not None:
732
+ # take 1 because we are taking the noise_pred_text
733
+ a = module.attn_probs[1].mean(dim=0) # (2,Head,H,W,77)->(H,W,77)
734
+ attn_probs[name] = a
735
+
736
+ attn_maps[i] = attn_probs
737
+
738
+ # perform guidance
739
+ if self.do_classifier_free_guidance:
740
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
741
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
742
+
743
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
744
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
745
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
746
+
747
+ # compute the previous noisy sample x_t -> x_t-1
748
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
749
+
750
+ if callback_on_step_end is not None:
751
+ callback_kwargs = {}
752
+ for k in callback_on_step_end_tensor_inputs:
753
+ callback_kwargs[k] = locals()[k]
754
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
755
+
756
+ latents = callback_outputs.pop("latents", latents)
757
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
758
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
759
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
760
+ negative_pooled_prompt_embeds = callback_outputs.pop(
761
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
762
+ )
763
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
764
+ negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
765
+
766
+ # call the callback, if provided
767
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
768
+ progress_bar.update()
769
+ if callback is not None and i % callback_steps == 0:
770
+ step_idx = i // getattr(self.scheduler, "order", 1)
771
+ callback(step_idx, t, latents)
772
+
773
+ if XLA_AVAILABLE:
774
+ import torch_xla.core.xla_model as xm
775
+ xm.mark_step()
776
+
777
+ if get_attention_map:
778
+ output_maps = {}
779
+ for name in attn_probs.keys():
780
+ timeavg_maps = []
781
+ for i in attn_maps.keys():
782
+ timeavg_maps.append(attn_maps[i][name])
783
+ timeavg_maps = torch.stack(timeavg_maps, dim=0).mean(dim=0)
784
+ output_maps[name] = timeavg_maps
785
+
786
+ avg_attn_map = []
787
+ for name in attn_probs:
788
+ avg_attn_map.append(attn_probs[name])
789
+ avg_attn_map = torch.stack(avg_attn_map, dim=0).mean(dim=0) # (5,B,H,W,77) -> (B,H,W,77)
790
+ output_maps['avg'] = avg_attn_map
791
+
792
+ del attn_maps
793
+ del attn_probs
794
+ self.attn_maps = output_maps
795
+
796
+ if not output_type == "latent":
797
+ # make sure the VAE is in float32 mode, as it overflows in float16
798
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
799
+
800
+ if needs_upcasting:
801
+ self.upcast_vae()
802
+ latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
803
+
804
+ image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
805
+
806
+ # cast back to fp16 if needed
807
+ if needs_upcasting:
808
+ self.vae.to(dtype=torch.float16)
809
+ else:
810
+ image = latents
811
+
812
+ if not output_type == "latent":
813
+ # apply watermark if available
814
+ if self.watermark is not None:
815
+ image = self.watermark.apply_watermark(image)
816
+
817
+ image = self.image_processor.postprocess(image, output_type=output_type)
818
+
819
+ # Offload all models
820
+ self.maybe_free_model_hooks()
821
+
822
+ if not return_dict:
823
+ return (image,)
824
+
825
+ return StableDiffusionXLPipelineOutput(images=image)
826
+
827
+
828
+ def create_args_xl(output_dir, num_parts=8, num_k_per_part=256):
829
+ args = OmegaConf.create({
830
+ 'pretrained_model_name_or_path': 'stabilityai/stable-diffusion-xl-base-1.0',
831
+ 'num_parts': num_parts,
832
+ 'num_k_per_part': num_k_per_part,
833
+ 'revision': None,
834
+ 'variant': None,
835
+ 'rank': 4,
836
+ 'projection_nlayers': 1,
837
+ 'output_dir': output_dir
838
+ })
839
+ folders = sorted(os.listdir(args.output_dir))
840
+ cps = [int(f.split('-')[1]) for f in folders if 'checkpoint' in f and '.ipynb' not in f]
841
+ maxcp = max(cps)
842
+
843
+ args.maxcp = maxcp
844
+ return args
845
+
846
+
847
+ def load_pipeline_xl(args, weight_dtype=torch.float16, device=torch.device('cuda')):
848
+ text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, simple_mapper = init_for_pipeline(args)
849
+ placeholder_token = "<part>"
850
+ initializer_token = None
851
+ placeholder_token_ids_one = add_tokens(tokenizer_one,
852
+ text_encoder_one,
853
+ placeholder_token,
854
+ args.num_parts,
855
+ initializer_token)
856
+ placeholder_token_ids_two = add_tokens(tokenizer_two,
857
+ text_encoder_two,
858
+ placeholder_token,
859
+ args.num_parts,
860
+ initializer_token)
861
+
862
+ vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
863
+ pipeline = DreamCreatureSDXLPipeline.from_pretrained(
864
+ args.pretrained_model_name_or_path,
865
+ vae=vae,
866
+ tokenizer=tokenizer_one,
867
+ tokenizer_2=tokenizer_two,
868
+ text_encoder=text_encoder_one,
869
+ text_encoder_2=text_encoder_two,
870
+ revision=args.revision,
871
+ variant=args.variant,
872
+ torch_dtype=weight_dtype,
873
+ )
874
+ pipeline.placeholder_token_ids = placeholder_token_ids_one
875
+ pipeline.replace_token = False
876
+ pipeline.simple_mapper = simple_mapper
877
+ pipeline.simple_mapper.load_state_dict(torch.load(args.output_dir + f'/checkpoint-{args.maxcp}/hash_mapper.pth',
878
+ map_location='cpu'))
879
+
880
+ pipeline.simple_mapper.to(device)
881
+ pipeline = pipeline.to(device)
882
+
883
+ # load attention processors
884
+ pipeline.load_lora_weights(args.output_dir + f'/checkpoint-{args.maxcp}')
885
+
886
+ def setup_attn_processors(unet, attn_size):
887
+ attn_procs = {}
888
+ for name in unet.attn_processors.keys():
889
+ attn_procs[name] = AttnProcessorCustom(attn_size)
890
+ unet.set_attn_processor(attn_procs)
891
+
892
+ pipeline = pipeline.to(weight_dtype)
893
+ setup_attn_processors(pipeline.unet, 16)
894
+
895
+ return pipeline
dreamcreature/text_encoder.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict
2
+ from typing import Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.utils.checkpoint
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.models.clip.modeling_clip import (CLIPTextConfig,
8
+ CLIPTextModel,
9
+ CLIPTextModelWithProjection,
10
+ CLIPTextModelOutput)
11
+ from transformers.models.clip.modeling_clip import (CLIPTextTransformer,
12
+ _prepare_4d_attention_mask,
13
+ _create_4d_causal_attention_mask)
14
+
15
+
16
+ class CustomCLIPTextModel(CLIPTextModel):
17
+ """ Modification of CLIPTextModel to use our NeTI mapper for computing the embeddings of the concept. """
18
+
19
+ def __init__(self, config: CLIPTextConfig):
20
+ super().__init__(config)
21
+ self.text_model = CustomCLIPTextTransformer(config)
22
+ self.post_init()
23
+
24
+ def forward(self, input_ids: Optional[torch.Tensor] = None,
25
+ attention_mask: Optional[torch.Tensor] = None,
26
+ position_ids: Optional[torch.Tensor] = None,
27
+ output_attentions: Optional[bool] = None,
28
+ output_hidden_states: Optional[bool] = None,
29
+ return_dict: Optional[bool] = None,
30
+ hidden_states: Optional[torch.Tensor] = None) -> Union[Tuple, BaseModelOutputWithPooling]:
31
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
32
+
33
+ return self.text_model.forward(
34
+ input_ids=input_ids,
35
+ attention_mask=attention_mask,
36
+ position_ids=position_ids,
37
+ output_attentions=output_attentions,
38
+ output_hidden_states=output_hidden_states,
39
+ return_dict=return_dict,
40
+ hidden_states=hidden_states
41
+ )
42
+
43
+
44
+ class CustomCLIPTextTransformer(CLIPTextTransformer):
45
+ def forward_embeddings(self, input_ids, position_ids, inputs_embeds):
46
+ input_shape = input_ids.size()
47
+ input_ids = input_ids.view(-1, input_shape[-1])
48
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids, inputs_embeds=inputs_embeds)
49
+ return hidden_states
50
+
51
+ def forward_embeddings_with_mapper(self, input_ids, position_ids, mapper_outputs, placeholder_token_ids):
52
+ inputs_embeds = self.embeddings.token_embedding(input_ids)
53
+ dtype = inputs_embeds.dtype
54
+
55
+ offset = defaultdict(int)
56
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
57
+ # Overwrite the index of the placeholder token with the mapper output for each entry in the batch
58
+ # learnable_idxs = (input_ids == placeholder_token_id).nonzero(as_tuple=True)[1]
59
+ # inputs_embeds[torch.arange(input_ids.shape[0]), learnable_idxs] = mapper_outputs[:, i].to(dtype)
60
+
61
+ for bi in range(input_ids.shape[0]):
62
+ learnable_idx = (input_ids[bi] == placeholder_token_id).nonzero(as_tuple=True)[0]
63
+
64
+ if len(learnable_idx) != 0: # only assign if found
65
+ if len(learnable_idx) == 1:
66
+ offset_learnable_idx = learnable_idx
67
+ else: # if there is two and above.
68
+ start = offset[(bi, placeholder_token_id)]
69
+ offset_learnable_idx = learnable_idx[start:start + 1]
70
+ offset[(bi, placeholder_token_id)] += 1
71
+
72
+ # print(i, offset_learnable_idx)
73
+
74
+ # before = inputs_embeds[bi, learnable_idx]
75
+ inputs_embeds[bi, offset_learnable_idx] = mapper_outputs[bi, i].to(dtype)
76
+ # after = inputs_embeds[bi, learnable_idx]
77
+
78
+ return self.forward_embeddings(input_ids, position_ids, inputs_embeds)
79
+
80
+ def forward(
81
+ self,
82
+ input_ids: Optional[torch.Tensor] = None,
83
+ attention_mask: Optional[torch.Tensor] = None,
84
+ position_ids: Optional[torch.Tensor] = None,
85
+ output_attentions: Optional[bool] = None,
86
+ output_hidden_states: Optional[bool] = None,
87
+ return_dict: Optional[bool] = None,
88
+ hidden_states: Optional[torch.Tensor] = None
89
+ ) -> Union[Tuple, BaseModelOutputWithPooling]:
90
+ r"""
91
+ Returns:
92
+
93
+ """
94
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
95
+ output_hidden_states = (
96
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
97
+ )
98
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
99
+
100
+ if input_ids is None:
101
+ raise ValueError("You have to specify either input_ids")
102
+
103
+ input_shape = input_ids.size()
104
+ input_ids = input_ids.view(-1, input_shape[-1])
105
+ if hidden_states is None:
106
+ hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
107
+
108
+ # CLIP's text model uses causal mask, prepare it here.
109
+ # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
110
+ causal_attention_mask = _create_4d_causal_attention_mask(
111
+ input_shape, hidden_states.dtype, device=hidden_states.device
112
+ )
113
+ # expand attention_mask
114
+ if attention_mask is not None:
115
+ # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
116
+ attention_mask = _prepare_4d_attention_mask(attention_mask, hidden_states.dtype)
117
+
118
+ # # bsz, seq_len = input_shape
119
+ # # CLIP's text model uses causal mask, prepare it here.
120
+ # # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
121
+ # causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
122
+ # # causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
123
+ # # hidden_states.device
124
+ # # )
125
+ #
126
+ # # expand attention_mask
127
+ # if attention_mask is not None:
128
+ # # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
129
+ # attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
130
+
131
+ encoder_outputs = self.encoder(
132
+ inputs_embeds=hidden_states,
133
+ attention_mask=attention_mask,
134
+ causal_attention_mask=causal_attention_mask,
135
+ output_attentions=output_attentions,
136
+ output_hidden_states=output_hidden_states,
137
+ return_dict=return_dict,
138
+ )
139
+
140
+ last_hidden_state = encoder_outputs[0]
141
+ last_hidden_state = self.final_layer_norm(last_hidden_state)
142
+
143
+ # text_embeds.shape = [batch_size, sequence_length, transformer.width]
144
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
145
+ # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
146
+ pooled_output = last_hidden_state[
147
+ torch.arange(last_hidden_state.shape[0], device=input_ids.device), input_ids.to(torch.int).argmax(dim=-1)
148
+ ]
149
+
150
+ if not return_dict:
151
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
152
+
153
+ return BaseModelOutputWithPooling(
154
+ last_hidden_state=last_hidden_state,
155
+ pooler_output=pooled_output,
156
+ hidden_states=encoder_outputs.hidden_states,
157
+ attentions=encoder_outputs.attentions,
158
+ )
159
+
160
+
161
+ class CustomCLIPTextModelWithProjection(CLIPTextModelWithProjection):
162
+ """ Modification of CLIPTextModel to use our NeTI mapper for computing the embeddings of the concept. """
163
+
164
+ def __init__(self, config: CLIPTextConfig):
165
+ super().__init__(config)
166
+ self.text_model = CustomCLIPTextTransformer(config)
167
+ self.post_init()
168
+
169
+ def forward(self, input_ids: Optional[torch.Tensor] = None,
170
+ attention_mask: Optional[torch.Tensor] = None,
171
+ position_ids: Optional[torch.Tensor] = None,
172
+ output_attentions: Optional[bool] = None,
173
+ output_hidden_states: Optional[bool] = None,
174
+ return_dict: Optional[bool] = None,
175
+ hidden_states: Optional[torch.Tensor] = None) -> Union[Tuple, CLIPTextModelOutput]:
176
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
177
+
178
+ text_outputs = self.text_model(
179
+ input_ids=input_ids,
180
+ attention_mask=attention_mask,
181
+ position_ids=position_ids,
182
+ output_attentions=output_attentions,
183
+ output_hidden_states=output_hidden_states,
184
+ return_dict=return_dict,
185
+ hidden_states=hidden_states
186
+ )
187
+
188
+ pooled_output = text_outputs[1]
189
+
190
+ text_embeds = self.text_projection(pooled_output)
191
+
192
+ if not return_dict:
193
+ outputs = (text_embeds, text_outputs[0]) + text_outputs[2:]
194
+ return tuple(output for output in outputs if output is not None)
195
+
196
+ return CLIPTextModelOutput(
197
+ text_embeds=text_embeds,
198
+ last_hidden_state=text_outputs.last_hidden_state,
199
+ hidden_states=text_outputs.hidden_states,
200
+ attentions=text_outputs.attentions,
201
+ )
dreamcreature/tokenizer.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ The main idea for this code is to provide a way for users to not need to bother with the hassle of multiple tokens for a concept by typing
3
+ a photo of <concept>_0 <concept>_1 ... and so on
4
+ and instead just do
5
+ a photo of <concept>
6
+ which gets translated to the above. This needs to work for both inference and training.
7
+ For inference,
8
+ the tokenizer encodes the text. So, we would want logic for our tokenizer to replace the placeholder token with
9
+ it's underlying vectors
10
+ For training,
11
+ we would want to abstract away some logic like
12
+ 1. Adding tokens
13
+ 2. Updating gradient mask
14
+ 3. Saving embeddings
15
+ to our Util class here.
16
+ so
17
+ TODO:
18
+ 1. have tokenizer keep track of concept, multiconcept pairs and replace during encode call x
19
+ 2. have mechanism for adding tokens x
20
+ 3. have mech for saving emebeddings x
21
+ 4. get mask to update x
22
+ 5. Loading tokens from embedding x
23
+ 6. Integrate to training x
24
+ 7. Test
25
+ """
26
+ import copy
27
+ import random
28
+
29
+ from transformers import CLIPTokenizer
30
+
31
+
32
+ class MultiTokenCLIPTokenizer(CLIPTokenizer):
33
+ def __init__(self, *args, **kwargs):
34
+ super().__init__(*args, **kwargs)
35
+ self.token_map = {}
36
+
37
+ def try_adding_tokens(self, placeholder_token, *args, **kwargs):
38
+ num_added_tokens = super().add_tokens(placeholder_token, *args, **kwargs)
39
+ if num_added_tokens == 0:
40
+ raise ValueError(
41
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
42
+ " `placeholder_token` that is not already in the tokenizer."
43
+ )
44
+
45
+ def add_placeholder_tokens(self, placeholder_token, *args, num_vec_per_token=1, **kwargs):
46
+ output = []
47
+ # if num_vec_per_token == 1:
48
+ # self.try_adding_tokens(placeholder_token, *args, **kwargs)
49
+ # output.append(placeholder_token)
50
+ # else:
51
+ output = []
52
+ for i in range(num_vec_per_token):
53
+ ith_token = placeholder_token + f"_{i}"
54
+ self.try_adding_tokens(ith_token, *args, **kwargs)
55
+ output.append(ith_token)
56
+ # handle cases where there is a new placeholder token that contains the current placeholder token but is larger
57
+ for token in self.token_map:
58
+ if token in placeholder_token:
59
+ raise ValueError(
60
+ f"The tokenizer already has placeholder token {token} that can get confused with"
61
+ f" {placeholder_token}keep placeholder tokens independent"
62
+ )
63
+ self.token_map[placeholder_token] = output
64
+
65
+ def replace_placeholder_tokens_in_text(self, text, vector_shuffle=False, prop_tokens_to_load=1.0):
66
+ """
67
+ Here, we replace the placeholder tokens in text recorded in token_map so that the text_encoder
68
+ can encode them
69
+ vector_shuffle was inspired by https://github.com/rinongal/textual_inversion/pull/119
70
+ where shuffling tokens were found to force the model to learn the concepts more descriptively.
71
+ """
72
+ if isinstance(text, list):
73
+ output = []
74
+ for i in range(len(text)):
75
+ output.append(self.replace_placeholder_tokens_in_text(text[i], vector_shuffle=vector_shuffle))
76
+ return output
77
+ for placeholder_token in self.token_map:
78
+ if placeholder_token in text:
79
+ tokens = self.token_map[placeholder_token]
80
+ tokens = tokens[: 1 + int(len(tokens) * prop_tokens_to_load)]
81
+ if vector_shuffle:
82
+ tokens = copy.copy(tokens)
83
+ random.shuffle(tokens)
84
+ text = text.replace(placeholder_token, " ".join(tokens)) # <part>_0 <part>_1 -> <part>
85
+ return text
86
+
87
+ def __call__(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, replace_token=True, **kwargs):
88
+ if replace_token:
89
+ return super().__call__(
90
+ self.replace_placeholder_tokens_in_text(
91
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
92
+ ),
93
+ *args,
94
+ **kwargs,
95
+ )
96
+ else:
97
+ return super().__call__(text, *args, **kwargs)
98
+
99
+ def encode(self, text, *args, vector_shuffle=False, prop_tokens_to_load=1.0, replace_token=True, **kwargs):
100
+ if replace_token:
101
+ return super().encode(
102
+ self.replace_placeholder_tokens_in_text(
103
+ text, vector_shuffle=vector_shuffle, prop_tokens_to_load=prop_tokens_to_load
104
+ ),
105
+ *args,
106
+ **kwargs,
107
+ )
108
+ else:
109
+ return super().encoder(text, *args, **kwargs)
gradio_demo_cub200.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import os
4
+ import re
5
+ import shutil
6
+
7
+ import gradio as gr
8
+ import requests
9
+ import torch
10
+
11
+ from dreamcreature.pipeline import create_args, load_pipeline
12
+
13
+
14
+ def download_file(url, local_path):
15
+ if os.path.exists(local_path):
16
+ return
17
+
18
+ with requests.get(url, stream=True) as r:
19
+ with open(local_path, 'wb') as f:
20
+ shutil.copyfileobj(r.raw, f)
21
+
22
+ # Example usage
23
+
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--model_name', default='dreamcreature-sd1.5-cub200')
27
+ parser.add_argument('--checkpoint', default='checkpoint-74900')
28
+ opt = parser.parse_args()
29
+
30
+ model_name = opt.model_name
31
+ checkpoint_name = opt.checkpoint
32
+
33
+ repo_url = f"https://huggingface.co/kamwoh/{model_name}/resolve/main"
34
+ file_url = repo_url + f"/{checkpoint_name}/pytorch_model.bin"
35
+ local_path = f"{model_name}/{checkpoint_name}/pytorch_model.bin"
36
+ os.makedirs(f"{model_name}/{checkpoint_name}", exist_ok=True)
37
+ download_file(file_url, local_path)
38
+
39
+ file_url = repo_url + f"/{checkpoint_name}/pytorch_model_1.bin"
40
+ local_path = f"{model_name}/{checkpoint_name}/pytorch_model_1.bin"
41
+ download_file(file_url, local_path)
42
+
43
+ OUTPUT_DIR = model_name
44
+
45
+ args = create_args(OUTPUT_DIR)
46
+ if 'dpo' in OUTPUT_DIR:
47
+ args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
48
+
49
+ pipe = load_pipeline(args, torch.float16, 'cuda')
50
+ pipe = pipe.to(torch.float16)
51
+
52
+ pipe.verbose = True
53
+ pipe.v = 're'
54
+ pipe.num_k_per_part = 200
55
+
56
+ MAPPING = {
57
+ 'body': 0,
58
+ 'tail': 1,
59
+ 'head': 2,
60
+ 'wing': 4,
61
+ 'leg': 6
62
+ }
63
+
64
+ ID2NAME = open('data/cub200_2011/class_names.txt').readlines()
65
+ ID2NAME = [line.strip() for line in ID2NAME]
66
+
67
+
68
+ def process_text(text):
69
+ pattern = r"<([^:>]+):(\d+)>"
70
+ result = text
71
+ offset = 0
72
+
73
+ part2id = []
74
+
75
+ for match in re.finditer(pattern, text):
76
+ key = match.group(1)
77
+ clsid = int(match.group(2))
78
+ clsid = min(max(clsid, 1), 200) # must be 1~200
79
+
80
+ replacement = f"<{MAPPING[key]}:{clsid - 1}>"
81
+ start, end = match.span()
82
+
83
+ # Adjust the start and end positions based on the offset from previous replacements
84
+ start += offset
85
+ end += offset
86
+
87
+ # Replace the matched text with the replacement
88
+ result = result[:start] + replacement + result[end:]
89
+
90
+ # Update the offset for the next replacement
91
+ offset += len(replacement) - (end - start)
92
+
93
+ part2id.append(f'{key}: {ID2NAME[clsid - 1]}')
94
+
95
+ return result, part2id
96
+
97
+
98
+ def generate_images(prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
99
+ generator = torch.Generator(device='cuda')
100
+ generator = generator.manual_seed(int(seed))
101
+
102
+ try:
103
+ prompt, part2id = process_text(prompt)
104
+ negative_prompt, _ = process_text(negative_prompt)
105
+
106
+ images = pipe(prompt,
107
+ negative_prompt=negative_prompt, generator=generator,
108
+ num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
109
+ num_images_per_prompt=num_images).images
110
+ except Exception as e:
111
+ raise gr.Error(f"Probably due to the prompt have invalid input, please follow the instruction. "
112
+ f"The error message: {e}")
113
+ finally:
114
+ gc.collect()
115
+ torch.cuda.empty_cache()
116
+
117
+ return images, '; '.join(part2id)
118
+
119
+
120
+ with gr.Blocks(title="DreamCreature") as demo:
121
+ with gr.Row():
122
+ gr.Markdown(
123
+ """
124
+ # DreamCreature (CUB-200-2011)
125
+ To create your own creature, you can type:
126
+
127
+ `"a photo of a <head:id> <wing:id> bird"` where `id` ranges from 1~200 (200 classes corresponding to CUB-200-2011)
128
+
129
+ For instance `"a photo of a <head:17> <wing:18> bird"` using head of `cardinal (17)` and wing of `spotted catbird (18)`
130
+
131
+ Please see `id` in https://github.com/kamwoh/dreamcreature/blob/master/src/data/cub200_2011/class_names.txt
132
+
133
+ You can also try any prompt you like such as:
134
+
135
+ Sub-concept transfer: `"a photo of a <wing:17> cat"`
136
+
137
+ Inspiring design: `"a photo of a <head:101> <wing:191> teddy bear"`
138
+
139
+ (Experimental) You can also use two parts together such as:
140
+
141
+ `"a photo of a <head:17> <head:18> bird"` mixing head of `cardinal (17)` and `spotted catbird (18)`
142
+
143
+ The current available parts are: `head`, `body`, `wing`, `tail`, and `leg`
144
+
145
+ """)
146
+ with gr.Column():
147
+ with gr.Row():
148
+ with gr.Group():
149
+ prompt = gr.Textbox(label="Prompt", value="a photo of a <head:101> <wing:191> teddy bear")
150
+ negative_prompt = gr.Textbox(label="Negative Prompt",
151
+ value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic")
152
+ num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Num Inference Steps")
153
+ guidance_scale = gr.Slider(minimum=2, maximum=20, step=0.1, value=7.5, label="Guidance Scale")
154
+ num_images = gr.Slider(minimum=1, maximum=4, step=1, value=4, label="Number of Images")
155
+ seed = gr.Number(label="Seed", value=777881414)
156
+ button = gr.Button()
157
+
158
+ with gr.Column():
159
+ output_images = gr.Gallery(columns=4, label='Output')
160
+ markdown_labels = gr.Markdown("")
161
+
162
+ button.click(fn=generate_images,
163
+ inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, num_images,
164
+ seed], outputs=[output_images, markdown_labels], show_progress=True)
165
+
166
+ demo.queue().launch(inline=False, share=True, debug=True, server_name='0.0.0.0')
gradio_demo_dog.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import gc
3
+ import os
4
+ import re
5
+ import shutil
6
+
7
+ import gradio as gr
8
+ import requests
9
+ import torch
10
+
11
+ from dreamcreature.pipeline import create_args, load_pipeline
12
+
13
+
14
+ def download_file(url, local_path):
15
+ if os.path.exists(local_path):
16
+ return
17
+
18
+ with requests.get(url, stream=True) as r:
19
+ with open(local_path, 'wb') as f:
20
+ shutil.copyfileobj(r.raw, f)
21
+
22
+ # Example usage
23
+
24
+
25
+ parser = argparse.ArgumentParser()
26
+ parser.add_argument('--model_name', default='dreamcreature-sd1.5-dog')
27
+ parser.add_argument('--checkpoint', default='checkpoint-150000')
28
+ opt = parser.parse_args()
29
+
30
+ model_name = opt.model_name
31
+ checkpoint_name = opt.checkpoint
32
+
33
+ repo_url = f"https://huggingface.co/kamwoh/{model_name}/resolve/main"
34
+ file_url = repo_url + f"/{checkpoint_name}/pytorch_model.bin"
35
+ local_path = f"{model_name}/{checkpoint_name}/pytorch_model.bin"
36
+ os.makedirs(f"{model_name}/{checkpoint_name}", exist_ok=True)
37
+ download_file(file_url, local_path)
38
+
39
+ file_url = repo_url + f"/{checkpoint_name}/pytorch_model_1.bin"
40
+ local_path = f"{model_name}/{checkpoint_name}/pytorch_model_1.bin"
41
+ download_file(file_url, local_path)
42
+
43
+ OUTPUT_DIR = model_name
44
+
45
+ args = create_args(OUTPUT_DIR)
46
+ if 'dpo' in OUTPUT_DIR:
47
+ args.unet_path = "mhdang/dpo-sd1.5-text2image-v1"
48
+
49
+ pipe = load_pipeline(args, torch.float16, 'cuda')
50
+ pipe = pipe.to(torch.float16)
51
+
52
+ pipe.verbose = True
53
+ pipe.v = 're'
54
+ pipe.num_k_per_part = 120
55
+
56
+ MAPPING = {
57
+ 'eye': 0,
58
+ 'neck': 2,
59
+ 'ear': 3,
60
+ 'body': 4,
61
+ 'leg': 5,
62
+ 'nose': 6,
63
+ 'forehead': 7
64
+ }
65
+
66
+ ID2NAME = open('data/dogs/class_names.txt').readlines()
67
+ ID2NAME = [line.strip() for line in ID2NAME]
68
+
69
+
70
+ def process_text(text):
71
+ pattern = r"<([^:>]+):(\d+)>"
72
+ result = text
73
+ offset = 0
74
+
75
+ part2id = []
76
+
77
+ for match in re.finditer(pattern, text):
78
+ key = match.group(1)
79
+ clsid = int(match.group(2))
80
+ clsid = min(max(clsid, 1), 200) # must be 1~200
81
+
82
+ replacement = f"<{MAPPING[key]}:{clsid - 1}>"
83
+ start, end = match.span()
84
+
85
+ # Adjust the start and end positions based on the offset from previous replacements
86
+ start += offset
87
+ end += offset
88
+
89
+ # Replace the matched text with the replacement
90
+ result = result[:start] + replacement + result[end:]
91
+
92
+ # Update the offset for the next replacement
93
+ offset += len(replacement) - (end - start)
94
+
95
+ part2id.append(f'{key}: {ID2NAME[clsid - 1]}')
96
+
97
+ return result, part2id
98
+
99
+
100
+ def generate_images(prompt, negative_prompt, num_inference_steps, guidance_scale, num_images, seed):
101
+ generator = torch.Generator(device='cuda')
102
+ generator = generator.manual_seed(int(seed))
103
+
104
+ try:
105
+ prompt, part2id = process_text(prompt)
106
+ negative_prompt, _ = process_text(negative_prompt)
107
+
108
+ images = pipe(prompt,
109
+ negative_prompt=negative_prompt, generator=generator,
110
+ num_inference_steps=int(num_inference_steps), guidance_scale=guidance_scale,
111
+ num_images_per_prompt=num_images).images
112
+ except Exception as e:
113
+ raise gr.Error(f"Probably due to the prompt have invalid input, please follow the instruction. "
114
+ f"The error message: {e}")
115
+ finally:
116
+ gc.collect()
117
+ torch.cuda.empty_cache()
118
+
119
+ return images, '; '.join(part2id)
120
+
121
+
122
+ with gr.Blocks(title="DreamCreature") as demo:
123
+ with gr.Row():
124
+ gr.Markdown(
125
+ """
126
+ # DreamCreature (Stanford Dogs)
127
+ To create your own creature, you can type:
128
+
129
+ `"a photo of a <nose:id> <ear:id> dog"` where `id` ranges from 0~119 (120 classes corresponding to Stanford Dogs)
130
+
131
+ For instance `"a photo of a <nose:2> <ear:112> dog"` using head of `maltese dog (2)` and wing of `cardigan (112)`
132
+
133
+ Please see `id` in https://github.com/kamwoh/dreamcreature/blob/master/src/data/dogs/class_names.txt
134
+
135
+ Sub-concept transfer: `"a photo of a <ear:112> cat"`
136
+
137
+ Inspiring design: `"a photo of a <eye:38> <body:38> teddy bear"`
138
+
139
+ (Experimental) You can also use two parts together such as:
140
+
141
+ `"a photo of a <nose:1> <nose:112> dog"` mixing head of `maltese dog (2)` and `spotted cardigan (112)`
142
+
143
+ The current available parts are: `eye`, `neck`, `ear`, `body`, `leg`, `nose` and `forehead`
144
+
145
+ """)
146
+ with gr.Column():
147
+ with gr.Row():
148
+ with gr.Group():
149
+ prompt = gr.Textbox(label="Prompt", value="a photo of a <eye:37> <body:37> teddy bear")
150
+ negative_prompt = gr.Textbox(label="Negative Prompt",
151
+ value="blurry, ugly, duplicate, poorly drawn, deformed, mosaic")
152
+ num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, value=30, label="Num Inference Steps")
153
+ guidance_scale = gr.Slider(minimum=2, maximum=20, step=0.1, value=7.5, label="Guidance Scale")
154
+ num_images = gr.Slider(minimum=1, maximum=4, step=1, value=1, label="Number of Images")
155
+ seed = gr.Number(label="Seed", value=777881414)
156
+ button = gr.Button()
157
+
158
+ with gr.Column():
159
+ output_images = gr.Gallery(columns=4, label='Output')
160
+ markdown_labels = gr.Markdown("")
161
+
162
+ button.click(fn=generate_images,
163
+ inputs=[prompt, negative_prompt, num_inference_steps, guidance_scale, num_images,
164
+ seed], outputs=[output_images, markdown_labels], show_progress=True)
165
+
166
+ demo.queue().launch(inline=False, share=True, debug=True, server_name='0.0.0.0')
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ git+https://github.com/huggingface/diffusers
4
+ transformers
5
+ torchpq
6
+ omegaconf
7
+ scikit-learn
8
+ faiss-cpu
9
+ tqdm
10
+ accelerate
11
+ gradio
12
+ huggingface_hub
run_sd_sup.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python train_dreamcreature_sd.py \
2
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
3
+ --train_data_dir=data/cub200_2011 \
4
+ --resolution=512 --random_flip --train_batch_size=2 --gradient_accumulation_steps=4 \
5
+ --num_train_epochs=100 --checkpointing_steps=749 --learning_rate=0.0001 \
6
+ --lr_scheduler="constant" --lr_warmup_steps=0 --seed=42 --output_dir="sd15-cub200-sup" \
7
+ --validation_prompt="a photo of a 0:16 1:16 2:16 4:16 6:16" \
8
+ --num_validation_images 8 --num_parts 8 --num_k_per_part 256 --filename="train.txt" \
9
+ --code_filename="train_caps_better_m8_k256.txt" --projection_nlayers=1 \
10
+ --use_templates --vector_shuffle --snr_gamma=5 \
11
+ --attn_loss=0.01 --use_gt_label --bg_code=7 \
12
+ --resume_from_checkpoint="latest" --mixed_precision="fp16"
run_sd_unsup.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python train_dreamcreature_sd.py \
2
+ --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \
3
+ --train_data_dir=data/cub200_2011 \
4
+ --resolution=512 --random_flip --train_batch_size=2 --gradient_accumulation_steps=4 \
5
+ --num_train_epochs=100 --checkpointing_steps=749 --learning_rate=0.0001 \
6
+ --lr_scheduler="constant" --lr_warmup_steps=0 --seed=42 --output_dir="sd15-cub200-unsup" \
7
+ --validation_prompt="a photo of a 0:16 1:16 2:16 4:16 6:16" \
8
+ --num_validation_images 8 --num_parts 8 --num_k_per_part 256 --filename="train.txt" \
9
+ --code_filename="train_caps_better_m8_k256.txt" --projection_nlayers=1 \
10
+ --use_templates --vector_shuffle --snr_gamma=5 \
11
+ --attn_loss=0.01 \
12
+ --resume_from_checkpoint="latest"
run_sdxl_sup.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python train_dreamcreature_sdxl.py \
2
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
3
+ --scheduler_steps 1000 \
4
+ --train_data_dir=data/cub200_2011 \
5
+ --resolution=512 --random_flip --train_batch_size=2 --gradient_accumulation_steps=4 \
6
+ --num_train_epochs=100 --checkpointing_steps=749 --learning_rate=0.0001 \
7
+ --lr_scheduler="constant" --lr_warmup_steps=0 --seed=42 --output_dir="sdxlbase-cub200-sup" \
8
+ --validation_prompt="a photo of a 0:1 2:1 3:1 4:1 5:1 6:1 7:1" \
9
+ --num_validation_images 8 --num_parts 8 --num_k_per_part 256 --filename="train.txt" \
10
+ --code_filename="train_caps_better_m8_k256.txt" --projection_nlayers=1 \
11
+ --use_templates --vector_shuffle --snr_gamma=5 \
12
+ --attn_loss=0.1 --use_gt_label --bg_code=7 \
13
+ --resume_from_checkpoint="latest"
run_sdxl_unsup.sh ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ python train_dreamcreature_sdxl.py \
2
+ --pretrained_model_name_or_path="stabilityai/stable-diffusion-xl-base-1.0" \
3
+ --scheduler_steps 1000 \
4
+ --train_data_dir=data/cub200_2011 \
5
+ --resolution=512 --random_flip --train_batch_size=2 --gradient_accumulation_steps=4 \
6
+ --num_train_epochs=100 --checkpointing_steps=749 --learning_rate=0.0001 \
7
+ --lr_scheduler="constant" --lr_warmup_steps=0 --seed=42 --output_dir="sdxlbase-cub200-unsup" \
8
+ --validation_prompt="a photo of a 0:1 2:1 3:1 4:1 5:1 6:1 7:1" \
9
+ --num_validation_images 8 --num_parts 8 --num_k_per_part 256 --filename="train.txt" \
10
+ --code_filename="train_caps_better_m8_k256.txt" --projection_nlayers=1 \
11
+ --use_templates --vector_shuffle --snr_gamma=5 \
12
+ --attn_loss=0.1 \
13
+ --resume_from_checkpoint="latest"
train_dreamcreature_sd.py ADDED
@@ -0,0 +1,1122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """Fine-tuning script for Stable Diffusion for text2image with support for LoRA."""
16
+
17
+ import argparse
18
+ import copy
19
+ import logging
20
+ import math
21
+ import os
22
+ import random
23
+ import shutil
24
+ from pathlib import Path
25
+
26
+ import datasets
27
+ import diffusers
28
+ import numpy as np
29
+ import torch
30
+ import torch.nn.functional as F
31
+ import torch.utils.checkpoint
32
+ import transformers
33
+ from accelerate import Accelerator
34
+ from accelerate.logging import get_logger
35
+ from accelerate.utils import ProjectConfiguration, set_seed
36
+ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
37
+ from diffusers.loaders import AttnProcsLayers
38
+ from diffusers.optimization import get_scheduler
39
+ from diffusers.training_utils import compute_snr
40
+ from diffusers.utils import check_min_version, is_wandb_available
41
+ from diffusers.utils.import_utils import is_xformers_available
42
+ from huggingface_hub import create_repo, upload_folder
43
+ from packaging import version
44
+ from torchvision import transforms
45
+ from torchvision.transforms import InterpolationMode
46
+ from tqdm.auto import tqdm
47
+
48
+ from dreamcreature.attn_processor import LoRAAttnProcessorCustom
49
+ from dreamcreature.dataset import DreamCreatureDataset
50
+ from dreamcreature.dino import DINO
51
+ from dreamcreature.kmeans_segmentation import KMeansSegmentation
52
+ from dreamcreature.loss import dreamcreature_loss
53
+ from dreamcreature.mapper import TokenMapper
54
+ from dreamcreature.pipeline import DreamCreatureSDPipeline
55
+ from dreamcreature.text_encoder import CustomCLIPTextModel
56
+ from dreamcreature.tokenizer import MultiTokenCLIPTokenizer
57
+ from utils import add_tokens, tokenize_prompt, get_attn_processors
58
+
59
+ imagenet_templates = [
60
+ "a photo of a {}",
61
+ "a rendering of a {}",
62
+ "a cropped photo of the {}",
63
+ "the photo of a {}",
64
+ "a photo of a clean {}",
65
+ "a photo of a dirty {}",
66
+ "a dark photo of the {}",
67
+ "a photo of my {}",
68
+ "a photo of the cool {}",
69
+ "a close-up photo of a {}",
70
+ "a bright photo of the {}",
71
+ "a cropped photo of a {}",
72
+ "a photo of the {}",
73
+ "a good photo of the {}",
74
+ "a photo of one {}",
75
+ "a close-up photo of the {}",
76
+ "a rendition of the {}",
77
+ "a photo of the clean {}",
78
+ "a rendition of a {}",
79
+ "a photo of a nice {}",
80
+ "a good photo of a {}",
81
+ "a photo of the nice {}",
82
+ "a photo of the small {}",
83
+ "a photo of the weird {}",
84
+ "a photo of the large {}",
85
+ "a photo of a cool {}",
86
+ "a photo of a small {}",
87
+ ]
88
+
89
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
90
+ check_min_version("0.21.0.dev0")
91
+
92
+ logger = get_logger(__name__, log_level="INFO")
93
+
94
+
95
+ def save_model_card(repo_id: str, images=None, base_model=str, dataset_name=str, repo_folder=None):
96
+ img_str = ""
97
+ for i, image in enumerate(images):
98
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
99
+ img_str += f"![img_{i}](./image_{i}.png)\n"
100
+
101
+ yaml = f"""
102
+ ---
103
+ license: creativeml-openrail-m
104
+ base_model: {base_model}
105
+ tags:
106
+ - stable-diffusion
107
+ - stable-diffusion-diffusers
108
+ - text-to-image
109
+ - diffusers
110
+ - lora
111
+ inference: true
112
+ ---
113
+ """
114
+ model_card = f"""
115
+ # LoRA text2image fine-tuning - {repo_id}
116
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
117
+ {img_str}
118
+ """
119
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
120
+ f.write(yaml + model_card)
121
+
122
+
123
+ def parse_args():
124
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
125
+ parser.add_argument(
126
+ "--pretrained_model_name_or_path",
127
+ type=str,
128
+ default=None,
129
+ required=True,
130
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
131
+ )
132
+ parser.add_argument(
133
+ "--revision",
134
+ type=str,
135
+ default=None,
136
+ required=False,
137
+ help="Revision of pretrained model identifier from huggingface.co/models.",
138
+ )
139
+ parser.add_argument(
140
+ "--dataset_name",
141
+ type=str,
142
+ default=None,
143
+ help=(
144
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
145
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
146
+ " or to a folder containing files that 🤗 Datasets can understand."
147
+ ),
148
+ )
149
+ parser.add_argument(
150
+ "--dataset_config_name",
151
+ type=str,
152
+ default=None,
153
+ help="The config of the Dataset, leave as None if there's only one config.",
154
+ )
155
+ parser.add_argument(
156
+ "--train_data_dir",
157
+ type=str,
158
+ default=None,
159
+ help=(
160
+ "A folder containing the training data. Folder contents must follow the structure described in"
161
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
162
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
163
+ ),
164
+ )
165
+ parser.add_argument(
166
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
167
+ )
168
+ parser.add_argument(
169
+ "--caption_column",
170
+ type=str,
171
+ default="text",
172
+ help="The column of the dataset containing a caption or a list of captions.",
173
+ )
174
+ parser.add_argument(
175
+ "--validation_prompt", type=str, default=None, help="A prompt that is sampled during training for inference."
176
+ )
177
+ parser.add_argument(
178
+ "--num_validation_images",
179
+ type=int,
180
+ default=4,
181
+ help="Number of images that should be generated during validation with `validation_prompt`.",
182
+ )
183
+ parser.add_argument(
184
+ "--validation_epochs",
185
+ type=int,
186
+ default=1,
187
+ help=(
188
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
189
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
190
+ ),
191
+ )
192
+ parser.add_argument(
193
+ "--max_train_samples",
194
+ type=int,
195
+ default=None,
196
+ help=(
197
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
198
+ "value if set."
199
+ ),
200
+ )
201
+ parser.add_argument(
202
+ "--output_dir",
203
+ type=str,
204
+ default="sd-model-finetuned-lora",
205
+ help="The output directory where the model predictions and checkpoints will be written.",
206
+ )
207
+ parser.add_argument(
208
+ "--cache_dir",
209
+ type=str,
210
+ default=None,
211
+ help="The directory where the downloaded models and datasets will be stored.",
212
+ )
213
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
214
+ parser.add_argument(
215
+ "--resolution",
216
+ type=int,
217
+ default=512,
218
+ help=(
219
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
220
+ " resolution"
221
+ ),
222
+ )
223
+ parser.add_argument(
224
+ "--center_crop",
225
+ default=False,
226
+ action="store_true",
227
+ help=(
228
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
229
+ " cropped. The images will be resized to the resolution first before cropping."
230
+ ),
231
+ )
232
+ parser.add_argument(
233
+ "--random_flip",
234
+ action="store_true",
235
+ help="whether to randomly flip images horizontally",
236
+ )
237
+ parser.add_argument(
238
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
239
+ )
240
+ parser.add_argument("--num_train_epochs", type=int, default=100)
241
+ parser.add_argument(
242
+ "--max_train_steps",
243
+ type=int,
244
+ default=None,
245
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
246
+ )
247
+ parser.add_argument(
248
+ "--gradient_accumulation_steps",
249
+ type=int,
250
+ default=1,
251
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
252
+ )
253
+ parser.add_argument(
254
+ "--gradient_checkpointing",
255
+ action="store_true",
256
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
257
+ )
258
+ parser.add_argument(
259
+ "--learning_rate",
260
+ type=float,
261
+ default=1e-4,
262
+ help="Initial learning rate (after the potential warmup period) to use.",
263
+ )
264
+ parser.add_argument(
265
+ "--scale_lr",
266
+ action="store_true",
267
+ default=False,
268
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
269
+ )
270
+ parser.add_argument(
271
+ "--lr_scheduler",
272
+ type=str,
273
+ default="constant",
274
+ help=(
275
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
276
+ ' "constant", "constant_with_warmup"]'
277
+ ),
278
+ )
279
+ parser.add_argument(
280
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
281
+ )
282
+ parser.add_argument(
283
+ "--snr_gamma",
284
+ type=float,
285
+ default=None,
286
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
287
+ "More details here: https://arxiv.org/abs/2303.09556.",
288
+ )
289
+ parser.add_argument(
290
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
291
+ )
292
+ parser.add_argument(
293
+ "--allow_tf32",
294
+ action="store_true",
295
+ help=(
296
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
297
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
298
+ ),
299
+ )
300
+ parser.add_argument(
301
+ "--dataloader_num_workers",
302
+ type=int,
303
+ default=0,
304
+ help=(
305
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
306
+ ),
307
+ )
308
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
309
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
310
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
311
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
312
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
313
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
314
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
315
+ parser.add_argument(
316
+ "--prediction_type",
317
+ type=str,
318
+ default=None,
319
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
320
+ )
321
+ parser.add_argument(
322
+ "--hub_model_id",
323
+ type=str,
324
+ default=None,
325
+ help="The name of the repository to keep in sync with the local `output_dir`.",
326
+ )
327
+ parser.add_argument(
328
+ "--logging_dir",
329
+ type=str,
330
+ default="logs",
331
+ help=(
332
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
333
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
334
+ ),
335
+ )
336
+ parser.add_argument(
337
+ "--mixed_precision",
338
+ type=str,
339
+ default=None,
340
+ choices=["no", "fp16", "bf16"],
341
+ help=(
342
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
343
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
344
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
345
+ ),
346
+ )
347
+ parser.add_argument(
348
+ "--report_to",
349
+ type=str,
350
+ default="tensorboard",
351
+ help=(
352
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
353
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
354
+ ),
355
+ )
356
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
357
+ parser.add_argument(
358
+ "--checkpointing_steps",
359
+ type=int,
360
+ default=500,
361
+ help=(
362
+ "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
363
+ " training using `--resume_from_checkpoint`."
364
+ ),
365
+ )
366
+ parser.add_argument(
367
+ "--checkpoints_total_limit",
368
+ type=int,
369
+ default=None,
370
+ help=("Max number of checkpoints to store."),
371
+ )
372
+ parser.add_argument(
373
+ "--resume_from_checkpoint",
374
+ type=str,
375
+ default=None,
376
+ help=(
377
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
378
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
379
+ ),
380
+ )
381
+ parser.add_argument(
382
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
383
+ )
384
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
385
+ parser.add_argument(
386
+ "--rank",
387
+ type=int,
388
+ default=4,
389
+ help=("The dimension of the LoRA update matrices."),
390
+ )
391
+
392
+ parser.add_argument('--filename', default='train.txt')
393
+ parser.add_argument('--code_filename', default='train_caps_better_m8_k256.txt')
394
+ parser.add_argument('--repeat', default=1, type=int)
395
+
396
+ parser.add_argument('--scheduler_steps', default=1000, type=int, help='scheduler step, if turbo, set to 4')
397
+ parser.add_argument('--num_parts', type=int, default=4, help="Number of parts")
398
+ parser.add_argument('--num_k_per_part', type=int, default=256, help='Number of k')
399
+
400
+ parser.add_argument('--mapper_lr_scale', default=1, type=float)
401
+ parser.add_argument('--mapper_lr', default=0.0001, type=float)
402
+ parser.add_argument('--attn_loss', default=0, type=float)
403
+ parser.add_argument('--projection_nlayers', default=3, type=int)
404
+
405
+ parser.add_argument('--masked_training', action='store_true')
406
+ parser.add_argument('--drop_tokens', action='store_true')
407
+ parser.add_argument('--drop_rate', type=float, default=0.5)
408
+ parser.add_argument('--drop_counts', default='half')
409
+
410
+ parser.add_argument('--class_name', default='')
411
+ parser.add_argument('--no_pe', action='store_true')
412
+ parser.add_argument('--vector_shuffle', action='store_true')
413
+ parser.add_argument('--use_templates', action='store_true')
414
+
415
+ parser.add_argument('--use_gt_label', action='store_true')
416
+ parser.add_argument('--bg_code', default=7, type=int) # for gt_label
417
+ parser.add_argument('--fg_idx', default=0, type=int) # for gt_label
418
+
419
+ parser.add_argument('--filter_class', default=None, type=int, help='debugging purpose')
420
+
421
+ parser.add_argument('--unet_path', default=None)
422
+
423
+ args = parser.parse_args()
424
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
425
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
426
+ args.local_rank = env_local_rank
427
+
428
+ # Sanity checks
429
+ if args.dataset_name is None and args.train_data_dir is None:
430
+ raise ValueError("Need either a dataset name or a training folder.")
431
+
432
+ return args
433
+
434
+
435
+ def collate_fn(args, tokenizer, placeholder_token):
436
+ train_resizecrop = transforms.Compose([
437
+ transforms.Resize(int(args.resolution), InterpolationMode.BILINEAR),
438
+ transforms.RandomCrop(args.resolution),
439
+ ])
440
+
441
+ train_transforms = transforms.Compose(
442
+ [
443
+ transforms.ToTensor(),
444
+ transforms.Normalize([0.5], [0.5]),
445
+ ]
446
+ )
447
+
448
+ def f(examples):
449
+ raw_images = [train_resizecrop(example["pixel_values"]) for example in examples]
450
+
451
+ pixel_values = torch.stack([train_transforms(image) for image in raw_images])
452
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
453
+
454
+ captions = []
455
+ appeared_tokens = []
456
+
457
+ for i in range(len(examples)):
458
+ if args.use_templates and random.random() <= 0.5: # 50% using templates
459
+ if args.class_name != '':
460
+ caption = random.choice(imagenet_templates).format(f'{placeholder_token} {args.class_name}')
461
+ else:
462
+ caption = random.choice(imagenet_templates).format(placeholder_token)
463
+ else:
464
+ if args.class_name != '':
465
+ caption = f'{placeholder_token} {args.class_name}'
466
+ else:
467
+ caption = placeholder_token
468
+
469
+ tokens = tokenizer.token_map[placeholder_token][:args.num_parts]
470
+ tokens = [tokens[a] for a in examples[i]['appeared']]
471
+
472
+ if args.vector_shuffle or args.drop_tokens:
473
+ tokens = copy.copy(tokens)
474
+ random.shuffle(tokens)
475
+
476
+ if args.drop_tokens and random.random() < args.drop_rate and len(tokens) >= 2:
477
+ # randomly drop half of the tokens
478
+ if args.drop_counts == 'half':
479
+ tokens = tokens[:len(tokens) // 2]
480
+ else:
481
+ tokens = tokens[:int(args.drop_counts)]
482
+
483
+ appeared = [int(t.split('_')[1]) for t in tokens] # <part>_i
484
+ appeared_tokens.append(appeared)
485
+
486
+ caption = caption.replace(placeholder_token, ' '.join(tokens))
487
+ captions.append(caption)
488
+
489
+ input_ids = tokenize_prompt(tokenizer, captions)
490
+ # input_ids = inputs.input_ids.repeat(len(examples), 1) # (1, 77) -> (B, 77)
491
+
492
+ codes = torch.stack([example["codes"] for example in examples])
493
+
494
+ return {"pixel_values": pixel_values,
495
+ "raw_images": raw_images,
496
+ "appeared_tokens": appeared_tokens,
497
+ "input_ids": input_ids,
498
+ "codes": codes}
499
+
500
+ return f
501
+
502
+
503
+ def setup_attn_processor(unet, **kwargs):
504
+ lora_attn_procs = {}
505
+ for name in unet.attn_processors.keys():
506
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
507
+ if name.startswith("mid_block"):
508
+ hidden_size = unet.config.block_out_channels[-1]
509
+ elif name.startswith("up_blocks"):
510
+ block_id = int(name[len("up_blocks.")])
511
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
512
+ elif name.startswith("down_blocks"):
513
+ block_id = int(name[len("down_blocks.")])
514
+ hidden_size = unet.config.block_out_channels[block_id]
515
+
516
+ lora_attn_procs[name] = LoRAAttnProcessorCustom(
517
+ hidden_size=hidden_size,
518
+ cross_attention_dim=cross_attention_dim,
519
+ rank=kwargs['rank'],
520
+ )
521
+
522
+ unet.set_attn_processor(lora_attn_procs)
523
+
524
+
525
+ def load_attn_processor(unet, filename):
526
+ logger.info(f'Load attn processors from {filename}')
527
+ lora_layers = AttnProcsLayers(get_attn_processors(unet))
528
+ lora_layers.load_state_dict(torch.load(filename))
529
+
530
+
531
+ def main():
532
+ args = parse_args()
533
+ logging_dir = Path(args.output_dir, args.logging_dir)
534
+
535
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
536
+
537
+ accelerator = Accelerator(
538
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
539
+ mixed_precision=args.mixed_precision,
540
+ log_with=args.report_to,
541
+ project_config=accelerator_project_config,
542
+ )
543
+ if args.report_to == "wandb":
544
+ if not is_wandb_available():
545
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
546
+ import wandb
547
+
548
+ # Make one log on every process with the configuration for debugging.
549
+ logging.basicConfig(
550
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
551
+ datefmt="%m/%d/%Y %H:%M:%S",
552
+ level=logging.INFO,
553
+ )
554
+ logger.info(accelerator.state, main_process_only=False)
555
+ if accelerator.is_local_main_process:
556
+ datasets.utils.logging.set_verbosity_warning()
557
+ transformers.utils.logging.set_verbosity_warning()
558
+ diffusers.utils.logging.set_verbosity_info()
559
+ else:
560
+ datasets.utils.logging.set_verbosity_error()
561
+ transformers.utils.logging.set_verbosity_error()
562
+ diffusers.utils.logging.set_verbosity_error()
563
+
564
+ # If passed along, set the training seed now.
565
+ if args.seed is not None:
566
+ set_seed(args.seed)
567
+
568
+ # Handle the repository creation
569
+ if accelerator.is_main_process:
570
+ if args.output_dir is not None:
571
+ os.makedirs(args.output_dir, exist_ok=True)
572
+
573
+ if args.push_to_hub:
574
+ repo_id = create_repo(
575
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
576
+ ).repo_id
577
+ # Load scheduler, tokenizer and models.
578
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
579
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(
580
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
581
+ )
582
+
583
+ OUT_DIMS = 1024 if 'stabilityai/stable-diffusion-2-1' in args.pretrained_model_name_or_path else 768
584
+
585
+ text_encoder = CustomCLIPTextModel.from_pretrained(
586
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
587
+ )
588
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
589
+
590
+ unet_path = args.unet_path if args.unet_path is not None else args.pretrained_model_name_or_path
591
+ unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
592
+ unet_path, subfolder="unet", revision=args.revision
593
+ )
594
+
595
+ dino = DINO()
596
+ seg = KMeansSegmentation(args.train_data_dir + '/pretrained_kmeans.pth',
597
+ args.fg_idx,
598
+ args.bg_code,
599
+ args.num_parts,
600
+ args.num_k_per_part)
601
+
602
+ simple_mapper = TokenMapper(args.num_parts,
603
+ args.num_k_per_part,
604
+ OUT_DIMS,
605
+ args.projection_nlayers)
606
+ # initialize placeholder token
607
+ placeholder_token = "<part>"
608
+ initializer_token = None
609
+ placeholder_token_ids = add_tokens(tokenizer,
610
+ text_encoder,
611
+ placeholder_token,
612
+ args.num_parts,
613
+ initializer_token)
614
+
615
+ # freeze parameters of models to save more memory
616
+ unet.requires_grad_(False)
617
+ vae.requires_grad_(False)
618
+ text_encoder.requires_grad_(False)
619
+ dino.requires_grad_(False)
620
+
621
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
622
+ # as these weights are only used for inference, keeping weights in full precision is not required.
623
+ weight_dtype = torch.float32
624
+ if accelerator.mixed_precision == "fp16":
625
+ weight_dtype = torch.float16
626
+ elif accelerator.mixed_precision == "bf16":
627
+ weight_dtype = torch.bfloat16
628
+
629
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
630
+ unet.to(accelerator.device, dtype=weight_dtype)
631
+ vae.to(accelerator.device, dtype=weight_dtype)
632
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
633
+
634
+ # now we will add new LoRA weights to the attention layers
635
+ # It's important to realize here how many attention weights will be added and of which sizes
636
+ # The sizes of the attention layers consist only of two different variables:
637
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
638
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
639
+
640
+ # Let's first see how many attention processors we will have to set.
641
+ # For Stable Diffusion, it should be equal to:
642
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
643
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
644
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
645
+ # => 32 layers
646
+
647
+ # Set correct lora layers
648
+ setup_attn_processor(unet, rank=args.rank)
649
+
650
+ if args.enable_xformers_memory_efficient_attention:
651
+ if is_xformers_available():
652
+ import xformers
653
+
654
+ xformers_version = version.parse(xformers.__version__)
655
+ if xformers_version == version.parse("0.0.16"):
656
+ logger.warn(
657
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
658
+ )
659
+ unet.enable_xformers_memory_efficient_attention()
660
+ else:
661
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
662
+
663
+ lora_layers = AttnProcsLayers(get_attn_processors(unet))
664
+
665
+ # Enable TF32 for faster training on Ampere GPUs,
666
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
667
+ if args.allow_tf32:
668
+ torch.backends.cuda.matmul.allow_tf32 = True
669
+
670
+ if args.scale_lr:
671
+ args.learning_rate = (
672
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
673
+ )
674
+
675
+ # Initialize the optimizer
676
+ if args.use_8bit_adam:
677
+ try:
678
+ import bitsandbytes as bnb
679
+ except ImportError:
680
+ raise ImportError(
681
+ "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
682
+ )
683
+
684
+ optimizer_cls = bnb.optim.AdamW8bit
685
+ else:
686
+ optimizer_cls = torch.optim.AdamW
687
+
688
+ extra_params = list(simple_mapper.parameters())
689
+ mapper_lr = args.learning_rate * args.mapper_lr_scale if args.learning_rate != 0 else args.mapper_lr
690
+
691
+ optimizer = optimizer_cls(
692
+ [{'params': lora_layers.parameters()},
693
+ {'params': extra_params, 'lr': mapper_lr}],
694
+ lr=args.learning_rate,
695
+ betas=(args.adam_beta1, args.adam_beta2),
696
+ weight_decay=args.adam_weight_decay,
697
+ eps=args.adam_epsilon,
698
+ )
699
+
700
+ train_dataset = DreamCreatureDataset(args.train_data_dir,
701
+ args.filename,
702
+ code_filename=args.code_filename,
703
+ num_parts=args.num_parts,
704
+ num_k_per_part=args.num_k_per_part,
705
+ use_gt_label=args.use_gt_label,
706
+ bg_code=args.bg_code,
707
+ repeat=args.repeat)
708
+
709
+ with accelerator.main_process_first():
710
+ if args.max_train_samples is not None:
711
+ train_dataset.set_max_samples(args.max_train_samples, args.seed)
712
+
713
+ # DataLoaders creation:
714
+ train_dataloader = torch.utils.data.DataLoader(
715
+ train_dataset,
716
+ shuffle=True,
717
+ collate_fn=collate_fn(args, tokenizer, placeholder_token),
718
+ batch_size=args.train_batch_size,
719
+ num_workers=args.dataloader_num_workers,
720
+ )
721
+
722
+ # Scheduler and math around the number of training steps.
723
+ overrode_max_train_steps = False
724
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
725
+ if args.max_train_steps is None:
726
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
727
+ overrode_max_train_steps = True
728
+
729
+ lr_scheduler = get_scheduler(
730
+ args.lr_scheduler,
731
+ optimizer=optimizer,
732
+ num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
733
+ num_training_steps=args.max_train_steps * accelerator.num_processes,
734
+ )
735
+
736
+ # Prepare everything with our `accelerator`.
737
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
738
+ lora_layers, optimizer, train_dataloader, lr_scheduler
739
+ )
740
+ simple_mapper = accelerator.prepare(simple_mapper)
741
+
742
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
743
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
744
+ if overrode_max_train_steps:
745
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
746
+ # Afterwards we recalculate our number of training epochs
747
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
748
+
749
+ # We need to initialize the trackers we use, and also store our configuration.
750
+ # The trackers initializes automatically on the main process.
751
+ if accelerator.is_main_process:
752
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
753
+
754
+ # Train!
755
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
756
+
757
+ logger.info("***** Running training *****")
758
+ logger.info(f" Num examples = {len(train_dataset)}")
759
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
760
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
761
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
762
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
763
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
764
+ global_step = 0
765
+ first_epoch = 0
766
+
767
+ # Potentially load in the weights and states from a previous save
768
+ if args.resume_from_checkpoint:
769
+ if args.resume_from_checkpoint != "latest":
770
+ path = os.path.basename(args.resume_from_checkpoint)
771
+ else:
772
+ # Get the most recent checkpoint
773
+ dirs = os.listdir(args.output_dir)
774
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
775
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
776
+ path = dirs[-1] if len(dirs) > 0 else None
777
+
778
+ if path is None:
779
+ accelerator.print(
780
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
781
+ )
782
+ args.resume_from_checkpoint = None
783
+ else:
784
+ accelerator.print(f"Resuming from checkpoint {path}")
785
+ accelerator.load_state(os.path.join(args.output_dir, path))
786
+ global_step = int(path.split("-")[1])
787
+
788
+ resume_global_step = global_step * args.gradient_accumulation_steps
789
+ first_epoch = global_step // num_update_steps_per_epoch
790
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
791
+
792
+ # Only show the progress bar once on each machine.
793
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process,
794
+ bar_format="{l_bar}{bar:10}{r_bar}{bar:-10b}")
795
+ progress_bar.set_description("Steps")
796
+
797
+ print(simple_mapper)
798
+
799
+ for epoch in range(first_epoch, args.num_train_epochs):
800
+ unet.train()
801
+ train_loss = 0.0
802
+ train_attn_loss = 0.0
803
+ train_diff_loss = 0.0
804
+ for step, batch in enumerate(train_dataloader):
805
+ # Skip steps until we reach the resumed step
806
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
807
+ if step % args.gradient_accumulation_steps == 0:
808
+ progress_bar.update(1)
809
+ continue
810
+
811
+ with accelerator.accumulate(unet, simple_mapper):
812
+ # Convert images to latent space
813
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
814
+ latents = latents * vae.config.scaling_factor
815
+
816
+ # Sample noise that we'll add to the latents
817
+ noise = torch.randn_like(latents)
818
+ if args.noise_offset:
819
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
820
+ noise += args.noise_offset * torch.randn(
821
+ (latents.shape[0], latents.shape[1], 1, 1), device=latents.device
822
+ )
823
+
824
+ bsz = latents.shape[0]
825
+ # Sample a random timestep for each image
826
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
827
+ timesteps = timesteps.long()
828
+
829
+ # Add noise to the latents according to the noise magnitude at each timestep
830
+ # (this is the forward diffusion process)
831
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
832
+
833
+ # Get the text embedding for conditioning
834
+ mapper_outputs = simple_mapper(batch['codes'])
835
+ # print(mapper_outputs.size(), batch["input_ids"].size())
836
+ modified_hs = text_encoder.text_model.forward_embeddings_with_mapper(batch["input_ids"],
837
+ None,
838
+ mapper_outputs,
839
+ placeholder_token_ids)
840
+ # print(modified_hs.size())
841
+ encoder_hidden_states = text_encoder(batch["input_ids"], hidden_states=modified_hs)[0]
842
+
843
+ # Get the target for loss depending on the prediction type
844
+ if args.prediction_type is not None:
845
+ # set prediction_type of scheduler if defined
846
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
847
+
848
+ if noise_scheduler.config.prediction_type == "epsilon":
849
+ target = noise
850
+ elif noise_scheduler.config.prediction_type == "v_prediction":
851
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
852
+ else:
853
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
854
+
855
+ # Predict the noise residual and compute loss
856
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
857
+
858
+ if args.snr_gamma is None:
859
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
860
+ attn_loss, max_attn = dreamcreature_loss(batch,
861
+ unet,
862
+ dino,
863
+ seg,
864
+ placeholder_token_ids,
865
+ accelerator)
866
+ if args.masked_training:
867
+ masks = batch['masks'].unsqueeze(1).to(accelerator.device)
868
+ loss_image_mask = F.interpolate(masks.float(),
869
+ size=target.shape[-2:],
870
+ mode='bilinear') * torch.ones_like(target)
871
+ loss = loss * loss_image_mask
872
+ loss = loss.sum() / loss_image_mask.sum()
873
+ else:
874
+ loss = loss.mean()
875
+ else:
876
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
877
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
878
+ # This is discussed in Section 4.2 of the same paper.
879
+ snr = compute_snr(noise_scheduler, timesteps)
880
+ if noise_scheduler.config.prediction_type == "v_prediction":
881
+ # Velocity objective requires that we add one to SNR values before we divide by them.
882
+ snr = snr + 1
883
+ mse_loss_weights = (
884
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
885
+ )
886
+
887
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
888
+ attn_loss, max_attn = dreamcreature_loss(batch,
889
+ unet,
890
+ dino,
891
+ seg,
892
+ placeholder_token_ids,
893
+ accelerator)
894
+ if args.masked_training:
895
+ masks = batch['masks'].unsqueeze(1).to(accelerator.device)
896
+ loss_image_mask = F.interpolate(masks.float(),
897
+ size=target.shape[-2:],
898
+ mode='bilinear') * torch.ones_like(target)
899
+ loss = loss * loss_image_mask
900
+ loss = loss.sum(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
901
+ loss = loss.sum() / loss_image_mask.sum()
902
+ else:
903
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
904
+ loss = loss.mean()
905
+
906
+ diff_loss = loss.clone().detach()
907
+ avg_diff_loss = accelerator.gather(diff_loss.repeat(args.train_batch_size)).mean()
908
+ train_diff_loss += avg_diff_loss.item() / args.gradient_accumulation_steps
909
+
910
+ avg_attn_loss = accelerator.gather(attn_loss.repeat(args.train_batch_size)).mean()
911
+ train_attn_loss += avg_attn_loss.item() / args.gradient_accumulation_steps
912
+
913
+ loss += args.attn_loss * attn_loss
914
+
915
+ # Gather the losses across all processes for logging (if we use distributed training).
916
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
917
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
918
+
919
+ # Backpropagate
920
+ accelerator.backward(loss)
921
+ if accelerator.sync_gradients:
922
+ params_to_clip = list(lora_layers.parameters()) + list(simple_mapper.parameters())
923
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
924
+
925
+ optimizer.step()
926
+ lr_scheduler.step()
927
+ optimizer.zero_grad()
928
+
929
+ # Checks if the accelerator has performed an optimization step behind the scenes
930
+ if accelerator.sync_gradients:
931
+ progress_bar.update(1)
932
+ global_step += 1
933
+ accelerator.log({"train_loss": train_loss,
934
+ "diff_loss": train_diff_loss,
935
+ "attn_loss": train_attn_loss,
936
+ "mapper_norm": mapper_outputs.detach().norm().item(),
937
+ "max_attn": max_attn.item()
938
+ }, step=global_step)
939
+ train_loss = 0.0
940
+ train_attn_loss = 0.0
941
+ train_diff_loss = 0.0
942
+
943
+ if global_step % args.checkpointing_steps == 0:
944
+ if accelerator.is_main_process:
945
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
946
+ if args.checkpoints_total_limit is not None:
947
+ checkpoints = os.listdir(args.output_dir)
948
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
949
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
950
+
951
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
952
+ if len(checkpoints) >= args.checkpoints_total_limit:
953
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
954
+ removing_checkpoints = checkpoints[0:num_to_remove]
955
+
956
+ logger.info(
957
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
958
+ )
959
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
960
+
961
+ for removing_checkpoint in removing_checkpoints:
962
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
963
+ shutil.rmtree(removing_checkpoint)
964
+
965
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
966
+ accelerator.save_state(save_path)
967
+ logger.info(f"Saved state to {save_path}")
968
+
969
+ logs = {"step_loss": diff_loss.detach().item(),
970
+ "attn_loss": attn_loss.detach().item(),
971
+ "lr": lr_scheduler.get_last_lr()[0]}
972
+ progress_bar.set_postfix(**logs)
973
+
974
+ if global_step >= args.max_train_steps:
975
+ break
976
+
977
+ if accelerator.is_main_process:
978
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
979
+ logger.info(
980
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
981
+ f" {args.validation_prompt}."
982
+ )
983
+ pipeline = DreamCreatureSDPipeline.from_pretrained(
984
+ args.pretrained_model_name_or_path,
985
+ unet=accelerator.unwrap_model(unet),
986
+ text_encoder=accelerator.unwrap_model(text_encoder),
987
+ tokenizer=tokenizer,
988
+ revision=args.revision,
989
+ torch_dtype=weight_dtype,
990
+ )
991
+ pipeline.placeholder_token_ids = placeholder_token_ids
992
+ pipeline.simple_mapper = accelerator.unwrap_model(simple_mapper)
993
+ pipeline.replace_token = False
994
+
995
+ pipeline = pipeline.to(accelerator.device)
996
+ pipeline.set_progress_bar_config(disable=True)
997
+
998
+ # run inference
999
+ generator = torch.Generator(device=accelerator.device)
1000
+ if args.seed is not None:
1001
+ generator = generator.manual_seed(args.seed)
1002
+ images = []
1003
+ for _ in range(args.num_validation_images):
1004
+ images.append(
1005
+ pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0]
1006
+ )
1007
+
1008
+ for tracker in accelerator.trackers:
1009
+ if tracker.name == "tensorboard":
1010
+ np_images = np.stack([np.asarray(img) for img in images])
1011
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1012
+ if tracker.name == "wandb":
1013
+ tracker.log(
1014
+ {
1015
+ "validation": [
1016
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1017
+ for i, image in enumerate(images)
1018
+ ]
1019
+ }
1020
+ )
1021
+
1022
+ del pipeline
1023
+ torch.cuda.empty_cache()
1024
+
1025
+ # Save the lora layers
1026
+ accelerator.wait_for_everyone()
1027
+ if accelerator.is_main_process:
1028
+ # unet = unet.to(torch.float32)
1029
+ # unet.save_attn_procs(args.output_dir, safe_serialization=not args.custom_diffusion)
1030
+
1031
+ torch.save(lora_layers.to(torch.float32).state_dict(), args.output_dir + '/lora_layers.pth')
1032
+ torch.save(simple_mapper.to(torch.float32).state_dict(), args.output_dir + '/hash_mapper.pth')
1033
+
1034
+ if args.push_to_hub:
1035
+ save_model_card(
1036
+ repo_id,
1037
+ images=images,
1038
+ base_model=args.pretrained_model_name_or_path,
1039
+ dataset_name=args.dataset_name,
1040
+ repo_folder=args.output_dir,
1041
+ )
1042
+ upload_folder(
1043
+ repo_id=repo_id,
1044
+ folder_path=args.output_dir,
1045
+ commit_message="End of training",
1046
+ ignore_patterns=["step_*", "epoch_*"],
1047
+ )
1048
+
1049
+ del unet
1050
+
1051
+ # Final inference
1052
+ # Load previous pipeline
1053
+ tokenizer = MultiTokenCLIPTokenizer.from_pretrained(
1054
+ args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
1055
+ )
1056
+ text_encoder = CustomCLIPTextModel.from_pretrained(
1057
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
1058
+ )
1059
+ unet_path = args.unet_path if args.unet_path is not None else args.pretrained_model_name_or_path
1060
+ unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(
1061
+ unet_path, subfolder="unet", revision=args.revision
1062
+ )
1063
+ pipeline = DreamCreatureSDPipeline.from_pretrained(
1064
+ args.pretrained_model_name_or_path,
1065
+ unet=unet,
1066
+ text_encoder=text_encoder,
1067
+ tokenizer=tokenizer,
1068
+ revision=args.revision,
1069
+ torch_dtype=weight_dtype,
1070
+ )
1071
+ placeholder_token = "<part>"
1072
+ initializer_token = None
1073
+ placeholder_token_ids = add_tokens(tokenizer,
1074
+ text_encoder,
1075
+ placeholder_token,
1076
+ args.num_parts,
1077
+ initializer_token)
1078
+ pipeline.placeholder_token_ids = placeholder_token_ids
1079
+ pipeline.simple_mapper = TokenMapper(args.num_parts,
1080
+ args.num_k_per_part,
1081
+ OUT_DIMS,
1082
+ args.projection_nlayers)
1083
+ pipeline.simple_mapper.load_state_dict(torch.load(args.output_dir + '/hash_mapper.pth', map_location='cpu'))
1084
+ pipeline.simple_mapper.to(accelerator.device)
1085
+
1086
+ pipeline = pipeline.to(accelerator.device)
1087
+
1088
+ # load attention processors
1089
+ # pipeline.unet.load_attn_procs(args.output_dir, use_safetensors=not args.custom_diffusion)
1090
+ setup_attn_processor(pipeline.unet, rank=args.rank)
1091
+ load_attn_processor(pipeline.unet, args.output_dir + '/lora_layers.pth')
1092
+
1093
+ # run inference
1094
+ pipeline.replace_token = False
1095
+ generator = torch.Generator(device=accelerator.device)
1096
+ if args.seed is not None:
1097
+ generator = generator.manual_seed(args.seed)
1098
+ images = []
1099
+ for _ in range(args.num_validation_images):
1100
+ images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])
1101
+
1102
+ if accelerator.is_main_process:
1103
+ for tracker in accelerator.trackers:
1104
+ if len(images) != 0:
1105
+ if tracker.name == "tensorboard":
1106
+ np_images = np.stack([np.asarray(img) for img in images])
1107
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1108
+ if tracker.name == "wandb":
1109
+ tracker.log(
1110
+ {
1111
+ "test": [
1112
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1113
+ for i, image in enumerate(images)
1114
+ ]
1115
+ }
1116
+ )
1117
+
1118
+ accelerator.end_training()
1119
+
1120
+
1121
+ if __name__ == "__main__":
1122
+ main()
train_dreamcreature_sdxl.py ADDED
@@ -0,0 +1,1539 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ # Copyright 2023 The HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Fine-tuning script for Stable Diffusion XL for text2image with support for LoRA."""
17
+
18
+ import argparse
19
+ import copy
20
+ import itertools
21
+ import logging
22
+ import math
23
+ import os
24
+ import random
25
+ import shutil
26
+ from pathlib import Path
27
+ from typing import Dict
28
+
29
+ import datasets
30
+ import diffusers
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn.functional as F
34
+ import torch.utils.checkpoint
35
+ import transformers
36
+ from accelerate import Accelerator
37
+ from accelerate.logging import get_logger
38
+ from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
39
+ from diffusers import (
40
+ AutoencoderKL,
41
+ DDPMScheduler,
42
+ StableDiffusionXLPipeline,
43
+ UNet2DConditionModel,
44
+ )
45
+ from diffusers.loaders import LoraLoaderMixin
46
+ from diffusers.models.lora import LoRALinearLayer
47
+ from diffusers.optimization import get_scheduler
48
+ from diffusers.training_utils import compute_snr
49
+ from diffusers.utils import check_min_version, is_wandb_available
50
+ from diffusers.utils.import_utils import is_xformers_available
51
+ from huggingface_hub import create_repo, upload_folder
52
+ from packaging import version
53
+ from torchvision import transforms
54
+ from torchvision.transforms.functional import crop
55
+ from tqdm.auto import tqdm
56
+ from transformers import PretrainedConfig
57
+
58
+ from dreamcreature.attn_processor import AttnProcessorCustom
59
+ from dreamcreature.dataset import DreamCreatureDataset
60
+ from dreamcreature.dino import DINO
61
+ from dreamcreature.kmeans_segmentation import KMeansSegmentation
62
+ from dreamcreature.loss import dreamcreature_loss
63
+ from dreamcreature.mapper import TokenMapper
64
+ from dreamcreature.pipeline_xl import DreamCreatureSDXLPipeline
65
+ from dreamcreature.text_encoder import CustomCLIPTextModel, CustomCLIPTextModelWithProjection
66
+ from dreamcreature.tokenizer import MultiTokenCLIPTokenizer
67
+ from utils import add_tokens, tokenize_prompt, get_attn_processors
68
+
69
+ IMAGENET_TEMPLATES = [
70
+ "a photo of a {}",
71
+ "a rendering of a {}",
72
+ "a cropped photo of the {}",
73
+ "the photo of a {}",
74
+ "a photo of a clean {}",
75
+ "a photo of a dirty {}",
76
+ "a dark photo of the {}",
77
+ "a photo of my {}",
78
+ "a photo of the cool {}",
79
+ "a close-up photo of a {}",
80
+ "a bright photo of the {}",
81
+ "a cropped photo of a {}",
82
+ "a photo of the {}",
83
+ "a good photo of the {}",
84
+ "a photo of one {}",
85
+ "a close-up photo of the {}",
86
+ "a rendition of the {}",
87
+ "a photo of the clean {}",
88
+ "a rendition of a {}",
89
+ "a photo of a nice {}",
90
+ "a good photo of a {}",
91
+ "a photo of the nice {}",
92
+ "a photo of the small {}",
93
+ "a photo of the weird {}",
94
+ "a photo of the large {}",
95
+ "a photo of a cool {}",
96
+ "a photo of a small {}",
97
+ ]
98
+
99
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
100
+ check_min_version("0.25.0.dev0")
101
+
102
+ logger = get_logger(__name__)
103
+
104
+
105
+ # TODO: This function should be removed once training scripts are rewritten in PEFT
106
+ def text_encoder_lora_state_dict(text_encoder):
107
+ state_dict = {}
108
+
109
+ def text_encoder_attn_modules(text_encoder):
110
+ from transformers import CLIPTextModel, CLIPTextModelWithProjection
111
+
112
+ attn_modules = []
113
+
114
+ if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
115
+ for i, layer in enumerate(text_encoder.text_model.encoder.layers):
116
+ name = f"text_model.encoder.layers.{i}.self_attn"
117
+ mod = layer.self_attn
118
+ attn_modules.append((name, mod))
119
+
120
+ return attn_modules
121
+
122
+ for name, module in text_encoder_attn_modules(text_encoder):
123
+ for k, v in module.q_proj.lora_linear_layer.state_dict().items():
124
+ state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
125
+
126
+ for k, v in module.k_proj.lora_linear_layer.state_dict().items():
127
+ state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
128
+
129
+ for k, v in module.v_proj.lora_linear_layer.state_dict().items():
130
+ state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
131
+
132
+ for k, v in module.out_proj.lora_linear_layer.state_dict().items():
133
+ state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
134
+
135
+ return state_dict
136
+
137
+
138
+ def save_model_card(
139
+ repo_id: str,
140
+ images=None,
141
+ base_model=str,
142
+ dataset_name=str,
143
+ train_text_encoder=False,
144
+ repo_folder=None,
145
+ vae_path=None,
146
+ ):
147
+ img_str = ""
148
+ for i, image in enumerate(images):
149
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
150
+ img_str += f"![img_{i}](./image_{i}.png)\n"
151
+
152
+ yaml = f"""
153
+ ---
154
+ license: creativeml-openrail-m
155
+ base_model: {base_model}
156
+ dataset: {dataset_name}
157
+ tags:
158
+ - stable-diffusion-xl
159
+ - stable-diffusion-xl-diffusers
160
+ - text-to-image
161
+ - diffusers
162
+ - lora
163
+ inference: true
164
+ ---
165
+ """
166
+ model_card = f"""
167
+ # LoRA text2image fine-tuning - {repo_id}
168
+
169
+ These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
170
+ {img_str}
171
+
172
+ LoRA for the text encoder was enabled: {train_text_encoder}.
173
+
174
+ Special VAE used for training: {vae_path}.
175
+ """
176
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
177
+ f.write(yaml + model_card)
178
+
179
+
180
+ def import_model_class_from_model_name_or_path(
181
+ pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder"
182
+ ):
183
+ text_encoder_config = PretrainedConfig.from_pretrained(
184
+ pretrained_model_name_or_path, subfolder=subfolder, revision=revision
185
+ )
186
+ model_class = text_encoder_config.architectures[0]
187
+
188
+ if model_class == "CLIPTextModel":
189
+ from transformers import CLIPTextModel
190
+
191
+ return CLIPTextModel
192
+ elif model_class == "CLIPTextModelWithProjection":
193
+ from transformers import CLIPTextModelWithProjection
194
+
195
+ return CLIPTextModelWithProjection
196
+ else:
197
+ raise ValueError(f"{model_class} is not supported.")
198
+
199
+
200
+ def parse_args(input_args=None):
201
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
202
+ parser.add_argument(
203
+ "--pretrained_model_name_or_path",
204
+ type=str,
205
+ default=None,
206
+ required=True,
207
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
208
+ )
209
+ parser.add_argument(
210
+ "--pretrained_vae_model_name_or_path",
211
+ type=str,
212
+ default=None,
213
+ help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
214
+ )
215
+ parser.add_argument(
216
+ "--revision",
217
+ type=str,
218
+ default=None,
219
+ required=False,
220
+ help="Revision of pretrained model identifier from huggingface.co/models.",
221
+ )
222
+ parser.add_argument(
223
+ "--variant",
224
+ type=str,
225
+ default=None,
226
+ help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
227
+ )
228
+ parser.add_argument(
229
+ "--dataset_name",
230
+ type=str,
231
+ default=None,
232
+ help=(
233
+ "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
234
+ " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
235
+ " or to a folder containing files that 🤗 Datasets can understand."
236
+ ),
237
+ )
238
+ parser.add_argument(
239
+ "--dataset_config_name",
240
+ type=str,
241
+ default=None,
242
+ help="The config of the Dataset, leave as None if there's only one config.",
243
+ )
244
+ parser.add_argument(
245
+ "--train_data_dir",
246
+ type=str,
247
+ default=None,
248
+ help=(
249
+ "A folder containing the training data. Folder contents must follow the structure described in"
250
+ " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
251
+ " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
252
+ ),
253
+ )
254
+ parser.add_argument(
255
+ "--image_column", type=str, default="image", help="The column of the dataset containing an image."
256
+ )
257
+ parser.add_argument(
258
+ "--caption_column",
259
+ type=str,
260
+ default="text",
261
+ help="The column of the dataset containing a caption or a list of captions.",
262
+ )
263
+ parser.add_argument(
264
+ "--validation_prompt",
265
+ type=str,
266
+ default=None,
267
+ help="A prompt that is used during validation to verify that the model is learning.",
268
+ )
269
+ parser.add_argument(
270
+ "--num_validation_images",
271
+ type=int,
272
+ default=4,
273
+ help="Number of images that should be generated during validation with `validation_prompt`.",
274
+ )
275
+ parser.add_argument(
276
+ "--validation_epochs",
277
+ type=int,
278
+ default=1,
279
+ help=(
280
+ "Run fine-tuning validation every X epochs. The validation process consists of running the prompt"
281
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
282
+ ),
283
+ )
284
+ parser.add_argument(
285
+ "--max_train_samples",
286
+ type=int,
287
+ default=None,
288
+ help=(
289
+ "For debugging purposes or quicker training, truncate the number of training examples to this "
290
+ "value if set."
291
+ ),
292
+ )
293
+ parser.add_argument(
294
+ "--output_dir",
295
+ type=str,
296
+ default="sd-model-finetuned-lora",
297
+ help="The output directory where the model predictions and checkpoints will be written.",
298
+ )
299
+ parser.add_argument(
300
+ "--cache_dir",
301
+ type=str,
302
+ default=None,
303
+ help="The directory where the downloaded models and datasets will be stored.",
304
+ )
305
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
306
+ parser.add_argument(
307
+ "--resolution",
308
+ type=int,
309
+ default=1024,
310
+ help=(
311
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
312
+ " resolution"
313
+ ),
314
+ )
315
+ parser.add_argument(
316
+ "--center_crop",
317
+ default=False,
318
+ action="store_true",
319
+ help=(
320
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
321
+ " cropped. The images will be resized to the resolution first before cropping."
322
+ ),
323
+ )
324
+ parser.add_argument(
325
+ "--random_flip",
326
+ action="store_true",
327
+ help="whether to randomly flip images horizontally",
328
+ )
329
+ parser.add_argument(
330
+ "--train_text_encoder",
331
+ action="store_true",
332
+ help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
333
+ )
334
+ parser.add_argument(
335
+ "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
336
+ )
337
+ parser.add_argument("--num_train_epochs", type=int, default=100)
338
+ parser.add_argument(
339
+ "--max_train_steps",
340
+ type=int,
341
+ default=None,
342
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
343
+ )
344
+ parser.add_argument(
345
+ "--checkpointing_steps",
346
+ type=int,
347
+ default=500,
348
+ help=(
349
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
350
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
351
+ " training using `--resume_from_checkpoint`."
352
+ ),
353
+ )
354
+ parser.add_argument(
355
+ "--checkpoints_total_limit",
356
+ type=int,
357
+ default=None,
358
+ help=("Max number of checkpoints to store."),
359
+ )
360
+ parser.add_argument(
361
+ "--resume_from_checkpoint",
362
+ type=str,
363
+ default=None,
364
+ help=(
365
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
366
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
367
+ ),
368
+ )
369
+ parser.add_argument(
370
+ "--gradient_accumulation_steps",
371
+ type=int,
372
+ default=1,
373
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
374
+ )
375
+ parser.add_argument(
376
+ "--gradient_checkpointing",
377
+ action="store_true",
378
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
379
+ )
380
+ parser.add_argument(
381
+ "--learning_rate",
382
+ type=float,
383
+ default=1e-4,
384
+ help="Initial learning rate (after the potential warmup period) to use.",
385
+ )
386
+ parser.add_argument(
387
+ "--scale_lr",
388
+ action="store_true",
389
+ default=False,
390
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
391
+ )
392
+ parser.add_argument(
393
+ "--lr_scheduler",
394
+ type=str,
395
+ default="constant",
396
+ help=(
397
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
398
+ ' "constant", "constant_with_warmup"]'
399
+ ),
400
+ )
401
+ parser.add_argument(
402
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
403
+ )
404
+ parser.add_argument(
405
+ "--snr_gamma",
406
+ type=float,
407
+ default=None,
408
+ help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
409
+ "More details here: https://arxiv.org/abs/2303.09556.",
410
+ )
411
+ parser.add_argument(
412
+ "--allow_tf32",
413
+ action="store_true",
414
+ help=(
415
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
416
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
417
+ ),
418
+ )
419
+ parser.add_argument(
420
+ "--dataloader_num_workers",
421
+ type=int,
422
+ default=0,
423
+ help=(
424
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
425
+ ),
426
+ )
427
+ parser.add_argument(
428
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
429
+ )
430
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
431
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
432
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
433
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
434
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
435
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
436
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
437
+ parser.add_argument(
438
+ "--prediction_type",
439
+ type=str,
440
+ default=None,
441
+ help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediciton_type` is chosen.",
442
+ )
443
+ parser.add_argument(
444
+ "--hub_model_id",
445
+ type=str,
446
+ default=None,
447
+ help="The name of the repository to keep in sync with the local `output_dir`.",
448
+ )
449
+ parser.add_argument(
450
+ "--logging_dir",
451
+ type=str,
452
+ default="logs",
453
+ help=(
454
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
455
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
456
+ ),
457
+ )
458
+ parser.add_argument(
459
+ "--report_to",
460
+ type=str,
461
+ default="tensorboard",
462
+ help=(
463
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
464
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
465
+ ),
466
+ )
467
+ parser.add_argument(
468
+ "--mixed_precision",
469
+ type=str,
470
+ default=None,
471
+ choices=["no", "fp16", "bf16"],
472
+ help=(
473
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
474
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
475
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
476
+ ),
477
+ )
478
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
479
+ parser.add_argument(
480
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
481
+ )
482
+ parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
483
+ parser.add_argument(
484
+ "--rank",
485
+ type=int,
486
+ default=4,
487
+ help=("The dimension of the LoRA update matrices."),
488
+ )
489
+
490
+ parser.add_argument('--filename', default='train.txt')
491
+ parser.add_argument('--code_filename', default='train_caps_better_m8_k256.txt')
492
+ parser.add_argument('--repeat', default=1, type=int)
493
+
494
+ parser.add_argument('--scheduler_steps', default=1000, type=int, help='scheduler step, if turbo, set to 4')
495
+ parser.add_argument('--num_parts', type=int, default=4, help="Number of parts")
496
+ parser.add_argument('--num_k_per_part', type=int, default=256, help='Number of k')
497
+
498
+ parser.add_argument('--mapper_lr_scale', default=1, type=float)
499
+ parser.add_argument('--mapper_lr', default=0.0001, type=float)
500
+ parser.add_argument('--attn_loss', default=0, type=float)
501
+ parser.add_argument('--projection_nlayers', default=3, type=int)
502
+
503
+ parser.add_argument('--masked_training', action='store_true')
504
+ parser.add_argument('--drop_tokens', action='store_true')
505
+ parser.add_argument('--drop_rate', type=float, default=0.5)
506
+ parser.add_argument('--drop_counts', default='half')
507
+
508
+ parser.add_argument('--class_name', default='')
509
+ parser.add_argument('--no_pe', action='store_true')
510
+ parser.add_argument('--vector_shuffle', action='store_true')
511
+
512
+ parser.add_argument('--use_gt_label', action='store_true')
513
+ parser.add_argument('--bg_code', default=7, type=int) # for gt_label
514
+ parser.add_argument('--fg_idx', default=0, type=int)
515
+ parser.add_argument('--use_templates', action='store_true')
516
+
517
+ parser.add_argument('--filter_class', default=None, type=int, help='debugging purpose')
518
+
519
+ if input_args is not None:
520
+ args = parser.parse_args(input_args)
521
+ else:
522
+ args = parser.parse_args()
523
+
524
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
525
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
526
+ args.local_rank = env_local_rank
527
+
528
+ # Sanity checks
529
+ if args.dataset_name is None and args.train_data_dir is None:
530
+ raise ValueError("Need either a dataset name or a training folder.")
531
+
532
+ return args
533
+
534
+
535
+ def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
536
+ """
537
+ Returns:
538
+ a state dict containing just the attention processor parameters.
539
+ """
540
+ attn_processors = get_attn_processors(unet)
541
+
542
+ attn_processors_state_dict = {}
543
+
544
+ for attn_processor_key, attn_processor in attn_processors.items():
545
+ for parameter_key, parameter in attn_processor.state_dict().items():
546
+ attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
547
+
548
+ return attn_processors_state_dict
549
+
550
+
551
+ def encode_prompt(text_encoders, text_input_ids_list, placeholder_token_ids, mapper_outputs):
552
+ prompt_embeds_list = []
553
+
554
+ for i, text_encoder in enumerate(text_encoders):
555
+ text_input_ids = text_input_ids_list[i]
556
+
557
+ modified_hs = text_encoder.text_model.forward_embeddings_with_mapper(text_input_ids,
558
+ None,
559
+ mapper_outputs[i],
560
+ placeholder_token_ids)
561
+
562
+ prompt_embeds = text_encoder(text_input_ids,
563
+ hidden_states=modified_hs,
564
+ output_hidden_states=True)
565
+ # prompt_embeds = text_encoder(
566
+ # text_input_ids.to(text_encoder.device),
567
+ # output_hidden_states=True,
568
+ # )
569
+
570
+ # We are only ALWAYS interested in the pooled output of the final text encoder
571
+ pooled_prompt_embeds = prompt_embeds[0]
572
+ prompt_embeds = prompt_embeds.hidden_states[-2]
573
+ bs_embed, seq_len, _ = prompt_embeds.shape
574
+ prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1)
575
+ prompt_embeds_list.append(prompt_embeds)
576
+
577
+ prompt_embeds = torch.concat(prompt_embeds_list, dim=-1)
578
+ pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1)
579
+ return prompt_embeds, pooled_prompt_embeds
580
+
581
+
582
+ def collate_fn(args, tokenizer_one, tokenizer_two, placeholder_token):
583
+ # Preprocessing the datasets.
584
+ train_resize = transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR)
585
+ train_crop = transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)
586
+ train_flip = transforms.RandomHorizontalFlip(p=1.0)
587
+ train_transforms = transforms.Compose(
588
+ [
589
+ transforms.ToTensor(),
590
+ transforms.Normalize([0.5], [0.5]),
591
+ ]
592
+ )
593
+
594
+ def f(examples):
595
+ # image aug
596
+ original_sizes = []
597
+ all_images = []
598
+ crop_top_lefts = []
599
+ captions = []
600
+ raw_images = []
601
+ appeared_tokens = []
602
+ codes = []
603
+ for i in range(len(examples)):
604
+ ##### original sdxl process #####
605
+ image = examples[i]['pixel_values'].convert('RGB')
606
+ original_sizes.append((image.height, image.width))
607
+ image = train_resize(image)
608
+ if args.center_crop:
609
+ y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
610
+ x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
611
+ image = train_crop(image)
612
+ else:
613
+ y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
614
+ image = crop(image, y1, x1, h, w)
615
+ if args.random_flip and random.random() < 0.5:
616
+ # flip
617
+ x1 = image.width - x1
618
+ image = train_flip(image)
619
+ crop_top_left = (y1, x1)
620
+ crop_top_lefts.append(crop_top_left)
621
+ raw_images.append(image)
622
+ image = train_transforms(image)
623
+ all_images.append(image)
624
+
625
+ ##### dreamcreature caption #####
626
+ if args.use_templates and random.random() <= 0.5: # 50% using templates
627
+ if args.class_name != '':
628
+ caption = random.choice(IMAGENET_TEMPLATES).format(f'{placeholder_token} {args.class_name}')
629
+ else:
630
+ caption = random.choice(IMAGENET_TEMPLATES).format(placeholder_token)
631
+ else:
632
+ if args.class_name != '':
633
+ caption = f'{placeholder_token} {args.class_name}'
634
+ else:
635
+ caption = placeholder_token
636
+
637
+ tokens = tokenizer_one.token_map[placeholder_token][:args.num_parts]
638
+ tokens = [tokens[a] for a in examples[i]['appeared']]
639
+
640
+ if args.vector_shuffle or args.drop_tokens:
641
+ tokens = copy.copy(tokens)
642
+ random.shuffle(tokens)
643
+
644
+ if args.drop_tokens and random.random() < args.drop_rate and len(tokens) >= 2:
645
+ # randomly drop half of the tokens
646
+ if args.drop_counts == 'half':
647
+ tokens = tokens[:len(tokens) // 2]
648
+ else:
649
+ tokens = tokens[:int(args.drop_counts)]
650
+
651
+ caption = caption.replace(placeholder_token, ' '.join(tokens))
652
+ captions.append(caption)
653
+
654
+ appeared = [int(t.split('_')[1]) for t in tokens] # <part>_i
655
+ # examples[i]['appeared'] = appeared
656
+
657
+ appeared_tokens.append(appeared)
658
+
659
+ code = examples[i]['codes']
660
+ codes.append(code)
661
+
662
+ tokens_one = tokenize_prompt(tokenizer_one, captions)
663
+ tokens_two = tokenize_prompt(tokenizer_two, captions)
664
+
665
+ ##### start stacking #####
666
+ pixel_values = torch.stack([image for image in all_images])
667
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
668
+ original_sizes = [s for s in original_sizes]
669
+ crop_top_lefts = [c for c in crop_top_lefts]
670
+ input_ids_one = torch.stack([t for t in tokens_one])
671
+ input_ids_two = torch.stack([t for t in tokens_two])
672
+
673
+ codes = torch.stack(codes, dim=0)
674
+
675
+ collate_output = {
676
+ "original_sizes": original_sizes,
677
+ "crop_top_lefts": crop_top_lefts,
678
+ "pixel_values": pixel_values,
679
+ "input_ids_one": input_ids_one,
680
+ "input_ids_two": input_ids_two,
681
+ "raw_images": raw_images,
682
+ "appeared_tokens": appeared_tokens,
683
+ "codes": codes
684
+ }
685
+
686
+ return collate_output
687
+
688
+ return f
689
+
690
+
691
+ def setup_attn_processors(unet, args):
692
+ attn_size = args.resolution // 32
693
+ attn_procs = {}
694
+ for name in unet.attn_processors.keys():
695
+ attn_procs[name] = AttnProcessorCustom(attn_size)
696
+ unet.set_attn_processor(attn_procs)
697
+
698
+
699
+ def init_for_pipeline(args):
700
+ tokenizer_one = MultiTokenCLIPTokenizer.from_pretrained(
701
+ args.pretrained_model_name_or_path,
702
+ subfolder="tokenizer",
703
+ revision=args.revision,
704
+ use_fast=False,
705
+ )
706
+ tokenizer_two = MultiTokenCLIPTokenizer.from_pretrained(
707
+ args.pretrained_model_name_or_path,
708
+ subfolder="tokenizer_2",
709
+ revision=args.revision,
710
+ use_fast=False,
711
+ )
712
+ text_encoder_cls_one = CustomCLIPTextModel
713
+ text_encoder_cls_two = CustomCLIPTextModelWithProjection
714
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
715
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
716
+ )
717
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
718
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
719
+ )
720
+
721
+ OUT_DIMS = 768 + 1280 # 2048
722
+ simple_mapper = TokenMapper(args.num_parts,
723
+ args.num_k_per_part,
724
+ OUT_DIMS,
725
+ args.projection_nlayers)
726
+ return text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, simple_mapper
727
+
728
+
729
+ def main(args):
730
+ logging_dir = Path(args.output_dir, args.logging_dir)
731
+
732
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
733
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
734
+ accelerator = Accelerator(
735
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
736
+ mixed_precision=args.mixed_precision,
737
+ log_with=args.report_to,
738
+ project_config=accelerator_project_config,
739
+ kwargs_handlers=[kwargs],
740
+ )
741
+
742
+ if args.report_to == "wandb":
743
+ if not is_wandb_available():
744
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
745
+ import wandb
746
+
747
+ # Make one log on every process with the configuration for debugging.
748
+ logging.basicConfig(
749
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
750
+ datefmt="%m/%d/%Y %H:%M:%S",
751
+ level=logging.INFO,
752
+ )
753
+ logger.info(accelerator.state, main_process_only=False)
754
+ if accelerator.is_local_main_process:
755
+ datasets.utils.logging.set_verbosity_warning()
756
+ transformers.utils.logging.set_verbosity_warning()
757
+ diffusers.utils.logging.set_verbosity_info()
758
+ else:
759
+ datasets.utils.logging.set_verbosity_error()
760
+ transformers.utils.logging.set_verbosity_error()
761
+ diffusers.utils.logging.set_verbosity_error()
762
+
763
+ # If passed along, set the training seed now.
764
+ if args.seed is not None:
765
+ set_seed(args.seed)
766
+
767
+ # Handle the repository creation
768
+ if accelerator.is_main_process:
769
+ if args.output_dir is not None:
770
+ os.makedirs(args.output_dir, exist_ok=True)
771
+
772
+ if args.push_to_hub:
773
+ repo_id = create_repo(
774
+ repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
775
+ ).repo_id
776
+
777
+ # Load the tokenizers (replace AutoTokenizer with the custom MultiTokenCLIPTokenizer)
778
+ tokenizer_one = MultiTokenCLIPTokenizer.from_pretrained(
779
+ args.pretrained_model_name_or_path,
780
+ subfolder="tokenizer",
781
+ revision=args.revision,
782
+ use_fast=False,
783
+ )
784
+ tokenizer_two = MultiTokenCLIPTokenizer.from_pretrained(
785
+ args.pretrained_model_name_or_path,
786
+ subfolder="tokenizer_2",
787
+ revision=args.revision,
788
+ use_fast=False,
789
+ )
790
+ # import correct text encoder classes
791
+ # text_encoder_cls_one = import_model_class_from_model_name_or_path(
792
+ # args.pretrained_model_name_or_path, args.revision
793
+ # )
794
+ # text_encoder_cls_two = import_model_class_from_model_name_or_path(
795
+ # args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2"
796
+ # )
797
+ text_encoder_cls_one = CustomCLIPTextModel
798
+ text_encoder_cls_two = CustomCLIPTextModelWithProjection
799
+
800
+ # Load scheduler and models
801
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path,
802
+ subfolder="scheduler",
803
+ num_train_steps=args.scheduler_steps)
804
+ text_encoder_one = text_encoder_cls_one.from_pretrained(
805
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant
806
+ )
807
+ text_encoder_two = text_encoder_cls_two.from_pretrained(
808
+ args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant
809
+ )
810
+ vae_path = (
811
+ args.pretrained_model_name_or_path
812
+ if args.pretrained_vae_model_name_or_path is None
813
+ else args.pretrained_vae_model_name_or_path
814
+ )
815
+ vae = AutoencoderKL.from_pretrained(
816
+ vae_path,
817
+ subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
818
+ revision=args.revision,
819
+ variant=args.variant,
820
+ )
821
+ unet = UNet2DConditionModel.from_pretrained(
822
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, variant=args.variant
823
+ )
824
+
825
+ ##### dreamcreature init #####
826
+ OUT_DIMS = 768 + 1280 # 2048
827
+
828
+ dino = DINO()
829
+ seg = KMeansSegmentation(args.train_data_dir + '/pretrained_kmeans.pth',
830
+ args.fg_idx,
831
+ args.bg_code,
832
+ args.num_parts,
833
+ args.num_k_per_part)
834
+
835
+ simple_mapper = TokenMapper(args.num_parts,
836
+ args.num_k_per_part,
837
+ OUT_DIMS,
838
+ args.projection_nlayers)
839
+
840
+ # We only train the additional adapter LoRA layers
841
+ vae.requires_grad_(False)
842
+ text_encoder_one.requires_grad_(False)
843
+ text_encoder_two.requires_grad_(False)
844
+ unet.requires_grad_(False)
845
+ dino.requires_grad_(False)
846
+
847
+ ##### dreamcreature, add sub-concepts token ids ####
848
+ placeholder_token = "<part>"
849
+ initializer_token = None
850
+ placeholder_token_ids_one = add_tokens(tokenizer_one,
851
+ text_encoder_one,
852
+ placeholder_token,
853
+ args.num_parts,
854
+ initializer_token)
855
+ placeholder_token_ids_two = add_tokens(tokenizer_two,
856
+ text_encoder_two,
857
+ placeholder_token,
858
+ args.num_parts,
859
+ initializer_token)
860
+
861
+ # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
862
+ # as these weights are only used for inference, keeping weights in full precision is not required.
863
+ weight_dtype = torch.float32
864
+ if accelerator.mixed_precision == "fp16":
865
+ weight_dtype = torch.float16
866
+ elif accelerator.mixed_precision == "bf16":
867
+ weight_dtype = torch.bfloat16
868
+
869
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
870
+ # The VAE is in float32 to avoid NaN losses.
871
+ unet.to(accelerator.device, dtype=weight_dtype)
872
+ if args.pretrained_vae_model_name_or_path is None:
873
+ vae.to(accelerator.device, dtype=torch.float32)
874
+ else:
875
+ vae.to(accelerator.device, dtype=weight_dtype)
876
+ text_encoder_one.to(accelerator.device, dtype=weight_dtype)
877
+ text_encoder_two.to(accelerator.device, dtype=weight_dtype)
878
+ simple_mapper.to(accelerator.device)
879
+
880
+ if args.enable_xformers_memory_efficient_attention:
881
+ if is_xformers_available():
882
+ import xformers
883
+
884
+ xformers_version = version.parse(xformers.__version__)
885
+ if xformers_version == version.parse("0.0.16"):
886
+ logger.warn(
887
+ "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
888
+ )
889
+ unet.enable_xformers_memory_efficient_attention()
890
+ else:
891
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
892
+
893
+ # now we will add new LoRA weights to the attention layers
894
+ # Set correct lora layers
895
+ unet_lora_parameters = []
896
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
897
+ # Parse the attention module.
898
+ attn_module = unet
899
+ for n in attn_processor_name.split(".")[:-1]:
900
+ attn_module = getattr(attn_module, n)
901
+
902
+ # Set the `lora_layer` attribute of the attention-related matrices.
903
+ attn_module.to_q.set_lora_layer(
904
+ LoRALinearLayer(
905
+ in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
906
+ )
907
+ )
908
+ attn_module.to_k.set_lora_layer(
909
+ LoRALinearLayer(
910
+ in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
911
+ )
912
+ )
913
+ attn_module.to_v.set_lora_layer(
914
+ LoRALinearLayer(
915
+ in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
916
+ )
917
+ )
918
+ attn_module.to_out[0].set_lora_layer(
919
+ LoRALinearLayer(
920
+ in_features=attn_module.to_out[0].in_features,
921
+ out_features=attn_module.to_out[0].out_features,
922
+ rank=args.rank,
923
+ )
924
+ )
925
+
926
+ # Accumulate the LoRA params to optimize.
927
+ unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
928
+ unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
929
+ unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
930
+ unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
931
+
932
+ setup_attn_processors(unet, args)
933
+
934
+ # The text encoder comes from 🤗 transformers, so we cannot directly modify it.
935
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
936
+ if args.train_text_encoder:
937
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
938
+ text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(
939
+ text_encoder_one, dtype=torch.float32, rank=args.rank
940
+ )
941
+ text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(
942
+ text_encoder_two, dtype=torch.float32, rank=args.rank
943
+ )
944
+
945
+ # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
946
+ def save_model_hook(models, weights, output_dir):
947
+ if accelerator.is_main_process:
948
+ # there are only two options here. Either are just the unet attn processor layers
949
+ # or there are the unet and text encoder atten layers
950
+ unet_lora_layers_to_save = None
951
+ text_encoder_one_lora_layers_to_save = None
952
+ text_encoder_two_lora_layers_to_save = None
953
+ mapper_to_save = None
954
+
955
+ for model in models:
956
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
957
+ unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
958
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
959
+ text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
960
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
961
+ text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
962
+ elif isinstance(model, TokenMapper):
963
+ mapper_to_save = model.state_dict()
964
+ else:
965
+ raise ValueError(f"unexpected save model: {model.__class__}")
966
+
967
+ # make sure to pop weight so that corresponding model is not saved again
968
+ weights.pop()
969
+
970
+ StableDiffusionXLPipeline.save_lora_weights(
971
+ output_dir,
972
+ unet_lora_layers=unet_lora_layers_to_save,
973
+ text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
974
+ text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
975
+ )
976
+ torch.save(mapper_to_save, output_dir + '/hash_mapper.pth')
977
+
978
+ def load_model_hook(models, input_dir):
979
+ unet_ = None
980
+ text_encoder_one_ = None
981
+ text_encoder_two_ = None
982
+ mapper_ = None
983
+
984
+ while len(models) > 0:
985
+ model = models.pop()
986
+
987
+ if isinstance(model, type(accelerator.unwrap_model(unet))):
988
+ unet_ = model
989
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
990
+ text_encoder_one_ = model
991
+ elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
992
+ text_encoder_two_ = model
993
+ elif isinstance(model, TokenMapper):
994
+ mapper_ = model
995
+ else:
996
+ raise ValueError(f"unexpected save model: {model.__class__}")
997
+
998
+ lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
999
+ LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
1000
+
1001
+ text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
1002
+ LoraLoaderMixin.load_lora_into_text_encoder(
1003
+ text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
1004
+ )
1005
+
1006
+ text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
1007
+ LoraLoaderMixin.load_lora_into_text_encoder(
1008
+ text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
1009
+ )
1010
+ mapper_.load_state_dict(torch.load(input_dir + '/hash_mapper.pth'))
1011
+
1012
+ accelerator.register_save_state_pre_hook(save_model_hook)
1013
+ accelerator.register_load_state_pre_hook(load_model_hook)
1014
+
1015
+ # Enable TF32 for faster training on Ampere GPUs,
1016
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
1017
+ if args.allow_tf32:
1018
+ torch.backends.cuda.matmul.allow_tf32 = True
1019
+
1020
+ if args.scale_lr:
1021
+ args.learning_rate = (
1022
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1023
+ )
1024
+
1025
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
1026
+ if args.use_8bit_adam:
1027
+ try:
1028
+ import bitsandbytes as bnb
1029
+ except ImportError:
1030
+ raise ImportError(
1031
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
1032
+ )
1033
+
1034
+ optimizer_class = bnb.optim.AdamW8bit
1035
+ else:
1036
+ optimizer_class = torch.optim.AdamW
1037
+
1038
+ extra_params = list(simple_mapper.parameters())
1039
+ mapper_lr = args.learning_rate * args.mapper_lr_scale if args.learning_rate != 0 else args.mapper_lr
1040
+
1041
+ # Optimizer creation
1042
+ params_to_optimize = (
1043
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
1044
+ if args.train_text_encoder
1045
+ else unet_lora_parameters
1046
+ )
1047
+ optimizer = optimizer_class(
1048
+ [{'params': params_to_optimize},
1049
+ {'params': extra_params, 'lr': mapper_lr}],
1050
+ lr=args.learning_rate,
1051
+ betas=(args.adam_beta1, args.adam_beta2),
1052
+ weight_decay=args.adam_weight_decay,
1053
+ eps=args.adam_epsilon,
1054
+ )
1055
+
1056
+ # create
1057
+ train_dataset = DreamCreatureDataset(args.train_data_dir,
1058
+ args.filename,
1059
+ code_filename=args.code_filename,
1060
+ num_parts=args.num_parts,
1061
+ num_k_per_part=args.num_k_per_part,
1062
+ repeat=args.repeat,
1063
+ use_gt_label=args.use_gt_label,
1064
+ bg_code=args.bg_code)
1065
+
1066
+ with accelerator.main_process_first():
1067
+ if args.filter_class is not None:
1068
+ train_dataset.filter_by_class(args.filter_class)
1069
+ print('selected', len(train_dataset))
1070
+ if args.max_train_samples is not None:
1071
+ train_dataset.set_max_samples(args.max_train_samples, args.seed)
1072
+
1073
+ # DataLoaders creation:
1074
+ train_dataloader = torch.utils.data.DataLoader(
1075
+ train_dataset,
1076
+ shuffle=True,
1077
+ collate_fn=collate_fn(args, tokenizer_one, tokenizer_two, placeholder_token),
1078
+ batch_size=args.train_batch_size,
1079
+ num_workers=args.dataloader_num_workers,
1080
+ )
1081
+
1082
+ # Scheduler and math around the number of training steps.
1083
+ overrode_max_train_steps = False
1084
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1085
+ if args.max_train_steps is None:
1086
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1087
+ overrode_max_train_steps = True
1088
+
1089
+ lr_scheduler = get_scheduler(
1090
+ args.lr_scheduler,
1091
+ optimizer=optimizer,
1092
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
1093
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
1094
+ )
1095
+
1096
+ # Prepare everything with our `accelerator`.
1097
+ if args.train_text_encoder:
1098
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1099
+ unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
1100
+ )
1101
+ else:
1102
+ unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1103
+ unet, optimizer, train_dataloader, lr_scheduler
1104
+ )
1105
+ simple_mapper = accelerator.prepare(simple_mapper)
1106
+
1107
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
1108
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
1109
+ if overrode_max_train_steps:
1110
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
1111
+ # Afterwards we recalculate our number of training epochs
1112
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
1113
+
1114
+ # We need to initialize the trackers we use, and also store our configuration.
1115
+ # The trackers initializes automatically on the main process.
1116
+ if accelerator.is_main_process:
1117
+ accelerator.init_trackers("text2image-fine-tune", config=vars(args))
1118
+
1119
+ # Train!
1120
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
1121
+
1122
+ logger.info("***** Running training *****")
1123
+ logger.info(f" Num examples = {len(train_dataset)}")
1124
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
1125
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
1126
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
1127
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
1128
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
1129
+ global_step = 0
1130
+ first_epoch = 0
1131
+
1132
+ # Potentially load in the weights and states from a previous save
1133
+ if args.resume_from_checkpoint:
1134
+ if args.resume_from_checkpoint != "latest":
1135
+ path = os.path.basename(args.resume_from_checkpoint)
1136
+ else:
1137
+ # Get the most recent checkpoint
1138
+ dirs = os.listdir(args.output_dir)
1139
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
1140
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
1141
+ path = dirs[-1] if len(dirs) > 0 else None
1142
+
1143
+ if path is None:
1144
+ accelerator.print(
1145
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
1146
+ )
1147
+ args.resume_from_checkpoint = None
1148
+ initial_global_step = 0
1149
+ else:
1150
+ accelerator.print(f"Resuming from checkpoint {path}")
1151
+ accelerator.load_state(os.path.join(args.output_dir, path))
1152
+ global_step = int(path.split("-")[1])
1153
+
1154
+ initial_global_step = global_step
1155
+ first_epoch = global_step // num_update_steps_per_epoch
1156
+
1157
+ else:
1158
+ initial_global_step = 0
1159
+
1160
+ progress_bar = tqdm(
1161
+ range(0, args.max_train_steps),
1162
+ initial=initial_global_step,
1163
+ desc="Steps",
1164
+ # Only show the progress bar once on each machine.
1165
+ disable=not accelerator.is_local_main_process,
1166
+ )
1167
+
1168
+ for epoch in range(first_epoch, args.num_train_epochs):
1169
+ unet.train()
1170
+ if args.train_text_encoder:
1171
+ text_encoder_one.train()
1172
+ text_encoder_two.train()
1173
+ train_loss = 0.0
1174
+ train_diff_loss = 0.0
1175
+ train_attn_loss = 0.0
1176
+ for step, batch in enumerate(train_dataloader):
1177
+ with accelerator.accumulate(unet, simple_mapper):
1178
+ # Convert images to latent space
1179
+ if args.pretrained_vae_model_name_or_path is not None:
1180
+ pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
1181
+ else:
1182
+ pixel_values = batch["pixel_values"]
1183
+
1184
+ model_input = vae.encode(pixel_values).latent_dist.sample()
1185
+ model_input = model_input * vae.config.scaling_factor
1186
+ if args.pretrained_vae_model_name_or_path is None:
1187
+ model_input = model_input.to(weight_dtype)
1188
+
1189
+ # Sample noise that we'll add to the latents
1190
+ noise = torch.randn_like(model_input)
1191
+ if args.noise_offset:
1192
+ # https://www.crosslabs.org//blog/diffusion-with-offset-noise
1193
+ noise += args.noise_offset * torch.randn(
1194
+ (model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
1195
+ )
1196
+
1197
+ bsz = model_input.shape[0]
1198
+ # Sample a random timestep for each image
1199
+ timesteps = torch.randint(
1200
+ 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=model_input.device
1201
+ )
1202
+ timesteps = timesteps.long()
1203
+
1204
+ # Add noise to the model input according to the noise magnitude at each timestep
1205
+ # (this is the forward diffusion process)
1206
+ noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
1207
+
1208
+ # time ids
1209
+ def compute_time_ids(original_size, crops_coords_top_left):
1210
+ # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
1211
+ target_size = (args.resolution, args.resolution)
1212
+ add_time_ids = list(original_size + crops_coords_top_left + target_size)
1213
+ add_time_ids = torch.tensor([add_time_ids])
1214
+ add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
1215
+ return add_time_ids
1216
+
1217
+ add_time_ids = torch.cat(
1218
+ [compute_time_ids(s, c) for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])]
1219
+ )
1220
+
1221
+ # Predict the noise residual
1222
+ unet_added_conditions = {"time_ids": add_time_ids}
1223
+ # prompt_embeds, pooled_prompt_embeds = encode_prompt(
1224
+ # text_encoders=[text_encoder_one, text_encoder_two],
1225
+ # tokenizers=None,
1226
+ # prompt=None,
1227
+ # text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
1228
+ # )
1229
+ mapper_outputs = simple_mapper(batch['codes'])
1230
+ prompt_embeds, pooled_prompt_embeds = encode_prompt(
1231
+ text_encoders=[text_encoder_one, text_encoder_two],
1232
+ text_input_ids_list=[batch["input_ids_one"], batch["input_ids_two"]],
1233
+ placeholder_token_ids=placeholder_token_ids_one,
1234
+ mapper_outputs=[mapper_outputs[..., :768], mapper_outputs[..., 768:]]
1235
+ )
1236
+
1237
+ unet_added_conditions.update({"text_embeds": pooled_prompt_embeds})
1238
+ model_pred = unet(
1239
+ noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
1240
+ ).sample
1241
+
1242
+ # Get the target for loss depending on the prediction type
1243
+ if args.prediction_type is not None:
1244
+ # set prediction_type of scheduler if defined
1245
+ noise_scheduler.register_to_config(prediction_type=args.prediction_type)
1246
+
1247
+ if noise_scheduler.config.prediction_type == "epsilon":
1248
+ target = noise
1249
+ elif noise_scheduler.config.prediction_type == "v_prediction":
1250
+ target = noise_scheduler.get_velocity(model_input, noise, timesteps)
1251
+ else:
1252
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
1253
+
1254
+ if args.snr_gamma is None:
1255
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1256
+ attn_loss, max_attn = dreamcreature_loss(batch,
1257
+ unet,
1258
+ dino,
1259
+ seg,
1260
+ placeholder_token_ids_one,
1261
+ accelerator)
1262
+ if args.masked_training:
1263
+ masks = batch['masks'].unsqueeze(1).to(accelerator.device)
1264
+ loss_image_mask = F.interpolate(masks.float(),
1265
+ size=target.shape[-2:],
1266
+ mode='bilinear') * torch.ones_like(target)
1267
+ loss = loss * loss_image_mask
1268
+ loss = loss.sum() / loss_image_mask.sum()
1269
+ else:
1270
+ loss = loss.mean()
1271
+ else:
1272
+ # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
1273
+ # Since we predict the noise instead of x_0, the original formulation is slightly changed.
1274
+ # This is discussed in Section 4.2 of the same paper.
1275
+ snr = compute_snr(noise_scheduler, timesteps)
1276
+ if noise_scheduler.config.prediction_type == "v_prediction":
1277
+ # Velocity objective requires that we add one to SNR values before we divide by them.
1278
+ snr = snr + 1
1279
+ mse_loss_weights = (
1280
+ torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
1281
+ )
1282
+
1283
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
1284
+ attn_loss, max_attn = dreamcreature_loss(batch,
1285
+ unet,
1286
+ dino,
1287
+ seg,
1288
+ placeholder_token_ids_one,
1289
+ accelerator)
1290
+ if args.masked_training:
1291
+ masks = batch['masks'].unsqueeze(1).to(accelerator.device)
1292
+ loss_image_mask = F.interpolate(masks.float(),
1293
+ size=target.shape[-2:],
1294
+ mode='bilinear') * torch.ones_like(target)
1295
+ loss = loss * loss_image_mask
1296
+ loss = loss.sum(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1297
+ loss = loss.sum() / loss_image_mask.sum()
1298
+ else:
1299
+ loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
1300
+ loss = loss.mean()
1301
+
1302
+ diff_loss = loss.clone().detach()
1303
+ avg_diff_loss = accelerator.gather(diff_loss.repeat(args.train_batch_size)).mean()
1304
+ train_diff_loss += avg_diff_loss.item() / args.gradient_accumulation_steps
1305
+
1306
+ avg_attn_loss = accelerator.gather(attn_loss.repeat(args.train_batch_size)).mean()
1307
+ train_attn_loss += avg_attn_loss.item() / args.gradient_accumulation_steps
1308
+
1309
+ loss += args.attn_loss * attn_loss
1310
+
1311
+ # Gather the losses across all processes for logging (if we use distributed training).
1312
+ avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
1313
+ train_loss += avg_loss.item() / args.gradient_accumulation_steps
1314
+
1315
+ # Backpropagate
1316
+ accelerator.backward(loss)
1317
+ if accelerator.sync_gradients:
1318
+ params_to_clip = (
1319
+ itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
1320
+ if args.train_text_encoder
1321
+ else unet_lora_parameters
1322
+ )
1323
+ params_to_clip = list(params_to_clip) + extra_params
1324
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
1325
+ optimizer.step()
1326
+ lr_scheduler.step()
1327
+ optimizer.zero_grad()
1328
+
1329
+ # Checks if the accelerator has performed an optimization step behind the scenes
1330
+ if accelerator.sync_gradients:
1331
+ progress_bar.update(1)
1332
+ global_step += 1
1333
+ accelerator.log({"train_loss": train_loss,
1334
+ "diff_loss": train_diff_loss,
1335
+ "attn_loss": train_attn_loss,
1336
+ "max_attn": max_attn.item()
1337
+ }, step=global_step)
1338
+ train_loss = 0.0
1339
+ train_attn_loss = 0.0
1340
+ train_diff_loss = 0.0
1341
+
1342
+ if accelerator.is_main_process:
1343
+ if global_step % args.checkpointing_steps == 0:
1344
+ # _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
1345
+ if args.checkpoints_total_limit is not None:
1346
+ checkpoints = os.listdir(args.output_dir)
1347
+ checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
1348
+ checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
1349
+
1350
+ # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
1351
+ if len(checkpoints) >= args.checkpoints_total_limit:
1352
+ num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
1353
+ removing_checkpoints = checkpoints[0:num_to_remove]
1354
+
1355
+ logger.info(
1356
+ f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
1357
+ )
1358
+ logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
1359
+
1360
+ for removing_checkpoint in removing_checkpoints:
1361
+ removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
1362
+ shutil.rmtree(removing_checkpoint)
1363
+
1364
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
1365
+ accelerator.save_state(save_path)
1366
+ logger.info(f"Saved state to {save_path}")
1367
+
1368
+ logs = {"step_loss": diff_loss.detach().item(),
1369
+ "attn_loss": attn_loss.detach().item(),
1370
+ "lr": lr_scheduler.get_last_lr()[0]}
1371
+ progress_bar.set_postfix(**logs)
1372
+
1373
+ if global_step >= args.max_train_steps:
1374
+ break
1375
+
1376
+ if accelerator.is_main_process:
1377
+ # todo: change pipeline
1378
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
1379
+ logger.info(
1380
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
1381
+ f" {args.validation_prompt}."
1382
+ )
1383
+ # create pipeline
1384
+ pipeline = DreamCreatureSDXLPipeline.from_pretrained(
1385
+ args.pretrained_model_name_or_path,
1386
+ vae=vae,
1387
+ tokenizer=tokenizer_one,
1388
+ tokenizer_2=tokenizer_two,
1389
+ text_encoder=accelerator.unwrap_model(text_encoder_one),
1390
+ text_encoder_2=accelerator.unwrap_model(text_encoder_two),
1391
+ unet=accelerator.unwrap_model(unet),
1392
+ revision=args.revision,
1393
+ variant=args.variant,
1394
+ torch_dtype=weight_dtype,
1395
+ )
1396
+ pipeline.placeholder_token_ids = placeholder_token_ids_one
1397
+ pipeline.simple_mapper = accelerator.unwrap_model(simple_mapper)
1398
+ pipeline.replace_token = False
1399
+
1400
+ pipeline = pipeline.to(accelerator.device)
1401
+ pipeline.set_progress_bar_config(disable=True)
1402
+
1403
+ # run inference
1404
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1405
+ pipeline_args = {"prompt": args.validation_prompt}
1406
+
1407
+ num_steps = 4 if 'turbo' in args.pretrained_model_name_or_path else 25
1408
+ gs = 0 if 'turbo' in args.pretrained_model_name_or_path else 5.0
1409
+
1410
+ images = [
1411
+ pipeline(**pipeline_args, num_inference_steps=num_steps, guidance_scale=gs,
1412
+ generator=generator, height=args.resolution, width=args.resolution).images[0]
1413
+ for _ in range(args.num_validation_images)
1414
+ ]
1415
+
1416
+ for tracker in accelerator.trackers:
1417
+ if tracker.name == "tensorboard":
1418
+ np_images = np.stack([np.asarray(img) for img in images])
1419
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
1420
+ if tracker.name == "wandb":
1421
+ tracker.log(
1422
+ {
1423
+ "validation": [
1424
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1425
+ for i, image in enumerate(images)
1426
+ ]
1427
+ }
1428
+ )
1429
+
1430
+ del pipeline
1431
+ torch.cuda.empty_cache()
1432
+
1433
+ # Save the lora layers
1434
+ accelerator.wait_for_everyone()
1435
+ if accelerator.is_main_process:
1436
+ unet = accelerator.unwrap_model(unet)
1437
+ unet_lora_layers = unet_attn_processors_state_dict(unet)
1438
+
1439
+ if args.train_text_encoder:
1440
+ text_encoder_one = accelerator.unwrap_model(text_encoder_one)
1441
+ text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one)
1442
+ text_encoder_two = accelerator.unwrap_model(text_encoder_two)
1443
+ text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two)
1444
+ else:
1445
+ text_encoder_lora_layers = None
1446
+ text_encoder_2_lora_layers = None
1447
+
1448
+ StableDiffusionXLPipeline.save_lora_weights(
1449
+ save_directory=args.output_dir,
1450
+ unet_lora_layers=unet_lora_layers,
1451
+ text_encoder_lora_layers=text_encoder_lora_layers,
1452
+ text_encoder_2_lora_layers=text_encoder_2_lora_layers,
1453
+ )
1454
+ torch.save(simple_mapper.to(torch.float32).state_dict(), args.output_dir + '/hash_mapper.pth')
1455
+
1456
+ del unet
1457
+ del text_encoder_one
1458
+ del text_encoder_two
1459
+ del text_encoder_lora_layers
1460
+ del text_encoder_2_lora_layers
1461
+ del simple_mapper
1462
+ torch.cuda.empty_cache()
1463
+
1464
+ # Final inference
1465
+ # Load previous pipeline
1466
+ text_encoder_one, text_encoder_two, tokenizer_one, tokenizer_two, simple_mapper = init_for_pipeline(args)
1467
+ pipeline = DreamCreatureSDXLPipeline.from_pretrained(
1468
+ args.pretrained_model_name_or_path,
1469
+ vae=vae,
1470
+ tokenizer=tokenizer_one,
1471
+ tokenizer_2=tokenizer_two,
1472
+ text_encoder=text_encoder_one,
1473
+ text_encoder_2=text_encoder_two,
1474
+ revision=args.revision,
1475
+ variant=args.variant,
1476
+ torch_dtype=weight_dtype,
1477
+ )
1478
+ pipeline.placeholder_token_ids = placeholder_token_ids_one
1479
+ pipeline.replace_token = False
1480
+ pipeline.simple_mapper = simple_mapper
1481
+ pipeline.simple_mapper.load_state_dict(torch.load(args.output_dir + '/hash_mapper.pth', map_location='cpu'))
1482
+ pipeline.simple_mapper.to(accelerator.device)
1483
+ setup_attn_processors(pipeline.unet, args)
1484
+
1485
+ pipeline = pipeline.to(accelerator.device)
1486
+
1487
+ # load attention processors
1488
+ pipeline.load_lora_weights(args.output_dir)
1489
+
1490
+ # run inference
1491
+ images = []
1492
+ if args.validation_prompt and args.num_validation_images > 0:
1493
+ num_steps = 4 if 'turbo' in args.pretrained_model_name_or_path else 25
1494
+ gs = 0 if 'turbo' in args.pretrained_model_name_or_path else 5.0
1495
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
1496
+ images = [
1497
+ pipeline(args.validation_prompt, num_inference_steps=num_steps,
1498
+ guidance_scale=gs, generator=generator, height=args.resolution,
1499
+ width=args.resolution).images[0]
1500
+ for _ in range(args.num_validation_images)
1501
+ ]
1502
+
1503
+ for tracker in accelerator.trackers:
1504
+ if tracker.name == "tensorboard":
1505
+ np_images = np.stack([np.asarray(img) for img in images])
1506
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1507
+ if tracker.name == "wandb":
1508
+ tracker.log(
1509
+ {
1510
+ "test": [
1511
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1512
+ for i, image in enumerate(images)
1513
+ ]
1514
+ }
1515
+ )
1516
+
1517
+ if args.push_to_hub:
1518
+ save_model_card(
1519
+ repo_id,
1520
+ images=images,
1521
+ base_model=args.pretrained_model_name_or_path,
1522
+ dataset_name=args.dataset_name,
1523
+ train_text_encoder=args.train_text_encoder,
1524
+ repo_folder=args.output_dir,
1525
+ vae_path=args.pretrained_vae_model_name_or_path,
1526
+ )
1527
+ upload_folder(
1528
+ repo_id=repo_id,
1529
+ folder_path=args.output_dir,
1530
+ commit_message="End of training",
1531
+ ignore_patterns=["step_*", "epoch_*"],
1532
+ )
1533
+
1534
+ accelerator.end_training()
1535
+
1536
+
1537
+ if __name__ == "__main__":
1538
+ args = parse_args()
1539
+ main(args)
train_kmeans_segmentation.ipynb ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "1c073d83-8e73-407a-a669-3a837a90aac6",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Compute DINO features"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": 2,
14
+ "id": "901240c3-5111-4733-97ee-69891e4e7184",
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "import argparse\n",
19
+ "import math\n",
20
+ "import os\n",
21
+ "\n",
22
+ "import torch\n",
23
+ "import torchpq\n",
24
+ "from omegaconf import OmegaConf\n",
25
+ "from torch.utils.data import DataLoader\n",
26
+ "from sklearn.decomposition import PCA\n",
27
+ "from torchvision.transforms import transforms\n",
28
+ "from tqdm import tqdm\n",
29
+ "from transformers.utils import constants\n",
30
+ "\n",
31
+ "from dreamcreature.dino import DINO\n",
32
+ "from dreamcreature.dataset import ImageDataset\n",
33
+ "\n",
34
+ "MEAN = constants.IMAGENET_DEFAULT_MEAN\n",
35
+ "STD = constants.IMAGENET_DEFAULT_STD"
36
+ ]
37
+ },
38
+ {
39
+ "cell_type": "code",
40
+ "execution_count": 4,
41
+ "id": "2f83c248-c3d5-4fe2-a111-ecdeda648214",
42
+ "metadata": {},
43
+ "outputs": [],
44
+ "source": [
45
+ "dataset_name = 'cub200_2011'\n",
46
+ "# dataset_name = 'dogs'\n",
47
+ "\n",
48
+ "rootdir = f'data/{dataset_name}'\n",
49
+ "resize = 256\n",
50
+ "crop = 224\n",
51
+ "\n",
52
+ "dataset = ImageDataset(rootdir,\n",
53
+ " 'train.txt',\n",
54
+ " transform=transforms.Compose([\n",
55
+ " transforms.Resize(resize, interpolation=transforms.InterpolationMode.BICUBIC),\n",
56
+ " transforms.CenterCrop(crop),\n",
57
+ " transforms.ToTensor(),\n",
58
+ " transforms.Normalize(MEAN, STD)\n",
59
+ " ]))"
60
+ ]
61
+ },
62
+ {
63
+ "cell_type": "code",
64
+ "execution_count": null,
65
+ "id": "946a88e5-b368-4cc3-a625-fd952650036b",
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "dataloader = DataLoader(dataset, 32, shuffle=False, drop_last=False, num_workers=4)\n",
70
+ "model = DINO()\n",
71
+ "model.eval()\n",
72
+ "\n",
73
+ "device = torch.device('cuda')\n",
74
+ "model = model.to(device)"
75
+ ]
76
+ },
77
+ {
78
+ "cell_type": "code",
79
+ "execution_count": null,
80
+ "id": "f526a02a-745d-43df-94af-1a50ed438fda",
81
+ "metadata": {},
82
+ "outputs": [],
83
+ "source": [
84
+ "os.makedirs(config.rootdir + '/dinov2', exist_ok=True)\n",
85
+ "\n",
86
+ "image_feats = []\n",
87
+ "with tqdm(dataloader, bar_format='{l_bar}{bar:10}{r_bar}') as tepoch:\n",
88
+ " for i, (image, label, index) in enumerate(tepoch):\n",
89
+ " image = image.to(device)\n",
90
+ "\n",
91
+ " with torch.no_grad():\n",
92
+ " output = model.get_feat_maps(image) # (B, C, H, W)\n",
93
+ "\n",
94
+ " B, C, H, W = output.size()\n",
95
+ " output = output.reshape(B, C, H * W)\n",
96
+ " image_feats.append(output.cpu())\n",
97
+ "\n",
98
+ "image_feats = torch.cat(image_feats, dim=0) # (N, C, H*W)\n",
99
+ "torch.save(image_feats, rootdir + '/dinov2_image_feats.pth')"
100
+ ]
101
+ },
102
+ {
103
+ "cell_type": "markdown",
104
+ "id": "50f9ed32-5231-47d0-864e-cfbcd8b6d732",
105
+ "metadata": {},
106
+ "source": [
107
+ "# Train Kmeans Segmentation"
108
+ ]
109
+ },
110
+ {
111
+ "cell_type": "code",
112
+ "execution_count": null,
113
+ "id": "da8b9bde-2c1a-40d5-a32c-dda98334fe17",
114
+ "metadata": {},
115
+ "outputs": [],
116
+ "source": [
117
+ "import torch\n",
118
+ "import random\n",
119
+ "import numpy as np\n",
120
+ "\n",
121
+ "dataset_name = 'cub200_2011'\n",
122
+ "# dataset_name = 'dogs'\n",
123
+ "\n",
124
+ "sd = torch.load(f'data/{dataset_name}/dinov2_image_feats.pth', map_location='cpu')\n",
125
+ "sd.size()"
126
+ ]
127
+ },
128
+ {
129
+ "cell_type": "code",
130
+ "execution_count": null,
131
+ "id": "67b0f5ba-e272-4aea-b17d-05a80e2ce025",
132
+ "metadata": {},
133
+ "outputs": [],
134
+ "source": [
135
+ "from dataset import code_to_int, int_to_caption\n",
136
+ "from dataset import ImageDataset\n",
137
+ "from torchvision.transforms import transforms\n",
138
+ "\n",
139
+ "ds = ImageDataset(f'data/{dataset_name}', transform=transforms.Compose([transforms.Resize(256), transforms.CenterCrop(224)]))\n",
140
+ "train_lines = open(f'data/{dataset_name}/train.txt').readlines()"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "id": "46c57aa3-de67-4d1a-b9a7-584687e437f5",
147
+ "metadata": {},
148
+ "outputs": [],
149
+ "source": [
150
+ "def set_seed(seed):\n",
151
+ " random.seed(seed)\n",
152
+ " np.random.seed(seed)\n",
153
+ " torch.manual_seed(seed)\n",
154
+ " torch.cuda.manual_seed(seed)\n",
155
+ "\n",
156
+ "set_seed(42)\n",
157
+ " \n",
158
+ "n = 100 # use small training sample to avoid OOM\n",
159
+ "randidx = torch.randperm(len(sd))[:n]\n",
160
+ "randsd = sd[randidx].permute(0, 2, 1) # (N, HW, C)\n",
161
+ "randsd.size()"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": null,
167
+ "id": "fab9e741-ff9a-4b6a-8ee3-41bc1d43cb52",
168
+ "metadata": {},
169
+ "outputs": [],
170
+ "source": [
171
+ "import numpy as np\n",
172
+ "import torchpq\n",
173
+ "import torch.nn.functional as F\n",
174
+ "import matplotlib.pyplot as plt\n",
175
+ "import random\n",
176
+ "from sklearn.decomposition import PCA\n",
177
+ "\n",
178
+ "set_seed(42)\n",
179
+ "\n",
180
+ "fg_kmeans = torchpq.clustering.KMeans(n_clusters=2,\n",
181
+ " distance=\"cosine\",\n",
182
+ " verbose=1,\n",
183
+ " n_redo=5,\n",
184
+ " max_iter=1000)\n",
185
+ "fg_labels = fg_kmeans.fit(randsd.reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(n, -1)"
186
+ ]
187
+ },
188
+ {
189
+ "cell_type": "code",
190
+ "execution_count": null,
191
+ "id": "9a12550b-1c66-48a0-9bd6-a39b928cf57d",
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "torch.unique(fg_labels, return_counts=True)"
196
+ ]
197
+ },
198
+ {
199
+ "cell_type": "code",
200
+ "execution_count": null,
201
+ "id": "b9a3b9f0-fc73-4e44-bca8-a2a2da16df38",
202
+ "metadata": {},
203
+ "outputs": [],
204
+ "source": [
205
+ "for i in range(100):\n",
206
+ " plt.subplot(10, 10, i+1)\n",
207
+ " plt.imshow(fg_labels[i].reshape(16, 16))\n",
208
+ " plt.axis('off')"
209
+ ]
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "execution_count": null,
214
+ "id": "6f8e6121-74ee-4539-bf61-9d0ba2198ef5",
215
+ "metadata": {},
216
+ "outputs": [],
217
+ "source": [
218
+ "fg_idx = 0 # this have to do manual inspection, based on the visualization above\n",
219
+ "bg_idx = 1 - fg_idx\n",
220
+ "\n",
221
+ "randsd_bgnorm = []\n",
222
+ "randsd_nobg = []\n",
223
+ "randsd_bgmean = []\n",
224
+ "\n",
225
+ "for i in range(n):\n",
226
+ " bgnorm_mean = randsd[i][fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)\n",
227
+ " \n",
228
+ " if fg_idx == 0:\n",
229
+ " bg_mask = fg_labels[i]\n",
230
+ " else:\n",
231
+ " bg_mask = 1 - fg_labels[i]\n",
232
+ " \n",
233
+ " bg_mask = bg_mask.unsqueeze(1)\n",
234
+ " bgnorm = (randsd[i] * (1 - bg_mask)) + (bgnorm_mean * bg_mask)\n",
235
+ " \n",
236
+ " randsd_bgnorm.append(bgnorm)\n",
237
+ " randsd_nobg.append(randsd[i] * (1 - bg_mask) + (-1 * bg_mask))\n",
238
+ " randsd_bgmean.append(bgnorm_mean)\n",
239
+ " \n",
240
+ "randsd_bgnorm = torch.stack(randsd_bgnorm, dim=0)\n",
241
+ "randsd_nobg = torch.stack(randsd_nobg, dim=0)\n",
242
+ "randsd_bgmean = torch.cat(randsd_bgmean, dim=0)"
243
+ ]
244
+ },
245
+ {
246
+ "cell_type": "code",
247
+ "execution_count": null,
248
+ "id": "f90785f8-d1e4-4dc9-9e6d-c242058639a5",
249
+ "metadata": {},
250
+ "outputs": [],
251
+ "source": [
252
+ "set_seed(42)\n",
253
+ "M = 8\n",
254
+ "\n",
255
+ "coarse_kmeans = torchpq.clustering.KMeans(n_clusters=M,\n",
256
+ " distance=\"cosine\",\n",
257
+ " verbose=1,\n",
258
+ " n_redo=5,\n",
259
+ " max_iter=1000)\n",
260
+ "coarse_labels = coarse_kmeans.fit(randsd_nobg.reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(n, -1)"
261
+ ]
262
+ },
263
+ {
264
+ "cell_type": "code",
265
+ "execution_count": null,
266
+ "id": "d4930216-de8a-420b-b7cd-45260299b7b9",
267
+ "metadata": {},
268
+ "outputs": [],
269
+ "source": [
270
+ "for i in range(100):\n",
271
+ " plt.subplot(10, 10, i+1)\n",
272
+ " plt.imshow(coarse_labels[i].reshape(16, 16))\n",
273
+ " plt.axis('off')"
274
+ ]
275
+ },
276
+ {
277
+ "cell_type": "code",
278
+ "execution_count": null,
279
+ "id": "d5e72d5e-df24-485d-85aa-de17ba381d84",
280
+ "metadata": {},
281
+ "outputs": [],
282
+ "source": [
283
+ "import torch\n",
284
+ "import numpy as np\n",
285
+ "import matplotlib.pyplot as plt\n",
286
+ "\n",
287
+ "disp = coarse_labels[0].reshape(16, 16)\n",
288
+ "\n",
289
+ "plt.imshow(disp)\n",
290
+ "plt.axis('off')"
291
+ ]
292
+ },
293
+ {
294
+ "cell_type": "code",
295
+ "execution_count": null,
296
+ "id": "b508e11d-082a-40a6-83db-4d306c0f9f00",
297
+ "metadata": {},
298
+ "outputs": [],
299
+ "source": [
300
+ "torch.unique(coarse_labels, return_counts=True)"
301
+ ]
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "execution_count": null,
306
+ "id": "8f4bf418-2297-4dc5-8bdd-15af7cf44c7a",
307
+ "metadata": {},
308
+ "outputs": [],
309
+ "source": [
310
+ "sd_bgnorm = []\n",
311
+ "sd_nobg = []\n",
312
+ "sd_bgmean = []\n",
313
+ "\n",
314
+ "inp = sd.permute(0, 2, 1)\n",
315
+ "N = inp.size(0)\n",
316
+ "\n",
317
+ "sd_fg_labels = []\n",
318
+ "bs = 1000\n",
319
+ "for bidx in range(N // bs + 1):\n",
320
+ " if bidx * bs >= N:\n",
321
+ " break\n",
322
+ " \n",
323
+ " start_bidx = bidx*bs\n",
324
+ " end_bidx = min((bidx+1)*bs, N)\n",
325
+ " \n",
326
+ " sd_fg_labels.append(fg_kmeans.predict(inp[start_bidx:end_bidx].reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(end_bidx - start_bidx, -1))\n",
327
+ " \n",
328
+ "sd_fg_labels = torch.cat(sd_fg_labels, dim=0)\n",
329
+ "\n",
330
+ "for i in range(N):\n",
331
+ " bgnorm_mean = inp[i][sd_fg_labels[i] == bg_idx].mean(dim=0, keepdim=True)\n",
332
+ " \n",
333
+ " if fg_idx == 0:\n",
334
+ " bg_mask = sd_fg_labels[i]\n",
335
+ " else:\n",
336
+ " bg_mask = 1 - sd_fg_labels[i]\n",
337
+ " \n",
338
+ " bg_mask = bg_mask.unsqueeze(1)\n",
339
+ " bgnorm = (inp[i] * (1 - bg_mask)) + (bgnorm_mean * bg_mask)\n",
340
+ " \n",
341
+ " sd_bgnorm.append(bgnorm)\n",
342
+ " sd_nobg.append(inp[i] * (1 - bg_mask) + (-1 * bg_mask))\n",
343
+ " sd_bgmean.append(bgnorm_mean)\n",
344
+ " print(i, end='\\r')\n",
345
+ " \n",
346
+ "sd_bgnorm = torch.stack(sd_bgnorm, dim=0)\n",
347
+ "sd_nobg = torch.stack(sd_nobg, dim=0)\n",
348
+ "sd_bgmean = torch.cat(sd_bgmean, dim=0)"
349
+ ]
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "execution_count": null,
354
+ "id": "d8026046-2ac5-4d62-9d18-4cbf7e2ebdb0",
355
+ "metadata": {},
356
+ "outputs": [],
357
+ "source": [
358
+ "sd_coarse_labels = []\n",
359
+ "bs = 1000\n",
360
+ "for bidx in range(N // bs + 1):\n",
361
+ " if bidx * bs >= N:\n",
362
+ " break\n",
363
+ " \n",
364
+ " start_bidx = bidx*bs\n",
365
+ " end_bidx = min((bidx+1)*bs, N)\n",
366
+ " \n",
367
+ " sd_coarse_labels.append(coarse_kmeans.predict(sd_nobg[start_bidx:end_bidx].reshape(-1, 768).t().contiguous().cuda()).cpu().reshape(end_bidx - start_bidx, -1))\n",
368
+ " \n",
369
+ "sd_coarse_labels = torch.cat(sd_coarse_labels, dim=0)"
370
+ ]
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "execution_count": null,
375
+ "id": "80a46c2d-6cd4-4344-b56f-4ff1fa7235d3",
376
+ "metadata": {},
377
+ "outputs": [],
378
+ "source": [
379
+ "for i in range(100):\n",
380
+ " plt.subplot(10, 10, i+1)\n",
381
+ " coarse_mask = sd_coarse_labels[i].reshape(16, 16)\n",
382
+ " plt.imshow(coarse_mask)\n",
383
+ " plt.axis('off')"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "execution_count": null,
389
+ "id": "e41a997b-cccc-40c2-a983-c6092ffe69be",
390
+ "metadata": {},
391
+ "outputs": [],
392
+ "source": [
393
+ "torch.save(sd_coarse_labels.reshape(N, 16, 16).long().cpu(), f'data/{dataset_name}/coarse_mask_m8.pth')"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "code",
398
+ "execution_count": null,
399
+ "id": "cc918d47-35ca-4177-b63b-a12f4f4a3d5a",
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "torch.unique(sd_coarse_labels, return_counts=True)"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "id": "9d4beb2d-e385-4ecc-9f90-287d7bc13c0d",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "from tqdm.auto import tqdm\n",
414
+ "\n",
415
+ "sd_fgmean = []\n",
416
+ "\n",
417
+ "inp = sd.permute(0, 2, 1)\n",
418
+ "N = inp.size(0)\n",
419
+ "M = 8\n",
420
+ "\n",
421
+ "for i in tqdm(range(N)):\n",
422
+ " mean_feats = []\n",
423
+ " for m in range(M):\n",
424
+ " coarse_mask = sd_coarse_labels[i] == m\n",
425
+ " if coarse_mask.sum().item() == 0:\n",
426
+ " m_mean_feats = torch.zeros(1, 768)\n",
427
+ " else:\n",
428
+ " m_mean_feats = inp[i][coarse_mask].mean(dim=0, keepdim=True)\n",
429
+ " \n",
430
+ " mean_feats.append(m_mean_feats)\n",
431
+ " \n",
432
+ " mean_feats = torch.cat(mean_feats, dim=0)\n",
433
+ " sd_fgmean.append(mean_feats)\n",
434
+ " print(i, end='\\r')\n",
435
+ " \n",
436
+ "sd_fgmean = torch.stack(sd_fgmean, dim=0)"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "id": "5034a91a-ae6d-468b-a738-f9ec0d019d72",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "N = inp.size(0)\n",
447
+ "M = 8\n",
448
+ "K = 256\n",
449
+ "bgm = {'cub200_2011': 7, 'dogs': 1}[dataset_name] # 7 for cub, 1 for dog, this means which index is background\n",
450
+ "\n",
451
+ "final_labels = torch.ones(N, M) * K\n",
452
+ "\n",
453
+ "set_seed(42)\n",
454
+ "\n",
455
+ "zero_mean_idxs = []\n",
456
+ "fine_feats = []\n",
457
+ "fine_kmeans_trained = []\n",
458
+ "\n",
459
+ "for m in range(M):\n",
460
+ " fine_kmeans = torchpq.clustering.KMeans(n_clusters=K,\n",
461
+ " distance=\"cosine\",\n",
462
+ " verbose=1,\n",
463
+ " n_redo=5,\n",
464
+ " max_iter=1000)\n",
465
+ " \n",
466
+ " if m == bgm:\n",
467
+ " fine_labels = fine_kmeans.fit(sd_bgmean.t().contiguous().cuda()).cpu()\n",
468
+ " final_labels[:, m] = fine_labels\n",
469
+ " else:\n",
470
+ " fine_inp = sd_fgmean[:, m].reshape(-1, 768)\n",
471
+ " fine_labels = fine_kmeans.fit(fine_inp.t().contiguous().cuda()).cpu()\n",
472
+ " \n",
473
+ " final_labels[:, m] = fine_labels\n",
474
+ " \n",
475
+ " fine_kmeans_trained.append(fine_kmeans)\n",
476
+ " \n",
477
+ " fine_feats.append(fine_kmeans.centroizds.cpu().t()[fine_labels])\n",
478
+ " \n",
479
+ " print('zero mean', torch.arange(K)[fine_kmeans.centroids.t().sum(dim=-1).cpu() == 0].tolist())\n",
480
+ " zero_mean_idxs.append(torch.arange(K)[fine_kmeans.centroids.t().sum(dim=-1).cpu() == 0].tolist())\n",
481
+ " \n",
482
+ "fine_feats = torch.cat(fine_feats, dim=1)\n",
483
+ "print(fine_feats.size())"
484
+ ]
485
+ },
486
+ {
487
+ "cell_type": "code",
488
+ "execution_count": null,
489
+ "id": "450bc5c1-cc41-44b3-ab6a-c0c65eb1210a",
490
+ "metadata": {},
491
+ "outputs": [],
492
+ "source": [
493
+ "torch.save({\n",
494
+ " 'foreground_background': fg_kmeans,\n",
495
+ " 'coarse_kmeans': coarse_kmeans,\n",
496
+ " 'fine_kmeans': fine_kmeans_trained,\n",
497
+ "}, f'data/{dataset_name}/pretrained_kmeans.pth')"
498
+ ]
499
+ },
500
+ {
501
+ "cell_type": "code",
502
+ "execution_count": null,
503
+ "id": "9ad2f095-f99d-489e-b645-8933b6f66372",
504
+ "metadata": {},
505
+ "outputs": [],
506
+ "source": [
507
+ "from tqdm.auto import tqdm\n",
508
+ "\n",
509
+ "final_code_captions = []\n",
510
+ "counts = [[0 for _ in range(K)] for _ in range(M)]\n",
511
+ "\n",
512
+ "for i in tqdm(range(N)):\n",
513
+ " m_labels = final_labels[i] # M\n",
514
+ " \n",
515
+ " line = []\n",
516
+ " for m in range(M):\n",
517
+ " k = m_labels[m].long().item()\n",
518
+ " \n",
519
+ " if k not in zero_mean_idxs[m]:\n",
520
+ " line.append(f'{m}:{k}')\n",
521
+ " counts[m][k] += 1\n",
522
+ " \n",
523
+ " assert len(line) != 0, f'error at {i}'\n",
524
+ " final_code_captions.append(' '.join(line))"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "code",
529
+ "execution_count": null,
530
+ "id": "1be7233d-d9fe-4e0a-bda9-93fc45f8e982",
531
+ "metadata": {},
532
+ "outputs": [],
533
+ "source": [
534
+ "import matplotlib.pyplot as plt\n",
535
+ "\n",
536
+ "for m in range(M):\n",
537
+ " if max(counts[m]) == 0:\n",
538
+ " continue\n",
539
+ " \n",
540
+ " plt.scatter(range(K), counts[m])\n",
541
+ " print(m, min(counts[m]), max(counts[m]), np.mean(counts[m]))"
542
+ ]
543
+ },
544
+ {
545
+ "cell_type": "code",
546
+ "execution_count": null,
547
+ "id": "e344ab48-fba6-4c17-a88c-e2528c0ac5cf",
548
+ "metadata": {},
549
+ "outputs": [],
550
+ "source": [
551
+ "with open(f'data/{dataset_name}/train_caps_better_m{M}_k{K}.txt', 'w+') as f:\n",
552
+ " for line in final_code_captions:\n",
553
+ " f.write(line + '\\n')"
554
+ ]
555
+ }
556
+ ],
557
+ "metadata": {
558
+ "kernelspec": {
559
+ "display_name": "Python 3 (ipykernel)",
560
+ "language": "python",
561
+ "name": "python3"
562
+ },
563
+ "language_info": {
564
+ "codemirror_mode": {
565
+ "name": "ipython",
566
+ "version": 3
567
+ },
568
+ "file_extension": ".py",
569
+ "mimetype": "text/x-python",
570
+ "name": "python",
571
+ "nbconvert_exporter": "python",
572
+ "pygments_lexer": "ipython3",
573
+ "version": "3.10.8"
574
+ }
575
+ },
576
+ "nbformat": 4,
577
+ "nbformat_minor": 5
578
+ }
utils.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from diffusers.models.attention_processor import LoRAAttnProcessor
3
+
4
+
5
+ def add_tokens(tokenizer, text_encoder, placeholder_token, num_vec_per_token=1, initializer_token=None):
6
+ """
7
+ Add tokens to the tokenizer and set the initial value of token embeddings
8
+ """
9
+ tokenizer.add_placeholder_tokens(placeholder_token, num_vec_per_token=num_vec_per_token)
10
+ text_encoder.resize_token_embeddings(len(tokenizer))
11
+ token_embeds = text_encoder.get_input_embeddings().weight.data
12
+ placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
13
+ if initializer_token:
14
+ token_ids = tokenizer.encode(initializer_token, add_special_tokens=False)
15
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
16
+ token_embeds[placeholder_token_id] = token_embeds[token_ids[i * len(token_ids) // num_vec_per_token]]
17
+ else:
18
+ for i, placeholder_token_id in enumerate(placeholder_token_ids):
19
+ token_embeds[placeholder_token_id] = torch.randn_like(token_embeds[placeholder_token_id])
20
+ return placeholder_token_ids
21
+
22
+
23
+ def tokenize_prompt(tokenizer, prompt, replace_token=False):
24
+ text_inputs = tokenizer(
25
+ prompt,
26
+ replace_token=replace_token,
27
+ padding="max_length",
28
+ max_length=tokenizer.model_max_length,
29
+ truncation=True,
30
+ return_tensors="pt",
31
+ )
32
+ text_input_ids = text_inputs.input_ids
33
+ return text_input_ids
34
+
35
+
36
+ def get_processor(self, return_deprecated_lora: bool = False):
37
+ r"""
38
+ Get the attention processor in use.
39
+
40
+ Args:
41
+ return_deprecated_lora (`bool`, *optional*, defaults to `False`):
42
+ Set to `True` to return the deprecated LoRA attention processor.
43
+
44
+ Returns:
45
+ "AttentionProcessor": The attention processor in use.
46
+ """
47
+ if not return_deprecated_lora:
48
+ return self.processor
49
+
50
+ # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible
51
+ # serialization format for LoRA Attention Processors. It should be deleted once the integration
52
+ # with PEFT is completed.
53
+ is_lora_activated = {
54
+ name: module.lora_layer is not None
55
+ for name, module in self.named_modules()
56
+ if hasattr(module, "lora_layer")
57
+ }
58
+
59
+ # 1. if no layer has a LoRA activated we can return the processor as usual
60
+ if not any(is_lora_activated.values()):
61
+ return self.processor
62
+
63
+ # If doesn't apply LoRA do `add_k_proj` or `add_v_proj`
64
+ is_lora_activated.pop("add_k_proj", None)
65
+ is_lora_activated.pop("add_v_proj", None)
66
+ # 2. else it is not posssible that only some layers have LoRA activated
67
+ if not all(is_lora_activated.values()):
68
+ raise ValueError(
69
+ f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}"
70
+ )
71
+
72
+ # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor
73
+ # non_lora_processor_cls_name = self.processor.__class__.__name__
74
+ # lora_processor_cls = getattr(import_module(__name__), "LoRA" + non_lora_processor_cls_name)
75
+
76
+ hidden_size = self.inner_dim
77
+
78
+ # now create a LoRA attention processor from the LoRA layers
79
+ kwargs = {
80
+ "cross_attention_dim": self.cross_attention_dim,
81
+ "rank": self.to_q.lora_layer.rank,
82
+ "network_alpha": self.to_q.lora_layer.network_alpha,
83
+ "q_rank": self.to_q.lora_layer.rank,
84
+ "q_hidden_size": self.to_q.lora_layer.out_features,
85
+ "k_rank": self.to_k.lora_layer.rank,
86
+ "k_hidden_size": self.to_k.lora_layer.out_features,
87
+ "v_rank": self.to_v.lora_layer.rank,
88
+ "v_hidden_size": self.to_v.lora_layer.out_features,
89
+ "out_rank": self.to_out[0].lora_layer.rank,
90
+ "out_hidden_size": self.to_out[0].lora_layer.out_features,
91
+ }
92
+
93
+ if hasattr(self.processor, "attention_op"):
94
+ kwargs["attention_op"] = self.processor.attention_op
95
+
96
+ lora_processor = LoRAAttnProcessor(hidden_size, **kwargs)
97
+ lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict())
98
+ lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict())
99
+ lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict())
100
+ lora_processor.to_out_lora.load_state_dict(self.to_out[0].lora_layer.state_dict())
101
+
102
+ return lora_processor
103
+
104
+
105
+ def get_attn_processors(self):
106
+ r"""
107
+ Returns:
108
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
109
+ indexed by its weight name.
110
+ """
111
+ # set recursively
112
+ processors = {}
113
+
114
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors):
115
+ if hasattr(module, "get_processor"):
116
+ processors[f"{name}.processor"] = get_processor(module, return_deprecated_lora=True)
117
+
118
+ for sub_name, child in module.named_children():
119
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
120
+
121
+ return processors
122
+
123
+ for name, module in self.named_children():
124
+ fn_recursive_add_processors(name, module, processors)
125
+
126
+ return processors
127
+