Spaces:
Runtime error
Runtime error
Garrett Goon
commited on
Commit
·
4e26ea8
1
Parent(s):
ad18a5b
tests
Browse files
app.py
CHANGED
@@ -25,31 +25,31 @@ pipeline = StableDiffusionPipeline.from_pretrained(
|
|
25 |
# Load in the new concepts.
|
26 |
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
|
27 |
learned_embeddings_dict = torch.load(CONCEPT_PATH)
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
|
54 |
|
55 |
# def inference(
|
@@ -101,7 +101,7 @@ with gr.Blocks() as demo:
|
|
101 |
interactive=True,
|
102 |
)
|
103 |
output = gr.Textbox(
|
104 |
-
label="output", placeholder=use_auth_token[:
|
105 |
)
|
106 |
gr.Button("test").click(
|
107 |
lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output
|
|
|
25 |
# Load in the new concepts.
|
26 |
CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
|
27 |
learned_embeddings_dict = torch.load(CONCEPT_PATH)
|
28 |
+
|
29 |
+
concept_to_dummy_tokens_map = {}
|
30 |
+
for concept_token, embedding_dict in learned_embeddings_dict.items():
|
31 |
+
initializer_tokens = embedding_dict["initializer_tokens"]
|
32 |
+
learned_embeddings = embedding_dict["learned_embeddings"]
|
33 |
+
(
|
34 |
+
initializer_ids,
|
35 |
+
dummy_placeholder_ids,
|
36 |
+
dummy_placeholder_tokens,
|
37 |
+
) = utils.add_new_tokens_to_tokenizer(
|
38 |
+
concept_token=concept_token,
|
39 |
+
initializer_tokens=initializer_tokens,
|
40 |
+
tokenizer=pipeline.tokenizer,
|
41 |
+
)
|
42 |
+
pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
|
43 |
+
token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
|
44 |
+
for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
|
45 |
+
token_embeddings[d_id] = tensor
|
46 |
+
concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
|
47 |
+
|
48 |
+
|
49 |
+
def replace_concept_tokens(text: str):
|
50 |
+
for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
|
51 |
+
text = text.replace(concept_token, dummy_tokens)
|
52 |
+
return text
|
53 |
|
54 |
|
55 |
# def inference(
|
|
|
101 |
interactive=True,
|
102 |
)
|
103 |
output = gr.Textbox(
|
104 |
+
label="output", placeholder=use_auth_token[:5], interactive=False
|
105 |
)
|
106 |
gr.Button("test").click(
|
107 |
lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output
|