rahul7star commited on
Commit
fcc02a2
·
verified ·
1 Parent(s): af6809f

boilerplate

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. .github/FUNDING.yml +2 -0
  3. .github/ISSUE_TEMPLATE/bug_report.md +19 -0
  4. .github/ISSUE_TEMPLATE/config.yml +5 -0
  5. .gitignore +183 -0
  6. .gitmodules +0 -0
  7. .vscode/launch.json +28 -0
  8. FAQ.md +10 -0
  9. LICENSE +21 -0
  10. README.md +413 -10
  11. assets/VAE_test1.jpg +3 -0
  12. assets/glif.svg +40 -0
  13. assets/lora_ease_ui.png +3 -0
  14. build_and_push_docker +29 -0
  15. config/examples/extract.example.yml +75 -0
  16. config/examples/generate.example.yaml +60 -0
  17. config/examples/mod_lora_scale.yaml +48 -0
  18. config/examples/modal/modal_train_lora_flux_24gb.yaml +96 -0
  19. config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml +98 -0
  20. config/examples/train_flex_redux.yaml +112 -0
  21. config/examples/train_full_fine_tune_flex.yaml +107 -0
  22. config/examples/train_full_fine_tune_lumina.yaml +99 -0
  23. config/examples/train_lora_chroma_24gb.yaml +104 -0
  24. config/examples/train_lora_flex2_24gb.yaml +165 -0
  25. config/examples/train_lora_flex_24gb.yaml +101 -0
  26. config/examples/train_lora_flux_24gb.yaml +96 -0
  27. config/examples/train_lora_flux_schnell_24gb.yaml +98 -0
  28. config/examples/train_lora_hidream_48.yaml +112 -0
  29. config/examples/train_lora_lumina.yaml +96 -0
  30. config/examples/train_lora_sd35_large_24gb.yaml +97 -0
  31. config/examples/train_lora_wan21_14b_24gb.yaml +101 -0
  32. config/examples/train_lora_wan21_1b_24gb.yaml +90 -0
  33. config/examples/train_slider.example.yml +230 -0
  34. docker-compose.yml +25 -0
  35. docker/Dockerfile +77 -0
  36. docker/start.sh +70 -0
  37. extensions/example/ExampleMergeModels.py +129 -0
  38. extensions/example/__init__.py +25 -0
  39. extensions/example/config/config.example.yaml +48 -0
  40. extensions_built_in/advanced_generator/Img2ImgGenerator.py +256 -0
  41. extensions_built_in/advanced_generator/PureLoraGenerator.py +102 -0
  42. extensions_built_in/advanced_generator/ReferenceGenerator.py +212 -0
  43. extensions_built_in/advanced_generator/__init__.py +59 -0
  44. extensions_built_in/advanced_generator/config/train.example.yaml +91 -0
  45. extensions_built_in/concept_replacer/ConceptReplacer.py +151 -0
  46. extensions_built_in/concept_replacer/__init__.py +26 -0
  47. extensions_built_in/concept_replacer/config/train.example.yaml +91 -0
  48. extensions_built_in/dataset_tools/DatasetTools.py +20 -0
  49. extensions_built_in/dataset_tools/SuperTagger.py +196 -0
  50. extensions_built_in/dataset_tools/SyncFromCollection.py +131 -0
.gitattributes CHANGED
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/lora_ease_ui.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/VAE_test1.jpg filter=lfs diff=lfs merge=lfs -text
.github/FUNDING.yml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ github: [ostris]
2
+ patreon: ostris
.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug Report
3
+ about: For bugs only. Not for feature requests or questions.
4
+ title: ''
5
+ labels: ''
6
+ assignees: ''
7
+ ---
8
+
9
+ ## This is for bugs only
10
+
11
+ Did you already ask [in the discord](https://discord.gg/VXmU2f5WEU)?
12
+
13
+ Yes/No
14
+
15
+ You verified that this is a bug and not a feature request or question by asking [in the discord](https://discord.gg/VXmU2f5WEU)?
16
+
17
+ Yes/No
18
+
19
+ ## Describe the bug
.github/ISSUE_TEMPLATE/config.yml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ blank_issues_enabled: false
2
+ contact_links:
3
+ - name: Ask in the Discord BEFORE opening an issue
4
+ url: https://discord.gg/VXmU2f5WEU
5
+ about: Please ask in the discord before opening a github issue.
.gitignore ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ /env.sh
163
+ /models
164
+ /datasets
165
+ /custom/*
166
+ !/custom/.gitkeep
167
+ /.tmp
168
+ /venv.bkp
169
+ /venv.*
170
+ /config/*
171
+ !/config/examples
172
+ !/config/_PUT_YOUR_CONFIGS_HERE).txt
173
+ /output/*
174
+ !/output/.gitkeep
175
+ /extensions/*
176
+ !/extensions/example
177
+ /temp
178
+ /wandb
179
+ .vscode/settings.json
180
+ .DS_Store
181
+ ._.DS_Store
182
+ aitk_db.db
183
+ /notes.md
.gitmodules ADDED
File without changes
.vscode/launch.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Run current config",
6
+ "type": "python",
7
+ "request": "launch",
8
+ "program": "${workspaceFolder}/run.py",
9
+ "args": [
10
+ "${file}"
11
+ ],
12
+ "env": {
13
+ "CUDA_LAUNCH_BLOCKING": "1",
14
+ "DEBUG_TOOLKIT": "1"
15
+ },
16
+ "console": "integratedTerminal",
17
+ "justMyCode": false
18
+ },
19
+ {
20
+ "name": "Python: Debug Current File",
21
+ "type": "python",
22
+ "request": "launch",
23
+ "program": "${file}",
24
+ "console": "integratedTerminal",
25
+ "justMyCode": false
26
+ },
27
+ ]
28
+ }
FAQ.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # FAQ
2
+
3
+ WIP. Will continue to add things as they are needed.
4
+
5
+ ## FLUX.1 Training
6
+
7
+ #### How much VRAM is required to train a lora on FLUX.1?
8
+
9
+ 24GB minimum is required.
10
+
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Ostris, LLC
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,13 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Ai Toolkit
3
- emoji: 🌍
4
- colorFrom: red
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.29.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: ai-toolkit-training
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # AI Toolkit by Ostris
2
+
3
+ AI Toolkit is an all in one training suite for diffusion models. I try to support all the latest models on consumer grade hardware. Image and video models. It can be run as a GUI or CLI. It is designed to be easy to use but still have every feature imaginable.
4
+
5
+ ## Support My Work
6
+
7
+ If you enjoy my projects or use them commercially, please consider sponsoring me. Every bit helps! 💖
8
+
9
+ [Sponsor on GitHub](https://github.com/orgs/ostris) | [Support on Patreon](https://www.patreon.com/ostris) | [Donate on PayPal](https://www.paypal.com/donate/?hosted_button_id=9GEFUKC8T9R9W)
10
+
11
+ ### Current Sponsors
12
+
13
+ All of these people / organizations are the ones who selflessly make this project possible. Thank you!!
14
+
15
+ _Last updated: 2025-04-22 16:45 UTC_
16
+
17
+ <p align="center">
18
+ <a href="https://github.com/replicate" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/60410876?v=4" alt="Replicate" width="200" height="200" style="border-radius:8px;margin:5px;"></a>
19
+ <a href="https://github.com/josephrocca" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/1167575?u=92d92921b4cb5c8c7e225663fed53c4b41897736&v=4" alt="josephrocca" width="200" height="200" style="border-radius:8px;margin:5px;"></a>
20
+ </p>
21
+ <hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
22
+ <p align="center">
23
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/162524101/81a72689c3754ac5b9e38612ce5ce914/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=3XLSlLFCWAQ-0wd2_vZMikyotdQNSzKOjoyeoJiZEw0%3D" alt="Prasanth Veerina" width="150" height="150" style="border-radius:8px;margin:5px;">
24
+ <a href="https://github.com/weights-ai" target="_blank" rel="noopener noreferrer"><img src="https://avatars.githubusercontent.com/u/185568492?v=4" alt="Weights" width="150" height="150" style="border-radius:8px;margin:5px;"></a>
25
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/161471720/dd330b4036d44a5985ed5985c12a5def/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=qkRvrEc5gLPxaXxLvcvbYv1W1lcmOoTwhj4A9Cq5BxQ%3D" alt="Vladimir Sotnikov" width="150" height="150" style="border-radius:8px;margin:5px;">
26
+ <img src="https://c8.patreon.com/3/200/33158543" alt="clement Delangue" width="150" height="150" style="border-radius:8px;margin:5px;">
27
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/54890369/45cea21d82974c78bf43956de7fb0e12/eyJ3IjoyMDB9/2.jpeg?token-time=2145916800&token-hash=IK6OT6UpusHgdaC4y8IhK5XxXiP5TuLy3vjvgL77Fho%3D" alt="Eli Slugworth" width="150" height="150" style="border-radius:8px;margin:5px;">
28
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/8654302/b0f5ebedc62a47c4b56222693e1254e9/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=lpeicIh1_S-3Ji3W27gyiRB7iXurp8Bx8HAzDHftOuo%3D" alt="Misch Strotz" width="150" height="150" style="border-radius:8px;margin:5px;">
29
+ <img src="https://c8.patreon.com/3/200/93304" alt="Joseph Rocca" width="150" height="150" style="border-radius:8px;margin:5px;">
30
+ </p>
31
+ <hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
32
+ <p align="center">
33
+ <a href="https://x.com/NuxZoe" target="_blank" rel="noopener noreferrer"><img src="https://pbs.twimg.com/profile_images/1714760743273574400/tdvQjNTl_400x400.jpg" alt="tungsten" width="100" height="100" style="border-radius:8px;margin:5px;"></a>
34
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/2298192/1228b69bd7d7481baf3103315183250d/eyJ3IjoyMDB9/1.jpg?token-time=2145916800&token-hash=1B7dbXy_gAcPT9WXBesLhs7z_9APiz2k1Wx4Vml_-8Q%3D" alt="Mohamed Oumoumad" width="100" height="100" style="border-radius:8px;margin:5px;">
35
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/120239481/49b1ce70d3d24704b8ec34de24ec8f55/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=Dv1NPKwdv9QT8fhYYwbGnQIvfiyqTUlh52bjDW1vYxY%3D" alt="nitish PNR" width="100" height="100" style="border-radius:8px;margin:5px;">
36
+ <img src="https://c8.patreon.com/3/200/548524" alt="Steve Hanff" width="100" height="100" style="border-radius:8px;margin:5px;">
37
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/152118848/3b15a43d71714552b5ed1c9f84e66adf/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=IEKE18CBHVZ3k-08UD7Dkb7HbiFHb84W0FATdLMI0Dg%3D" alt="Kristjan Retter" width="100" height="100" style="border-radius:8px;margin:5px;">
38
+ <img src="https://c8.patreon.com/3/200/83319230" alt="Miguel Lara" width="100" height="100" style="border-radius:8px;margin:5px;">
39
+ </p>
40
+ <hr style="width:100%;border:none;height:2px;background:#ddd;margin:30px 0;">
41
+ <p align="center">
42
+ <img src="https://c8.patreon.com/3/200/8449560" alt="Patron" width="60" height="60" style="border-radius:8px;margin:5px;">
43
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/27288932/6c35d2d961ee4e14a7a368c990791315/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=dpFFssZXZM_KZMKQhl3uDwwusdFw1c_v9x_ChJU7_zc%3D" alt="David Garrido" width="60" height="60" style="border-radius:8px;margin:5px;">
44
+ <img src="https://c8.patreon.com/3/200/2410522" alt="George Gostyshev" width="60" height="60" style="border-radius:8px;margin:5px;">
45
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/16287560/78130de30950410ca528d8a888997081/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=Ok-HSL2MthKXF09SmCOlPFCPfbMctFBZKCuTnPwxZ3A%3D" alt="Vitaly Golubenko" width="60" height="60" style="border-radius:8px;margin:5px;">
46
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/570742/4ceb33453a5a4745b430a216aba9280f/eyJ3IjoyMDB9/1.jpg?token-time=2145916800&token-hash=wUzsI5cO5Evp2ukIGdSgBbvKeYgv5LSOQMa6Br33Rrs%3D" alt="Al H" width="60" height="60" style="border-radius:8px;margin:5px;">
47
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/131773947/eda3405aa582437db4582fce908c8739/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=S4Bh0sMqTNmJlo3uRr7co5d_kxvBjITemDTfi_1KrCA%3D" alt="Jodh Singh" width="60" height="60" style="border-radius:8px;margin:5px;">
48
+ <img src="https://c8.patreon.com/3/200/22809690" alt="Michael Levine" width="60" height="60" style="border-radius:8px;margin:5px;">
49
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/99036356/7ae9c4d80e604e739b68cca12ee2ed01/eyJ3IjoyMDB9/3.png?token-time=2145916800&token-hash=zK0dHe6A937WtNlrGdefoXFTPPzHUCfn__23HP8-Ui0%3D" alt="Noctre" width="60" height="60" style="border-radius:8px;margin:5px;">
50
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/141098579/1a9f0a1249d447a7a0df718a57343912/eyJ3IjoyMDB9/2.png?token-time=2145916800&token-hash=Rd_AjZGhMATVkZDf8E95ILc0n93gvvFWe1Ig0_dxwf4%3D" alt="The Local Lab" width="60" height="60" style="border-radius:8px;margin:5px;">
51
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/98811435/3a3632d1795b4c2b9f8f0270f2f6a650/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=93w8RMxwXlcM4X74t03u6P5_SrKvlm1IpjnD2SzVpJk%3D" alt="EmmanuelMr18" width="60" height="60" style="border-radius:8px;margin:5px;">
52
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/338551/e8f257d8d3dd46c38272b391a5785948/eyJ3IjoyMDB9/1.jpg?token-time=2145916800&token-hash=GLom1rGgOZjBeO7I1OnjiIgWmjl6PO9ZjBB8YTvc7AM%3D" alt="Plaidam" width="60" height="60" style="border-radius:8px;margin:5px;">
53
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/82763/f99cc484361d4b9d94fe4f0814ada303/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=BpwC020pR3TRZ4r0RSCiSIOh-jmatkrpy1h2XU4sGa4%3D" alt="Doron Adler" width="60" height="60" style="border-radius:8px;margin:5px;">
54
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/103077711/bb215761cc004e80bd9cec7d4bcd636d/eyJ3IjoyMDB9/2.jpeg?token-time=2145916800&token-hash=zvtBie29rRTKTXvAA2KhOI-l3mSMk9xxr-mg_CksLtc%3D" alt="John Dopamine" width="60" height="60" style="border-radius:8px;margin:5px;">
55
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/93348210/5c650f32a0bc481d80900d2674528777/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=PpXK9B_iy288annlNdLOexhiQHbTftPEDeCh-sTQ2KA%3D" alt="Armin Behjati" width="60" height="60" style="border-radius:8px;margin:5px;">
56
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/155963250/6f8fd7075c3b4247bfeb054ba49172d6/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=twmKs4mADF_h7bKh5jBuigYVScMeaeHv2pEPin9K0Dg%3D" alt="Un Defined" width="60" height="60" style="border-radius:8px;margin:5px;">
57
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/45562978/0de33cf52ec642ae8a2f612cddec4ca6/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=hSAvaD4phiLcF0pvX7FP0juI5NQWCon-_TZSNpJzQJg%3D" alt="Jack English" width="60" height="60" style="border-radius:8px;margin:5px;">
58
+ <img src="https://c8.patreon.com/3/200/27791680" alt="Jean-Tristan Marin" width="60" height="60" style="border-radius:8px;margin:5px;">
59
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/60995694/92e0e8f336eb4a5bb8d99b940247d1d1/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=pj6Tm8XRdpGJcAEdnCakqYSNiSjoAYjvZescX7d0ic0%3D" alt="Abraham Irawan" width="60" height="60" style="border-radius:8px;margin:5px;">
60
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/164958178/4eb7a37baa0541bab7a091f2b14615b7/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=_aaum7fBJAGaJhMBhlR8vqYavDhExdVxmO9mwd3_XMw%3D" alt="Austin Robinson" width="60" height="60" style="border-radius:8px;margin:5px;">
61
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/134129880/680c7e14cd1a4d1a9face921fb010f88/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=vNKojv67krNqx7gdpKBX1R_stX2TkMRYvRc0xZrbY6s%3D" alt="Bharat Prabhakar" width="60" height="60" style="border-radius:8px;margin:5px;">
62
+ <img src="https://c8.patreon.com/3/200/70218846" alt="Cosmosis" width="60" height="60" style="border-radius:8px;margin:5px;">
63
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/83054970/13de6cb103ad41a5841edf549e66cd51/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=wU_Eke9VYcfI40FAQvdEV84Xspqlo5VSiafLqhg_FOE%3D" alt="Gili Ben Shahar" width="60" height="60" style="border-radius:8px;margin:5px;">
64
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/30931983/54ab4e4ceab946e79a6418d205f9ed51/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=LBmsSsMQZhO6yRZ_YyRwTgE6a7BVWrGNsAVveLXHXR0%3D" alt="HestoySeghuro ." width="60" height="60" style="border-radius:8px;margin:5px;">
65
+ <img src="https://c8.patreon.com/3/200/4105384" alt="Jack Blakely" width="60" height="60" style="border-radius:8px;margin:5px;">
66
+ <img src="https://c8.patreon.com/3/200/494309" alt="Julian Tsependa" width="60" height="60" style="border-radius:8px;margin:5px;">
67
+ <img src="https://c8.patreon.com/3/200/24653779" alt="RayHell" width="60" height="60" style="border-radius:8px;margin:5px;">
68
+ <img src="https://c8.patreon.com/3/200/4541423" alt="Sören " width="60" height="60" style="border-radius:8px;margin:5px;">
69
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/31950857/c567dc648f6144be9f6234946df05da2/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=3Vx4R1eOfD4X_ZPPd40MsZ-3lyknLM35XmaHRELnWjM%3D" alt="Trent Hunter" width="60" height="60" style="border-radius:8px;margin:5px;">
70
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/110407414/30f9e9d88ef945ddb0f47fd23a8cbac2/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=QQRWOkMyOfDBERHn4O8N2wMB32zeiIEsydVTbSNUw-I%3D" alt="Wesley Reitzfeld" width="60" height="60" style="border-radius:8px;margin:5px;">
71
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/162398691/89d78d89eecb4d6b981ce8c3c6a3d4b8/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=SWhI-0jGpY6Nc_bUQeXz4pa9DRURi9VnnnJ3Mxjg1po%3D" alt="Zoltán-Csaba Nyiró" width="60" height="60" style="border-radius:8px;margin:5px;">
72
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/97985240/3d1d0e6905d045aba713e8132cab4a30/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=pG3X2m-py2lRYI2aoJiXI47_4ArD78ZHdSm6jCAHA_w%3D" alt="עומר מכלוף" width="60" height="60" style="border-radius:8px;margin:5px;">
73
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/140599287/cff037fb93804af28bc3a4f1e91154f8/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=vkscmpmFoM5wq7GnsLmOEgNhvyXe-774kNGNqD0wurE%3D" alt="Lukas" width="60" height="60" style="border-radius:8px;margin:5px;">
74
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/96561218/b0694642d13a49faa75aec9762ff2aeb/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=sLQXomYm1iMYpknvGwKQ49f30TKQ0B1R2W3EZfCJqr8%3D" alt="Ultimate Golf Archives" width="60" height="60" style="border-radius:8px;margin:5px;">
75
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/81275465/1e4148fe9c47452b838949d02dd9a70f/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=uzJzkUq9rte3wx8wDLjGAgvSoxdtZcAnH7HctDhdYEo%3D" alt="Aaron Amortegui" width="60" height="60" style="border-radius:8px;margin:5px;">
76
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/44568304/a9d83a0e786b41b4bdada150f7c9271c/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=SBphTD654nwr-OTrvIBIJBEQho7GE2PtRre8nyaG1Fk%3D" alt="Albert Bukoski" width="60" height="60" style="border-radius:8px;margin:5px;">
77
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/49304261/d0a730de1c3349e585c49288b9f419c6/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=C2BMZ3ci-Ty2nhnSwKZqsR-5hOGsUNDYcvXps0Geq9w%3D" alt="Arvin Flores" width="60" height="60" style="border-radius:8px;margin:5px;">
78
+ <img src="https://c8.patreon.com/3/200/5048649" alt="Ben Ward" width="60" height="60" style="border-radius:8px;margin:5px;">
79
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/130338124/f904a3bb76cd4588ac8d8f595c6cb486/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=k-inISRUtYDu9q7fNAKc3S2S7qcaw26fr1pj7PqU28Q%3D" alt="Bnp" width="60" height="60" style="border-radius:8px;margin:5px;">
80
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/111904990/08b1cf65be6a4de091c9b73b693b3468/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=OAJc9W5Ak0uJfQ2COlo1Upo38K3aj1fMQFCMC7ft5tM%3D" alt="Brian Smith" width="60" height="60" style="border-radius:8px;margin:5px;">
81
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/113207022/d4a67cc113e84fb69032bef71d068720/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=mu-tIg88VwoQdgLEOmxuVkhVm9JT59DdnHXJstmkkLU%3D" alt="Fagem X" width="60" height="60" style="border-radius:8px;margin:5px;">
82
+ <img src="https://c8.patreon.com/3/200/5602036" alt="Kelevra" width="60" height="60" style="border-radius:8px;margin:5px;">
83
+ <img src="https://c8.patreon.com/3/200/358350" alt="L D" width="60" height="60" style="border-radius:8px;margin:5px;">
84
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/159203973/36c817f941ac4fa18103a4b8c0cb9cae/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=9toslDfsO14QyaOiu6vIf--d4marBsWCZWN3gdPqbIU%3D" alt="Marko jak" width="60" height="60" style="border-radius:8px;margin:5px;">
85
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/11198131/e696d9647feb4318bcf16243c2425805/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=o6Hrpzw9rf2Ucd4cZ-hdUkGejLNv44-pqF8smeOF3ts%3D" alt="Nicholas Agranoff" width="60" height="60" style="border-radius:8px;margin:5px;">
86
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/785333/bdb9ede5765d42e5a2021a86eebf0d8f/eyJ3IjoyMDB9/2.jpg?token-time=2145916800&token-hash=dr5eaMg3Ua0wyCy40Qv3F-ZFajWZmuz2fWG55FskREc%3D" alt="Sapjes " width="60" height="60" style="border-radius:8px;margin:5px;">
87
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/44738426/b01ff676da864d4ab9c21f226275b63e/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=54nIkcxFaGszJ3q0jNhtrVSBbV3WNK9e5WX9VzXltYk%3D" alt="Shakeel Saleemi" width="60" height="60" style="border-radius:8px;margin:5px;">
88
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/76566911/6485eaf5ec6249a7b524ee0b979372f0/eyJ3IjoyMDB9/1.jpeg?token-time=2145916800&token-hash=S1QK78ief5byQU7tB_reqnw4V2zhW_cpwTqHThk-tGc%3D" alt="the biitz" width="60" height="60" style="border-radius:8px;margin:5px;">
89
+ <img src="https://c8.patreon.com/3/200/83034" alt="william tatum" width="60" height="60" style="border-radius:8px;margin:5px;">
90
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/32633822/1ab5612efe80417cbebfe91e871fc052/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=RHYMcjr0UGIYw5FBrUfJdKMGuoYWhBQlLIykccEFJvo%3D" alt="Zack Abrams" width="60" height="60" style="border-radius:8px;margin:5px;">
91
+ <img src="https://c10.patreonusercontent.com/4/patreon-media/p/user/138787189/2b5662dcb638466282ac758e3ac651b4/eyJ3IjoyMDB9/1.png?token-time=2145916800&token-hash=IlUAs9JAlVRphfx81V-Jt-nMiSBS8mPewRr9u6pQjaQ%3D" alt="Антон Антонио" width="60" height="60" style="border-radius:8px;margin:5px;">
92
+ </p>
93
+
94
  ---
95
+
96
+
97
+
98
+
99
+
100
+
101
+ ## Installation
102
+
103
+ Requirements:
104
+ - python >3.10
105
+ - Nvidia GPU with enough ram to do what you need
106
+ - python venv
107
+ - git
108
+
109
+
110
+ Linux:
111
+ ```bash
112
+ git clone https://github.com/ostris/ai-toolkit.git
113
+ cd ai-toolkit
114
+ python3 -m venv venv
115
+ source venv/bin/activate
116
+ # install torch first
117
+ pip3 install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
118
+ pip3 install -r requirements.txt
119
+ ```
120
+
121
+ Windows:
122
+ ```bash
123
+ git clone https://github.com/ostris/ai-toolkit.git
124
+ cd ai-toolkit
125
+ python -m venv venv
126
+ .\venv\Scripts\activate
127
+ pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
128
+ pip install -r requirements.txt
129
+ ```
130
+
131
+
132
+ # AI Toolkit UI
133
+
134
+ <img src="https://ostris.com/wp-content/uploads/2025/02/toolkit-ui.jpg" alt="AI Toolkit UI" width="100%">
135
+
136
+ The AI Toolkit UI is a web interface for the AI Toolkit. It allows you to easily start, stop, and monitor jobs. It also allows you to easily train models with a few clicks. It also allows you to set a token for the UI to prevent unauthorized access so it is mostly safe to run on an exposed server.
137
+
138
+ ## Running the UI
139
+
140
+ Requirements:
141
+ - Node.js > 18
142
+
143
+ The UI does not need to be kept running for the jobs to run. It is only needed to start/stop/monitor jobs. The commands below
144
+ will install / update the UI and it's dependencies and start the UI.
145
+
146
+ ```bash
147
+ cd ui
148
+ npm run build_and_start
149
+ ```
150
+
151
+ You can now access the UI at `http://localhost:8675` or `http://<your-ip>:8675` if you are running it on a server.
152
+
153
+ ## Securing the UI
154
+
155
+ If you are hosting the UI on a cloud provider or any network that is not secure, I highly recommend securing it with an auth token.
156
+ You can do this by setting the environment variable `AI_TOOLKIT_AUTH` to super secure password. This token will be required to access
157
+ the UI. You can set this when starting the UI like so:
158
+
159
+ ```bash
160
+ # Linux
161
+ AI_TOOLKIT_AUTH=super_secure_password npm run build_and_start
162
+
163
+ # Windows
164
+ set AI_TOOLKIT_AUTH=super_secure_password && npm run build_and_start
165
+
166
+ # Windows Powershell
167
+ $env:AI_TOOLKIT_AUTH="super_secure_password"; npm run build_and_start
168
+ ```
169
+
170
+
171
+ ## FLUX.1 Training
172
+
173
+ ### Tutorial
174
+
175
+ To get started quickly, check out [@araminta_k](https://x.com/araminta_k) tutorial on [Finetuning Flux Dev on a 3090](https://www.youtube.com/watch?v=HzGW_Kyermg) with 24GB VRAM.
176
+
177
+
178
+ ### Requirements
179
+ You currently need a GPU with **at least 24GB of VRAM** to train FLUX.1. If you are using it as your GPU to control
180
+ your monitors, you probably need to set the flag `low_vram: true` in the config file under `model:`. This will quantize
181
+ the model on CPU and should allow it to train with monitors attached. Users have gotten it to work on Windows with WSL,
182
+ but there are some reports of a bug when running on windows natively.
183
+ I have only tested on linux for now. This is still extremely experimental
184
+ and a lot of quantizing and tricks had to happen to get it to fit on 24GB at all.
185
+
186
+ ### FLUX.1-dev
187
+
188
+ FLUX.1-dev has a non-commercial license. Which means anything you train will inherit the
189
+ non-commercial license. It is also a gated model, so you need to accept the license on HF before using it.
190
+ Otherwise, this will fail. Here are the required steps to setup a license.
191
+
192
+ 1. Sign into HF and accept the model access here [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
193
+ 2. Make a file named `.env` in the root on this folder
194
+ 3. [Get a READ key from huggingface](https://huggingface.co/settings/tokens/new?) and add it to the `.env` file like so `HF_TOKEN=your_key_here`
195
+
196
+ ### FLUX.1-schnell
197
+
198
+ FLUX.1-schnell is Apache 2.0. Anything trained on it can be licensed however you want and it does not require a HF_TOKEN to train.
199
+ However, it does require a special adapter to train with it, [ostris/FLUX.1-schnell-training-adapter](https://huggingface.co/ostris/FLUX.1-schnell-training-adapter).
200
+ It is also highly experimental. For best overall quality, training on FLUX.1-dev is recommended.
201
+
202
+ To use it, You just need to add the assistant to the `model` section of your config file like so:
203
+
204
+ ```yaml
205
+ model:
206
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
207
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter"
208
+ is_flux: true
209
+ quantize: true
210
+ ```
211
+
212
+ You also need to adjust your sample steps since schnell does not require as many
213
+
214
+ ```yaml
215
+ sample:
216
+ guidance_scale: 1 # schnell does not do guidance
217
+ sample_steps: 4 # 1 - 4 works well
218
+ ```
219
+
220
+ ### Training
221
+ 1. Copy the example config file located at `config/examples/train_lora_flux_24gb.yaml` (`config/examples/train_lora_flux_schnell_24gb.yaml` for schnell) to the `config` folder and rename it to `whatever_you_want.yml`
222
+ 2. Edit the file following the comments in the file
223
+ 3. Run the file like so `python run.py config/whatever_you_want.yml`
224
+
225
+ A folder with the name and the training folder from the config file will be created when you start. It will have all
226
+ checkpoints and images in it. You can stop the training at any time using ctrl+c and when you resume, it will pick back up
227
+ from the last checkpoint.
228
+
229
+ IMPORTANT. If you press crtl+c while it is saving, it will likely corrupt that checkpoint. So wait until it is done saving
230
+
231
+ ### Need help?
232
+
233
+ Please do not open a bug report unless it is a bug in the code. You are welcome to [Join my Discord](https://discord.gg/VXmU2f5WEU)
234
+ and ask for help there. However, please refrain from PMing me directly with general question or support. Ask in the discord
235
+ and I will answer when I can.
236
+
237
+ ## Gradio UI
238
+
239
+ To get started training locally with a with a custom UI, once you followed the steps above and `ai-toolkit` is installed:
240
+
241
+ ```bash
242
+ cd ai-toolkit #in case you are not yet in the ai-toolkit folder
243
+ huggingface-cli login #provide a `write` token to publish your LoRA at the end
244
+ python flux_train_ui.py
245
+ ```
246
+
247
+ You will instantiate a UI that will let you upload your images, caption them, train and publish your LoRA
248
+ ![image](assets/lora_ease_ui.png)
249
+
250
+
251
+ ## Training in RunPod
252
+ Example RunPod template: **runpod/pytorch:2.2.0-py3.10-cuda12.1.1-devel-ubuntu22.04**
253
+ > You need a minimum of 24GB VRAM, pick a GPU by your preference.
254
+
255
+ #### Example config ($0.5/hr):
256
+ - 1x A40 (48 GB VRAM)
257
+ - 19 vCPU 100 GB RAM
258
+
259
+ #### Custom overrides (you need some storage to clone FLUX.1, store datasets, store trained models and samples):
260
+ - ~120 GB Disk
261
+ - ~120 GB Pod Volume
262
+ - Start Jupyter Notebook
263
+
264
+ ### 1. Setup
265
+ ```
266
+ git clone https://github.com/ostris/ai-toolkit.git
267
+ cd ai-toolkit
268
+ git submodule update --init --recursive
269
+ python -m venv venv
270
+ source venv/bin/activate
271
+ pip install torch
272
+ pip install -r requirements.txt
273
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
274
+ ```
275
+ ### 2. Upload your dataset
276
+ - Create a new folder in the root, name it `dataset` or whatever you like.
277
+ - Drag and drop your .jpg, .jpeg, or .png images and .txt files inside the newly created dataset folder.
278
+
279
+ ### 3. Login into Hugging Face with an Access Token
280
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
281
+ - Run ```huggingface-cli login``` and paste your token.
282
+
283
+ ### 4. Training
284
+ - Copy an example config file located at ```config/examples``` to the config folder and rename it to ```whatever_you_want.yml```.
285
+ - Edit the config following the comments in the file.
286
+ - Change ```folder_path: "/path/to/images/folder"``` to your dataset path like ```folder_path: "/workspace/ai-toolkit/your-dataset"```.
287
+ - Run the file: ```python run.py config/whatever_you_want.yml```.
288
+
289
+ ### Screenshot from RunPod
290
+ <img width="1728" alt="RunPod Training Screenshot" src="https://github.com/user-attachments/assets/53a1b8ef-92fa-4481-81a7-bde45a14a7b5">
291
+
292
+ ## Training in Modal
293
+
294
+ ### 1. Setup
295
+ #### ai-toolkit:
296
+ ```
297
+ git clone https://github.com/ostris/ai-toolkit.git
298
+ cd ai-toolkit
299
+ git submodule update --init --recursive
300
+ python -m venv venv
301
+ source venv/bin/activate
302
+ pip install torch
303
+ pip install -r requirements.txt
304
+ pip install --upgrade accelerate transformers diffusers huggingface_hub #Optional, run it if you run into issues
305
+ ```
306
+ #### Modal:
307
+ - Run `pip install modal` to install the modal Python package.
308
+ - Run `modal setup` to authenticate (if this doesn’t work, try `python -m modal setup`).
309
+
310
+ #### Hugging Face:
311
+ - Get a READ token from [here](https://huggingface.co/settings/tokens) and request access to Flux.1-dev model from [here](https://huggingface.co/black-forest-labs/FLUX.1-dev).
312
+ - Run `huggingface-cli login` and paste your token.
313
+
314
+ ### 2. Upload your dataset
315
+ - Drag and drop your dataset folder containing the .jpg, .jpeg, or .png images and .txt files in `ai-toolkit`.
316
+
317
+ ### 3. Configs
318
+ - Copy an example config file located at ```config/examples/modal``` to the `config` folder and rename it to ```whatever_you_want.yml```.
319
+ - Edit the config following the comments in the file, **<ins>be careful and follow the example `/root/ai-toolkit` paths</ins>**.
320
+
321
+ ### 4. Edit run_modal.py
322
+ - Set your entire local `ai-toolkit` path at `code_mount = modal.Mount.from_local_dir` like:
323
+
324
+ ```
325
+ code_mount = modal.Mount.from_local_dir("/Users/username/ai-toolkit", remote_path="/root/ai-toolkit")
326
+ ```
327
+ - Choose a `GPU` and `Timeout` in `@app.function` _(default is A100 40GB and 2 hour timeout)_.
328
+
329
+ ### 5. Training
330
+ - Run the config file in your terminal: `modal run run_modal.py --config-file-list-str=/root/ai-toolkit/config/whatever_you_want.yml`.
331
+ - You can monitor your training in your local terminal, or on [modal.com](https://modal.com/).
332
+ - Models, samples and optimizer will be stored in `Storage > flux-lora-models`.
333
+
334
+ ### 6. Saving the model
335
+ - Check contents of the volume by running `modal volume ls flux-lora-models`.
336
+ - Download the content by running `modal volume get flux-lora-models your-model-name`.
337
+ - Example: `modal volume get flux-lora-models my_first_flux_lora_v1`.
338
+
339
+ ### Screenshot from Modal
340
+
341
+ <img width="1728" alt="Modal Traning Screenshot" src="https://github.com/user-attachments/assets/7497eb38-0090-49d6-8ad9-9c8ea7b5388b">
342
+
343
  ---
344
 
345
+ ## Dataset Preparation
346
+
347
+ Datasets generally need to be a folder containing images and associated text files. Currently, the only supported
348
+ formats are jpg, jpeg, and png. Webp currently has issues. The text files should be named the same as the images
349
+ but with a `.txt` extension. For example `image2.jpg` and `image2.txt`. The text file should contain only the caption.
350
+ You can add the word `[trigger]` in the caption file and if you have `trigger_word` in your config, it will be automatically
351
+ replaced.
352
+
353
+ Images are never upscaled but they are downscaled and placed in buckets for batching. **You do not need to crop/resize your images**.
354
+ The loader will automatically resize them and can handle varying aspect ratios.
355
+
356
+
357
+ ## Training Specific Layers
358
+
359
+ To train specific layers with LoRA, you can use the `only_if_contains` network kwargs. For instance, if you want to train only the 2 layers
360
+ used by The Last Ben, [mentioned in this post](https://x.com/__TheBen/status/1829554120270987740), you can adjust your
361
+ network kwargs like so:
362
+
363
+ ```yaml
364
+ network:
365
+ type: "lora"
366
+ linear: 128
367
+ linear_alpha: 128
368
+ network_kwargs:
369
+ only_if_contains:
370
+ - "transformer.single_transformer_blocks.7.proj_out"
371
+ - "transformer.single_transformer_blocks.20.proj_out"
372
+ ```
373
+
374
+ The naming conventions of the layers are in diffusers format, so checking the state dict of a model will reveal
375
+ the suffix of the name of the layers you want to train. You can also use this method to only train specific groups of weights.
376
+ For instance to only train the `single_transformer` for FLUX.1, you can use the following:
377
+
378
+ ```yaml
379
+ network:
380
+ type: "lora"
381
+ linear: 128
382
+ linear_alpha: 128
383
+ network_kwargs:
384
+ only_if_contains:
385
+ - "transformer.single_transformer_blocks."
386
+ ```
387
+
388
+ You can also exclude layers by their names by using `ignore_if_contains` network kwarg. So to exclude all the single transformer blocks,
389
+
390
+
391
+ ```yaml
392
+ network:
393
+ type: "lora"
394
+ linear: 128
395
+ linear_alpha: 128
396
+ network_kwargs:
397
+ ignore_if_contains:
398
+ - "transformer.single_transformer_blocks."
399
+ ```
400
+
401
+ `ignore_if_contains` takes priority over `only_if_contains`. So if a weight is covered by both,
402
+ if will be ignored.
403
+
404
+ ## LoKr Training
405
+
406
+ To learn more about LoKr, read more about it at [KohakuBlueleaf/LyCORIS](https://github.com/KohakuBlueleaf/LyCORIS/blob/main/docs/Guidelines.md). To train a LoKr model, you can adjust the network type in the config file like so:
407
+
408
+ ```yaml
409
+ network:
410
+ type: "lokr"
411
+ lokr_full_rank: true
412
+ lokr_factor: 8
413
+ ```
414
+
415
+ Everything else should work the same including layer targeting.
416
+
assets/VAE_test1.jpg ADDED

Git LFS Details

  • SHA256: 879fcb537d039408d7aada297b7397420132684f0106edacc1205fb5cc839476
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
assets/glif.svg ADDED
assets/lora_ease_ui.png ADDED

Git LFS Details

  • SHA256: f647b9fe90cc96db2aa84d1cb25a73b60ffcc5394822f99e9dac27d373f89d79
  • Pointer size: 131 Bytes
  • Size of remote file: 349 kB
build_and_push_docker ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # Extract version from version.py
4
+ if [ -f "version.py" ]; then
5
+ VERSION=$(python3 -c "from version import VERSION; print(VERSION)")
6
+ echo "Building version: $VERSION"
7
+ else
8
+ echo "Error: version.py not found. Please create a version.py file with VERSION defined."
9
+ exit 1
10
+ fi
11
+
12
+ echo "Docker builds from the repo, not this dir. Make sure changes are pushed to the repo."
13
+ echo "Building version: $VERSION and latest"
14
+ # wait 2 seconds
15
+ sleep 2
16
+
17
+ # Build the image with cache busting
18
+ docker build --build-arg CACHEBUST=$(date +%s) -t aitoolkit:$VERSION -f docker/Dockerfile .
19
+
20
+ # Tag with version and latest
21
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:$VERSION
22
+ docker tag aitoolkit:$VERSION ostris/aitoolkit:latest
23
+
24
+ # Push both tags
25
+ echo "Pushing images to Docker Hub..."
26
+ docker push ostris/aitoolkit:$VERSION
27
+ docker push ostris/aitoolkit:latest
28
+
29
+ echo "Successfully built and pushed ostris/aitoolkit:$VERSION and ostris/aitoolkit:latest"
config/examples/extract.example.yml ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # this is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to read and write
4
+ # plus it has comments which is nice for documentation
5
+ job: extract # tells the runner what to do
6
+ config:
7
+ # the name will be used to create a folder in the output folder
8
+ # it will also replace any [name] token in the rest of this config
9
+ name: name_of_your_model
10
+ # can be hugging face model, a .ckpt, or a .safetensors
11
+ base_model: "/path/to/base/model.safetensors"
12
+ # can be hugging face model, a .ckpt, or a .safetensors
13
+ extract_model: "/path/to/model/to/extract/trained.safetensors"
14
+ # we will create folder here with name above so. This will create /path/to/output/folder/name_of_your_model
15
+ output_folder: "/path/to/output/folder"
16
+ is_v2: false
17
+ dtype: fp16 # saved dtype
18
+ device: cpu # cpu, cuda:0, etc
19
+
20
+ # processes can be chained like this to run multiple in a row
21
+ # they must all use same models above, but great for testing different
22
+ # sizes and typed of extractions. It is much faster as we already have the models loaded
23
+ process:
24
+ # process 1
25
+ - type: locon # locon or lora (locon is lycoris)
26
+ filename: "[name]_64_32.safetensors" # will be put in output folder
27
+ dtype: fp16
28
+ mode: fixed
29
+ linear: 64
30
+ conv: 32
31
+
32
+ # process 2
33
+ - type: locon
34
+ output_path: "/absolute/path/for/this/output.safetensors" # can be absolute
35
+ mode: ratio
36
+ linear: 0.2
37
+ conv: 0.2
38
+
39
+ # process 3
40
+ - type: locon
41
+ filename: "[name]_ratio_02.safetensors"
42
+ mode: quantile
43
+ linear: 0.5
44
+ conv: 0.5
45
+
46
+ # process 4
47
+ - type: lora # traditional lora extraction (lierla) with linear layers only
48
+ filename: "[name]_4.safetensors"
49
+ mode: fixed # fixed, ratio, quantile supported for lora as well
50
+ linear: 4 # lora dim or rank
51
+ # no conv for lora
52
+
53
+ # process 5
54
+ - type: lora
55
+ filename: "[name]_q05.safetensors"
56
+ mode: quantile
57
+ linear: 0.5
58
+
59
+ # you can put any information you want here, and it will be saved in the model
60
+ # the below is an example. I recommend doing trigger words at a minimum
61
+ # in the metadata. The software will include this plus some other information
62
+ meta:
63
+ name: "[name]" # [name] gets replaced with the name above
64
+ description: A short description of your model
65
+ trigger_words:
66
+ - put
67
+ - trigger
68
+ - words
69
+ - here
70
+ version: '0.1'
71
+ creator:
72
+ name: Your Name
73
74
+ website: https://yourwebsite.com
75
+ any: All meta data above is arbitrary, it can be whatever you want.
config/examples/generate.example.yaml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+
3
+ job: generate # tells the runner what to do
4
+ config:
5
+ name: "generate" # this is not really used anywhere currently but required by runner
6
+ process:
7
+ # process 1
8
+ - type: to_folder # process images to a folder
9
+ output_folder: "output/gen"
10
+ device: cuda:0 # cpu, cuda:0, etc
11
+ generate:
12
+ # these are your defaults you can override most of them with flags
13
+ sampler: "ddpm" # ignored for now, will add later though ddpm is used regardless for now
14
+ width: 1024
15
+ height: 1024
16
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime"
17
+ seed: -1 # -1 is random
18
+ guidance_scale: 7
19
+ sample_steps: 20
20
+ ext: ".png" # .png, .jpg, .jpeg, .webp
21
+
22
+ # here ate the flags you can use for prompts. Always start with
23
+ # your prompt first then add these flags after. You can use as many
24
+ # like
25
+ # photo of a baseball --n painting, ugly --w 1024 --h 1024 --seed 42 --cfg 7 --steps 20
26
+ # we will try to support all sd-scripts flags where we can
27
+
28
+ # FROM SD-SCRIPTS
29
+ # --n Treat everything until the next option as a negative prompt.
30
+ # --w Specify the width of the generated image.
31
+ # --h Specify the height of the generated image.
32
+ # --d Specify the seed for the generated image.
33
+ # --l Specify the CFG scale for the generated image.
34
+ # --s Specify the number of steps during generation.
35
+
36
+ # OURS and some QOL additions
37
+ # --p2 Prompt for the second text encoder (SDXL only)
38
+ # --n2 Negative prompt for the second text encoder (SDXL only)
39
+ # --gr Specify the guidance rescale for the generated image (SDXL only)
40
+ # --seed Specify the seed for the generated image same as --d
41
+ # --cfg Specify the CFG scale for the generated image same as --l
42
+ # --steps Specify the number of steps during generation same as --s
43
+
44
+ prompt_file: false # if true a txt file will be created next to images with prompt strings used
45
+ # prompts can also be a path to a text file with one prompt per line
46
+ # prompts: "/path/to/prompts.txt"
47
+ prompts:
48
+ - "photo of batman"
49
+ - "photo of superman"
50
+ - "photo of spiderman"
51
+ - "photo of a superhero --n batman superman spiderman"
52
+
53
+ model:
54
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
55
+ # name_or_path: "runwayml/stable-diffusion-v1-5"
56
+ name_or_path: "/mnt/Models/stable-diffusion/models/stable-diffusion/Ostris/Ostris_Real_v1.safetensors"
57
+ is_v2: false # for v2 models
58
+ is_v_pred: false # for v-prediction models (most v2 models)
59
+ is_xl: false # for SDXL models
60
+ dtype: bf16
config/examples/mod_lora_scale.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: mod
3
+ config:
4
+ name: name_of_your_model_v1
5
+ process:
6
+ - type: rescale_lora
7
+ # path to your current lora model
8
+ input_path: "/path/to/lora/lora.safetensors"
9
+ # output path for your new lora model, can be the same as input_path to replace
10
+ output_path: "/path/to/lora/output_lora_v1.safetensors"
11
+ # replaces meta with the meta below (plus minimum meta fields)
12
+ # if false, we will leave the meta alone except for updating hashes (sd-script hashes)
13
+ replace_meta: true
14
+ # how to adjust, we can scale the up_down weights or the alpha
15
+ # up_down is the default and probably the best, they will both net the same outputs
16
+ # would only affect rare NaN cases and maybe merging with old merge tools
17
+ scale_target: 'up_down'
18
+ # precision to save, fp16 is the default and standard
19
+ save_dtype: fp16
20
+ # current_weight is the ideal weight you use as a multiplier when using the lora
21
+ # IE in automatic1111 <lora:my_lora:6.0> the 6.0 is the current_weight
22
+ # you can do negatives here too if you want to flip the lora
23
+ current_weight: 6.0
24
+ # target_weight is the ideal weight you use as a multiplier when using the lora
25
+ # instead of the one above. IE in automatic1111 instead of using <lora:my_lora:6.0>
26
+ # we want to use <lora:my_lora:1.0> so 1.0 is the target_weight
27
+ target_weight: 1.0
28
+
29
+ # base model for the lora
30
+ # this is just used to add meta so automatic111 knows which model it is for
31
+ # assume v1.5 if these are not set
32
+ is_xl: false
33
+ is_v2: false
34
+ meta:
35
+ # this is only used if you set replace_meta to true above
36
+ name: "[name]" # [name] gets replaced with the name above
37
+ description: A short description of your lora
38
+ trigger_words:
39
+ - put
40
+ - trigger
41
+ - words
42
+ - here
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
config/examples/modal/modal_train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the model locally and
65
+ # place it like "/root/ai-toolkit/FLUX.1-dev"
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/modal/modal_train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "/root/ai-toolkit/modal_output" # must match MOUNT_DIR from run_modal.py
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ datasets:
25
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
26
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
27
+ # images will automatically be resized and bucketed into the resolution specified
28
+ # on windows, escape back slashes with another backslash so
29
+ # "C:\\path\\to\\images\\folder"
30
+ # your dataset must be placed in /ai-toolkit and /root is for modal to find the dir:
31
+ - folder_path: "/root/ai-toolkit/your-dataset"
32
+ caption_ext: "txt"
33
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
34
+ shuffle_tokens: false # shuffle caption order, split by commas
35
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
36
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
37
+ train:
38
+ batch_size: 1
39
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
40
+ gradient_accumulation_steps: 1
41
+ train_unet: true
42
+ train_text_encoder: false # probably won't work with flux
43
+ gradient_checkpointing: true # need the on unless you have a ton of vram
44
+ noise_scheduler: "flowmatch" # for training only
45
+ optimizer: "adamw8bit"
46
+ lr: 1e-4
47
+ # uncomment this to skip the pre training sample
48
+ # skip_first_sample: true
49
+ # uncomment to completely disable sampling
50
+ # disable_sampling: true
51
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
52
+ # linear_timesteps: true
53
+
54
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
55
+ ema_config:
56
+ use_ema: true
57
+ ema_decay: 0.99
58
+
59
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
60
+ dtype: bf16
61
+ model:
62
+ # huggingface model name or path
63
+ # if you get an error, or get stuck while downloading,
64
+ # check https://github.com/ostris/ai-toolkit/issues/84, download the models locally and
65
+ # place them like "/root/ai-toolkit/FLUX.1-schnell" and "/root/ai-toolkit/FLUX.1-schnell-training-adapter"
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
config/examples/train_flex_redux.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_redux_finetune_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ adapter:
14
+ type: "redux"
15
+ # you can finetune an existing adapter or start from scratch. Set to null to start from scratch
16
+ name_or_path: '/local/path/to/redux_adapter_to_finetune.safetensors'
17
+ # name_or_path: null
18
+ # image_encoder_path: 'google/siglip-so400m-patch14-384' # Flux.1 redux adapter
19
+ image_encoder_path: 'google/siglip2-so400m-patch16-512' # Flex.1 512 redux adapter
20
+ # image_encoder_arch: 'siglip' # for Flux.1
21
+ image_encoder_arch: 'siglip2'
22
+ # You need a control input for each sample. Best to do squares for both images
23
+ test_img_path:
24
+ - "/path/to/x_01.jpg"
25
+ - "/path/to/x_02.jpg"
26
+ - "/path/to/x_03.jpg"
27
+ - "/path/to/x_04.jpg"
28
+ - "/path/to/x_05.jpg"
29
+ - "/path/to/x_06.jpg"
30
+ - "/path/to/x_07.jpg"
31
+ - "/path/to/x_08.jpg"
32
+ - "/path/to/x_09.jpg"
33
+ - "/path/to/x_10.jpg"
34
+ clip_layer: 'last_hidden_state'
35
+ train: true
36
+ save:
37
+ dtype: bf16 # precision to save
38
+ save_every: 250 # save every this many steps
39
+ max_step_saves_to_keep: 4
40
+ datasets:
41
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
42
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
43
+ # images will automatically be resized and bucketed into the resolution specified
44
+ # on windows, escape back slashes with another backslash so
45
+ # "C:\\path\\to\\images\\folder"
46
+ - folder_path: "/path/to/images/folder"
47
+ # clip_image_path is directory containting your control images. They must have filename as their train image. (extension does not matter)
48
+ # for normal redux, we are just recreating the same image, so you can use the same folder path above
49
+ clip_image_path: "/path/to/control/images/folder"
50
+ caption_ext: "txt"
51
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
52
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
53
+ train:
54
+ # this is what I used for the 24GB card, but feel free to adjust
55
+ # total batch size is 6 here
56
+ batch_size: 3
57
+ gradient_accumulation: 2
58
+
59
+ # captions are not needed for this training, we cache a blank proompt and rely on the vision encoder
60
+ unload_text_encoder: true
61
+
62
+ loss_type: "mse"
63
+ train_unet: true
64
+ train_text_encoder: false
65
+ steps: 4000000 # I set this very high and stop when I like the results
66
+ content_or_style: balanced # content, style, balanced
67
+ gradient_checkpointing: true
68
+ noise_scheduler: "flowmatch" # or "ddpm", "lms", "euler_a"
69
+ timestep_type: "flux_shift"
70
+ optimizer: "adamw8bit"
71
+ lr: 1e-4
72
+
73
+ # this is for Flex.1, comment this out for FLUX.1-dev
74
+ bypass_guidance_embedding: true
75
+
76
+ dtype: bf16
77
+ ema_config:
78
+ use_ema: true
79
+ ema_decay: 0.99
80
+ model:
81
+ name_or_path: "ostris/Flex.1-alpha"
82
+ is_flux: true
83
+ quantize: true
84
+ text_encoder_bits: 8
85
+ sample:
86
+ sampler: "flowmatch" # must match train.noise_scheduler
87
+ sample_every: 250 # sample every this many steps
88
+ width: 1024
89
+ height: 1024
90
+ # I leave half blank to test prompt and unprompted
91
+ prompts:
92
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
93
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
94
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
95
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
96
+ - "a bear building a log cabin in the snow covered mountains"
97
+ - ""
98
+ - ""
99
+ - ""
100
+ - ""
101
+ - ""
102
+ neg: ""
103
+ seed: 42
104
+ walk_seed: true
105
+ guidance_scale: 4
106
+ sample_steps: 25
107
+ network_multiplier: 1.0
108
+
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
config/examples/train_full_fine_tune_flex.yaml ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 48GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_flex_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
37
+ bypass_guidance_embedding: true
38
+
39
+ # can be 'sigmoid', 'linear', or 'lognorm_blend'
40
+ timestep_type: 'sigmoid'
41
+
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flex
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adafactor"
49
+ lr: 3e-5
50
+
51
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
52
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
53
+
54
+ # do_paramiter_swapping: true
55
+ # paramiter_swapping_factor: 0.9
56
+
57
+ # uncomment this to skip the pre training sample
58
+ # skip_first_sample: true
59
+ # uncomment to completely disable sampling
60
+ # disable_sampling: true
61
+
62
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
63
+ ema_config:
64
+ use_ema: true
65
+ ema_decay: 0.99
66
+
67
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
68
+ dtype: bf16
69
+ model:
70
+ # huggingface model name or path
71
+ name_or_path: "ostris/Flex.1-alpha"
72
+ is_flux: true # flex is flux architecture
73
+ # full finetuning quantized models is a crapshoot and results in subpar outputs
74
+ # quantize: true
75
+ # you can quantize just the T5 text encoder here to save vram
76
+ quantize_te: true
77
+ # only train the transformer blocks
78
+ only_if_contains:
79
+ - "transformer.transformer_blocks."
80
+ - "transformer.single_transformer_blocks."
81
+ sample:
82
+ sampler: "flowmatch" # must match train.noise_scheduler
83
+ sample_every: 250 # sample every this many steps
84
+ width: 1024
85
+ height: 1024
86
+ prompts:
87
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
88
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
89
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
90
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
91
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
92
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
93
+ - "a bear building a log cabin in the snow covered mountains"
94
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
95
+ - "hipster man with a beard, building a chair, in a wood shop"
96
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
97
+ - "a man holding a sign that says, 'this is a sign'"
98
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
99
+ neg: "" # not used on flex
100
+ seed: 42
101
+ walk_seed: true
102
+ guidance_scale: 4
103
+ sample_steps: 25
104
+ # you can add any additional meta info here. [name] is replaced with config name at top
105
+ meta:
106
+ name: "[name]"
107
+ version: '1.0'
config/examples/train_full_fine_tune_lumina.yaml ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 24GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_finetune_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ save:
18
+ dtype: bf16 # precision to save
19
+ save_every: 250 # save every this many steps
20
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
21
+ save_format: 'diffusers' # 'diffusers'
22
+ datasets:
23
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
24
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
25
+ # images will automatically be resized and bucketed into the resolution specified
26
+ # on windows, escape back slashes with another backslash so
27
+ # "C:\\path\\to\\images\\folder"
28
+ - folder_path: "/path/to/images/folder"
29
+ caption_ext: "txt"
30
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
31
+ shuffle_tokens: false # shuffle caption order, split by commas
32
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
33
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
34
+ train:
35
+ batch_size: 1
36
+
37
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
38
+ timestep_type: 'lumina2_shift'
39
+
40
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
41
+ gradient_accumulation: 1
42
+ train_unet: true
43
+ train_text_encoder: false # probably won't work with lumina2
44
+ gradient_checkpointing: true # need the on unless you have a ton of vram
45
+ noise_scheduler: "flowmatch" # for training only
46
+ optimizer: "adafactor"
47
+ lr: 3e-5
48
+
49
+ # Paramiter swapping can reduce vram requirements. Set factor from 1.0 to 0.0.
50
+ # 0.1 is 10% of paramiters active at easc step. Only works with adafactor
51
+
52
+ # do_paramiter_swapping: true
53
+ # paramiter_swapping_factor: 0.9
54
+
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
61
+ # ema_config:
62
+ # use_ema: true
63
+ # ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
70
+ is_lumina2: true # lumina2 architecture
71
+ # you can quantize just the Gemma2 text encoder here to save vram
72
+ quantize_te: true
73
+ sample:
74
+ sampler: "flowmatch" # must match train.noise_scheduler
75
+ sample_every: 250 # sample every this many steps
76
+ width: 1024
77
+ height: 1024
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
82
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
83
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
84
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
85
+ - "a bear building a log cabin in the snow covered mountains"
86
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
87
+ - "hipster man with a beard, building a chair, in a wood shop"
88
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
89
+ - "a man holding a sign that says, 'this is a sign'"
90
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
91
+ neg: ""
92
+ seed: 42
93
+ walk_seed: true
94
+ guidance_scale: 4.0
95
+ sample_steps: 25
96
+ # you can add any additional meta info here. [name] is replaced with config name at top
97
+ meta:
98
+ name: "[name]"
99
+ version: '1.0'
config/examples/train_lora_chroma_24gb.yaml ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_chroma_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # chroma enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with chroma
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for chroma, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # Download the whichever model you prefer from the Chroma repo
66
+ # https://huggingface.co/lodestones/Chroma/tree/main
67
+ # point to it here.
68
+ # name_or_path: "/path/to/chroma/chroma-unlocked-vVERSION.safetensors"
69
+
70
+ # using lodestones/Chroma will automatically use the latest version
71
+ name_or_path: "lodestones/Chroma"
72
+
73
+ # # You can also select a version of Chroma like so
74
+ # name_or_path: "lodestones/Chroma/v28"
75
+
76
+ arch: "chroma"
77
+ quantize: true # run 8bit mixed precision
78
+ sample:
79
+ sampler: "flowmatch" # must match train.noise_scheduler
80
+ sample_every: 250 # sample every this many steps
81
+ width: 1024
82
+ height: 1024
83
+ prompts:
84
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
85
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
86
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
87
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
88
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
89
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
90
+ - "a bear building a log cabin in the snow covered mountains"
91
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
92
+ - "hipster man with a beard, building a chair, in a wood shop"
93
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
94
+ - "a man holding a sign that says, 'this is a sign'"
95
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
96
+ neg: "" # negative prompt, optional
97
+ seed: 42
98
+ walk_seed: true
99
+ guidance_scale: 4
100
+ sample_steps: 25
101
+ # you can add any additional meta info here. [name] is replaced with config name at top
102
+ meta:
103
+ name: "[name]"
104
+ version: '1.0'
config/examples/train_lora_flex2_24gb.yaml ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Note, Flex2 is a highly experimental WIP model. Finetuning a model with built in controls and inpainting has not
2
+ # been done before, so you will be experimenting with me on how to do it. This is my recommended setup, but this is highly
3
+ # subject to change as we learn more about how Flex2 works.
4
+
5
+ ---
6
+ job: extension
7
+ config:
8
+ # this name will be the folder and filename name
9
+ name: "my_first_flex2_lora_v1"
10
+ process:
11
+ - type: 'sd_trainer'
12
+ # root folder to save training sessions/samples/weights
13
+ training_folder: "output"
14
+ # uncomment to see performance stats in the terminal every N steps
15
+ # performance_log_every: 1000
16
+ device: cuda:0
17
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
18
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
19
+ # trigger_word: "p3r5on"
20
+ network:
21
+ type: "lora"
22
+ linear: 32
23
+ linear_alpha: 32
24
+ save:
25
+ dtype: float16 # precision to save
26
+ save_every: 250 # save every this many steps
27
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
28
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
29
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
30
+ # hf_repo_id: your-username/your-model-slug
31
+ # hf_private: true #whether the repo is private or public
32
+ datasets:
33
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
34
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
35
+ # images will automatically be resized and bucketed into the resolution specified
36
+ # on windows, escape back slashes with another backslash so
37
+ # "C:\\path\\to\\images\\folder"
38
+ - folder_path: "/path/to/images/folder"
39
+ # Flex2 is trained with controls and inpainting. If you want the model to truely understand how the
40
+ # controls function with your dataset, it is a good idea to keep doing controls during training.
41
+ # this will automatically generate the controls for you before training. The current script is not
42
+ # fully optimized so this could be rather slow for large datasets, but it caches them to disk so it
43
+ # only needs to be done once. If you want to skip this step, you can set the controls to [] and it will
44
+ controls:
45
+ - "depth"
46
+ - "line"
47
+ - "pose"
48
+ - "inpaint"
49
+
50
+ # you can make custom inpainting images as well. These images must be webp or png format with an alpha.
51
+ # just erase the part of the image you want to inpaint and save it as a webp or png. Again, erase your
52
+ # train target. So the person if training a person. The automatic controls above with inpaint will
53
+ # just run a background remover mask and erase the foreground, which works well for subjects.
54
+
55
+ # inpaint_path: "/my/impaint/images"
56
+
57
+ # you can also specify existing control image pairs. It can handle multiple groups and will randomly
58
+ # select one for each step.
59
+
60
+ # control_path:
61
+ # - "/my/custom/control/images"
62
+ # - "/my/custom/control/images2"
63
+
64
+ caption_ext: "txt"
65
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
66
+ resolution: [ 512, 768, 1024 ] # flex2 enjoys multiple resolutions
67
+ train:
68
+ batch_size: 1
69
+ # IMPORTANT! For Flex2, you must bypass the guidance embedder during training
70
+ bypass_guidance_embedding: true
71
+
72
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
73
+ gradient_accumulation: 1
74
+ train_unet: true
75
+ train_text_encoder: false # probably won't work with flex2
76
+ gradient_checkpointing: true # need the on unless you have a ton of vram
77
+ noise_scheduler: "flowmatch" # for training only
78
+ # shift works well for training fast and learning composition and style.
79
+ # for just subject, you may want to change this to sigmoid
80
+ timestep_type: 'shift' # 'linear', 'sigmoid', 'shift'
81
+ optimizer: "adamw8bit"
82
+ lr: 1e-4
83
+
84
+ optimizer_params:
85
+ weight_decay: 1e-5
86
+ # uncomment this to skip the pre training sample
87
+ # skip_first_sample: true
88
+ # uncomment to completely disable sampling
89
+ # disable_sampling: true
90
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
91
+ # linear_timesteps: true
92
+
93
+ # ema will smooth out learning, but could slow it down. Defaults off
94
+ ema_config:
95
+ use_ema: false
96
+ ema_decay: 0.99
97
+
98
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
99
+ dtype: bf16
100
+ model:
101
+ # huggingface model name or path
102
+ name_or_path: "ostris/Flex.2-preview"
103
+ arch: "flex2"
104
+ quantize: true # run 8bit mixed precision
105
+ quantize_te: true
106
+
107
+ # you can pass special training infor for controls to the model here
108
+ # percentages are decimal based so 0.0 is 0% and 1.0 is 100% of the time.
109
+ model_kwargs:
110
+ # inverts the inpainting mask, good to learn outpainting as well, recommended 0.0 for characters
111
+ invert_inpaint_mask_chance: 0.5
112
+ # this will do a normal t2i training step without inpaint when dropped out. REcommended if you want
113
+ # your lora to be able to inference with and without inpainting.
114
+ inpaint_dropout: 0.5
115
+ # randomly drops out the control image. Dropout recvommended if your want it to work without controls as well.
116
+ control_dropout: 0.5
117
+ # does a random inpaint blob. Usually a good idea to keep. Without it, the model will learn to always 100%
118
+ # fill the inpaint area with your subject. This is not always a good thing.
119
+ inpaint_random_chance: 0.5
120
+ # generates random inpaint blobs if you did not provide an inpaint image for your dataset. Inpaint breaks down fast
121
+ # if you are not training with it. Controls are a little more robust and can be left out,
122
+ # but when in doubt, always leave this on
123
+ do_random_inpainting: false
124
+ # does random blurring of the inpaint mask. Helps prevent weird edge artifacts for real workd inpainting. Leave on.
125
+ random_blur_mask: true
126
+ # applies a small amount of random dialition and restriction to the inpaint mask. Helps with edge artifacts.
127
+ # Leave on.
128
+ random_dialate_mask: true
129
+ sample:
130
+ sampler: "flowmatch" # must match train.noise_scheduler
131
+ sample_every: 250 # sample every this many steps
132
+ width: 1024
133
+ height: 1024
134
+ prompts:
135
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
136
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
137
+
138
+ # you can use a single inpaint or single control image on your samples.
139
+ # for controls, the ctrl_idx is 1, the images can be any name and image format.
140
+ # use either a pose/line/depth image or whatever you are training with. An example is
141
+ # - "photo of [trigger] --ctrl_idx 1 --ctrl_img /path/to/control/image.jpg"
142
+
143
+ # for an inpainting image, it must be png/webp. Erase the part of the image you want to inpaint
144
+ # IMPORTANT! the inpaint images must be ctrl_idx 0 and have .inpaint.{ext} in the name for this to work right.
145
+ # - "photo of [trigger] --ctrl_idx 0 --ctrl_img /path/to/inpaint/image.inpaint.png"
146
+
147
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
148
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
149
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
150
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
151
+ - "a bear building a log cabin in the snow covered mountains"
152
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
153
+ - "hipster man with a beard, building a chair, in a wood shop"
154
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
155
+ - "a man holding a sign that says, 'this is a sign'"
156
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
157
+ neg: "" # not used on flex2
158
+ seed: 42
159
+ walk_seed: true
160
+ guidance_scale: 4
161
+ sample_steps: 25
162
+ # you can add any additional meta info here. [name] is replaced with config name at top
163
+ meta:
164
+ name: "[name]"
165
+ version: '1.0'
config/examples/train_lora_flex_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flex_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flex enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ # IMPORTANT! For Flex, you must bypass the guidance embedder during training
43
+ bypass_guidance_embedding: true
44
+
45
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
46
+ gradient_accumulation: 1
47
+ train_unet: true
48
+ train_text_encoder: false # probably won't work with flex
49
+ gradient_checkpointing: true # need the on unless you have a ton of vram
50
+ noise_scheduler: "flowmatch" # for training only
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ # uncomment this to skip the pre training sample
54
+ # skip_first_sample: true
55
+ # uncomment to completely disable sampling
56
+ # disable_sampling: true
57
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
58
+ # linear_timesteps: true
59
+
60
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
61
+ ema_config:
62
+ use_ema: true
63
+ ema_decay: 0.99
64
+
65
+ # will probably need this if gpu supports it for flex, other dtypes may not work correctly
66
+ dtype: bf16
67
+ model:
68
+ # huggingface model name or path
69
+ name_or_path: "ostris/Flex.1-alpha"
70
+ is_flux: true
71
+ quantize: true # run 8bit mixed precision
72
+ quantize_kwargs:
73
+ exclude:
74
+ - "*time_text_embed*" # exclude the time text embedder from quantization
75
+ sample:
76
+ sampler: "flowmatch" # must match train.noise_scheduler
77
+ sample_every: 250 # sample every this many steps
78
+ width: 1024
79
+ height: 1024
80
+ prompts:
81
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
82
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
83
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
84
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
85
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
86
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
87
+ - "a bear building a log cabin in the snow covered mountains"
88
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
89
+ - "hipster man with a beard, building a chair, in a wood shop"
90
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
91
+ - "a man holding a sign that says, 'this is a sign'"
92
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
93
+ neg: "" # not used on flex
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 4
97
+ sample_steps: 25
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
config/examples/train_lora_flux_24gb.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-dev"
67
+ is_flux: true
68
+ quantize: true # run 8bit mixed precision
69
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: "" # not used on flux
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4
92
+ sample_steps: 20
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/train_lora_flux_schnell_24gb.yaml ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_flux_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 16
19
+ linear_alpha: 16
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ - folder_path: "/path/to/images/folder"
35
+ caption_ext: "txt"
36
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
37
+ shuffle_tokens: false # shuffle caption order, split by commas
38
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
39
+ resolution: [ 512, 768, 1024 ] # flux enjoys multiple resolutions
40
+ train:
41
+ batch_size: 1
42
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
43
+ gradient_accumulation_steps: 1
44
+ train_unet: true
45
+ train_text_encoder: false # probably won't work with flux
46
+ gradient_checkpointing: true # need the on unless you have a ton of vram
47
+ noise_scheduler: "flowmatch" # for training only
48
+ optimizer: "adamw8bit"
49
+ lr: 1e-4
50
+ # uncomment this to skip the pre training sample
51
+ # skip_first_sample: true
52
+ # uncomment to completely disable sampling
53
+ # disable_sampling: true
54
+ # uncomment to use new bell curved weighting. Experimental but may produce better results
55
+ # linear_timesteps: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for flux, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "black-forest-labs/FLUX.1-schnell"
67
+ assistant_lora_path: "ostris/FLUX.1-schnell-training-adapter" # Required for flux schnell training
68
+ is_flux: true
69
+ quantize: true # run 8bit mixed precision
70
+ # low_vram is painfully slow to fuse in the adapter avoid it unless absolutely necessary
71
+ # low_vram: true # uncomment this if the GPU is connected to your monitors. It will use less vram to quantize, but is slower.
72
+ sample:
73
+ sampler: "flowmatch" # must match train.noise_scheduler
74
+ sample_every: 250 # sample every this many steps
75
+ width: 1024
76
+ height: 1024
77
+ prompts:
78
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
79
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
80
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
81
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
82
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
83
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
84
+ - "a bear building a log cabin in the snow covered mountains"
85
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
86
+ - "hipster man with a beard, building a chair, in a wood shop"
87
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
88
+ - "a man holding a sign that says, 'this is a sign'"
89
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
90
+ neg: "" # not used on flux
91
+ seed: 42
92
+ walk_seed: true
93
+ guidance_scale: 1 # schnell does not do guidance
94
+ sample_steps: 4 # 1 - 4 works well
95
+ # you can add any additional meta info here. [name] is replaced with config name at top
96
+ meta:
97
+ name: "[name]"
98
+ version: '1.0'
config/examples/train_lora_hidream_48.yaml ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # HiDream training is still highly experimental. The settings here will take ~35.2GB of vram to train.
2
+ # It is not possible to train on a single 24GB card yet, but I am working on it. If you have more VRAM
3
+ # I highly recommend first disabling quantization on the model itself if you can. You can leave the TEs quantized.
4
+ # HiDream has a mixture of experts that may take special training considerations that I do not
5
+ # have implemented properly. The current implementation seems to work well for LoRA training, but
6
+ # may not be effective for longer training runs. The implementation could change in future updates
7
+ # so your results may vary when this happens.
8
+
9
+ ---
10
+ job: extension
11
+ config:
12
+ # this name will be the folder and filename name
13
+ name: "my_first_hidream_lora_v1"
14
+ process:
15
+ - type: 'sd_trainer'
16
+ # root folder to save training sessions/samples/weights
17
+ training_folder: "output"
18
+ # uncomment to see performance stats in the terminal every N steps
19
+ # performance_log_every: 1000
20
+ device: cuda:0
21
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
22
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
23
+ # trigger_word: "p3r5on"
24
+ network:
25
+ type: "lora"
26
+ linear: 32
27
+ linear_alpha: 32
28
+ network_kwargs:
29
+ # it is probably best to ignore the mixture of experts since only 2 are active each block. It works activating it, but I wouldnt.
30
+ # proper training of it is not fully implemented
31
+ ignore_if_contains:
32
+ - "ff_i.experts"
33
+ - "ff_i.gate"
34
+ save:
35
+ dtype: bfloat16 # precision to save
36
+ save_every: 250 # save every this many steps
37
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
38
+ datasets:
39
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
40
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
41
+ # images will automatically be resized and bucketed into the resolution specified
42
+ # on windows, escape back slashes with another backslash so
43
+ # "C:\\path\\to\\images\\folder"
44
+ - folder_path: "/path/to/images/folder"
45
+ caption_ext: "txt"
46
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
47
+ resolution: [ 512, 768, 1024 ] # hidream enjoys multiple resolutions
48
+ train:
49
+ batch_size: 1
50
+ steps: 3000 # total number of steps to train 500 - 4000 is a good range
51
+ gradient_accumulation_steps: 1
52
+ train_unet: true
53
+ train_text_encoder: false # wont work with hidream
54
+ gradient_checkpointing: true # need the on unless you have a ton of vram
55
+ noise_scheduler: "flowmatch" # for training only
56
+ timestep_type: shift # sigmoid, shift, linear
57
+ optimizer: "adamw8bit"
58
+ lr: 2e-4
59
+ # uncomment this to skip the pre training sample
60
+ # skip_first_sample: true
61
+ # uncomment to completely disable sampling
62
+ # disable_sampling: true
63
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
64
+ # linear_timesteps: true
65
+
66
+ # ema will smooth out learning, but could slow it down. Defaults off
67
+ ema_config:
68
+ use_ema: false
69
+ ema_decay: 0.99
70
+
71
+ # will probably need this if gpu supports it for hidream, other dtypes may not work correctly
72
+ dtype: bf16
73
+ model:
74
+ # the transformer will get grabbed from this hf repo
75
+ # warning ONLY train on Full. The dev and fast models are distilled and will break
76
+ name_or_path: "HiDream-ai/HiDream-I1-Full"
77
+ # the extras will be grabbed from this hf repo. (text encoder, vae)
78
+ extras_name_or_path: "HiDream-ai/HiDream-I1-Full"
79
+ arch: "hidream"
80
+ # both need to be quantized to train on 48GB currently
81
+ quantize: true
82
+ quantize_te: true
83
+ model_kwargs:
84
+ # llama is a gated model, It defaults to unsloth version, but you can set the llama path here
85
+ llama_model_path: "unsloth/Meta-Llama-3.1-8B-Instruct"
86
+ sample:
87
+ sampler: "flowmatch" # must match train.noise_scheduler
88
+ sample_every: 250 # sample every this many steps
89
+ width: 1024
90
+ height: 1024
91
+ prompts:
92
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
93
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
94
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
95
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
96
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
97
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
98
+ - "a bear building a log cabin in the snow covered mountains"
99
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
100
+ - "hipster man with a beard, building a chair, in a wood shop"
101
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
102
+ - "a man holding a sign that says, 'this is a sign'"
103
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
104
+ neg: ""
105
+ seed: 42
106
+ walk_seed: true
107
+ guidance_scale: 4
108
+ sample_steps: 25
109
+ # you can add any additional meta info here. [name] is replaced with config name at top
110
+ meta:
111
+ name: "[name]"
112
+ version: '1.0'
config/examples/train_lora_lumina.yaml ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This configuration requires 20GB of VRAM or more to operate
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_lumina_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: bf16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 2 # how many intermittent saves to keep
25
+ save_format: 'diffusers' # 'diffusers'
26
+ datasets:
27
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
28
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
29
+ # images will automatically be resized and bucketed into the resolution specified
30
+ # on windows, escape back slashes with another backslash so
31
+ # "C:\\path\\to\\images\\folder"
32
+ - folder_path: "/path/to/images/folder"
33
+ caption_ext: "txt"
34
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
35
+ shuffle_tokens: false # shuffle caption order, split by commas
36
+ # cache_latents_to_disk: true # leave this true unless you know what you're doing
37
+ resolution: [ 512, 768, 1024 ] # lumina2 enjoys multiple resolutions
38
+ train:
39
+ batch_size: 1
40
+
41
+ # can be 'sigmoid', 'linear', or 'lumina2_shift'
42
+ timestep_type: 'lumina2_shift'
43
+
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with lumina2
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+
57
+ # ema will smooth out learning, but could slow it down. Recommended to leave on if you have the vram
58
+ ema_config:
59
+ use_ema: true
60
+ ema_decay: 0.99
61
+
62
+ # will probably need this if gpu supports it for lumina2, other dtypes may not work correctly
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Alpha-VLLM/Lumina-Image-2.0"
67
+ is_lumina2: true # lumina2 architecture
68
+ # you can quantize just the Gemma2 text encoder here to save vram
69
+ quantize_te: true
70
+ sample:
71
+ sampler: "flowmatch" # must match train.noise_scheduler
72
+ sample_every: 250 # sample every this many steps
73
+ width: 1024
74
+ height: 1024
75
+ prompts:
76
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
77
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
78
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
79
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
80
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
81
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
82
+ - "a bear building a log cabin in the snow covered mountains"
83
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
84
+ - "hipster man with a beard, building a chair, in a wood shop"
85
+ - "photo of a cat that is half black and half orange tabby, split down the middle. The cat has on a blue tophat. They are holding a martini glass with a pink ball of yarn in it with green knitting needles sticking out, in one paw. In the other paw, they are holding a DVD case for a movie titled, \"This is a test\" that has a golden robot on it. In the background is a busy night club with a giant mushroom man dancing with a bear."
86
+ - "a man holding a sign that says, 'this is a sign'"
87
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
88
+ neg: ""
89
+ seed: 42
90
+ walk_seed: true
91
+ guidance_scale: 4.0
92
+ sample_steps: 25
93
+ # you can add any additional meta info here. [name] is replaced with config name at top
94
+ meta:
95
+ name: "[name]"
96
+ version: '1.0'
config/examples/train_lora_sd35_large_24gb.yaml ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # NOTE!! THIS IS CURRENTLY EXPERIMENTAL AND UNDER DEVELOPMENT. SOME THINGS WILL CHANGE
3
+ job: extension
4
+ config:
5
+ # this name will be the folder and filename name
6
+ name: "my_first_sd3l_lora_v1"
7
+ process:
8
+ - type: 'sd_trainer'
9
+ # root folder to save training sessions/samples/weights
10
+ training_folder: "output"
11
+ # uncomment to see performance stats in the terminal every N steps
12
+ # performance_log_every: 1000
13
+ device: cuda:0
14
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
15
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
16
+ # trigger_word: "p3r5on"
17
+ network:
18
+ type: "lora"
19
+ linear: 16
20
+ linear_alpha: 16
21
+ save:
22
+ dtype: float16 # precision to save
23
+ save_every: 250 # save every this many steps
24
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
25
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
26
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
27
+ # hf_repo_id: your-username/your-model-slug
28
+ # hf_private: true #whether the repo is private or public
29
+ datasets:
30
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
31
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
32
+ # images will automatically be resized and bucketed into the resolution specified
33
+ # on windows, escape back slashes with another backslash so
34
+ # "C:\\path\\to\\images\\folder"
35
+ - folder_path: "/path/to/images/folder"
36
+ caption_ext: "txt"
37
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
38
+ shuffle_tokens: false # shuffle caption order, split by commas
39
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
40
+ resolution: [ 1024 ]
41
+ train:
42
+ batch_size: 1
43
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
44
+ gradient_accumulation_steps: 1
45
+ train_unet: true
46
+ train_text_encoder: false # May not fully work with SD3 yet
47
+ gradient_checkpointing: true # need the on unless you have a ton of vram
48
+ noise_scheduler: "flowmatch"
49
+ timestep_type: "linear" # linear or sigmoid
50
+ optimizer: "adamw8bit"
51
+ lr: 1e-4
52
+ # uncomment this to skip the pre training sample
53
+ # skip_first_sample: true
54
+ # uncomment to completely disable sampling
55
+ # disable_sampling: true
56
+ # uncomment to use new vell curved weighting. Experimental but may produce better results
57
+ # linear_timesteps: true
58
+
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+
64
+ # will probably need this if gpu supports it for sd3, other dtypes may not work correctly
65
+ dtype: bf16
66
+ model:
67
+ # huggingface model name or path
68
+ name_or_path: "stabilityai/stable-diffusion-3.5-large"
69
+ is_v3: true
70
+ quantize: true # run 8bit mixed precision
71
+ sample:
72
+ sampler: "flowmatch" # must match train.noise_scheduler
73
+ sample_every: 250 # sample every this many steps
74
+ width: 1024
75
+ height: 1024
76
+ prompts:
77
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
78
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
79
+ - "woman with red hair, playing chess at the park, bomb going off in the background"
80
+ - "a woman holding a coffee cup, in a beanie, sitting at a cafe"
81
+ - "a horse is a DJ at a night club, fish eye lens, smoke machine, lazer lights, holding a martini"
82
+ - "a man showing off his cool new t shirt at the beach, a shark is jumping out of the water in the background"
83
+ - "a bear building a log cabin in the snow covered mountains"
84
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
85
+ - "hipster man with a beard, building a chair, in a wood shop"
86
+ - "photo of a man, white background, medium shot, modeling clothing, studio lighting, white backdrop"
87
+ - "a man holding a sign that says, 'this is a sign'"
88
+ - "a bulldog, in a post apocalyptic world, with a shotgun, in a leather jacket, in a desert, with a motorcycle"
89
+ neg: ""
90
+ seed: 42
91
+ walk_seed: true
92
+ guidance_scale: 4
93
+ sample_steps: 25
94
+ # you can add any additional meta info here. [name] is replaced with config name at top
95
+ meta:
96
+ name: "[name]"
97
+ version: '1.0'
config/examples/train_lora_wan21_14b_24gb.yaml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # IMPORTANT: The Wan2.1 14B model is huge. This config should work on 24GB GPUs. It cannot
2
+ # support keeping the text encoder on GPU while training with 24GB, so it is only good
3
+ # for training on a single prompt, for example a person with a trigger word.
4
+ # to train on captions, you need more vran for now.
5
+ ---
6
+ job: extension
7
+ config:
8
+ # this name will be the folder and filename name
9
+ name: "my_first_wan21_14b_lora_v1"
10
+ process:
11
+ - type: 'sd_trainer'
12
+ # root folder to save training sessions/samples/weights
13
+ training_folder: "output"
14
+ # uncomment to see performance stats in the terminal every N steps
15
+ # performance_log_every: 1000
16
+ device: cuda:0
17
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
18
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
19
+ # this is probably needed for 24GB cards when offloading TE to CPU
20
+ trigger_word: "p3r5on"
21
+ network:
22
+ type: "lora"
23
+ linear: 32
24
+ linear_alpha: 32
25
+ save:
26
+ dtype: float16 # precision to save
27
+ save_every: 250 # save every this many steps
28
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
29
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
30
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
31
+ # hf_repo_id: your-username/your-model-slug
32
+ # hf_private: true #whether the repo is private or public
33
+ datasets:
34
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
35
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
36
+ # images will automatically be resized and bucketed into the resolution specified
37
+ # on windows, escape back slashes with another backslash so
38
+ # "C:\\path\\to\\images\\folder"
39
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
40
+ # it works well for characters, but not as well for "actions"
41
+ - folder_path: "/path/to/images/folder"
42
+ caption_ext: "txt"
43
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
44
+ shuffle_tokens: false # shuffle caption order, split by commas
45
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
46
+ resolution: [ 632 ] # will be around 480p
47
+ train:
48
+ batch_size: 1
49
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
50
+ gradient_accumulation: 1
51
+ train_unet: true
52
+ train_text_encoder: false # probably won't work with wan
53
+ gradient_checkpointing: true # need the on unless you have a ton of vram
54
+ noise_scheduler: "flowmatch" # for training only
55
+ timestep_type: 'sigmoid'
56
+ optimizer: "adamw8bit"
57
+ lr: 1e-4
58
+ optimizer_params:
59
+ weight_decay: 1e-4
60
+ # uncomment this to skip the pre training sample
61
+ # skip_first_sample: true
62
+ # uncomment to completely disable sampling
63
+ # disable_sampling: true
64
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
65
+ ema_config:
66
+ use_ema: true
67
+ ema_decay: 0.99
68
+ dtype: bf16
69
+ # required for 24GB cards
70
+ # this will encode your trigger word and use those embeddings for every image in the dataset
71
+ unload_text_encoder: true
72
+ model:
73
+ # huggingface model name or path
74
+ name_or_path: "Wan-AI/Wan2.1-T2V-14B-Diffusers"
75
+ arch: 'wan21'
76
+ # these settings will save as much vram as possible
77
+ quantize: true
78
+ quantize_te: true
79
+ low_vram: true
80
+ sample:
81
+ sampler: "flowmatch"
82
+ sample_every: 250 # sample every this many steps
83
+ width: 832
84
+ height: 480
85
+ num_frames: 40
86
+ fps: 15
87
+ # samples take a long time. so use them sparingly
88
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
89
+ prompts:
90
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
91
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
92
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
93
+ neg: ""
94
+ seed: 42
95
+ walk_seed: true
96
+ guidance_scale: 5
97
+ sample_steps: 30
98
+ # you can add any additional meta info here. [name] is replaced with config name at top
99
+ meta:
100
+ name: "[name]"
101
+ version: '1.0'
config/examples/train_lora_wan21_1b_24gb.yaml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ # this name will be the folder and filename name
5
+ name: "my_first_wan21_1b_lora_v1"
6
+ process:
7
+ - type: 'sd_trainer'
8
+ # root folder to save training sessions/samples/weights
9
+ training_folder: "output"
10
+ # uncomment to see performance stats in the terminal every N steps
11
+ # performance_log_every: 1000
12
+ device: cuda:0
13
+ # if a trigger word is specified, it will be added to captions of training data if it does not already exist
14
+ # alternatively, in your captions you can add [trigger] and it will be replaced with the trigger word
15
+ # trigger_word: "p3r5on"
16
+ network:
17
+ type: "lora"
18
+ linear: 32
19
+ linear_alpha: 32
20
+ save:
21
+ dtype: float16 # precision to save
22
+ save_every: 250 # save every this many steps
23
+ max_step_saves_to_keep: 4 # how many intermittent saves to keep
24
+ push_to_hub: false #change this to True to push your trained model to Hugging Face.
25
+ # You can either set up a HF_TOKEN env variable or you'll be prompted to log-in
26
+ # hf_repo_id: your-username/your-model-slug
27
+ # hf_private: true #whether the repo is private or public
28
+ datasets:
29
+ # datasets are a folder of images. captions need to be txt files with the same name as the image
30
+ # for instance image2.jpg and image2.txt. Only jpg, jpeg, and png are supported currently
31
+ # images will automatically be resized and bucketed into the resolution specified
32
+ # on windows, escape back slashes with another backslash so
33
+ # "C:\\path\\to\\images\\folder"
34
+ # AI-Toolkit does not currently support video datasets, we will train on 1 frame at a time
35
+ # it works well for characters, but not as well for "actions"
36
+ - folder_path: "/path/to/images/folder"
37
+ caption_ext: "txt"
38
+ caption_dropout_rate: 0.05 # will drop out the caption 5% of time
39
+ shuffle_tokens: false # shuffle caption order, split by commas
40
+ cache_latents_to_disk: true # leave this true unless you know what you're doing
41
+ resolution: [ 632 ] # will be around 480p
42
+ train:
43
+ batch_size: 1
44
+ steps: 2000 # total number of steps to train 500 - 4000 is a good range
45
+ gradient_accumulation: 1
46
+ train_unet: true
47
+ train_text_encoder: false # probably won't work with wan
48
+ gradient_checkpointing: true # need the on unless you have a ton of vram
49
+ noise_scheduler: "flowmatch" # for training only
50
+ timestep_type: 'sigmoid'
51
+ optimizer: "adamw8bit"
52
+ lr: 1e-4
53
+ optimizer_params:
54
+ weight_decay: 1e-4
55
+ # uncomment this to skip the pre training sample
56
+ # skip_first_sample: true
57
+ # uncomment to completely disable sampling
58
+ # disable_sampling: true
59
+ # ema will smooth out learning, but could slow it down. Recommended to leave on.
60
+ ema_config:
61
+ use_ema: true
62
+ ema_decay: 0.99
63
+ dtype: bf16
64
+ model:
65
+ # huggingface model name or path
66
+ name_or_path: "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
67
+ arch: 'wan21'
68
+ quantize_te: true # saves vram
69
+ sample:
70
+ sampler: "flowmatch"
71
+ sample_every: 250 # sample every this many steps
72
+ width: 832
73
+ height: 480
74
+ num_frames: 40
75
+ fps: 15
76
+ # samples take a long time. so use them sparingly
77
+ # samples will be animated webp files, if you don't see them animated, open in a browser.
78
+ prompts:
79
+ # you can add [trigger] to the prompts here and it will be replaced with the trigger word
80
+ # - "[trigger] holding a sign that says 'I LOVE PROMPTS!'"\
81
+ - "woman playing the guitar, on stage, singing a song, laser lights, punk rocker"
82
+ neg: ""
83
+ seed: 42
84
+ walk_seed: true
85
+ guidance_scale: 5
86
+ sample_steps: 30
87
+ # you can add any additional meta info here. [name] is replaced with config name at top
88
+ meta:
89
+ name: "[name]"
90
+ version: '1.0'
config/examples/train_slider.example.yml ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # This is in yaml format. You can use json if you prefer
3
+ # I like both but yaml is easier to write
4
+ # Plus it has comments which is nice for documentation
5
+ # This is the config I use on my sliders, It is solid and tested
6
+ job: train
7
+ config:
8
+ # the name will be used to create a folder in the output folder
9
+ # it will also replace any [name] token in the rest of this config
10
+ name: detail_slider_v1
11
+ # folder will be created with name above in folder below
12
+ # it can be relative to the project root or absolute
13
+ training_folder: "output/LoRA"
14
+ device: cuda:0 # cpu, cuda:0, etc
15
+ # for tensorboard logging, we will make a subfolder for this job
16
+ log_dir: "output/.tensorboard"
17
+ # you can stack processes for other jobs, It is not tested with sliders though
18
+ # just use one for now
19
+ process:
20
+ - type: slider # tells runner to run the slider process
21
+ # network is the LoRA network for a slider, I recommend to leave this be
22
+ network:
23
+ # network type lierla is traditional LoRA that works everywhere, only linear layers
24
+ type: "lierla"
25
+ # rank / dim of the network. Bigger is not always better. Especially for sliders. 8 is good
26
+ linear: 8
27
+ linear_alpha: 4 # Do about half of rank
28
+ # training config
29
+ train:
30
+ # this is also used in sampling. Stick with ddpm unless you know what you are doing
31
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
32
+ # how many steps to train. More is not always better. I rarely go over 1000
33
+ steps: 500
34
+ # I have had good results with 4e-4 to 1e-4 at 500 steps
35
+ lr: 2e-4
36
+ # enables gradient checkpoint, saves vram, leave it on
37
+ gradient_checkpointing: true
38
+ # train the unet. I recommend leaving this true
39
+ train_unet: true
40
+ # train the text encoder. I don't recommend this unless you have a special use case
41
+ # for sliders we are adjusting representation of the concept (unet),
42
+ # not the description of it (text encoder)
43
+ train_text_encoder: false
44
+ # same as from sd-scripts, not fully tested but should speed up training
45
+ min_snr_gamma: 5.0
46
+ # just leave unless you know what you are doing
47
+ # also supports "dadaptation" but set lr to 1 if you use that,
48
+ # but it learns too fast and I don't recommend it
49
+ optimizer: "adamw"
50
+ # only constant for now
51
+ lr_scheduler: "constant"
52
+ # we randomly denoise random num of steps form 1 to this number
53
+ # while training. Just leave it
54
+ max_denoising_steps: 40
55
+ # works great at 1. I do 1 even with my 4090.
56
+ # higher may not work right with newer single batch stacking code anyway
57
+ batch_size: 1
58
+ # bf16 works best if your GPU supports it (modern)
59
+ dtype: bf16 # fp32, bf16, fp16
60
+ # if you have it, use it. It is faster and better
61
+ # torch 2.0 doesnt need xformers anymore, only use if you have lower version
62
+ # xformers: true
63
+ # I don't recommend using unless you are trying to make a darker lora. Then do 0.1 MAX
64
+ # although, the way we train sliders is comparative, so it probably won't work anyway
65
+ noise_offset: 0.0
66
+ # noise_offset: 0.0357 # SDXL was trained with offset of 0.0357. So use that when training on SDXL
67
+
68
+ # the model to train the LoRA network on
69
+ model:
70
+ # huggingface name, relative prom project path, or absolute path to .safetensors or .ckpt
71
+ name_or_path: "runwayml/stable-diffusion-v1-5"
72
+ is_v2: false # for v2 models
73
+ is_v_pred: false # for v-prediction models (most v2 models)
74
+ # has some issues with the dual text encoder and the way we train sliders
75
+ # it works bit weights need to probably be higher to see it.
76
+ is_xl: false # for SDXL models
77
+
78
+ # saving config
79
+ save:
80
+ dtype: float16 # precision to save. I recommend float16
81
+ save_every: 50 # save every this many steps
82
+ # this will remove step counts more than this number
83
+ # allows you to save more often in case of a crash without filling up your drive
84
+ max_step_saves_to_keep: 2
85
+
86
+ # sampling config
87
+ sample:
88
+ # must match train.noise_scheduler, this is not used here
89
+ # but may be in future and in other processes
90
+ sampler: "ddpm"
91
+ # sample every this many steps
92
+ sample_every: 20
93
+ # image size
94
+ width: 512
95
+ height: 512
96
+ # prompts to use for sampling. Do as many as you want, but it slows down training
97
+ # pick ones that will best represent the concept you are trying to adjust
98
+ # allows some flags after the prompt
99
+ # --m [number] # network multiplier. LoRA weight. -3 for the negative slide, 3 for the positive
100
+ # slide are good tests. will inherit sample.network_multiplier if not set
101
+ # --n [string] # negative prompt, will inherit sample.neg if not set
102
+ # Only 75 tokens allowed currently
103
+ # I like to do a wide positive and negative spread so I can see a good range and stop
104
+ # early if the network is braking down
105
+ prompts:
106
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -5"
107
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m -3"
108
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 3"
109
+ - "a woman in a coffee shop, black hat, blonde hair, blue jacket --m 5"
110
+ - "a golden retriever sitting on a leather couch, --m -5"
111
+ - "a golden retriever sitting on a leather couch --m -3"
112
+ - "a golden retriever sitting on a leather couch --m 3"
113
+ - "a golden retriever sitting on a leather couch --m 5"
114
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -5"
115
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m -3"
116
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 3"
117
+ - "a man with a beard and red flannel shirt, wearing vr goggles, walking into traffic --m 5"
118
+ # negative prompt used on all prompts above as default if they don't have one
119
+ neg: "cartoon, fake, drawing, illustration, cgi, animated, anime, monochrome"
120
+ # seed for sampling. 42 is the answer for everything
121
+ seed: 42
122
+ # walks the seed so s1 is 42, s2 is 43, s3 is 44, etc
123
+ # will start over on next sample_every so s1 is always seed
124
+ # works well if you use same prompt but want different results
125
+ walk_seed: false
126
+ # cfg scale (4 to 10 is good)
127
+ guidance_scale: 7
128
+ # sampler steps (20 to 30 is good)
129
+ sample_steps: 20
130
+ # default network multiplier for all prompts
131
+ # since we are training a slider, I recommend overriding this with --m [number]
132
+ # in the prompts above to get both sides of the slider
133
+ network_multiplier: 1.0
134
+
135
+ # logging information
136
+ logging:
137
+ log_every: 10 # log every this many steps
138
+ use_wandb: false # not supported yet
139
+ verbose: false # probably done need unless you are debugging
140
+
141
+ # slider training config, best for last
142
+ slider:
143
+ # resolutions to train on. [ width, height ]. This is less important for sliders
144
+ # as we are not teaching the model anything it doesn't already know
145
+ # but must be a size it understands [ 512, 512 ] for sd_v1.5 and [ 768, 768 ] for sd_v2.1
146
+ # and [ 1024, 1024 ] for sd_xl
147
+ # you can do as many as you want here
148
+ resolutions:
149
+ - [ 512, 512 ]
150
+ # - [ 512, 768 ]
151
+ # - [ 768, 768 ]
152
+ # slider training uses 4 combined steps for a single round. This will do it in one gradient
153
+ # step. It is highly optimized and shouldn't take anymore vram than doing without it,
154
+ # since we break down batches for gradient accumulation now. so just leave it on.
155
+ batch_full_slide: true
156
+ # These are the concepts to train on. You can do as many as you want here,
157
+ # but they can conflict outweigh each other. Other than experimenting, I recommend
158
+ # just doing one for good results
159
+ targets:
160
+ # target_class is the base concept we are adjusting the representation of
161
+ # for example, if we are adjusting the representation of a person, we would use "person"
162
+ # if we are adjusting the representation of a cat, we would use "cat" It is not
163
+ # a keyword necessarily but what the model understands the concept to represent.
164
+ # "person" will affect men, women, children, etc but will not affect cats, dogs, etc
165
+ # it is the models base general understanding of the concept and everything it represents
166
+ # you can leave it blank to affect everything. In this example, we are adjusting
167
+ # detail, so we will leave it blank to affect everything
168
+ - target_class: ""
169
+ # positive is the prompt for the positive side of the slider.
170
+ # It is the concept that will be excited and amplified in the model when we slide the slider
171
+ # to the positive side and forgotten / inverted when we slide
172
+ # the slider to the negative side. It is generally best to include the target_class in
173
+ # the prompt. You want it to be the extreme of what you want to train on. For example,
174
+ # if you want to train on fat people, you would use "an extremely fat, morbidly obese person"
175
+ # as the prompt. Not just "fat person"
176
+ # max 75 tokens for now
177
+ positive: "high detail, 8k, intricate, detailed, high resolution, high res, high quality"
178
+ # negative is the prompt for the negative side of the slider and works the same as positive
179
+ # it does not necessarily work the same as a negative prompt when generating images
180
+ # these need to be polar opposites.
181
+ # max 76 tokens for now
182
+ negative: "blurry, boring, fuzzy, low detail, low resolution, low res, low quality"
183
+ # the loss for this target is multiplied by this number.
184
+ # if you are doing more than one target it may be good to set less important ones
185
+ # to a lower number like 0.1 so they don't outweigh the primary target
186
+ weight: 1.0
187
+ # shuffle the prompts split by the comma. We will run every combination randomly
188
+ # this will make the LoRA more robust. You probably want this on unless prompt order
189
+ # is important for some reason
190
+ shuffle: true
191
+
192
+
193
+ # anchors are prompts that we will try to hold on to while training the slider
194
+ # these are NOT necessary and can prevent the slider from converging if not done right
195
+ # leave them off if you are having issues, but they can help lock the network
196
+ # on certain concepts to help prevent catastrophic forgetting
197
+ # you want these to generate an image that is not your target_class, but close to it
198
+ # is fine as long as it does not directly overlap it.
199
+ # For example, if you are training on a person smiling,
200
+ # you could use "a person with a face mask" as an anchor. It is a person, the image is the same
201
+ # regardless if they are smiling or not, however, the closer the concept is to the target_class
202
+ # the less the multiplier needs to be. Keep multipliers less than 1.0 for anchors usually
203
+ # for close concepts, you want to be closer to 0.1 or 0.2
204
+ # these will slow down training. I am leaving them off for the demo
205
+
206
+ # anchors:
207
+ # - prompt: "a woman"
208
+ # neg_prompt: "animal"
209
+ # # the multiplier applied to the LoRA when this is run.
210
+ # # higher will give it more weight but also help keep the lora from collapsing
211
+ # multiplier: 1.0
212
+ # - prompt: "a man"
213
+ # neg_prompt: "animal"
214
+ # multiplier: 1.0
215
+ # - prompt: "a person"
216
+ # neg_prompt: "animal"
217
+ # multiplier: 1.0
218
+
219
+ # You can put any information you want here, and it will be saved in the model.
220
+ # The below is an example, but you can put your grocery list in it if you want.
221
+ # It is saved in the model so be aware of that. The software will include this
222
+ # plus some other information for you automatically
223
+ meta:
224
+ # [name] gets replaced with the name above
225
+ name: "[name]"
226
+ # version: '1.0'
227
+ # creator:
228
+ # name: Your Name
229
+ # email: [email protected]
230
+ # website: https://your.website
docker-compose.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: "3.8"
2
+
3
+ services:
4
+ ai-toolkit:
5
+ image: ostris/aitoolkit:latest
6
+ restart: unless-stopped
7
+ ports:
8
+ - "8675:8675"
9
+ volumes:
10
+ - ~/.cache/huggingface/hub:/root/.cache/huggingface/hub
11
+ - ./aitk_db.db:/app/ai-toolkit/aitk_db.db
12
+ - ./datasets:/app/ai-toolkit/datasets
13
+ - ./output:/app/ai-toolkit/output
14
+ - ./config:/app/ai-toolkit/config
15
+ environment:
16
+ - AI_TOOLKIT_AUTH=${AI_TOOLKIT_AUTH:-password}
17
+ - NODE_ENV=production
18
+ - TZ=UTC
19
+ deploy:
20
+ resources:
21
+ reservations:
22
+ devices:
23
+ - driver: nvidia
24
+ count: all
25
+ capabilities: [gpu]
docker/Dockerfile ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.6.3-devel-ubuntu22.04
2
+
3
+ LABEL authors="jaret"
4
+
5
+ # Set noninteractive to avoid timezone prompts
6
+ ENV DEBIAN_FRONTEND=noninteractive
7
+
8
+ # Install dependencies
9
+ RUN apt-get update && apt-get install --no-install-recommends -y \
10
+ git \
11
+ curl \
12
+ build-essential \
13
+ cmake \
14
+ wget \
15
+ python3.10 \
16
+ python3-pip \
17
+ python3-dev \
18
+ python3-setuptools \
19
+ python3-wheel \
20
+ python3-venv \
21
+ ffmpeg \
22
+ tmux \
23
+ htop \
24
+ nvtop \
25
+ python3-opencv \
26
+ openssh-client \
27
+ openssh-server \
28
+ openssl \
29
+ rsync \
30
+ unzip \
31
+ && apt-get clean \
32
+ && rm -rf /var/lib/apt/lists/*
33
+
34
+ # Install nodejs
35
+ WORKDIR /tmp
36
+ RUN curl -sL https://deb.nodesource.com/setup_23.x -o nodesource_setup.sh && \
37
+ bash nodesource_setup.sh && \
38
+ apt-get update && \
39
+ apt-get install -y nodejs && \
40
+ apt-get clean && \
41
+ rm -rf /var/lib/apt/lists/*
42
+
43
+ WORKDIR /app
44
+
45
+ # Set aliases for python and pip
46
+ RUN ln -s /usr/bin/python3 /usr/bin/python
47
+
48
+ # install pytorch before cache bust to avoid redownloading pytorch
49
+ RUN pip install --no-cache-dir torch==2.6.0 torchvision==0.21.0 --index-url https://download.pytorch.org/whl/cu126
50
+
51
+ # Fix cache busting by moving CACHEBUST to right before git clone
52
+ ARG CACHEBUST=1234
53
+ RUN echo "Cache bust: ${CACHEBUST}" && \
54
+ git clone https://github.com/ostris/ai-toolkit.git && \
55
+ cd ai-toolkit
56
+
57
+ WORKDIR /app/ai-toolkit
58
+
59
+ # Install Python dependencies
60
+ RUN pip install --no-cache-dir -r requirements.txt && \
61
+ pip install flash-attn --no-build-isolation --no-cache-dir
62
+
63
+ # Build UI
64
+ WORKDIR /app/ai-toolkit/ui
65
+ RUN npm install && \
66
+ npm run build && \
67
+ npm run update_db
68
+
69
+ # Expose port (assuming the application runs on port 3000)
70
+ EXPOSE 8675
71
+
72
+ WORKDIR /
73
+
74
+ COPY docker/start.sh /start.sh
75
+ RUN chmod +x /start.sh
76
+
77
+ CMD ["/start.sh"]
docker/start.sh ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ set -e # Exit the script if any statement returns a non-true return value
3
+
4
+ # ref https://github.com/runpod/containers/blob/main/container-template/start.sh
5
+
6
+ # ---------------------------------------------------------------------------- #
7
+ # Function Definitions #
8
+ # ---------------------------------------------------------------------------- #
9
+
10
+
11
+ # Setup ssh
12
+ setup_ssh() {
13
+ if [[ $PUBLIC_KEY ]]; then
14
+ echo "Setting up SSH..."
15
+ mkdir -p ~/.ssh
16
+ echo "$PUBLIC_KEY" >> ~/.ssh/authorized_keys
17
+ chmod 700 -R ~/.ssh
18
+
19
+ if [ ! -f /etc/ssh/ssh_host_rsa_key ]; then
20
+ ssh-keygen -t rsa -f /etc/ssh/ssh_host_rsa_key -q -N ''
21
+ echo "RSA key fingerprint:"
22
+ ssh-keygen -lf /etc/ssh/ssh_host_rsa_key.pub
23
+ fi
24
+
25
+ if [ ! -f /etc/ssh/ssh_host_dsa_key ]; then
26
+ ssh-keygen -t dsa -f /etc/ssh/ssh_host_dsa_key -q -N ''
27
+ echo "DSA key fingerprint:"
28
+ ssh-keygen -lf /etc/ssh/ssh_host_dsa_key.pub
29
+ fi
30
+
31
+ if [ ! -f /etc/ssh/ssh_host_ecdsa_key ]; then
32
+ ssh-keygen -t ecdsa -f /etc/ssh/ssh_host_ecdsa_key -q -N ''
33
+ echo "ECDSA key fingerprint:"
34
+ ssh-keygen -lf /etc/ssh/ssh_host_ecdsa_key.pub
35
+ fi
36
+
37
+ if [ ! -f /etc/ssh/ssh_host_ed25519_key ]; then
38
+ ssh-keygen -t ed25519 -f /etc/ssh/ssh_host_ed25519_key -q -N ''
39
+ echo "ED25519 key fingerprint:"
40
+ ssh-keygen -lf /etc/ssh/ssh_host_ed25519_key.pub
41
+ fi
42
+
43
+ service ssh start
44
+
45
+ echo "SSH host keys:"
46
+ for key in /etc/ssh/*.pub; do
47
+ echo "Key: $key"
48
+ ssh-keygen -lf $key
49
+ done
50
+ fi
51
+ }
52
+
53
+ # Export env vars
54
+ export_env_vars() {
55
+ echo "Exporting environment variables..."
56
+ printenv | grep -E '^RUNPOD_|^PATH=|^_=' | awk -F = '{ print "export " $1 "=\"" $2 "\"" }' >> /etc/rp_environment
57
+ echo 'source /etc/rp_environment' >> ~/.bashrc
58
+ }
59
+
60
+ # ---------------------------------------------------------------------------- #
61
+ # Main Program #
62
+ # ---------------------------------------------------------------------------- #
63
+
64
+
65
+ echo "Pod Started"
66
+
67
+ setup_ssh
68
+ export_env_vars
69
+ echo "Starting AI Toolkit UI..."
70
+ cd /app/ai-toolkit/ui && npm run start
extensions/example/ExampleMergeModels.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gc
3
+ from collections import OrderedDict
4
+ from typing import TYPE_CHECKING
5
+ from jobs.process import BaseExtensionProcess
6
+ from toolkit.config_modules import ModelConfig
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ from toolkit.train_tools import get_torch_dtype
9
+ from tqdm import tqdm
10
+
11
+ # Type check imports. Prevents circular imports
12
+ if TYPE_CHECKING:
13
+ from jobs import ExtensionJob
14
+
15
+
16
+ # extend standard config classes to add weight
17
+ class ModelInputConfig(ModelConfig):
18
+ def __init__(self, **kwargs):
19
+ super().__init__(**kwargs)
20
+ self.weight = kwargs.get('weight', 1.0)
21
+ # overwrite default dtype unless user specifies otherwise
22
+ # float 32 will give up better precision on the merging functions
23
+ self.dtype: str = kwargs.get('dtype', 'float32')
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ # this is our main class process
32
+ class ExampleMergeModels(BaseExtensionProcess):
33
+ def __init__(
34
+ self,
35
+ process_id: int,
36
+ job: 'ExtensionJob',
37
+ config: OrderedDict
38
+ ):
39
+ super().__init__(process_id, job, config)
40
+ # this is the setup process, do not do process intensive stuff here, just variable setup and
41
+ # checking requirements. This is called before the run() function
42
+ # no loading models or anything like that, it is just for setting up the process
43
+ # all of your process intensive stuff should be done in the run() function
44
+ # config will have everything from the process item in the config file
45
+
46
+ # convince methods exist on BaseProcess to get config values
47
+ # if required is set to true and the value is not found it will throw an error
48
+ # you can pass a default value to get_conf() as well if it was not in the config file
49
+ # as well as a type to cast the value to
50
+ self.save_path = self.get_conf('save_path', required=True)
51
+ self.save_dtype = self.get_conf('save_dtype', default='float16', as_type=get_torch_dtype)
52
+ self.device = self.get_conf('device', default='cpu', as_type=torch.device)
53
+
54
+ # build models to merge list
55
+ models_to_merge = self.get_conf('models_to_merge', required=True, as_type=list)
56
+ # build list of ModelInputConfig objects. I find it is a good idea to make a class for each config
57
+ # this way you can add methods to it and it is easier to read and code. There are a lot of
58
+ # inbuilt config classes located in toolkit.config_modules as well
59
+ self.models_to_merge = [ModelInputConfig(**model) for model in models_to_merge]
60
+ # setup is complete. Don't load anything else here, just setup variables and stuff
61
+
62
+ # this is the entire run process be sure to call super().run() first
63
+ def run(self):
64
+ # always call first
65
+ super().run()
66
+ print(f"Running process: {self.__class__.__name__}")
67
+
68
+ # let's adjust our weights first to normalize them so the total is 1.0
69
+ total_weight = sum([model.weight for model in self.models_to_merge])
70
+ weight_adjust = 1.0 / total_weight
71
+ for model in self.models_to_merge:
72
+ model.weight *= weight_adjust
73
+
74
+ output_model: StableDiffusion = None
75
+ # let's do the merge, it is a good idea to use tqdm to show progress
76
+ for model_config in tqdm(self.models_to_merge, desc="Merging models"):
77
+ # setup model class with our helper class
78
+ sd_model = StableDiffusion(
79
+ device=self.device,
80
+ model_config=model_config,
81
+ dtype="float32"
82
+ )
83
+ # load the model
84
+ sd_model.load_model()
85
+
86
+ # adjust the weight of the text encoder
87
+ if isinstance(sd_model.text_encoder, list):
88
+ # sdxl model
89
+ for text_encoder in sd_model.text_encoder:
90
+ for key, value in text_encoder.state_dict().items():
91
+ value *= model_config.weight
92
+ else:
93
+ # normal model
94
+ for key, value in sd_model.text_encoder.state_dict().items():
95
+ value *= model_config.weight
96
+ # adjust the weights of the unet
97
+ for key, value in sd_model.unet.state_dict().items():
98
+ value *= model_config.weight
99
+
100
+ if output_model is None:
101
+ # use this one as the base
102
+ output_model = sd_model
103
+ else:
104
+ # merge the models
105
+ # text encoder
106
+ if isinstance(output_model.text_encoder, list):
107
+ # sdxl model
108
+ for i, text_encoder in enumerate(output_model.text_encoder):
109
+ for key, value in text_encoder.state_dict().items():
110
+ value += sd_model.text_encoder[i].state_dict()[key]
111
+ else:
112
+ # normal model
113
+ for key, value in output_model.text_encoder.state_dict().items():
114
+ value += sd_model.text_encoder.state_dict()[key]
115
+ # unet
116
+ for key, value in output_model.unet.state_dict().items():
117
+ value += sd_model.unet.state_dict()[key]
118
+
119
+ # remove the model to free memory
120
+ del sd_model
121
+ flush()
122
+
123
+ # merge loop is done, let's save the model
124
+ print(f"Saving merged model to {self.save_path}")
125
+ output_model.save(self.save_path, meta=self.meta, save_dtype=self.save_dtype)
126
+ print(f"Saved merged model to {self.save_path}")
127
+ # do cleanup here
128
+ del output_model
129
+ flush()
extensions/example/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # We make a subclass of Extension
6
+ class ExampleMergeExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "example_merge_extension"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Example Merge Extension"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ExampleMergeModels import ExampleMergeModels
19
+ return ExampleMergeModels
20
+
21
+
22
+ AI_TOOLKIT_EXTENSIONS = [
23
+ # you can put a list of extensions here
24
+ ExampleMergeExtension
25
+ ]
extensions/example/config/config.example.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ # Always include at least one example config file to show how to use your extension.
3
+ # use plenty of comments so users know how to use it and what everything does
4
+
5
+ # all extensions will use this job name
6
+ job: extension
7
+ config:
8
+ name: 'my_awesome_merge'
9
+ process:
10
+ # Put your example processes here. This will be passed
11
+ # to your extension process in the config argument.
12
+ # the type MUST match your extension uid
13
+ - type: "example_merge_extension"
14
+ # save path for the merged model
15
+ save_path: "output/merge/[name].safetensors"
16
+ # save type
17
+ dtype: fp16
18
+ # device to run it on
19
+ device: cuda:0
20
+ # input models can only be SD1.x and SD2.x models for this example (currently)
21
+ models_to_merge:
22
+ # weights are relative, total weights will be normalized
23
+ # for example. If you have 2 models with weight 1.0, they will
24
+ # both be weighted 0.5. If you have 1 model with weight 1.0 and
25
+ # another with weight 2.0, the first will be weighted 1/3 and the
26
+ # second will be weighted 2/3
27
+ - name_or_path: "input/model1.safetensors"
28
+ weight: 1.0
29
+ - name_or_path: "input/model2.safetensors"
30
+ weight: 1.0
31
+ - name_or_path: "input/model3.safetensors"
32
+ weight: 0.3
33
+ - name_or_path: "input/model4.safetensors"
34
+ weight: 1.0
35
+
36
+
37
+ # you can put any information you want here, and it will be saved in the model
38
+ # the below is an example. I recommend doing trigger words at a minimum
39
+ # in the metadata. The software will include this plus some other information
40
+ meta:
41
+ name: "[name]" # [name] gets replaced with the name above
42
+ description: A short description of your model
43
+ version: '0.1'
44
+ creator:
45
+ name: Your Name
46
47
+ website: https://yourwebsite.com
48
+ any: All meta data above is arbitrary, it can be whatever you want.
extensions_built_in/advanced_generator/Img2ImgGenerator.py ADDED
@@ -0,0 +1,256 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import random
4
+ from collections import OrderedDict
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ from PIL import Image
9
+ from diffusers import T2IAdapter
10
+ from diffusers.utils.torch_utils import randn_tensor
11
+ from torch.utils.data import DataLoader
12
+ from diffusers import StableDiffusionXLImg2ImgPipeline, PixArtSigmaPipeline
13
+ from tqdm import tqdm
14
+
15
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
16
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
17
+ from toolkit.sampler import get_sampler
18
+ from toolkit.stable_diffusion_model import StableDiffusion
19
+ import gc
20
+ import torch
21
+ from jobs.process import BaseExtensionProcess
22
+ from toolkit.data_loader import get_dataloader_from_datasets
23
+ from toolkit.train_tools import get_torch_dtype
24
+ from controlnet_aux.midas import MidasDetector
25
+ from diffusers.utils import load_image
26
+ from torchvision.transforms import ToTensor
27
+
28
+
29
+ def flush():
30
+ torch.cuda.empty_cache()
31
+ gc.collect()
32
+
33
+
34
+
35
+
36
+
37
+ class GenerateConfig:
38
+
39
+ def __init__(self, **kwargs):
40
+ self.prompts: List[str]
41
+ self.sampler = kwargs.get('sampler', 'ddpm')
42
+ self.neg = kwargs.get('neg', '')
43
+ self.seed = kwargs.get('seed', -1)
44
+ self.walk_seed = kwargs.get('walk_seed', False)
45
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
46
+ self.sample_steps = kwargs.get('sample_steps', 20)
47
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
48
+ self.ext = kwargs.get('ext', 'png')
49
+ self.denoise_strength = kwargs.get('denoise_strength', 0.5)
50
+ self.trigger_word = kwargs.get('trigger_word', None)
51
+
52
+
53
+ class Img2ImgGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.copy_inputs_to = self.get_conf('copy_inputs_to', None)
59
+ self.device = self.get_conf('device', 'cuda')
60
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
61
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
62
+ self.is_latents_cached = True
63
+ raw_datasets = self.get_conf('datasets', None)
64
+ if raw_datasets is not None and len(raw_datasets) > 0:
65
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
66
+ self.datasets = None
67
+ self.datasets_reg = None
68
+ self.dtype = self.get_conf('dtype', 'float16')
69
+ self.torch_dtype = get_torch_dtype(self.dtype)
70
+ self.params = []
71
+ if raw_datasets is not None and len(raw_datasets) > 0:
72
+ for raw_dataset in raw_datasets:
73
+ dataset = DatasetConfig(**raw_dataset)
74
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
75
+ if not is_caching:
76
+ self.is_latents_cached = False
77
+ if dataset.is_reg:
78
+ if self.datasets_reg is None:
79
+ self.datasets_reg = []
80
+ self.datasets_reg.append(dataset)
81
+ else:
82
+ if self.datasets is None:
83
+ self.datasets = []
84
+ self.datasets.append(dataset)
85
+
86
+ self.progress_bar = None
87
+ self.sd = StableDiffusion(
88
+ device=self.device,
89
+ model_config=self.model_config,
90
+ dtype=self.dtype,
91
+ )
92
+ print(f"Using device {self.device}")
93
+ self.data_loader: DataLoader = None
94
+ self.adapter: T2IAdapter = None
95
+
96
+ def to_pil(self, img):
97
+ # image comes in -1 to 1. convert to a PIL RGB image
98
+ img = (img + 1) / 2
99
+ img = img.clamp(0, 1)
100
+ img = img[0].permute(1, 2, 0).cpu().numpy()
101
+ img = (img * 255).astype(np.uint8)
102
+ image = Image.fromarray(img)
103
+ return image
104
+
105
+ def run(self):
106
+ with torch.no_grad():
107
+ super().run()
108
+ print("Loading model...")
109
+ self.sd.load_model()
110
+ device = torch.device(self.device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLImg2ImgPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ ).to(device, dtype=self.torch_dtype)
122
+ elif self.model_config.is_pixart:
123
+ pipe = self.sd.pipeline.to(device, dtype=self.torch_dtype)
124
+ else:
125
+ raise NotImplementedError("Only XL models are supported")
126
+ pipe.set_progress_bar_config(disable=True)
127
+
128
+ # pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
129
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
130
+
131
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
132
+
133
+ num_batches = len(self.data_loader)
134
+ pbar = tqdm(total=num_batches, desc="Generating images")
135
+ seed = self.generate_config.seed
136
+ # load images from datasets, use tqdm
137
+ for i, batch in enumerate(self.data_loader):
138
+ batch: DataLoaderBatchDTO = batch
139
+
140
+ gen_seed = seed if seed > 0 else random.randint(0, 2 ** 32 - 1)
141
+ generator = torch.manual_seed(gen_seed)
142
+
143
+ file_item: FileItemDTO = batch.file_items[0]
144
+ img_path = file_item.path
145
+ img_filename = os.path.basename(img_path)
146
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
147
+ img_filename = img_filename_no_ext + '.' + self.generate_config.ext
148
+ output_path = os.path.join(self.output_folder, img_filename)
149
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
150
+
151
+ if self.copy_inputs_to is not None:
152
+ output_inputs_path = os.path.join(self.copy_inputs_to, img_filename)
153
+ output_inputs_caption_path = os.path.join(self.copy_inputs_to, img_filename_no_ext + '.txt')
154
+ else:
155
+ output_inputs_path = None
156
+ output_inputs_caption_path = None
157
+
158
+ caption = batch.get_caption_list()[0]
159
+ if self.generate_config.trigger_word is not None:
160
+ caption = caption.replace('[trigger]', self.generate_config.trigger_word)
161
+
162
+ img: torch.Tensor = batch.tensor.clone()
163
+ image = self.to_pil(img)
164
+
165
+ # image.save(output_depth_path)
166
+ if self.model_config.is_pixart:
167
+ pipe: PixArtSigmaPipeline = pipe
168
+
169
+ # Encode the full image once
170
+ encoded_image = pipe.vae.encode(
171
+ pipe.image_processor.preprocess(image).to(device=pipe.device, dtype=pipe.dtype))
172
+ if hasattr(encoded_image, "latent_dist"):
173
+ latents = encoded_image.latent_dist.sample(generator)
174
+ elif hasattr(encoded_image, "latents"):
175
+ latents = encoded_image.latents
176
+ else:
177
+ raise AttributeError("Could not access latents of provided encoder_output")
178
+ latents = pipe.vae.config.scaling_factor * latents
179
+
180
+ # latents = self.sd.encode_images(img)
181
+
182
+ # self.sd.noise_scheduler.set_timesteps(self.generate_config.sample_steps)
183
+ # start_step = math.floor(self.generate_config.sample_steps * self.generate_config.denoise_strength)
184
+ # timestep = self.sd.noise_scheduler.timesteps[start_step].unsqueeze(0)
185
+ # timestep = timestep.to(device, dtype=torch.int32)
186
+ # latent = latent.to(device, dtype=self.torch_dtype)
187
+ # noise = torch.randn_like(latent, device=device, dtype=self.torch_dtype)
188
+ # latent = self.sd.add_noise(latent, noise, timestep)
189
+ # timesteps_to_use = self.sd.noise_scheduler.timesteps[start_step + 1:]
190
+ batch_size = 1
191
+ num_images_per_prompt = 1
192
+
193
+ shape = (batch_size, pipe.transformer.config.in_channels, image.height // pipe.vae_scale_factor,
194
+ image.width // pipe.vae_scale_factor)
195
+ noise = randn_tensor(shape, generator=generator, device=pipe.device, dtype=pipe.dtype)
196
+
197
+ # noise = torch.randn_like(latents, device=device, dtype=self.torch_dtype)
198
+ num_inference_steps = self.generate_config.sample_steps
199
+ strength = self.generate_config.denoise_strength
200
+ # Get timesteps
201
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
202
+ t_start = max(num_inference_steps - init_timestep, 0)
203
+ pipe.scheduler.set_timesteps(num_inference_steps, device="cpu")
204
+ timesteps = pipe.scheduler.timesteps[t_start:]
205
+ timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
206
+ latents = pipe.scheduler.add_noise(latents, noise, timestep)
207
+
208
+ gen_images = pipe.__call__(
209
+ prompt=caption,
210
+ negative_prompt=self.generate_config.neg,
211
+ latents=latents,
212
+ timesteps=timesteps,
213
+ width=image.width,
214
+ height=image.height,
215
+ num_inference_steps=num_inference_steps,
216
+ num_images_per_prompt=num_images_per_prompt,
217
+ guidance_scale=self.generate_config.guidance_scale,
218
+ # strength=self.generate_config.denoise_strength,
219
+ use_resolution_binning=False,
220
+ output_type="np"
221
+ ).images[0]
222
+ gen_images = (gen_images * 255).clip(0, 255).astype(np.uint8)
223
+ gen_images = Image.fromarray(gen_images)
224
+ else:
225
+ pipe: StableDiffusionXLImg2ImgPipeline = pipe
226
+
227
+ gen_images = pipe.__call__(
228
+ prompt=caption,
229
+ negative_prompt=self.generate_config.neg,
230
+ image=image,
231
+ num_inference_steps=self.generate_config.sample_steps,
232
+ guidance_scale=self.generate_config.guidance_scale,
233
+ strength=self.generate_config.denoise_strength,
234
+ ).images[0]
235
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
236
+ gen_images.save(output_path)
237
+
238
+ # save caption
239
+ with open(output_caption_path, 'w') as f:
240
+ f.write(caption)
241
+
242
+ if output_inputs_path is not None:
243
+ os.makedirs(os.path.dirname(output_inputs_path), exist_ok=True)
244
+ image.save(output_inputs_path)
245
+ with open(output_inputs_caption_path, 'w') as f:
246
+ f.write(caption)
247
+
248
+ pbar.update(1)
249
+ batch.cleanup()
250
+
251
+ pbar.close()
252
+ print("Done generating images")
253
+ # cleanup
254
+ del self.sd
255
+ gc.collect()
256
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/PureLoraGenerator.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from collections import OrderedDict
3
+
4
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, SampleConfig, LoRMConfig
5
+ from toolkit.lorm import ExtractMode, convert_diffusers_unet_to_lorm
6
+ from toolkit.sd_device_states_presets import get_train_sd_device_state_preset
7
+ from toolkit.stable_diffusion_model import StableDiffusion
8
+ import gc
9
+ import torch
10
+ from jobs.process import BaseExtensionProcess
11
+ from toolkit.train_tools import get_torch_dtype
12
+
13
+
14
+ def flush():
15
+ torch.cuda.empty_cache()
16
+ gc.collect()
17
+
18
+
19
+ class PureLoraGenerator(BaseExtensionProcess):
20
+
21
+ def __init__(self, process_id: int, job, config: OrderedDict):
22
+ super().__init__(process_id, job, config)
23
+ self.output_folder = self.get_conf('output_folder', required=True)
24
+ self.device = self.get_conf('device', 'cuda')
25
+ self.device_torch = torch.device(self.device)
26
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
27
+ self.generate_config = SampleConfig(**self.get_conf('sample', required=True))
28
+ self.dtype = self.get_conf('dtype', 'float16')
29
+ self.torch_dtype = get_torch_dtype(self.dtype)
30
+ lorm_config = self.get_conf('lorm', None)
31
+ self.lorm_config = LoRMConfig(**lorm_config) if lorm_config is not None else None
32
+
33
+ self.device_state_preset = get_train_sd_device_state_preset(
34
+ device=torch.device(self.device),
35
+ )
36
+
37
+ self.progress_bar = None
38
+ self.sd = StableDiffusion(
39
+ device=self.device,
40
+ model_config=self.model_config,
41
+ dtype=self.dtype,
42
+ )
43
+
44
+ def run(self):
45
+ super().run()
46
+ print("Loading model...")
47
+ with torch.no_grad():
48
+ self.sd.load_model()
49
+ self.sd.unet.eval()
50
+ self.sd.unet.to(self.device_torch)
51
+ if isinstance(self.sd.text_encoder, list):
52
+ for te in self.sd.text_encoder:
53
+ te.eval()
54
+ te.to(self.device_torch)
55
+ else:
56
+ self.sd.text_encoder.eval()
57
+ self.sd.to(self.device_torch)
58
+
59
+ print(f"Converting to LoRM UNet")
60
+ # replace the unet with LoRMUnet
61
+ convert_diffusers_unet_to_lorm(
62
+ self.sd.unet,
63
+ config=self.lorm_config,
64
+ )
65
+
66
+ sample_folder = os.path.join(self.output_folder)
67
+ gen_img_config_list = []
68
+
69
+ sample_config = self.generate_config
70
+ start_seed = sample_config.seed
71
+ current_seed = start_seed
72
+ for i in range(len(sample_config.prompts)):
73
+ if sample_config.walk_seed:
74
+ current_seed = start_seed + i
75
+
76
+ filename = f"[time]_[count].{self.generate_config.ext}"
77
+ output_path = os.path.join(sample_folder, filename)
78
+ prompt = sample_config.prompts[i]
79
+ extra_args = {}
80
+ gen_img_config_list.append(GenerateImageConfig(
81
+ prompt=prompt, # it will autoparse the prompt
82
+ width=sample_config.width,
83
+ height=sample_config.height,
84
+ negative_prompt=sample_config.neg,
85
+ seed=current_seed,
86
+ guidance_scale=sample_config.guidance_scale,
87
+ guidance_rescale=sample_config.guidance_rescale,
88
+ num_inference_steps=sample_config.sample_steps,
89
+ network_multiplier=sample_config.network_multiplier,
90
+ output_path=output_path,
91
+ output_ext=sample_config.ext,
92
+ adapter_conditioning_scale=sample_config.adapter_conditioning_scale,
93
+ **extra_args
94
+ ))
95
+
96
+ # send to be generated
97
+ self.sd.generate_images(gen_img_config_list, sampler=sample_config.sampler)
98
+ print("Done generating images")
99
+ # cleanup
100
+ del self.sd
101
+ gc.collect()
102
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/ReferenceGenerator.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from collections import OrderedDict
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ from PIL import Image
8
+ from diffusers import T2IAdapter
9
+ from torch.utils.data import DataLoader
10
+ from diffusers import StableDiffusionXLAdapterPipeline, StableDiffusionAdapterPipeline
11
+ from tqdm import tqdm
12
+
13
+ from toolkit.config_modules import ModelConfig, GenerateImageConfig, preprocess_dataset_raw_config, DatasetConfig
14
+ from toolkit.data_transfer_object.data_loader import FileItemDTO, DataLoaderBatchDTO
15
+ from toolkit.sampler import get_sampler
16
+ from toolkit.stable_diffusion_model import StableDiffusion
17
+ import gc
18
+ import torch
19
+ from jobs.process import BaseExtensionProcess
20
+ from toolkit.data_loader import get_dataloader_from_datasets
21
+ from toolkit.train_tools import get_torch_dtype
22
+ from controlnet_aux.midas import MidasDetector
23
+ from diffusers.utils import load_image
24
+
25
+
26
+ def flush():
27
+ torch.cuda.empty_cache()
28
+ gc.collect()
29
+
30
+
31
+ class GenerateConfig:
32
+
33
+ def __init__(self, **kwargs):
34
+ self.prompts: List[str]
35
+ self.sampler = kwargs.get('sampler', 'ddpm')
36
+ self.neg = kwargs.get('neg', '')
37
+ self.seed = kwargs.get('seed', -1)
38
+ self.walk_seed = kwargs.get('walk_seed', False)
39
+ self.t2i_adapter_path = kwargs.get('t2i_adapter_path', None)
40
+ self.guidance_scale = kwargs.get('guidance_scale', 7)
41
+ self.sample_steps = kwargs.get('sample_steps', 20)
42
+ self.prompt_2 = kwargs.get('prompt_2', None)
43
+ self.neg_2 = kwargs.get('neg_2', None)
44
+ self.prompts = kwargs.get('prompts', None)
45
+ self.guidance_rescale = kwargs.get('guidance_rescale', 0.0)
46
+ self.ext = kwargs.get('ext', 'png')
47
+ self.adapter_conditioning_scale = kwargs.get('adapter_conditioning_scale', 1.0)
48
+ if kwargs.get('shuffle', False):
49
+ # shuffle the prompts
50
+ random.shuffle(self.prompts)
51
+
52
+
53
+ class ReferenceGenerator(BaseExtensionProcess):
54
+
55
+ def __init__(self, process_id: int, job, config: OrderedDict):
56
+ super().__init__(process_id, job, config)
57
+ self.output_folder = self.get_conf('output_folder', required=True)
58
+ self.device = self.get_conf('device', 'cuda')
59
+ self.model_config = ModelConfig(**self.get_conf('model', required=True))
60
+ self.generate_config = GenerateConfig(**self.get_conf('generate', required=True))
61
+ self.is_latents_cached = True
62
+ raw_datasets = self.get_conf('datasets', None)
63
+ if raw_datasets is not None and len(raw_datasets) > 0:
64
+ raw_datasets = preprocess_dataset_raw_config(raw_datasets)
65
+ self.datasets = None
66
+ self.datasets_reg = None
67
+ self.dtype = self.get_conf('dtype', 'float16')
68
+ self.torch_dtype = get_torch_dtype(self.dtype)
69
+ self.params = []
70
+ if raw_datasets is not None and len(raw_datasets) > 0:
71
+ for raw_dataset in raw_datasets:
72
+ dataset = DatasetConfig(**raw_dataset)
73
+ is_caching = dataset.cache_latents or dataset.cache_latents_to_disk
74
+ if not is_caching:
75
+ self.is_latents_cached = False
76
+ if dataset.is_reg:
77
+ if self.datasets_reg is None:
78
+ self.datasets_reg = []
79
+ self.datasets_reg.append(dataset)
80
+ else:
81
+ if self.datasets is None:
82
+ self.datasets = []
83
+ self.datasets.append(dataset)
84
+
85
+ self.progress_bar = None
86
+ self.sd = StableDiffusion(
87
+ device=self.device,
88
+ model_config=self.model_config,
89
+ dtype=self.dtype,
90
+ )
91
+ print(f"Using device {self.device}")
92
+ self.data_loader: DataLoader = None
93
+ self.adapter: T2IAdapter = None
94
+
95
+ def run(self):
96
+ super().run()
97
+ print("Loading model...")
98
+ self.sd.load_model()
99
+ device = torch.device(self.device)
100
+
101
+ if self.generate_config.t2i_adapter_path is not None:
102
+ self.adapter = T2IAdapter.from_pretrained(
103
+ self.generate_config.t2i_adapter_path,
104
+ torch_dtype=self.torch_dtype,
105
+ varient="fp16"
106
+ ).to(device)
107
+
108
+ midas_depth = MidasDetector.from_pretrained(
109
+ "valhalla/t2iadapter-aux-models", filename="dpt_large_384.pt", model_type="dpt_large"
110
+ ).to(device)
111
+
112
+ if self.model_config.is_xl:
113
+ pipe = StableDiffusionXLAdapterPipeline(
114
+ vae=self.sd.vae,
115
+ unet=self.sd.unet,
116
+ text_encoder=self.sd.text_encoder[0],
117
+ text_encoder_2=self.sd.text_encoder[1],
118
+ tokenizer=self.sd.tokenizer[0],
119
+ tokenizer_2=self.sd.tokenizer[1],
120
+ scheduler=get_sampler(self.generate_config.sampler),
121
+ adapter=self.adapter,
122
+ ).to(device, dtype=self.torch_dtype)
123
+ else:
124
+ pipe = StableDiffusionAdapterPipeline(
125
+ vae=self.sd.vae,
126
+ unet=self.sd.unet,
127
+ text_encoder=self.sd.text_encoder,
128
+ tokenizer=self.sd.tokenizer,
129
+ scheduler=get_sampler(self.generate_config.sampler),
130
+ safety_checker=None,
131
+ feature_extractor=None,
132
+ requires_safety_checker=False,
133
+ adapter=self.adapter,
134
+ ).to(device, dtype=self.torch_dtype)
135
+ pipe.set_progress_bar_config(disable=True)
136
+
137
+ pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
138
+ # midas_depth = torch.compile(midas_depth, mode="reduce-overhead", fullgraph=True)
139
+
140
+ self.data_loader = get_dataloader_from_datasets(self.datasets, 1, self.sd)
141
+
142
+ num_batches = len(self.data_loader)
143
+ pbar = tqdm(total=num_batches, desc="Generating images")
144
+ seed = self.generate_config.seed
145
+ # load images from datasets, use tqdm
146
+ for i, batch in enumerate(self.data_loader):
147
+ batch: DataLoaderBatchDTO = batch
148
+
149
+ file_item: FileItemDTO = batch.file_items[0]
150
+ img_path = file_item.path
151
+ img_filename = os.path.basename(img_path)
152
+ img_filename_no_ext = os.path.splitext(img_filename)[0]
153
+ output_path = os.path.join(self.output_folder, img_filename)
154
+ output_caption_path = os.path.join(self.output_folder, img_filename_no_ext + '.txt')
155
+ output_depth_path = os.path.join(self.output_folder, img_filename_no_ext + '.depth.png')
156
+
157
+ caption = batch.get_caption_list()[0]
158
+
159
+ img: torch.Tensor = batch.tensor.clone()
160
+ # image comes in -1 to 1. convert to a PIL RGB image
161
+ img = (img + 1) / 2
162
+ img = img.clamp(0, 1)
163
+ img = img[0].permute(1, 2, 0).cpu().numpy()
164
+ img = (img * 255).astype(np.uint8)
165
+ image = Image.fromarray(img)
166
+
167
+ width, height = image.size
168
+ min_res = min(width, height)
169
+
170
+ if self.generate_config.walk_seed:
171
+ seed = seed + 1
172
+
173
+ if self.generate_config.seed == -1:
174
+ # random
175
+ seed = random.randint(0, 1000000)
176
+
177
+ torch.manual_seed(seed)
178
+ torch.cuda.manual_seed(seed)
179
+
180
+ # generate depth map
181
+ image = midas_depth(
182
+ image,
183
+ detect_resolution=min_res, # do 512 ?
184
+ image_resolution=min_res
185
+ )
186
+
187
+ # image.save(output_depth_path)
188
+
189
+ gen_images = pipe(
190
+ prompt=caption,
191
+ negative_prompt=self.generate_config.neg,
192
+ image=image,
193
+ num_inference_steps=self.generate_config.sample_steps,
194
+ adapter_conditioning_scale=self.generate_config.adapter_conditioning_scale,
195
+ guidance_scale=self.generate_config.guidance_scale,
196
+ ).images[0]
197
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
198
+ gen_images.save(output_path)
199
+
200
+ # save caption
201
+ with open(output_caption_path, 'w') as f:
202
+ f.write(caption)
203
+
204
+ pbar.update(1)
205
+ batch.cleanup()
206
+
207
+ pbar.close()
208
+ print("Done generating images")
209
+ # cleanup
210
+ del self.sd
211
+ gc.collect()
212
+ torch.cuda.empty_cache()
extensions_built_in/advanced_generator/__init__.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class AdvancedReferenceGeneratorExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "reference_generator"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Reference Generator"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ReferenceGenerator import ReferenceGenerator
19
+ return ReferenceGenerator
20
+
21
+
22
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
23
+ class PureLoraGenerator(Extension):
24
+ # uid must be unique, it is how the extension is identified
25
+ uid = "pure_lora_generator"
26
+
27
+ # name is the name of the extension for printing
28
+ name = "Pure LoRA Generator"
29
+
30
+ # This is where your process class is loaded
31
+ # keep your imports in here so they don't slow down the rest of the program
32
+ @classmethod
33
+ def get_process(cls):
34
+ # import your process class here so it is only loaded when needed and return it
35
+ from .PureLoraGenerator import PureLoraGenerator
36
+ return PureLoraGenerator
37
+
38
+
39
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
40
+ class Img2ImgGeneratorExtension(Extension):
41
+ # uid must be unique, it is how the extension is identified
42
+ uid = "batch_img2img"
43
+
44
+ # name is the name of the extension for printing
45
+ name = "Img2ImgGeneratorExtension"
46
+
47
+ # This is where your process class is loaded
48
+ # keep your imports in here so they don't slow down the rest of the program
49
+ @classmethod
50
+ def get_process(cls):
51
+ # import your process class here so it is only loaded when needed and return it
52
+ from .Img2ImgGenerator import Img2ImgGenerator
53
+ return Img2ImgGenerator
54
+
55
+
56
+ AI_TOOLKIT_EXTENSIONS = [
57
+ # you can put a list of extensions here
58
+ AdvancedReferenceGeneratorExtension, PureLoraGenerator, Img2ImgGeneratorExtension
59
+ ]
extensions_built_in/advanced_generator/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: [email protected]
91
+ # website: https://your.website
extensions_built_in/concept_replacer/ConceptReplacer.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ from collections import OrderedDict
3
+ from torch.utils.data import DataLoader
4
+ from toolkit.prompt_utils import concat_prompt_embeds, split_prompt_embeds
5
+ from toolkit.stable_diffusion_model import StableDiffusion, BlankNetwork
6
+ from toolkit.train_tools import get_torch_dtype, apply_snr_weight
7
+ import gc
8
+ import torch
9
+ from jobs.process import BaseSDTrainProcess
10
+
11
+
12
+ def flush():
13
+ torch.cuda.empty_cache()
14
+ gc.collect()
15
+
16
+
17
+ class ConceptReplacementConfig:
18
+ def __init__(self, **kwargs):
19
+ self.concept: str = kwargs.get('concept', '')
20
+ self.replacement: str = kwargs.get('replacement', '')
21
+
22
+
23
+ class ConceptReplacer(BaseSDTrainProcess):
24
+
25
+ def __init__(self, process_id: int, job, config: OrderedDict, **kwargs):
26
+ super().__init__(process_id, job, config, **kwargs)
27
+ replacement_list = self.config.get('replacements', [])
28
+ self.replacement_list = [ConceptReplacementConfig(**x) for x in replacement_list]
29
+
30
+ def before_model_load(self):
31
+ pass
32
+
33
+ def hook_before_train_loop(self):
34
+ self.sd.vae.eval()
35
+ self.sd.vae.to(self.device_torch)
36
+
37
+ # textual inversion
38
+ if self.embedding is not None:
39
+ # set text encoder to train. Not sure if this is necessary but diffusers example did it
40
+ self.sd.text_encoder.train()
41
+
42
+ def hook_train_loop(self, batch):
43
+ with torch.no_grad():
44
+ dtype = get_torch_dtype(self.train_config.dtype)
45
+ noisy_latents, noise, timesteps, conditioned_prompts, imgs = self.process_general_training_batch(batch)
46
+ network_weight_list = batch.get_network_weight_list()
47
+
48
+ # have a blank network so we can wrap it in a context and set multipliers without checking every time
49
+ if self.network is not None:
50
+ network = self.network
51
+ else:
52
+ network = BlankNetwork()
53
+
54
+ batch_replacement_list = []
55
+ # get a random replacement for each prompt
56
+ for prompt in conditioned_prompts:
57
+ replacement = random.choice(self.replacement_list)
58
+ batch_replacement_list.append(replacement)
59
+
60
+ # build out prompts
61
+ concept_prompts = []
62
+ replacement_prompts = []
63
+ for idx, replacement in enumerate(batch_replacement_list):
64
+ prompt = conditioned_prompts[idx]
65
+
66
+ # insert shuffled concept at beginning and end of prompt
67
+ shuffled_concept = [x.strip() for x in replacement.concept.split(',')]
68
+ random.shuffle(shuffled_concept)
69
+ shuffled_concept = ', '.join(shuffled_concept)
70
+ concept_prompts.append(f"{shuffled_concept}, {prompt}, {shuffled_concept}")
71
+
72
+ # insert replacement at beginning and end of prompt
73
+ shuffled_replacement = [x.strip() for x in replacement.replacement.split(',')]
74
+ random.shuffle(shuffled_replacement)
75
+ shuffled_replacement = ', '.join(shuffled_replacement)
76
+ replacement_prompts.append(f"{shuffled_replacement}, {prompt}, {shuffled_replacement}")
77
+
78
+ # predict the replacement without network
79
+ conditional_embeds = self.sd.encode_prompt(replacement_prompts).to(self.device_torch, dtype=dtype)
80
+
81
+ replacement_pred = self.sd.predict_noise(
82
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
83
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
84
+ timestep=timesteps,
85
+ guidance_scale=1.0,
86
+ )
87
+
88
+ del conditional_embeds
89
+ replacement_pred = replacement_pred.detach()
90
+
91
+ self.optimizer.zero_grad()
92
+ flush()
93
+
94
+ # text encoding
95
+ grad_on_text_encoder = False
96
+ if self.train_config.train_text_encoder:
97
+ grad_on_text_encoder = True
98
+
99
+ if self.embedding:
100
+ grad_on_text_encoder = True
101
+
102
+ # set the weights
103
+ network.multiplier = network_weight_list
104
+
105
+ # activate network if it exits
106
+ with network:
107
+ with torch.set_grad_enabled(grad_on_text_encoder):
108
+ # embed the prompts
109
+ conditional_embeds = self.sd.encode_prompt(concept_prompts).to(self.device_torch, dtype=dtype)
110
+ if not grad_on_text_encoder:
111
+ # detach the embeddings
112
+ conditional_embeds = conditional_embeds.detach()
113
+ self.optimizer.zero_grad()
114
+ flush()
115
+
116
+ noise_pred = self.sd.predict_noise(
117
+ latents=noisy_latents.to(self.device_torch, dtype=dtype),
118
+ conditional_embeddings=conditional_embeds.to(self.device_torch, dtype=dtype),
119
+ timestep=timesteps,
120
+ guidance_scale=1.0,
121
+ )
122
+
123
+ loss = torch.nn.functional.mse_loss(noise_pred.float(), replacement_pred.float(), reduction="none")
124
+ loss = loss.mean([1, 2, 3])
125
+
126
+ if self.train_config.min_snr_gamma is not None and self.train_config.min_snr_gamma > 0.000001:
127
+ # add min_snr_gamma
128
+ loss = apply_snr_weight(loss, timesteps, self.sd.noise_scheduler, self.train_config.min_snr_gamma)
129
+
130
+ loss = loss.mean()
131
+
132
+ # back propagate loss to free ram
133
+ loss.backward()
134
+ flush()
135
+
136
+ # apply gradients
137
+ self.optimizer.step()
138
+ self.optimizer.zero_grad()
139
+ self.lr_scheduler.step()
140
+
141
+ if self.embedding is not None:
142
+ # Let's make sure we don't update any embedding weights besides the newly added token
143
+ self.embedding.restore_embeddings()
144
+
145
+ loss_dict = OrderedDict(
146
+ {'loss': loss.item()}
147
+ )
148
+ # reset network multiplier
149
+ network.multiplier = 1.0
150
+
151
+ return loss_dict
extensions_built_in/concept_replacer/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This is an example extension for custom training. It is great for experimenting with new ideas.
2
+ from toolkit.extension import Extension
3
+
4
+
5
+ # This is for generic training (LoRA, Dreambooth, FineTuning)
6
+ class ConceptReplacerExtension(Extension):
7
+ # uid must be unique, it is how the extension is identified
8
+ uid = "concept_replacer"
9
+
10
+ # name is the name of the extension for printing
11
+ name = "Concept Replacer"
12
+
13
+ # This is where your process class is loaded
14
+ # keep your imports in here so they don't slow down the rest of the program
15
+ @classmethod
16
+ def get_process(cls):
17
+ # import your process class here so it is only loaded when needed and return it
18
+ from .ConceptReplacer import ConceptReplacer
19
+ return ConceptReplacer
20
+
21
+
22
+
23
+ AI_TOOLKIT_EXTENSIONS = [
24
+ # you can put a list of extensions here
25
+ ConceptReplacerExtension,
26
+ ]
extensions_built_in/concept_replacer/config/train.example.yaml ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ job: extension
3
+ config:
4
+ name: test_v1
5
+ process:
6
+ - type: 'textual_inversion_trainer'
7
+ training_folder: "out/TI"
8
+ device: cuda:0
9
+ # for tensorboard logging
10
+ log_dir: "out/.tensorboard"
11
+ embedding:
12
+ trigger: "your_trigger_here"
13
+ tokens: 12
14
+ init_words: "man with short brown hair"
15
+ save_format: "safetensors" # 'safetensors' or 'pt'
16
+ save:
17
+ dtype: float16 # precision to save
18
+ save_every: 100 # save every this many steps
19
+ max_step_saves_to_keep: 5 # only affects step counts
20
+ datasets:
21
+ - folder_path: "/path/to/dataset"
22
+ caption_ext: "txt"
23
+ default_caption: "[trigger]"
24
+ buckets: true
25
+ resolution: 512
26
+ train:
27
+ noise_scheduler: "ddpm" # or "ddpm", "lms", "euler_a"
28
+ steps: 3000
29
+ weight_jitter: 0.0
30
+ lr: 5e-5
31
+ train_unet: false
32
+ gradient_checkpointing: true
33
+ train_text_encoder: false
34
+ optimizer: "adamw"
35
+ # optimizer: "prodigy"
36
+ optimizer_params:
37
+ weight_decay: 1e-2
38
+ lr_scheduler: "constant"
39
+ max_denoising_steps: 1000
40
+ batch_size: 4
41
+ dtype: bf16
42
+ xformers: true
43
+ min_snr_gamma: 5.0
44
+ # skip_first_sample: true
45
+ noise_offset: 0.0 # not needed for this
46
+ model:
47
+ # objective reality v2
48
+ name_or_path: "https://civitai.com/models/128453?modelVersionId=142465"
49
+ is_v2: false # for v2 models
50
+ is_xl: false # for SDXL models
51
+ is_v_pred: false # for v-prediction models (most v2 models)
52
+ sample:
53
+ sampler: "ddpm" # must match train.noise_scheduler
54
+ sample_every: 100 # sample every this many steps
55
+ width: 512
56
+ height: 512
57
+ prompts:
58
+ - "photo of [trigger] laughing"
59
+ - "photo of [trigger] smiling"
60
+ - "[trigger] close up"
61
+ - "dark scene [trigger] frozen"
62
+ - "[trigger] nighttime"
63
+ - "a painting of [trigger]"
64
+ - "a drawing of [trigger]"
65
+ - "a cartoon of [trigger]"
66
+ - "[trigger] pixar style"
67
+ - "[trigger] costume"
68
+ neg: ""
69
+ seed: 42
70
+ walk_seed: false
71
+ guidance_scale: 7
72
+ sample_steps: 20
73
+ network_multiplier: 1.0
74
+
75
+ logging:
76
+ log_every: 10 # log every this many steps
77
+ use_wandb: false # not supported yet
78
+ verbose: false
79
+
80
+ # You can put any information you want here, and it will be saved in the model.
81
+ # The below is an example, but you can put your grocery list in it if you want.
82
+ # It is saved in the model so be aware of that. The software will include this
83
+ # plus some other information for you automatically
84
+ meta:
85
+ # [name] gets replaced with the name above
86
+ name: "[name]"
87
+ # version: '1.0'
88
+ # creator:
89
+ # name: Your Name
90
+ # email: [email protected]
91
+ # website: https://your.website
extensions_built_in/dataset_tools/DatasetTools.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict
2
+ import gc
3
+ import torch
4
+ from jobs.process import BaseExtensionProcess
5
+
6
+
7
+ def flush():
8
+ torch.cuda.empty_cache()
9
+ gc.collect()
10
+
11
+
12
+ class DatasetTools(BaseExtensionProcess):
13
+
14
+ def __init__(self, process_id: int, job, config: OrderedDict):
15
+ super().__init__(process_id, job, config)
16
+
17
+ def run(self):
18
+ super().run()
19
+
20
+ raise NotImplementedError("This extension is not yet implemented")
extensions_built_in/dataset_tools/SuperTagger.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ from collections import OrderedDict
5
+ import gc
6
+ import traceback
7
+ import torch
8
+ from PIL import Image, ImageOps
9
+ from tqdm import tqdm
10
+
11
+ from .tools.dataset_tools_config_modules import RAW_DIR, TRAIN_DIR, Step, ImgInfo
12
+ from .tools.fuyu_utils import FuyuImageProcessor
13
+ from .tools.image_tools import load_image, ImageProcessor, resize_to_max
14
+ from .tools.llava_utils import LLaVAImageProcessor
15
+ from .tools.caption import default_long_prompt, default_short_prompt, default_replacements
16
+ from jobs.process import BaseExtensionProcess
17
+ from .tools.sync_tools import get_img_paths
18
+
19
+ img_ext = ['.jpg', '.jpeg', '.png', '.webp']
20
+
21
+
22
+ def flush():
23
+ torch.cuda.empty_cache()
24
+ gc.collect()
25
+
26
+
27
+ VERSION = 2
28
+
29
+
30
+ class SuperTagger(BaseExtensionProcess):
31
+
32
+ def __init__(self, process_id: int, job, config: OrderedDict):
33
+ super().__init__(process_id, job, config)
34
+ parent_dir = config.get('parent_dir', None)
35
+ self.dataset_paths: list[str] = config.get('dataset_paths', [])
36
+ self.device = config.get('device', 'cuda')
37
+ self.steps: list[Step] = config.get('steps', [])
38
+ self.caption_method = config.get('caption_method', 'llava:default')
39
+ self.caption_prompt = config.get('caption_prompt', default_long_prompt)
40
+ self.caption_short_prompt = config.get('caption_short_prompt', default_short_prompt)
41
+ self.force_reprocess_img = config.get('force_reprocess_img', False)
42
+ self.caption_replacements = config.get('caption_replacements', default_replacements)
43
+ self.caption_short_replacements = config.get('caption_short_replacements', default_replacements)
44
+ self.master_dataset_dict = OrderedDict()
45
+ self.dataset_master_config_file = config.get('dataset_master_config_file', None)
46
+ if parent_dir is not None and len(self.dataset_paths) == 0:
47
+ # find all folders in the patent_dataset_path
48
+ self.dataset_paths = [
49
+ os.path.join(parent_dir, folder)
50
+ for folder in os.listdir(parent_dir)
51
+ if os.path.isdir(os.path.join(parent_dir, folder))
52
+ ]
53
+ else:
54
+ # make sure they exist
55
+ for dataset_path in self.dataset_paths:
56
+ if not os.path.exists(dataset_path):
57
+ raise ValueError(f"Dataset path does not exist: {dataset_path}")
58
+
59
+ print(f"Found {len(self.dataset_paths)} dataset paths")
60
+
61
+ self.image_processor: ImageProcessor = self.get_image_processor()
62
+
63
+ def get_image_processor(self):
64
+ if self.caption_method.startswith('llava'):
65
+ return LLaVAImageProcessor(device=self.device)
66
+ elif self.caption_method.startswith('fuyu'):
67
+ return FuyuImageProcessor(device=self.device)
68
+ else:
69
+ raise ValueError(f"Unknown caption method: {self.caption_method}")
70
+
71
+ def process_image(self, img_path: str):
72
+ root_img_dir = os.path.dirname(os.path.dirname(img_path))
73
+ filename = os.path.basename(img_path)
74
+ filename_no_ext = os.path.splitext(filename)[0]
75
+ train_dir = os.path.join(root_img_dir, TRAIN_DIR)
76
+ train_img_path = os.path.join(train_dir, filename)
77
+ json_path = os.path.join(train_dir, f"{filename_no_ext}.json")
78
+
79
+ # check if json exists, if it does load it as image info
80
+ if os.path.exists(json_path):
81
+ with open(json_path, 'r') as f:
82
+ img_info = ImgInfo(**json.load(f))
83
+ else:
84
+ img_info = ImgInfo()
85
+
86
+ # always send steps first in case other processes need them
87
+ img_info.add_steps(copy.deepcopy(self.steps))
88
+ img_info.set_version(VERSION)
89
+ img_info.set_caption_method(self.caption_method)
90
+
91
+ image: Image = None
92
+ caption_image: Image = None
93
+
94
+ did_update_image = False
95
+
96
+ # trigger reprocess of steps
97
+ if self.force_reprocess_img:
98
+ img_info.trigger_image_reprocess()
99
+
100
+ # set the image as updated if it does not exist on disk
101
+ if not os.path.exists(train_img_path):
102
+ did_update_image = True
103
+ image = load_image(img_path)
104
+ if img_info.force_image_process:
105
+ did_update_image = True
106
+ image = load_image(img_path)
107
+
108
+ # go through the needed steps
109
+ for step in copy.deepcopy(img_info.state.steps_to_complete):
110
+ if step == 'caption':
111
+ # load image
112
+ if image is None:
113
+ image = load_image(img_path)
114
+ if caption_image is None:
115
+ caption_image = resize_to_max(image, 1024, 1024)
116
+
117
+ if not self.image_processor.is_loaded:
118
+ print('Loading Model. Takes a while, especially the first time')
119
+ self.image_processor.load_model()
120
+
121
+ img_info.caption = self.image_processor.generate_caption(
122
+ image=caption_image,
123
+ prompt=self.caption_prompt,
124
+ replacements=self.caption_replacements
125
+ )
126
+ img_info.mark_step_complete(step)
127
+ elif step == 'caption_short':
128
+ # load image
129
+ if image is None:
130
+ image = load_image(img_path)
131
+
132
+ if caption_image is None:
133
+ caption_image = resize_to_max(image, 1024, 1024)
134
+
135
+ if not self.image_processor.is_loaded:
136
+ print('Loading Model. Takes a while, especially the first time')
137
+ self.image_processor.load_model()
138
+ img_info.caption_short = self.image_processor.generate_caption(
139
+ image=caption_image,
140
+ prompt=self.caption_short_prompt,
141
+ replacements=self.caption_short_replacements
142
+ )
143
+ img_info.mark_step_complete(step)
144
+ elif step == 'contrast_stretch':
145
+ # load image
146
+ if image is None:
147
+ image = load_image(img_path)
148
+ image = ImageOps.autocontrast(image, cutoff=(0.1, 0), preserve_tone=True)
149
+ did_update_image = True
150
+ img_info.mark_step_complete(step)
151
+ else:
152
+ raise ValueError(f"Unknown step: {step}")
153
+
154
+ os.makedirs(os.path.dirname(train_img_path), exist_ok=True)
155
+ if did_update_image:
156
+ image.save(train_img_path)
157
+
158
+ if img_info.is_dirty:
159
+ with open(json_path, 'w') as f:
160
+ json.dump(img_info.to_dict(), f, indent=4)
161
+
162
+ if self.dataset_master_config_file:
163
+ # add to master dict
164
+ self.master_dataset_dict[train_img_path] = img_info.to_dict()
165
+
166
+ def run(self):
167
+ super().run()
168
+ imgs_to_process = []
169
+ # find all images
170
+ for dataset_path in self.dataset_paths:
171
+ raw_dir = os.path.join(dataset_path, RAW_DIR)
172
+ raw_image_paths = get_img_paths(raw_dir)
173
+ for raw_image_path in raw_image_paths:
174
+ imgs_to_process.append(raw_image_path)
175
+
176
+ if len(imgs_to_process) == 0:
177
+ print(f"No images to process")
178
+ else:
179
+ print(f"Found {len(imgs_to_process)} to process")
180
+
181
+ for img_path in tqdm(imgs_to_process, desc="Processing images"):
182
+ try:
183
+ self.process_image(img_path)
184
+ except Exception:
185
+ # print full stack trace
186
+ print(traceback.format_exc())
187
+ continue
188
+ # self.process_image(img_path)
189
+
190
+ if self.dataset_master_config_file is not None:
191
+ # save it as json
192
+ with open(self.dataset_master_config_file, 'w') as f:
193
+ json.dump(self.master_dataset_dict, f, indent=4)
194
+
195
+ del self.image_processor
196
+ flush()
extensions_built_in/dataset_tools/SyncFromCollection.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ from collections import OrderedDict
4
+ import gc
5
+ from typing import List
6
+
7
+ import torch
8
+ from tqdm import tqdm
9
+
10
+ from .tools.dataset_tools_config_modules import DatasetSyncCollectionConfig, RAW_DIR, NEW_DIR
11
+ from .tools.sync_tools import get_unsplash_images, get_pexels_images, get_local_image_file_names, download_image, \
12
+ get_img_paths
13
+ from jobs.process import BaseExtensionProcess
14
+
15
+
16
+ def flush():
17
+ torch.cuda.empty_cache()
18
+ gc.collect()
19
+
20
+
21
+ class SyncFromCollection(BaseExtensionProcess):
22
+
23
+ def __init__(self, process_id: int, job, config: OrderedDict):
24
+ super().__init__(process_id, job, config)
25
+
26
+ self.min_width = config.get('min_width', 1024)
27
+ self.min_height = config.get('min_height', 1024)
28
+
29
+ # add our min_width and min_height to each dataset config if they don't exist
30
+ for dataset_config in config.get('dataset_sync', []):
31
+ if 'min_width' not in dataset_config:
32
+ dataset_config['min_width'] = self.min_width
33
+ if 'min_height' not in dataset_config:
34
+ dataset_config['min_height'] = self.min_height
35
+
36
+ self.dataset_configs: List[DatasetSyncCollectionConfig] = [
37
+ DatasetSyncCollectionConfig(**dataset_config)
38
+ for dataset_config in config.get('dataset_sync', [])
39
+ ]
40
+ print(f"Found {len(self.dataset_configs)} dataset configs")
41
+
42
+ def move_new_images(self, root_dir: str):
43
+ raw_dir = os.path.join(root_dir, RAW_DIR)
44
+ new_dir = os.path.join(root_dir, NEW_DIR)
45
+ new_images = get_img_paths(new_dir)
46
+
47
+ for img_path in new_images:
48
+ # move to raw
49
+ new_path = os.path.join(raw_dir, os.path.basename(img_path))
50
+ shutil.move(img_path, new_path)
51
+
52
+ # remove new dir
53
+ shutil.rmtree(new_dir)
54
+
55
+ def sync_dataset(self, config: DatasetSyncCollectionConfig):
56
+ if config.host == 'unsplash':
57
+ get_images = get_unsplash_images
58
+ elif config.host == 'pexels':
59
+ get_images = get_pexels_images
60
+ else:
61
+ raise ValueError(f"Unknown host: {config.host}")
62
+
63
+ results = {
64
+ 'num_downloaded': 0,
65
+ 'num_skipped': 0,
66
+ 'bad': 0,
67
+ 'total': 0,
68
+ }
69
+
70
+ photos = get_images(config)
71
+ raw_dir = os.path.join(config.directory, RAW_DIR)
72
+ new_dir = os.path.join(config.directory, NEW_DIR)
73
+ raw_images = get_local_image_file_names(raw_dir)
74
+ new_images = get_local_image_file_names(new_dir)
75
+
76
+ for photo in tqdm(photos, desc=f"{config.host}-{config.collection_id}"):
77
+ try:
78
+ if photo.filename not in raw_images and photo.filename not in new_images:
79
+ download_image(photo, new_dir, min_width=self.min_width, min_height=self.min_height)
80
+ results['num_downloaded'] += 1
81
+ else:
82
+ results['num_skipped'] += 1
83
+ except Exception as e:
84
+ print(f" - BAD({photo.id}): {e}")
85
+ results['bad'] += 1
86
+ continue
87
+ results['total'] += 1
88
+
89
+ return results
90
+
91
+ def print_results(self, results):
92
+ print(
93
+ f" - new:{results['num_downloaded']}, old:{results['num_skipped']}, bad:{results['bad']} total:{results['total']}")
94
+
95
+ def run(self):
96
+ super().run()
97
+ print(f"Syncing {len(self.dataset_configs)} datasets")
98
+ all_results = None
99
+ failed_datasets = []
100
+ for dataset_config in tqdm(self.dataset_configs, desc="Syncing datasets", leave=True):
101
+ try:
102
+ results = self.sync_dataset(dataset_config)
103
+ if all_results is None:
104
+ all_results = {**results}
105
+ else:
106
+ for key, value in results.items():
107
+ all_results[key] += value
108
+
109
+ self.print_results(results)
110
+ except Exception as e:
111
+ print(f" - FAILED: {e}")
112
+ if 'response' in e.__dict__:
113
+ error = f"{e.response.status_code}: {e.response.text}"
114
+ print(f" - {error}")
115
+ failed_datasets.append({'dataset': dataset_config, 'error': error})
116
+ else:
117
+ failed_datasets.append({'dataset': dataset_config, 'error': str(e)})
118
+ continue
119
+
120
+ print("Moving new images to raw")
121
+ for dataset_config in self.dataset_configs:
122
+ self.move_new_images(dataset_config.directory)
123
+
124
+ print("Done syncing datasets")
125
+ self.print_results(all_results)
126
+
127
+ if len(failed_datasets) > 0:
128
+ print(f"Failed to sync {len(failed_datasets)} datasets")
129
+ for failed in failed_datasets:
130
+ print(f" - {failed['dataset'].host}-{failed['dataset'].collection_id}")
131
+ print(f" - ERR: {failed['error']}")