EAraid12 hysts HF staff commited on
Commit
73fd4a6
0 Parent(s):

Duplicate from lora-library/LoRA-DreamBooth-Training-UI

Browse files

Co-authored-by: hysts <[email protected]>

Files changed (18) hide show
  1. .gitattributes +34 -0
  2. .gitignore +165 -0
  3. .pre-commit-config.yaml +37 -0
  4. .style.yapf +5 -0
  5. LICENSE +21 -0
  6. README.md +15 -0
  7. app.py +76 -0
  8. app_inference.py +176 -0
  9. app_training.py +144 -0
  10. app_upload.py +100 -0
  11. constants.py +6 -0
  12. inference.py +94 -0
  13. requirements.txt +14 -0
  14. style.css +3 -0
  15. train_dreambooth_lora.py +1026 -0
  16. trainer.py +166 -0
  17. uploader.py +42 -0
  18. utils.py +59 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ training_data/
2
+ experiments/
3
+ wandb/
4
+
5
+
6
+ # Byte-compiled / optimized / DLL files
7
+ __pycache__/
8
+ *.py[cod]
9
+ *$py.class
10
+
11
+ # C extensions
12
+ *.so
13
+
14
+ # Distribution / packaging
15
+ .Python
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+ cover/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ .pybuilder/
81
+ target/
82
+
83
+ # Jupyter Notebook
84
+ .ipynb_checkpoints
85
+
86
+ # IPython
87
+ profile_default/
88
+ ipython_config.py
89
+
90
+ # pyenv
91
+ # For a library or package, you might want to ignore these files since the code is
92
+ # intended to run in multiple environments; otherwise, check them in:
93
+ # .python-version
94
+
95
+ # pipenv
96
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
97
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
98
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
99
+ # install all needed dependencies.
100
+ #Pipfile.lock
101
+
102
+ # poetry
103
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
104
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
105
+ # commonly ignored for libraries.
106
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
107
+ #poetry.lock
108
+
109
+ # pdm
110
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
111
+ #pdm.lock
112
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
113
+ # in version control.
114
+ # https://pdm.fming.dev/#use-with-ide
115
+ .pdm.toml
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exclude: train_dreambooth_lora.py
2
+ repos:
3
+ - repo: https://github.com/pre-commit/pre-commit-hooks
4
+ rev: v4.2.0
5
+ hooks:
6
+ - id: check-executables-have-shebangs
7
+ - id: check-json
8
+ - id: check-merge-conflict
9
+ - id: check-shebang-scripts-are-executable
10
+ - id: check-toml
11
+ - id: check-yaml
12
+ - id: double-quote-string-fixer
13
+ - id: end-of-file-fixer
14
+ - id: mixed-line-ending
15
+ args: ['--fix=lf']
16
+ - id: requirements-txt-fixer
17
+ - id: trailing-whitespace
18
+ - repo: https://github.com/myint/docformatter
19
+ rev: v1.4
20
+ hooks:
21
+ - id: docformatter
22
+ args: ['--in-place']
23
+ - repo: https://github.com/pycqa/isort
24
+ rev: 5.10.1
25
+ hooks:
26
+ - id: isort
27
+ - repo: https://github.com/pre-commit/mirrors-mypy
28
+ rev: v0.991
29
+ hooks:
30
+ - id: mypy
31
+ args: ['--ignore-missing-imports']
32
+ additional_dependencies: ['types-python-slugify']
33
+ - repo: https://github.com/google/yapf
34
+ rev: v0.32.0
35
+ hooks:
36
+ - id: yapf
37
+ args: ['--parallel', '--in-place']
.style.yapf ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ [style]
2
+ based_on_style = pep8
3
+ blank_line_before_nested_class_or_def = false
4
+ spaces_before_comment = 2
5
+ split_before_logical_operator = true
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 hysts
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: LoRA DreamBooth Training UI
3
+ emoji: ⚡
4
+ colorFrom: red
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 3.16.2
8
+ python_version: 3.10.9
9
+ app_file: app.py
10
+ pinned: false
11
+ license: mit
12
+ duplicated_from: lora-library/LoRA-DreamBooth-Training-UI
13
+ ---
14
+
15
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+
10
+ from app_inference import create_inference_demo
11
+ from app_training import create_training_demo
12
+ from app_upload import create_upload_demo
13
+ from inference import InferencePipeline
14
+ from trainer import Trainer
15
+
16
+ TITLE = '# LoRA DreamBooth Training UI'
17
+
18
+ ORIGINAL_SPACE_ID = 'lora-library/LoRA-DreamBooth-Training-UI'
19
+ SPACE_ID = os.getenv('SPACE_ID', ORIGINAL_SPACE_ID)
20
+ SHARED_UI_WARNING = f'''# Attention - This Space doesn't work in this shared UI. You can duplicate and use it with a paid private T4 GPU.
21
+
22
+ <center><a class="duplicate-button" style="display:inline-block" target="_blank" href="https://huggingface.co/spaces/{SPACE_ID}?duplicate=true"><img src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAABAAAAAQCAYAAAAf8/9hAAAAAXNSR0IArs4c6QAAAP5JREFUOE+lk7FqAkEURY+ltunEgFXS2sZGIbXfEPdLlnxJyDdYB62sbbUKpLbVNhyYFzbrrA74YJlh9r079973psed0cvUD4A+4HoCjsA85X0Dfn/RBLBgBDxnQPfAEJgBY+A9gALA4tcbamSzS4xq4FOQAJgCDwV2CPKV8tZAJcAjMMkUe1vX+U+SMhfAJEHasQIWmXNN3abzDwHUrgcRGmYcgKe0bxrblHEB4E/pndMazNpSZGcsZdBlYJcEL9Afo75molJyM2FxmPgmgPqlWNLGfwZGG6UiyEvLzHYDmoPkDDiNm9JR9uboiONcBXrpY1qmgs21x1QwyZcpvxt9NS09PlsPAAAAAElFTkSuQmCC&logoWidth=14" alt="Duplicate Space"></a></center>
23
+ '''
24
+
25
+ if os.getenv('SYSTEM') == 'spaces' and SPACE_ID != ORIGINAL_SPACE_ID:
26
+ SETTINGS = f'<a href="https://huggingface.co/spaces/{SPACE_ID}/settings">Settings</a>'
27
+ else:
28
+ SETTINGS = 'Settings'
29
+ CUDA_NOT_AVAILABLE_WARNING = f'''# Attention - Running on CPU.
30
+ <center>
31
+ You can assign a GPU in the {SETTINGS} tab if you are running this on HF Spaces.
32
+ "T4 small" is sufficient to run this demo.
33
+ </center>
34
+ '''
35
+
36
+ HF_TOKEN_NOT_SPECIFIED_WARNING = f'''# Attention - The environment variable `HF_TOKEN` is not specified. Please specify your Hugging Face token with write permission as the value of it.
37
+ <center>
38
+ You can check and create your Hugging Face tokens <a href="https://huggingface.co/settings/tokens" target="_blank">here</a>.
39
+ You can specify environment variables in the "Repository secrets" section of the {SETTINGS} tab.
40
+ </center>
41
+ '''
42
+
43
+ HF_TOKEN = os.getenv('HF_TOKEN')
44
+
45
+
46
+ def show_warning(warning_text: str) -> gr.Blocks:
47
+ with gr.Blocks() as demo:
48
+ with gr.Box():
49
+ gr.Markdown(warning_text)
50
+ return demo
51
+
52
+
53
+ pipe = InferencePipeline(HF_TOKEN)
54
+ trainer = Trainer(HF_TOKEN)
55
+
56
+ with gr.Blocks(css='style.css') as demo:
57
+ if os.getenv('IS_SHARED_UI'):
58
+ show_warning(SHARED_UI_WARNING)
59
+ if not torch.cuda.is_available():
60
+ show_warning(CUDA_NOT_AVAILABLE_WARNING)
61
+ if not HF_TOKEN:
62
+ show_warning(HF_TOKEN_NOT_SPECIFIED_WARNING)
63
+
64
+ gr.Markdown(TITLE)
65
+ with gr.Tabs():
66
+ with gr.TabItem('Train'):
67
+ create_training_demo(trainer, pipe)
68
+ with gr.TabItem('Test'):
69
+ create_inference_demo(pipe, HF_TOKEN)
70
+ with gr.TabItem('Upload'):
71
+ gr.Markdown('''
72
+ - You can use this tab to upload models later if you choose not to upload models in training time or if upload in training time failed.
73
+ ''')
74
+ create_upload_demo(HF_TOKEN)
75
+
76
+ demo.queue(max_size=1).launch(share=False)
app_inference.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import enum
6
+
7
+ import gradio as gr
8
+ from huggingface_hub import HfApi
9
+
10
+ from inference import InferencePipeline
11
+ from utils import find_exp_dirs
12
+
13
+ SAMPLE_MODEL_IDS = [
14
+ 'patrickvonplaten/lora_dreambooth_dog_example',
15
+ 'sayakpaul/sd-model-finetuned-lora-t4',
16
+ ]
17
+
18
+
19
+ class ModelSource(enum.Enum):
20
+ SAMPLE = 'Sample'
21
+ HUB_LIB = 'Hub (lora-library)'
22
+ LOCAL = 'Local'
23
+
24
+
25
+ class InferenceUtil:
26
+ def __init__(self, hf_token: str | None):
27
+ self.hf_token = hf_token
28
+
29
+ @staticmethod
30
+ def load_sample_lora_model_list():
31
+ return gr.update(choices=SAMPLE_MODEL_IDS, value=SAMPLE_MODEL_IDS[0])
32
+
33
+ def load_hub_lora_model_list(self) -> dict:
34
+ api = HfApi(token=self.hf_token)
35
+ choices = [
36
+ info.modelId for info in api.list_models(author='lora-library')
37
+ ]
38
+ return gr.update(choices=choices,
39
+ value=choices[0] if choices else None)
40
+
41
+ @staticmethod
42
+ def load_local_lora_model_list() -> dict:
43
+ choices = find_exp_dirs()
44
+ return gr.update(choices=choices,
45
+ value=choices[0] if choices else None)
46
+
47
+ def reload_lora_model_list(self, model_source: str) -> dict:
48
+ if model_source == ModelSource.SAMPLE.value:
49
+ return self.load_sample_lora_model_list()
50
+ elif model_source == ModelSource.HUB_LIB.value:
51
+ return self.load_hub_lora_model_list()
52
+ elif model_source == ModelSource.LOCAL.value:
53
+ return self.load_local_lora_model_list()
54
+ else:
55
+ raise ValueError
56
+
57
+ def load_model_info(self, lora_model_id: str) -> tuple[str, str]:
58
+ try:
59
+ card = InferencePipeline.get_model_card(lora_model_id,
60
+ self.hf_token)
61
+ except Exception:
62
+ return '', ''
63
+ base_model = getattr(card.data, 'base_model', '')
64
+ instance_prompt = getattr(card.data, 'instance_prompt', '')
65
+ return base_model, instance_prompt
66
+
67
+ def reload_lora_model_list_and_update_model_info(
68
+ self, model_source: str) -> tuple[dict, str, str]:
69
+ model_list_update = self.reload_lora_model_list(model_source)
70
+ model_list = model_list_update['choices']
71
+ model_info = self.load_model_info(model_list[0] if model_list else '')
72
+ return model_list_update, *model_info
73
+
74
+
75
+ def create_inference_demo(pipe: InferencePipeline,
76
+ hf_token: str | None = None) -> gr.Blocks:
77
+ app = InferenceUtil(hf_token)
78
+
79
+ with gr.Blocks() as demo:
80
+ with gr.Row():
81
+ with gr.Column():
82
+ with gr.Box():
83
+ model_source = gr.Radio(
84
+ label='Model Source',
85
+ choices=[_.value for _ in ModelSource],
86
+ value=ModelSource.SAMPLE.value)
87
+ reload_button = gr.Button('Reload Model List')
88
+ lora_model_id = gr.Dropdown(label='LoRA Model ID',
89
+ choices=SAMPLE_MODEL_IDS,
90
+ value=SAMPLE_MODEL_IDS[0])
91
+ with gr.Accordion(
92
+ label=
93
+ 'Model info (Base model and instance prompt used for training)',
94
+ open=False):
95
+ with gr.Row():
96
+ base_model_used_for_training = gr.Text(
97
+ label='Base model', interactive=False)
98
+ instance_prompt_used_for_training = gr.Text(
99
+ label='Instance prompt', interactive=False)
100
+ prompt = gr.Textbox(
101
+ label='Prompt',
102
+ max_lines=1,
103
+ placeholder='Example: "A picture of a sks dog in a bucket"'
104
+ )
105
+ alpha = gr.Slider(label='LoRA alpha',
106
+ minimum=0,
107
+ maximum=2,
108
+ step=0.05,
109
+ value=1)
110
+ seed = gr.Slider(label='Seed',
111
+ minimum=0,
112
+ maximum=100000,
113
+ step=1,
114
+ value=0)
115
+ with gr.Accordion('Other Parameters', open=False):
116
+ num_steps = gr.Slider(label='Number of Steps',
117
+ minimum=0,
118
+ maximum=100,
119
+ step=1,
120
+ value=25)
121
+ guidance_scale = gr.Slider(label='CFG Scale',
122
+ minimum=0,
123
+ maximum=50,
124
+ step=0.1,
125
+ value=7.5)
126
+
127
+ run_button = gr.Button('Generate')
128
+
129
+ gr.Markdown('''
130
+ - After training, you can press "Reload Model List" button to load your trained model names.
131
+ ''')
132
+ with gr.Column():
133
+ result = gr.Image(label='Result')
134
+
135
+ model_source.change(
136
+ fn=app.reload_lora_model_list_and_update_model_info,
137
+ inputs=model_source,
138
+ outputs=[
139
+ lora_model_id,
140
+ base_model_used_for_training,
141
+ instance_prompt_used_for_training,
142
+ ])
143
+ reload_button.click(
144
+ fn=app.reload_lora_model_list_and_update_model_info,
145
+ inputs=model_source,
146
+ outputs=[
147
+ lora_model_id,
148
+ base_model_used_for_training,
149
+ instance_prompt_used_for_training,
150
+ ])
151
+ lora_model_id.change(fn=app.load_model_info,
152
+ inputs=lora_model_id,
153
+ outputs=[
154
+ base_model_used_for_training,
155
+ instance_prompt_used_for_training,
156
+ ])
157
+ inputs = [
158
+ lora_model_id,
159
+ prompt,
160
+ alpha,
161
+ seed,
162
+ num_steps,
163
+ guidance_scale,
164
+ ]
165
+ prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
166
+ run_button.click(fn=pipe.run, inputs=inputs, outputs=result)
167
+ return demo
168
+
169
+
170
+ if __name__ == '__main__':
171
+ import os
172
+
173
+ hf_token = os.getenv('HF_TOKEN')
174
+ pipe = InferencePipeline(hf_token)
175
+ demo = create_inference_demo(pipe, hf_token)
176
+ demo.queue(max_size=10).launch(share=False)
app_training.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import os
6
+
7
+ import gradio as gr
8
+
9
+ from constants import UploadTarget
10
+ from inference import InferencePipeline
11
+ from trainer import Trainer
12
+
13
+
14
+ def create_training_demo(trainer: Trainer,
15
+ pipe: InferencePipeline | None = None) -> gr.Blocks:
16
+ with gr.Blocks() as demo:
17
+ with gr.Row():
18
+ with gr.Column():
19
+ with gr.Box():
20
+ gr.Markdown('Training Data')
21
+ instance_images = gr.Files(label='Instance images')
22
+ instance_prompt = gr.Textbox(label='Instance prompt',
23
+ max_lines=1)
24
+ gr.Markdown('''
25
+ - Upload images of the style you are planning on training on.
26
+ - For an instance prompt, use a unique, made up word to avoid collisions.
27
+ ''')
28
+ with gr.Box():
29
+ gr.Markdown('Output Model')
30
+ output_model_name = gr.Text(label='Name of your model',
31
+ max_lines=1)
32
+ delete_existing_model = gr.Checkbox(
33
+ label='Delete existing model of the same name',
34
+ value=False)
35
+ validation_prompt = gr.Text(label='Validation Prompt')
36
+ with gr.Box():
37
+ gr.Markdown('Upload Settings')
38
+ with gr.Row():
39
+ upload_to_hub = gr.Checkbox(
40
+ label='Upload model to Hub', value=True)
41
+ use_private_repo = gr.Checkbox(label='Private',
42
+ value=True)
43
+ delete_existing_repo = gr.Checkbox(
44
+ label='Delete existing repo of the same name',
45
+ value=False)
46
+ upload_to = gr.Radio(
47
+ label='Upload to',
48
+ choices=[_.value for _ in UploadTarget],
49
+ value=UploadTarget.LORA_LIBRARY.value)
50
+ gr.Markdown('''
51
+ - By default, trained models will be uploaded to [LoRA Library](https://huggingface.co/lora-library) (see [this example model](https://huggingface.co/lora-library/lora-dreambooth-sample-dog)).
52
+ - You can also choose "Personal Profile", in which case, the model will be uploaded to https://huggingface.co/{your_username}/{model_name}.
53
+ ''')
54
+
55
+ with gr.Box():
56
+ gr.Markdown('Training Parameters')
57
+ with gr.Row():
58
+ base_model = gr.Text(
59
+ label='Base Model',
60
+ value='stabilityai/stable-diffusion-2-1-base',
61
+ max_lines=1)
62
+ resolution = gr.Dropdown(choices=['512', '768'],
63
+ value='512',
64
+ label='Resolution')
65
+ num_training_steps = gr.Number(
66
+ label='Number of Training Steps', value=1000, precision=0)
67
+ learning_rate = gr.Number(label='Learning Rate', value=0.0001)
68
+ gradient_accumulation = gr.Number(
69
+ label='Number of Gradient Accumulation',
70
+ value=1,
71
+ precision=0)
72
+ seed = gr.Slider(label='Seed',
73
+ minimum=0,
74
+ maximum=100000,
75
+ step=1,
76
+ value=0)
77
+ fp16 = gr.Checkbox(label='FP16', value=True)
78
+ use_8bit_adam = gr.Checkbox(label='Use 8bit Adam', value=True)
79
+ checkpointing_steps = gr.Number(label='Checkpointing Steps',
80
+ value=100,
81
+ precision=0)
82
+ use_wandb = gr.Checkbox(label='Use W&B',
83
+ value=False,
84
+ interactive=bool(
85
+ os.getenv('WANDB_API_KEY')))
86
+ validation_epochs = gr.Number(label='Validation Epochs',
87
+ value=100,
88
+ precision=0)
89
+ gr.Markdown('''
90
+ - The base model must be a model that is compatible with [diffusers](https://github.com/huggingface/diffusers) library.
91
+ - It takes a few minutes to download the base model first.
92
+ - It will take about 8 minutes to train for 1000 steps with a T4 GPU.
93
+ - You may want to try a small number of steps first, like 1, to see if everything works fine in your environment.
94
+ - You can check the training status by pressing the "Open logs" button if you are running this on your Space.
95
+ - You need to set the environment variable `WANDB_API_KEY` if you'd like to use [W&B](https://wandb.ai/site). See [W&B documentation](https://docs.wandb.ai/guides/track/advanced/environment-variables).
96
+ - **Note:** Due to [this issue](https://github.com/huggingface/accelerate/issues/944), currently, training will not terminate properly if you use W&B.
97
+ ''')
98
+
99
+ remove_gpu_after_training = gr.Checkbox(
100
+ label='Remove GPU after training',
101
+ value=False,
102
+ interactive=bool(os.getenv('SPACE_ID')),
103
+ visible=False)
104
+ run_button = gr.Button('Start Training')
105
+
106
+ with gr.Box():
107
+ gr.Markdown('Output message')
108
+ output_message = gr.Markdown()
109
+
110
+ if pipe is not None:
111
+ run_button.click(fn=pipe.clear)
112
+ run_button.click(fn=trainer.run,
113
+ inputs=[
114
+ instance_images,
115
+ instance_prompt,
116
+ output_model_name,
117
+ delete_existing_model,
118
+ validation_prompt,
119
+ base_model,
120
+ resolution,
121
+ num_training_steps,
122
+ learning_rate,
123
+ gradient_accumulation,
124
+ seed,
125
+ fp16,
126
+ use_8bit_adam,
127
+ checkpointing_steps,
128
+ use_wandb,
129
+ validation_epochs,
130
+ upload_to_hub,
131
+ use_private_repo,
132
+ delete_existing_repo,
133
+ upload_to,
134
+ remove_gpu_after_training,
135
+ ],
136
+ outputs=output_message)
137
+ return demo
138
+
139
+
140
+ if __name__ == '__main__':
141
+ hf_token = os.getenv('HF_TOKEN')
142
+ trainer = Trainer(hf_token)
143
+ demo = create_training_demo(trainer)
144
+ demo.queue(max_size=1).launch(share=False)
app_upload.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ from __future__ import annotations
4
+
5
+ import pathlib
6
+
7
+ import gradio as gr
8
+ import slugify
9
+
10
+ from constants import UploadTarget
11
+ from uploader import Uploader
12
+ from utils import find_exp_dirs
13
+
14
+
15
+ class LoRAModelUploader(Uploader):
16
+ def upload_lora_model(
17
+ self,
18
+ folder_path: str,
19
+ repo_name: str,
20
+ upload_to: str,
21
+ private: bool,
22
+ delete_existing_repo: bool,
23
+ ) -> str:
24
+ if not folder_path:
25
+ raise ValueError
26
+ if not repo_name:
27
+ repo_name = pathlib.Path(folder_path).name
28
+ repo_name = slugify.slugify(repo_name)
29
+
30
+ if upload_to == UploadTarget.PERSONAL_PROFILE.value:
31
+ organization = ''
32
+ elif upload_to == UploadTarget.LORA_LIBRARY.value:
33
+ organization = 'lora-library'
34
+ else:
35
+ raise ValueError
36
+
37
+ return self.upload(folder_path,
38
+ repo_name,
39
+ organization=organization,
40
+ private=private,
41
+ delete_existing_repo=delete_existing_repo)
42
+
43
+
44
+ def load_local_lora_model_list() -> dict:
45
+ choices = find_exp_dirs(ignore_repo=True)
46
+ return gr.update(choices=choices, value=choices[0] if choices else None)
47
+
48
+
49
+ def create_upload_demo(hf_token: str | None) -> gr.Blocks:
50
+ uploader = LoRAModelUploader(hf_token)
51
+ model_dirs = find_exp_dirs(ignore_repo=True)
52
+
53
+ with gr.Blocks() as demo:
54
+ with gr.Box():
55
+ gr.Markdown('Local Models')
56
+ reload_button = gr.Button('Reload Model List')
57
+ model_dir = gr.Dropdown(
58
+ label='Model names',
59
+ choices=model_dirs,
60
+ value=model_dirs[0] if model_dirs else None)
61
+ with gr.Box():
62
+ gr.Markdown('Upload Settings')
63
+ with gr.Row():
64
+ use_private_repo = gr.Checkbox(label='Private', value=True)
65
+ delete_existing_repo = gr.Checkbox(
66
+ label='Delete existing repo of the same name', value=False)
67
+ upload_to = gr.Radio(label='Upload to',
68
+ choices=[_.value for _ in UploadTarget],
69
+ value=UploadTarget.LORA_LIBRARY.value)
70
+ model_name = gr.Textbox(label='Model Name')
71
+ upload_button = gr.Button('Upload')
72
+ gr.Markdown('''
73
+ - You can upload your trained model to your personal profile (i.e. https://huggingface.co/{your_username}/{model_name}) or to the public [LoRA Concepts Library](https://huggingface.co/lora-library) (i.e. https://huggingface.co/lora-library/{model_name}).
74
+ ''')
75
+ with gr.Box():
76
+ gr.Markdown('Output message')
77
+ output_message = gr.Markdown()
78
+
79
+ reload_button.click(fn=load_local_lora_model_list,
80
+ inputs=None,
81
+ outputs=model_dir)
82
+ upload_button.click(fn=uploader.upload_lora_model,
83
+ inputs=[
84
+ model_dir,
85
+ model_name,
86
+ upload_to,
87
+ use_private_repo,
88
+ delete_existing_repo,
89
+ ],
90
+ outputs=output_message)
91
+
92
+ return demo
93
+
94
+
95
+ if __name__ == '__main__':
96
+ import os
97
+
98
+ hf_token = os.getenv('HF_TOKEN')
99
+ demo = create_upload_demo(hf_token)
100
+ demo.queue(max_size=1).launch(share=False)
constants.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ import enum
2
+
3
+
4
+ class UploadTarget(enum.Enum):
5
+ PERSONAL_PROFILE = 'Personal Profile'
6
+ LORA_LIBRARY = 'LoRA Library'
inference.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import gc
4
+ import pathlib
5
+
6
+ import gradio as gr
7
+ import PIL.Image
8
+ import torch
9
+ from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
10
+ from huggingface_hub import ModelCard
11
+
12
+
13
+ class InferencePipeline:
14
+ def __init__(self, hf_token: str | None = None):
15
+ self.hf_token = hf_token
16
+ self.pipe = None
17
+ self.device = torch.device(
18
+ 'cuda:0' if torch.cuda.is_available() else 'cpu')
19
+ self.lora_model_id = None
20
+ self.base_model_id = None
21
+
22
+ def clear(self) -> None:
23
+ self.lora_model_id = None
24
+ self.base_model_id = None
25
+ del self.pipe
26
+ self.pipe = None
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+ @staticmethod
31
+ def check_if_model_is_local(lora_model_id: str) -> bool:
32
+ return pathlib.Path(lora_model_id).exists()
33
+
34
+ @staticmethod
35
+ def get_model_card(model_id: str,
36
+ hf_token: str | None = None) -> ModelCard:
37
+ if InferencePipeline.check_if_model_is_local(model_id):
38
+ card_path = (pathlib.Path(model_id) / 'README.md').as_posix()
39
+ else:
40
+ card_path = model_id
41
+ return ModelCard.load(card_path, token=hf_token)
42
+
43
+ @staticmethod
44
+ def get_base_model_info(lora_model_id: str,
45
+ hf_token: str | None = None) -> str:
46
+ card = InferencePipeline.get_model_card(lora_model_id, hf_token)
47
+ return card.data.base_model
48
+
49
+ def load_pipe(self, lora_model_id: str) -> None:
50
+ if lora_model_id == self.lora_model_id:
51
+ return
52
+ base_model_id = self.get_base_model_info(lora_model_id, self.hf_token)
53
+ if base_model_id != self.base_model_id:
54
+ if self.device.type == 'cpu':
55
+ pipe = DiffusionPipeline.from_pretrained(
56
+ base_model_id, use_auth_token=self.hf_token)
57
+ else:
58
+ pipe = DiffusionPipeline.from_pretrained(
59
+ base_model_id,
60
+ torch_dtype=torch.float16,
61
+ use_auth_token=self.hf_token)
62
+ pipe = pipe.to(self.device)
63
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(
64
+ pipe.scheduler.config)
65
+ self.pipe = pipe
66
+ self.pipe.unet.load_attn_procs( # type: ignore
67
+ lora_model_id, use_auth_token=self.hf_token)
68
+
69
+ self.lora_model_id = lora_model_id # type: ignore
70
+ self.base_model_id = base_model_id # type: ignore
71
+
72
+ def run(
73
+ self,
74
+ lora_model_id: str,
75
+ prompt: str,
76
+ lora_scale: float,
77
+ seed: int,
78
+ n_steps: int,
79
+ guidance_scale: float,
80
+ ) -> PIL.Image.Image:
81
+ if not torch.cuda.is_available():
82
+ raise gr.Error('CUDA is not available.')
83
+
84
+ self.load_pipe(lora_model_id)
85
+
86
+ generator = torch.Generator(device=self.device).manual_seed(seed)
87
+ out = self.pipe(
88
+ prompt,
89
+ num_inference_steps=n_steps,
90
+ guidance_scale=guidance_scale,
91
+ generator=generator,
92
+ cross_attention_kwargs={'scale': lora_scale},
93
+ ) # type: ignore
94
+ return out.images[0]
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.15.0
2
+ bitsandbytes==0.36.0.post2
3
+ datasets==2.8.0
4
+ git+https://github.com/huggingface/diffusers@31be42209ddfdb69d9640a777b32e9b5c6259bf0#egg=diffusers
5
+ ftfy==6.1.1
6
+ gradio==3.16.2
7
+ huggingface-hub==0.12.0
8
+ Pillow==9.4.0
9
+ python-slugify==7.0.0
10
+ tensorboard==2.11.2
11
+ torch==1.13.1
12
+ torchvision==0.14.1
13
+ transformers==4.26.0
14
+ wandb==0.13.9
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
train_dreambooth_lora.py ADDED
@@ -0,0 +1,1026 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding=utf-8
3
+ #
4
+ # This file is adapted from https://github.com/huggingface/diffusers/blob/febaf863026bd014b7a14349336544fc109d0f57/examples/dreambooth/train_dreambooth_lora.py
5
+ # The original license is as below:
6
+ #
7
+ # Copyright 2022 The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+
20
+ import argparse
21
+ import hashlib
22
+ import logging
23
+ import math
24
+ import os
25
+ import warnings
26
+ from pathlib import Path
27
+ from typing import Optional
28
+
29
+ import numpy as np
30
+ import torch
31
+ import torch.nn.functional as F
32
+ import torch.utils.checkpoint
33
+ from torch.utils.data import Dataset
34
+
35
+ import datasets
36
+ import diffusers
37
+ import transformers
38
+ from accelerate import Accelerator
39
+ from accelerate.logging import get_logger
40
+ from accelerate.utils import set_seed
41
+ from diffusers import (
42
+ AutoencoderKL,
43
+ DDPMScheduler,
44
+ DiffusionPipeline,
45
+ DPMSolverMultistepScheduler,
46
+ UNet2DConditionModel,
47
+ )
48
+ from diffusers.loaders import AttnProcsLayers
49
+ from diffusers.models.cross_attention import LoRACrossAttnProcessor
50
+ from diffusers.optimization import get_scheduler
51
+ from diffusers.utils import check_min_version, is_wandb_available
52
+ from diffusers.utils.import_utils import is_xformers_available
53
+ from huggingface_hub import HfFolder, Repository, create_repo, whoami
54
+ from PIL import Image
55
+ from torchvision import transforms
56
+ from tqdm.auto import tqdm
57
+ from transformers import AutoTokenizer, PretrainedConfig
58
+
59
+
60
+ # Will error if the minimal version of diffusers is not installed. Remove at your own risks.
61
+ check_min_version("0.12.0.dev0")
62
+
63
+ logger = get_logger(__name__)
64
+
65
+
66
+ def save_model_card(repo_name, images=None, base_model=str, prompt=str, repo_folder=None):
67
+ img_str = ""
68
+ for i, image in enumerate(images):
69
+ image.save(os.path.join(repo_folder, f"image_{i}.png"))
70
+ img_str += f"![img_{i}](./image_{i}.png)\n"
71
+
72
+ yaml = f"""
73
+ ---
74
+ license: creativeml-openrail-m
75
+ base_model: {base_model}
76
+ tags:
77
+ - stable-diffusion
78
+ - stable-diffusion-diffusers
79
+ - text-to-image
80
+ - diffusers
81
+ - lora
82
+ inference: true
83
+ ---
84
+ """
85
+ model_card = f"""
86
+ # LoRA DreamBooth - {repo_name}
87
+
88
+ These are LoRA adaption weights for {repo_name}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
89
+ {img_str}
90
+ """
91
+ with open(os.path.join(repo_folder, "README.md"), "w") as f:
92
+ f.write(yaml + model_card)
93
+
94
+
95
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
96
+ text_encoder_config = PretrainedConfig.from_pretrained(
97
+ pretrained_model_name_or_path,
98
+ subfolder="text_encoder",
99
+ revision=revision,
100
+ )
101
+ model_class = text_encoder_config.architectures[0]
102
+
103
+ if model_class == "CLIPTextModel":
104
+ from transformers import CLIPTextModel
105
+
106
+ return CLIPTextModel
107
+ elif model_class == "RobertaSeriesModelWithTransformation":
108
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
109
+
110
+ return RobertaSeriesModelWithTransformation
111
+ else:
112
+ raise ValueError(f"{model_class} is not supported.")
113
+
114
+
115
+ def parse_args(input_args=None):
116
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
117
+ parser.add_argument(
118
+ "--pretrained_model_name_or_path",
119
+ type=str,
120
+ default=None,
121
+ required=True,
122
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
123
+ )
124
+ parser.add_argument(
125
+ "--revision",
126
+ type=str,
127
+ default=None,
128
+ required=False,
129
+ help="Revision of pretrained model identifier from huggingface.co/models.",
130
+ )
131
+ parser.add_argument(
132
+ "--tokenizer_name",
133
+ type=str,
134
+ default=None,
135
+ help="Pretrained tokenizer name or path if not the same as model_name",
136
+ )
137
+ parser.add_argument(
138
+ "--instance_data_dir",
139
+ type=str,
140
+ default=None,
141
+ required=True,
142
+ help="A folder containing the training data of instance images.",
143
+ )
144
+ parser.add_argument(
145
+ "--class_data_dir",
146
+ type=str,
147
+ default=None,
148
+ required=False,
149
+ help="A folder containing the training data of class images.",
150
+ )
151
+ parser.add_argument(
152
+ "--instance_prompt",
153
+ type=str,
154
+ default=None,
155
+ required=True,
156
+ help="The prompt with identifier specifying the instance",
157
+ )
158
+ parser.add_argument(
159
+ "--class_prompt",
160
+ type=str,
161
+ default=None,
162
+ help="The prompt to specify images in the same class as provided instance images.",
163
+ )
164
+ parser.add_argument(
165
+ "--validation_prompt",
166
+ type=str,
167
+ default=None,
168
+ help="A prompt that is used during validation to verify that the model is learning.",
169
+ )
170
+ parser.add_argument(
171
+ "--num_validation_images",
172
+ type=int,
173
+ default=4,
174
+ help="Number of images that should be generated during validation with `validation_prompt`.",
175
+ )
176
+ parser.add_argument(
177
+ "--validation_epochs",
178
+ type=int,
179
+ default=50,
180
+ help=(
181
+ "Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
182
+ " `args.validation_prompt` multiple times: `args.num_validation_images`."
183
+ ),
184
+ )
185
+ parser.add_argument(
186
+ "--with_prior_preservation",
187
+ default=False,
188
+ action="store_true",
189
+ help="Flag to add prior preservation loss.",
190
+ )
191
+ parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
192
+ parser.add_argument(
193
+ "--num_class_images",
194
+ type=int,
195
+ default=100,
196
+ help=(
197
+ "Minimal class images for prior preservation loss. If there are not enough images already present in"
198
+ " class_data_dir, additional images will be sampled with class_prompt."
199
+ ),
200
+ )
201
+ parser.add_argument(
202
+ "--output_dir",
203
+ type=str,
204
+ default="lora-dreambooth-model",
205
+ help="The output directory where the model predictions and checkpoints will be written.",
206
+ )
207
+ parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
208
+ parser.add_argument(
209
+ "--resolution",
210
+ type=int,
211
+ default=512,
212
+ help=(
213
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
214
+ " resolution"
215
+ ),
216
+ )
217
+ parser.add_argument(
218
+ "--center_crop",
219
+ default=False,
220
+ action="store_true",
221
+ help=(
222
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
223
+ " cropped. The images will be resized to the resolution first before cropping."
224
+ ),
225
+ )
226
+ parser.add_argument(
227
+ "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
228
+ )
229
+ parser.add_argument(
230
+ "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
231
+ )
232
+ parser.add_argument("--num_train_epochs", type=int, default=1)
233
+ parser.add_argument(
234
+ "--max_train_steps",
235
+ type=int,
236
+ default=None,
237
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
238
+ )
239
+ parser.add_argument(
240
+ "--checkpointing_steps",
241
+ type=int,
242
+ default=500,
243
+ help=(
244
+ "Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
245
+ " checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
246
+ " training using `--resume_from_checkpoint`."
247
+ ),
248
+ )
249
+ parser.add_argument(
250
+ "--resume_from_checkpoint",
251
+ type=str,
252
+ default=None,
253
+ help=(
254
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
255
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
256
+ ),
257
+ )
258
+ parser.add_argument(
259
+ "--gradient_accumulation_steps",
260
+ type=int,
261
+ default=1,
262
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
263
+ )
264
+ parser.add_argument(
265
+ "--gradient_checkpointing",
266
+ action="store_true",
267
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
268
+ )
269
+ parser.add_argument(
270
+ "--learning_rate",
271
+ type=float,
272
+ default=5e-4,
273
+ help="Initial learning rate (after the potential warmup period) to use.",
274
+ )
275
+ parser.add_argument(
276
+ "--scale_lr",
277
+ action="store_true",
278
+ default=False,
279
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
280
+ )
281
+ parser.add_argument(
282
+ "--lr_scheduler",
283
+ type=str,
284
+ default="constant",
285
+ help=(
286
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
287
+ ' "constant", "constant_with_warmup"]'
288
+ ),
289
+ )
290
+ parser.add_argument(
291
+ "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
292
+ )
293
+ parser.add_argument(
294
+ "--lr_num_cycles",
295
+ type=int,
296
+ default=1,
297
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
298
+ )
299
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
300
+ parser.add_argument(
301
+ "--dataloader_num_workers",
302
+ type=int,
303
+ default=0,
304
+ help=(
305
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
306
+ ),
307
+ )
308
+ parser.add_argument(
309
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
310
+ )
311
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
312
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
313
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
314
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
315
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
316
+ parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
317
+ parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
318
+ parser.add_argument(
319
+ "--hub_model_id",
320
+ type=str,
321
+ default=None,
322
+ help="The name of the repository to keep in sync with the local `output_dir`.",
323
+ )
324
+ parser.add_argument(
325
+ "--logging_dir",
326
+ type=str,
327
+ default="logs",
328
+ help=(
329
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
330
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
331
+ ),
332
+ )
333
+ parser.add_argument(
334
+ "--allow_tf32",
335
+ action="store_true",
336
+ help=(
337
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
338
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
339
+ ),
340
+ )
341
+ parser.add_argument(
342
+ "--report_to",
343
+ type=str,
344
+ default="tensorboard",
345
+ help=(
346
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
347
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
348
+ ),
349
+ )
350
+ parser.add_argument(
351
+ "--mixed_precision",
352
+ type=str,
353
+ default=None,
354
+ choices=["no", "fp16", "bf16"],
355
+ help=(
356
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
357
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
358
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
359
+ ),
360
+ )
361
+ parser.add_argument(
362
+ "--prior_generation_precision",
363
+ type=str,
364
+ default=None,
365
+ choices=["no", "fp32", "fp16", "bf16"],
366
+ help=(
367
+ "Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
368
+ " 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
369
+ ),
370
+ )
371
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
372
+ parser.add_argument(
373
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
374
+ )
375
+
376
+ if input_args is not None:
377
+ args = parser.parse_args(input_args)
378
+ else:
379
+ args = parser.parse_args()
380
+
381
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
382
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
383
+ args.local_rank = env_local_rank
384
+
385
+ if args.with_prior_preservation:
386
+ if args.class_data_dir is None:
387
+ raise ValueError("You must specify a data directory for class images.")
388
+ if args.class_prompt is None:
389
+ raise ValueError("You must specify prompt for class images.")
390
+ else:
391
+ # logger is not available yet
392
+ if args.class_data_dir is not None:
393
+ warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
394
+ if args.class_prompt is not None:
395
+ warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
396
+
397
+ return args
398
+
399
+
400
+ class DreamBoothDataset(Dataset):
401
+ """
402
+ A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
403
+ It pre-processes the images and the tokenizes prompts.
404
+ """
405
+
406
+ def __init__(
407
+ self,
408
+ instance_data_root,
409
+ instance_prompt,
410
+ tokenizer,
411
+ class_data_root=None,
412
+ class_prompt=None,
413
+ size=512,
414
+ center_crop=False,
415
+ ):
416
+ self.size = size
417
+ self.center_crop = center_crop
418
+ self.tokenizer = tokenizer
419
+
420
+ self.instance_data_root = Path(instance_data_root)
421
+ if not self.instance_data_root.exists():
422
+ raise ValueError("Instance images root doesn't exists.")
423
+
424
+ self.instance_images_path = list(Path(instance_data_root).iterdir())
425
+ self.num_instance_images = len(self.instance_images_path)
426
+ self.instance_prompt = instance_prompt
427
+ self._length = self.num_instance_images
428
+
429
+ if class_data_root is not None:
430
+ self.class_data_root = Path(class_data_root)
431
+ self.class_data_root.mkdir(parents=True, exist_ok=True)
432
+ self.class_images_path = list(self.class_data_root.iterdir())
433
+ self.num_class_images = len(self.class_images_path)
434
+ self._length = max(self.num_class_images, self.num_instance_images)
435
+ self.class_prompt = class_prompt
436
+ else:
437
+ self.class_data_root = None
438
+
439
+ self.image_transforms = transforms.Compose(
440
+ [
441
+ transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
442
+ transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
443
+ transforms.ToTensor(),
444
+ transforms.Normalize([0.5], [0.5]),
445
+ ]
446
+ )
447
+
448
+ def __len__(self):
449
+ return self._length
450
+
451
+ def __getitem__(self, index):
452
+ example = {}
453
+ instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
454
+ if not instance_image.mode == "RGB":
455
+ instance_image = instance_image.convert("RGB")
456
+ example["instance_images"] = self.image_transforms(instance_image)
457
+ example["instance_prompt_ids"] = self.tokenizer(
458
+ self.instance_prompt,
459
+ truncation=True,
460
+ padding="max_length",
461
+ max_length=self.tokenizer.model_max_length,
462
+ return_tensors="pt",
463
+ ).input_ids
464
+
465
+ if self.class_data_root:
466
+ class_image = Image.open(self.class_images_path[index % self.num_class_images])
467
+ if not class_image.mode == "RGB":
468
+ class_image = class_image.convert("RGB")
469
+ example["class_images"] = self.image_transforms(class_image)
470
+ example["class_prompt_ids"] = self.tokenizer(
471
+ self.class_prompt,
472
+ truncation=True,
473
+ padding="max_length",
474
+ max_length=self.tokenizer.model_max_length,
475
+ return_tensors="pt",
476
+ ).input_ids
477
+
478
+ return example
479
+
480
+
481
+ def collate_fn(examples, with_prior_preservation=False):
482
+ input_ids = [example["instance_prompt_ids"] for example in examples]
483
+ pixel_values = [example["instance_images"] for example in examples]
484
+
485
+ # Concat class and instance examples for prior preservation.
486
+ # We do this to avoid doing two forward passes.
487
+ if with_prior_preservation:
488
+ input_ids += [example["class_prompt_ids"] for example in examples]
489
+ pixel_values += [example["class_images"] for example in examples]
490
+
491
+ pixel_values = torch.stack(pixel_values)
492
+ pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
493
+
494
+ input_ids = torch.cat(input_ids, dim=0)
495
+
496
+ batch = {
497
+ "input_ids": input_ids,
498
+ "pixel_values": pixel_values,
499
+ }
500
+ return batch
501
+
502
+
503
+ class PromptDataset(Dataset):
504
+ "A simple dataset to prepare the prompts to generate class images on multiple GPUs."
505
+
506
+ def __init__(self, prompt, num_samples):
507
+ self.prompt = prompt
508
+ self.num_samples = num_samples
509
+
510
+ def __len__(self):
511
+ return self.num_samples
512
+
513
+ def __getitem__(self, index):
514
+ example = {}
515
+ example["prompt"] = self.prompt
516
+ example["index"] = index
517
+ return example
518
+
519
+
520
+ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
521
+ if token is None:
522
+ token = HfFolder.get_token()
523
+ if organization is None:
524
+ username = whoami(token)["name"]
525
+ return f"{username}/{model_id}"
526
+ else:
527
+ return f"{organization}/{model_id}"
528
+
529
+
530
+ def main(args):
531
+ logging_dir = Path(args.output_dir, args.logging_dir)
532
+
533
+ accelerator = Accelerator(
534
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
535
+ mixed_precision=args.mixed_precision,
536
+ log_with=args.report_to,
537
+ logging_dir=logging_dir,
538
+ )
539
+
540
+ if args.report_to == "wandb":
541
+ if not is_wandb_available():
542
+ raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
543
+ import wandb
544
+
545
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
546
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
547
+ # TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
548
+ # Make one log on every process with the configuration for debugging.
549
+ logging.basicConfig(
550
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
551
+ datefmt="%m/%d/%Y %H:%M:%S",
552
+ level=logging.INFO,
553
+ )
554
+ logger.info(accelerator.state, main_process_only=False)
555
+ if accelerator.is_local_main_process:
556
+ datasets.utils.logging.set_verbosity_warning()
557
+ transformers.utils.logging.set_verbosity_warning()
558
+ diffusers.utils.logging.set_verbosity_info()
559
+ else:
560
+ datasets.utils.logging.set_verbosity_error()
561
+ transformers.utils.logging.set_verbosity_error()
562
+ diffusers.utils.logging.set_verbosity_error()
563
+
564
+ # If passed along, set the training seed now.
565
+ if args.seed is not None:
566
+ set_seed(args.seed)
567
+
568
+ # Generate class images if prior preservation is enabled.
569
+ if args.with_prior_preservation:
570
+ class_images_dir = Path(args.class_data_dir)
571
+ if not class_images_dir.exists():
572
+ class_images_dir.mkdir(parents=True)
573
+ cur_class_images = len(list(class_images_dir.iterdir()))
574
+
575
+ if cur_class_images < args.num_class_images:
576
+ torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
577
+ if args.prior_generation_precision == "fp32":
578
+ torch_dtype = torch.float32
579
+ elif args.prior_generation_precision == "fp16":
580
+ torch_dtype = torch.float16
581
+ elif args.prior_generation_precision == "bf16":
582
+ torch_dtype = torch.bfloat16
583
+ pipeline = DiffusionPipeline.from_pretrained(
584
+ args.pretrained_model_name_or_path,
585
+ torch_dtype=torch_dtype,
586
+ safety_checker=None,
587
+ revision=args.revision,
588
+ )
589
+ pipeline.set_progress_bar_config(disable=True)
590
+
591
+ num_new_images = args.num_class_images - cur_class_images
592
+ logger.info(f"Number of class images to sample: {num_new_images}.")
593
+
594
+ sample_dataset = PromptDataset(args.class_prompt, num_new_images)
595
+ sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
596
+
597
+ sample_dataloader = accelerator.prepare(sample_dataloader)
598
+ pipeline.to(accelerator.device)
599
+
600
+ for example in tqdm(
601
+ sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
602
+ ):
603
+ images = pipeline(example["prompt"]).images
604
+
605
+ for i, image in enumerate(images):
606
+ hash_image = hashlib.sha1(image.tobytes()).hexdigest()
607
+ image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
608
+ image.save(image_filename)
609
+
610
+ del pipeline
611
+ if torch.cuda.is_available():
612
+ torch.cuda.empty_cache()
613
+
614
+ # Handle the repository creation
615
+ if accelerator.is_main_process:
616
+ if args.push_to_hub:
617
+ if args.hub_model_id is None:
618
+ repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
619
+ else:
620
+ repo_name = args.hub_model_id
621
+
622
+ create_repo(repo_name, exist_ok=True, token=args.hub_token)
623
+ repo = Repository(args.output_dir, clone_from=repo_name, token=args.hub_token)
624
+
625
+ with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
626
+ if "step_*" not in gitignore:
627
+ gitignore.write("step_*\n")
628
+ if "epoch_*" not in gitignore:
629
+ gitignore.write("epoch_*\n")
630
+ elif args.output_dir is not None:
631
+ os.makedirs(args.output_dir, exist_ok=True)
632
+
633
+ # Load the tokenizer
634
+ if args.tokenizer_name:
635
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
636
+ elif args.pretrained_model_name_or_path:
637
+ tokenizer = AutoTokenizer.from_pretrained(
638
+ args.pretrained_model_name_or_path,
639
+ subfolder="tokenizer",
640
+ revision=args.revision,
641
+ use_fast=False,
642
+ )
643
+
644
+ # import correct text encoder class
645
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
646
+
647
+ # Load scheduler and models
648
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
649
+ text_encoder = text_encoder_cls.from_pretrained(
650
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
651
+ )
652
+ vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
653
+ unet = UNet2DConditionModel.from_pretrained(
654
+ args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
655
+ )
656
+
657
+ # We only train the additional adapter LoRA layers
658
+ vae.requires_grad_(False)
659
+ text_encoder.requires_grad_(False)
660
+ unet.requires_grad_(False)
661
+
662
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
663
+ # as these models are only used for inference, keeping weights in full precision is not required.
664
+ weight_dtype = torch.float32
665
+ if accelerator.mixed_precision == "fp16":
666
+ weight_dtype = torch.float16
667
+ elif accelerator.mixed_precision == "bf16":
668
+ weight_dtype = torch.bfloat16
669
+
670
+ # Move unet, vae and text_encoder to device and cast to weight_dtype
671
+ unet.to(accelerator.device, dtype=weight_dtype)
672
+ vae.to(accelerator.device, dtype=weight_dtype)
673
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
674
+
675
+ if args.enable_xformers_memory_efficient_attention:
676
+ if is_xformers_available():
677
+ unet.enable_xformers_memory_efficient_attention()
678
+ else:
679
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
680
+
681
+ # now we will add new LoRA weights to the attention layers
682
+ # It's important to realize here how many attention weights will be added and of which sizes
683
+ # The sizes of the attention layers consist only of two different variables:
684
+ # 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
685
+ # 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
686
+
687
+ # Let's first see how many attention processors we will have to set.
688
+ # For Stable Diffusion, it should be equal to:
689
+ # - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
690
+ # - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
691
+ # - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
692
+ # => 32 layers
693
+
694
+ # Set correct lora layers
695
+ lora_attn_procs = {}
696
+ for name in unet.attn_processors.keys():
697
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
698
+ if name.startswith("mid_block"):
699
+ hidden_size = unet.config.block_out_channels[-1]
700
+ elif name.startswith("up_blocks"):
701
+ block_id = int(name[len("up_blocks.")])
702
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
703
+ elif name.startswith("down_blocks"):
704
+ block_id = int(name[len("down_blocks.")])
705
+ hidden_size = unet.config.block_out_channels[block_id]
706
+
707
+ lora_attn_procs[name] = LoRACrossAttnProcessor(
708
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
709
+ )
710
+
711
+ unet.set_attn_processor(lora_attn_procs)
712
+ lora_layers = AttnProcsLayers(unet.attn_processors)
713
+
714
+ accelerator.register_for_checkpointing(lora_layers)
715
+
716
+ if args.scale_lr:
717
+ args.learning_rate = (
718
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
719
+ )
720
+
721
+ # Enable TF32 for faster training on Ampere GPUs,
722
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
723
+ if args.allow_tf32:
724
+ torch.backends.cuda.matmul.allow_tf32 = True
725
+
726
+ if args.scale_lr:
727
+ args.learning_rate = (
728
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
729
+ )
730
+
731
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
732
+ if args.use_8bit_adam:
733
+ try:
734
+ import bitsandbytes as bnb
735
+ except ImportError:
736
+ raise ImportError(
737
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
738
+ )
739
+
740
+ optimizer_class = bnb.optim.AdamW8bit
741
+ else:
742
+ optimizer_class = torch.optim.AdamW
743
+
744
+ # Optimizer creation
745
+ optimizer = optimizer_class(
746
+ lora_layers.parameters(),
747
+ lr=args.learning_rate,
748
+ betas=(args.adam_beta1, args.adam_beta2),
749
+ weight_decay=args.adam_weight_decay,
750
+ eps=args.adam_epsilon,
751
+ )
752
+
753
+ # Dataset and DataLoaders creation:
754
+ train_dataset = DreamBoothDataset(
755
+ instance_data_root=args.instance_data_dir,
756
+ instance_prompt=args.instance_prompt,
757
+ class_data_root=args.class_data_dir if args.with_prior_preservation else None,
758
+ class_prompt=args.class_prompt,
759
+ tokenizer=tokenizer,
760
+ size=args.resolution,
761
+ center_crop=args.center_crop,
762
+ )
763
+
764
+ train_dataloader = torch.utils.data.DataLoader(
765
+ train_dataset,
766
+ batch_size=args.train_batch_size,
767
+ shuffle=True,
768
+ collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
769
+ num_workers=args.dataloader_num_workers,
770
+ )
771
+
772
+ # Scheduler and math around the number of training steps.
773
+ overrode_max_train_steps = False
774
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
775
+ if args.max_train_steps is None:
776
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
777
+ overrode_max_train_steps = True
778
+
779
+ lr_scheduler = get_scheduler(
780
+ args.lr_scheduler,
781
+ optimizer=optimizer,
782
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
783
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
784
+ num_cycles=args.lr_num_cycles,
785
+ power=args.lr_power,
786
+ )
787
+
788
+ # Prepare everything with our `accelerator`.
789
+ lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
790
+ lora_layers, optimizer, train_dataloader, lr_scheduler
791
+ )
792
+
793
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
794
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
795
+ if overrode_max_train_steps:
796
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
797
+ # Afterwards we recalculate our number of training epochs
798
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
799
+
800
+ # We need to initialize the trackers we use, and also store our configuration.
801
+ # The trackers initializes automatically on the main process.
802
+ if accelerator.is_main_process:
803
+ accelerator.init_trackers("dreambooth-lora", config=vars(args))
804
+
805
+ # Train!
806
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
807
+
808
+ logger.info("***** Running training *****")
809
+ logger.info(f" Num examples = {len(train_dataset)}")
810
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
811
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
812
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
813
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
814
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
815
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
816
+ global_step = 0
817
+ first_epoch = 0
818
+
819
+ # Potentially load in the weights and states from a previous save
820
+ if args.resume_from_checkpoint:
821
+ if args.resume_from_checkpoint != "latest":
822
+ path = os.path.basename(args.resume_from_checkpoint)
823
+ else:
824
+ # Get the mos recent checkpoint
825
+ dirs = os.listdir(args.output_dir)
826
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
827
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
828
+ path = dirs[-1] if len(dirs) > 0 else None
829
+
830
+ if path is None:
831
+ accelerator.print(
832
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
833
+ )
834
+ args.resume_from_checkpoint = None
835
+ else:
836
+ accelerator.print(f"Resuming from checkpoint {path}")
837
+ accelerator.load_state(os.path.join(args.output_dir, path))
838
+ global_step = int(path.split("-")[1])
839
+
840
+ resume_global_step = global_step * args.gradient_accumulation_steps
841
+ first_epoch = global_step // num_update_steps_per_epoch
842
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
843
+
844
+ # Only show the progress bar once on each machine.
845
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
846
+ progress_bar.set_description("Steps")
847
+
848
+ for epoch in range(first_epoch, args.num_train_epochs):
849
+ unet.train()
850
+ for step, batch in enumerate(train_dataloader):
851
+ # Skip steps until we reach the resumed step
852
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
853
+ if step % args.gradient_accumulation_steps == 0:
854
+ progress_bar.update(1)
855
+ continue
856
+
857
+ with accelerator.accumulate(unet):
858
+ # Convert images to latent space
859
+ latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
860
+ latents = latents * 0.18215
861
+
862
+ # Sample noise that we'll add to the latents
863
+ noise = torch.randn_like(latents)
864
+ bsz = latents.shape[0]
865
+ # Sample a random timestep for each image
866
+ timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
867
+ timesteps = timesteps.long()
868
+
869
+ # Add noise to the latents according to the noise magnitude at each timestep
870
+ # (this is the forward diffusion process)
871
+ noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
872
+
873
+ # Get the text embedding for conditioning
874
+ encoder_hidden_states = text_encoder(batch["input_ids"])[0]
875
+
876
+ # Predict the noise residual
877
+ model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
878
+
879
+ # Get the target for loss depending on the prediction type
880
+ if noise_scheduler.config.prediction_type == "epsilon":
881
+ target = noise
882
+ elif noise_scheduler.config.prediction_type == "v_prediction":
883
+ target = noise_scheduler.get_velocity(latents, noise, timesteps)
884
+ else:
885
+ raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
886
+
887
+ if args.with_prior_preservation:
888
+ # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
889
+ model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
890
+ target, target_prior = torch.chunk(target, 2, dim=0)
891
+
892
+ # Compute instance loss
893
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
894
+
895
+ # Compute prior loss
896
+ prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
897
+
898
+ # Add the prior loss to the instance loss.
899
+ loss = loss + args.prior_loss_weight * prior_loss
900
+ else:
901
+ loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
902
+
903
+ accelerator.backward(loss)
904
+ if accelerator.sync_gradients:
905
+ params_to_clip = lora_layers.parameters()
906
+ accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
907
+ optimizer.step()
908
+ lr_scheduler.step()
909
+ optimizer.zero_grad()
910
+
911
+ # Checks if the accelerator has performed an optimization step behind the scenes
912
+ if accelerator.sync_gradients:
913
+ progress_bar.update(1)
914
+ global_step += 1
915
+
916
+ if global_step % args.checkpointing_steps == 0:
917
+ if accelerator.is_main_process:
918
+ save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
919
+ accelerator.save_state(save_path)
920
+ logger.info(f"Saved state to {save_path}")
921
+
922
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
923
+ progress_bar.set_postfix(**logs)
924
+ accelerator.log(logs, step=global_step)
925
+
926
+ if global_step >= args.max_train_steps:
927
+ break
928
+
929
+ if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
930
+ logger.info(
931
+ f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
932
+ f" {args.validation_prompt}."
933
+ )
934
+ # create pipeline
935
+ pipeline = DiffusionPipeline.from_pretrained(
936
+ args.pretrained_model_name_or_path,
937
+ unet=accelerator.unwrap_model(unet),
938
+ text_encoder=accelerator.unwrap_model(text_encoder),
939
+ revision=args.revision,
940
+ torch_dtype=weight_dtype,
941
+ )
942
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
943
+ pipeline = pipeline.to(accelerator.device)
944
+ pipeline.set_progress_bar_config(disable=True)
945
+
946
+ # run inference
947
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
948
+ prompt = args.num_validation_images * [args.validation_prompt]
949
+ images = pipeline(prompt, num_inference_steps=25, generator=generator).images
950
+
951
+ for tracker in accelerator.trackers:
952
+ if tracker.name == "tensorboard":
953
+ np_images = np.stack([np.asarray(img) for img in images])
954
+ tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
955
+ if tracker.name == "wandb":
956
+ tracker.log(
957
+ {
958
+ "validation": [
959
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
960
+ for i, image in enumerate(images)
961
+ ]
962
+ }
963
+ )
964
+
965
+ del pipeline
966
+ torch.cuda.empty_cache()
967
+
968
+ # Save the lora layers
969
+ accelerator.wait_for_everyone()
970
+ if accelerator.is_main_process:
971
+ unet = unet.to(torch.float32)
972
+ unet.save_attn_procs(args.output_dir)
973
+
974
+ # Final inference
975
+ # Load previous pipeline
976
+ pipeline = DiffusionPipeline.from_pretrained(
977
+ args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
978
+ )
979
+ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
980
+ pipeline = pipeline.to(accelerator.device)
981
+
982
+ # load attention processors
983
+ pipeline.unet.load_attn_procs(args.output_dir)
984
+
985
+ # run inference
986
+ if args.validation_prompt and args.num_validation_images > 0:
987
+ generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
988
+ prompt = args.num_validation_images * [args.validation_prompt]
989
+ images = pipeline(prompt, num_inference_steps=25, generator=generator).images
990
+
991
+ test_image_dir = Path(args.output_dir) / 'test_images'
992
+ test_image_dir.mkdir()
993
+ for i, image in enumerate(images):
994
+ out_path = test_image_dir / f'image_{i}.png'
995
+ image.save(out_path)
996
+
997
+ for tracker in accelerator.trackers:
998
+ if tracker.name == "tensorboard":
999
+ np_images = np.stack([np.asarray(img) for img in images])
1000
+ tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
1001
+ if tracker.name == "wandb":
1002
+ tracker.log(
1003
+ {
1004
+ "test": [
1005
+ wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
1006
+ for i, image in enumerate(images)
1007
+ ]
1008
+ }
1009
+ )
1010
+
1011
+ if args.push_to_hub:
1012
+ save_model_card(
1013
+ repo_name,
1014
+ images=images,
1015
+ base_model=args.pretrained_model_name_or_path,
1016
+ prompt=args.instance_prompt,
1017
+ repo_folder=args.output_dir,
1018
+ )
1019
+ repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
1020
+
1021
+ accelerator.end_training()
1022
+
1023
+
1024
+ if __name__ == "__main__":
1025
+ args = parse_args()
1026
+ main(args)
trainer.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import datetime
4
+ import os
5
+ import pathlib
6
+ import shlex
7
+ import shutil
8
+ import subprocess
9
+
10
+ import gradio as gr
11
+ import PIL.Image
12
+ import slugify
13
+ import torch
14
+ from huggingface_hub import HfApi
15
+
16
+ from app_upload import LoRAModelUploader
17
+ from utils import save_model_card
18
+
19
+ URL_TO_JOIN_LORA_LIBRARY_ORG = 'https://huggingface.co/organizations/lora-library/share/hjetHAcKjnPHXhHfbeEcqnBqmhgilFfpOL'
20
+
21
+
22
+ def pad_image(image: PIL.Image.Image) -> PIL.Image.Image:
23
+ w, h = image.size
24
+ if w == h:
25
+ return image
26
+ elif w > h:
27
+ new_image = PIL.Image.new(image.mode, (w, w), (0, 0, 0))
28
+ new_image.paste(image, (0, (w - h) // 2))
29
+ return new_image
30
+ else:
31
+ new_image = PIL.Image.new(image.mode, (h, h), (0, 0, 0))
32
+ new_image.paste(image, ((h - w) // 2, 0))
33
+ return new_image
34
+
35
+
36
+ class Trainer:
37
+ def __init__(self, hf_token: str | None = None):
38
+ self.hf_token = hf_token
39
+ self.api = HfApi(token=hf_token)
40
+ self.model_uploader = LoRAModelUploader(hf_token)
41
+
42
+ def prepare_dataset(self, instance_images: list, resolution: int,
43
+ instance_data_dir: pathlib.Path) -> None:
44
+ shutil.rmtree(instance_data_dir, ignore_errors=True)
45
+ instance_data_dir.mkdir(parents=True)
46
+ for i, temp_path in enumerate(instance_images):
47
+ image = PIL.Image.open(temp_path.name)
48
+ image = pad_image(image)
49
+ image = image.resize((resolution, resolution))
50
+ image = image.convert('RGB')
51
+ out_path = instance_data_dir / f'{i:03d}.jpg'
52
+ image.save(out_path, format='JPEG', quality=100)
53
+
54
+ def join_lora_library_org(self) -> None:
55
+ subprocess.run(
56
+ shlex.split(
57
+ f'curl -X POST -H "Authorization: Bearer {self.hf_token}" -H "Content-Type: application/json" {URL_TO_JOIN_LORA_LIBRARY_ORG}'
58
+ ))
59
+
60
+ def run(
61
+ self,
62
+ instance_images: list | None,
63
+ instance_prompt: str,
64
+ output_model_name: str,
65
+ overwrite_existing_model: bool,
66
+ validation_prompt: str,
67
+ base_model: str,
68
+ resolution_s: str,
69
+ n_steps: int,
70
+ learning_rate: float,
71
+ gradient_accumulation: int,
72
+ seed: int,
73
+ fp16: bool,
74
+ use_8bit_adam: bool,
75
+ checkpointing_steps: int,
76
+ use_wandb: bool,
77
+ validation_epochs: int,
78
+ upload_to_hub: bool,
79
+ use_private_repo: bool,
80
+ delete_existing_repo: bool,
81
+ upload_to: str,
82
+ remove_gpu_after_training: bool,
83
+ ) -> str:
84
+ if not torch.cuda.is_available():
85
+ raise gr.Error('CUDA is not available.')
86
+ if instance_images is None:
87
+ raise gr.Error('You need to upload images.')
88
+ if not instance_prompt:
89
+ raise gr.Error('The instance prompt is missing.')
90
+ if not validation_prompt:
91
+ raise gr.Error('The validation prompt is missing.')
92
+
93
+ resolution = int(resolution_s)
94
+
95
+ if not output_model_name:
96
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d-%H-%M-%S')
97
+ output_model_name = f'lora-dreambooth-{timestamp}'
98
+ output_model_name = slugify.slugify(output_model_name)
99
+
100
+ repo_dir = pathlib.Path(__file__).parent
101
+ output_dir = repo_dir / 'experiments' / output_model_name
102
+ if overwrite_existing_model or upload_to_hub:
103
+ shutil.rmtree(output_dir, ignore_errors=True)
104
+ output_dir.mkdir(parents=True)
105
+
106
+ instance_data_dir = repo_dir / 'training_data' / output_model_name
107
+ self.prepare_dataset(instance_images, resolution, instance_data_dir)
108
+
109
+ if upload_to_hub:
110
+ self.join_lora_library_org()
111
+
112
+ command = f'''
113
+ accelerate launch train_dreambooth_lora.py \
114
+ --pretrained_model_name_or_path={base_model} \
115
+ --instance_data_dir={instance_data_dir} \
116
+ --output_dir={output_dir} \
117
+ --instance_prompt="{instance_prompt}" \
118
+ --resolution={resolution} \
119
+ --train_batch_size=1 \
120
+ --gradient_accumulation_steps={gradient_accumulation} \
121
+ --learning_rate={learning_rate} \
122
+ --lr_scheduler=constant \
123
+ --lr_warmup_steps=0 \
124
+ --max_train_steps={n_steps} \
125
+ --checkpointing_steps={checkpointing_steps} \
126
+ --validation_prompt="{validation_prompt}" \
127
+ --validation_epochs={validation_epochs} \
128
+ --seed={seed}
129
+ '''
130
+ if fp16:
131
+ command += ' --mixed_precision fp16'
132
+ if use_8bit_adam:
133
+ command += ' --use_8bit_adam'
134
+ if use_wandb:
135
+ command += ' --report_to wandb'
136
+
137
+ with open(output_dir / 'train.sh', 'w') as f:
138
+ command_s = ' '.join(command.split())
139
+ f.write(command_s)
140
+ subprocess.run(shlex.split(command))
141
+ save_model_card(save_dir=output_dir,
142
+ base_model=base_model,
143
+ instance_prompt=instance_prompt,
144
+ test_prompt=validation_prompt,
145
+ test_image_dir='test_images')
146
+
147
+ message = 'Training completed!'
148
+ print(message)
149
+
150
+ if upload_to_hub:
151
+ upload_message = self.model_uploader.upload_lora_model(
152
+ folder_path=output_dir.as_posix(),
153
+ repo_name=output_model_name,
154
+ upload_to=upload_to,
155
+ private=use_private_repo,
156
+ delete_existing_repo=delete_existing_repo)
157
+ print(upload_message)
158
+ message = message + '\n' + upload_message
159
+
160
+ if remove_gpu_after_training:
161
+ space_id = os.getenv('SPACE_ID')
162
+ if space_id:
163
+ self.api.request_space_hardware(repo_id=space_id,
164
+ hardware='cpu-basic')
165
+
166
+ return message
uploader.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from huggingface_hub import HfApi
4
+
5
+
6
+ class Uploader:
7
+ def __init__(self, hf_token: str | None):
8
+ self.api = HfApi(token=hf_token)
9
+
10
+ def get_username(self) -> str:
11
+ return self.api.whoami()['name']
12
+
13
+ def upload(self,
14
+ folder_path: str,
15
+ repo_name: str,
16
+ organization: str = '',
17
+ repo_type: str = 'model',
18
+ private: bool = True,
19
+ delete_existing_repo: bool = False) -> str:
20
+ if not folder_path:
21
+ raise ValueError
22
+ if not repo_name:
23
+ raise ValueError
24
+ if not organization:
25
+ organization = self.get_username()
26
+ repo_id = f'{organization}/{repo_name}'
27
+ if delete_existing_repo:
28
+ try:
29
+ self.api.delete_repo(repo_id, repo_type=repo_type)
30
+ except Exception:
31
+ pass
32
+ try:
33
+ self.api.create_repo(repo_id, repo_type=repo_type, private=private)
34
+ self.api.upload_folder(repo_id=repo_id,
35
+ folder_path=folder_path,
36
+ path_in_repo='.',
37
+ repo_type=repo_type)
38
+ url = f'https://huggingface.co/{repo_id}'
39
+ message = f'Your model was successfully uploaded to <a href="{url}" target="_blank">{url}</a>.'
40
+ except Exception as e:
41
+ message = str(e)
42
+ return message
utils.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import pathlib
4
+
5
+
6
+ def find_exp_dirs(ignore_repo: bool = False) -> list[str]:
7
+ repo_dir = pathlib.Path(__file__).parent
8
+ exp_root_dir = repo_dir / 'experiments'
9
+ if not exp_root_dir.exists():
10
+ return []
11
+ exp_dirs = sorted(exp_root_dir.glob('*'))
12
+ exp_dirs = [
13
+ exp_dir for exp_dir in exp_dirs
14
+ if (exp_dir / 'pytorch_lora_weights.bin').exists()
15
+ ]
16
+ if ignore_repo:
17
+ exp_dirs = [
18
+ exp_dir for exp_dir in exp_dirs if not (exp_dir / '.git').exists()
19
+ ]
20
+ return [path.relative_to(repo_dir).as_posix() for path in exp_dirs]
21
+
22
+
23
+ def save_model_card(
24
+ save_dir: pathlib.Path,
25
+ base_model: str,
26
+ instance_prompt: str,
27
+ test_prompt: str = '',
28
+ test_image_dir: str = '',
29
+ ) -> None:
30
+ image_str = ''
31
+ if test_prompt and test_image_dir:
32
+ image_paths = sorted((save_dir / test_image_dir).glob('*'))
33
+ if image_paths:
34
+ image_str = f'Test prompt: {test_prompt}\n'
35
+ for image_path in image_paths:
36
+ rel_path = image_path.relative_to(save_dir)
37
+ image_str += f'![{image_path.stem}]({rel_path})\n'
38
+
39
+ model_card = f'''---
40
+ license: creativeml-openrail-m
41
+ base_model: {base_model}
42
+ instance_prompt: {instance_prompt}
43
+ tags:
44
+ - stable-diffusion
45
+ - stable-diffusion-diffusers
46
+ - text-to-image
47
+ - diffusers
48
+ - lora
49
+ inference: true
50
+ ---
51
+ # LoRA DreamBooth - {save_dir.name}
52
+
53
+ These are LoRA adaption weights for [{base_model}](https://huggingface.co/{base_model}). The weights were trained on the instance prompt "{instance_prompt}" using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following.
54
+
55
+ {image_str}
56
+ '''
57
+
58
+ with open(save_dir / 'README.md', 'w') as f:
59
+ f.write(model_card)