multimodalart HF staff commited on
Commit
dbaf842
·
1 Parent(s): a25de9b
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. .gitignore +6 -0
  3. LICENSE +21 -0
  4. app.py +19 -0
  5. assets/.DS_Store +0 -0
  6. assets/comparison.jpg +3 -0
  7. assets/embeddings_sd_1.4/cat.pt +3 -0
  8. assets/embeddings_sd_1.4/dog.pt +3 -0
  9. assets/embeddings_sd_1.4/horse.pt +3 -0
  10. assets/embeddings_sd_1.4/zebra.pt +3 -0
  11. assets/grid_cat2dog.jpg +3 -0
  12. assets/grid_dog2cat.jpg +3 -0
  13. assets/grid_horse2zebra.jpg +3 -0
  14. assets/grid_tree2fall.jpg +3 -0
  15. assets/grid_zebra2horse.jpg +3 -0
  16. assets/main.gif +3 -0
  17. assets/method.jpeg +3 -0
  18. assets/results_real.jpg +3 -0
  19. assets/results_syn.jpg +3 -0
  20. assets/results_teaser.jpg +0 -0
  21. assets/test_images/cats/cat_1.png +0 -0
  22. assets/test_images/cats/cat_2.png +0 -0
  23. assets/test_images/cats/cat_3.png +0 -0
  24. assets/test_images/cats/cat_4.png +0 -0
  25. assets/test_images/cats/cat_5.png +0 -0
  26. assets/test_images/cats/cat_6.png +0 -0
  27. assets/test_images/cats/cat_7.png +0 -0
  28. assets/test_images/cats/cat_8.png +0 -0
  29. assets/test_images/cats/cat_9.png +0 -0
  30. assets/test_images/dogs/dog_1.png +0 -0
  31. assets/test_images/dogs/dog_2.png +0 -0
  32. assets/test_images/dogs/dog_3.png +0 -0
  33. assets/test_images/dogs/dog_4.png +0 -0
  34. assets/test_images/dogs/dog_5.png +0 -0
  35. assets/test_images/dogs/dog_6.png +0 -0
  36. assets/test_images/dogs/dog_7.png +0 -0
  37. assets/test_images/dogs/dog_8.png +0 -0
  38. assets/test_images/dogs/dog_9.png +0 -0
  39. environment.yml +23 -0
  40. output/test_cat/edit/cat_1.png +0 -0
  41. output/test_cat/inversion/cat_1.pt +3 -0
  42. output/test_cat/prompt/cat_1.txt +1 -0
  43. output/test_cat/reconstruction/cat_1.png +0 -0
  44. output/test_cat2/edit/cat_1.png +0 -0
  45. output/test_cat2/reconstruction/cat_1.png +0 -0
  46. requirements.txt +11 -0
  47. src/edit_real.py +65 -0
  48. src/edit_synthetic.py +52 -0
  49. src/inversion.py +64 -0
  50. src/make_edit_direction.py +61 -0
.gitattributes CHANGED
@@ -32,3 +32,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ assets/comparison.jpg filter=lfs diff=lfs merge=lfs -text
36
+ assets/grid_cat2dog.jpg filter=lfs diff=lfs merge=lfs -text
37
+ assets/grid_dog2cat.jpg filter=lfs diff=lfs merge=lfs -text
38
+ assets/grid_horse2zebra.jpg filter=lfs diff=lfs merge=lfs -text
39
+ assets/grid_tree2fall.jpg filter=lfs diff=lfs merge=lfs -text
40
+ assets/grid_zebra2horse.jpg filter=lfs diff=lfs merge=lfs -text
41
+ assets/main.gif filter=lfs diff=lfs merge=lfs -text
42
+ assets/method.jpeg filter=lfs diff=lfs merge=lfs -text
43
+ assets/results_real.jpg filter=lfs diff=lfs merge=lfs -text
44
+ assets/results_syn.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ output
2
+ scripts
3
+ src/folder_*.py
4
+ src/ig_*.py
5
+ assets/edit_sentences
6
+ src/utils/edit_pipeline_spatial.py
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 pix2pixzero
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
app.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+
4
+ def update(name):
5
+ os.system('''python src/inversion.py \
6
+ --input_image "assets/test_images/cats/cat_1.png" \
7
+ --results_folder "output/test_cat"
8
+ ''')
9
+ return f"Inverted!"
10
+
11
+ with gr.Blocks() as demo:
12
+ gr.Markdown("Start typing below and then click **Run** to see the output.")
13
+ with gr.Row():
14
+ inp = gr.Textbox(placeholder="Do you want to invert?")
15
+ out = gr.Textbox()
16
+ btn = gr.Button("Run")
17
+ btn.click(fn=update, inputs=inp, outputs=out)
18
+
19
+ demo.launch()
assets/.DS_Store ADDED
Binary file (6.15 kB). View file
 
assets/comparison.jpg ADDED

Git LFS Details

  • SHA256: acab8ed1680a42dd2f540e8188a43eb0d101895fca8ed36c0e06c8b351d2c276
  • Pointer size: 132 Bytes
  • Size of remote file: 3.39 MB
assets/embeddings_sd_1.4/cat.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:aa9441dc014d5e86567c5ef165e10b50d2a7b3a68d90686d0cd1006792adf334
3
+ size 237300
assets/embeddings_sd_1.4/dog.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:becf079d61d7f35727bcc0d8506ddcdcddb61e62d611840ff3d18eca7fb6338c
3
+ size 237300
assets/embeddings_sd_1.4/horse.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c5d499299544d11371f84674761292b0512055ef45776c700c0b0da164cbf6c7
3
+ size 118949
assets/embeddings_sd_1.4/zebra.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a29f6a11d91f3a276e27326b7623fae9d61a3d253ad430bb868bd40fb7e02fec
3
+ size 118949
assets/grid_cat2dog.jpg ADDED

Git LFS Details

  • SHA256: 0080134b70277af723e25c4627494fda8555d43a9f6376e682b67b3341d1f1f3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.21 MB
assets/grid_dog2cat.jpg ADDED

Git LFS Details

  • SHA256: 0e5059ec1ad8e4b07fe8b715295e82fcead652b9c366733793674e84d51427d9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.25 MB
assets/grid_horse2zebra.jpg ADDED

Git LFS Details

  • SHA256: a31e0a456e9323697c966e675b02403511ebf0b7c334416a8da91df1c14723df
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/grid_tree2fall.jpg ADDED

Git LFS Details

  • SHA256: 559ab066e4ef0972748d0a7f004d2ca18fd15062c667ac6665309727f6dc0cc8
  • Pointer size: 132 Bytes
  • Size of remote file: 1.63 MB
assets/grid_zebra2horse.jpg ADDED

Git LFS Details

  • SHA256: b44b4aa4576be49289515f0aa9023dfd4424b3ba2476c66516b876dd83a06713
  • Pointer size: 132 Bytes
  • Size of remote file: 1.05 MB
assets/main.gif ADDED

Git LFS Details

  • SHA256: d1ebc380a461c4847beece13bdc9b5ea88312e8a8013f384eb8809109ff198fc
  • Pointer size: 132 Bytes
  • Size of remote file: 6.19 MB
assets/method.jpeg ADDED

Git LFS Details

  • SHA256: 8b1b4ea3608b9ad3797c4c7423bf2fd88e5e24f34fecbb00d3d2de22a99fd2ee
  • Pointer size: 132 Bytes
  • Size of remote file: 2.35 MB
assets/results_real.jpg ADDED

Git LFS Details

  • SHA256: 94095526e76b7a000ed56df15f7b5208c0f5a069b20b04fc9bcade14c54d92dc
  • Pointer size: 132 Bytes
  • Size of remote file: 1.48 MB
assets/results_syn.jpg ADDED

Git LFS Details

  • SHA256: 5731190e33098406995de563ca12bd6d2f84d9db725618a6d6580b4d1f2f0813
  • Pointer size: 132 Bytes
  • Size of remote file: 1.28 MB
assets/results_teaser.jpg ADDED
assets/test_images/cats/cat_1.png ADDED
assets/test_images/cats/cat_2.png ADDED
assets/test_images/cats/cat_3.png ADDED
assets/test_images/cats/cat_4.png ADDED
assets/test_images/cats/cat_5.png ADDED
assets/test_images/cats/cat_6.png ADDED
assets/test_images/cats/cat_7.png ADDED
assets/test_images/cats/cat_8.png ADDED
assets/test_images/cats/cat_9.png ADDED
assets/test_images/dogs/dog_1.png ADDED
assets/test_images/dogs/dog_2.png ADDED
assets/test_images/dogs/dog_3.png ADDED
assets/test_images/dogs/dog_4.png ADDED
assets/test_images/dogs/dog_5.png ADDED
assets/test_images/dogs/dog_6.png ADDED
assets/test_images/dogs/dog_7.png ADDED
assets/test_images/dogs/dog_8.png ADDED
assets/test_images/dogs/dog_9.png ADDED
environment.yml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: pix2pix-zero
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ - defaults
6
+ dependencies:
7
+ - pip
8
+ - pytorch-cuda=11.6
9
+ - torchvision
10
+ - pytorch
11
+ - pip:
12
+ - accelerate
13
+ - diffusers
14
+ - einops
15
+ - gradio
16
+ - ipython
17
+ - numpy
18
+ - opencv-python-headless
19
+ - pillow
20
+ - psutil
21
+ - tqdm
22
+ - transformers
23
+ - salesforce-lavis
output/test_cat/edit/cat_1.png ADDED
output/test_cat/inversion/cat_1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fd7cd2554d2d695841ede6038f7906b50085841706a2f62429ee32c08a0dc85
3
+ size 66283
output/test_cat/prompt/cat_1.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ a dog with his paws on top of a ball, painting
output/test_cat/reconstruction/cat_1.png ADDED
output/test_cat2/edit/cat_1.png ADDED
output/test_cat2/reconstruction/cat_1.png ADDED
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate
2
+ diffusers
3
+ einops
4
+ numpy
5
+ opencv-python-headless
6
+ pillow
7
+ psutil
8
+ transformers
9
+ tqdm
10
+ pytorch
11
+ salesforce-lavis
src/edit_real.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.ddim_inv import DDIMInversion
11
+ from utils.edit_directions import construct_direction
12
+ from utils.edit_pipeline import EditingPipeline
13
+
14
+
15
+ if __name__=="__main__":
16
+ parser = argparse.ArgumentParser()
17
+ parser.add_argument('--inversion', required=True)
18
+ parser.add_argument('--prompt', type=str, required=True)
19
+ parser.add_argument('--task_name', type=str, default='cat2dog')
20
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
21
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
22
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
23
+ parser.add_argument('--xa_guidance', default=0.1, type=float)
24
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
25
+ parser.add_argument('--use_float_16', action='store_true')
26
+
27
+ args = parser.parse_args()
28
+
29
+ os.makedirs(os.path.join(args.results_folder, "edit"), exist_ok=True)
30
+ os.makedirs(os.path.join(args.results_folder, "reconstruction"), exist_ok=True)
31
+
32
+ if args.use_float_16:
33
+ torch_dtype = torch.float16
34
+ else:
35
+ torch_dtype = torch.float32
36
+
37
+ # if the inversion is a folder, the prompt should also be a folder
38
+ assert (os.path.isdir(args.inversion)==os.path.isdir(args.prompt)), "If the inversion is a folder, the prompt should also be a folder"
39
+ if os.path.isdir(args.inversion):
40
+ l_inv_paths = sorted(glob(os.path.join(args.inversion, "*.pt")))
41
+ l_bnames = [os.path.basename(x) for x in l_inv_paths]
42
+ l_prompt_paths = [os.path.join(args.prompt, x.replace(".pt",".txt")) for x in l_bnames]
43
+ else:
44
+ l_inv_paths = [args.inversion]
45
+ l_prompt_paths = [args.prompt]
46
+
47
+ # Make the editing pipeline
48
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
49
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
50
+
51
+
52
+ for inv_path, prompt_path in zip(l_inv_paths, l_prompt_paths):
53
+ prompt_str = open(prompt_path).read().strip()
54
+ rec_pil, edit_pil = pipe(prompt_str,
55
+ num_inference_steps=args.num_ddim_steps,
56
+ x_in=torch.load(inv_path).unsqueeze(0),
57
+ edit_dir=construct_direction(args.task_name),
58
+ guidance_amount=args.xa_guidance,
59
+ guidance_scale=args.negative_guidance_scale,
60
+ negative_prompt=prompt_str # use the unedited prompt for the negative prompt
61
+ )
62
+
63
+ bname = os.path.basename(args.inversion).split(".")[0]
64
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit/{bname}.png"))
65
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction/{bname}.png"))
src/edit_synthetic.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_directions import construct_direction
11
+ from utils.edit_pipeline import EditingPipeline
12
+
13
+
14
+ if __name__=="__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--prompt_str', type=str, required=True)
17
+ parser.add_argument('--random_seed', default=0)
18
+ parser.add_argument('--task_name', type=str, default='cat2dog')
19
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
20
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
21
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
22
+ parser.add_argument('--xa_guidance', default=0.15, type=float)
23
+ parser.add_argument('--negative_guidance_scale', default=5.0, type=float)
24
+ parser.add_argument('--use_float_16', action='store_true')
25
+ args = parser.parse_args()
26
+
27
+ os.makedirs(args.results_folder, exist_ok=True)
28
+
29
+ if args.use_float_16:
30
+ torch_dtype = torch.float16
31
+ else:
32
+ torch_dtype = torch.float32
33
+
34
+ # make the input noise map
35
+ torch.cuda.manual_seed(args.random_seed)
36
+ x = torch.randn((1,4,64,64), device="cuda")
37
+
38
+ # Make the editing pipeline
39
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
40
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
41
+
42
+ rec_pil, edit_pil = pipe(args.prompt_str,
43
+ num_inference_steps=args.num_ddim_steps,
44
+ x_in=x,
45
+ edit_dir=construct_direction(args.task_name),
46
+ guidance_amount=args.xa_guidance,
47
+ guidance_scale=args.negative_guidance_scale,
48
+ negative_prompt="" # use the empty string for the negative prompt
49
+ )
50
+
51
+ edit_pil[0].save(os.path.join(args.results_folder, f"edit.png"))
52
+ rec_pil[0].save(os.path.join(args.results_folder, f"reconstruction.png"))
src/inversion.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from lavis.models import load_model_and_preprocess
10
+
11
+ from utils.ddim_inv import DDIMInversion
12
+ from utils.scheduler import DDIMInverseScheduler
13
+
14
+ if __name__=="__main__":
15
+ parser = argparse.ArgumentParser()
16
+ parser.add_argument('--input_image', type=str, default='assets/test_images/cat_a.png')
17
+ parser.add_argument('--results_folder', type=str, default='output/test_cat')
18
+ parser.add_argument('--num_ddim_steps', type=int, default=50)
19
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
20
+ parser.add_argument('--use_float_16', action='store_true')
21
+ args = parser.parse_args()
22
+
23
+ # make the output folders
24
+ os.makedirs(os.path.join(args.results_folder, "inversion"), exist_ok=True)
25
+ os.makedirs(os.path.join(args.results_folder, "prompt"), exist_ok=True)
26
+
27
+ if args.use_float_16:
28
+ torch_dtype = torch.float16
29
+ else:
30
+ torch_dtype = torch.float32
31
+
32
+
33
+ # load the BLIP model
34
+ model_blip, vis_processors, _ = load_model_and_preprocess(name="blip_caption", model_type="base_coco", is_eval=True, device=torch.device("cuda"))
35
+ # make the DDIM inversion pipeline
36
+ pipe = DDIMInversion.from_pretrained(args.model_path, torch_dtype=torch_dtype).to("cuda")
37
+ pipe.scheduler = DDIMInverseScheduler.from_config(pipe.scheduler.config)
38
+
39
+
40
+ # if the input is a folder, collect all the images as a list
41
+ if os.path.isdir(args.input_image):
42
+ l_img_paths = sorted(glob(os.path.join(args.input_image, "*.png")))
43
+ else:
44
+ l_img_paths = [args.input_image]
45
+
46
+
47
+ for img_path in l_img_paths:
48
+ bname = os.path.basename(args.input_image).split(".")[0]
49
+ img = Image.open(args.input_image).resize((512,512), Image.Resampling.LANCZOS)
50
+ # generate the caption
51
+ _image = vis_processors["eval"](img).unsqueeze(0).cuda()
52
+ prompt_str = model_blip.generate({"image": _image})[0]
53
+ x_inv, x_inv_image, x_dec_img = pipe(
54
+ prompt_str,
55
+ guidance_scale=1,
56
+ num_inversion_steps=args.num_ddim_steps,
57
+ img=img,
58
+ torch_dtype=torch_dtype
59
+ )
60
+ # save the inversion
61
+ torch.save(x_inv[0], os.path.join(args.results_folder, f"inversion/{bname}.pt"))
62
+ # save the prompt string
63
+ with open(os.path.join(args.results_folder, f"prompt/{bname}.txt"), "w") as f:
64
+ f.write(prompt_str)
src/make_edit_direction.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, pdb
2
+
3
+ import argparse
4
+ import numpy as np
5
+ import torch
6
+ import requests
7
+ from PIL import Image
8
+
9
+ from diffusers import DDIMScheduler
10
+ from utils.edit_pipeline import EditingPipeline
11
+
12
+
13
+ ## convert sentences to sentence embeddings
14
+ def load_sentence_embeddings(l_sentences, tokenizer, text_encoder, device="cuda"):
15
+ with torch.no_grad():
16
+ l_embeddings = []
17
+ for sent in l_sentences:
18
+ text_inputs = tokenizer(
19
+ sent,
20
+ padding="max_length",
21
+ max_length=tokenizer.model_max_length,
22
+ truncation=True,
23
+ return_tensors="pt",
24
+ )
25
+ text_input_ids = text_inputs.input_ids
26
+ prompt_embeds = text_encoder(text_input_ids.to(device), attention_mask=None)[0]
27
+ l_embeddings.append(prompt_embeds)
28
+ return torch.concatenate(l_embeddings, dim=0).mean(dim=0).unsqueeze(0)
29
+
30
+
31
+ if __name__=="__main__":
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--file_source_sentences', required=True)
34
+ parser.add_argument('--file_target_sentences', required=True)
35
+ parser.add_argument('--output_folder', required=True)
36
+ parser.add_argument('--model_path', type=str, default='CompVis/stable-diffusion-v1-4')
37
+ args = parser.parse_args()
38
+
39
+ # load the model
40
+ pipe = EditingPipeline.from_pretrained(args.model_path, torch_dtype=torch.float16).to("cuda")
41
+ bname_src = os.path.basename(args.file_source_sentences).strip(".txt")
42
+ outf_src = os.path.join(args.output_folder, bname_src+".pt")
43
+ if os.path.exists(outf_src):
44
+ print(f"Skipping source file {outf_src} as it already exists")
45
+ else:
46
+ with open(args.file_source_sentences, "r") as f:
47
+ l_sents = [x.strip() for x in f.readlines()]
48
+ mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
49
+ print(mean_emb.shape)
50
+ torch.save(mean_emb, outf_src)
51
+
52
+ bname_tgt = os.path.basename(args.file_target_sentences).strip(".txt")
53
+ outf_tgt = os.path.join(args.output_folder, bname_tgt+".pt")
54
+ if os.path.exists(outf_tgt):
55
+ print(f"Skipping target file {outf_tgt} as it already exists")
56
+ else:
57
+ with open(args.file_target_sentences, "r") as f:
58
+ l_sents = [x.strip() for x in f.readlines()]
59
+ mean_emb = load_sentence_embeddings(l_sents, pipe.tokenizer, pipe.text_encoder, device="cuda")
60
+ print(mean_emb.shape)
61
+ torch.save(mean_emb, outf_tgt)