Garrett Goon commited on
Commit
ad18a5b
·
1 Parent(s): 8680dd4
Files changed (3) hide show
  1. app.py +44 -14
  2. learned_embeddings_dict.pt +3 -0
  3. learned_embeddings_dict.py +0 -0
app.py CHANGED
@@ -1,25 +1,55 @@
 
1
  import os
2
 
3
  import gradio as gr
4
  import torch
5
  from diffusers import StableDiffusionPipeline
6
 
7
- use_auth_token = os.environ["HF_AUTH_TOKEN"]
8
-
9
- # pipeline = StableDiffusionPipeline.from_pretrained(
10
- # pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
11
- # use_auth_token=use_auth_token,
12
- # revision="fp16",
13
- # torch_dtype=torch.float16,
14
- # ).to("cuda")
15
-
16
- concept_to_dummy_tokens_map = torch.load("concept_to_dummy_tokens_map.pt")
17
 
 
18
 
19
- def replace_concept_tokens(text: str):
20
- for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
21
- text = text.replace(concept_token, dummy_tokens)
22
- return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
 
25
  # def inference(
 
1
+ import pathlib
2
  import os
3
 
4
  import gradio as gr
5
  import torch
6
  from diffusers import StableDiffusionPipeline
7
 
8
+ import utils
 
 
 
 
 
 
 
 
 
9
 
10
+ use_auth_token = os.environ["HF_AUTH_TOKEN"]
11
 
12
+ # Instantiate the pipeline.
13
+ device, revision, torch_dtype = (
14
+ ("cuda", "fp16", torch.float16)
15
+ if torch.cuda.is_available()
16
+ else ("cpu", "main", torch.float32)
17
+ )
18
+ pipeline = StableDiffusionPipeline.from_pretrained(
19
+ pretrained_model_name_or_path="CompVis/stable-diffusion-v1-4",
20
+ use_auth_token=use_auth_token,
21
+ revision=revision,
22
+ torch_dtype=torch_dtype,
23
+ ).to(device)
24
+
25
+ # Load in the new concepts.
26
+ CONCEPT_PATH = pathlib.Path("learned_embeddings_dict.pt")
27
+ learned_embeddings_dict = torch.load(CONCEPT_PATH)
28
+ #
29
+ # concept_to_dummy_tokens_map = {}
30
+ # for concept_token, embedding_dict in learned_embeddings_dict.items():
31
+ # initializer_tokens = embedding_dict["initializer_tokens"]
32
+ # learned_embeddings = embedding_dict["learned_embeddings"]
33
+ # (
34
+ # initializer_ids,
35
+ # dummy_placeholder_ids,
36
+ # dummy_placeholder_tokens,
37
+ # ) = utils.add_new_tokens_to_tokenizer(
38
+ # concept_token=concept_token,
39
+ # initializer_tokens=initializer_tokens,
40
+ # tokenizer=pipeline.tokenizer,
41
+ # )
42
+ # pipeline.text_encoder.resize_token_embeddings(len(pipeline.tokenizer))
43
+ # token_embeddings = pipeline.text_encoder.get_input_embeddings().weight.data
44
+ # for d_id, tensor in zip(dummy_placeholder_ids, learned_embeddings):
45
+ # token_embeddings[d_id] = tensor
46
+ # concept_to_dummy_tokens_map[concept_token] = dummy_placeholder_tokens
47
+ #
48
+ #
49
+ # def replace_concept_tokens(text: str):
50
+ # for concept_token, dummy_tokens in concept_to_dummy_tokens_map.items():
51
+ # text = text.replace(concept_token, dummy_tokens)
52
+ # return text
53
 
54
 
55
  # def inference(
learned_embeddings_dict.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:73ab240e6ef7b16a70e14b4625882d8f63050f1d96ffc0eef6e0e0caa2844109
3
+ size 16235
learned_embeddings_dict.py DELETED
Binary file (16.2 kB)