cymic commited on
Commit
679ffae
·
1 Parent(s): 8cb9c9e

Upload 4 files

Browse files
modules/textual_inversion/dataset.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import PIL
4
+ import torch
5
+ from PIL import Image
6
+ from torch.utils.data import Dataset
7
+ from torchvision import transforms
8
+
9
+ import random
10
+ import tqdm
11
+ from modules import devices
12
+ import re
13
+
14
+ re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
15
+
16
+
17
+ class PersonalizedBase(Dataset):
18
+ def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
19
+
20
+ self.placeholder_token = placeholder_token
21
+
22
+ self.size = size
23
+ self.width = width
24
+ self.height = height
25
+ self.flip = transforms.RandomHorizontalFlip(p=flip_p)
26
+
27
+ self.dataset = []
28
+
29
+ with open(template_file, "r") as file:
30
+ lines = [x.strip() for x in file.readlines()]
31
+
32
+ self.lines = lines
33
+
34
+ assert data_root, 'dataset directory not specified'
35
+
36
+ self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
37
+ print("Preparing dataset...")
38
+ for path in tqdm.tqdm(self.image_paths):
39
+ image = Image.open(path)
40
+ image = image.convert('RGB')
41
+ image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
42
+
43
+ filename = os.path.basename(path)
44
+ filename_tokens = os.path.splitext(filename)[0]
45
+ filename_tokens = re_tag.findall(filename_tokens)
46
+
47
+ npimage = np.array(image).astype(np.uint8)
48
+ npimage = (npimage / 127.5 - 1.0).astype(np.float32)
49
+
50
+ torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
51
+ torchdata = torch.moveaxis(torchdata, 2, 0)
52
+
53
+ init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
54
+ init_latent = init_latent.to(devices.cpu)
55
+
56
+ self.dataset.append((init_latent, filename_tokens))
57
+
58
+ self.length = len(self.dataset) * repeats
59
+
60
+ self.initial_indexes = np.arange(self.length) % len(self.dataset)
61
+ self.indexes = None
62
+ self.shuffle()
63
+
64
+ def shuffle(self):
65
+ self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
66
+
67
+ def __len__(self):
68
+ return self.length
69
+
70
+ def __getitem__(self, i):
71
+ if i % len(self.dataset) == 0:
72
+ self.shuffle()
73
+
74
+ index = self.indexes[i % len(self.indexes)]
75
+ x, filename_tokens = self.dataset[index]
76
+
77
+ text = random.choice(self.lines)
78
+ text = text.replace("[name]", self.placeholder_token)
79
+ text = text.replace("[filewords]", ' '.join(filename_tokens))
80
+
81
+ return x, text
modules/textual_inversion/preprocess.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, ImageOps
3
+ import platform
4
+ import sys
5
+ import tqdm
6
+
7
+ from modules import shared, images
8
+
9
+
10
+ def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
11
+ size = 512
12
+ src = os.path.abspath(process_src)
13
+ dst = os.path.abspath(process_dst)
14
+
15
+ assert src != dst, 'same directory specified as source and destination'
16
+
17
+ os.makedirs(dst, exist_ok=True)
18
+
19
+ files = os.listdir(src)
20
+
21
+ shared.state.textinfo = "Preprocessing..."
22
+ shared.state.job_count = len(files)
23
+
24
+ if process_caption:
25
+ shared.interrogator.load()
26
+
27
+ def save_pic_with_caption(image, index):
28
+ if process_caption:
29
+ caption = "-" + shared.interrogator.generate_caption(image)
30
+ caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
31
+ else:
32
+ caption = filename
33
+ caption = os.path.splitext(caption)[0]
34
+ caption = os.path.basename(caption)
35
+
36
+ image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
37
+ subindex[0] += 1
38
+
39
+ def save_pic(image, index):
40
+ save_pic_with_caption(image, index)
41
+
42
+ if process_flip:
43
+ save_pic_with_caption(ImageOps.mirror(image), index)
44
+
45
+ for index, imagefile in enumerate(tqdm.tqdm(files)):
46
+ subindex = [0]
47
+ filename = os.path.join(src, imagefile)
48
+ img = Image.open(filename).convert("RGB")
49
+
50
+ if shared.state.interrupted:
51
+ break
52
+
53
+ ratio = img.height / img.width
54
+ is_tall = ratio > 1.35
55
+ is_wide = ratio < 1 / 1.35
56
+
57
+ if process_split and is_tall:
58
+ img = img.resize((size, size * img.height // img.width))
59
+
60
+ top = img.crop((0, 0, size, size))
61
+ save_pic(top, index)
62
+
63
+ bot = img.crop((0, img.height - size, size, img.height))
64
+ save_pic(bot, index)
65
+ elif process_split and is_wide:
66
+ img = img.resize((size * img.width // img.height, size))
67
+
68
+ left = img.crop((0, 0, size, size))
69
+ save_pic(left, index)
70
+
71
+ right = img.crop((img.width - size, 0, img.width, size))
72
+ save_pic(right, index)
73
+ else:
74
+ img = images.resize_image(1, img, size, size)
75
+ save_pic(img, index)
76
+
77
+ shared.state.nextjob()
78
+
79
+ if process_caption:
80
+ shared.interrogator.send_blip_to_ram()
81
+
82
+ def sanitize_caption(base_path, original_caption, suffix):
83
+ operating_system = platform.system().lower()
84
+ if (operating_system == "windows"):
85
+ invalid_path_characters = "\\/:*?\"<>|"
86
+ max_path_length = 259
87
+ else:
88
+ invalid_path_characters = "/" #linux/macos
89
+ max_path_length = 1023
90
+ caption = original_caption
91
+ for invalid_character in invalid_path_characters:
92
+ caption = caption.replace(invalid_character, "")
93
+ fixed_path_length = len(base_path) + len(suffix)
94
+ if fixed_path_length + len(caption) <= max_path_length:
95
+ return caption
96
+ caption_tokens = caption.split()
97
+ new_caption = ""
98
+ for token in caption_tokens:
99
+ last_caption = new_caption
100
+ new_caption = new_caption + token + " "
101
+ if (len(new_caption) + fixed_path_length - 1 > max_path_length):
102
+ break
103
+ print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
104
+ return last_caption.strip()
modules/textual_inversion/textual_inversion.py ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ import torch
6
+ import tqdm
7
+ import html
8
+ import datetime
9
+
10
+
11
+ from modules import shared, devices, sd_hijack, processing, sd_models
12
+ import modules.textual_inversion.dataset
13
+
14
+
15
+ class Embedding:
16
+ def __init__(self, vec, name, step=None):
17
+ self.vec = vec
18
+ self.name = name
19
+ self.step = step
20
+ self.cached_checksum = None
21
+ self.sd_checkpoint = None
22
+ self.sd_checkpoint_name = None
23
+
24
+ def save(self, filename):
25
+ embedding_data = {
26
+ "string_to_token": {"*": 265},
27
+ "string_to_param": {"*": self.vec},
28
+ "name": self.name,
29
+ "step": self.step,
30
+ "sd_checkpoint": self.sd_checkpoint,
31
+ "sd_checkpoint_name": self.sd_checkpoint_name,
32
+ }
33
+
34
+ torch.save(embedding_data, filename)
35
+
36
+ def checksum(self):
37
+ if self.cached_checksum is not None:
38
+ return self.cached_checksum
39
+
40
+ def const_hash(a):
41
+ r = 0
42
+ for v in a:
43
+ r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
44
+ return r
45
+
46
+ self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
47
+ return self.cached_checksum
48
+
49
+
50
+ class EmbeddingDatabase:
51
+ def __init__(self, embeddings_dir):
52
+ self.ids_lookup = {}
53
+ self.word_embeddings = {}
54
+ self.dir_mtime = None
55
+ self.embeddings_dir = embeddings_dir
56
+
57
+ def register_embedding(self, embedding, model):
58
+
59
+ self.word_embeddings[embedding.name] = embedding
60
+
61
+ ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
62
+
63
+ first_id = ids[0]
64
+ if first_id not in self.ids_lookup:
65
+ self.ids_lookup[first_id] = []
66
+
67
+ self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
68
+
69
+ return embedding
70
+
71
+ def load_textual_inversion_embeddings(self):
72
+ mt = os.path.getmtime(self.embeddings_dir)
73
+ if self.dir_mtime is not None and mt <= self.dir_mtime:
74
+ return
75
+
76
+ self.dir_mtime = mt
77
+ self.ids_lookup.clear()
78
+ self.word_embeddings.clear()
79
+
80
+ def process_file(path, filename):
81
+ name = os.path.splitext(filename)[0]
82
+
83
+ data = torch.load(path, map_location="cpu")
84
+
85
+ # textual inversion embeddings
86
+ if 'string_to_param' in data:
87
+ param_dict = data['string_to_param']
88
+ if hasattr(param_dict, '_parameters'):
89
+ param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
90
+ assert len(param_dict) == 1, 'embedding file has multiple terms in it'
91
+ emb = next(iter(param_dict.items()))[1]
92
+ # diffuser concepts
93
+ elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
94
+ assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
95
+
96
+ emb = next(iter(data.values()))
97
+ if len(emb.shape) == 1:
98
+ emb = emb.unsqueeze(0)
99
+ else:
100
+ raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
101
+
102
+ vec = emb.detach().to(devices.device, dtype=torch.float32)
103
+ embedding = Embedding(vec, name)
104
+ embedding.step = data.get('step', None)
105
+ embedding.sd_checkpoint = data.get('hash', None)
106
+ embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
107
+ self.register_embedding(embedding, shared.sd_model)
108
+
109
+ for fn in os.listdir(self.embeddings_dir):
110
+ try:
111
+ fullfn = os.path.join(self.embeddings_dir, fn)
112
+
113
+ if os.stat(fullfn).st_size == 0:
114
+ continue
115
+
116
+ process_file(fullfn, fn)
117
+ except Exception:
118
+ print(f"Error loading emedding {fn}:", file=sys.stderr)
119
+ print(traceback.format_exc(), file=sys.stderr)
120
+ continue
121
+
122
+ print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
123
+
124
+ def find_embedding_at_position(self, tokens, offset):
125
+ token = tokens[offset]
126
+ possible_matches = self.ids_lookup.get(token, None)
127
+
128
+ if possible_matches is None:
129
+ return None, None
130
+
131
+ for ids, embedding in possible_matches:
132
+ if tokens[offset:offset + len(ids)] == ids:
133
+ return embedding, len(ids)
134
+
135
+ return None, None
136
+
137
+
138
+ def create_embedding(name, num_vectors_per_token, init_text='*'):
139
+ cond_model = shared.sd_model.cond_stage_model
140
+ embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
141
+
142
+ ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
143
+ embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
144
+ vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
145
+
146
+ for i in range(num_vectors_per_token):
147
+ vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
148
+
149
+ fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
150
+ assert not os.path.exists(fn), f"file {fn} already exists"
151
+
152
+ embedding = Embedding(vec, name)
153
+ embedding.step = 0
154
+ embedding.save(fn)
155
+
156
+ return fn
157
+
158
+
159
+ def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
160
+ assert embedding_name, 'embedding not selected'
161
+
162
+ shared.state.textinfo = "Initializing textual inversion training..."
163
+ shared.state.job_count = steps
164
+
165
+ filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
166
+
167
+ log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
168
+
169
+ if save_embedding_every > 0:
170
+ embedding_dir = os.path.join(log_directory, "embeddings")
171
+ os.makedirs(embedding_dir, exist_ok=True)
172
+ else:
173
+ embedding_dir = None
174
+
175
+ if create_image_every > 0:
176
+ images_dir = os.path.join(log_directory, "images")
177
+ os.makedirs(images_dir, exist_ok=True)
178
+ else:
179
+ images_dir = None
180
+
181
+ cond_model = shared.sd_model.cond_stage_model
182
+
183
+ shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
184
+ with torch.autocast("cuda"):
185
+ ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
186
+
187
+ hijack = sd_hijack.model_hijack
188
+
189
+ embedding = hijack.embedding_db.word_embeddings[embedding_name]
190
+ embedding.vec.requires_grad = True
191
+
192
+ optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
193
+
194
+ losses = torch.zeros((32,))
195
+
196
+ last_saved_file = "<none>"
197
+ last_saved_image = "<none>"
198
+
199
+ ititial_step = embedding.step or 0
200
+ if ititial_step > steps:
201
+ return embedding, filename
202
+
203
+ pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
204
+ for i, (x, text) in pbar:
205
+ embedding.step = i + ititial_step
206
+
207
+ if embedding.step > steps:
208
+ break
209
+
210
+ if shared.state.interrupted:
211
+ break
212
+
213
+ with torch.autocast("cuda"):
214
+ c = cond_model([text])
215
+
216
+ x = x.to(devices.device)
217
+ loss = shared.sd_model(x.unsqueeze(0), c)[0]
218
+ del x
219
+
220
+ losses[embedding.step % losses.shape[0]] = loss.item()
221
+
222
+ optimizer.zero_grad()
223
+ loss.backward()
224
+ optimizer.step()
225
+
226
+ pbar.set_description(f"loss: {losses.mean():.7f}")
227
+
228
+ if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
229
+ last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
230
+ embedding.save(last_saved_file)
231
+
232
+ if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
233
+ last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
234
+
235
+ p = processing.StableDiffusionProcessingTxt2Img(
236
+ sd_model=shared.sd_model,
237
+ prompt=text,
238
+ steps=20,
239
+ do_not_save_grid=True,
240
+ do_not_save_samples=True,
241
+ )
242
+
243
+ processed = processing.process_images(p)
244
+ image = processed.images[0]
245
+
246
+ shared.state.current_image = image
247
+ image.save(last_saved_image)
248
+
249
+ last_saved_image += f", prompt: {text}"
250
+
251
+ shared.state.job_no = embedding.step
252
+
253
+ shared.state.textinfo = f"""
254
+ <p>
255
+ Loss: {losses.mean():.7f}<br/>
256
+ Step: {embedding.step}<br/>
257
+ Last prompt: {html.escape(text)}<br/>
258
+ Last saved embedding: {html.escape(last_saved_file)}<br/>
259
+ Last saved image: {html.escape(last_saved_image)}<br/>
260
+ </p>
261
+ """
262
+
263
+ checkpoint = sd_models.select_checkpoint()
264
+
265
+ embedding.sd_checkpoint = checkpoint.hash
266
+ embedding.sd_checkpoint_name = checkpoint.model_name
267
+ embedding.cached_checksum = None
268
+ embedding.save(filename)
269
+
270
+ return embedding, filename
271
+
modules/textual_inversion/ui.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import html
2
+
3
+ import gradio as gr
4
+
5
+ import modules.textual_inversion.textual_inversion
6
+ import modules.textual_inversion.preprocess
7
+ from modules import sd_hijack, shared
8
+
9
+
10
+ def create_embedding(name, initialization_text, nvpt):
11
+ filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
12
+
13
+ sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
14
+
15
+ return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
16
+
17
+
18
+ def preprocess(*args):
19
+ modules.textual_inversion.preprocess.preprocess(*args)
20
+
21
+ return "Preprocessing finished.", ""
22
+
23
+
24
+ def train_embedding(*args):
25
+
26
+ try:
27
+ sd_hijack.undo_optimizations()
28
+
29
+ embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
30
+
31
+ res = f"""
32
+ Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
33
+ Embedding saved to {html.escape(filename)}
34
+ """
35
+ return res, ""
36
+ except Exception:
37
+ raise
38
+ finally:
39
+ sd_hijack.apply_optimizations()
40
+