Garrett Goon commited on
Commit
4e26ea8
·
1 Parent(s): ad18a5b
Files changed (1) hide show
  1. app.py +26 -26
app.py CHANGED
@@ -25,31 +25,31 @@ pipeline = StableDiffusionPipeline.from_pretrained(
25
  # Load in the new concepts.
26
  CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
27
  learned_embeddings_dict = torch.load(CONCEPT_PATH)
28
- #
29
- # concept_to_dummy_tokens_map = {}
30
- # for concept_token, embedding_dict in learned_embeddings_dict.items():
31
- # initializer_tokens = embedding_dict["initializer_tokens"]
32
- # learned_embeddings = embedding_dict["learned_embeddings"]
33
- # (
34
- # initializer_ids,
35
- # dummy_placeholder_ids,
36
- # dummy_placeholder_tokens,
37
- # ) = utils.add_new_tokens_to_tokenizer(
38
- # concept_token=concept_token,
39
- # initializer_tokens=initializer_tokens,
40
- # tokenizer=pipeline.tokenizer,
41
- # )
42
- # pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
43
- # token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
44
- # for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
45
- # token_embeddings[d_id] = tensor
46
- # concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
47
- #
48
- #
49
- # def replace_concept_tokens(text: str):
50
- # for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
51
- # text = text.replace(concept_token, dummy_tokens)
52
- # return text
53
 
54
 
55
  # def inference(
@@ -101,7 +101,7 @@ with gr.Blocks() as demo:
101
  interactive=True,
102
  )
103
  output = gr.Textbox(
104
- label="output", placeholder=use_auth_token[:10], interactive=False
105
  )
106
  gr.Button("test").click(
107
  lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output
 
25
  # Load in the new concepts.
26
  CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
27
  learned_embeddings_dict = torch.load(CONCEPT_PATH)
28
+
29
+ concept_to_dummy_tokens_map = {}
30
+ for concept_token, embedding_dict in learned_embeddings_dict.items():
31
+ initializer_tokens = embedding_dict["initializer_tokens"]
32
+ learned_embeddings = embedding_dict["learned_embeddings"]
33
+ (
34
+ initializer_ids,
35
+ dummy_placeholder_ids,
36
+ dummy_placeholder_tokens,
37
+ ) = utils.add_new_tokens_to_tokenizer(
38
+ concept_token=concept_token,
39
+ initializer_tokens=initializer_tokens,
40
+ tokenizer=pipeline.tokenizer,
41
+ )
42
+ pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
43
+ token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
44
+ for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
45
+ token_embeddings[d_id] = tensor
46
+ concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
47
+
48
+
49
+ def replace_concept_tokens(text: str):
50
+ for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
51
+ text = text.replace(concept_token, dummy_tokens)
52
+ return text
53
 
54
 
55
  # def inference(
 
101
  interactive=True,
102
  )
103
  output = gr.Textbox(
104
+ label="output", placeholder=use_auth_token[:5], interactive=False
105
  )
106
  gr.Button("test").click(
107
  lambda s: replace_concept_tokens(s), inputs=[prompt], outputs=output