Spaces:
Running
Running
Commit
·
ecd5028
1
Parent(s):
7c52128
upgrade finetrainers + gradio
Browse files- README.md +1 -1
- docs/huggingface/Downloading files from the hub.md +270 -0
- docs/huggingface/HfApi Client API Reference.md +0 -0
- docs/huggingface/Load a dataset from the hub.md +126 -0
- docs/huggingface/Search the Hub.md +61 -0
- finetrainers/args.py +4 -0
- finetrainers/config.py +10 -4
- finetrainers/data/__init__.py +9 -1
- finetrainers/data/dataset.py +154 -20
- finetrainers/data/precomputation.py +222 -9
- finetrainers/functional/image.py +1 -1
- finetrainers/models/cogvideox/base_specification.py +1 -2
- finetrainers/models/cogview4/__init__.py +1 -0
- finetrainers/models/cogview4/base_specification.py +395 -0
- finetrainers/models/hunyuan_video/base_specification.py +1 -4
- finetrainers/models/ltx_video/base_specification.py +2 -6
- finetrainers/models/modeling_utils.py +0 -3
- finetrainers/models/wan/base_specification.py +2 -11
- finetrainers/processors/__init__.py +1 -0
- finetrainers/processors/glm.py +74 -0
- finetrainers/trainer/sft_trainer/trainer.py +116 -61
- finetrainers/utils/__init__.py +1 -0
- finetrainers/utils/diffusion.py +7 -0
- requirements.txt +1 -1
- requirements_without_flash_attention.txt +1 -1
README.md
CHANGED
@@ -4,7 +4,7 @@ emoji: 🎥
|
|
4 |
colorFrom: gray
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
-
sdk_version: 5.
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
|
|
4 |
colorFrom: gray
|
5 |
colorTo: gray
|
6 |
sdk: gradio
|
7 |
+
sdk_version: 5.20.1
|
8 |
app_file: app.py
|
9 |
pinned: true
|
10 |
license: apache-2.0
|
docs/huggingface/Downloading files from the hub.md
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[](#downloading-files)Downloading files
|
2 |
+
=======================================
|
3 |
+
|
4 |
+
[](#download-a-single-file)Download a single file
|
5 |
+
-------------------------------------------------
|
6 |
+
|
7 |
+
### [](#huggingface_hub.hf_hub_download)hf\_hub\_download
|
8 |
+
|
9 |
+
#### huggingface\_hub.hf\_hub\_download
|
10 |
+
|
11 |
+
[](#huggingface_hub.hf_hub_download)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L663)
|
12 |
+
|
13 |
+
( repo\_id: strfilename: strsubfolder: typing.Optional\[str\] = Nonerepo\_type: typing.Optional\[str\] = Nonerevision: typing.Optional\[str\] = Nonelibrary\_name: typing.Optional\[str\] = Nonelibrary\_version: typing.Optional\[str\] = Nonecache\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Nonelocal\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Noneuser\_agent: typing.Union\[typing.Dict, str, NoneType\] = Noneforce\_download: bool = Falseproxies: typing.Optional\[typing.Dict\] = Noneetag\_timeout: float = 10token: typing.Union\[bool, str, NoneType\] = Nonelocal\_files\_only: bool = Falseheaders: typing.Optional\[typing.Dict\[str, str\]\] = Noneendpoint: typing.Optional\[str\] = Noneresume\_download: typing.Optional\[bool\] = Noneforce\_filename: typing.Optional\[str\] = Nonelocal\_dir\_use\_symlinks: typing.Union\[bool, typing.Literal\['auto'\]\] = 'auto' ) → export const metadata = 'undefined';`str`
|
14 |
+
|
15 |
+
Expand 16 parameters
|
16 |
+
|
17 |
+
Parameters
|
18 |
+
|
19 |
+
* [](#huggingface_hub.hf_hub_download.repo_id)**repo\_id** (`str`) — A user or an organization name and a repo name separated by a `/`.
|
20 |
+
* [](#huggingface_hub.hf_hub_download.filename)**filename** (`str`) — The name of the file in the repo.
|
21 |
+
* [](#huggingface_hub.hf_hub_download.subfolder)**subfolder** (`str`, _optional_) — An optional value corresponding to a folder inside the model repo.
|
22 |
+
* [](#huggingface_hub.hf_hub_download.repo_type)**repo\_type** (`str`, _optional_) — Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`.
|
23 |
+
* [](#huggingface_hub.hf_hub_download.revision)**revision** (`str`, _optional_) — An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
24 |
+
* [](#huggingface_hub.hf_hub_download.library_name)**library\_name** (`str`, _optional_) — The name of the library to which the object corresponds.
|
25 |
+
* [](#huggingface_hub.hf_hub_download.library_version)**library\_version** (`str`, _optional_) — The version of the library.
|
26 |
+
* [](#huggingface_hub.hf_hub_download.cache_dir)**cache\_dir** (`str`, `Path`, _optional_) — Path to the folder where cached files are stored.
|
27 |
+
* [](#huggingface_hub.hf_hub_download.local_dir)**local\_dir** (`str` or `Path`, _optional_) — If provided, the downloaded file will be placed under this directory.
|
28 |
+
* [](#huggingface_hub.hf_hub_download.user_agent)**user\_agent** (`dict`, `str`, _optional_) — The user-agent info in the form of a dictionary or a string.
|
29 |
+
* [](#huggingface_hub.hf_hub_download.force_download)**force\_download** (`bool`, _optional_, defaults to `False`) — Whether the file should be downloaded even if it already exists in the local cache.
|
30 |
+
* [](#huggingface_hub.hf_hub_download.proxies)**proxies** (`dict`, _optional_) — Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
|
31 |
+
* [](#huggingface_hub.hf_hub_download.etag_timeout)**etag\_timeout** (`float`, _optional_, defaults to `10`) — When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`.
|
32 |
+
* [](#huggingface_hub.hf_hub_download.token)**token** (`str`, `bool`, _optional_) — A token to be used for the download.
|
33 |
+
|
34 |
+
* If `True`, the token is read from the HuggingFace config folder.
|
35 |
+
* If a string, it’s used as the authentication token.
|
36 |
+
|
37 |
+
* [](#huggingface_hub.hf_hub_download.local_files_only)**local\_files\_only** (`bool`, _optional_, defaults to `False`) — If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
38 |
+
* [](#huggingface_hub.hf_hub_download.headers)**headers** (`dict`, _optional_) — Additional headers to be sent with the request.
|
39 |
+
|
40 |
+
Returns
|
41 |
+
|
42 |
+
export const metadata = 'undefined';
|
43 |
+
|
44 |
+
`str`
|
45 |
+
|
46 |
+
export const metadata = 'undefined';
|
47 |
+
|
48 |
+
Local path of file or if networking is off, last version of file cached on disk.
|
49 |
+
|
50 |
+
Raises
|
51 |
+
|
52 |
+
export const metadata = 'undefined';
|
53 |
+
|
54 |
+
[RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) or [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) or [EntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.EntryNotFoundError) or [LocalEntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.LocalEntryNotFoundError) or `EnvironmentError` or `OSError` or `ValueError`
|
55 |
+
|
56 |
+
export const metadata = 'undefined';
|
57 |
+
|
58 |
+
* [RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) — If the repository to download from cannot be found. This may be because it doesn’t exist, or because it is set to `private` and you do not have access.
|
59 |
+
* [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) — If the revision to download from cannot be found.
|
60 |
+
* [EntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.EntryNotFoundError) — If the file to download cannot be found.
|
61 |
+
* [LocalEntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.LocalEntryNotFoundError) — If network is disabled or unavailable and file is not found in cache.
|
62 |
+
* [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) — If `token=True` but the token cannot be found.
|
63 |
+
* [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) — If ETag cannot be determined.
|
64 |
+
* [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) — If some parameter value is invalid.
|
65 |
+
|
66 |
+
Download a given file if it’s not already present in the local cache.
|
67 |
+
|
68 |
+
The new cache file layout looks like this:
|
69 |
+
|
70 |
+
* The cache directory contains one subfolder per repo\_id (namespaced by repo type)
|
71 |
+
* inside each repo folder:
|
72 |
+
* refs is a list of the latest known revision => commit\_hash pairs
|
73 |
+
* blobs contains the actual file blobs (identified by their git-sha or sha256, depending on whether they’re LFS files or not)
|
74 |
+
* snapshots contains one subfolder per commit, each “commit” contains the subset of the files that have been resolved at that particular commit. Each filename is a symlink to the blob at that particular commit.
|
75 |
+
|
76 |
+
[](#huggingface_hub.hf_hub_download.example)
|
77 |
+
|
78 |
+
Copied
|
79 |
+
|
80 |
+
\[ 96\] .
|
81 |
+
└── \[ 160\] models\--julien-c--EsperBERTo-small
|
82 |
+
├── \[ 160\] blobs
|
83 |
+
│ ├── \[321M\] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
84 |
+
│ ├── \[ 398\] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
85 |
+
│ └── \[1.4K\] d7edf6bd2a681fb0175f7735299831ee1b22b812
|
86 |
+
├── \[ 96\] refs
|
87 |
+
│ └── \[ 40\] main
|
88 |
+
└── \[ 128\] snapshots
|
89 |
+
├── \[ 128\] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
|
90 |
+
│ ├── \[ 52\] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
|
91 |
+
│ └── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
92 |
+
└── \[ 128\] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
|
93 |
+
├── \[ 52\] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
|
94 |
+
└── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
|
95 |
+
|
96 |
+
If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it’s optimized for regularly pulling the latest version of a repository.
|
97 |
+
|
98 |
+
### [](#huggingface_hub.hf_hub_url)hf\_hub\_url
|
99 |
+
|
100 |
+
#### huggingface\_hub.hf\_hub\_url
|
101 |
+
|
102 |
+
[](#huggingface_hub.hf_hub_url)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L171)
|
103 |
+
|
104 |
+
( repo\_id: strfilename: strsubfolder: typing.Optional\[str\] = Nonerepo\_type: typing.Optional\[str\] = Nonerevision: typing.Optional\[str\] = Noneendpoint: typing.Optional\[str\] = None )
|
105 |
+
|
106 |
+
Parameters
|
107 |
+
|
108 |
+
* [](#huggingface_hub.hf_hub_url.repo_id)**repo\_id** (`str`) — A namespace (user or an organization) name and a repo name separated by a `/`.
|
109 |
+
* [](#huggingface_hub.hf_hub_url.filename)**filename** (`str`) — The name of the file in the repo.
|
110 |
+
* [](#huggingface_hub.hf_hub_url.subfolder)**subfolder** (`str`, _optional_) — An optional value corresponding to a folder inside the repo.
|
111 |
+
* [](#huggingface_hub.hf_hub_url.repo_type)**repo\_type** (`str`, _optional_) — Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`.
|
112 |
+
* [](#huggingface_hub.hf_hub_url.revision)**revision** (`str`, _optional_) — An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
113 |
+
|
114 |
+
Construct the URL of a file from the given information.
|
115 |
+
|
116 |
+
The resolved address can either be a huggingface.co-hosted url, or a link to Cloudfront (a Content Delivery Network, or CDN) for large files which are more than a few MBs.
|
117 |
+
|
118 |
+
[](#huggingface_hub.hf_hub_url.example)
|
119 |
+
|
120 |
+
Example:
|
121 |
+
|
122 |
+
Copied
|
123 |
+
|
124 |
+
\>>> from huggingface\_hub import hf\_hub\_url
|
125 |
+
|
126 |
+
\>>> hf\_hub\_url(
|
127 |
+
... repo\_id="julien-c/EsperBERTo-small", filename="pytorch\_model.bin"
|
128 |
+
... )
|
129 |
+
'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch\_model.bin'
|
130 |
+
|
131 |
+
Notes:
|
132 |
+
|
133 |
+
Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our bandwidth costs).
|
134 |
+
|
135 |
+
Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here because we implement a git-based versioning system on huggingface.co, which means that we store the files on S3/Cloudfront in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache can’t ever be stale.
|
136 |
+
|
137 |
+
In terms of client-side caching from this library, we base our caching on the objects’ entity tag (`ETag`), which is an identifier of a specific version of a resource \[1\]\_. An object’s ETag is: its git-sha1 if stored in git, or its sha256 if stored in git-lfs.
|
138 |
+
|
139 |
+
References:
|
140 |
+
|
141 |
+
* \[1\] [https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag)
|
142 |
+
|
143 |
+
[](#huggingface_hub.snapshot_download)Download a snapshot of the repo
|
144 |
+
---------------------------------------------------------------------
|
145 |
+
|
146 |
+
#### huggingface\_hub.snapshot\_download
|
147 |
+
|
148 |
+
[](#huggingface_hub.snapshot_download)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/_snapshot_download.py#L20)
|
149 |
+
|
150 |
+
( repo\_id: strrepo\_type: typing.Optional\[str\] = Nonerevision: typing.Optional\[str\] = Nonecache\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Nonelocal\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Nonelibrary\_name: typing.Optional\[str\] = Nonelibrary\_version: typing.Optional\[str\] = Noneuser\_agent: typing.Union\[typing.Dict, str, NoneType\] = Noneproxies: typing.Optional\[typing.Dict\] = Noneetag\_timeout: float = 10force\_download: bool = Falsetoken: typing.Union\[bool, str, NoneType\] = Nonelocal\_files\_only: bool = Falseallow\_patterns: typing.Union\[typing.List\[str\], str, NoneType\] = Noneignore\_patterns: typing.Union\[typing.List\[str\], str, NoneType\] = Nonemax\_workers: int = 8tqdm\_class: typing.Optional\[tqdm.asyncio.tqdm\_asyncio\] = Noneheaders: typing.Optional\[typing.Dict\[str, str\]\] = Noneendpoint: typing.Optional\[str\] = Nonelocal\_dir\_use\_symlinks: typing.Union\[bool, typing.Literal\['auto'\]\] = 'auto'resume\_download: typing.Optional\[bool\] = None ) → export const metadata = 'undefined';`str`
|
151 |
+
|
152 |
+
Expand 18 parameters
|
153 |
+
|
154 |
+
Parameters
|
155 |
+
|
156 |
+
* [](#huggingface_hub.snapshot_download.repo_id)**repo\_id** (`str`) — A user or an organization name and a repo name separated by a `/`.
|
157 |
+
* [](#huggingface_hub.snapshot_download.repo_type)**repo\_type** (`str`, _optional_) — Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`.
|
158 |
+
* [](#huggingface_hub.snapshot_download.revision)**revision** (`str`, _optional_) — An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
159 |
+
* [](#huggingface_hub.snapshot_download.cache_dir)**cache\_dir** (`str`, `Path`, _optional_) — Path to the folder where cached files are stored.
|
160 |
+
* [](#huggingface_hub.snapshot_download.local_dir)**local\_dir** (`str` or `Path`, _optional_) — If provided, the downloaded files will be placed under this directory.
|
161 |
+
* [](#huggingface_hub.snapshot_download.library_name)**library\_name** (`str`, _optional_) — The name of the library to which the object corresponds.
|
162 |
+
* [](#huggingface_hub.snapshot_download.library_version)**library\_version** (`str`, _optional_) — The version of the library.
|
163 |
+
* [](#huggingface_hub.snapshot_download.user_agent)**user\_agent** (`str`, `dict`, _optional_) — The user-agent info in the form of a dictionary or a string.
|
164 |
+
* [](#huggingface_hub.snapshot_download.proxies)**proxies** (`dict`, _optional_) — Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
|
165 |
+
* [](#huggingface_hub.snapshot_download.etag_timeout)**etag\_timeout** (`float`, _optional_, defaults to `10`) — When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`.
|
166 |
+
* [](#huggingface_hub.snapshot_download.force_download)**force\_download** (`bool`, _optional_, defaults to `False`) — Whether the file should be downloaded even if it already exists in the local cache.
|
167 |
+
* [](#huggingface_hub.snapshot_download.token)**token** (`str`, `bool`, _optional_) — A token to be used for the download.
|
168 |
+
|
169 |
+
* If `True`, the token is read from the HuggingFace config folder.
|
170 |
+
* If a string, it’s used as the authentication token.
|
171 |
+
|
172 |
+
* [](#huggingface_hub.snapshot_download.headers)**headers** (`dict`, _optional_) — Additional headers to include in the request. Those headers take precedence over the others.
|
173 |
+
* [](#huggingface_hub.snapshot_download.local_files_only)**local\_files\_only** (`bool`, _optional_, defaults to `False`) — If `True`, avoid downloading the file and return the path to the local cached file if it exists.
|
174 |
+
* [](#huggingface_hub.snapshot_download.allow_patterns)**allow\_patterns** (`List[str]` or `str`, _optional_) — If provided, only files matching at least one pattern are downloaded.
|
175 |
+
* [](#huggingface_hub.snapshot_download.ignore_patterns)**ignore\_patterns** (`List[str]` or `str`, _optional_) — If provided, files matching any of the patterns are not downloaded.
|
176 |
+
* [](#huggingface_hub.snapshot_download.max_workers)**max\_workers** (`int`, _optional_) — Number of concurrent threads to download files (1 thread = 1 file download). Defaults to 8.
|
177 |
+
* [](#huggingface_hub.snapshot_download.tqdm_class)**tqdm\_class** (`tqdm`, _optional_) — If provided, overwrites the default behavior for the progress bar. Passed argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. Note that the `tqdm_class` is not passed to each individual download. Defaults to the custom HF progress bar that can be disabled by setting `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
|
178 |
+
|
179 |
+
Returns
|
180 |
+
|
181 |
+
export const metadata = 'undefined';
|
182 |
+
|
183 |
+
`str`
|
184 |
+
|
185 |
+
export const metadata = 'undefined';
|
186 |
+
|
187 |
+
folder path of the repo snapshot.
|
188 |
+
|
189 |
+
Raises
|
190 |
+
|
191 |
+
export const metadata = 'undefined';
|
192 |
+
|
193 |
+
[RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) or [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) or `EnvironmentError` or `OSError` or `ValueError`
|
194 |
+
|
195 |
+
export const metadata = 'undefined';
|
196 |
+
|
197 |
+
* [RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) — If the repository to download from cannot be found. This may be because it doesn’t exist, or because it is set to `private` and you do not have access.
|
198 |
+
* [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) — If the revision to download from cannot be found.
|
199 |
+
* [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) — If `token=True` and the token cannot be found.
|
200 |
+
* [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) — if ETag cannot be determined.
|
201 |
+
* [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) — if some parameter value is invalid.
|
202 |
+
|
203 |
+
Download repo files.
|
204 |
+
|
205 |
+
Download a whole snapshot of a repo’s files at the specified revision. This is useful when you want all files from a repo, because you don’t know which ones you will need a priori. All files are nested inside a folder in order to keep their actual filename relative to that folder. You can also filter which files to download using `allow_patterns` and `ignore_patterns`.
|
206 |
+
|
207 |
+
If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it’s optimized for regularly pulling the latest version of a repository.
|
208 |
+
|
209 |
+
An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly configured. It is also not possible to filter which files to download when cloning a repository using git.
|
210 |
+
|
211 |
+
[](#get-metadata-about-a-file)Get metadata about a file
|
212 |
+
-------------------------------------------------------
|
213 |
+
|
214 |
+
### [](#huggingface_hub.get_hf_file_metadata)get\_hf\_file\_metadata
|
215 |
+
|
216 |
+
#### huggingface\_hub.get\_hf\_file\_metadata
|
217 |
+
|
218 |
+
[](#huggingface_hub.get_hf_file_metadata)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L1246)
|
219 |
+
|
220 |
+
( url: strtoken: typing.Union\[bool, str, NoneType\] = Noneproxies: typing.Optional\[typing.Dict\] = Nonetimeout: typing.Optional\[float\] = 10library\_name: typing.Optional\[str\] = Nonelibrary\_version: typing.Optional\[str\] = Noneuser\_agent: typing.Union\[typing.Dict, str, NoneType\] = Noneheaders: typing.Optional\[typing.Dict\[str, str\]\] = None )
|
221 |
+
|
222 |
+
Parameters
|
223 |
+
|
224 |
+
* [](#huggingface_hub.get_hf_file_metadata.url)**url** (`str`) — File url, for example returned by [hf\_hub\_url()](/docs/huggingface_hub/v0.29.2/en/package_reference/file_download#huggingface_hub.hf_hub_url).
|
225 |
+
* [](#huggingface_hub.get_hf_file_metadata.token)**token** (`str` or `bool`, _optional_) — A token to be used for the download.
|
226 |
+
|
227 |
+
* If `True`, the token is read from the HuggingFace config folder.
|
228 |
+
* If `False` or `None`, no token is provided.
|
229 |
+
* If a string, it’s used as the authentication token.
|
230 |
+
|
231 |
+
* [](#huggingface_hub.get_hf_file_metadata.proxies)**proxies** (`dict`, _optional_) — Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
|
232 |
+
* [](#huggingface_hub.get_hf_file_metadata.timeout)**timeout** (`float`, _optional_, defaults to 10) — How many seconds to wait for the server to send metadata before giving up.
|
233 |
+
* [](#huggingface_hub.get_hf_file_metadata.library_name)**library\_name** (`str`, _optional_) — The name of the library to which the object corresponds.
|
234 |
+
* [](#huggingface_hub.get_hf_file_metadata.library_version)**library\_version** (`str`, _optional_) — The version of the library.
|
235 |
+
* [](#huggingface_hub.get_hf_file_metadata.user_agent)**user\_agent** (`dict`, `str`, _optional_) — The user-agent info in the form of a dictionary or a string.
|
236 |
+
* [](#huggingface_hub.get_hf_file_metadata.headers)**headers** (`dict`, _optional_) — Additional headers to be sent with the request.
|
237 |
+
|
238 |
+
Fetch metadata of a file versioned on the Hub for a given url.
|
239 |
+
|
240 |
+
### [](#huggingface_hub.HfFileMetadata)HfFileMetadata
|
241 |
+
|
242 |
+
### class huggingface\_hub.HfFileMetadata
|
243 |
+
|
244 |
+
[](#huggingface_hub.HfFileMetadata)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L147)
|
245 |
+
|
246 |
+
( commit\_hash: typing.Optional\[str\]etag: typing.Optional\[str\]location: strsize: typing.Optional\[int\] )
|
247 |
+
|
248 |
+
Parameters
|
249 |
+
|
250 |
+
* [](#huggingface_hub.HfFileMetadata.commit_hash)**commit\_hash** (`str`, _optional_) — The commit\_hash related to the file.
|
251 |
+
* [](#huggingface_hub.HfFileMetadata.etag)**etag** (`str`, _optional_) — Etag of the file on the server.
|
252 |
+
* [](#huggingface_hub.HfFileMetadata.location)**location** (`str`) — Location where to download the file. Can be a Hub url or not (CDN).
|
253 |
+
* [](#huggingface_hub.HfFileMetadata.size)**size** (`size`) — Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer.
|
254 |
+
|
255 |
+
Data structure containing information about a file versioned on the Hub.
|
256 |
+
|
257 |
+
Returned by [get\_hf\_file\_metadata()](/docs/huggingface_hub/v0.29.2/en/package_reference/file_download#huggingface_hub.get_hf_file_metadata) based on a URL.
|
258 |
+
|
259 |
+
[](#caching)Caching
|
260 |
+
-------------------
|
261 |
+
|
262 |
+
The methods displayed above are designed to work with a caching system that prevents re-downloading files. The caching system was updated in v0.8.0 to become the central cache-system shared across libraries that depend on the Hub.
|
263 |
+
|
264 |
+
Read the [cache-system guide](../guides/manage-cache) for a detailed presentation of caching at at HF.
|
265 |
+
|
266 |
+
[< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/package_reference/file_download.md)
|
267 |
+
|
268 |
+
HfApi Client
|
269 |
+
|
270 |
+
[←Hugging Face Hub API](/docs/huggingface_hub/en/package_reference/hf_api) [Mixins & serialization methods→](/docs/huggingface_hub/en/package_reference/mixins)
|
docs/huggingface/HfApi Client API Reference.md
ADDED
The diff for this file is too large to render.
See raw diff
|
|
docs/huggingface/Load a dataset from the hub.md
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[](#load-a-dataset-from-the-hub)Load a dataset from the Hub
|
2 |
+
===========================================================
|
3 |
+
|
4 |
+
Finding high-quality datasets that are reproducible and accessible can be difficult. One of 🤗 Datasets main goals is to provide a simple way to load a dataset of any format or type. The easiest way to get started is to discover an existing dataset on the [Hugging Face Hub](https://huggingface.co/datasets) - a community-driven collection of datasets for tasks in NLP, computer vision, and audio - and use 🤗 Datasets to download and generate the dataset.
|
5 |
+
|
6 |
+
This tutorial uses the [rotten\_tomatoes](https://huggingface.co/datasets/rotten_tomatoes) and [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) datasets, but feel free to load any dataset you want and follow along. Head over to the Hub now and find a dataset for your task!
|
7 |
+
|
8 |
+
[](#load-a-dataset)Load a dataset
|
9 |
+
---------------------------------
|
10 |
+
|
11 |
+
Before you take the time to download a dataset, it’s often helpful to quickly get some general information about a dataset. A dataset’s information is stored inside [DatasetInfo](/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.DatasetInfo) and can include information such as the dataset description, features, and dataset size.
|
12 |
+
|
13 |
+
Use the [load\_dataset\_builder()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.load_dataset_builder) function to load a dataset builder and inspect a dataset’s attributes without committing to downloading it:
|
14 |
+
|
15 |
+
Copied
|
16 |
+
|
17 |
+
\>>> from datasets import load\_dataset\_builder
|
18 |
+
\>>> ds\_builder = load\_dataset\_builder("cornell-movie-review-data/rotten\_tomatoes")
|
19 |
+
|
20 |
+
\# Inspect dataset description
|
21 |
+
\>>> ds\_builder.info.description
|
22 |
+
Movie Review Dataset. This is a dataset of containing 5,331 positive and 5,331 negative processed sentences from Rotten Tomatoes movie reviews. This data was first used in Bo Pang and Lillian Lee, \`\`Seeing stars: Exploiting class relationships for sentiment categorization with respect to rating scales.'', Proceedings of the ACL, 2005.
|
23 |
+
|
24 |
+
\# Inspect dataset features
|
25 |
+
\>>> ds\_builder.info.features
|
26 |
+
{'label': ClassLabel(names=\['neg', 'pos'\], id\=None),
|
27 |
+
'text': Value(dtype='string', id\=None)}
|
28 |
+
|
29 |
+
If you’re happy with the dataset, then load it with [load\_dataset()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.load_dataset):
|
30 |
+
|
31 |
+
Copied
|
32 |
+
|
33 |
+
\>>> from datasets import load\_dataset
|
34 |
+
|
35 |
+
\>>> dataset = load\_dataset("cornell-movie-review-data/rotten\_tomatoes", split="train")
|
36 |
+
|
37 |
+
[](#splits)Splits
|
38 |
+
-----------------
|
39 |
+
|
40 |
+
A split is a specific subset of a dataset like `train` and `test`. List a dataset’s split names with the [get\_dataset\_split\_names()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.get_dataset_split_names) function:
|
41 |
+
|
42 |
+
Copied
|
43 |
+
|
44 |
+
\>>> from datasets import get\_dataset\_split\_names
|
45 |
+
|
46 |
+
\>>> get\_dataset\_split\_names("cornell-movie-review-data/rotten\_tomatoes")
|
47 |
+
\['train', 'validation', 'test'\]
|
48 |
+
|
49 |
+
Then you can load a specific split with the `split` parameter. Loading a dataset `split` returns a [Dataset](/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.Dataset) object:
|
50 |
+
|
51 |
+
Copied
|
52 |
+
|
53 |
+
\>>> from datasets import load\_dataset
|
54 |
+
|
55 |
+
\>>> dataset = load\_dataset("cornell-movie-review-data/rotten\_tomatoes", split="train")
|
56 |
+
\>>> dataset
|
57 |
+
Dataset({
|
58 |
+
features: \['text', 'label'\],
|
59 |
+
num\_rows: 8530
|
60 |
+
})
|
61 |
+
|
62 |
+
If you don’t specify a `split`, 🤗 Datasets returns a [DatasetDict](/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.DatasetDict) object instead:
|
63 |
+
|
64 |
+
Copied
|
65 |
+
|
66 |
+
\>>> from datasets import load\_dataset
|
67 |
+
|
68 |
+
\>>> dataset = load\_dataset("cornell-movie-review-data/rotten\_tomatoes")
|
69 |
+
DatasetDict({
|
70 |
+
train: Dataset({
|
71 |
+
features: \['text', 'label'\],
|
72 |
+
num\_rows: 8530
|
73 |
+
})
|
74 |
+
validation: Dataset({
|
75 |
+
features: \['text', 'label'\],
|
76 |
+
num\_rows: 1066
|
77 |
+
})
|
78 |
+
test: Dataset({
|
79 |
+
features: \['text', 'label'\],
|
80 |
+
num\_rows: 1066
|
81 |
+
})
|
82 |
+
})
|
83 |
+
|
84 |
+
[](#configurations)Configurations
|
85 |
+
---------------------------------
|
86 |
+
|
87 |
+
Some datasets contain several sub-datasets. For example, the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset has several sub-datasets, each one containing audio data in a different language. These sub-datasets are known as _configurations_ or _subsets_, and you must explicitly select one when loading the dataset. If you don’t provide a configuration name, 🤗 Datasets will raise a `ValueError` and remind you to choose a configuration.
|
88 |
+
|
89 |
+
Use the [get\_dataset\_config\_names()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.get_dataset_config_names) function to retrieve a list of all the possible configurations available to your dataset:
|
90 |
+
|
91 |
+
Copied
|
92 |
+
|
93 |
+
\>>> from datasets import get\_dataset\_config\_names
|
94 |
+
|
95 |
+
\>>> configs = get\_dataset\_config\_names("PolyAI/minds14")
|
96 |
+
\>>> print(configs)
|
97 |
+
\['cs-CZ', 'de-DE', 'en-AU', 'en-GB', 'en-US', 'es-ES', 'fr-FR', 'it-IT', 'ko-KR', 'nl-NL', 'pl-PL', 'pt-PT', 'ru-RU', 'zh-CN', 'all'\]
|
98 |
+
|
99 |
+
Then load the configuration you want:
|
100 |
+
|
101 |
+
Copied
|
102 |
+
|
103 |
+
\>>> from datasets import load\_dataset
|
104 |
+
|
105 |
+
\>>> mindsFR = load\_dataset("PolyAI/minds14", "fr-FR", split="train")
|
106 |
+
|
107 |
+
[](#remote-code)Remote code
|
108 |
+
---------------------------
|
109 |
+
|
110 |
+
Certain datasets repositories contain a loading script with the Python code used to generate the dataset. All files and code uploaded to the Hub are scanned for malware (refer to the Hub security documentation for more information), but you should still review the dataset loading scripts and authors to avoid executing malicious code on your machine. You should set `trust_remote_code=True` to use a dataset with a loading script, or you will get an error:
|
111 |
+
|
112 |
+
Copied
|
113 |
+
|
114 |
+
\>>> from datasets import get\_dataset\_config\_names, get\_dataset\_split\_names, load\_dataset
|
115 |
+
|
116 |
+
\>>> c4 = load\_dataset("c4", "en", split="train", trust\_remote\_code=True)
|
117 |
+
\>>> get\_dataset\_config\_names("c4", trust\_remote\_code=True)
|
118 |
+
\['en', 'realnewslike', 'en.noblocklist', 'en.noclean'\]
|
119 |
+
\>>> get\_dataset\_split\_names("c4", "en", trust\_remote\_code=True)
|
120 |
+
\['train', 'validation'\]
|
121 |
+
|
122 |
+
For security reasons, 🤗 Datasets do not allow running dataset loading scripts by default, and you have to pass `trust_remote_code=True` to load datasets that require running a dataset script.
|
123 |
+
|
124 |
+
[< \> Update on GitHub](https://github.com/huggingface/datasets/blob/main/docs/source/load_hub.mdx)
|
125 |
+
|
126 |
+
[←Overview](/docs/datasets/en/tutorial) [Know your dataset→](/docs/datasets/en/access)
|
docs/huggingface/Search the Hub.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[](#search-the-hub)Search the Hub
|
2 |
+
=================================
|
3 |
+
|
4 |
+
In this tutorial, you will learn how to search models, datasets and spaces on the Hub using `huggingface_hub`.
|
5 |
+
|
6 |
+
[](#how-to-list-repositories-)How to list repositories ?
|
7 |
+
--------------------------------------------------------
|
8 |
+
|
9 |
+
`huggingface_hub` library includes an HTTP client [HfApi](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi) to interact with the Hub. Among other things, it can list models, datasets and spaces stored on the Hub:
|
10 |
+
|
11 |
+
Copied
|
12 |
+
|
13 |
+
\>>> from huggingface\_hub import HfApi
|
14 |
+
\>>> api = HfApi()
|
15 |
+
\>>> models = api.list\_models()
|
16 |
+
|
17 |
+
The output of [list\_models()](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi.list_models) is an iterator over the models stored on the Hub.
|
18 |
+
|
19 |
+
Similarly, you can use [list\_datasets()](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi.list_datasets) to list datasets and [list\_spaces()](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi.list_spaces) to list Spaces.
|
20 |
+
|
21 |
+
[](#how-to-filter-repositories-)How to filter repositories ?
|
22 |
+
------------------------------------------------------------
|
23 |
+
|
24 |
+
Listing repositories is great but now you might want to filter your search. The list helpers have several attributes like:
|
25 |
+
|
26 |
+
* `filter`
|
27 |
+
* `author`
|
28 |
+
* `search`
|
29 |
+
* …
|
30 |
+
|
31 |
+
Let’s see an example to get all models on the Hub that does image classification, have been trained on the imagenet dataset and that runs with PyTorch.
|
32 |
+
|
33 |
+
Copied
|
34 |
+
|
35 |
+
models = hf\_api.list\_models(
|
36 |
+
task="image-classification",
|
37 |
+
library="pytorch",
|
38 |
+
trained\_dataset="imagenet",
|
39 |
+
)
|
40 |
+
|
41 |
+
While filtering, you can also sort the models and take only the top results. For example, the following example fetches the top 5 most downloaded datasets on the Hub:
|
42 |
+
|
43 |
+
Copied
|
44 |
+
|
45 |
+
\>>> list(list\_datasets(sort="downloads", direction=-1, limit=5))
|
46 |
+
\[DatasetInfo(
|
47 |
+
id\='argilla/databricks-dolly-15k-curated-en',
|
48 |
+
author='argilla',
|
49 |
+
sha='4dcd1dedbe148307a833c931b21ca456a1fc4281',
|
50 |
+
last\_modified=datetime.datetime(2023, 10, 2, 12, 32, 53, tzinfo=datetime.timezone.utc),
|
51 |
+
private=False,
|
52 |
+
downloads=8889377,
|
53 |
+
(...)
|
54 |
+
|
55 |
+
To explore available filters on the Hub, visit [models](https://huggingface.co/models) and [datasets](https://huggingface.co/datasets) pages in your browser, search for some parameters and look at the values in the URL.
|
56 |
+
|
57 |
+
[< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/guides/search.md)
|
58 |
+
|
59 |
+
HfApi Client
|
60 |
+
|
61 |
+
[←Repository](/docs/huggingface_hub/en/guides/repository) [Inference→](/docs/huggingface_hub/en/guides/inference)
|
finetrainers/args.py
CHANGED
@@ -316,6 +316,7 @@ class BaseArgs:
|
|
316 |
# Dataset arguments
|
317 |
dataset_config: str = None
|
318 |
dataset_shuffle_buffer_size: int = 1
|
|
|
319 |
precomputation_items: int = 512
|
320 |
precomputation_dir: Optional[str] = None
|
321 |
precomputation_once: bool = False
|
@@ -420,6 +421,7 @@ class BaseArgs:
|
|
420 |
dataset_arguments = {
|
421 |
"dataset_config": self.dataset_config,
|
422 |
"dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
|
|
|
423 |
"precomputation_items": self.precomputation_items,
|
424 |
"precomputation_dir": self.precomputation_dir,
|
425 |
"precomputation_once": self.precomputation_once,
|
@@ -625,6 +627,7 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
|
|
625 |
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
|
626 |
parser.add_argument("--dataset_config", type=str, required=True)
|
627 |
parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
|
|
|
628 |
parser.add_argument("--precomputation_items", type=int, default=512)
|
629 |
parser.add_argument("--precomputation_dir", type=str, default=None)
|
630 |
parser.add_argument("--precomputation_once", action="store_true")
|
@@ -761,6 +764,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
|
|
761 |
# Dataset arguments
|
762 |
result_args.dataset_config = args.dataset_config
|
763 |
result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
|
|
|
764 |
result_args.precomputation_items = args.precomputation_items
|
765 |
result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
|
766 |
result_args.precomputation_once = args.precomputation_once
|
|
|
316 |
# Dataset arguments
|
317 |
dataset_config: str = None
|
318 |
dataset_shuffle_buffer_size: int = 1
|
319 |
+
enable_precomputation: bool = False
|
320 |
precomputation_items: int = 512
|
321 |
precomputation_dir: Optional[str] = None
|
322 |
precomputation_once: bool = False
|
|
|
421 |
dataset_arguments = {
|
422 |
"dataset_config": self.dataset_config,
|
423 |
"dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
|
424 |
+
"enable_precomputation": self.enable_precomputation,
|
425 |
"precomputation_items": self.precomputation_items,
|
426 |
"precomputation_dir": self.precomputation_dir,
|
427 |
"precomputation_once": self.precomputation_once,
|
|
|
627 |
def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
|
628 |
parser.add_argument("--dataset_config", type=str, required=True)
|
629 |
parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
|
630 |
+
parser.add_argument("--enable_precomputation", action="store_true")
|
631 |
parser.add_argument("--precomputation_items", type=int, default=512)
|
632 |
parser.add_argument("--precomputation_dir", type=str, default=None)
|
633 |
parser.add_argument("--precomputation_once", action="store_true")
|
|
|
764 |
# Dataset arguments
|
765 |
result_args.dataset_config = args.dataset_config
|
766 |
result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
|
767 |
+
result_args.enable_precomputation = args.enable_precomputation
|
768 |
result_args.precomputation_items = args.precomputation_items
|
769 |
result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
|
770 |
result_args.precomputation_once = args.precomputation_once
|
finetrainers/config.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Type
|
|
3 |
|
4 |
from .models import ModelSpecification
|
5 |
from .models.cogvideox import CogVideoXModelSpecification
|
|
|
6 |
from .models.hunyuan_video import HunyuanVideoModelSpecification
|
7 |
from .models.ltx_video import LTXVideoModelSpecification
|
8 |
from .models.wan import WanModelSpecification
|
@@ -10,6 +11,7 @@ from .models.wan import WanModelSpecification
|
|
10 |
|
11 |
class ModelType(str, Enum):
|
12 |
COGVIDEOX = "cogvideox"
|
|
|
13 |
HUNYUAN_VIDEO = "hunyuan_video"
|
14 |
LTX_VIDEO = "ltx_video"
|
15 |
WAN = "wan"
|
@@ -21,6 +23,14 @@ class TrainingType(str, Enum):
|
|
21 |
|
22 |
|
23 |
SUPPORTED_MODEL_CONFIGS = {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
ModelType.HUNYUAN_VIDEO: {
|
25 |
TrainingType.LORA: HunyuanVideoModelSpecification,
|
26 |
TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
|
@@ -29,10 +39,6 @@ SUPPORTED_MODEL_CONFIGS = {
|
|
29 |
TrainingType.LORA: LTXVideoModelSpecification,
|
30 |
TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
|
31 |
},
|
32 |
-
ModelType.COGVIDEOX: {
|
33 |
-
TrainingType.LORA: CogVideoXModelSpecification,
|
34 |
-
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
|
35 |
-
},
|
36 |
ModelType.WAN: {
|
37 |
TrainingType.LORA: WanModelSpecification,
|
38 |
TrainingType.FULL_FINETUNE: WanModelSpecification,
|
|
|
3 |
|
4 |
from .models import ModelSpecification
|
5 |
from .models.cogvideox import CogVideoXModelSpecification
|
6 |
+
from .models.cogview4 import CogView4ModelSpecification
|
7 |
from .models.hunyuan_video import HunyuanVideoModelSpecification
|
8 |
from .models.ltx_video import LTXVideoModelSpecification
|
9 |
from .models.wan import WanModelSpecification
|
|
|
11 |
|
12 |
class ModelType(str, Enum):
|
13 |
COGVIDEOX = "cogvideox"
|
14 |
+
COGVIEW4 = "cogview4"
|
15 |
HUNYUAN_VIDEO = "hunyuan_video"
|
16 |
LTX_VIDEO = "ltx_video"
|
17 |
WAN = "wan"
|
|
|
23 |
|
24 |
|
25 |
SUPPORTED_MODEL_CONFIGS = {
|
26 |
+
ModelType.COGVIDEOX: {
|
27 |
+
TrainingType.LORA: CogVideoXModelSpecification,
|
28 |
+
TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
|
29 |
+
},
|
30 |
+
ModelType.COGVIEW4: {
|
31 |
+
TrainingType.LORA: CogView4ModelSpecification,
|
32 |
+
TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
|
33 |
+
},
|
34 |
ModelType.HUNYUAN_VIDEO: {
|
35 |
TrainingType.LORA: HunyuanVideoModelSpecification,
|
36 |
TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
|
|
|
39 |
TrainingType.LORA: LTXVideoModelSpecification,
|
40 |
TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
|
41 |
},
|
|
|
|
|
|
|
|
|
42 |
ModelType.WAN: {
|
43 |
TrainingType.LORA: WanModelSpecification,
|
44 |
TrainingType.FULL_FINETUNE: WanModelSpecification,
|
finetrainers/data/__init__.py
CHANGED
@@ -14,6 +14,14 @@ from .dataset import (
|
|
14 |
initialize_dataset,
|
15 |
wrap_iterable_dataset_for_preprocessing,
|
16 |
)
|
17 |
-
from .precomputation import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
from .sampler import ResolutionSampler
|
19 |
from .utils import find_files
|
|
|
14 |
initialize_dataset,
|
15 |
wrap_iterable_dataset_for_preprocessing,
|
16 |
)
|
17 |
+
from .precomputation import (
|
18 |
+
InMemoryDataIterable,
|
19 |
+
InMemoryDistributedDataPreprocessor,
|
20 |
+
InMemoryOnceDataIterable,
|
21 |
+
PrecomputedDataIterable,
|
22 |
+
PrecomputedDistributedDataPreprocessor,
|
23 |
+
PrecomputedOnceDataIterable,
|
24 |
+
initialize_preprocessor,
|
25 |
+
)
|
26 |
from .sampler import ResolutionSampler
|
27 |
from .utils import find_files
|
finetrainers/data/dataset.py
CHANGED
@@ -29,10 +29,13 @@ decord.bridge.set_bridge("torch")
|
|
29 |
logger = get_logger()
|
30 |
|
31 |
|
|
|
32 |
MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
|
33 |
COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
|
34 |
COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
|
35 |
COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
|
|
|
|
|
36 |
|
37 |
|
38 |
class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
|
@@ -420,22 +423,69 @@ class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.che
|
|
420 |
|
421 |
|
422 |
class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
|
423 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
424 |
super().__init__()
|
425 |
|
|
|
|
|
|
|
|
|
426 |
self.dataset_name = dataset_name
|
427 |
self.infinite = infinite
|
428 |
|
429 |
data = datasets.load_dataset(dataset_name, split="train", streaming=True)
|
430 |
-
|
431 |
-
|
432 |
-
if
|
433 |
-
|
434 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
435 |
|
436 |
self._data = data
|
437 |
self._sample_index = 0
|
438 |
self._precomputable_once = False
|
|
|
|
|
439 |
|
440 |
def _get_data_iter(self):
|
441 |
if self._sample_index == 0:
|
@@ -446,6 +496,9 @@ class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkp
|
|
446 |
while True:
|
447 |
for sample in self._get_data_iter():
|
448 |
self._sample_index += 1
|
|
|
|
|
|
|
449 |
yield sample
|
450 |
|
451 |
if not self.infinite:
|
@@ -464,22 +517,69 @@ class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkp
|
|
464 |
|
465 |
|
466 |
class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
|
467 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
468 |
super().__init__()
|
469 |
|
|
|
|
|
|
|
|
|
470 |
self.dataset_name = dataset_name
|
471 |
self.infinite = infinite
|
472 |
|
473 |
data = datasets.load_dataset(dataset_name, split="train", streaming=True)
|
474 |
-
|
475 |
-
|
476 |
-
if
|
477 |
-
|
478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
|
480 |
self._data = data
|
481 |
self._sample_index = 0
|
482 |
self._precomputable_once = False
|
|
|
|
|
483 |
|
484 |
def _get_data_iter(self):
|
485 |
if self._sample_index == 0:
|
@@ -490,6 +590,9 @@ class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkp
|
|
490 |
while True:
|
491 |
for sample in self._get_data_iter():
|
492 |
self._sample_index += 1
|
|
|
|
|
|
|
493 |
yield sample
|
494 |
|
495 |
if not self.infinite:
|
@@ -600,11 +703,17 @@ class IterableDatasetPreprocessingWrapper(
|
|
600 |
for sample in iter(self.dataset):
|
601 |
if self.dataset_type == "image":
|
602 |
if self.image_resolution_buckets:
|
|
|
|
|
|
|
603 |
sample["image"] = FF.resize_to_nearest_bucket_image(
|
604 |
sample["image"], self.image_resolution_buckets, self.reshape_mode
|
605 |
)
|
606 |
elif self.dataset_type == "video":
|
607 |
if self.video_resolution_buckets:
|
|
|
|
|
|
|
608 |
sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
|
609 |
sample["video"], self.video_resolution_buckets, self.reshape_mode
|
610 |
)
|
@@ -682,7 +791,12 @@ class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distribute
|
|
682 |
|
683 |
# TODO(aryan): maybe write a test for this
|
684 |
def initialize_dataset(
|
685 |
-
dataset_name_or_root: str,
|
|
|
|
|
|
|
|
|
|
|
686 |
) -> torch.utils.data.IterableDataset:
|
687 |
assert dataset_type in ["image", "video"]
|
688 |
|
@@ -692,7 +806,7 @@ def initialize_dataset(
|
|
692 |
does_repo_exist_on_hub = False
|
693 |
|
694 |
if does_repo_exist_on_hub:
|
695 |
-
return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite)
|
696 |
else:
|
697 |
return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
|
698 |
|
@@ -745,14 +859,33 @@ def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infi
|
|
745 |
return dataset
|
746 |
|
747 |
|
748 |
-
def _initialize_hub_dataset(
|
|
|
|
|
749 |
repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
|
750 |
if _has_data_caption_file_pairs(repo_file_list, remote=True):
|
751 |
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
|
752 |
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
|
753 |
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
|
754 |
-
|
755 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
756 |
|
757 |
|
758 |
def _initialize_data_caption_file_dataset_from_hub(
|
@@ -778,13 +911,14 @@ def _initialize_data_file_caption_file_dataset_from_hub(
|
|
778 |
|
779 |
|
780 |
def _initialize_webdataset(
|
781 |
-
dataset_name: str, dataset_type: str, infinite: bool = False
|
782 |
) -> torch.utils.data.IterableDataset:
|
783 |
logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
|
|
|
784 |
if dataset_type == "image":
|
785 |
-
return ImageWebDataset(dataset_name, infinite=infinite)
|
786 |
else:
|
787 |
-
return VideoWebDataset(dataset_name, infinite=infinite)
|
788 |
|
789 |
|
790 |
def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
|
|
|
29 |
logger = get_logger()
|
30 |
|
31 |
|
32 |
+
# fmt: off
|
33 |
MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
|
34 |
COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
|
35 |
COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
|
36 |
COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
|
37 |
+
COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"]
|
38 |
+
# fmt: on
|
39 |
|
40 |
|
41 |
class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
|
|
|
423 |
|
424 |
|
425 |
class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
|
426 |
+
def __init__(
|
427 |
+
self,
|
428 |
+
dataset_name: str,
|
429 |
+
infinite: bool = False,
|
430 |
+
column_names: Union[str, List[str]] = "__auto__",
|
431 |
+
weights: Dict[str, float] = -1,
|
432 |
+
**kwargs,
|
433 |
+
) -> None:
|
434 |
super().__init__()
|
435 |
|
436 |
+
assert weights == -1 or isinstance(
|
437 |
+
weights, dict
|
438 |
+
), "`weights` must be a dictionary of probabilities for each caption column"
|
439 |
+
|
440 |
self.dataset_name = dataset_name
|
441 |
self.infinite = infinite
|
442 |
|
443 |
data = datasets.load_dataset(dataset_name, split="train", streaming=True)
|
444 |
+
|
445 |
+
if column_names == "__auto__":
|
446 |
+
if weights == -1:
|
447 |
+
caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
|
448 |
+
if len(caption_columns) == 0:
|
449 |
+
raise ValueError(
|
450 |
+
f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
|
451 |
+
)
|
452 |
+
weights = [1] * len(caption_columns)
|
453 |
+
else:
|
454 |
+
caption_columns = list(weights.keys())
|
455 |
+
weights = list(weights.values())
|
456 |
+
if not all(column in data.column_names for column in caption_columns):
|
457 |
+
raise ValueError(
|
458 |
+
f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
|
459 |
+
)
|
460 |
+
else:
|
461 |
+
if isinstance(column_names, str):
|
462 |
+
if column_names not in data.column_names:
|
463 |
+
raise ValueError(
|
464 |
+
f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
|
465 |
+
)
|
466 |
+
caption_columns = [column_names]
|
467 |
+
weights = [1] if weights == -1 else [weights.get(column_names)]
|
468 |
+
elif isinstance(column_names, list):
|
469 |
+
if not all(column in data.column_names for column in column_names):
|
470 |
+
raise ValueError(
|
471 |
+
f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
|
472 |
+
)
|
473 |
+
caption_columns = column_names
|
474 |
+
weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
|
475 |
+
else:
|
476 |
+
raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
|
477 |
+
|
478 |
+
for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
|
479 |
+
if column_names in data.column_names:
|
480 |
+
data = data.cast_column(column_names, datasets.Image(mode="RGB"))
|
481 |
+
data = data.rename_column(column_names, "image")
|
482 |
+
break
|
483 |
|
484 |
self._data = data
|
485 |
self._sample_index = 0
|
486 |
self._precomputable_once = False
|
487 |
+
self._caption_columns = caption_columns
|
488 |
+
self._weights = weights
|
489 |
|
490 |
def _get_data_iter(self):
|
491 |
if self._sample_index == 0:
|
|
|
496 |
while True:
|
497 |
for sample in self._get_data_iter():
|
498 |
self._sample_index += 1
|
499 |
+
caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
|
500 |
+
sample["caption"] = sample[caption_column]
|
501 |
+
sample["image"] = _preprocess_image(sample["image"])
|
502 |
yield sample
|
503 |
|
504 |
if not self.infinite:
|
|
|
517 |
|
518 |
|
519 |
class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
|
520 |
+
def __init__(
|
521 |
+
self,
|
522 |
+
dataset_name: str,
|
523 |
+
infinite: bool = False,
|
524 |
+
column_names: Union[str, List[str]] = "__auto__",
|
525 |
+
weights: Dict[str, float] = -1,
|
526 |
+
**kwargs,
|
527 |
+
) -> None:
|
528 |
super().__init__()
|
529 |
|
530 |
+
assert weights == -1 or isinstance(
|
531 |
+
weights, dict
|
532 |
+
), "`weights` must be a dictionary of probabilities for each caption column"
|
533 |
+
|
534 |
self.dataset_name = dataset_name
|
535 |
self.infinite = infinite
|
536 |
|
537 |
data = datasets.load_dataset(dataset_name, split="train", streaming=True)
|
538 |
+
|
539 |
+
if column_names == "__auto__":
|
540 |
+
if weights == -1:
|
541 |
+
caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
|
542 |
+
if len(caption_columns) == 0:
|
543 |
+
raise ValueError(
|
544 |
+
f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
|
545 |
+
)
|
546 |
+
weights = [1] * len(caption_columns)
|
547 |
+
else:
|
548 |
+
caption_columns = list(weights.keys())
|
549 |
+
weights = list(weights.values())
|
550 |
+
if not all(column in data.column_names for column in caption_columns):
|
551 |
+
raise ValueError(
|
552 |
+
f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
|
553 |
+
)
|
554 |
+
else:
|
555 |
+
if isinstance(column_names, str):
|
556 |
+
if column_names not in data.column_names:
|
557 |
+
raise ValueError(
|
558 |
+
f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
|
559 |
+
)
|
560 |
+
caption_columns = [column_names]
|
561 |
+
weights = [1] if weights == -1 else [weights.get(column_names)]
|
562 |
+
elif isinstance(column_names, list):
|
563 |
+
if not all(column in data.column_names for column in column_names):
|
564 |
+
raise ValueError(
|
565 |
+
f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
|
566 |
+
)
|
567 |
+
caption_columns = column_names
|
568 |
+
weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
|
569 |
+
else:
|
570 |
+
raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
|
571 |
+
|
572 |
+
for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
|
573 |
+
if column_names in data.column_names:
|
574 |
+
data = data.cast_column(column_names, datasets.Video())
|
575 |
+
data = data.rename_column(column_names, "video")
|
576 |
+
break
|
577 |
|
578 |
self._data = data
|
579 |
self._sample_index = 0
|
580 |
self._precomputable_once = False
|
581 |
+
self._caption_columns = caption_columns
|
582 |
+
self._weights = weights
|
583 |
|
584 |
def _get_data_iter(self):
|
585 |
if self._sample_index == 0:
|
|
|
590 |
while True:
|
591 |
for sample in self._get_data_iter():
|
592 |
self._sample_index += 1
|
593 |
+
caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
|
594 |
+
sample["caption"] = sample[caption_column]
|
595 |
+
sample["video"] = _preprocess_video(sample["video"])
|
596 |
yield sample
|
597 |
|
598 |
if not self.infinite:
|
|
|
703 |
for sample in iter(self.dataset):
|
704 |
if self.dataset_type == "image":
|
705 |
if self.image_resolution_buckets:
|
706 |
+
sample["_original_num_frames"] = 1
|
707 |
+
sample["_original_height"] = sample["image"].size(1)
|
708 |
+
sample["_original_width"] = sample["image"].size(2)
|
709 |
sample["image"] = FF.resize_to_nearest_bucket_image(
|
710 |
sample["image"], self.image_resolution_buckets, self.reshape_mode
|
711 |
)
|
712 |
elif self.dataset_type == "video":
|
713 |
if self.video_resolution_buckets:
|
714 |
+
sample["_original_num_frames"] = sample["video"].size(0)
|
715 |
+
sample["_original_height"] = sample["video"].size(2)
|
716 |
+
sample["_original_width"] = sample["video"].size(3)
|
717 |
sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
|
718 |
sample["video"], self.video_resolution_buckets, self.reshape_mode
|
719 |
)
|
|
|
791 |
|
792 |
# TODO(aryan): maybe write a test for this
|
793 |
def initialize_dataset(
|
794 |
+
dataset_name_or_root: str,
|
795 |
+
dataset_type: str = "video",
|
796 |
+
streaming: bool = True,
|
797 |
+
infinite: bool = False,
|
798 |
+
*,
|
799 |
+
_caption_options: Optional[Dict[str, Any]] = None,
|
800 |
) -> torch.utils.data.IterableDataset:
|
801 |
assert dataset_type in ["image", "video"]
|
802 |
|
|
|
806 |
does_repo_exist_on_hub = False
|
807 |
|
808 |
if does_repo_exist_on_hub:
|
809 |
+
return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options)
|
810 |
else:
|
811 |
return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
|
812 |
|
|
|
859 |
return dataset
|
860 |
|
861 |
|
862 |
+
def _initialize_hub_dataset(
|
863 |
+
dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None
|
864 |
+
):
|
865 |
repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
|
866 |
if _has_data_caption_file_pairs(repo_file_list, remote=True):
|
867 |
return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
|
868 |
elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
|
869 |
return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
|
870 |
+
|
871 |
+
has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list)
|
872 |
+
if has_tar_files:
|
873 |
+
return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options)
|
874 |
+
|
875 |
+
# TODO(aryan): This should be improved
|
876 |
+
caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
|
877 |
+
if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
|
878 |
+
try:
|
879 |
+
dataset_root = snapshot_download(dataset_name, repo_type="dataset")
|
880 |
+
if dataset_type == "image":
|
881 |
+
dataset = ImageFolderDataset(dataset_root, infinite=infinite)
|
882 |
+
else:
|
883 |
+
dataset = VideoFolderDataset(dataset_root, infinite=infinite)
|
884 |
+
return dataset
|
885 |
+
except Exception:
|
886 |
+
pass
|
887 |
+
|
888 |
+
raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")
|
889 |
|
890 |
|
891 |
def _initialize_data_caption_file_dataset_from_hub(
|
|
|
911 |
|
912 |
|
913 |
def _initialize_webdataset(
|
914 |
+
dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None
|
915 |
) -> torch.utils.data.IterableDataset:
|
916 |
logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
|
917 |
+
_caption_options = _caption_options or {}
|
918 |
if dataset_type == "image":
|
919 |
+
return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options)
|
920 |
else:
|
921 |
+
return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options)
|
922 |
|
923 |
|
924 |
def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
|
finetrainers/data/precomputation.py
CHANGED
@@ -1,13 +1,132 @@
|
|
1 |
import pathlib
|
2 |
-
from typing import Any, Callable, Dict, Iterable, Optional
|
3 |
|
4 |
import torch
|
5 |
from tqdm.auto import tqdm
|
6 |
|
7 |
from .. import utils
|
|
|
8 |
|
9 |
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
def __init__(
|
12 |
self,
|
13 |
rank: int,
|
@@ -15,13 +134,15 @@ class DistributedDataPreprocessor:
|
|
15 |
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
|
16 |
save_dir: str,
|
17 |
) -> None:
|
|
|
|
|
18 |
self._rank = rank
|
19 |
self._num_items = num_items
|
20 |
self._processor_fn = processor_fn
|
21 |
self._save_dir = pathlib.Path(save_dir)
|
22 |
|
23 |
self._cached_samples = []
|
24 |
-
self._preprocessed_iterator: "
|
25 |
|
26 |
self._save_dir.mkdir(parents=True, exist_ok=True)
|
27 |
|
@@ -59,9 +180,8 @@ class DistributedDataPreprocessor:
|
|
59 |
if drop_samples:
|
60 |
del self._cached_samples
|
61 |
self._cached_samples = []
|
62 |
-
utils.free_memory()
|
63 |
|
64 |
-
self._preprocessed_iterator =
|
65 |
return iter(self._preprocessed_iterator)
|
66 |
|
67 |
def consume_once(
|
@@ -95,9 +215,8 @@ class DistributedDataPreprocessor:
|
|
95 |
if drop_samples:
|
96 |
del self._cached_samples
|
97 |
self._cached_samples = []
|
98 |
-
utils.free_memory()
|
99 |
|
100 |
-
self._preprocessed_iterator =
|
101 |
return iter(self._preprocessed_iterator)
|
102 |
|
103 |
@property
|
@@ -107,7 +226,70 @@ class DistributedDataPreprocessor:
|
|
107 |
return self._preprocessed_iterator.requires_data
|
108 |
|
109 |
|
110 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
111 |
def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
|
112 |
self._rank = rank
|
113 |
self._save_dir = pathlib.Path(save_dir)
|
@@ -130,7 +312,13 @@ class PreprocessedDataIterable:
|
|
130 |
return self._requires_data
|
131 |
|
132 |
|
133 |
-
class
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
|
135 |
self._rank = rank
|
136 |
self._save_dir = pathlib.Path(save_dir)
|
@@ -153,6 +341,31 @@ class PreprocessedOnceDataIterable:
|
|
153 |
return self._requires_data
|
154 |
|
155 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
156 |
def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
|
157 |
filename = directory / f"{data_type}-{rank}-{index}.pt"
|
158 |
torch.save(item, filename.as_posix())
|
|
|
1 |
import pathlib
|
2 |
+
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
|
3 |
|
4 |
import torch
|
5 |
from tqdm.auto import tqdm
|
6 |
|
7 |
from .. import utils
|
8 |
+
from ..logging import get_logger
|
9 |
|
10 |
|
11 |
+
logger = get_logger()
|
12 |
+
|
13 |
+
|
14 |
+
def initialize_preprocessor(
|
15 |
+
rank: int,
|
16 |
+
num_items: int,
|
17 |
+
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
|
18 |
+
save_dir: Optional[str] = None,
|
19 |
+
enable_precomputation: bool = False,
|
20 |
+
) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]:
|
21 |
+
if enable_precomputation:
|
22 |
+
return PrecomputedDistributedDataPreprocessor(rank, num_items, processor_fn, save_dir)
|
23 |
+
return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn)
|
24 |
+
|
25 |
+
|
26 |
+
class DistributedDataProcessorMixin:
|
27 |
+
def consume(self, *args, **kwargs):
|
28 |
+
raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.")
|
29 |
+
|
30 |
+
def consume_once(self, *args, **kwargs):
|
31 |
+
raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.")
|
32 |
+
|
33 |
+
@property
|
34 |
+
def requires_data(self):
|
35 |
+
raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.")
|
36 |
+
|
37 |
+
|
38 |
+
class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin):
|
39 |
+
def __init__(
|
40 |
+
self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]]
|
41 |
+
) -> None:
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self._rank = rank
|
45 |
+
self._num_items = num_items
|
46 |
+
self._processor_fn = processor_fn
|
47 |
+
|
48 |
+
self._cached_samples = []
|
49 |
+
self._buffer = InMemoryDataBuffer(num_items)
|
50 |
+
self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None
|
51 |
+
|
52 |
+
def consume(
|
53 |
+
self,
|
54 |
+
data_type: str,
|
55 |
+
components: Dict[str, Any],
|
56 |
+
data_iterator,
|
57 |
+
generator: Optional[torch.Generator] = None,
|
58 |
+
cache_samples: bool = False,
|
59 |
+
use_cached_samples: bool = False,
|
60 |
+
drop_samples: bool = False,
|
61 |
+
) -> Iterable[Dict[str, Any]]:
|
62 |
+
if data_type not in self._processor_fn.keys():
|
63 |
+
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
|
64 |
+
if cache_samples:
|
65 |
+
if use_cached_samples:
|
66 |
+
raise ValueError("Cannot cache and use cached samples at the same time.")
|
67 |
+
if drop_samples:
|
68 |
+
raise ValueError("Cannot cache and drop samples at the same time.")
|
69 |
+
|
70 |
+
for i in range(self._num_items):
|
71 |
+
if use_cached_samples:
|
72 |
+
item = self._cached_samples[i]
|
73 |
+
else:
|
74 |
+
item = next(data_iterator)
|
75 |
+
if cache_samples:
|
76 |
+
self._cached_samples.append(item)
|
77 |
+
item = self._processor_fn[data_type](**item, **components, generator=generator)
|
78 |
+
self._buffer.add(data_type, item)
|
79 |
+
|
80 |
+
if drop_samples:
|
81 |
+
del self._cached_samples
|
82 |
+
self._cached_samples = []
|
83 |
+
|
84 |
+
self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer)
|
85 |
+
return iter(self._preprocessed_iterator)
|
86 |
+
|
87 |
+
def consume_once(
|
88 |
+
self,
|
89 |
+
data_type: str,
|
90 |
+
components: Dict[str, Any],
|
91 |
+
data_iterator,
|
92 |
+
generator: Optional[torch.Generator] = None,
|
93 |
+
cache_samples: bool = False,
|
94 |
+
use_cached_samples: bool = False,
|
95 |
+
drop_samples: bool = False,
|
96 |
+
) -> Iterable[Dict[str, Any]]:
|
97 |
+
if data_type not in self._processor_fn.keys():
|
98 |
+
raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
|
99 |
+
if cache_samples:
|
100 |
+
if use_cached_samples:
|
101 |
+
raise ValueError("Cannot cache and use cached samples at the same time.")
|
102 |
+
if drop_samples:
|
103 |
+
raise ValueError("Cannot cache and drop samples at the same time.")
|
104 |
+
|
105 |
+
for i in range(self._num_items):
|
106 |
+
if use_cached_samples:
|
107 |
+
item = self._cached_samples[i]
|
108 |
+
else:
|
109 |
+
item = next(data_iterator)
|
110 |
+
if cache_samples:
|
111 |
+
self._cached_samples.append(item)
|
112 |
+
item = self._processor_fn[data_type](**item, **components, generator=generator)
|
113 |
+
self._buffer.add(data_type, item)
|
114 |
+
|
115 |
+
if drop_samples:
|
116 |
+
del self._cached_samples
|
117 |
+
self._cached_samples = []
|
118 |
+
|
119 |
+
self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer)
|
120 |
+
return iter(self._preprocessed_iterator)
|
121 |
+
|
122 |
+
@property
|
123 |
+
def requires_data(self):
|
124 |
+
if self._preprocessed_iterator is None:
|
125 |
+
return True
|
126 |
+
return self._preprocessed_iterator.requires_data
|
127 |
+
|
128 |
+
|
129 |
+
class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin):
|
130 |
def __init__(
|
131 |
self,
|
132 |
rank: int,
|
|
|
134 |
processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
|
135 |
save_dir: str,
|
136 |
) -> None:
|
137 |
+
super().__init__()
|
138 |
+
|
139 |
self._rank = rank
|
140 |
self._num_items = num_items
|
141 |
self._processor_fn = processor_fn
|
142 |
self._save_dir = pathlib.Path(save_dir)
|
143 |
|
144 |
self._cached_samples = []
|
145 |
+
self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None
|
146 |
|
147 |
self._save_dir.mkdir(parents=True, exist_ok=True)
|
148 |
|
|
|
180 |
if drop_samples:
|
181 |
del self._cached_samples
|
182 |
self._cached_samples = []
|
|
|
183 |
|
184 |
+
self._preprocessed_iterator = PrecomputedDataIterable(self._rank, self._save_dir, data_type)
|
185 |
return iter(self._preprocessed_iterator)
|
186 |
|
187 |
def consume_once(
|
|
|
215 |
if drop_samples:
|
216 |
del self._cached_samples
|
217 |
self._cached_samples = []
|
|
|
218 |
|
219 |
+
self._preprocessed_iterator = PrecomputedOnceDataIterable(self._rank, self._save_dir, data_type)
|
220 |
return iter(self._preprocessed_iterator)
|
221 |
|
222 |
@property
|
|
|
226 |
return self._preprocessed_iterator.requires_data
|
227 |
|
228 |
|
229 |
+
class InMemoryDataIterable:
|
230 |
+
"""
|
231 |
+
An iterator that loads data items from an in-memory buffer. Once all the data is consumed,
|
232 |
+
`requires_data` is set to True, indicating that the more data is required and the preprocessor's
|
233 |
+
consume method should be called again.
|
234 |
+
"""
|
235 |
+
|
236 |
+
def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
|
237 |
+
self._rank = rank
|
238 |
+
self._data_type = data_type
|
239 |
+
self._buffer = buffer
|
240 |
+
|
241 |
+
self._requires_data = False
|
242 |
+
|
243 |
+
def __iter__(self) -> Iterable[Dict[str, Any]]:
|
244 |
+
while (length := self._buffer.get_length(self._data_type)) > 0:
|
245 |
+
if length <= 1:
|
246 |
+
self._requires_data = True
|
247 |
+
yield self._buffer.get(self._data_type)
|
248 |
+
|
249 |
+
def __len__(self) -> int:
|
250 |
+
return self._buffer.get_length(self._data_type)
|
251 |
+
|
252 |
+
@property
|
253 |
+
def requires_data(self):
|
254 |
+
return self._requires_data
|
255 |
+
|
256 |
+
|
257 |
+
class InMemoryOnceDataIterable:
|
258 |
+
"""
|
259 |
+
An iterator that loads data items from an in-memory buffer. This iterator will never set
|
260 |
+
`requires_data` to True, as it is assumed that all the data was configured to be preprocessed
|
261 |
+
by the user. The data will indefinitely be cycled from the buffer.
|
262 |
+
"""
|
263 |
+
|
264 |
+
def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
|
265 |
+
self._rank = rank
|
266 |
+
self._data_type = data_type
|
267 |
+
self._buffer = buffer
|
268 |
+
|
269 |
+
self._requires_data = False
|
270 |
+
|
271 |
+
def __iter__(self) -> Iterable[Dict[str, Any]]:
|
272 |
+
assert len(self) > 0, "No data available in the buffer."
|
273 |
+
while True:
|
274 |
+
item = self._buffer.get(self._data_type)
|
275 |
+
yield item
|
276 |
+
self._buffer.add(self._data_type, item)
|
277 |
+
|
278 |
+
def __len__(self) -> int:
|
279 |
+
return self._buffer.get_length(self._data_type)
|
280 |
+
|
281 |
+
@property
|
282 |
+
def requires_data(self):
|
283 |
+
return self._requires_data
|
284 |
+
|
285 |
+
|
286 |
+
class PrecomputedDataIterable:
|
287 |
+
"""
|
288 |
+
An iterator that loads preconfigured number of data items from disk. Once all the data is
|
289 |
+
loaded, `requires_data` is set to True, indicating that the more data is required and
|
290 |
+
the preprocessor's consume method should be called again.
|
291 |
+
"""
|
292 |
+
|
293 |
def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
|
294 |
self._rank = rank
|
295 |
self._save_dir = pathlib.Path(save_dir)
|
|
|
312 |
return self._requires_data
|
313 |
|
314 |
|
315 |
+
class PrecomputedOnceDataIterable:
|
316 |
+
"""
|
317 |
+
An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator
|
318 |
+
will never set `requires_data` to True, as it is assumed that all the data was configured to
|
319 |
+
be preprocessed by the user.
|
320 |
+
"""
|
321 |
+
|
322 |
def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
|
323 |
self._rank = rank
|
324 |
self._save_dir = pathlib.Path(save_dir)
|
|
|
341 |
return self._requires_data
|
342 |
|
343 |
|
344 |
+
class InMemoryDataBuffer:
|
345 |
+
def __init__(self, max_limit: int = -1) -> None:
|
346 |
+
self.max_limit = max_limit
|
347 |
+
self.buffer: Dict[str, List[str]] = {}
|
348 |
+
|
349 |
+
def add(self, data_type: str, item: Dict[str, Any]) -> None:
|
350 |
+
if data_type not in self.buffer:
|
351 |
+
self.buffer[data_type] = []
|
352 |
+
if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit:
|
353 |
+
logger.log_freq(
|
354 |
+
"WARN",
|
355 |
+
"IN_MEMORY_DATA_BUFFER_FULL",
|
356 |
+
"Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.",
|
357 |
+
64,
|
358 |
+
)
|
359 |
+
self.buffer[data_type].pop(0)
|
360 |
+
self.buffer[data_type].append(item)
|
361 |
+
|
362 |
+
def get(self, data_type: str) -> Dict[str, Any]:
|
363 |
+
return self.buffer[data_type].pop(0)
|
364 |
+
|
365 |
+
def get_length(self, data_type: str) -> int:
|
366 |
+
return len(self.buffer[data_type])
|
367 |
+
|
368 |
+
|
369 |
def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
|
370 |
filename = directory / f"{data_type}-{rank}-{index}.pt"
|
371 |
torch.save(item, filename.as_posix())
|
finetrainers/functional/image.py
CHANGED
@@ -22,7 +22,7 @@ def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tenso
|
|
22 |
|
23 |
|
24 |
def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
|
25 |
-
return F.interpolate(image, size=size, mode="bicubic", align_corners=False)
|
26 |
|
27 |
|
28 |
def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
|
|
|
22 |
|
23 |
|
24 |
def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
|
25 |
+
return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0]
|
26 |
|
27 |
|
28 |
def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
|
finetrainers/models/cogvideox/base_specification.py
CHANGED
@@ -105,7 +105,7 @@ class CogVideoXModelSpecification(ModelSpecification):
|
|
105 |
)
|
106 |
|
107 |
if condition_model_processors is None:
|
108 |
-
condition_model_processors = [T5Processor(["
|
109 |
if latent_model_processors is None:
|
110 |
latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
|
111 |
|
@@ -337,7 +337,6 @@ class CogVideoXModelSpecification(ModelSpecification):
|
|
337 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
338 |
latent_model_conditions["image_rotary_emb"] = image_rotary_emb
|
339 |
latent_model_conditions["ofs"] = ofs_emb
|
340 |
-
condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
|
341 |
|
342 |
velocity = transformer(
|
343 |
**latent_model_conditions,
|
|
|
105 |
)
|
106 |
|
107 |
if condition_model_processors is None:
|
108 |
+
condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
|
109 |
if latent_model_processors is None:
|
110 |
latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
|
111 |
|
|
|
337 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
338 |
latent_model_conditions["image_rotary_emb"] = image_rotary_emb
|
339 |
latent_model_conditions["ofs"] = ofs_emb
|
|
|
340 |
|
341 |
velocity = transformer(
|
342 |
**latent_model_conditions,
|
finetrainers/models/cogview4/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .base_specification import CogView4ModelSpecification
|
finetrainers/models/cogview4/base_specification.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import Any, Dict, List, Optional, Tuple
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from accelerate import init_empty_weights
|
6 |
+
from diffusers import (
|
7 |
+
AutoencoderKL,
|
8 |
+
CogView4Pipeline,
|
9 |
+
CogView4Transformer2DModel,
|
10 |
+
FlowMatchEulerDiscreteScheduler,
|
11 |
+
)
|
12 |
+
from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
|
13 |
+
from transformers import AutoTokenizer, GlmModel
|
14 |
+
|
15 |
+
from ... import data
|
16 |
+
from ... import functional as FF
|
17 |
+
from ...logging import get_logger
|
18 |
+
from ...processors import CogView4GLMProcessor, ProcessorMixin
|
19 |
+
from ...typing import ArtifactType, SchedulerType
|
20 |
+
from ...utils import get_non_null_items
|
21 |
+
from ..modeling_utils import ModelSpecification
|
22 |
+
|
23 |
+
|
24 |
+
logger = get_logger()
|
25 |
+
|
26 |
+
|
27 |
+
class CogView4LatentEncodeProcessor(ProcessorMixin):
|
28 |
+
r"""
|
29 |
+
Processor to encode image/video into latents using the LTX VAE.
|
30 |
+
|
31 |
+
Args:
|
32 |
+
output_names (`List[str]`):
|
33 |
+
The names of the outputs that the processor returns. The outputs are in the following order:
|
34 |
+
- latents: The latents of the input image/video.
|
35 |
+
- original_size: The original size of the input image/video.
|
36 |
+
- target_size: The target size of the input image/video.
|
37 |
+
- crop_coords: The top-left crop coordinates of the input image/video.
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, output_names: List[str]):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.output_names = output_names
|
44 |
+
assert len(self.output_names) == 4
|
45 |
+
|
46 |
+
def forward(
|
47 |
+
self,
|
48 |
+
vae: AutoencoderKL,
|
49 |
+
image: Optional[torch.Tensor] = None,
|
50 |
+
video: Optional[torch.Tensor] = None,
|
51 |
+
generator: Optional[torch.Generator] = None,
|
52 |
+
compute_posterior: bool = True,
|
53 |
+
_original_height: Optional[int] = None,
|
54 |
+
_original_width: Optional[int] = None,
|
55 |
+
) -> Dict[str, torch.Tensor]:
|
56 |
+
device = vae.device
|
57 |
+
dtype = vae.dtype
|
58 |
+
|
59 |
+
if video is not None:
|
60 |
+
# TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly
|
61 |
+
image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W]
|
62 |
+
|
63 |
+
assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
|
64 |
+
image = image.to(device=device, dtype=vae.dtype)
|
65 |
+
|
66 |
+
if compute_posterior:
|
67 |
+
latents = vae.encode(image).latent_dist.sample(generator=generator)
|
68 |
+
latents = latents.to(dtype=dtype)
|
69 |
+
else:
|
70 |
+
if vae.use_slicing and image.shape[0] > 1:
|
71 |
+
encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)]
|
72 |
+
moments = torch.cat(encoded_slices)
|
73 |
+
else:
|
74 |
+
moments = vae._encode(image)
|
75 |
+
latents = moments.to(dtype=dtype)
|
76 |
+
|
77 |
+
batch_size = latents.size(0)
|
78 |
+
target_height = image.size(2)
|
79 |
+
target_width = image.size(3)
|
80 |
+
original_size = torch.tensor([(_original_height, _original_width)], device=device, dtype=dtype).repeat(
|
81 |
+
batch_size, 1
|
82 |
+
)
|
83 |
+
target_size = torch.tensor([(target_height, target_width)], device=device, dtype=dtype).repeat(batch_size, 1)
|
84 |
+
crop_coords = torch.tensor([(0, 0)], device=device, dtype=dtype).repeat(batch_size, 1)
|
85 |
+
|
86 |
+
return {
|
87 |
+
self.output_names[0]: latents,
|
88 |
+
self.output_names[1]: original_size,
|
89 |
+
self.output_names[2]: target_size,
|
90 |
+
self.output_names[3]: crop_coords,
|
91 |
+
}
|
92 |
+
|
93 |
+
|
94 |
+
class CogView4ModelSpecification(ModelSpecification):
|
95 |
+
def __init__(
|
96 |
+
self,
|
97 |
+
pretrained_model_name_or_path: str = "THUDM/CogView4-6B",
|
98 |
+
tokenizer_id: Optional[str] = None,
|
99 |
+
text_encoder_id: Optional[str] = None,
|
100 |
+
transformer_id: Optional[str] = None,
|
101 |
+
vae_id: Optional[str] = None,
|
102 |
+
text_encoder_dtype: torch.dtype = torch.bfloat16,
|
103 |
+
transformer_dtype: torch.dtype = torch.bfloat16,
|
104 |
+
vae_dtype: torch.dtype = torch.bfloat16,
|
105 |
+
revision: Optional[str] = None,
|
106 |
+
cache_dir: Optional[str] = None,
|
107 |
+
condition_model_processors: List[ProcessorMixin] = None,
|
108 |
+
latent_model_processors: List[ProcessorMixin] = None,
|
109 |
+
**kwargs,
|
110 |
+
) -> None:
|
111 |
+
super().__init__(
|
112 |
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
113 |
+
tokenizer_id=tokenizer_id,
|
114 |
+
text_encoder_id=text_encoder_id,
|
115 |
+
transformer_id=transformer_id,
|
116 |
+
vae_id=vae_id,
|
117 |
+
text_encoder_dtype=text_encoder_dtype,
|
118 |
+
transformer_dtype=transformer_dtype,
|
119 |
+
vae_dtype=vae_dtype,
|
120 |
+
revision=revision,
|
121 |
+
cache_dir=cache_dir,
|
122 |
+
)
|
123 |
+
|
124 |
+
if condition_model_processors is None:
|
125 |
+
condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])]
|
126 |
+
if latent_model_processors is None:
|
127 |
+
latent_model_processors = [
|
128 |
+
CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"])
|
129 |
+
]
|
130 |
+
|
131 |
+
self.condition_model_processors = condition_model_processors
|
132 |
+
self.latent_model_processors = latent_model_processors
|
133 |
+
|
134 |
+
@property
|
135 |
+
def _resolution_dim_keys(self):
|
136 |
+
return {"latents": (2, 3)}
|
137 |
+
|
138 |
+
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
|
139 |
+
if self.tokenizer_id is not None:
|
140 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
141 |
+
self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
|
142 |
+
)
|
143 |
+
else:
|
144 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
145 |
+
self.pretrained_model_name_or_path,
|
146 |
+
subfolder="tokenizer",
|
147 |
+
revision=self.revision,
|
148 |
+
cache_dir=self.cache_dir,
|
149 |
+
)
|
150 |
+
|
151 |
+
if self.text_encoder_id is not None:
|
152 |
+
text_encoder = GlmModel.from_pretrained(
|
153 |
+
self.text_encoder_id,
|
154 |
+
torch_dtype=self.text_encoder_dtype,
|
155 |
+
revision=self.revision,
|
156 |
+
cache_dir=self.cache_dir,
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
text_encoder = GlmModel.from_pretrained(
|
160 |
+
self.pretrained_model_name_or_path,
|
161 |
+
subfolder="text_encoder",
|
162 |
+
torch_dtype=self.text_encoder_dtype,
|
163 |
+
revision=self.revision,
|
164 |
+
cache_dir=self.cache_dir,
|
165 |
+
)
|
166 |
+
|
167 |
+
return {"tokenizer": tokenizer, "text_encoder": text_encoder}
|
168 |
+
|
169 |
+
def load_latent_models(self) -> Dict[str, torch.nn.Module]:
|
170 |
+
if self.vae_id is not None:
|
171 |
+
vae = AutoencoderKL.from_pretrained(
|
172 |
+
self.vae_id,
|
173 |
+
torch_dtype=self.vae_dtype,
|
174 |
+
revision=self.revision,
|
175 |
+
cache_dir=self.cache_dir,
|
176 |
+
)
|
177 |
+
else:
|
178 |
+
vae = AutoencoderKL.from_pretrained(
|
179 |
+
self.pretrained_model_name_or_path,
|
180 |
+
subfolder="vae",
|
181 |
+
torch_dtype=self.vae_dtype,
|
182 |
+
revision=self.revision,
|
183 |
+
cache_dir=self.cache_dir,
|
184 |
+
)
|
185 |
+
|
186 |
+
return {"vae": vae}
|
187 |
+
|
188 |
+
def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
|
189 |
+
if self.transformer_id is not None:
|
190 |
+
transformer = CogView4Transformer2DModel.from_pretrained(
|
191 |
+
self.transformer_id,
|
192 |
+
torch_dtype=self.transformer_dtype,
|
193 |
+
revision=self.revision,
|
194 |
+
cache_dir=self.cache_dir,
|
195 |
+
)
|
196 |
+
else:
|
197 |
+
transformer = CogView4Transformer2DModel.from_pretrained(
|
198 |
+
self.pretrained_model_name_or_path,
|
199 |
+
subfolder="transformer",
|
200 |
+
torch_dtype=self.transformer_dtype,
|
201 |
+
revision=self.revision,
|
202 |
+
cache_dir=self.cache_dir,
|
203 |
+
)
|
204 |
+
|
205 |
+
scheduler = FlowMatchEulerDiscreteScheduler()
|
206 |
+
|
207 |
+
return {"transformer": transformer, "scheduler": scheduler}
|
208 |
+
|
209 |
+
def load_pipeline(
|
210 |
+
self,
|
211 |
+
tokenizer: Optional[AutoTokenizer] = None,
|
212 |
+
text_encoder: Optional[GlmModel] = None,
|
213 |
+
transformer: Optional[CogView4Transformer2DModel] = None,
|
214 |
+
vae: Optional[AutoencoderKL] = None,
|
215 |
+
scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
|
216 |
+
enable_slicing: bool = False,
|
217 |
+
enable_tiling: bool = False,
|
218 |
+
enable_model_cpu_offload: bool = False,
|
219 |
+
training: bool = False,
|
220 |
+
**kwargs,
|
221 |
+
) -> CogView4Pipeline:
|
222 |
+
components = {
|
223 |
+
"tokenizer": tokenizer,
|
224 |
+
"text_encoder": text_encoder,
|
225 |
+
"transformer": transformer,
|
226 |
+
"vae": vae,
|
227 |
+
# Load the scheduler based on CogView4's config instead of using the default initialization being used for training
|
228 |
+
# "scheduler": scheduler,
|
229 |
+
}
|
230 |
+
components = get_non_null_items(components)
|
231 |
+
|
232 |
+
pipe = CogView4Pipeline.from_pretrained(
|
233 |
+
self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
|
234 |
+
)
|
235 |
+
pipe.text_encoder.to(self.text_encoder_dtype)
|
236 |
+
pipe.vae.to(self.vae_dtype)
|
237 |
+
|
238 |
+
if not training:
|
239 |
+
pipe.transformer.to(self.transformer_dtype)
|
240 |
+
|
241 |
+
if enable_slicing:
|
242 |
+
pipe.vae.enable_slicing()
|
243 |
+
if enable_tiling:
|
244 |
+
pipe.vae.enable_tiling()
|
245 |
+
if enable_model_cpu_offload:
|
246 |
+
pipe.enable_model_cpu_offload()
|
247 |
+
|
248 |
+
return pipe
|
249 |
+
|
250 |
+
@torch.no_grad()
|
251 |
+
def prepare_conditions(
|
252 |
+
self,
|
253 |
+
tokenizer: AutoTokenizer,
|
254 |
+
text_encoder: GlmModel,
|
255 |
+
caption: str,
|
256 |
+
max_sequence_length: int = 1024,
|
257 |
+
**kwargs,
|
258 |
+
) -> Dict[str, Any]:
|
259 |
+
conditions = {
|
260 |
+
"tokenizer": tokenizer,
|
261 |
+
"text_encoder": text_encoder,
|
262 |
+
"caption": caption,
|
263 |
+
"max_sequence_length": max_sequence_length,
|
264 |
+
**kwargs,
|
265 |
+
}
|
266 |
+
input_keys = set(conditions.keys())
|
267 |
+
conditions = super().prepare_conditions(**conditions)
|
268 |
+
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
|
269 |
+
return conditions
|
270 |
+
|
271 |
+
@torch.no_grad()
|
272 |
+
def prepare_latents(
|
273 |
+
self,
|
274 |
+
vae: AutoencoderKL,
|
275 |
+
image: Optional[torch.Tensor] = None,
|
276 |
+
video: Optional[torch.Tensor] = None,
|
277 |
+
generator: Optional[torch.Generator] = None,
|
278 |
+
compute_posterior: bool = True,
|
279 |
+
_original_height: Optional[int] = None,
|
280 |
+
_original_width: Optional[int] = None,
|
281 |
+
**kwargs,
|
282 |
+
) -> Dict[str, torch.Tensor]:
|
283 |
+
conditions = {
|
284 |
+
"vae": vae,
|
285 |
+
"image": image,
|
286 |
+
"video": video,
|
287 |
+
"generator": generator,
|
288 |
+
"compute_posterior": compute_posterior,
|
289 |
+
"_original_height": _original_height,
|
290 |
+
"_original_width": _original_width,
|
291 |
+
**kwargs,
|
292 |
+
}
|
293 |
+
input_keys = set(conditions.keys())
|
294 |
+
conditions = super().prepare_latents(**conditions)
|
295 |
+
conditions = {k: v for k, v in conditions.items() if k not in input_keys}
|
296 |
+
return conditions
|
297 |
+
|
298 |
+
def forward(
|
299 |
+
self,
|
300 |
+
transformer: CogView4Transformer2DModel,
|
301 |
+
condition_model_conditions: Dict[str, torch.Tensor],
|
302 |
+
latent_model_conditions: Dict[str, torch.Tensor],
|
303 |
+
sigmas: torch.Tensor,
|
304 |
+
generator: Optional[torch.Generator] = None,
|
305 |
+
compute_posterior: bool = True,
|
306 |
+
**kwargs,
|
307 |
+
) -> Tuple[torch.Tensor, ...]:
|
308 |
+
if compute_posterior:
|
309 |
+
latents = latent_model_conditions.pop("latents")
|
310 |
+
else:
|
311 |
+
posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
|
312 |
+
latents = posterior.sample(generator=generator)
|
313 |
+
del posterior
|
314 |
+
|
315 |
+
latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
|
316 |
+
noise = torch.zeros_like(latents).normal_(generator=generator)
|
317 |
+
timesteps = (sigmas.flatten() * 1000.0).long()
|
318 |
+
|
319 |
+
base_image_sequence_length = 256
|
320 |
+
base_shift = 0.25
|
321 |
+
max_shift = 0.75
|
322 |
+
|
323 |
+
image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2
|
324 |
+
mu = (image_sequence_length / base_image_sequence_length) ** 0.5
|
325 |
+
mu = mu * max_shift + base_shift
|
326 |
+
shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0)
|
327 |
+
noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas)
|
328 |
+
|
329 |
+
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
330 |
+
|
331 |
+
pred = transformer(
|
332 |
+
**latent_model_conditions,
|
333 |
+
**condition_model_conditions,
|
334 |
+
timestep=timesteps,
|
335 |
+
return_dict=False,
|
336 |
+
)[0]
|
337 |
+
target = FF.flow_match_target(noise, latents)
|
338 |
+
|
339 |
+
# NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation
|
340 |
+
# but let's keep it this way for now. Longer training runs should reveal more insights.
|
341 |
+
# return pred, target, sigmas
|
342 |
+
return pred, target, shifted_sigmas
|
343 |
+
|
344 |
+
def validation(
|
345 |
+
self,
|
346 |
+
pipeline: CogView4Pipeline,
|
347 |
+
prompt: str,
|
348 |
+
height: Optional[int] = None,
|
349 |
+
width: Optional[int] = None,
|
350 |
+
num_inference_steps: int = 50,
|
351 |
+
generator: Optional[torch.Generator] = None,
|
352 |
+
**kwargs,
|
353 |
+
) -> List[ArtifactType]:
|
354 |
+
generation_kwargs = {
|
355 |
+
"prompt": prompt,
|
356 |
+
"height": height,
|
357 |
+
"width": width,
|
358 |
+
"num_inference_steps": num_inference_steps,
|
359 |
+
"generator": generator,
|
360 |
+
"return_dict": True,
|
361 |
+
"output_type": "pil",
|
362 |
+
}
|
363 |
+
generation_kwargs = get_non_null_items(generation_kwargs)
|
364 |
+
image = pipeline(**generation_kwargs).images[0]
|
365 |
+
return [data.ImageArtifact(value=image)]
|
366 |
+
|
367 |
+
def _save_lora_weights(
|
368 |
+
self,
|
369 |
+
directory: str,
|
370 |
+
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
371 |
+
scheduler: Optional[SchedulerType] = None,
|
372 |
+
*args,
|
373 |
+
**kwargs,
|
374 |
+
) -> None:
|
375 |
+
# TODO(aryan): this needs refactoring
|
376 |
+
if transformer_state_dict is not None:
|
377 |
+
CogView4Pipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
|
378 |
+
if scheduler is not None:
|
379 |
+
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
|
380 |
+
|
381 |
+
def _save_model(
|
382 |
+
self,
|
383 |
+
directory: str,
|
384 |
+
transformer: CogView4Transformer2DModel,
|
385 |
+
transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
386 |
+
scheduler: Optional[SchedulerType] = None,
|
387 |
+
) -> None:
|
388 |
+
# TODO(aryan): this needs refactoring
|
389 |
+
if transformer_state_dict is not None:
|
390 |
+
with init_empty_weights():
|
391 |
+
transformer_copy = CogView4Transformer2DModel.from_config(transformer.config)
|
392 |
+
transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
|
393 |
+
transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
|
394 |
+
if scheduler is not None:
|
395 |
+
scheduler.save_pretrained(os.path.join(directory, "scheduler"))
|
finetrainers/models/hunyuan_video/base_specification.py
CHANGED
@@ -117,10 +117,7 @@ class HunyuanVideoModelSpecification(ModelSpecification):
|
|
117 |
|
118 |
@property
|
119 |
def _resolution_dim_keys(self):
|
120 |
-
|
121 |
-
return {
|
122 |
-
"latents": (2, 3, 4),
|
123 |
-
}
|
124 |
|
125 |
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
|
126 |
if self.tokenizer_id is not None:
|
|
|
117 |
|
118 |
@property
|
119 |
def _resolution_dim_keys(self):
|
120 |
+
return {"latents": (2, 3, 4)}
|
|
|
|
|
|
|
121 |
|
122 |
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
|
123 |
if self.tokenizer_id is not None:
|
finetrainers/models/ltx_video/base_specification.py
CHANGED
@@ -120,7 +120,7 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
120 |
)
|
121 |
|
122 |
if condition_model_processors is None:
|
123 |
-
condition_model_processors = [T5Processor(["
|
124 |
if latent_model_processors is None:
|
125 |
latent_model_processors = [
|
126 |
LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
|
@@ -131,9 +131,7 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
131 |
|
132 |
@property
|
133 |
def _resolution_dim_keys(self):
|
134 |
-
return {
|
135 |
-
"latents": (2, 3, 4),
|
136 |
-
}
|
137 |
|
138 |
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
|
139 |
if self.tokenizer_id is not None:
|
@@ -342,8 +340,6 @@ class LTXVideoModelSpecification(ModelSpecification):
|
|
342 |
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
|
343 |
|
344 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
345 |
-
condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
|
346 |
-
condition_model_conditions["encoder_attention_mask"] = condition_model_conditions.pop("prompt_attention_mask")
|
347 |
|
348 |
# TODO(aryan): make this configurable
|
349 |
frame_rate = 25
|
|
|
120 |
)
|
121 |
|
122 |
if condition_model_processors is None:
|
123 |
+
condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])]
|
124 |
if latent_model_processors is None:
|
125 |
latent_model_processors = [
|
126 |
LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
|
|
|
131 |
|
132 |
@property
|
133 |
def _resolution_dim_keys(self):
|
134 |
+
return {"latents": (2, 3, 4)}
|
|
|
|
|
135 |
|
136 |
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
|
137 |
if self.tokenizer_id is not None:
|
|
|
340 |
sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
|
341 |
|
342 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
|
|
|
|
343 |
|
344 |
# TODO(aryan): make this configurable
|
345 |
frame_rate = 25
|
finetrainers/models/modeling_utils.py
CHANGED
@@ -115,9 +115,6 @@ class ModelSpecification:
|
|
115 |
f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
|
116 |
)
|
117 |
|
118 |
-
def collate_fn(self, batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
|
119 |
-
raise NotImplementedError(f"ModelSpecification::collate_fn is not implemented for {self.__class__.__name__}")
|
120 |
-
|
121 |
def prepare_conditions(self, **kwargs) -> Dict[str, Any]:
|
122 |
for processor in self.condition_model_processors:
|
123 |
result = processor(**kwargs)
|
|
|
115 |
f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
|
116 |
)
|
117 |
|
|
|
|
|
|
|
118 |
def prepare_conditions(self, **kwargs) -> Dict[str, Any]:
|
119 |
for processor in self.condition_model_processors:
|
120 |
result = processor(**kwargs)
|
finetrainers/models/wan/base_specification.py
CHANGED
@@ -34,11 +34,6 @@ class WanLatentEncodeProcessor(ProcessorMixin):
|
|
34 |
output_names (`List[str]`):
|
35 |
The names of the outputs that the processor returns. The outputs are in the following order:
|
36 |
- latents: The latents of the input image/video.
|
37 |
-
- num_frames: The number of frames in the input video.
|
38 |
-
- height: The height of the input image/video.
|
39 |
-
- width: The width of the input image/video.
|
40 |
-
- latents_mean: The latent channel means from the VAE state dict.
|
41 |
-
- latents_std: The latent channel standard deviations from the VAE state dict.
|
42 |
"""
|
43 |
|
44 |
def __init__(self, output_names: List[str]):
|
@@ -111,7 +106,7 @@ class WanModelSpecification(ModelSpecification):
|
|
111 |
)
|
112 |
|
113 |
if condition_model_processors is None:
|
114 |
-
condition_model_processors = [T5Processor(["
|
115 |
if latent_model_processors is None:
|
116 |
latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
|
117 |
|
@@ -120,10 +115,7 @@ class WanModelSpecification(ModelSpecification):
|
|
120 |
|
121 |
@property
|
122 |
def _resolution_dim_keys(self):
|
123 |
-
|
124 |
-
return {
|
125 |
-
"latents": (2, 3, 4),
|
126 |
-
}
|
127 |
|
128 |
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
|
129 |
if self.tokenizer_id is not None:
|
@@ -303,7 +295,6 @@ class WanModelSpecification(ModelSpecification):
|
|
303 |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
|
304 |
|
305 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
306 |
-
condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
|
307 |
|
308 |
timesteps = (sigmas.flatten() * 1000.0).long()
|
309 |
|
|
|
34 |
output_names (`List[str]`):
|
35 |
The names of the outputs that the processor returns. The outputs are in the following order:
|
36 |
- latents: The latents of the input image/video.
|
|
|
|
|
|
|
|
|
|
|
37 |
"""
|
38 |
|
39 |
def __init__(self, output_names: List[str]):
|
|
|
106 |
)
|
107 |
|
108 |
if condition_model_processors is None:
|
109 |
+
condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
|
110 |
if latent_model_processors is None:
|
111 |
latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
|
112 |
|
|
|
115 |
|
116 |
@property
|
117 |
def _resolution_dim_keys(self):
|
118 |
+
return {"latents": (2, 3, 4)}
|
|
|
|
|
|
|
119 |
|
120 |
def load_condition_models(self) -> Dict[str, torch.nn.Module]:
|
121 |
if self.tokenizer_id is not None:
|
|
|
295 |
noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
|
296 |
|
297 |
latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
|
|
|
298 |
|
299 |
timesteps = (sigmas.flatten() * 1000.0).long()
|
300 |
|
finetrainers/processors/__init__.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
from .base import ProcessorMixin
|
2 |
from .clip import CLIPPooledProcessor
|
|
|
3 |
from .llama import LlamaProcessor
|
4 |
from .t5 import T5Processor
|
5 |
from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor
|
|
|
1 |
from .base import ProcessorMixin
|
2 |
from .clip import CLIPPooledProcessor
|
3 |
+
from .glm import CogView4GLMProcessor
|
4 |
from .llama import LlamaProcessor
|
5 |
from .t5 import T5Processor
|
6 |
from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor
|
finetrainers/processors/glm.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Union
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from transformers import AutoTokenizer, GlmModel
|
5 |
+
|
6 |
+
from .base import ProcessorMixin
|
7 |
+
|
8 |
+
|
9 |
+
class CogView4GLMProcessor(ProcessorMixin):
|
10 |
+
r"""
|
11 |
+
Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings
|
12 |
+
and attention masks for the input text.
|
13 |
+
|
14 |
+
This processor is specific to CogView4 but can be used with any other model.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
output_names (`List[str]`):
|
18 |
+
The names of the outputs that the processor should return. The first output is the embeddings of the input
|
19 |
+
text and the second output is the attention mask for the input text.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, output_names: List[str]):
|
23 |
+
super().__init__()
|
24 |
+
|
25 |
+
self.output_names = output_names
|
26 |
+
|
27 |
+
assert len(self.output_names) == 1
|
28 |
+
|
29 |
+
def forward(
|
30 |
+
self,
|
31 |
+
tokenizer: AutoTokenizer,
|
32 |
+
text_encoder: GlmModel,
|
33 |
+
caption: Union[str, List[str]],
|
34 |
+
max_sequence_length: int,
|
35 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
36 |
+
r"""
|
37 |
+
Encode the input text and return the embeddings and attention mask for the input text.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
tokenizer (`AutoTokenizer`):
|
41 |
+
The tokenizer used to tokenize the input text.
|
42 |
+
text_encoder (`GlmModel`):
|
43 |
+
The text encoder used to encode the input text.
|
44 |
+
caption (`Union[str, List[str]]`):
|
45 |
+
The input text to be encoded.
|
46 |
+
max_sequence_length (`int`):
|
47 |
+
The maximum sequence length of the input text.
|
48 |
+
"""
|
49 |
+
if isinstance(caption, str):
|
50 |
+
caption = [caption]
|
51 |
+
|
52 |
+
device = text_encoder.device
|
53 |
+
dtype = text_encoder.dtype
|
54 |
+
|
55 |
+
text_inputs = tokenizer(
|
56 |
+
caption,
|
57 |
+
padding="longest",
|
58 |
+
max_length=max_sequence_length,
|
59 |
+
truncation=True,
|
60 |
+
add_special_tokens=True,
|
61 |
+
return_tensors="pt",
|
62 |
+
)
|
63 |
+
text_input_ids = text_inputs.input_ids.to(device)
|
64 |
+
|
65 |
+
current_length = text_input_ids.size(1)
|
66 |
+
pad_length = 16 - current_length % 16
|
67 |
+
if pad_length > 0:
|
68 |
+
pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id)
|
69 |
+
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
|
70 |
+
|
71 |
+
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
|
72 |
+
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
73 |
+
|
74 |
+
return {self.output_names[0]: prompt_embeds}
|
finetrainers/trainer/sft_trainer/trainer.py
CHANGED
@@ -2,6 +2,7 @@ import functools
|
|
2 |
import json
|
3 |
import math
|
4 |
import os
|
|
|
5 |
from pathlib import Path
|
6 |
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
7 |
|
@@ -33,6 +34,13 @@ logger = logging.get_logger()
|
|
33 |
|
34 |
|
35 |
class SFTTrainer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None:
|
37 |
self.args = args
|
38 |
self.state = State()
|
@@ -72,6 +80,7 @@ class SFTTrainer:
|
|
72 |
patches.perform_patches_for_training(self.args, self.state.parallel_backend)
|
73 |
|
74 |
self.model_specification = model_specification
|
|
|
75 |
|
76 |
def run(self) -> None:
|
77 |
try:
|
@@ -254,12 +263,15 @@ class SFTTrainer:
|
|
254 |
data_root = config.pop("data_root", None)
|
255 |
dataset_file = config.pop("dataset_file", None)
|
256 |
dataset_type = config.pop("dataset_type")
|
|
|
257 |
|
258 |
if data_root is not None and dataset_file is not None:
|
259 |
raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.")
|
260 |
|
261 |
dataset_name_or_root = data_root or dataset_file
|
262 |
-
dataset = data.initialize_dataset(
|
|
|
|
|
263 |
|
264 |
if not dataset._precomputable_once and self.args.precomputation_once:
|
265 |
raise ValueError(
|
@@ -369,9 +381,9 @@ class SFTTrainer:
|
|
369 |
self.transformer.train()
|
370 |
data_iterator = iter(self.dataloader)
|
371 |
|
372 |
-
preprocessor = data.
|
373 |
rank=parallel_backend.rank,
|
374 |
-
num_items=self.args.precomputation_items,
|
375 |
processor_fn={
|
376 |
"condition": self.model_specification.prepare_conditions,
|
377 |
"latent": functools.partial(
|
@@ -379,6 +391,7 @@ class SFTTrainer:
|
|
379 |
),
|
380 |
},
|
381 |
save_dir=self.args.precomputation_dir,
|
|
|
382 |
)
|
383 |
precomputed_condition_iterator: Iterable[Dict[str, Any]] = None
|
384 |
precomputed_latent_iterator: Iterable[Dict[str, Any]] = None
|
@@ -495,7 +508,6 @@ class SFTTrainer:
|
|
495 |
|
496 |
if train_state.step % self.args.gradient_accumulation_steps == 0:
|
497 |
# TODO(aryan): revisit no_sync() for FSDP
|
498 |
-
# TODO(aryan): average the gradients for accumulation?
|
499 |
self.optimizer.step()
|
500 |
self.lr_scheduler.step()
|
501 |
self.optimizer.zero_grad()
|
@@ -651,28 +663,29 @@ class SFTTrainer:
|
|
651 |
# TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
|
652 |
for index, (key, artifact) in enumerate(list(artifacts.items())):
|
653 |
assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))
|
|
|
|
|
654 |
filename = "validation-" if not final_validation else "final-"
|
655 |
-
filename += f"{step}-{
|
656 |
output_filename = os.path.join(self.args.output_dir, filename)
|
657 |
|
658 |
if parallel_backend.is_main_process and artifact.file_extension == "mp4":
|
659 |
main_process_prompts_to_filenames[PROMPT] = filename
|
660 |
|
661 |
-
caption = f"{PROMPT} | (filename: {output_filename})"
|
662 |
if artifact.type == "image" and artifact.value is not None:
|
663 |
logger.debug(
|
664 |
f"Saving image from rank={parallel_backend.rank} to {output_filename}",
|
665 |
local_main_process_only=False,
|
666 |
)
|
667 |
artifact.value.save(output_filename)
|
668 |
-
all_processes_artifacts.append(wandb.Image(output_filename, caption=
|
669 |
elif artifact.type == "video" and artifact.value is not None:
|
670 |
logger.debug(
|
671 |
f"Saving video from rank={parallel_backend.rank} to {output_filename}",
|
672 |
local_main_process_only=False,
|
673 |
)
|
674 |
export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
|
675 |
-
all_processes_artifacts.append(wandb.Video(output_filename, caption=
|
676 |
|
677 |
# 3. Cleanup & log artifacts
|
678 |
parallel_backend.wait_for_everyone()
|
@@ -804,24 +817,16 @@ class SFTTrainer:
|
|
804 |
component.to(device)
|
805 |
|
806 |
def _set_components(self, components: Dict[str, Any]) -> None:
|
807 |
-
|
808 |
-
component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
|
809 |
-
# fmt: on
|
810 |
-
|
811 |
-
for component_name in component_names:
|
812 |
existing_component = getattr(self, component_name, None)
|
813 |
new_component = components.get(component_name, existing_component)
|
814 |
setattr(self, component_name, new_component)
|
815 |
|
816 |
def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
|
817 |
if component_names is None:
|
818 |
-
|
819 |
-
component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
|
820 |
-
# fmt: on
|
821 |
-
|
822 |
for component_name in component_names:
|
823 |
setattr(self, component_name, None)
|
824 |
-
|
825 |
utils.free_memory()
|
826 |
utils.synchronize_device()
|
827 |
|
@@ -848,7 +853,6 @@ class SFTTrainer:
|
|
848 |
training=True,
|
849 |
)
|
850 |
else:
|
851 |
-
# TODO(aryan): this branch does not work yet, needs to be implemented
|
852 |
self._delete_components()
|
853 |
|
854 |
# Load the transformer weights from the final checkpoint if performing full-finetune
|
@@ -874,50 +878,101 @@ class SFTTrainer:
|
|
874 |
self._move_components_to_device(list(components.values()))
|
875 |
return pipeline
|
876 |
|
877 |
-
def _prepare_data(
|
878 |
-
|
879 |
-
|
880 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
881 |
else:
|
882 |
-
|
883 |
-
|
884 |
-
|
885 |
-
|
886 |
-
|
887 |
-
|
888 |
-
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
893 |
-
|
894 |
-
|
895 |
-
|
896 |
-
|
897 |
-
|
898 |
-
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
|
904 |
-
|
905 |
-
|
906 |
-
|
907 |
-
|
908 |
-
|
909 |
-
|
910 |
-
|
911 |
-
|
912 |
-
|
913 |
-
|
914 |
-
|
915 |
-
|
916 |
-
|
917 |
-
|
918 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
919 |
|
920 |
-
return
|
921 |
|
922 |
def _get_training_info(self) -> Dict[str, Any]:
|
923 |
info = self.args.to_dict()
|
|
|
2 |
import json
|
3 |
import math
|
4 |
import os
|
5 |
+
import time
|
6 |
from pathlib import Path
|
7 |
from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
|
8 |
|
|
|
34 |
|
35 |
|
36 |
class SFTTrainer:
|
37 |
+
# fmt: off
|
38 |
+
_all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
|
39 |
+
_condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"]
|
40 |
+
_latent_component_names = ["vae"]
|
41 |
+
_diffusion_component_names = ["transformer", "unet", "scheduler"]
|
42 |
+
# fmt: on
|
43 |
+
|
44 |
def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None:
|
45 |
self.args = args
|
46 |
self.state = State()
|
|
|
80 |
patches.perform_patches_for_training(self.args, self.state.parallel_backend)
|
81 |
|
82 |
self.model_specification = model_specification
|
83 |
+
self._are_condition_models_loaded = False
|
84 |
|
85 |
def run(self) -> None:
|
86 |
try:
|
|
|
263 |
data_root = config.pop("data_root", None)
|
264 |
dataset_file = config.pop("dataset_file", None)
|
265 |
dataset_type = config.pop("dataset_type")
|
266 |
+
caption_options = config.pop("caption_options", {})
|
267 |
|
268 |
if data_root is not None and dataset_file is not None:
|
269 |
raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.")
|
270 |
|
271 |
dataset_name_or_root = data_root or dataset_file
|
272 |
+
dataset = data.initialize_dataset(
|
273 |
+
dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options
|
274 |
+
)
|
275 |
|
276 |
if not dataset._precomputable_once and self.args.precomputation_once:
|
277 |
raise ValueError(
|
|
|
381 |
self.transformer.train()
|
382 |
data_iterator = iter(self.dataloader)
|
383 |
|
384 |
+
preprocessor = data.initialize_preprocessor(
|
385 |
rank=parallel_backend.rank,
|
386 |
+
num_items=self.args.precomputation_items if self.args.enable_precomputation else 1,
|
387 |
processor_fn={
|
388 |
"condition": self.model_specification.prepare_conditions,
|
389 |
"latent": functools.partial(
|
|
|
391 |
),
|
392 |
},
|
393 |
save_dir=self.args.precomputation_dir,
|
394 |
+
enable_precomputation=self.args.enable_precomputation,
|
395 |
)
|
396 |
precomputed_condition_iterator: Iterable[Dict[str, Any]] = None
|
397 |
precomputed_latent_iterator: Iterable[Dict[str, Any]] = None
|
|
|
508 |
|
509 |
if train_state.step % self.args.gradient_accumulation_steps == 0:
|
510 |
# TODO(aryan): revisit no_sync() for FSDP
|
|
|
511 |
self.optimizer.step()
|
512 |
self.lr_scheduler.step()
|
513 |
self.optimizer.zero_grad()
|
|
|
663 |
# TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
|
664 |
for index, (key, artifact) in enumerate(list(artifacts.items())):
|
665 |
assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))
|
666 |
+
|
667 |
+
time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension
|
668 |
filename = "validation-" if not final_validation else "final-"
|
669 |
+
filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}"
|
670 |
output_filename = os.path.join(self.args.output_dir, filename)
|
671 |
|
672 |
if parallel_backend.is_main_process and artifact.file_extension == "mp4":
|
673 |
main_process_prompts_to_filenames[PROMPT] = filename
|
674 |
|
|
|
675 |
if artifact.type == "image" and artifact.value is not None:
|
676 |
logger.debug(
|
677 |
f"Saving image from rank={parallel_backend.rank} to {output_filename}",
|
678 |
local_main_process_only=False,
|
679 |
)
|
680 |
artifact.value.save(output_filename)
|
681 |
+
all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT))
|
682 |
elif artifact.type == "video" and artifact.value is not None:
|
683 |
logger.debug(
|
684 |
f"Saving video from rank={parallel_backend.rank} to {output_filename}",
|
685 |
local_main_process_only=False,
|
686 |
)
|
687 |
export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
|
688 |
+
all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT))
|
689 |
|
690 |
# 3. Cleanup & log artifacts
|
691 |
parallel_backend.wait_for_everyone()
|
|
|
817 |
component.to(device)
|
818 |
|
819 |
def _set_components(self, components: Dict[str, Any]) -> None:
|
820 |
+
for component_name in self._all_component_names:
|
|
|
|
|
|
|
|
|
821 |
existing_component = getattr(self, component_name, None)
|
822 |
new_component = components.get(component_name, existing_component)
|
823 |
setattr(self, component_name, new_component)
|
824 |
|
825 |
def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
|
826 |
if component_names is None:
|
827 |
+
component_names = self._all_component_names
|
|
|
|
|
|
|
828 |
for component_name in component_names:
|
829 |
setattr(self, component_name, None)
|
|
|
830 |
utils.free_memory()
|
831 |
utils.synchronize_device()
|
832 |
|
|
|
853 |
training=True,
|
854 |
)
|
855 |
else:
|
|
|
856 |
self._delete_components()
|
857 |
|
858 |
# Load the transformer weights from the final checkpoint if performing full-finetune
|
|
|
878 |
self._move_components_to_device(list(components.values()))
|
879 |
return pipeline
|
880 |
|
881 |
+
def _prepare_data(
|
882 |
+
self,
|
883 |
+
preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor],
|
884 |
+
data_iterator,
|
885 |
+
):
|
886 |
+
if not self.args.enable_precomputation:
|
887 |
+
if not self._are_condition_models_loaded:
|
888 |
+
logger.info(
|
889 |
+
"Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs."
|
890 |
+
)
|
891 |
+
condition_components = self.model_specification.load_condition_models()
|
892 |
+
latent_components = self.model_specification.load_latent_models()
|
893 |
+
all_components = {**condition_components, **latent_components}
|
894 |
+
self._set_components(all_components)
|
895 |
+
self._move_components_to_device(list(all_components.values()))
|
896 |
+
utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
|
897 |
+
else:
|
898 |
+
condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))}
|
899 |
+
latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))}
|
900 |
+
|
901 |
+
condition_iterator = preprocessor.consume(
|
902 |
+
"condition",
|
903 |
+
components=condition_components,
|
904 |
+
data_iterator=data_iterator,
|
905 |
+
generator=self.state.generator,
|
906 |
+
cache_samples=True,
|
907 |
+
)
|
908 |
+
latent_iterator = preprocessor.consume(
|
909 |
+
"latent",
|
910 |
+
components=latent_components,
|
911 |
+
data_iterator=data_iterator,
|
912 |
+
generator=self.state.generator,
|
913 |
+
use_cached_samples=True,
|
914 |
+
drop_samples=True,
|
915 |
+
)
|
916 |
+
|
917 |
+
self._are_condition_models_loaded = True
|
918 |
else:
|
919 |
+
logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
|
920 |
+
|
921 |
+
# TODO(aryan): This needs to be revisited. For some reason, the tests did not detect that self.transformer
|
922 |
+
# had become None after this but should have been loaded back from the checkpoint.
|
923 |
+
# parallel_backend = self.state.parallel_backend
|
924 |
+
# train_state = self.state.train_state
|
925 |
+
# self.checkpointer.save(
|
926 |
+
# train_state.step,
|
927 |
+
# force=True,
|
928 |
+
# _device=parallel_backend.device,
|
929 |
+
# _is_main_process=parallel_backend.is_main_process,
|
930 |
+
# )
|
931 |
+
# self._delete_components(component_names=["transformer", "unet"])
|
932 |
+
|
933 |
+
if self.args.precomputation_once:
|
934 |
+
consume_fn = preprocessor.consume_once
|
935 |
+
else:
|
936 |
+
consume_fn = preprocessor.consume
|
937 |
+
|
938 |
+
# Prepare condition iterators
|
939 |
+
condition_components = self.model_specification.load_condition_models()
|
940 |
+
component_names = list(condition_components.keys())
|
941 |
+
component_modules = list(condition_components.values())
|
942 |
+
self._set_components(condition_components)
|
943 |
+
self._move_components_to_device(component_modules)
|
944 |
+
condition_iterator = consume_fn(
|
945 |
+
"condition",
|
946 |
+
components=condition_components,
|
947 |
+
data_iterator=data_iterator,
|
948 |
+
generator=self.state.generator,
|
949 |
+
cache_samples=True,
|
950 |
+
)
|
951 |
+
self._delete_components(component_names)
|
952 |
+
del condition_components, component_names, component_modules
|
953 |
+
|
954 |
+
# Prepare latent iterators
|
955 |
+
latent_components = self.model_specification.load_latent_models()
|
956 |
+
utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
|
957 |
+
component_names = list(latent_components.keys())
|
958 |
+
component_modules = list(latent_components.values())
|
959 |
+
self._set_components(latent_components)
|
960 |
+
self._move_components_to_device(component_modules)
|
961 |
+
latent_iterator = consume_fn(
|
962 |
+
"latent",
|
963 |
+
components=latent_components,
|
964 |
+
data_iterator=data_iterator,
|
965 |
+
generator=self.state.generator,
|
966 |
+
use_cached_samples=True,
|
967 |
+
drop_samples=True,
|
968 |
+
)
|
969 |
+
self._delete_components(component_names)
|
970 |
+
del latent_components, component_names, component_modules
|
971 |
+
|
972 |
+
# self.checkpointer.load()
|
973 |
+
# self.transformer = self.checkpointer.states["model"].model[0]
|
974 |
|
975 |
+
return condition_iterator, latent_iterator
|
976 |
|
977 |
def _get_training_info(self) -> Dict[str, Any]:
|
978 |
info = self.args.to_dict()
|
finetrainers/utils/__init__.py
CHANGED
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
|
|
4 |
from .activation_checkpoint import apply_activation_checkpointing
|
5 |
from .data import determine_batch_size, should_perform_precomputation
|
6 |
from .diffusion import (
|
|
|
7 |
default_flow_shift,
|
8 |
get_scheduler_alphas,
|
9 |
get_scheduler_sigmas,
|
|
|
4 |
from .activation_checkpoint import apply_activation_checkpointing
|
5 |
from .data import determine_batch_size, should_perform_precomputation
|
6 |
from .diffusion import (
|
7 |
+
_enable_vae_memory_optimizations,
|
8 |
default_flow_shift,
|
9 |
get_scheduler_alphas,
|
10 |
get_scheduler_sigmas,
|
finetrainers/utils/diffusion.py
CHANGED
@@ -143,3 +143,10 @@ def prepare_target(
|
|
143 |
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
|
144 |
|
145 |
return target
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
|
144 |
|
145 |
return target
|
146 |
+
|
147 |
+
|
148 |
+
def _enable_vae_memory_optimizations(vae, enable_slicing: bool = False, enable_tiling: bool = False):
|
149 |
+
if hasattr(vae, "enable_slicing") and enable_slicing:
|
150 |
+
vae.enable_slicing()
|
151 |
+
if hasattr(vae, "enable_tiling") and enable_tiling:
|
152 |
+
vae.enable_tiling()
|
requirements.txt
CHANGED
@@ -40,5 +40,5 @@ av==14.1.0
|
|
40 |
git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
41 |
|
42 |
# for our frontend
|
43 |
-
gradio==5.
|
44 |
gradio_toggle
|
|
|
40 |
git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
41 |
|
42 |
# for our frontend
|
43 |
+
gradio==5.20.1
|
44 |
gradio_toggle
|
requirements_without_flash_attention.txt
CHANGED
@@ -39,5 +39,5 @@ av==14.1.0
|
|
39 |
git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
40 |
|
41 |
# for our frontend
|
42 |
-
gradio==5.
|
43 |
gradio_toggle
|
|
|
39 |
git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
40 |
|
41 |
# for our frontend
|
42 |
+
gradio==5.20.1
|
43 |
gradio_toggle
|