Spaces:
Runtime error
Runtime error
Garrett Goon
commited on
Commit
·
1da6b3f
1
Parent(s):
7b17c3f
updated syntax to match reviewed repo
Browse files- __pycache__/utils.cpython-38.pyc +0 -0
- app.py +19 -19
- learned_embeddings_dict.pt +1 -1
- utils.py +31 -41
__pycache__/utils.cpython-38.pyc
CHANGED
Binary files a/__pycache__/utils.cpython-38.pyc and b/__pycache__/utils.cpython-38.pyc differ
|
|
app.py
CHANGED
@@ -34,37 +34,37 @@ pipeline = StableDiffusionPipeline.from_pretrained(
|
|
34 |
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
|
35 |
learned_embeddings_dict = torch.load(CONCEPT_PATH)
|
36 |
|
37 |
-
|
38 |
for concept_token, embedding_dict in learned_embeddings_dict.items():
|
39 |
-
|
40 |
learned_embeddings = embedding_dict["learned_embeddings"]
|
41 |
(
|
42 |
initializer_ids,
|
43 |
dummy_placeholder_ids,
|
44 |
-
|
45 |
) = utils.add_new_tokens_to_tokenizer(
|
46 |
-
|
47 |
-
|
48 |
tokenizer=pipeline.tokenizer,
|
49 |
)
|
50 |
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
|
51 |
token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
|
52 |
for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
|
53 |
token_embeddings[d_id] = tensor
|
54 |
-
|
55 |
|
56 |
|
57 |
-
def
|
58 |
-
for concept_token,
|
59 |
-
text = text.replace(concept_token,
|
60 |
return text
|
61 |
|
62 |
def inference(prompt: str, guidance_scale: int, num_inference_steps: int, seed: int):
|
63 |
if not prompt:
|
64 |
raise ValueError("Please enter a prompt.")
|
65 |
-
if '
|
66 |
-
raise ValueError('"
|
67 |
-
prompt =
|
68 |
generator = torch.Generator(device=device).manual_seed(seed)
|
69 |
output = pipeline(
|
70 |
prompt=[prompt] * BATCH_SIZE,
|
@@ -275,35 +275,35 @@ block = gr.Blocks(css=css)
|
|
275 |
|
276 |
examples = [
|
277 |
[
|
278 |
-
"a Van Gogh painting of a
|
279 |
# 4,
|
280 |
# 45,
|
281 |
# 7.5,
|
282 |
# 1024,
|
283 |
],
|
284 |
[
|
285 |
-
"Futuristic
|
286 |
# 4,
|
287 |
# 45,
|
288 |
# 7,
|
289 |
# 1024,
|
290 |
],
|
291 |
[
|
292 |
-
"cell shaded cartoon of a
|
293 |
# 4,
|
294 |
# 45,
|
295 |
# 7,
|
296 |
# 1024,
|
297 |
],
|
298 |
[
|
299 |
-
"a surreal Salvador Dali painting of a
|
300 |
# 4,
|
301 |
# 45,
|
302 |
# 7,
|
303 |
# 1024,
|
304 |
],
|
305 |
[
|
306 |
-
"Beautiful tarot illustration of a
|
307 |
# 4,
|
308 |
# 45,
|
309 |
# 7,
|
@@ -334,10 +334,10 @@ with block:
|
|
334 |
with gr.Box():
|
335 |
with gr.Row(elem_id="prompt-container").style(equal_height=True):
|
336 |
prompt = gr.Textbox(
|
337 |
-
label='Enter a prompt including "
|
338 |
show_label=False,
|
339 |
max_lines=1,
|
340 |
-
placeholder='Enter a prompt including "
|
341 |
elem_id="prompt-text-input",
|
342 |
).style(
|
343 |
container=False,
|
|
|
34 |
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
|
35 |
learned_embeddings_dict = torch.load(CONCEPT_PATH)
|
36 |
|
37 |
+
concept_to_dummy_strs_map = {}
|
38 |
for concept_token, embedding_dict in learned_embeddings_dict.items():
|
39 |
+
initializer_strs = embedding_dict["initializer_strs"]
|
40 |
learned_embeddings = embedding_dict["learned_embeddings"]
|
41 |
(
|
42 |
initializer_ids,
|
43 |
dummy_placeholder_ids,
|
44 |
+
dummy_placeholder_strs,
|
45 |
) = utils.add_new_tokens_to_tokenizer(
|
46 |
+
concept_str=concept_token,
|
47 |
+
initializer_strs=initializer_strs,
|
48 |
tokenizer=pipeline.tokenizer,
|
49 |
)
|
50 |
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
|
51 |
token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
|
52 |
for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
|
53 |
token_embeddings[d_id] = tensor
|
54 |
+
concept_to_dummy_strs_map[concept_token] = dummy_placeholder_strs
|
55 |
|
56 |
|
57 |
+
def replace_concept_strs(text: str):
|
58 |
+
for concept_token, dummy_strs in concept_to_dummy_strs_map.items():
|
59 |
+
text = text.replace(concept_token, dummy_strs)
|
60 |
return text
|
61 |
|
62 |
def inference(prompt: str, guidance_scale: int, num_inference_steps: int, seed: int):
|
63 |
if not prompt:
|
64 |
raise ValueError("Please enter a prompt.")
|
65 |
+
if 'det-logo' not in prompt:
|
66 |
+
raise ValueError('"det-logo" must be included in the prompt.')
|
67 |
+
prompt = replace_concept_strs(prompt)
|
68 |
generator = torch.Generator(device=device).manual_seed(seed)
|
69 |
output = pipeline(
|
70 |
prompt=[prompt] * BATCH_SIZE,
|
|
|
275 |
|
276 |
examples = [
|
277 |
[
|
278 |
+
"a Van Gogh painting of a det-logo with thick strokes, masterful composition",
|
279 |
# 4,
|
280 |
# 45,
|
281 |
# 7.5,
|
282 |
# 1024,
|
283 |
],
|
284 |
[
|
285 |
+
"Futuristic det-logo in a desert, painting, octane render, 4 k, anime sky, warm colors",
|
286 |
# 4,
|
287 |
# 45,
|
288 |
# 7,
|
289 |
# 1024,
|
290 |
],
|
291 |
[
|
292 |
+
"cell shaded cartoon of a det-logo, subtle colors, post grunge, concept art by josan gonzales and wlop, by james jean, victo ngai, david rubin, mike mignola, deviantart, art by artgem",
|
293 |
# 4,
|
294 |
# 45,
|
295 |
# 7,
|
296 |
# 1024,
|
297 |
],
|
298 |
[
|
299 |
+
"a surreal Salvador Dali painting of a det-logo, soft blended colors",
|
300 |
# 4,
|
301 |
# 45,
|
302 |
# 7,
|
303 |
# 1024,
|
304 |
],
|
305 |
[
|
306 |
+
"Beautiful tarot illustration of a det-logo, in the style of james jean and victo ngai, mystical colors, trending on artstation",
|
307 |
# 4,
|
308 |
# 45,
|
309 |
# 7,
|
|
|
334 |
with gr.Box():
|
335 |
with gr.Row(elem_id="prompt-container").style(equal_height=True):
|
336 |
prompt = gr.Textbox(
|
337 |
+
label='Enter a prompt including "det-logo"',
|
338 |
show_label=False,
|
339 |
max_lines=1,
|
340 |
+
placeholder='Enter a prompt including "det-logo"',
|
341 |
elem_id="prompt-text-input",
|
342 |
).style(
|
343 |
container=False,
|
learned_embeddings_dict.pt
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 16235
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5184c747567ac6240bd45b701cb29416752fcc925b2a967a811c28729451b942
|
3 |
size 16235
|
utils.py
CHANGED
@@ -1,59 +1,49 @@
|
|
1 |
-
from typing import List,
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
6 |
|
7 |
def add_new_tokens_to_tokenizer(
|
8 |
-
|
9 |
-
|
10 |
tokenizer: nn.Module,
|
11 |
-
) -> Tuple[
|
12 |
"""Helper function for adding new tokens to the tokenizer and extending the corresponding
|
13 |
embeddings appropriately, given a single concept token and its sequence of corresponding
|
14 |
-
initializer tokens. Returns the
|
15 |
replacements, as well as the string representation of the dummies.
|
16 |
"""
|
|
|
|
|
|
|
|
|
17 |
initializer_ids = tokenizer(
|
18 |
-
|
19 |
-
padding="max_length",
|
20 |
-
truncation=True,
|
21 |
-
max_length=tokenizer.model_max_length,
|
22 |
return_tensors="pt",
|
23 |
add_special_tokens=False,
|
24 |
-
).input_ids
|
25 |
-
|
26 |
-
try:
|
27 |
-
special_token_ids = tokenizer.all_special_ids
|
28 |
-
except AttributeError:
|
29 |
-
special_token_ids = []
|
30 |
-
|
31 |
-
non_special_initializer_locations = torch.isin(
|
32 |
-
initializer_ids, torch.tensor(special_token_ids), invert=True
|
33 |
-
)
|
34 |
-
non_special_initializer_ids = initializer_ids[non_special_initializer_locations]
|
35 |
-
if len(non_special_initializer_ids) == 0:
|
36 |
-
raise ValueError(
|
37 |
-
f'"{initializer_tokens}" maps to trivial tokens, please choose a different initializer.'
|
38 |
-
)
|
39 |
|
40 |
# Add a dummy placeholder token for every token in the initializer.
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
dummy_placeholder_ids = tokenizer.convert_tokens_to_ids(
|
52 |
-
|
53 |
-
)
|
54 |
-
# Sanity check
|
55 |
assert len(dummy_placeholder_ids) == len(
|
56 |
-
|
57 |
-
), 'Length of "dummy_placeholder_ids" and "
|
|
|
|
|
|
|
58 |
|
59 |
-
|
|
|
|
|
|
1 |
+
from typing import List, Tuple
|
2 |
|
3 |
import torch
|
4 |
import torch.nn as nn
|
5 |
|
6 |
|
7 |
def add_new_tokens_to_tokenizer(
|
8 |
+
concept_str: str,
|
9 |
+
initializer_strs: str,
|
10 |
tokenizer: nn.Module,
|
11 |
+
) -> Tuple[torch.Tensor, List[int], str]:
|
12 |
"""Helper function for adding new tokens to the tokenizer and extending the corresponding
|
13 |
embeddings appropriately, given a single concept token and its sequence of corresponding
|
14 |
+
initializer tokens. Returns the tensor of ids for the initializer tokens and their dummy
|
15 |
replacements, as well as the string representation of the dummies.
|
16 |
"""
|
17 |
+
assert not token_exists_in_tokenizer(
|
18 |
+
concept_str, tokenizer
|
19 |
+
), f"concept_str {concept_str} already exists in tokenizer."
|
20 |
+
|
21 |
initializer_ids = tokenizer(
|
22 |
+
initializer_strs,
|
|
|
|
|
|
|
23 |
return_tensors="pt",
|
24 |
add_special_tokens=False,
|
25 |
+
).input_ids[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
|
27 |
# Add a dummy placeholder token for every token in the initializer.
|
28 |
+
dummy_placeholder_str_list = [f"<{concept_str}>_{n}" for n in range(len(initializer_ids))]
|
29 |
+
# Sanity check.
|
30 |
+
for dummy in dummy_placeholder_str_list:
|
31 |
+
assert not token_exists_in_tokenizer(
|
32 |
+
dummy, tokenizer
|
33 |
+
), f"dummy {dummy} already exists in tokenizer."
|
34 |
+
|
35 |
+
dummy_placeholder_strs = " ".join(dummy_placeholder_str_list)
|
36 |
+
|
37 |
+
tokenizer.add_tokens(dummy_placeholder_str_list)
|
38 |
+
dummy_placeholder_ids = tokenizer.convert_tokens_to_ids(dummy_placeholder_str_list)
|
39 |
+
# Sanity check that the dummies correspond to the correct number of ids.
|
|
|
|
|
40 |
assert len(dummy_placeholder_ids) == len(
|
41 |
+
initializer_ids
|
42 |
+
), 'Length of "dummy_placeholder_ids" and "initializer_ids" must match.'
|
43 |
+
|
44 |
+
return initializer_ids, dummy_placeholder_ids, dummy_placeholder_strs
|
45 |
+
|
46 |
|
47 |
+
def token_exists_in_tokenizer(token: str, tokenizer: nn.Module) -> bool:
|
48 |
+
exists = tokenizer.convert_tokens_to_ids([token]) != [tokenizer.unk_token_id]
|
49 |
+
return exists
|