Spaces:
Runtime error
Runtime error
Upload 4 files
Browse files
modules/textual_inversion/dataset.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import PIL
|
4 |
+
import torch
|
5 |
+
from PIL import Image
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchvision import transforms
|
8 |
+
|
9 |
+
import random
|
10 |
+
import tqdm
|
11 |
+
from modules import devices
|
12 |
+
import re
|
13 |
+
|
14 |
+
re_tag = re.compile(r"[a-zA-Z][_\w\d()]+")
|
15 |
+
|
16 |
+
|
17 |
+
class PersonalizedBase(Dataset):
|
18 |
+
def __init__(self, data_root, size=None, repeats=100, flip_p=0.5, placeholder_token="*", width=512, height=512, model=None, device=None, template_file=None):
|
19 |
+
|
20 |
+
self.placeholder_token = placeholder_token
|
21 |
+
|
22 |
+
self.size = size
|
23 |
+
self.width = width
|
24 |
+
self.height = height
|
25 |
+
self.flip = transforms.RandomHorizontalFlip(p=flip_p)
|
26 |
+
|
27 |
+
self.dataset = []
|
28 |
+
|
29 |
+
with open(template_file, "r") as file:
|
30 |
+
lines = [x.strip() for x in file.readlines()]
|
31 |
+
|
32 |
+
self.lines = lines
|
33 |
+
|
34 |
+
assert data_root, 'dataset directory not specified'
|
35 |
+
|
36 |
+
self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]
|
37 |
+
print("Preparing dataset...")
|
38 |
+
for path in tqdm.tqdm(self.image_paths):
|
39 |
+
image = Image.open(path)
|
40 |
+
image = image.convert('RGB')
|
41 |
+
image = image.resize((self.width, self.height), PIL.Image.BICUBIC)
|
42 |
+
|
43 |
+
filename = os.path.basename(path)
|
44 |
+
filename_tokens = os.path.splitext(filename)[0]
|
45 |
+
filename_tokens = re_tag.findall(filename_tokens)
|
46 |
+
|
47 |
+
npimage = np.array(image).astype(np.uint8)
|
48 |
+
npimage = (npimage / 127.5 - 1.0).astype(np.float32)
|
49 |
+
|
50 |
+
torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
|
51 |
+
torchdata = torch.moveaxis(torchdata, 2, 0)
|
52 |
+
|
53 |
+
init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
|
54 |
+
init_latent = init_latent.to(devices.cpu)
|
55 |
+
|
56 |
+
self.dataset.append((init_latent, filename_tokens))
|
57 |
+
|
58 |
+
self.length = len(self.dataset) * repeats
|
59 |
+
|
60 |
+
self.initial_indexes = np.arange(self.length) % len(self.dataset)
|
61 |
+
self.indexes = None
|
62 |
+
self.shuffle()
|
63 |
+
|
64 |
+
def shuffle(self):
|
65 |
+
self.indexes = self.initial_indexes[torch.randperm(self.initial_indexes.shape[0])]
|
66 |
+
|
67 |
+
def __len__(self):
|
68 |
+
return self.length
|
69 |
+
|
70 |
+
def __getitem__(self, i):
|
71 |
+
if i % len(self.dataset) == 0:
|
72 |
+
self.shuffle()
|
73 |
+
|
74 |
+
index = self.indexes[i % len(self.indexes)]
|
75 |
+
x, filename_tokens = self.dataset[index]
|
76 |
+
|
77 |
+
text = random.choice(self.lines)
|
78 |
+
text = text.replace("[name]", self.placeholder_token)
|
79 |
+
text = text.replace("[filewords]", ' '.join(filename_tokens))
|
80 |
+
|
81 |
+
return x, text
|
modules/textual_inversion/preprocess.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, ImageOps
|
3 |
+
import platform
|
4 |
+
import sys
|
5 |
+
import tqdm
|
6 |
+
|
7 |
+
from modules import shared, images
|
8 |
+
|
9 |
+
|
10 |
+
def preprocess(process_src, process_dst, process_flip, process_split, process_caption):
|
11 |
+
size = 512
|
12 |
+
src = os.path.abspath(process_src)
|
13 |
+
dst = os.path.abspath(process_dst)
|
14 |
+
|
15 |
+
assert src != dst, 'same directory specified as source and destination'
|
16 |
+
|
17 |
+
os.makedirs(dst, exist_ok=True)
|
18 |
+
|
19 |
+
files = os.listdir(src)
|
20 |
+
|
21 |
+
shared.state.textinfo = "Preprocessing..."
|
22 |
+
shared.state.job_count = len(files)
|
23 |
+
|
24 |
+
if process_caption:
|
25 |
+
shared.interrogator.load()
|
26 |
+
|
27 |
+
def save_pic_with_caption(image, index):
|
28 |
+
if process_caption:
|
29 |
+
caption = "-" + shared.interrogator.generate_caption(image)
|
30 |
+
caption = sanitize_caption(os.path.join(dst, f"{index:05}-{subindex[0]}"), caption, ".png")
|
31 |
+
else:
|
32 |
+
caption = filename
|
33 |
+
caption = os.path.splitext(caption)[0]
|
34 |
+
caption = os.path.basename(caption)
|
35 |
+
|
36 |
+
image.save(os.path.join(dst, f"{index:05}-{subindex[0]}{caption}.png"))
|
37 |
+
subindex[0] += 1
|
38 |
+
|
39 |
+
def save_pic(image, index):
|
40 |
+
save_pic_with_caption(image, index)
|
41 |
+
|
42 |
+
if process_flip:
|
43 |
+
save_pic_with_caption(ImageOps.mirror(image), index)
|
44 |
+
|
45 |
+
for index, imagefile in enumerate(tqdm.tqdm(files)):
|
46 |
+
subindex = [0]
|
47 |
+
filename = os.path.join(src, imagefile)
|
48 |
+
img = Image.open(filename).convert("RGB")
|
49 |
+
|
50 |
+
if shared.state.interrupted:
|
51 |
+
break
|
52 |
+
|
53 |
+
ratio = img.height / img.width
|
54 |
+
is_tall = ratio > 1.35
|
55 |
+
is_wide = ratio < 1 / 1.35
|
56 |
+
|
57 |
+
if process_split and is_tall:
|
58 |
+
img = img.resize((size, size * img.height // img.width))
|
59 |
+
|
60 |
+
top = img.crop((0, 0, size, size))
|
61 |
+
save_pic(top, index)
|
62 |
+
|
63 |
+
bot = img.crop((0, img.height - size, size, img.height))
|
64 |
+
save_pic(bot, index)
|
65 |
+
elif process_split and is_wide:
|
66 |
+
img = img.resize((size * img.width // img.height, size))
|
67 |
+
|
68 |
+
left = img.crop((0, 0, size, size))
|
69 |
+
save_pic(left, index)
|
70 |
+
|
71 |
+
right = img.crop((img.width - size, 0, img.width, size))
|
72 |
+
save_pic(right, index)
|
73 |
+
else:
|
74 |
+
img = images.resize_image(1, img, size, size)
|
75 |
+
save_pic(img, index)
|
76 |
+
|
77 |
+
shared.state.nextjob()
|
78 |
+
|
79 |
+
if process_caption:
|
80 |
+
shared.interrogator.send_blip_to_ram()
|
81 |
+
|
82 |
+
def sanitize_caption(base_path, original_caption, suffix):
|
83 |
+
operating_system = platform.system().lower()
|
84 |
+
if (operating_system == "windows"):
|
85 |
+
invalid_path_characters = "\\/:*?\"<>|"
|
86 |
+
max_path_length = 259
|
87 |
+
else:
|
88 |
+
invalid_path_characters = "/" #linux/macos
|
89 |
+
max_path_length = 1023
|
90 |
+
caption = original_caption
|
91 |
+
for invalid_character in invalid_path_characters:
|
92 |
+
caption = caption.replace(invalid_character, "")
|
93 |
+
fixed_path_length = len(base_path) + len(suffix)
|
94 |
+
if fixed_path_length + len(caption) <= max_path_length:
|
95 |
+
return caption
|
96 |
+
caption_tokens = caption.split()
|
97 |
+
new_caption = ""
|
98 |
+
for token in caption_tokens:
|
99 |
+
last_caption = new_caption
|
100 |
+
new_caption = new_caption + token + " "
|
101 |
+
if (len(new_caption) + fixed_path_length - 1 > max_path_length):
|
102 |
+
break
|
103 |
+
print(f"\nPath will be too long. Truncated caption: {original_caption}\nto: {last_caption}", file=sys.stderr)
|
104 |
+
return last_caption.strip()
|
modules/textual_inversion/textual_inversion.py
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import traceback
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import tqdm
|
7 |
+
import html
|
8 |
+
import datetime
|
9 |
+
|
10 |
+
|
11 |
+
from modules import shared, devices, sd_hijack, processing, sd_models
|
12 |
+
import modules.textual_inversion.dataset
|
13 |
+
|
14 |
+
|
15 |
+
class Embedding:
|
16 |
+
def __init__(self, vec, name, step=None):
|
17 |
+
self.vec = vec
|
18 |
+
self.name = name
|
19 |
+
self.step = step
|
20 |
+
self.cached_checksum = None
|
21 |
+
self.sd_checkpoint = None
|
22 |
+
self.sd_checkpoint_name = None
|
23 |
+
|
24 |
+
def save(self, filename):
|
25 |
+
embedding_data = {
|
26 |
+
"string_to_token": {"*": 265},
|
27 |
+
"string_to_param": {"*": self.vec},
|
28 |
+
"name": self.name,
|
29 |
+
"step": self.step,
|
30 |
+
"sd_checkpoint": self.sd_checkpoint,
|
31 |
+
"sd_checkpoint_name": self.sd_checkpoint_name,
|
32 |
+
}
|
33 |
+
|
34 |
+
torch.save(embedding_data, filename)
|
35 |
+
|
36 |
+
def checksum(self):
|
37 |
+
if self.cached_checksum is not None:
|
38 |
+
return self.cached_checksum
|
39 |
+
|
40 |
+
def const_hash(a):
|
41 |
+
r = 0
|
42 |
+
for v in a:
|
43 |
+
r = (r * 281 ^ int(v) * 997) & 0xFFFFFFFF
|
44 |
+
return r
|
45 |
+
|
46 |
+
self.cached_checksum = f'{const_hash(self.vec.reshape(-1) * 100) & 0xffff:04x}'
|
47 |
+
return self.cached_checksum
|
48 |
+
|
49 |
+
|
50 |
+
class EmbeddingDatabase:
|
51 |
+
def __init__(self, embeddings_dir):
|
52 |
+
self.ids_lookup = {}
|
53 |
+
self.word_embeddings = {}
|
54 |
+
self.dir_mtime = None
|
55 |
+
self.embeddings_dir = embeddings_dir
|
56 |
+
|
57 |
+
def register_embedding(self, embedding, model):
|
58 |
+
|
59 |
+
self.word_embeddings[embedding.name] = embedding
|
60 |
+
|
61 |
+
ids = model.cond_stage_model.tokenizer([embedding.name], add_special_tokens=False)['input_ids'][0]
|
62 |
+
|
63 |
+
first_id = ids[0]
|
64 |
+
if first_id not in self.ids_lookup:
|
65 |
+
self.ids_lookup[first_id] = []
|
66 |
+
|
67 |
+
self.ids_lookup[first_id] = sorted(self.ids_lookup[first_id] + [(ids, embedding)], key=lambda x: len(x[0]), reverse=True)
|
68 |
+
|
69 |
+
return embedding
|
70 |
+
|
71 |
+
def load_textual_inversion_embeddings(self):
|
72 |
+
mt = os.path.getmtime(self.embeddings_dir)
|
73 |
+
if self.dir_mtime is not None and mt <= self.dir_mtime:
|
74 |
+
return
|
75 |
+
|
76 |
+
self.dir_mtime = mt
|
77 |
+
self.ids_lookup.clear()
|
78 |
+
self.word_embeddings.clear()
|
79 |
+
|
80 |
+
def process_file(path, filename):
|
81 |
+
name = os.path.splitext(filename)[0]
|
82 |
+
|
83 |
+
data = torch.load(path, map_location="cpu")
|
84 |
+
|
85 |
+
# textual inversion embeddings
|
86 |
+
if 'string_to_param' in data:
|
87 |
+
param_dict = data['string_to_param']
|
88 |
+
if hasattr(param_dict, '_parameters'):
|
89 |
+
param_dict = getattr(param_dict, '_parameters') # fix for torch 1.12.1 loading saved file from torch 1.11
|
90 |
+
assert len(param_dict) == 1, 'embedding file has multiple terms in it'
|
91 |
+
emb = next(iter(param_dict.items()))[1]
|
92 |
+
# diffuser concepts
|
93 |
+
elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
|
94 |
+
assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
|
95 |
+
|
96 |
+
emb = next(iter(data.values()))
|
97 |
+
if len(emb.shape) == 1:
|
98 |
+
emb = emb.unsqueeze(0)
|
99 |
+
else:
|
100 |
+
raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")
|
101 |
+
|
102 |
+
vec = emb.detach().to(devices.device, dtype=torch.float32)
|
103 |
+
embedding = Embedding(vec, name)
|
104 |
+
embedding.step = data.get('step', None)
|
105 |
+
embedding.sd_checkpoint = data.get('hash', None)
|
106 |
+
embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
|
107 |
+
self.register_embedding(embedding, shared.sd_model)
|
108 |
+
|
109 |
+
for fn in os.listdir(self.embeddings_dir):
|
110 |
+
try:
|
111 |
+
fullfn = os.path.join(self.embeddings_dir, fn)
|
112 |
+
|
113 |
+
if os.stat(fullfn).st_size == 0:
|
114 |
+
continue
|
115 |
+
|
116 |
+
process_file(fullfn, fn)
|
117 |
+
except Exception:
|
118 |
+
print(f"Error loading emedding {fn}:", file=sys.stderr)
|
119 |
+
print(traceback.format_exc(), file=sys.stderr)
|
120 |
+
continue
|
121 |
+
|
122 |
+
print(f"Loaded a total of {len(self.word_embeddings)} textual inversion embeddings.")
|
123 |
+
|
124 |
+
def find_embedding_at_position(self, tokens, offset):
|
125 |
+
token = tokens[offset]
|
126 |
+
possible_matches = self.ids_lookup.get(token, None)
|
127 |
+
|
128 |
+
if possible_matches is None:
|
129 |
+
return None, None
|
130 |
+
|
131 |
+
for ids, embedding in possible_matches:
|
132 |
+
if tokens[offset:offset + len(ids)] == ids:
|
133 |
+
return embedding, len(ids)
|
134 |
+
|
135 |
+
return None, None
|
136 |
+
|
137 |
+
|
138 |
+
def create_embedding(name, num_vectors_per_token, init_text='*'):
|
139 |
+
cond_model = shared.sd_model.cond_stage_model
|
140 |
+
embedding_layer = cond_model.wrapped.transformer.text_model.embeddings
|
141 |
+
|
142 |
+
ids = cond_model.tokenizer(init_text, max_length=num_vectors_per_token, return_tensors="pt", add_special_tokens=False)["input_ids"]
|
143 |
+
embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)
|
144 |
+
vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)
|
145 |
+
|
146 |
+
for i in range(num_vectors_per_token):
|
147 |
+
vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]
|
148 |
+
|
149 |
+
fn = os.path.join(shared.cmd_opts.embeddings_dir, f"{name}.pt")
|
150 |
+
assert not os.path.exists(fn), f"file {fn} already exists"
|
151 |
+
|
152 |
+
embedding = Embedding(vec, name)
|
153 |
+
embedding.step = 0
|
154 |
+
embedding.save(fn)
|
155 |
+
|
156 |
+
return fn
|
157 |
+
|
158 |
+
|
159 |
+
def train_embedding(embedding_name, learn_rate, data_root, log_directory, steps, create_image_every, save_embedding_every, template_file):
|
160 |
+
assert embedding_name, 'embedding not selected'
|
161 |
+
|
162 |
+
shared.state.textinfo = "Initializing textual inversion training..."
|
163 |
+
shared.state.job_count = steps
|
164 |
+
|
165 |
+
filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
|
166 |
+
|
167 |
+
log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
|
168 |
+
|
169 |
+
if save_embedding_every > 0:
|
170 |
+
embedding_dir = os.path.join(log_directory, "embeddings")
|
171 |
+
os.makedirs(embedding_dir, exist_ok=True)
|
172 |
+
else:
|
173 |
+
embedding_dir = None
|
174 |
+
|
175 |
+
if create_image_every > 0:
|
176 |
+
images_dir = os.path.join(log_directory, "images")
|
177 |
+
os.makedirs(images_dir, exist_ok=True)
|
178 |
+
else:
|
179 |
+
images_dir = None
|
180 |
+
|
181 |
+
cond_model = shared.sd_model.cond_stage_model
|
182 |
+
|
183 |
+
shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
|
184 |
+
with torch.autocast("cuda"):
|
185 |
+
ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, size=512, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file)
|
186 |
+
|
187 |
+
hijack = sd_hijack.model_hijack
|
188 |
+
|
189 |
+
embedding = hijack.embedding_db.word_embeddings[embedding_name]
|
190 |
+
embedding.vec.requires_grad = True
|
191 |
+
|
192 |
+
optimizer = torch.optim.AdamW([embedding.vec], lr=learn_rate)
|
193 |
+
|
194 |
+
losses = torch.zeros((32,))
|
195 |
+
|
196 |
+
last_saved_file = "<none>"
|
197 |
+
last_saved_image = "<none>"
|
198 |
+
|
199 |
+
ititial_step = embedding.step or 0
|
200 |
+
if ititial_step > steps:
|
201 |
+
return embedding, filename
|
202 |
+
|
203 |
+
pbar = tqdm.tqdm(enumerate(ds), total=steps-ititial_step)
|
204 |
+
for i, (x, text) in pbar:
|
205 |
+
embedding.step = i + ititial_step
|
206 |
+
|
207 |
+
if embedding.step > steps:
|
208 |
+
break
|
209 |
+
|
210 |
+
if shared.state.interrupted:
|
211 |
+
break
|
212 |
+
|
213 |
+
with torch.autocast("cuda"):
|
214 |
+
c = cond_model([text])
|
215 |
+
|
216 |
+
x = x.to(devices.device)
|
217 |
+
loss = shared.sd_model(x.unsqueeze(0), c)[0]
|
218 |
+
del x
|
219 |
+
|
220 |
+
losses[embedding.step % losses.shape[0]] = loss.item()
|
221 |
+
|
222 |
+
optimizer.zero_grad()
|
223 |
+
loss.backward()
|
224 |
+
optimizer.step()
|
225 |
+
|
226 |
+
pbar.set_description(f"loss: {losses.mean():.7f}")
|
227 |
+
|
228 |
+
if embedding.step > 0 and embedding_dir is not None and embedding.step % save_embedding_every == 0:
|
229 |
+
last_saved_file = os.path.join(embedding_dir, f'{embedding_name}-{embedding.step}.pt')
|
230 |
+
embedding.save(last_saved_file)
|
231 |
+
|
232 |
+
if embedding.step > 0 and images_dir is not None and embedding.step % create_image_every == 0:
|
233 |
+
last_saved_image = os.path.join(images_dir, f'{embedding_name}-{embedding.step}.png')
|
234 |
+
|
235 |
+
p = processing.StableDiffusionProcessingTxt2Img(
|
236 |
+
sd_model=shared.sd_model,
|
237 |
+
prompt=text,
|
238 |
+
steps=20,
|
239 |
+
do_not_save_grid=True,
|
240 |
+
do_not_save_samples=True,
|
241 |
+
)
|
242 |
+
|
243 |
+
processed = processing.process_images(p)
|
244 |
+
image = processed.images[0]
|
245 |
+
|
246 |
+
shared.state.current_image = image
|
247 |
+
image.save(last_saved_image)
|
248 |
+
|
249 |
+
last_saved_image += f", prompt: {text}"
|
250 |
+
|
251 |
+
shared.state.job_no = embedding.step
|
252 |
+
|
253 |
+
shared.state.textinfo = f"""
|
254 |
+
<p>
|
255 |
+
Loss: {losses.mean():.7f}<br/>
|
256 |
+
Step: {embedding.step}<br/>
|
257 |
+
Last prompt: {html.escape(text)}<br/>
|
258 |
+
Last saved embedding: {html.escape(last_saved_file)}<br/>
|
259 |
+
Last saved image: {html.escape(last_saved_image)}<br/>
|
260 |
+
</p>
|
261 |
+
"""
|
262 |
+
|
263 |
+
checkpoint = sd_models.select_checkpoint()
|
264 |
+
|
265 |
+
embedding.sd_checkpoint = checkpoint.hash
|
266 |
+
embedding.sd_checkpoint_name = checkpoint.model_name
|
267 |
+
embedding.cached_checksum = None
|
268 |
+
embedding.save(filename)
|
269 |
+
|
270 |
+
return embedding, filename
|
271 |
+
|
modules/textual_inversion/ui.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import html
|
2 |
+
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
import modules.textual_inversion.textual_inversion
|
6 |
+
import modules.textual_inversion.preprocess
|
7 |
+
from modules import sd_hijack, shared
|
8 |
+
|
9 |
+
|
10 |
+
def create_embedding(name, initialization_text, nvpt):
|
11 |
+
filename = modules.textual_inversion.textual_inversion.create_embedding(name, nvpt, init_text=initialization_text)
|
12 |
+
|
13 |
+
sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings()
|
14 |
+
|
15 |
+
return gr.Dropdown.update(choices=sorted(sd_hijack.model_hijack.embedding_db.word_embeddings.keys())), f"Created: {filename}", ""
|
16 |
+
|
17 |
+
|
18 |
+
def preprocess(*args):
|
19 |
+
modules.textual_inversion.preprocess.preprocess(*args)
|
20 |
+
|
21 |
+
return "Preprocessing finished.", ""
|
22 |
+
|
23 |
+
|
24 |
+
def train_embedding(*args):
|
25 |
+
|
26 |
+
try:
|
27 |
+
sd_hijack.undo_optimizations()
|
28 |
+
|
29 |
+
embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
|
30 |
+
|
31 |
+
res = f"""
|
32 |
+
Training {'interrupted' if shared.state.interrupted else 'finished'} at {embedding.step} steps.
|
33 |
+
Embedding saved to {html.escape(filename)}
|
34 |
+
"""
|
35 |
+
return res, ""
|
36 |
+
except Exception:
|
37 |
+
raise
|
38 |
+
finally:
|
39 |
+
sd_hijack.apply_optimizations()
|
40 |
+
|