Toraong Bingsu commited on
Commit
8b21943
·
0 Parent(s):

Duplicate from Bingsu/color_textual_inversion

Browse files

Co-authored-by: Dowon Hwang <[email protected]>

Files changed (9) hide show
  1. .gitignore +173 -0
  2. LICENSE.md +22 -0
  3. README.md +11 -0
  4. app.py +128 -0
  5. info.txt +7 -0
  6. pdm.lock +0 -0
  7. pyproject.toml +40 -0
  8. requirements.txt +9 -0
  9. textual_inversion.py +769 -0
.gitignore ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Created by https://www.toptal.com/developers/gitignore/api/python
2
+ # Edit at https://www.toptal.com/developers/gitignore?templates=python
3
+
4
+ ### Python ###
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ share/python-wheels/
28
+ *.egg-info/
29
+ .installed.cfg
30
+ *.egg
31
+ MANIFEST
32
+
33
+ # PyInstaller
34
+ # Usually these files are written by a python script from a template
35
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
36
+ *.manifest
37
+ *.spec
38
+
39
+ # Installer logs
40
+ pip-log.txt
41
+ pip-delete-this-directory.txt
42
+
43
+ # Unit test / coverage reports
44
+ htmlcov/
45
+ .tox/
46
+ .nox/
47
+ .coverage
48
+ .coverage.*
49
+ .cache
50
+ nosetests.xml
51
+ coverage.xml
52
+ *.cover
53
+ *.py,cover
54
+ .hypothesis/
55
+ .pytest_cache/
56
+ cover/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+
78
+ # PyBuilder
79
+ .pybuilder/
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ # For a library or package, you might want to ignore these files since the code is
91
+ # intended to run in multiple environments; otherwise, check them in:
92
+ # .python-version
93
+
94
+ # pipenv
95
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
96
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
97
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
98
+ # install all needed dependencies.
99
+ #Pipfile.lock
100
+
101
+ # poetry
102
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
103
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
104
+ # commonly ignored for libraries.
105
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
106
+ #poetry.lock
107
+
108
+ # pdm
109
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
110
+ #pdm.lock
111
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
112
+ # in version control.
113
+ # https://pdm.fming.dev/#use-with-ide
114
+ .pdm.toml
115
+
116
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
117
+ __pypackages__/
118
+
119
+ # Celery stuff
120
+ celerybeat-schedule
121
+ celerybeat.pid
122
+
123
+ # SageMath parsed files
124
+ *.sage.py
125
+
126
+ # Environments
127
+ .env
128
+ .venv
129
+ env/
130
+ venv/
131
+ ENV/
132
+ env.bak/
133
+ venv.bak/
134
+
135
+ # Spyder project settings
136
+ .spyderproject
137
+ .spyproject
138
+
139
+ # Rope project settings
140
+ .ropeproject
141
+
142
+ # mkdocs documentation
143
+ /site
144
+
145
+ # mypy
146
+ .mypy_cache/
147
+ .dmypy.json
148
+ dmypy.json
149
+
150
+ # Pyre type checker
151
+ .pyre/
152
+
153
+ # pytype static type analyzer
154
+ .pytype/
155
+
156
+ # Cython debug symbols
157
+ cython_debug/
158
+
159
+ # PyCharm
160
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
161
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
162
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
163
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
164
+ #.idea/
165
+
166
+ ### Python Patch ###
167
+ # Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration
168
+ poetry.toml
169
+
170
+
171
+ # End of https://www.toptal.com/developers/gitignore/api/python
172
+ dataset/
173
+ *.pt
LICENSE.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ The MIT License (MIT)
3
+
4
+ Copyright (c) 2022 Bingsu
5
+
6
+ Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ of this software and associated documentation files (the "Software"), to deal
8
+ in the Software without restriction, including without limitation the rights
9
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ copies of the Software, and to permit persons to whom the Software is
11
+ furnished to do so, subject to the following conditions:
12
+
13
+ The above copyright notice and this permission notice shall be included in all
14
+ copies or substantial portions of the Software.
15
+
16
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: color_textual_inversion
3
+ emoji: 🖌️
4
+ sdk: streamlit
5
+ python_version: 3.9
6
+ sdk_version: 1.10.0
7
+ app_file: app.py
8
+ duplicated_from: Bingsu/color_textual_inversion
9
+ ---
10
+
11
+ # color_textual_inversion
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import shlex
4
+ import subprocess
5
+ from pathlib import Path
6
+ from tempfile import TemporaryDirectory
7
+ from textwrap import dedent
8
+
9
+ import numpy as np
10
+ import streamlit as st
11
+ import torch
12
+ from PIL import Image
13
+ from transformers import CLIPTokenizer
14
+
15
+
16
+ def hex_to_rgb(s: str) -> tuple[int, int, int]:
17
+ value = s.lstrip("#")
18
+ return (int(value[:2], 16), int(value[2:4], 16), int(value[4:6], 16))
19
+
20
+
21
+ st.header("Color Textual Inversion")
22
+ with st.expander(label="info"):
23
+ with open("info.txt", "r", encoding="utf-8") as f:
24
+ st.markdown(f.read())
25
+
26
+ duplicate_button = """<a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/Bingsu/color_textual_inversion?duplicate=true"><img style="margin: 0" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a>"""
27
+ st.markdown(duplicate_button, unsafe_allow_html=True)
28
+
29
+ col1, col2 = st.columns([15, 85])
30
+ color = col1.color_picker("Pick a color", "#00f900")
31
+ col2.text_input("", color, disabled=True)
32
+
33
+ emb_name = st.text_input("Embedding name", color.lstrip("#").upper())
34
+ init_token = st.text_input("Initializer token", "init token name")
35
+ rgb = hex_to_rgb(color)
36
+
37
+ img_array = np.zeros((128, 128, 3), dtype=np.uint8)
38
+ for i in range(3):
39
+ img_array[..., i] = rgb[i]
40
+
41
+ dataset_temp = TemporaryDirectory(prefix="dataset_", dir=".")
42
+ dataset_path = Path(dataset_temp.name)
43
+ output_temp = TemporaryDirectory(prefix="output_", dir=".")
44
+ output_path = Path(output_temp.name)
45
+
46
+ img_path = dataset_path / f"{emb_name}.png"
47
+ Image.fromarray(img_array).save(img_path)
48
+
49
+ with st.sidebar:
50
+ model_name = st.text_input("Model name", "Linaqruf/anything-v3.0")
51
+ steps = st.slider("Steps", 1, 100, 30, step=1)
52
+ learning_rate = st.text_input("Learning rate", "0.005")
53
+ learning_rate = float(learning_rate)
54
+
55
+ tokenizer = CLIPTokenizer.from_pretrained(model_name, subfolder="tokenizer")
56
+
57
+ # case 1: init_token is not a single token
58
+ token = tokenizer.tokenize(init_token)
59
+ if len(token) > 1:
60
+ st.warning("Initializer token must be a single token")
61
+ st.stop()
62
+
63
+ # case 2: init_token already exists in the tokenizer
64
+ num_added_tokens = tokenizer.add_tokens(emb_name)
65
+ if num_added_tokens == 0:
66
+ st.warning(f"The tokenizer already contains the token {emb_name}")
67
+ st.stop()
68
+
69
+ cmd = """
70
+ accelerate launch textual_inversion.py \
71
+ --pretrained_model_name_or_path={model_name} \
72
+ --train_data_dir={dataset_path} \
73
+ --learnable_property="style" \
74
+ --placeholder_token="{emb_name}" \
75
+ --initializer_token="{init}" \
76
+ --resolution=128 \
77
+ --train_batch_size=1 \
78
+ --repeats=1 \
79
+ --gradient_accumulation_steps=1 \
80
+ --max_train_steps={steps} \
81
+ --learning_rate={lr} \
82
+ --output_dir={output_path} \
83
+ --only_save_embeds
84
+ """.strip()
85
+
86
+ cmd = dedent(cmd).format(
87
+ model_name=model_name,
88
+ dataset_path=dataset_path.as_posix(),
89
+ emb_name=emb_name,
90
+ init=init_token,
91
+ steps=steps,
92
+ lr=learning_rate,
93
+ output_path=output_path.as_posix(),
94
+ )
95
+ cmd = shlex.split(cmd)
96
+
97
+ result_path = output_path / "learned_embeds.bin"
98
+ captured = ""
99
+
100
+ start_button = st.button("Start")
101
+ download_button = st.empty()
102
+
103
+ if start_button:
104
+ with st.spinner("Training..."):
105
+ placeholder = st.empty()
106
+ p = subprocess.Popen(
107
+ cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, encoding="utf-8"
108
+ )
109
+
110
+ while line := p.stderr.readline():
111
+ captured += line
112
+ placeholder.code(captured, language="bash")
113
+
114
+ if not result_path.exists():
115
+ st.stop()
116
+
117
+ # fix unknown file volume bug
118
+ trained_emb = torch.load(result_path, map_location="cpu")
119
+ for k, v in trained_emb.items():
120
+ trained_emb[k] = torch.from_numpy(v.numpy())
121
+ torch.save(trained_emb, result_path)
122
+
123
+ file = result_path.read_bytes()
124
+ download_button.download_button(f"Download {emb_name}.pt", file, f"{emb_name}.pt")
125
+ st.download_button(f"Download {emb_name}.pt ", file, f"{emb_name}.pt")
126
+
127
+ dataset_temp.cleanup()
128
+ output_temp.cleanup()
info.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Create an embedding that represents a color code.
2
+
3
+ Using only one simple color image, textual inversion training is performed.
4
+
5
+ This idea is from the arcalive AI image channel, [내가 원하는 색상코드를 만들어 사용해 보자](https://arca.live/b/aiart/64702219).
6
+
7
+ However, this space uses an implementation of huggingface diffusers, so the result is different from webui. Please be careful on this point.
pdm.lock ADDED
The diff for this file is too large to render. See raw diff
 
pyproject.toml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "color-textual-inversion"
3
+ version = "0.1.3"
4
+ description = ""
5
+ authors = [
6
+ {name = "Bingsu", email = "[email protected]"},
7
+ ]
8
+ dependencies = [
9
+ "torch",
10
+ "torchvision",
11
+ "accelerate",
12
+ "ftfy",
13
+ "tensorboard",
14
+ "modelcards",
15
+ "transformers>=4.21.0",
16
+ "diffusers",
17
+ "streamlit==1.10.0",
18
+ ]
19
+ license = {text = "MIT"}
20
+ requires-python = ">=3.9"
21
+
22
+ [tool]
23
+ [tool.pdm]
24
+ [tool.pdm.dev-dependencies]
25
+ dev = [
26
+ "black>=22.10.0",
27
+ "isort>=5.10.1",
28
+ "mypy>=0.991",
29
+ "flake8-bugbear>=22.12.6",
30
+ "ipywidgets>=8.0.3",
31
+ ]
32
+
33
+ [tool.pdm.scripts]
34
+ st = "streamlit run app.py"
35
+ black = "black ."
36
+ isort = "isort ."
37
+ format = {composite = ["isort", "black"]}
38
+
39
+ [tool.isort]
40
+ profile = "black"
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ accelerate
4
+ ftfy
5
+ tensorboard
6
+ modelcards
7
+ transformers>=4.21.0
8
+ diffusers
9
+ streamlit==1.10.0
textual_inversion.py ADDED
@@ -0,0 +1,769 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import math
4
+ import os
5
+ import random
6
+ from pathlib import Path
7
+ from typing import Optional
8
+
9
+ import numpy as np
10
+ import PIL
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from accelerate import Accelerator
15
+ from accelerate.logging import get_logger
16
+ from accelerate.utils import set_seed
17
+ from diffusers import (
18
+ AutoencoderKL,
19
+ DDPMScheduler,
20
+ PNDMScheduler,
21
+ StableDiffusionPipeline,
22
+ UNet2DConditionModel,
23
+ )
24
+ from diffusers.optimization import get_scheduler
25
+ from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
26
+
27
+ # from diffusers.utils import check_min_version
28
+ from huggingface_hub import HfFolder, Repository, whoami
29
+
30
+ # TODO: remove and import from diffusers.utils when the new version of diffusers is released
31
+ from packaging import version
32
+ from PIL import Image
33
+ from torch.utils.data import Dataset
34
+ from torchvision import transforms
35
+ from tqdm.auto import tqdm
36
+ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
37
+
38
+ if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
39
+ PIL_INTERPOLATION = {
40
+ "linear": PIL.Image.Resampling.BILINEAR,
41
+ "bilinear": PIL.Image.Resampling.BILINEAR,
42
+ "bicubic": PIL.Image.Resampling.BICUBIC,
43
+ "lanczos": PIL.Image.Resampling.LANCZOS,
44
+ "nearest": PIL.Image.Resampling.NEAREST,
45
+ }
46
+ else:
47
+ PIL_INTERPOLATION = {
48
+ "linear": PIL.Image.LINEAR,
49
+ "bilinear": PIL.Image.BILINEAR,
50
+ "bicubic": PIL.Image.BICUBIC,
51
+ "lanczos": PIL.Image.LANCZOS,
52
+ "nearest": PIL.Image.NEAREST,
53
+ }
54
+ # ------------------------------------------------------------------------------
55
+
56
+
57
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
58
+ # check_min_version("0.10.0.dev0")
59
+
60
+
61
+ logger = get_logger(__name__)
62
+
63
+
64
+ def save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path):
65
+ logger.info("Saving embeddings")
66
+ learned_embeds = (
67
+ accelerator.unwrap_model(text_encoder)
68
+ .get_input_embeddings()
69
+ .weight[placeholder_token_id]
70
+ )
71
+ learned_embeds_dict = {args.placeholder_token: learned_embeds.detach().cpu()}
72
+ torch.save(learned_embeds_dict, save_path)
73
+
74
+
75
+ def parse_args():
76
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
77
+ parser.add_argument(
78
+ "--save_steps",
79
+ type=int,
80
+ default=500,
81
+ help="Save learned_embeds.bin every X updates steps.",
82
+ )
83
+ parser.add_argument(
84
+ "--only_save_embeds",
85
+ action="store_true",
86
+ default=False,
87
+ help="Save only the embeddings for the new concept.",
88
+ )
89
+ parser.add_argument(
90
+ "--pretrained_model_name_or_path",
91
+ type=str,
92
+ default=None,
93
+ required=True,
94
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
95
+ )
96
+ parser.add_argument(
97
+ "--revision",
98
+ type=str,
99
+ default=None,
100
+ required=False,
101
+ help="Revision of pretrained model identifier from huggingface.co/models.",
102
+ )
103
+ parser.add_argument(
104
+ "--tokenizer_name",
105
+ type=str,
106
+ default=None,
107
+ help="Pretrained tokenizer name or path if not the same as model_name",
108
+ )
109
+ parser.add_argument(
110
+ "--train_data_dir",
111
+ type=str,
112
+ default=None,
113
+ required=True,
114
+ help="A folder containing the training data.",
115
+ )
116
+ parser.add_argument(
117
+ "--placeholder_token",
118
+ type=str,
119
+ default=None,
120
+ required=True,
121
+ help="A token to use as a placeholder for the concept.",
122
+ )
123
+ parser.add_argument(
124
+ "--initializer_token",
125
+ type=str,
126
+ default=None,
127
+ required=True,
128
+ help="A token to use as initializer word.",
129
+ )
130
+ parser.add_argument(
131
+ "--learnable_property",
132
+ type=str,
133
+ default="object",
134
+ help="Choose between 'object' and 'style'",
135
+ )
136
+ parser.add_argument(
137
+ "--repeats",
138
+ type=int,
139
+ default=100,
140
+ help="How many times to repeat the training data.",
141
+ )
142
+ parser.add_argument(
143
+ "--output_dir",
144
+ type=str,
145
+ default="text-inversion-model",
146
+ help="The output directory where the model predictions and checkpoints will be written.",
147
+ )
148
+ parser.add_argument(
149
+ "--seed", type=int, default=None, help="A seed for reproducible training."
150
+ )
151
+ parser.add_argument(
152
+ "--resolution",
153
+ type=int,
154
+ default=512,
155
+ help=(
156
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
157
+ " resolution"
158
+ ),
159
+ )
160
+ parser.add_argument(
161
+ "--center_crop",
162
+ action="store_true",
163
+ help="Whether to center crop images before resizing to resolution",
164
+ )
165
+ parser.add_argument(
166
+ "--train_batch_size",
167
+ type=int,
168
+ default=16,
169
+ help="Batch size (per device) for the training dataloader.",
170
+ )
171
+ parser.add_argument("--num_train_epochs", type=int, default=100)
172
+ parser.add_argument(
173
+ "--max_train_steps",
174
+ type=int,
175
+ default=5000,
176
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
177
+ )
178
+ parser.add_argument(
179
+ "--gradient_accumulation_steps",
180
+ type=int,
181
+ default=1,
182
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
183
+ )
184
+ parser.add_argument(
185
+ "--learning_rate",
186
+ type=float,
187
+ default=1e-4,
188
+ help="Initial learning rate (after the potential warmup period) to use.",
189
+ )
190
+ parser.add_argument(
191
+ "--scale_lr",
192
+ action="store_true",
193
+ default=True,
194
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
195
+ )
196
+ parser.add_argument(
197
+ "--lr_scheduler",
198
+ type=str,
199
+ default="constant",
200
+ help=(
201
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
202
+ ' "constant", "constant_with_warmup"]'
203
+ ),
204
+ )
205
+ parser.add_argument(
206
+ "--lr_warmup_steps",
207
+ type=int,
208
+ default=500,
209
+ help="Number of steps for the warmup in the lr scheduler.",
210
+ )
211
+ parser.add_argument(
212
+ "--adam_beta1",
213
+ type=float,
214
+ default=0.9,
215
+ help="The beta1 parameter for the Adam optimizer.",
216
+ )
217
+ parser.add_argument(
218
+ "--adam_beta2",
219
+ type=float,
220
+ default=0.999,
221
+ help="The beta2 parameter for the Adam optimizer.",
222
+ )
223
+ parser.add_argument(
224
+ "--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use."
225
+ )
226
+ parser.add_argument(
227
+ "--adam_epsilon",
228
+ type=float,
229
+ default=1e-08,
230
+ help="Epsilon value for the Adam optimizer",
231
+ )
232
+ parser.add_argument(
233
+ "--push_to_hub",
234
+ action="store_true",
235
+ help="Whether or not to push the model to the Hub.",
236
+ )
237
+ parser.add_argument(
238
+ "--hub_token",
239
+ type=str,
240
+ default=None,
241
+ help="The token to use to push to the Model Hub.",
242
+ )
243
+ parser.add_argument(
244
+ "--hub_model_id",
245
+ type=str,
246
+ default=None,
247
+ help="The name of the repository to keep in sync with the local `output_dir`.",
248
+ )
249
+ parser.add_argument(
250
+ "--logging_dir",
251
+ type=str,
252
+ default="logs",
253
+ help=(
254
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
255
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
256
+ ),
257
+ )
258
+ parser.add_argument(
259
+ "--mixed_precision",
260
+ type=str,
261
+ default="no",
262
+ choices=["no", "fp16", "bf16"],
263
+ help=(
264
+ "Whether to use mixed precision. Choose"
265
+ "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
266
+ "and an Nvidia Ampere GPU."
267
+ ),
268
+ )
269
+ parser.add_argument(
270
+ "--local_rank",
271
+ type=int,
272
+ default=-1,
273
+ help="For distributed training: local_rank",
274
+ )
275
+
276
+ args = parser.parse_args()
277
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
278
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
279
+ args.local_rank = env_local_rank
280
+
281
+ if args.train_data_dir is None:
282
+ raise ValueError("You must specify a train data directory.")
283
+
284
+ return args
285
+
286
+
287
+ imagenet_templates_small = [
288
+ "a photo of a {}",
289
+ "a rendering of a {}",
290
+ "a cropped photo of the {}",
291
+ "the photo of a {}",
292
+ "a photo of a clean {}",
293
+ "a photo of a dirty {}",
294
+ "a dark photo of the {}",
295
+ "a photo of my {}",
296
+ "a photo of the cool {}",
297
+ "a close-up photo of a {}",
298
+ "a bright photo of the {}",
299
+ "a cropped photo of a {}",
300
+ "a photo of the {}",
301
+ "a good photo of the {}",
302
+ "a photo of one {}",
303
+ "a close-up photo of the {}",
304
+ "a rendition of the {}",
305
+ "a photo of the clean {}",
306
+ "a rendition of a {}",
307
+ "a photo of a nice {}",
308
+ "a good photo of a {}",
309
+ "a photo of the nice {}",
310
+ "a photo of the small {}",
311
+ "a photo of the weird {}",
312
+ "a photo of the large {}",
313
+ "a photo of a cool {}",
314
+ "a photo of a small {}",
315
+ ]
316
+
317
+ imagenet_style_templates_small = [
318
+ "a painting of {}, art by *",
319
+ "a rendering of {}, art by *",
320
+ "a cropped painting of {}, art by *",
321
+ "the painting of {}, art by *",
322
+ "a clean painting of {}, art by *",
323
+ "a dirty painting of {}, art by *",
324
+ "a dark painting of {}, art by *",
325
+ "a picture of {}, art by *",
326
+ "a cool painting of {}, art by *",
327
+ "a close-up painting of {}, art by *",
328
+ "a bright painting of {}, art by *",
329
+ "a cropped painting of {}, art by *",
330
+ "a good painting of {}, art by *",
331
+ "a close-up painting of {}, art by *",
332
+ "a rendition of {}, art by *",
333
+ "a nice painting of {}, art by *",
334
+ "a small painting of {}, art by *",
335
+ "a weird painting of {}, art by *",
336
+ "a large painting of {}, art by *",
337
+ ]
338
+
339
+
340
+ class TextualInversionDataset(Dataset):
341
+ def __init__(
342
+ self,
343
+ data_root,
344
+ tokenizer,
345
+ learnable_property="object", # [object, style]
346
+ size=512,
347
+ repeats=100,
348
+ interpolation="bicubic",
349
+ flip_p=0.5,
350
+ set="train",
351
+ placeholder_token="*",
352
+ center_crop=False,
353
+ ):
354
+ self.data_root = data_root
355
+ self.tokenizer = tokenizer
356
+ self.learnable_property = learnable_property
357
+ self.size = size
358
+ self.placeholder_token = placeholder_token
359
+ self.center_crop = center_crop
360
+ self.flip_p = flip_p
361
+
362
+ self.image_paths = [
363
+ os.path.join(self.data_root, file_path)
364
+ for file_path in os.listdir(self.data_root)
365
+ ]
366
+
367
+ self.num_images = len(self.image_paths)
368
+ self._length = self.num_images
369
+
370
+ if set == "train":
371
+ self._length = self.num_images * repeats
372
+
373
+ self.interpolation = {
374
+ "linear": PIL_INTERPOLATION["linear"],
375
+ "bilinear": PIL_INTERPOLATION["bilinear"],
376
+ "bicubic": PIL_INTERPOLATION["bicubic"],
377
+ "lanczos": PIL_INTERPOLATION["lanczos"],
378
+ }[interpolation]
379
+
380
+ self.templates = (
381
+ imagenet_style_templates_small
382
+ if learnable_property == "style"
383
+ else imagenet_templates_small
384
+ )
385
+ self.flip_transform = transforms.RandomHorizontalFlip(p=self.flip_p)
386
+
387
+ def __len__(self):
388
+ return self._length
389
+
390
+ def __getitem__(self, i):
391
+ example = {}
392
+ image = Image.open(self.image_paths[i % self.num_images])
393
+
394
+ if image.mode != "RGB":
395
+ image = image.convert("RGB")
396
+
397
+ placeholder_string = self.placeholder_token
398
+ text = random.choice(self.templates).format(placeholder_string)
399
+
400
+ example["input_ids"] = self.tokenizer(
401
+ text,
402
+ padding="max_length",
403
+ truncation=True,
404
+ max_length=self.tokenizer.model_max_length,
405
+ return_tensors="pt",
406
+ ).input_ids[0]
407
+
408
+ # default to score-sde preprocessing
409
+ img = np.array(image).astype(np.uint8)
410
+
411
+ if self.center_crop:
412
+ crop = min(img.shape[0], img.shape[1])
413
+ h, w, = (
414
+ img.shape[0],
415
+ img.shape[1],
416
+ )
417
+ img = img[
418
+ (h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2
419
+ ]
420
+
421
+ image = Image.fromarray(img)
422
+ image = image.resize((self.size, self.size), resample=self.interpolation)
423
+
424
+ image = self.flip_transform(image)
425
+ image = np.array(image).astype(np.uint8)
426
+ image = (image / 127.5 - 1.0).astype(np.float32)
427
+
428
+ example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
429
+ return example
430
+
431
+
432
+ def get_full_repo_name(
433
+ model_id: str, organization: Optional[str] = None, token: Optional[str] = None
434
+ ):
435
+ if token is None:
436
+ token = HfFolder.get_token()
437
+ if organization is None:
438
+ username = whoami(token)["name"]
439
+ return f"{username}/{model_id}"
440
+ else:
441
+ return f"{organization}/{model_id}"
442
+
443
+
444
+ def freeze_params(params):
445
+ for param in params:
446
+ param.requires_grad = False
447
+
448
+
449
+ def main():
450
+ args = parse_args()
451
+ # logging_dir = os.path.join(args.output_dir, args.logging_dir)
452
+
453
+ accelerator = Accelerator(
454
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
455
+ mixed_precision=args.mixed_precision,
456
+ )
457
+
458
+ # If passed along, set the training seed now.
459
+ if args.seed is not None:
460
+ set_seed(args.seed)
461
+
462
+ # Handle the repository creation
463
+ if accelerator.is_main_process:
464
+ if args.push_to_hub:
465
+ if args.hub_model_id is None:
466
+ repo_name = get_full_repo_name(
467
+ Path(args.output_dir).name, token=args.hub_token
468
+ )
469
+ else:
470
+ repo_name = args.hub_model_id
471
+ repo = Repository(args.output_dir, clone_from=repo_name)
472
+
473
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
474
+ if "step_*" not in gitignore:
475
+ gitignore.write("step_*\n")
476
+ if "epoch_*" not in gitignore:
477
+ gitignore.write("epoch_*\n")
478
+ elif args.output_dir is not None:
479
+ os.makedirs(args.output_dir, exist_ok=True)
480
+
481
+ # Load the tokenizer and add the placeholder token as a additional special token
482
+ if args.tokenizer_name:
483
+ tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
484
+ elif args.pretrained_model_name_or_path:
485
+ tokenizer = CLIPTokenizer.from_pretrained(
486
+ args.pretrained_model_name_or_path, subfolder="tokenizer"
487
+ )
488
+
489
+ # Add the placeholder token in tokenizer
490
+ num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
491
+ if num_added_tokens == 0:
492
+ raise ValueError(
493
+ f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
494
+ " `placeholder_token` that is not already in the tokenizer."
495
+ )
496
+
497
+ # Convert the initializer_token, placeholder_token to ids
498
+ token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
499
+ # Check if initializer_token is a single token or a sequence of tokens
500
+ if len(token_ids) > 1:
501
+ raise ValueError("The initializer token must be a single token.")
502
+
503
+ initializer_token_id = token_ids[0]
504
+ placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
505
+
506
+ # Load models and create wrapper for stable diffusion
507
+ text_encoder = CLIPTextModel.from_pretrained(
508
+ args.pretrained_model_name_or_path,
509
+ subfolder="text_encoder",
510
+ revision=args.revision,
511
+ )
512
+ vae = AutoencoderKL.from_pretrained(
513
+ args.pretrained_model_name_or_path,
514
+ subfolder="vae",
515
+ revision=args.revision,
516
+ )
517
+ unet = UNet2DConditionModel.from_pretrained(
518
+ args.pretrained_model_name_or_path,
519
+ subfolder="unet",
520
+ revision=args.revision,
521
+ )
522
+
523
+ # Resize the token embeddings as we are adding new special tokens to the tokenizer
524
+ text_encoder.resize_token_embeddings(len(tokenizer))
525
+
526
+ # Initialise the newly added placeholder token with the embeddings of the initializer token
527
+ token_embeds = text_encoder.get_input_embeddings().weight.data
528
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
529
+
530
+ # Freeze vae and unet
531
+ freeze_params(vae.parameters())
532
+ freeze_params(unet.parameters())
533
+ # Freeze all parameters except for the token embeddings in text encoder
534
+ params_to_freeze = itertools.chain(
535
+ text_encoder.text_model.encoder.parameters(),
536
+ text_encoder.text_model.final_layer_norm.parameters(),
537
+ text_encoder.text_model.embeddings.position_embedding.parameters(),
538
+ )
539
+ freeze_params(params_to_freeze)
540
+
541
+ if args.scale_lr:
542
+ args.learning_rate = (
543
+ args.learning_rate
544
+ * args.gradient_accumulation_steps
545
+ * args.train_batch_size
546
+ * accelerator.num_processes
547
+ )
548
+
549
+ # Initialize the optimizer
550
+ optimizer = torch.optim.AdamW(
551
+ text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
552
+ lr=args.learning_rate,
553
+ betas=(args.adam_beta1, args.adam_beta2),
554
+ weight_decay=args.adam_weight_decay,
555
+ eps=args.adam_epsilon,
556
+ )
557
+
558
+ noise_scheduler = DDPMScheduler.from_pretrained(
559
+ args.pretrained_model_name_or_path, subfolder="scheduler"
560
+ )
561
+
562
+ train_dataset = TextualInversionDataset(
563
+ data_root=args.train_data_dir,
564
+ tokenizer=tokenizer,
565
+ size=args.resolution,
566
+ placeholder_token=args.placeholder_token,
567
+ repeats=args.repeats,
568
+ learnable_property=args.learnable_property,
569
+ center_crop=args.center_crop,
570
+ set="train",
571
+ )
572
+ train_dataloader = torch.utils.data.DataLoader(
573
+ train_dataset, batch_size=args.train_batch_size, shuffle=True
574
+ )
575
+
576
+ # Scheduler and math around the number of training steps.
577
+ overrode_max_train_steps = False
578
+ num_update_steps_per_epoch = math.ceil(
579
+ len(train_dataloader) / args.gradient_accumulation_steps
580
+ )
581
+ if args.max_train_steps is None:
582
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
583
+ overrode_max_train_steps = True
584
+
585
+ lr_scheduler = get_scheduler(
586
+ args.lr_scheduler,
587
+ optimizer=optimizer,
588
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
589
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
590
+ )
591
+
592
+ text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
593
+ text_encoder, optimizer, train_dataloader, lr_scheduler
594
+ )
595
+
596
+ # Move vae and unet to device
597
+ vae.to(accelerator.device)
598
+ unet.to(accelerator.device)
599
+
600
+ # Keep vae and unet in eval model as we don't train these
601
+ vae.eval()
602
+ unet.eval()
603
+
604
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
605
+ num_update_steps_per_epoch = math.ceil(
606
+ len(train_dataloader) / args.gradient_accumulation_steps
607
+ )
608
+ if overrode_max_train_steps:
609
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
610
+ # Afterwards we recalculate our number of training epochs
611
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
612
+
613
+ # We need to initialize the trackers we use, and also store our configuration.
614
+ # The trackers initializes automatically on the main process.
615
+ if accelerator.is_main_process:
616
+ accelerator.init_trackers("textual_inversion", config=vars(args))
617
+
618
+ # Train!
619
+ total_batch_size = (
620
+ args.train_batch_size
621
+ * accelerator.num_processes
622
+ * args.gradient_accumulation_steps
623
+ )
624
+
625
+ logger.info("***** Running training *****")
626
+ logger.info(f" Num examples = {len(train_dataset)}")
627
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
628
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
629
+ logger.info(
630
+ f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
631
+ )
632
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
633
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
634
+ # Only show the progress bar once on each machine.
635
+ progress_bar = tqdm(
636
+ range(args.max_train_steps), disable=not accelerator.is_local_main_process
637
+ )
638
+ progress_bar.set_description("Steps")
639
+ global_step = 0
640
+
641
+ for epoch in range(args.num_train_epochs):
642
+ text_encoder.train()
643
+ for step, batch in enumerate(train_dataloader):
644
+ with accelerator.accumulate(text_encoder):
645
+ # Convert images to latent space
646
+ latents = (
647
+ vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
648
+ )
649
+ latents = latents * 0.18215
650
+
651
+ # Sample noise that we'll add to the latents
652
+ noise = torch.randn(latents.shape).to(latents.device)
653
+ bsz = latents.shape[0]
654
+ # Sample a random timestep for each image
655
+ timesteps = torch.randint(
656
+ 0,
657
+ noise_scheduler.config.num_train_timesteps,
658
+ (bsz,),
659
+ device=latents.device,
660
+ ).long()
661
+
662
+ # Add noise to the latents according to the noise magnitude at each timestep
663
+ # (this is the forward diffusion process)
664
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
665
+
666
+ # Get the text embedding for conditioning
667
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
668
+
669
+ # Predict the noise residual
670
+ model_pred = unet(
671
+ noisy_latents, timesteps, encoder_hidden_states
672
+ ).sample
673
+
674
+ # Get the target for loss depending on the prediction type
675
+ if noise_scheduler.config.prediction_type == "epsilon":
676
+ target = noise
677
+ elif noise_scheduler.config.prediction_type == "v_prediction":
678
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
679
+ else:
680
+ raise ValueError(
681
+ f"Unknown prediction type {noise_scheduler.config.prediction_type}"
682
+ )
683
+
684
+ loss = (
685
+ F.mse_loss(model_pred, target, reduction="none")
686
+ .mean([1, 2, 3])
687
+ .mean()
688
+ )
689
+ accelerator.backward(loss)
690
+
691
+ # Zero out the gradients for all token embeddings except the newly added
692
+ # embeddings for the concept, as we only want to optimize the concept embeddings
693
+ if accelerator.num_processes > 1:
694
+ grads = text_encoder.module.get_input_embeddings().weight.grad
695
+ else:
696
+ grads = text_encoder.get_input_embeddings().weight.grad
697
+ # Get the index for tokens that we want to zero the grads for
698
+ index_grads_to_zero = (
699
+ torch.arange(len(tokenizer)) != placeholder_token_id
700
+ )
701
+ grads.data[index_grads_to_zero, :] = grads.data[
702
+ index_grads_to_zero, :
703
+ ].fill_(0)
704
+
705
+ optimizer.step()
706
+ lr_scheduler.step()
707
+ optimizer.zero_grad()
708
+
709
+ # Checks if the accelerator has performed an optimization step behind the scenes
710
+ if accelerator.sync_gradients:
711
+ progress_bar.update(1)
712
+ global_step += 1
713
+ if global_step % args.save_steps == 0:
714
+ save_path = os.path.join(
715
+ args.output_dir, f"learned_embeds-steps-{global_step}.bin"
716
+ )
717
+ save_progress(
718
+ text_encoder, placeholder_token_id, accelerator, args, save_path
719
+ )
720
+
721
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
722
+ progress_bar.set_postfix(**logs)
723
+ accelerator.log(logs, step=global_step)
724
+
725
+ if global_step >= args.max_train_steps:
726
+ break
727
+
728
+ accelerator.wait_for_everyone()
729
+
730
+ # Create the pipeline using using the trained modules and save it.
731
+ if accelerator.is_main_process:
732
+ if args.push_to_hub and args.only_save_embeds:
733
+ logger.warn(
734
+ "Enabling full model saving because --push_to_hub=True was specified."
735
+ )
736
+ save_full_model = True
737
+ else:
738
+ save_full_model = not args.only_save_embeds
739
+ if save_full_model:
740
+ pipeline = StableDiffusionPipeline(
741
+ text_encoder=accelerator.unwrap_model(text_encoder),
742
+ vae=vae,
743
+ unet=unet,
744
+ tokenizer=tokenizer,
745
+ scheduler=PNDMScheduler.from_pretrained(
746
+ args.pretrained_model_name_or_path, subfolder="scheduler"
747
+ ),
748
+ safety_checker=StableDiffusionSafetyChecker.from_pretrained(
749
+ "CompVis/stable-diffusion-safety-checker"
750
+ ),
751
+ feature_extractor=CLIPFeatureExtractor.from_pretrained(
752
+ "openai/clip-vit-base-patch32"
753
+ ),
754
+ )
755
+ pipeline.save_pretrained(args.output_dir)
756
+ # Save the newly trained embeddings
757
+ save_path = os.path.join(args.output_dir, "learned_embeds.bin")
758
+ save_progress(text_encoder, placeholder_token_id, accelerator, args, save_path)
759
+
760
+ if args.push_to_hub:
761
+ repo.push_to_hub(
762
+ commit_message="End of training", blocking=False, auto_lfs_prune=True
763
+ )
764
+
765
+ accelerator.end_training()
766
+
767
+
768
+ if __name__ == "__main__":
769
+ main()