jbilcke-hf HF Staff commited on
Commit
ecd5028
·
1 Parent(s): 7c52128

upgrade finetrainers + gradio

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🎥
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: gray
5
  colorTo: gray
6
  sdk: gradio
7
+ sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
docs/huggingface/Downloading files from the hub.md ADDED
@@ -0,0 +1,270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [](#downloading-files)Downloading files
2
+ =======================================
3
+
4
+ [](#download-a-single-file)Download a single file
5
+ -------------------------------------------------
6
+
7
+ ### [](#huggingface_hub.hf_hub_download)hf\_hub\_download
8
+
9
+ #### huggingface\_hub.hf\_hub\_download
10
+
11
+ [](#huggingface_hub.hf_hub_download)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L663)
12
+
13
+ ( repo\_id: strfilename: strsubfolder: typing.Optional\[str\] = Nonerepo\_type: typing.Optional\[str\] = Nonerevision: typing.Optional\[str\] = Nonelibrary\_name: typing.Optional\[str\] = Nonelibrary\_version: typing.Optional\[str\] = Nonecache\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Nonelocal\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Noneuser\_agent: typing.Union\[typing.Dict, str, NoneType\] = Noneforce\_download: bool = Falseproxies: typing.Optional\[typing.Dict\] = Noneetag\_timeout: float = 10token: typing.Union\[bool, str, NoneType\] = Nonelocal\_files\_only: bool = Falseheaders: typing.Optional\[typing.Dict\[str, str\]\] = Noneendpoint: typing.Optional\[str\] = Noneresume\_download: typing.Optional\[bool\] = Noneforce\_filename: typing.Optional\[str\] = Nonelocal\_dir\_use\_symlinks: typing.Union\[bool, typing.Literal\['auto'\]\] = 'auto' ) → export const metadata = 'undefined';`str`
14
+
15
+ Expand 16 parameters
16
+
17
+ Parameters
18
+
19
+ * [](#huggingface_hub.hf_hub_download.repo_id)**repo\_id** (`str`) — A user or an organization name and a repo name separated by a `/`.
20
+ * [](#huggingface_hub.hf_hub_download.filename)**filename** (`str`) — The name of the file in the repo.
21
+ * [](#huggingface_hub.hf_hub_download.subfolder)**subfolder** (`str`, _optional_) — An optional value corresponding to a folder inside the model repo.
22
+ * [](#huggingface_hub.hf_hub_download.repo_type)**repo\_type** (`str`, _optional_) — Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`.
23
+ * [](#huggingface_hub.hf_hub_download.revision)**revision** (`str`, _optional_) — An optional Git revision id which can be a branch name, a tag, or a commit hash.
24
+ * [](#huggingface_hub.hf_hub_download.library_name)**library\_name** (`str`, _optional_) — The name of the library to which the object corresponds.
25
+ * [](#huggingface_hub.hf_hub_download.library_version)**library\_version** (`str`, _optional_) — The version of the library.
26
+ * [](#huggingface_hub.hf_hub_download.cache_dir)**cache\_dir** (`str`, `Path`, _optional_) — Path to the folder where cached files are stored.
27
+ * [](#huggingface_hub.hf_hub_download.local_dir)**local\_dir** (`str` or `Path`, _optional_) — If provided, the downloaded file will be placed under this directory.
28
+ * [](#huggingface_hub.hf_hub_download.user_agent)**user\_agent** (`dict`, `str`, _optional_) — The user-agent info in the form of a dictionary or a string.
29
+ * [](#huggingface_hub.hf_hub_download.force_download)**force\_download** (`bool`, _optional_, defaults to `False`) — Whether the file should be downloaded even if it already exists in the local cache.
30
+ * [](#huggingface_hub.hf_hub_download.proxies)**proxies** (`dict`, _optional_) — Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
31
+ * [](#huggingface_hub.hf_hub_download.etag_timeout)**etag\_timeout** (`float`, _optional_, defaults to `10`) — When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`.
32
+ * [](#huggingface_hub.hf_hub_download.token)**token** (`str`, `bool`, _optional_) — A token to be used for the download.
33
+
34
+ * If `True`, the token is read from the HuggingFace config folder.
35
+ * If a string, it’s used as the authentication token.
36
+
37
+ * [](#huggingface_hub.hf_hub_download.local_files_only)**local\_files\_only** (`bool`, _optional_, defaults to `False`) — If `True`, avoid downloading the file and return the path to the local cached file if it exists.
38
+ * [](#huggingface_hub.hf_hub_download.headers)**headers** (`dict`, _optional_) — Additional headers to be sent with the request.
39
+
40
+ Returns
41
+
42
+ export const metadata = 'undefined';
43
+
44
+ `str`
45
+
46
+ export const metadata = 'undefined';
47
+
48
+ Local path of file or if networking is off, last version of file cached on disk.
49
+
50
+ Raises
51
+
52
+ export const metadata = 'undefined';
53
+
54
+ [RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) or [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) or [EntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.EntryNotFoundError) or [LocalEntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.LocalEntryNotFoundError) or `EnvironmentError` or `OSError` or `ValueError`
55
+
56
+ export const metadata = 'undefined';
57
+
58
+ * [RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) — If the repository to download from cannot be found. This may be because it doesn’t exist, or because it is set to `private` and you do not have access.
59
+ * [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) — If the revision to download from cannot be found.
60
+ * [EntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.EntryNotFoundError) — If the file to download cannot be found.
61
+ * [LocalEntryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.LocalEntryNotFoundError) — If network is disabled or unavailable and file is not found in cache.
62
+ * [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) — If `token=True` but the token cannot be found.
63
+ * [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) — If ETag cannot be determined.
64
+ * [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) — If some parameter value is invalid.
65
+
66
+ Download a given file if it’s not already present in the local cache.
67
+
68
+ The new cache file layout looks like this:
69
+
70
+ * The cache directory contains one subfolder per repo\_id (namespaced by repo type)
71
+ * inside each repo folder:
72
+ * refs is a list of the latest known revision => commit\_hash pairs
73
+ * blobs contains the actual file blobs (identified by their git-sha or sha256, depending on whether they’re LFS files or not)
74
+ * snapshots contains one subfolder per commit, each “commit” contains the subset of the files that have been resolved at that particular commit. Each filename is a symlink to the blob at that particular commit.
75
+
76
+ [](#huggingface_hub.hf_hub_download.example)
77
+
78
+ Copied
79
+
80
+ \[ 96\] .
81
+ └── \[ 160\] models\--julien-c--EsperBERTo-small
82
+ ├── \[ 160\] blobs
83
+ │ ├── \[321M\] 403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
84
+ │ ├── \[ 398\] 7cb18dc9bafbfcf74629a4b760af1b160957a83e
85
+ │ └── \[1.4K\] d7edf6bd2a681fb0175f7735299831ee1b22b812
86
+ ├── \[ 96\] refs
87
+ │ └── \[ 40\] main
88
+ └── \[ 128\] snapshots
89
+ ├── \[ 128\] 2439f60ef33a0d46d85da5001d52aeda5b00ce9f
90
+ │ ├── \[ 52\] README.md -> ../../blobs/d7edf6bd2a681fb0175f7735299831ee1b22b812
91
+ │ └── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
92
+ └── \[ 128\] bbc77c8132af1cc5cf678da3f1ddf2de43606d48
93
+ ├── \[ 52\] README.md -> ../../blobs/7cb18dc9bafbfcf74629a4b760af1b160957a83e
94
+ └── \[ 76\] pytorch\_model.bin -> ../../blobs/403450e234d65943a7dcf7e05a771ce3c92faa84dd07db4ac20f592037a1e4bd
95
+
96
+ If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it’s optimized for regularly pulling the latest version of a repository.
97
+
98
+ ### [](#huggingface_hub.hf_hub_url)hf\_hub\_url
99
+
100
+ #### huggingface\_hub.hf\_hub\_url
101
+
102
+ [](#huggingface_hub.hf_hub_url)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L171)
103
+
104
+ ( repo\_id: strfilename: strsubfolder: typing.Optional\[str\] = Nonerepo\_type: typing.Optional\[str\] = Nonerevision: typing.Optional\[str\] = Noneendpoint: typing.Optional\[str\] = None )
105
+
106
+ Parameters
107
+
108
+ * [](#huggingface_hub.hf_hub_url.repo_id)**repo\_id** (`str`) — A namespace (user or an organization) name and a repo name separated by a `/`.
109
+ * [](#huggingface_hub.hf_hub_url.filename)**filename** (`str`) — The name of the file in the repo.
110
+ * [](#huggingface_hub.hf_hub_url.subfolder)**subfolder** (`str`, _optional_) — An optional value corresponding to a folder inside the repo.
111
+ * [](#huggingface_hub.hf_hub_url.repo_type)**repo\_type** (`str`, _optional_) — Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`.
112
+ * [](#huggingface_hub.hf_hub_url.revision)**revision** (`str`, _optional_) — An optional Git revision id which can be a branch name, a tag, or a commit hash.
113
+
114
+ Construct the URL of a file from the given information.
115
+
116
+ The resolved address can either be a huggingface.co-hosted url, or a link to Cloudfront (a Content Delivery Network, or CDN) for large files which are more than a few MBs.
117
+
118
+ [](#huggingface_hub.hf_hub_url.example)
119
+
120
+ Example:
121
+
122
+ Copied
123
+
124
+ \>>> from huggingface\_hub import hf\_hub\_url
125
+
126
+ \>>> hf\_hub\_url(
127
+ ... repo\_id="julien-c/EsperBERTo-small", filename="pytorch\_model.bin"
128
+ ... )
129
+ 'https://huggingface.co/julien-c/EsperBERTo-small/resolve/main/pytorch\_model.bin'
130
+
131
+ Notes:
132
+
133
+ Cloudfront is replicated over the globe so downloads are way faster for the end user (and it also lowers our bandwidth costs).
134
+
135
+ Cloudfront aggressively caches files by default (default TTL is 24 hours), however this is not an issue here because we implement a git-based versioning system on huggingface.co, which means that we store the files on S3/Cloudfront in a content-addressable way (i.e., the file name is its hash). Using content-addressable filenames means cache can’t ever be stale.
136
+
137
+ In terms of client-side caching from this library, we base our caching on the objects’ entity tag (`ETag`), which is an identifier of a specific version of a resource \[1\]\_. An object’s ETag is: its git-sha1 if stored in git, or its sha256 if stored in git-lfs.
138
+
139
+ References:
140
+
141
+ * \[1\] [https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/ETag)
142
+
143
+ [](#huggingface_hub.snapshot_download)Download a snapshot of the repo
144
+ ---------------------------------------------------------------------
145
+
146
+ #### huggingface\_hub.snapshot\_download
147
+
148
+ [](#huggingface_hub.snapshot_download)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/_snapshot_download.py#L20)
149
+
150
+ ( repo\_id: strrepo\_type: typing.Optional\[str\] = Nonerevision: typing.Optional\[str\] = Nonecache\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Nonelocal\_dir: typing.Union\[str, pathlib.Path, NoneType\] = Nonelibrary\_name: typing.Optional\[str\] = Nonelibrary\_version: typing.Optional\[str\] = Noneuser\_agent: typing.Union\[typing.Dict, str, NoneType\] = Noneproxies: typing.Optional\[typing.Dict\] = Noneetag\_timeout: float = 10force\_download: bool = Falsetoken: typing.Union\[bool, str, NoneType\] = Nonelocal\_files\_only: bool = Falseallow\_patterns: typing.Union\[typing.List\[str\], str, NoneType\] = Noneignore\_patterns: typing.Union\[typing.List\[str\], str, NoneType\] = Nonemax\_workers: int = 8tqdm\_class: typing.Optional\[tqdm.asyncio.tqdm\_asyncio\] = Noneheaders: typing.Optional\[typing.Dict\[str, str\]\] = Noneendpoint: typing.Optional\[str\] = Nonelocal\_dir\_use\_symlinks: typing.Union\[bool, typing.Literal\['auto'\]\] = 'auto'resume\_download: typing.Optional\[bool\] = None ) → export const metadata = 'undefined';`str`
151
+
152
+ Expand 18 parameters
153
+
154
+ Parameters
155
+
156
+ * [](#huggingface_hub.snapshot_download.repo_id)**repo\_id** (`str`) — A user or an organization name and a repo name separated by a `/`.
157
+ * [](#huggingface_hub.snapshot_download.repo_type)**repo\_type** (`str`, _optional_) — Set to `"dataset"` or `"space"` if downloading from a dataset or space, `None` or `"model"` if downloading from a model. Default is `None`.
158
+ * [](#huggingface_hub.snapshot_download.revision)**revision** (`str`, _optional_) — An optional Git revision id which can be a branch name, a tag, or a commit hash.
159
+ * [](#huggingface_hub.snapshot_download.cache_dir)**cache\_dir** (`str`, `Path`, _optional_) — Path to the folder where cached files are stored.
160
+ * [](#huggingface_hub.snapshot_download.local_dir)**local\_dir** (`str` or `Path`, _optional_) — If provided, the downloaded files will be placed under this directory.
161
+ * [](#huggingface_hub.snapshot_download.library_name)**library\_name** (`str`, _optional_) — The name of the library to which the object corresponds.
162
+ * [](#huggingface_hub.snapshot_download.library_version)**library\_version** (`str`, _optional_) — The version of the library.
163
+ * [](#huggingface_hub.snapshot_download.user_agent)**user\_agent** (`str`, `dict`, _optional_) — The user-agent info in the form of a dictionary or a string.
164
+ * [](#huggingface_hub.snapshot_download.proxies)**proxies** (`dict`, _optional_) — Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
165
+ * [](#huggingface_hub.snapshot_download.etag_timeout)**etag\_timeout** (`float`, _optional_, defaults to `10`) — When fetching ETag, how many seconds to wait for the server to send data before giving up which is passed to `requests.request`.
166
+ * [](#huggingface_hub.snapshot_download.force_download)**force\_download** (`bool`, _optional_, defaults to `False`) — Whether the file should be downloaded even if it already exists in the local cache.
167
+ * [](#huggingface_hub.snapshot_download.token)**token** (`str`, `bool`, _optional_) — A token to be used for the download.
168
+
169
+ * If `True`, the token is read from the HuggingFace config folder.
170
+ * If a string, it’s used as the authentication token.
171
+
172
+ * [](#huggingface_hub.snapshot_download.headers)**headers** (`dict`, _optional_) — Additional headers to include in the request. Those headers take precedence over the others.
173
+ * [](#huggingface_hub.snapshot_download.local_files_only)**local\_files\_only** (`bool`, _optional_, defaults to `False`) — If `True`, avoid downloading the file and return the path to the local cached file if it exists.
174
+ * [](#huggingface_hub.snapshot_download.allow_patterns)**allow\_patterns** (`List[str]` or `str`, _optional_) — If provided, only files matching at least one pattern are downloaded.
175
+ * [](#huggingface_hub.snapshot_download.ignore_patterns)**ignore\_patterns** (`List[str]` or `str`, _optional_) — If provided, files matching any of the patterns are not downloaded.
176
+ * [](#huggingface_hub.snapshot_download.max_workers)**max\_workers** (`int`, _optional_) — Number of concurrent threads to download files (1 thread = 1 file download). Defaults to 8.
177
+ * [](#huggingface_hub.snapshot_download.tqdm_class)**tqdm\_class** (`tqdm`, _optional_) — If provided, overwrites the default behavior for the progress bar. Passed argument must inherit from `tqdm.auto.tqdm` or at least mimic its behavior. Note that the `tqdm_class` is not passed to each individual download. Defaults to the custom HF progress bar that can be disabled by setting `HF_HUB_DISABLE_PROGRESS_BARS` environment variable.
178
+
179
+ Returns
180
+
181
+ export const metadata = 'undefined';
182
+
183
+ `str`
184
+
185
+ export const metadata = 'undefined';
186
+
187
+ folder path of the repo snapshot.
188
+
189
+ Raises
190
+
191
+ export const metadata = 'undefined';
192
+
193
+ [RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) or [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) or `EnvironmentError` or `OSError` or `ValueError`
194
+
195
+ export const metadata = 'undefined';
196
+
197
+ * [RepositoryNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RepositoryNotFoundError) — If the repository to download from cannot be found. This may be because it doesn’t exist, or because it is set to `private` and you do not have access.
198
+ * [RevisionNotFoundError](/docs/huggingface_hub/v0.29.2/en/package_reference/utilities#huggingface_hub.errors.RevisionNotFoundError) — If the revision to download from cannot be found.
199
+ * [`EnvironmentError`](https://docs.python.org/3/library/exceptions.html#EnvironmentError) — If `token=True` and the token cannot be found.
200
+ * [`OSError`](https://docs.python.org/3/library/exceptions.html#OSError) — if ETag cannot be determined.
201
+ * [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError) — if some parameter value is invalid.
202
+
203
+ Download repo files.
204
+
205
+ Download a whole snapshot of a repo’s files at the specified revision. This is useful when you want all files from a repo, because you don’t know which ones you will need a priori. All files are nested inside a folder in order to keep their actual filename relative to that folder. You can also filter which files to download using `allow_patterns` and `ignore_patterns`.
206
+
207
+ If `local_dir` is provided, the file structure from the repo will be replicated in this location. When using this option, the `cache_dir` will not be used and a `.cache/huggingface/` folder will be created at the root of `local_dir` to store some metadata related to the downloaded files. While this mechanism is not as robust as the main cache-system, it’s optimized for regularly pulling the latest version of a repository.
208
+
209
+ An alternative would be to clone the repo but this requires git and git-lfs to be installed and properly configured. It is also not possible to filter which files to download when cloning a repository using git.
210
+
211
+ [](#get-metadata-about-a-file)Get metadata about a file
212
+ -------------------------------------------------------
213
+
214
+ ### [](#huggingface_hub.get_hf_file_metadata)get\_hf\_file\_metadata
215
+
216
+ #### huggingface\_hub.get\_hf\_file\_metadata
217
+
218
+ [](#huggingface_hub.get_hf_file_metadata)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L1246)
219
+
220
+ ( url: strtoken: typing.Union\[bool, str, NoneType\] = Noneproxies: typing.Optional\[typing.Dict\] = Nonetimeout: typing.Optional\[float\] = 10library\_name: typing.Optional\[str\] = Nonelibrary\_version: typing.Optional\[str\] = Noneuser\_agent: typing.Union\[typing.Dict, str, NoneType\] = Noneheaders: typing.Optional\[typing.Dict\[str, str\]\] = None )
221
+
222
+ Parameters
223
+
224
+ * [](#huggingface_hub.get_hf_file_metadata.url)**url** (`str`) — File url, for example returned by [hf\_hub\_url()](/docs/huggingface_hub/v0.29.2/en/package_reference/file_download#huggingface_hub.hf_hub_url).
225
+ * [](#huggingface_hub.get_hf_file_metadata.token)**token** (`str` or `bool`, _optional_) — A token to be used for the download.
226
+
227
+ * If `True`, the token is read from the HuggingFace config folder.
228
+ * If `False` or `None`, no token is provided.
229
+ * If a string, it’s used as the authentication token.
230
+
231
+ * [](#huggingface_hub.get_hf_file_metadata.proxies)**proxies** (`dict`, _optional_) — Dictionary mapping protocol to the URL of the proxy passed to `requests.request`.
232
+ * [](#huggingface_hub.get_hf_file_metadata.timeout)**timeout** (`float`, _optional_, defaults to 10) — How many seconds to wait for the server to send metadata before giving up.
233
+ * [](#huggingface_hub.get_hf_file_metadata.library_name)**library\_name** (`str`, _optional_) — The name of the library to which the object corresponds.
234
+ * [](#huggingface_hub.get_hf_file_metadata.library_version)**library\_version** (`str`, _optional_) — The version of the library.
235
+ * [](#huggingface_hub.get_hf_file_metadata.user_agent)**user\_agent** (`dict`, `str`, _optional_) — The user-agent info in the form of a dictionary or a string.
236
+ * [](#huggingface_hub.get_hf_file_metadata.headers)**headers** (`dict`, _optional_) — Additional headers to be sent with the request.
237
+
238
+ Fetch metadata of a file versioned on the Hub for a given url.
239
+
240
+ ### [](#huggingface_hub.HfFileMetadata)HfFileMetadata
241
+
242
+ ### class huggingface\_hub.HfFileMetadata
243
+
244
+ [](#huggingface_hub.HfFileMetadata)[< source \>](https://github.com/huggingface/huggingface_hub/blob/v0.29.2/src/huggingface_hub/file_download.py#L147)
245
+
246
+ ( commit\_hash: typing.Optional\[str\]etag: typing.Optional\[str\]location: strsize: typing.Optional\[int\] )
247
+
248
+ Parameters
249
+
250
+ * [](#huggingface_hub.HfFileMetadata.commit_hash)**commit\_hash** (`str`, _optional_) — The commit\_hash related to the file.
251
+ * [](#huggingface_hub.HfFileMetadata.etag)**etag** (`str`, _optional_) — Etag of the file on the server.
252
+ * [](#huggingface_hub.HfFileMetadata.location)**location** (`str`) — Location where to download the file. Can be a Hub url or not (CDN).
253
+ * [](#huggingface_hub.HfFileMetadata.size)**size** (`size`) — Size of the file. In case of an LFS file, contains the size of the actual LFS file, not the pointer.
254
+
255
+ Data structure containing information about a file versioned on the Hub.
256
+
257
+ Returned by [get\_hf\_file\_metadata()](/docs/huggingface_hub/v0.29.2/en/package_reference/file_download#huggingface_hub.get_hf_file_metadata) based on a URL.
258
+
259
+ [](#caching)Caching
260
+ -------------------
261
+
262
+ The methods displayed above are designed to work with a caching system that prevents re-downloading files. The caching system was updated in v0.8.0 to become the central cache-system shared across libraries that depend on the Hub.
263
+
264
+ Read the [cache-system guide](../guides/manage-cache) for a detailed presentation of caching at at HF.
265
+
266
+ [< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/package_reference/file_download.md)
267
+
268
+ HfApi Client
269
+
270
+ [←Hugging Face Hub API](/docs/huggingface_hub/en/package_reference/hf_api) [Mixins & serialization methods→](/docs/huggingface_hub/en/package_reference/mixins)
docs/huggingface/HfApi Client API Reference.md ADDED
The diff for this file is too large to render. See raw diff
 
docs/huggingface/Load a dataset from the hub.md ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [](#load-a-dataset-from-the-hub)Load a dataset from the Hub
2
+ ===========================================================
3
+
4
+ Finding high-quality datasets that are reproducible and accessible can be difficult. One of 🤗 Datasets main goals is to provide a simple way to load a dataset of any format or type. The easiest way to get started is to discover an existing dataset on the [Hugging Face Hub](https://huggingface.co/datasets) - a community-driven collection of datasets for tasks in NLP, computer vision, and audio - and use 🤗 Datasets to download and generate the dataset.
5
+
6
+ This tutorial uses the [rotten\_tomatoes](https://huggingface.co/datasets/rotten_tomatoes) and [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) datasets, but feel free to load any dataset you want and follow along. Head over to the Hub now and find a dataset for your task!
7
+
8
+ [](#load-a-dataset)Load a dataset
9
+ ---------------------------------
10
+
11
+ Before you take the time to download a dataset, it’s often helpful to quickly get some general information about a dataset. A dataset’s information is stored inside [DatasetInfo](/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.DatasetInfo) and can include information such as the dataset description, features, and dataset size.
12
+
13
+ Use the [load\_dataset\_builder()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.load_dataset_builder) function to load a dataset builder and inspect a dataset’s attributes without committing to downloading it:
14
+
15
+ Copied
16
+
17
+ \>>> from datasets import load\_dataset\_builder
18
+ \>>> ds\_builder = load\_dataset\_builder("cornell-movie-review-data/rotten\_tomatoes")
19
+
20
+ \# Inspect dataset description
21
+ \>>> ds\_builder.info.description
22
+ Movie Review Dataset. This is a dataset of containing 5,331 positive and 5,331 negative processed sentences from Rotten Tomatoes movie reviews. This data was first used in Bo Pang and Lillian Lee, \`\`Seeing stars: Exploiting class relationships for sentiment categorization with respect to rating scales.'', Proceedings of the ACL, 2005.
23
+
24
+ \# Inspect dataset features
25
+ \>>> ds\_builder.info.features
26
+ {'label': ClassLabel(names=\['neg', 'pos'\], id\=None),
27
+ 'text': Value(dtype='string', id\=None)}
28
+
29
+ If you’re happy with the dataset, then load it with [load\_dataset()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.load_dataset):
30
+
31
+ Copied
32
+
33
+ \>>> from datasets import load\_dataset
34
+
35
+ \>>> dataset = load\_dataset("cornell-movie-review-data/rotten\_tomatoes", split="train")
36
+
37
+ [](#splits)Splits
38
+ -----------------
39
+
40
+ A split is a specific subset of a dataset like `train` and `test`. List a dataset’s split names with the [get\_dataset\_split\_names()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.get_dataset_split_names) function:
41
+
42
+ Copied
43
+
44
+ \>>> from datasets import get\_dataset\_split\_names
45
+
46
+ \>>> get\_dataset\_split\_names("cornell-movie-review-data/rotten\_tomatoes")
47
+ \['train', 'validation', 'test'\]
48
+
49
+ Then you can load a specific split with the `split` parameter. Loading a dataset `split` returns a [Dataset](/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.Dataset) object:
50
+
51
+ Copied
52
+
53
+ \>>> from datasets import load\_dataset
54
+
55
+ \>>> dataset = load\_dataset("cornell-movie-review-data/rotten\_tomatoes", split="train")
56
+ \>>> dataset
57
+ Dataset({
58
+ features: \['text', 'label'\],
59
+ num\_rows: 8530
60
+ })
61
+
62
+ If you don’t specify a `split`, 🤗 Datasets returns a [DatasetDict](/docs/datasets/v3.3.2/en/package_reference/main_classes#datasets.DatasetDict) object instead:
63
+
64
+ Copied
65
+
66
+ \>>> from datasets import load\_dataset
67
+
68
+ \>>> dataset = load\_dataset("cornell-movie-review-data/rotten\_tomatoes")
69
+ DatasetDict({
70
+ train: Dataset({
71
+ features: \['text', 'label'\],
72
+ num\_rows: 8530
73
+ })
74
+ validation: Dataset({
75
+ features: \['text', 'label'\],
76
+ num\_rows: 1066
77
+ })
78
+ test: Dataset({
79
+ features: \['text', 'label'\],
80
+ num\_rows: 1066
81
+ })
82
+ })
83
+
84
+ [](#configurations)Configurations
85
+ ---------------------------------
86
+
87
+ Some datasets contain several sub-datasets. For example, the [MInDS-14](https://huggingface.co/datasets/PolyAI/minds14) dataset has several sub-datasets, each one containing audio data in a different language. These sub-datasets are known as _configurations_ or _subsets_, and you must explicitly select one when loading the dataset. If you don’t provide a configuration name, 🤗 Datasets will raise a `ValueError` and remind you to choose a configuration.
88
+
89
+ Use the [get\_dataset\_config\_names()](/docs/datasets/v3.3.2/en/package_reference/loading_methods#datasets.get_dataset_config_names) function to retrieve a list of all the possible configurations available to your dataset:
90
+
91
+ Copied
92
+
93
+ \>>> from datasets import get\_dataset\_config\_names
94
+
95
+ \>>> configs = get\_dataset\_config\_names("PolyAI/minds14")
96
+ \>>> print(configs)
97
+ \['cs-CZ', 'de-DE', 'en-AU', 'en-GB', 'en-US', 'es-ES', 'fr-FR', 'it-IT', 'ko-KR', 'nl-NL', 'pl-PL', 'pt-PT', 'ru-RU', 'zh-CN', 'all'\]
98
+
99
+ Then load the configuration you want:
100
+
101
+ Copied
102
+
103
+ \>>> from datasets import load\_dataset
104
+
105
+ \>>> mindsFR = load\_dataset("PolyAI/minds14", "fr-FR", split="train")
106
+
107
+ [](#remote-code)Remote code
108
+ ---------------------------
109
+
110
+ Certain datasets repositories contain a loading script with the Python code used to generate the dataset. All files and code uploaded to the Hub are scanned for malware (refer to the Hub security documentation for more information), but you should still review the dataset loading scripts and authors to avoid executing malicious code on your machine. You should set `trust_remote_code=True` to use a dataset with a loading script, or you will get an error:
111
+
112
+ Copied
113
+
114
+ \>>> from datasets import get\_dataset\_config\_names, get\_dataset\_split\_names, load\_dataset
115
+
116
+ \>>> c4 = load\_dataset("c4", "en", split="train", trust\_remote\_code=True)
117
+ \>>> get\_dataset\_config\_names("c4", trust\_remote\_code=True)
118
+ \['en', 'realnewslike', 'en.noblocklist', 'en.noclean'\]
119
+ \>>> get\_dataset\_split\_names("c4", "en", trust\_remote\_code=True)
120
+ \['train', 'validation'\]
121
+
122
+ For security reasons, 🤗 Datasets do not allow running dataset loading scripts by default, and you have to pass `trust_remote_code=True` to load datasets that require running a dataset script.
123
+
124
+ [< \> Update on GitHub](https://github.com/huggingface/datasets/blob/main/docs/source/load_hub.mdx)
125
+
126
+ [←Overview](/docs/datasets/en/tutorial) [Know your dataset→](/docs/datasets/en/access)
docs/huggingface/Search the Hub.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [](#search-the-hub)Search the Hub
2
+ =================================
3
+
4
+ In this tutorial, you will learn how to search models, datasets and spaces on the Hub using `huggingface_hub`.
5
+
6
+ [](#how-to-list-repositories-)How to list repositories ?
7
+ --------------------------------------------------------
8
+
9
+ `huggingface_hub` library includes an HTTP client [HfApi](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi) to interact with the Hub. Among other things, it can list models, datasets and spaces stored on the Hub:
10
+
11
+ Copied
12
+
13
+ \>>> from huggingface\_hub import HfApi
14
+ \>>> api = HfApi()
15
+ \>>> models = api.list\_models()
16
+
17
+ The output of [list\_models()](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi.list_models) is an iterator over the models stored on the Hub.
18
+
19
+ Similarly, you can use [list\_datasets()](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi.list_datasets) to list datasets and [list\_spaces()](/docs/huggingface_hub/v0.29.2/en/package_reference/hf_api#huggingface_hub.HfApi.list_spaces) to list Spaces.
20
+
21
+ [](#how-to-filter-repositories-)How to filter repositories ?
22
+ ------------------------------------------------------------
23
+
24
+ Listing repositories is great but now you might want to filter your search. The list helpers have several attributes like:
25
+
26
+ * `filter`
27
+ * `author`
28
+ * `search`
29
+ * …
30
+
31
+ Let’s see an example to get all models on the Hub that does image classification, have been trained on the imagenet dataset and that runs with PyTorch.
32
+
33
+ Copied
34
+
35
+ models = hf\_api.list\_models(
36
+ task="image-classification",
37
+ library="pytorch",
38
+ trained\_dataset="imagenet",
39
+ )
40
+
41
+ While filtering, you can also sort the models and take only the top results. For example, the following example fetches the top 5 most downloaded datasets on the Hub:
42
+
43
+ Copied
44
+
45
+ \>>> list(list\_datasets(sort="downloads", direction=-1, limit=5))
46
+ \[DatasetInfo(
47
+ id\='argilla/databricks-dolly-15k-curated-en',
48
+ author='argilla',
49
+ sha='4dcd1dedbe148307a833c931b21ca456a1fc4281',
50
+ last\_modified=datetime.datetime(2023, 10, 2, 12, 32, 53, tzinfo=datetime.timezone.utc),
51
+ private=False,
52
+ downloads=8889377,
53
+ (...)
54
+
55
+ To explore available filters on the Hub, visit [models](https://huggingface.co/models) and [datasets](https://huggingface.co/datasets) pages in your browser, search for some parameters and look at the values in the URL.
56
+
57
+ [< \> Update on GitHub](https://github.com/huggingface/huggingface_hub/blob/main/docs/source/en/guides/search.md)
58
+
59
+ HfApi Client
60
+
61
+ [←Repository](/docs/huggingface_hub/en/guides/repository) [Inference→](/docs/huggingface_hub/en/guides/inference)
finetrainers/args.py CHANGED
@@ -316,6 +316,7 @@ class BaseArgs:
316
  # Dataset arguments
317
  dataset_config: str = None
318
  dataset_shuffle_buffer_size: int = 1
 
319
  precomputation_items: int = 512
320
  precomputation_dir: Optional[str] = None
321
  precomputation_once: bool = False
@@ -420,6 +421,7 @@ class BaseArgs:
420
  dataset_arguments = {
421
  "dataset_config": self.dataset_config,
422
  "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
 
423
  "precomputation_items": self.precomputation_items,
424
  "precomputation_dir": self.precomputation_dir,
425
  "precomputation_once": self.precomputation_once,
@@ -625,6 +627,7 @@ def _add_model_arguments(parser: argparse.ArgumentParser) -> None:
625
  def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
626
  parser.add_argument("--dataset_config", type=str, required=True)
627
  parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
 
628
  parser.add_argument("--precomputation_items", type=int, default=512)
629
  parser.add_argument("--precomputation_dir", type=str, default=None)
630
  parser.add_argument("--precomputation_once", action="store_true")
@@ -761,6 +764,7 @@ def _map_to_args_type(args: Dict[str, Any]) -> BaseArgs:
761
  # Dataset arguments
762
  result_args.dataset_config = args.dataset_config
763
  result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
 
764
  result_args.precomputation_items = args.precomputation_items
765
  result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
766
  result_args.precomputation_once = args.precomputation_once
 
316
  # Dataset arguments
317
  dataset_config: str = None
318
  dataset_shuffle_buffer_size: int = 1
319
+ enable_precomputation: bool = False
320
  precomputation_items: int = 512
321
  precomputation_dir: Optional[str] = None
322
  precomputation_once: bool = False
 
421
  dataset_arguments = {
422
  "dataset_config": self.dataset_config,
423
  "dataset_shuffle_buffer_size": self.dataset_shuffle_buffer_size,
424
+ "enable_precomputation": self.enable_precomputation,
425
  "precomputation_items": self.precomputation_items,
426
  "precomputation_dir": self.precomputation_dir,
427
  "precomputation_once": self.precomputation_once,
 
627
  def _add_dataset_arguments(parser: argparse.ArgumentParser) -> None:
628
  parser.add_argument("--dataset_config", type=str, required=True)
629
  parser.add_argument("--dataset_shuffle_buffer_size", type=int, default=1)
630
+ parser.add_argument("--enable_precomputation", action="store_true")
631
  parser.add_argument("--precomputation_items", type=int, default=512)
632
  parser.add_argument("--precomputation_dir", type=str, default=None)
633
  parser.add_argument("--precomputation_once", action="store_true")
 
764
  # Dataset arguments
765
  result_args.dataset_config = args.dataset_config
766
  result_args.dataset_shuffle_buffer_size = args.dataset_shuffle_buffer_size
767
+ result_args.enable_precomputation = args.enable_precomputation
768
  result_args.precomputation_items = args.precomputation_items
769
  result_args.precomputation_dir = args.precomputation_dir or os.path.join(args.output_dir, "precomputed")
770
  result_args.precomputation_once = args.precomputation_once
finetrainers/config.py CHANGED
@@ -3,6 +3,7 @@ from typing import Type
3
 
4
  from .models import ModelSpecification
5
  from .models.cogvideox import CogVideoXModelSpecification
 
6
  from .models.hunyuan_video import HunyuanVideoModelSpecification
7
  from .models.ltx_video import LTXVideoModelSpecification
8
  from .models.wan import WanModelSpecification
@@ -10,6 +11,7 @@ from .models.wan import WanModelSpecification
10
 
11
  class ModelType(str, Enum):
12
  COGVIDEOX = "cogvideox"
 
13
  HUNYUAN_VIDEO = "hunyuan_video"
14
  LTX_VIDEO = "ltx_video"
15
  WAN = "wan"
@@ -21,6 +23,14 @@ class TrainingType(str, Enum):
21
 
22
 
23
  SUPPORTED_MODEL_CONFIGS = {
 
 
 
 
 
 
 
 
24
  ModelType.HUNYUAN_VIDEO: {
25
  TrainingType.LORA: HunyuanVideoModelSpecification,
26
  TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
@@ -29,10 +39,6 @@ SUPPORTED_MODEL_CONFIGS = {
29
  TrainingType.LORA: LTXVideoModelSpecification,
30
  TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
31
  },
32
- ModelType.COGVIDEOX: {
33
- TrainingType.LORA: CogVideoXModelSpecification,
34
- TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
35
- },
36
  ModelType.WAN: {
37
  TrainingType.LORA: WanModelSpecification,
38
  TrainingType.FULL_FINETUNE: WanModelSpecification,
 
3
 
4
  from .models import ModelSpecification
5
  from .models.cogvideox import CogVideoXModelSpecification
6
+ from .models.cogview4 import CogView4ModelSpecification
7
  from .models.hunyuan_video import HunyuanVideoModelSpecification
8
  from .models.ltx_video import LTXVideoModelSpecification
9
  from .models.wan import WanModelSpecification
 
11
 
12
  class ModelType(str, Enum):
13
  COGVIDEOX = "cogvideox"
14
+ COGVIEW4 = "cogview4"
15
  HUNYUAN_VIDEO = "hunyuan_video"
16
  LTX_VIDEO = "ltx_video"
17
  WAN = "wan"
 
23
 
24
 
25
  SUPPORTED_MODEL_CONFIGS = {
26
+ ModelType.COGVIDEOX: {
27
+ TrainingType.LORA: CogVideoXModelSpecification,
28
+ TrainingType.FULL_FINETUNE: CogVideoXModelSpecification,
29
+ },
30
+ ModelType.COGVIEW4: {
31
+ TrainingType.LORA: CogView4ModelSpecification,
32
+ TrainingType.FULL_FINETUNE: CogView4ModelSpecification,
33
+ },
34
  ModelType.HUNYUAN_VIDEO: {
35
  TrainingType.LORA: HunyuanVideoModelSpecification,
36
  TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification,
 
39
  TrainingType.LORA: LTXVideoModelSpecification,
40
  TrainingType.FULL_FINETUNE: LTXVideoModelSpecification,
41
  },
 
 
 
 
42
  ModelType.WAN: {
43
  TrainingType.LORA: WanModelSpecification,
44
  TrainingType.FULL_FINETUNE: WanModelSpecification,
finetrainers/data/__init__.py CHANGED
@@ -14,6 +14,14 @@ from .dataset import (
14
  initialize_dataset,
15
  wrap_iterable_dataset_for_preprocessing,
16
  )
17
- from .precomputation import DistributedDataPreprocessor, PreprocessedDataIterable
 
 
 
 
 
 
 
 
18
  from .sampler import ResolutionSampler
19
  from .utils import find_files
 
14
  initialize_dataset,
15
  wrap_iterable_dataset_for_preprocessing,
16
  )
17
+ from .precomputation import (
18
+ InMemoryDataIterable,
19
+ InMemoryDistributedDataPreprocessor,
20
+ InMemoryOnceDataIterable,
21
+ PrecomputedDataIterable,
22
+ PrecomputedDistributedDataPreprocessor,
23
+ PrecomputedOnceDataIterable,
24
+ initialize_preprocessor,
25
+ )
26
  from .sampler import ResolutionSampler
27
  from .utils import find_files
finetrainers/data/dataset.py CHANGED
@@ -29,10 +29,13 @@ decord.bridge.set_bridge("torch")
29
  logger = get_logger()
30
 
31
 
 
32
  MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
33
  COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
34
  COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
35
  COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
 
 
36
 
37
 
38
  class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
@@ -420,22 +423,69 @@ class VideoFolderDataset(torch.utils.data.IterableDataset, torch.distributed.che
420
 
421
 
422
  class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
423
- def __init__(self, dataset_name: str, infinite: bool = False) -> None:
 
 
 
 
 
 
 
424
  super().__init__()
425
 
 
 
 
 
426
  self.dataset_name = dataset_name
427
  self.infinite = infinite
428
 
429
  data = datasets.load_dataset(dataset_name, split="train", streaming=True)
430
- data = data.rename_column("txt", "caption")
431
- for column_name in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
432
- if column_name in data.column_names:
433
- data = data.cast_column(column_name, datasets.Image(mode="RGB"))
434
- data = data.rename_column(column_name, "image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
 
436
  self._data = data
437
  self._sample_index = 0
438
  self._precomputable_once = False
 
 
439
 
440
  def _get_data_iter(self):
441
  if self._sample_index == 0:
@@ -446,6 +496,9 @@ class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkp
446
  while True:
447
  for sample in self._get_data_iter():
448
  self._sample_index += 1
 
 
 
449
  yield sample
450
 
451
  if not self.infinite:
@@ -464,22 +517,69 @@ class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkp
464
 
465
 
466
  class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
467
- def __init__(self, dataset_name: str, infinite: bool = False) -> None:
 
 
 
 
 
 
 
468
  super().__init__()
469
 
 
 
 
 
470
  self.dataset_name = dataset_name
471
  self.infinite = infinite
472
 
473
  data = datasets.load_dataset(dataset_name, split="train", streaming=True)
474
- data = data.rename_column("txt", "caption")
475
- for column_name in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
476
- if column_name in data.column_names:
477
- data = data.cast_column(column_name, datasets.Video())
478
- data = data.rename_column(column_name, "video")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
479
 
480
  self._data = data
481
  self._sample_index = 0
482
  self._precomputable_once = False
 
 
483
 
484
  def _get_data_iter(self):
485
  if self._sample_index == 0:
@@ -490,6 +590,9 @@ class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkp
490
  while True:
491
  for sample in self._get_data_iter():
492
  self._sample_index += 1
 
 
 
493
  yield sample
494
 
495
  if not self.infinite:
@@ -600,11 +703,17 @@ class IterableDatasetPreprocessingWrapper(
600
  for sample in iter(self.dataset):
601
  if self.dataset_type == "image":
602
  if self.image_resolution_buckets:
 
 
 
603
  sample["image"] = FF.resize_to_nearest_bucket_image(
604
  sample["image"], self.image_resolution_buckets, self.reshape_mode
605
  )
606
  elif self.dataset_type == "video":
607
  if self.video_resolution_buckets:
 
 
 
608
  sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
609
  sample["video"], self.video_resolution_buckets, self.reshape_mode
610
  )
@@ -682,7 +791,12 @@ class IterableCombinedDataset(torch.utils.data.IterableDataset, torch.distribute
682
 
683
  # TODO(aryan): maybe write a test for this
684
  def initialize_dataset(
685
- dataset_name_or_root: str, dataset_type: str = "video", streaming: bool = True, infinite: bool = False
 
 
 
 
 
686
  ) -> torch.utils.data.IterableDataset:
687
  assert dataset_type in ["image", "video"]
688
 
@@ -692,7 +806,7 @@ def initialize_dataset(
692
  does_repo_exist_on_hub = False
693
 
694
  if does_repo_exist_on_hub:
695
- return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite)
696
  else:
697
  return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
698
 
@@ -745,14 +859,33 @@ def _initialize_local_dataset(dataset_name_or_root: str, dataset_type: str, infi
745
  return dataset
746
 
747
 
748
- def _initialize_hub_dataset(dataset_name: str, dataset_type: str, infinite: bool = False):
 
 
749
  repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
750
  if _has_data_caption_file_pairs(repo_file_list, remote=True):
751
  return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
752
  elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
753
  return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
754
- else:
755
- return _initialize_webdataset(dataset_name, dataset_type, infinite)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
756
 
757
 
758
  def _initialize_data_caption_file_dataset_from_hub(
@@ -778,13 +911,14 @@ def _initialize_data_file_caption_file_dataset_from_hub(
778
 
779
 
780
  def _initialize_webdataset(
781
- dataset_name: str, dataset_type: str, infinite: bool = False
782
  ) -> torch.utils.data.IterableDataset:
783
  logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
 
784
  if dataset_type == "image":
785
- return ImageWebDataset(dataset_name, infinite=infinite)
786
  else:
787
- return VideoWebDataset(dataset_name, infinite=infinite)
788
 
789
 
790
  def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
 
29
  logger = get_logger()
30
 
31
 
32
+ # fmt: off
33
  MAX_PRECOMPUTABLE_ITEMS_LIMIT = 1024
34
  COMMON_CAPTION_FILES = ["prompt.txt", "prompts.txt", "caption.txt", "captions.txt"]
35
  COMMON_VIDEO_FILES = ["video.txt", "videos.txt"]
36
  COMMON_IMAGE_FILES = ["image.txt", "images.txt"]
37
+ COMMON_WDS_CAPTION_COLUMN_NAMES = ["txt", "text", "caption", "captions", "short_caption", "long_caption", "prompt", "prompts", "short_prompt", "long_prompt", "description", "descriptions", "alt_text", "alt_texts", "alt_caption", "alt_captions", "alt_prompt", "alt_prompts", "alt_description", "alt_descriptions", "image_description", "image_descriptions", "image_caption", "image_captions", "image_prompt", "image_prompts", "image_alt_text", "image_alt_texts", "image_alt_caption", "image_alt_captions", "image_alt_prompt", "image_alt_prompts", "image_alt_description", "image_alt_descriptions", "video_description", "video_descriptions", "video_caption", "video_captions", "video_prompt", "video_prompts", "video_alt_text", "video_alt_texts", "video_alt_caption", "video_alt_captions", "video_alt_prompt", "video_alt_prompts", "video_alt_description"]
38
+ # fmt: on
39
 
40
 
41
  class ImageCaptionFilePairDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
 
423
 
424
 
425
  class ImageWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
426
+ def __init__(
427
+ self,
428
+ dataset_name: str,
429
+ infinite: bool = False,
430
+ column_names: Union[str, List[str]] = "__auto__",
431
+ weights: Dict[str, float] = -1,
432
+ **kwargs,
433
+ ) -> None:
434
  super().__init__()
435
 
436
+ assert weights == -1 or isinstance(
437
+ weights, dict
438
+ ), "`weights` must be a dictionary of probabilities for each caption column"
439
+
440
  self.dataset_name = dataset_name
441
  self.infinite = infinite
442
 
443
  data = datasets.load_dataset(dataset_name, split="train", streaming=True)
444
+
445
+ if column_names == "__auto__":
446
+ if weights == -1:
447
+ caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
448
+ if len(caption_columns) == 0:
449
+ raise ValueError(
450
+ f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
451
+ )
452
+ weights = [1] * len(caption_columns)
453
+ else:
454
+ caption_columns = list(weights.keys())
455
+ weights = list(weights.values())
456
+ if not all(column in data.column_names for column in caption_columns):
457
+ raise ValueError(
458
+ f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
459
+ )
460
+ else:
461
+ if isinstance(column_names, str):
462
+ if column_names not in data.column_names:
463
+ raise ValueError(
464
+ f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
465
+ )
466
+ caption_columns = [column_names]
467
+ weights = [1] if weights == -1 else [weights.get(column_names)]
468
+ elif isinstance(column_names, list):
469
+ if not all(column in data.column_names for column in column_names):
470
+ raise ValueError(
471
+ f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
472
+ )
473
+ caption_columns = column_names
474
+ weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
475
+ else:
476
+ raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
477
+
478
+ for column_names in constants.SUPPORTED_IMAGE_FILE_EXTENSIONS:
479
+ if column_names in data.column_names:
480
+ data = data.cast_column(column_names, datasets.Image(mode="RGB"))
481
+ data = data.rename_column(column_names, "image")
482
+ break
483
 
484
  self._data = data
485
  self._sample_index = 0
486
  self._precomputable_once = False
487
+ self._caption_columns = caption_columns
488
+ self._weights = weights
489
 
490
  def _get_data_iter(self):
491
  if self._sample_index == 0:
 
496
  while True:
497
  for sample in self._get_data_iter():
498
  self._sample_index += 1
499
+ caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
500
+ sample["caption"] = sample[caption_column]
501
+ sample["image"] = _preprocess_image(sample["image"])
502
  yield sample
503
 
504
  if not self.infinite:
 
517
 
518
 
519
  class VideoWebDataset(torch.utils.data.IterableDataset, torch.distributed.checkpoint.stateful.Stateful):
520
+ def __init__(
521
+ self,
522
+ dataset_name: str,
523
+ infinite: bool = False,
524
+ column_names: Union[str, List[str]] = "__auto__",
525
+ weights: Dict[str, float] = -1,
526
+ **kwargs,
527
+ ) -> None:
528
  super().__init__()
529
 
530
+ assert weights == -1 or isinstance(
531
+ weights, dict
532
+ ), "`weights` must be a dictionary of probabilities for each caption column"
533
+
534
  self.dataset_name = dataset_name
535
  self.infinite = infinite
536
 
537
  data = datasets.load_dataset(dataset_name, split="train", streaming=True)
538
+
539
+ if column_names == "__auto__":
540
+ if weights == -1:
541
+ caption_columns = [column for column in data.column_names if column in COMMON_WDS_CAPTION_COLUMN_NAMES]
542
+ if len(caption_columns) == 0:
543
+ raise ValueError(
544
+ f"No common caption column found in the dataset. Supported columns are: {COMMON_WDS_CAPTION_COLUMN_NAMES}"
545
+ )
546
+ weights = [1] * len(caption_columns)
547
+ else:
548
+ caption_columns = list(weights.keys())
549
+ weights = list(weights.values())
550
+ if not all(column in data.column_names for column in caption_columns):
551
+ raise ValueError(
552
+ f"Caption columns {caption_columns} not found in the dataset. Available columns are: {data.column_names}"
553
+ )
554
+ else:
555
+ if isinstance(column_names, str):
556
+ if column_names not in data.column_names:
557
+ raise ValueError(
558
+ f"Caption column {column_names} not found in the dataset. Available columns are: {data.column_names}"
559
+ )
560
+ caption_columns = [column_names]
561
+ weights = [1] if weights == -1 else [weights.get(column_names)]
562
+ elif isinstance(column_names, list):
563
+ if not all(column in data.column_names for column in column_names):
564
+ raise ValueError(
565
+ f"Caption columns {column_names} not found in the dataset. Available columns are: {data.column_names}"
566
+ )
567
+ caption_columns = column_names
568
+ weights = [1] if weights == -1 else [weights.get(column) for column in column_names]
569
+ else:
570
+ raise ValueError(f"Unsupported type for column_name: {type(column_names)}")
571
+
572
+ for column_names in constants.SUPPORTED_VIDEO_FILE_EXTENSIONS:
573
+ if column_names in data.column_names:
574
+ data = data.cast_column(column_names, datasets.Video())
575
+ data = data.rename_column(column_names, "video")
576
+ break
577
 
578
  self._data = data
579
  self._sample_index = 0
580
  self._precomputable_once = False
581
+ self._caption_columns = caption_columns
582
+ self._weights = weights
583
 
584
  def _get_data_iter(self):
585
  if self._sample_index == 0:
 
590
  while True:
591
  for sample in self._get_data_iter():
592
  self._sample_index += 1
593
+ caption_column = random.choices(self._caption_columns, weights=self._weights, k=1)[0]
594
+ sample["caption"] = sample[caption_column]
595
+ sample["video"] = _preprocess_video(sample["video"])
596
  yield sample
597
 
598
  if not self.infinite:
 
703
  for sample in iter(self.dataset):
704
  if self.dataset_type == "image":
705
  if self.image_resolution_buckets:
706
+ sample["_original_num_frames"] = 1
707
+ sample["_original_height"] = sample["image"].size(1)
708
+ sample["_original_width"] = sample["image"].size(2)
709
  sample["image"] = FF.resize_to_nearest_bucket_image(
710
  sample["image"], self.image_resolution_buckets, self.reshape_mode
711
  )
712
  elif self.dataset_type == "video":
713
  if self.video_resolution_buckets:
714
+ sample["_original_num_frames"] = sample["video"].size(0)
715
+ sample["_original_height"] = sample["video"].size(2)
716
+ sample["_original_width"] = sample["video"].size(3)
717
  sample["video"], _first_frame_only = FF.resize_to_nearest_bucket_video(
718
  sample["video"], self.video_resolution_buckets, self.reshape_mode
719
  )
 
791
 
792
  # TODO(aryan): maybe write a test for this
793
  def initialize_dataset(
794
+ dataset_name_or_root: str,
795
+ dataset_type: str = "video",
796
+ streaming: bool = True,
797
+ infinite: bool = False,
798
+ *,
799
+ _caption_options: Optional[Dict[str, Any]] = None,
800
  ) -> torch.utils.data.IterableDataset:
801
  assert dataset_type in ["image", "video"]
802
 
 
806
  does_repo_exist_on_hub = False
807
 
808
  if does_repo_exist_on_hub:
809
+ return _initialize_hub_dataset(dataset_name_or_root, dataset_type, infinite, _caption_options=_caption_options)
810
  else:
811
  return _initialize_local_dataset(dataset_name_or_root, dataset_type, infinite)
812
 
 
859
  return dataset
860
 
861
 
862
+ def _initialize_hub_dataset(
863
+ dataset_name: str, dataset_type: str, infinite: bool = False, *, _caption_options: Optional[Dict[str, Any]] = None
864
+ ):
865
  repo_file_list = list_repo_files(dataset_name, repo_type="dataset")
866
  if _has_data_caption_file_pairs(repo_file_list, remote=True):
867
  return _initialize_data_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
868
  elif _has_data_file_caption_file_lists(repo_file_list, remote=True):
869
  return _initialize_data_file_caption_file_dataset_from_hub(dataset_name, dataset_type, infinite)
870
+
871
+ has_tar_files = any(file.endswith(".tar") or file.endswith(".parquet") for file in repo_file_list)
872
+ if has_tar_files:
873
+ return _initialize_webdataset(dataset_name, dataset_type, infinite, _caption_options=_caption_options)
874
+
875
+ # TODO(aryan): This should be improved
876
+ caption_files = [pathlib.Path(file).name for file in repo_file_list if file.endswith(".txt")]
877
+ if len(caption_files) < MAX_PRECOMPUTABLE_ITEMS_LIMIT:
878
+ try:
879
+ dataset_root = snapshot_download(dataset_name, repo_type="dataset")
880
+ if dataset_type == "image":
881
+ dataset = ImageFolderDataset(dataset_root, infinite=infinite)
882
+ else:
883
+ dataset = VideoFolderDataset(dataset_root, infinite=infinite)
884
+ return dataset
885
+ except Exception:
886
+ pass
887
+
888
+ raise ValueError(f"Could not load dataset {dataset_name} from the HF Hub")
889
 
890
 
891
  def _initialize_data_caption_file_dataset_from_hub(
 
911
 
912
 
913
  def _initialize_webdataset(
914
+ dataset_name: str, dataset_type: str, infinite: bool = False, _caption_options: Optional[Dict[str, Any]] = None
915
  ) -> torch.utils.data.IterableDataset:
916
  logger.info(f"Streaming webdataset {dataset_name} from the HF Hub")
917
+ _caption_options = _caption_options or {}
918
  if dataset_type == "image":
919
+ return ImageWebDataset(dataset_name, infinite=infinite, **_caption_options)
920
  else:
921
+ return VideoWebDataset(dataset_name, infinite=infinite, **_caption_options)
922
 
923
 
924
  def _has_data_caption_file_pairs(root: Union[pathlib.Path, List[str]], remote: bool = False) -> bool:
finetrainers/data/precomputation.py CHANGED
@@ -1,13 +1,132 @@
1
  import pathlib
2
- from typing import Any, Callable, Dict, Iterable, Optional
3
 
4
  import torch
5
  from tqdm.auto import tqdm
6
 
7
  from .. import utils
 
8
 
9
 
10
- class DistributedDataPreprocessor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def __init__(
12
  self,
13
  rank: int,
@@ -15,13 +134,15 @@ class DistributedDataPreprocessor:
15
  processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
16
  save_dir: str,
17
  ) -> None:
 
 
18
  self._rank = rank
19
  self._num_items = num_items
20
  self._processor_fn = processor_fn
21
  self._save_dir = pathlib.Path(save_dir)
22
 
23
  self._cached_samples = []
24
- self._preprocessed_iterator: "PreprocessedDataIterable" = None
25
 
26
  self._save_dir.mkdir(parents=True, exist_ok=True)
27
 
@@ -59,9 +180,8 @@ class DistributedDataPreprocessor:
59
  if drop_samples:
60
  del self._cached_samples
61
  self._cached_samples = []
62
- utils.free_memory()
63
 
64
- self._preprocessed_iterator = PreprocessedDataIterable(self._rank, self._save_dir, data_type)
65
  return iter(self._preprocessed_iterator)
66
 
67
  def consume_once(
@@ -95,9 +215,8 @@ class DistributedDataPreprocessor:
95
  if drop_samples:
96
  del self._cached_samples
97
  self._cached_samples = []
98
- utils.free_memory()
99
 
100
- self._preprocessed_iterator = PreprocessedOnceDataIterable(self._rank, self._save_dir, data_type)
101
  return iter(self._preprocessed_iterator)
102
 
103
  @property
@@ -107,7 +226,70 @@ class DistributedDataPreprocessor:
107
  return self._preprocessed_iterator.requires_data
108
 
109
 
110
- class PreprocessedDataIterable:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
112
  self._rank = rank
113
  self._save_dir = pathlib.Path(save_dir)
@@ -130,7 +312,13 @@ class PreprocessedDataIterable:
130
  return self._requires_data
131
 
132
 
133
- class PreprocessedOnceDataIterable:
 
 
 
 
 
 
134
  def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
135
  self._rank = rank
136
  self._save_dir = pathlib.Path(save_dir)
@@ -153,6 +341,31 @@ class PreprocessedOnceDataIterable:
153
  return self._requires_data
154
 
155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
  def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
157
  filename = directory / f"{data_type}-{rank}-{index}.pt"
158
  torch.save(item, filename.as_posix())
 
1
  import pathlib
2
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Union
3
 
4
  import torch
5
  from tqdm.auto import tqdm
6
 
7
  from .. import utils
8
+ from ..logging import get_logger
9
 
10
 
11
+ logger = get_logger()
12
+
13
+
14
+ def initialize_preprocessor(
15
+ rank: int,
16
+ num_items: int,
17
+ processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
18
+ save_dir: Optional[str] = None,
19
+ enable_precomputation: bool = False,
20
+ ) -> Union["InMemoryDistributedDataPreprocessor", "PrecomputedDistributedDataPreprocessor"]:
21
+ if enable_precomputation:
22
+ return PrecomputedDistributedDataPreprocessor(rank, num_items, processor_fn, save_dir)
23
+ return InMemoryDistributedDataPreprocessor(rank, num_items, processor_fn)
24
+
25
+
26
+ class DistributedDataProcessorMixin:
27
+ def consume(self, *args, **kwargs):
28
+ raise NotImplementedError("DistributedDataProcessorMixin::consume must be implemented by the subclass.")
29
+
30
+ def consume_once(self, *args, **kwargs):
31
+ raise NotImplementedError("DistributedDataProcessorMixin::consume_once must be implemented by the subclass.")
32
+
33
+ @property
34
+ def requires_data(self):
35
+ raise NotImplementedError("DistributedDataProcessorMixin::requires_data must be implemented by the subclass.")
36
+
37
+
38
+ class InMemoryDistributedDataPreprocessor(DistributedDataProcessorMixin):
39
+ def __init__(
40
+ self, rank: int, num_items: int, processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]]
41
+ ) -> None:
42
+ super().__init__()
43
+
44
+ self._rank = rank
45
+ self._num_items = num_items
46
+ self._processor_fn = processor_fn
47
+
48
+ self._cached_samples = []
49
+ self._buffer = InMemoryDataBuffer(num_items)
50
+ self._preprocessed_iterator: Union["InMemoryDataIterable", "InMemoryOnceDataIterable"] = None
51
+
52
+ def consume(
53
+ self,
54
+ data_type: str,
55
+ components: Dict[str, Any],
56
+ data_iterator,
57
+ generator: Optional[torch.Generator] = None,
58
+ cache_samples: bool = False,
59
+ use_cached_samples: bool = False,
60
+ drop_samples: bool = False,
61
+ ) -> Iterable[Dict[str, Any]]:
62
+ if data_type not in self._processor_fn.keys():
63
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
64
+ if cache_samples:
65
+ if use_cached_samples:
66
+ raise ValueError("Cannot cache and use cached samples at the same time.")
67
+ if drop_samples:
68
+ raise ValueError("Cannot cache and drop samples at the same time.")
69
+
70
+ for i in range(self._num_items):
71
+ if use_cached_samples:
72
+ item = self._cached_samples[i]
73
+ else:
74
+ item = next(data_iterator)
75
+ if cache_samples:
76
+ self._cached_samples.append(item)
77
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
78
+ self._buffer.add(data_type, item)
79
+
80
+ if drop_samples:
81
+ del self._cached_samples
82
+ self._cached_samples = []
83
+
84
+ self._preprocessed_iterator = InMemoryDataIterable(self._rank, data_type, self._buffer)
85
+ return iter(self._preprocessed_iterator)
86
+
87
+ def consume_once(
88
+ self,
89
+ data_type: str,
90
+ components: Dict[str, Any],
91
+ data_iterator,
92
+ generator: Optional[torch.Generator] = None,
93
+ cache_samples: bool = False,
94
+ use_cached_samples: bool = False,
95
+ drop_samples: bool = False,
96
+ ) -> Iterable[Dict[str, Any]]:
97
+ if data_type not in self._processor_fn.keys():
98
+ raise ValueError(f"Invalid data type: {data_type}. Supported types: {list(self._processor_fn.keys())}")
99
+ if cache_samples:
100
+ if use_cached_samples:
101
+ raise ValueError("Cannot cache and use cached samples at the same time.")
102
+ if drop_samples:
103
+ raise ValueError("Cannot cache and drop samples at the same time.")
104
+
105
+ for i in range(self._num_items):
106
+ if use_cached_samples:
107
+ item = self._cached_samples[i]
108
+ else:
109
+ item = next(data_iterator)
110
+ if cache_samples:
111
+ self._cached_samples.append(item)
112
+ item = self._processor_fn[data_type](**item, **components, generator=generator)
113
+ self._buffer.add(data_type, item)
114
+
115
+ if drop_samples:
116
+ del self._cached_samples
117
+ self._cached_samples = []
118
+
119
+ self._preprocessed_iterator = InMemoryOnceDataIterable(self._rank, data_type, self._buffer)
120
+ return iter(self._preprocessed_iterator)
121
+
122
+ @property
123
+ def requires_data(self):
124
+ if self._preprocessed_iterator is None:
125
+ return True
126
+ return self._preprocessed_iterator.requires_data
127
+
128
+
129
+ class PrecomputedDistributedDataPreprocessor(DistributedDataProcessorMixin):
130
  def __init__(
131
  self,
132
  rank: int,
 
134
  processor_fn: Dict[str, Callable[[Dict[str, Any]], Dict[str, Any]]],
135
  save_dir: str,
136
  ) -> None:
137
+ super().__init__()
138
+
139
  self._rank = rank
140
  self._num_items = num_items
141
  self._processor_fn = processor_fn
142
  self._save_dir = pathlib.Path(save_dir)
143
 
144
  self._cached_samples = []
145
+ self._preprocessed_iterator: Union["PrecomputedDataIterable", "PrecomputedOnceDataIterable"] = None
146
 
147
  self._save_dir.mkdir(parents=True, exist_ok=True)
148
 
 
180
  if drop_samples:
181
  del self._cached_samples
182
  self._cached_samples = []
 
183
 
184
+ self._preprocessed_iterator = PrecomputedDataIterable(self._rank, self._save_dir, data_type)
185
  return iter(self._preprocessed_iterator)
186
 
187
  def consume_once(
 
215
  if drop_samples:
216
  del self._cached_samples
217
  self._cached_samples = []
 
218
 
219
+ self._preprocessed_iterator = PrecomputedOnceDataIterable(self._rank, self._save_dir, data_type)
220
  return iter(self._preprocessed_iterator)
221
 
222
  @property
 
226
  return self._preprocessed_iterator.requires_data
227
 
228
 
229
+ class InMemoryDataIterable:
230
+ """
231
+ An iterator that loads data items from an in-memory buffer. Once all the data is consumed,
232
+ `requires_data` is set to True, indicating that the more data is required and the preprocessor's
233
+ consume method should be called again.
234
+ """
235
+
236
+ def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
237
+ self._rank = rank
238
+ self._data_type = data_type
239
+ self._buffer = buffer
240
+
241
+ self._requires_data = False
242
+
243
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
244
+ while (length := self._buffer.get_length(self._data_type)) > 0:
245
+ if length <= 1:
246
+ self._requires_data = True
247
+ yield self._buffer.get(self._data_type)
248
+
249
+ def __len__(self) -> int:
250
+ return self._buffer.get_length(self._data_type)
251
+
252
+ @property
253
+ def requires_data(self):
254
+ return self._requires_data
255
+
256
+
257
+ class InMemoryOnceDataIterable:
258
+ """
259
+ An iterator that loads data items from an in-memory buffer. This iterator will never set
260
+ `requires_data` to True, as it is assumed that all the data was configured to be preprocessed
261
+ by the user. The data will indefinitely be cycled from the buffer.
262
+ """
263
+
264
+ def __init__(self, rank: int, data_type: str, buffer: "InMemoryDataBuffer") -> None:
265
+ self._rank = rank
266
+ self._data_type = data_type
267
+ self._buffer = buffer
268
+
269
+ self._requires_data = False
270
+
271
+ def __iter__(self) -> Iterable[Dict[str, Any]]:
272
+ assert len(self) > 0, "No data available in the buffer."
273
+ while True:
274
+ item = self._buffer.get(self._data_type)
275
+ yield item
276
+ self._buffer.add(self._data_type, item)
277
+
278
+ def __len__(self) -> int:
279
+ return self._buffer.get_length(self._data_type)
280
+
281
+ @property
282
+ def requires_data(self):
283
+ return self._requires_data
284
+
285
+
286
+ class PrecomputedDataIterable:
287
+ """
288
+ An iterator that loads preconfigured number of data items from disk. Once all the data is
289
+ loaded, `requires_data` is set to True, indicating that the more data is required and
290
+ the preprocessor's consume method should be called again.
291
+ """
292
+
293
  def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
294
  self._rank = rank
295
  self._save_dir = pathlib.Path(save_dir)
 
312
  return self._requires_data
313
 
314
 
315
+ class PrecomputedOnceDataIterable:
316
+ """
317
+ An infinite iterator that loads preprocessed data from disk. Once initialized, this iterator
318
+ will never set `requires_data` to True, as it is assumed that all the data was configured to
319
+ be preprocessed by the user.
320
+ """
321
+
322
  def __init__(self, rank: int, save_dir: str, data_type: str) -> None:
323
  self._rank = rank
324
  self._save_dir = pathlib.Path(save_dir)
 
341
  return self._requires_data
342
 
343
 
344
+ class InMemoryDataBuffer:
345
+ def __init__(self, max_limit: int = -1) -> None:
346
+ self.max_limit = max_limit
347
+ self.buffer: Dict[str, List[str]] = {}
348
+
349
+ def add(self, data_type: str, item: Dict[str, Any]) -> None:
350
+ if data_type not in self.buffer:
351
+ self.buffer[data_type] = []
352
+ if self.max_limit != -1 and len(self.buffer[data_type]) >= self.max_limit:
353
+ logger.log_freq(
354
+ "WARN",
355
+ "IN_MEMORY_DATA_BUFFER_FULL",
356
+ "Buffer is full. Dropping the oldest item. This message will be logged every 64th time this happens.",
357
+ 64,
358
+ )
359
+ self.buffer[data_type].pop(0)
360
+ self.buffer[data_type].append(item)
361
+
362
+ def get(self, data_type: str) -> Dict[str, Any]:
363
+ return self.buffer[data_type].pop(0)
364
+
365
+ def get_length(self, data_type: str) -> int:
366
+ return len(self.buffer[data_type])
367
+
368
+
369
  def _save_item(rank: int, index: int, item: Dict[str, Any], directory: pathlib.Path, data_type: str) -> None:
370
  filename = directory / f"{data_type}-{rank}-{index}.pt"
371
  torch.save(item, filename.as_posix())
finetrainers/functional/image.py CHANGED
@@ -22,7 +22,7 @@ def resize_crop_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tenso
22
 
23
 
24
  def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
25
- return F.interpolate(image, size=size, mode="bicubic", align_corners=False)
26
 
27
 
28
  def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
 
22
 
23
 
24
  def bicubic_resize_image(image: torch.Tensor, size: Tuple[int, int]) -> torch.Tensor:
25
+ return F.interpolate(image.unsqueeze(0), size=size, mode="bicubic", align_corners=False)[0]
26
 
27
 
28
  def find_nearest_resolution_image(image: torch.Tensor, resolution_buckets: List[Tuple[int, int]]) -> Tuple[int, int]:
finetrainers/models/cogvideox/base_specification.py CHANGED
@@ -105,7 +105,7 @@ class CogVideoXModelSpecification(ModelSpecification):
105
  )
106
 
107
  if condition_model_processors is None:
108
- condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
109
  if latent_model_processors is None:
110
  latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
111
 
@@ -337,7 +337,6 @@ class CogVideoXModelSpecification(ModelSpecification):
337
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
338
  latent_model_conditions["image_rotary_emb"] = image_rotary_emb
339
  latent_model_conditions["ofs"] = ofs_emb
340
- condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
341
 
342
  velocity = transformer(
343
  **latent_model_conditions,
 
105
  )
106
 
107
  if condition_model_processors is None:
108
+ condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
109
  if latent_model_processors is None:
110
  latent_model_processors = [CogVideoXLatentEncodeProcessor(["latents"])]
111
 
 
337
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
338
  latent_model_conditions["image_rotary_emb"] = image_rotary_emb
339
  latent_model_conditions["ofs"] = ofs_emb
 
340
 
341
  velocity = transformer(
342
  **latent_model_conditions,
finetrainers/models/cogview4/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .base_specification import CogView4ModelSpecification
finetrainers/models/cogview4/base_specification.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Any, Dict, List, Optional, Tuple
3
+
4
+ import torch
5
+ from accelerate import init_empty_weights
6
+ from diffusers import (
7
+ AutoencoderKL,
8
+ CogView4Pipeline,
9
+ CogView4Transformer2DModel,
10
+ FlowMatchEulerDiscreteScheduler,
11
+ )
12
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
13
+ from transformers import AutoTokenizer, GlmModel
14
+
15
+ from ... import data
16
+ from ... import functional as FF
17
+ from ...logging import get_logger
18
+ from ...processors import CogView4GLMProcessor, ProcessorMixin
19
+ from ...typing import ArtifactType, SchedulerType
20
+ from ...utils import get_non_null_items
21
+ from ..modeling_utils import ModelSpecification
22
+
23
+
24
+ logger = get_logger()
25
+
26
+
27
+ class CogView4LatentEncodeProcessor(ProcessorMixin):
28
+ r"""
29
+ Processor to encode image/video into latents using the LTX VAE.
30
+
31
+ Args:
32
+ output_names (`List[str]`):
33
+ The names of the outputs that the processor returns. The outputs are in the following order:
34
+ - latents: The latents of the input image/video.
35
+ - original_size: The original size of the input image/video.
36
+ - target_size: The target size of the input image/video.
37
+ - crop_coords: The top-left crop coordinates of the input image/video.
38
+ """
39
+
40
+ def __init__(self, output_names: List[str]):
41
+ super().__init__()
42
+
43
+ self.output_names = output_names
44
+ assert len(self.output_names) == 4
45
+
46
+ def forward(
47
+ self,
48
+ vae: AutoencoderKL,
49
+ image: Optional[torch.Tensor] = None,
50
+ video: Optional[torch.Tensor] = None,
51
+ generator: Optional[torch.Generator] = None,
52
+ compute_posterior: bool = True,
53
+ _original_height: Optional[int] = None,
54
+ _original_width: Optional[int] = None,
55
+ ) -> Dict[str, torch.Tensor]:
56
+ device = vae.device
57
+ dtype = vae.dtype
58
+
59
+ if video is not None:
60
+ # TODO(aryan): perhaps better would be to flatten(0, 1), but need to account for reshaping sigmas accordingly
61
+ image = video[:, 0] # [B, F, C, H, W] -> [B, 1, C, H, W]
62
+
63
+ assert image.ndim == 4, f"Expected 4D tensor, got {image.ndim}D tensor"
64
+ image = image.to(device=device, dtype=vae.dtype)
65
+
66
+ if compute_posterior:
67
+ latents = vae.encode(image).latent_dist.sample(generator=generator)
68
+ latents = latents.to(dtype=dtype)
69
+ else:
70
+ if vae.use_slicing and image.shape[0] > 1:
71
+ encoded_slices = [vae._encode(x_slice) for x_slice in image.split(1)]
72
+ moments = torch.cat(encoded_slices)
73
+ else:
74
+ moments = vae._encode(image)
75
+ latents = moments.to(dtype=dtype)
76
+
77
+ batch_size = latents.size(0)
78
+ target_height = image.size(2)
79
+ target_width = image.size(3)
80
+ original_size = torch.tensor([(_original_height, _original_width)], device=device, dtype=dtype).repeat(
81
+ batch_size, 1
82
+ )
83
+ target_size = torch.tensor([(target_height, target_width)], device=device, dtype=dtype).repeat(batch_size, 1)
84
+ crop_coords = torch.tensor([(0, 0)], device=device, dtype=dtype).repeat(batch_size, 1)
85
+
86
+ return {
87
+ self.output_names[0]: latents,
88
+ self.output_names[1]: original_size,
89
+ self.output_names[2]: target_size,
90
+ self.output_names[3]: crop_coords,
91
+ }
92
+
93
+
94
+ class CogView4ModelSpecification(ModelSpecification):
95
+ def __init__(
96
+ self,
97
+ pretrained_model_name_or_path: str = "THUDM/CogView4-6B",
98
+ tokenizer_id: Optional[str] = None,
99
+ text_encoder_id: Optional[str] = None,
100
+ transformer_id: Optional[str] = None,
101
+ vae_id: Optional[str] = None,
102
+ text_encoder_dtype: torch.dtype = torch.bfloat16,
103
+ transformer_dtype: torch.dtype = torch.bfloat16,
104
+ vae_dtype: torch.dtype = torch.bfloat16,
105
+ revision: Optional[str] = None,
106
+ cache_dir: Optional[str] = None,
107
+ condition_model_processors: List[ProcessorMixin] = None,
108
+ latent_model_processors: List[ProcessorMixin] = None,
109
+ **kwargs,
110
+ ) -> None:
111
+ super().__init__(
112
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
113
+ tokenizer_id=tokenizer_id,
114
+ text_encoder_id=text_encoder_id,
115
+ transformer_id=transformer_id,
116
+ vae_id=vae_id,
117
+ text_encoder_dtype=text_encoder_dtype,
118
+ transformer_dtype=transformer_dtype,
119
+ vae_dtype=vae_dtype,
120
+ revision=revision,
121
+ cache_dir=cache_dir,
122
+ )
123
+
124
+ if condition_model_processors is None:
125
+ condition_model_processors = [CogView4GLMProcessor(["encoder_hidden_states"])]
126
+ if latent_model_processors is None:
127
+ latent_model_processors = [
128
+ CogView4LatentEncodeProcessor(["latents", "original_size", "target_size", "crop_coords"])
129
+ ]
130
+
131
+ self.condition_model_processors = condition_model_processors
132
+ self.latent_model_processors = latent_model_processors
133
+
134
+ @property
135
+ def _resolution_dim_keys(self):
136
+ return {"latents": (2, 3)}
137
+
138
+ def load_condition_models(self) -> Dict[str, torch.nn.Module]:
139
+ if self.tokenizer_id is not None:
140
+ tokenizer = AutoTokenizer.from_pretrained(
141
+ self.tokenizer_id, revision=self.revision, cache_dir=self.cache_dir
142
+ )
143
+ else:
144
+ tokenizer = AutoTokenizer.from_pretrained(
145
+ self.pretrained_model_name_or_path,
146
+ subfolder="tokenizer",
147
+ revision=self.revision,
148
+ cache_dir=self.cache_dir,
149
+ )
150
+
151
+ if self.text_encoder_id is not None:
152
+ text_encoder = GlmModel.from_pretrained(
153
+ self.text_encoder_id,
154
+ torch_dtype=self.text_encoder_dtype,
155
+ revision=self.revision,
156
+ cache_dir=self.cache_dir,
157
+ )
158
+ else:
159
+ text_encoder = GlmModel.from_pretrained(
160
+ self.pretrained_model_name_or_path,
161
+ subfolder="text_encoder",
162
+ torch_dtype=self.text_encoder_dtype,
163
+ revision=self.revision,
164
+ cache_dir=self.cache_dir,
165
+ )
166
+
167
+ return {"tokenizer": tokenizer, "text_encoder": text_encoder}
168
+
169
+ def load_latent_models(self) -> Dict[str, torch.nn.Module]:
170
+ if self.vae_id is not None:
171
+ vae = AutoencoderKL.from_pretrained(
172
+ self.vae_id,
173
+ torch_dtype=self.vae_dtype,
174
+ revision=self.revision,
175
+ cache_dir=self.cache_dir,
176
+ )
177
+ else:
178
+ vae = AutoencoderKL.from_pretrained(
179
+ self.pretrained_model_name_or_path,
180
+ subfolder="vae",
181
+ torch_dtype=self.vae_dtype,
182
+ revision=self.revision,
183
+ cache_dir=self.cache_dir,
184
+ )
185
+
186
+ return {"vae": vae}
187
+
188
+ def load_diffusion_models(self) -> Dict[str, torch.nn.Module]:
189
+ if self.transformer_id is not None:
190
+ transformer = CogView4Transformer2DModel.from_pretrained(
191
+ self.transformer_id,
192
+ torch_dtype=self.transformer_dtype,
193
+ revision=self.revision,
194
+ cache_dir=self.cache_dir,
195
+ )
196
+ else:
197
+ transformer = CogView4Transformer2DModel.from_pretrained(
198
+ self.pretrained_model_name_or_path,
199
+ subfolder="transformer",
200
+ torch_dtype=self.transformer_dtype,
201
+ revision=self.revision,
202
+ cache_dir=self.cache_dir,
203
+ )
204
+
205
+ scheduler = FlowMatchEulerDiscreteScheduler()
206
+
207
+ return {"transformer": transformer, "scheduler": scheduler}
208
+
209
+ def load_pipeline(
210
+ self,
211
+ tokenizer: Optional[AutoTokenizer] = None,
212
+ text_encoder: Optional[GlmModel] = None,
213
+ transformer: Optional[CogView4Transformer2DModel] = None,
214
+ vae: Optional[AutoencoderKL] = None,
215
+ scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None,
216
+ enable_slicing: bool = False,
217
+ enable_tiling: bool = False,
218
+ enable_model_cpu_offload: bool = False,
219
+ training: bool = False,
220
+ **kwargs,
221
+ ) -> CogView4Pipeline:
222
+ components = {
223
+ "tokenizer": tokenizer,
224
+ "text_encoder": text_encoder,
225
+ "transformer": transformer,
226
+ "vae": vae,
227
+ # Load the scheduler based on CogView4's config instead of using the default initialization being used for training
228
+ # "scheduler": scheduler,
229
+ }
230
+ components = get_non_null_items(components)
231
+
232
+ pipe = CogView4Pipeline.from_pretrained(
233
+ self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir
234
+ )
235
+ pipe.text_encoder.to(self.text_encoder_dtype)
236
+ pipe.vae.to(self.vae_dtype)
237
+
238
+ if not training:
239
+ pipe.transformer.to(self.transformer_dtype)
240
+
241
+ if enable_slicing:
242
+ pipe.vae.enable_slicing()
243
+ if enable_tiling:
244
+ pipe.vae.enable_tiling()
245
+ if enable_model_cpu_offload:
246
+ pipe.enable_model_cpu_offload()
247
+
248
+ return pipe
249
+
250
+ @torch.no_grad()
251
+ def prepare_conditions(
252
+ self,
253
+ tokenizer: AutoTokenizer,
254
+ text_encoder: GlmModel,
255
+ caption: str,
256
+ max_sequence_length: int = 1024,
257
+ **kwargs,
258
+ ) -> Dict[str, Any]:
259
+ conditions = {
260
+ "tokenizer": tokenizer,
261
+ "text_encoder": text_encoder,
262
+ "caption": caption,
263
+ "max_sequence_length": max_sequence_length,
264
+ **kwargs,
265
+ }
266
+ input_keys = set(conditions.keys())
267
+ conditions = super().prepare_conditions(**conditions)
268
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
269
+ return conditions
270
+
271
+ @torch.no_grad()
272
+ def prepare_latents(
273
+ self,
274
+ vae: AutoencoderKL,
275
+ image: Optional[torch.Tensor] = None,
276
+ video: Optional[torch.Tensor] = None,
277
+ generator: Optional[torch.Generator] = None,
278
+ compute_posterior: bool = True,
279
+ _original_height: Optional[int] = None,
280
+ _original_width: Optional[int] = None,
281
+ **kwargs,
282
+ ) -> Dict[str, torch.Tensor]:
283
+ conditions = {
284
+ "vae": vae,
285
+ "image": image,
286
+ "video": video,
287
+ "generator": generator,
288
+ "compute_posterior": compute_posterior,
289
+ "_original_height": _original_height,
290
+ "_original_width": _original_width,
291
+ **kwargs,
292
+ }
293
+ input_keys = set(conditions.keys())
294
+ conditions = super().prepare_latents(**conditions)
295
+ conditions = {k: v for k, v in conditions.items() if k not in input_keys}
296
+ return conditions
297
+
298
+ def forward(
299
+ self,
300
+ transformer: CogView4Transformer2DModel,
301
+ condition_model_conditions: Dict[str, torch.Tensor],
302
+ latent_model_conditions: Dict[str, torch.Tensor],
303
+ sigmas: torch.Tensor,
304
+ generator: Optional[torch.Generator] = None,
305
+ compute_posterior: bool = True,
306
+ **kwargs,
307
+ ) -> Tuple[torch.Tensor, ...]:
308
+ if compute_posterior:
309
+ latents = latent_model_conditions.pop("latents")
310
+ else:
311
+ posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents"))
312
+ latents = posterior.sample(generator=generator)
313
+ del posterior
314
+
315
+ latents = (latents - self.vae_config.shift_factor) * self.vae_config.scaling_factor
316
+ noise = torch.zeros_like(latents).normal_(generator=generator)
317
+ timesteps = (sigmas.flatten() * 1000.0).long()
318
+
319
+ base_image_sequence_length = 256
320
+ base_shift = 0.25
321
+ max_shift = 0.75
322
+
323
+ image_sequence_length = latents.size(2) * latents.size(3) // self.transformer_config.patch_size**2
324
+ mu = (image_sequence_length / base_image_sequence_length) ** 0.5
325
+ mu = mu * max_shift + base_shift
326
+ shifted_sigmas = mu / (mu + (1 / sigmas - 1) ** 1.0)
327
+ noisy_latents = FF.flow_match_xt(latents, noise, shifted_sigmas)
328
+
329
+ latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
330
+
331
+ pred = transformer(
332
+ **latent_model_conditions,
333
+ **condition_model_conditions,
334
+ timestep=timesteps,
335
+ return_dict=False,
336
+ )[0]
337
+ target = FF.flow_match_target(noise, latents)
338
+
339
+ # NOTE: shifted_sigmas loss weighting seems to work better than sigmas. Needs more investigation
340
+ # but let's keep it this way for now. Longer training runs should reveal more insights.
341
+ # return pred, target, sigmas
342
+ return pred, target, shifted_sigmas
343
+
344
+ def validation(
345
+ self,
346
+ pipeline: CogView4Pipeline,
347
+ prompt: str,
348
+ height: Optional[int] = None,
349
+ width: Optional[int] = None,
350
+ num_inference_steps: int = 50,
351
+ generator: Optional[torch.Generator] = None,
352
+ **kwargs,
353
+ ) -> List[ArtifactType]:
354
+ generation_kwargs = {
355
+ "prompt": prompt,
356
+ "height": height,
357
+ "width": width,
358
+ "num_inference_steps": num_inference_steps,
359
+ "generator": generator,
360
+ "return_dict": True,
361
+ "output_type": "pil",
362
+ }
363
+ generation_kwargs = get_non_null_items(generation_kwargs)
364
+ image = pipeline(**generation_kwargs).images[0]
365
+ return [data.ImageArtifact(value=image)]
366
+
367
+ def _save_lora_weights(
368
+ self,
369
+ directory: str,
370
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
371
+ scheduler: Optional[SchedulerType] = None,
372
+ *args,
373
+ **kwargs,
374
+ ) -> None:
375
+ # TODO(aryan): this needs refactoring
376
+ if transformer_state_dict is not None:
377
+ CogView4Pipeline.save_lora_weights(directory, transformer_state_dict, safe_serialization=True)
378
+ if scheduler is not None:
379
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
380
+
381
+ def _save_model(
382
+ self,
383
+ directory: str,
384
+ transformer: CogView4Transformer2DModel,
385
+ transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None,
386
+ scheduler: Optional[SchedulerType] = None,
387
+ ) -> None:
388
+ # TODO(aryan): this needs refactoring
389
+ if transformer_state_dict is not None:
390
+ with init_empty_weights():
391
+ transformer_copy = CogView4Transformer2DModel.from_config(transformer.config)
392
+ transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True)
393
+ transformer_copy.save_pretrained(os.path.join(directory, "transformer"))
394
+ if scheduler is not None:
395
+ scheduler.save_pretrained(os.path.join(directory, "scheduler"))
finetrainers/models/hunyuan_video/base_specification.py CHANGED
@@ -117,10 +117,7 @@ class HunyuanVideoModelSpecification(ModelSpecification):
117
 
118
  @property
119
  def _resolution_dim_keys(self):
120
- # TODO
121
- return {
122
- "latents": (2, 3, 4),
123
- }
124
 
125
  def load_condition_models(self) -> Dict[str, torch.nn.Module]:
126
  if self.tokenizer_id is not None:
 
117
 
118
  @property
119
  def _resolution_dim_keys(self):
120
+ return {"latents": (2, 3, 4)}
 
 
 
121
 
122
  def load_condition_models(self) -> Dict[str, torch.nn.Module]:
123
  if self.tokenizer_id is not None:
finetrainers/models/ltx_video/base_specification.py CHANGED
@@ -120,7 +120,7 @@ class LTXVideoModelSpecification(ModelSpecification):
120
  )
121
 
122
  if condition_model_processors is None:
123
- condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
124
  if latent_model_processors is None:
125
  latent_model_processors = [
126
  LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
@@ -131,9 +131,7 @@ class LTXVideoModelSpecification(ModelSpecification):
131
 
132
  @property
133
  def _resolution_dim_keys(self):
134
- return {
135
- "latents": (2, 3, 4),
136
- }
137
 
138
  def load_condition_models(self) -> Dict[str, torch.nn.Module]:
139
  if self.tokenizer_id is not None:
@@ -342,8 +340,6 @@ class LTXVideoModelSpecification(ModelSpecification):
342
  sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
343
 
344
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
345
- condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
346
- condition_model_conditions["encoder_attention_mask"] = condition_model_conditions.pop("prompt_attention_mask")
347
 
348
  # TODO(aryan): make this configurable
349
  frame_rate = 25
 
120
  )
121
 
122
  if condition_model_processors is None:
123
+ condition_model_processors = [T5Processor(["encoder_hidden_states", "encoder_attention_mask"])]
124
  if latent_model_processors is None:
125
  latent_model_processors = [
126
  LTXLatentEncodeProcessor(["latents", "num_frames", "height", "width", "latents_mean", "latents_std"])
 
131
 
132
  @property
133
  def _resolution_dim_keys(self):
134
+ return {"latents": (2, 3, 4)}
 
 
135
 
136
  def load_condition_models(self) -> Dict[str, torch.nn.Module]:
137
  if self.tokenizer_id is not None:
 
340
  sigmas = sigmas.view(-1, 1, 1).expand(-1, *noisy_latents.shape[1:-1], -1)
341
 
342
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
 
 
343
 
344
  # TODO(aryan): make this configurable
345
  frame_rate = 25
finetrainers/models/modeling_utils.py CHANGED
@@ -115,9 +115,6 @@ class ModelSpecification:
115
  f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
116
  )
117
 
118
- def collate_fn(self, batch: List[List[Dict[str, torch.Tensor]]]) -> Dict[str, torch.Tensor]:
119
- raise NotImplementedError(f"ModelSpecification::collate_fn is not implemented for {self.__class__.__name__}")
120
-
121
  def prepare_conditions(self, **kwargs) -> Dict[str, Any]:
122
  for processor in self.condition_model_processors:
123
  result = processor(**kwargs)
 
115
  f"ModelSpecification::load_pipeline is not implemented for {self.__class__.__name__}"
116
  )
117
 
 
 
 
118
  def prepare_conditions(self, **kwargs) -> Dict[str, Any]:
119
  for processor in self.condition_model_processors:
120
  result = processor(**kwargs)
finetrainers/models/wan/base_specification.py CHANGED
@@ -34,11 +34,6 @@ class WanLatentEncodeProcessor(ProcessorMixin):
34
  output_names (`List[str]`):
35
  The names of the outputs that the processor returns. The outputs are in the following order:
36
  - latents: The latents of the input image/video.
37
- - num_frames: The number of frames in the input video.
38
- - height: The height of the input image/video.
39
- - width: The width of the input image/video.
40
- - latents_mean: The latent channel means from the VAE state dict.
41
- - latents_std: The latent channel standard deviations from the VAE state dict.
42
  """
43
 
44
  def __init__(self, output_names: List[str]):
@@ -111,7 +106,7 @@ class WanModelSpecification(ModelSpecification):
111
  )
112
 
113
  if condition_model_processors is None:
114
- condition_model_processors = [T5Processor(["prompt_embeds", "prompt_attention_mask"])]
115
  if latent_model_processors is None:
116
  latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
117
 
@@ -120,10 +115,7 @@ class WanModelSpecification(ModelSpecification):
120
 
121
  @property
122
  def _resolution_dim_keys(self):
123
- # TODO
124
- return {
125
- "latents": (2, 3, 4),
126
- }
127
 
128
  def load_condition_models(self) -> Dict[str, torch.nn.Module]:
129
  if self.tokenizer_id is not None:
@@ -303,7 +295,6 @@ class WanModelSpecification(ModelSpecification):
303
  noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
304
 
305
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
306
- condition_model_conditions["encoder_hidden_states"] = condition_model_conditions.pop("prompt_embeds")
307
 
308
  timesteps = (sigmas.flatten() * 1000.0).long()
309
 
 
34
  output_names (`List[str]`):
35
  The names of the outputs that the processor returns. The outputs are in the following order:
36
  - latents: The latents of the input image/video.
 
 
 
 
 
37
  """
38
 
39
  def __init__(self, output_names: List[str]):
 
106
  )
107
 
108
  if condition_model_processors is None:
109
+ condition_model_processors = [T5Processor(["encoder_hidden_states", "prompt_attention_mask"])]
110
  if latent_model_processors is None:
111
  latent_model_processors = [WanLatentEncodeProcessor(["latents"])]
112
 
 
115
 
116
  @property
117
  def _resolution_dim_keys(self):
118
+ return {"latents": (2, 3, 4)}
 
 
 
119
 
120
  def load_condition_models(self) -> Dict[str, torch.nn.Module]:
121
  if self.tokenizer_id is not None:
 
295
  noisy_latents = FF.flow_match_xt(latents, noise, sigmas)
296
 
297
  latent_model_conditions["hidden_states"] = noisy_latents.to(latents)
 
298
 
299
  timesteps = (sigmas.flatten() * 1000.0).long()
300
 
finetrainers/processors/__init__.py CHANGED
@@ -1,5 +1,6 @@
1
  from .base import ProcessorMixin
2
  from .clip import CLIPPooledProcessor
 
3
  from .llama import LlamaProcessor
4
  from .t5 import T5Processor
5
  from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor
 
1
  from .base import ProcessorMixin
2
  from .clip import CLIPPooledProcessor
3
+ from .glm import CogView4GLMProcessor
4
  from .llama import LlamaProcessor
5
  from .t5 import T5Processor
6
  from .text import CaptionEmbeddingDropoutProcessor, CaptionTextDropoutProcessor
finetrainers/processors/glm.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Union
2
+
3
+ import torch
4
+ from transformers import AutoTokenizer, GlmModel
5
+
6
+ from .base import ProcessorMixin
7
+
8
+
9
+ class CogView4GLMProcessor(ProcessorMixin):
10
+ r"""
11
+ Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings
12
+ and attention masks for the input text.
13
+
14
+ This processor is specific to CogView4 but can be used with any other model.
15
+
16
+ Args:
17
+ output_names (`List[str]`):
18
+ The names of the outputs that the processor should return. The first output is the embeddings of the input
19
+ text and the second output is the attention mask for the input text.
20
+ """
21
+
22
+ def __init__(self, output_names: List[str]):
23
+ super().__init__()
24
+
25
+ self.output_names = output_names
26
+
27
+ assert len(self.output_names) == 1
28
+
29
+ def forward(
30
+ self,
31
+ tokenizer: AutoTokenizer,
32
+ text_encoder: GlmModel,
33
+ caption: Union[str, List[str]],
34
+ max_sequence_length: int,
35
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
36
+ r"""
37
+ Encode the input text and return the embeddings and attention mask for the input text.
38
+
39
+ Args:
40
+ tokenizer (`AutoTokenizer`):
41
+ The tokenizer used to tokenize the input text.
42
+ text_encoder (`GlmModel`):
43
+ The text encoder used to encode the input text.
44
+ caption (`Union[str, List[str]]`):
45
+ The input text to be encoded.
46
+ max_sequence_length (`int`):
47
+ The maximum sequence length of the input text.
48
+ """
49
+ if isinstance(caption, str):
50
+ caption = [caption]
51
+
52
+ device = text_encoder.device
53
+ dtype = text_encoder.dtype
54
+
55
+ text_inputs = tokenizer(
56
+ caption,
57
+ padding="longest",
58
+ max_length=max_sequence_length,
59
+ truncation=True,
60
+ add_special_tokens=True,
61
+ return_tensors="pt",
62
+ )
63
+ text_input_ids = text_inputs.input_ids.to(device)
64
+
65
+ current_length = text_input_ids.size(1)
66
+ pad_length = 16 - current_length % 16
67
+ if pad_length > 0:
68
+ pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id)
69
+ text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
70
+
71
+ prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
72
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
73
+
74
+ return {self.output_names[0]: prompt_embeds}
finetrainers/trainer/sft_trainer/trainer.py CHANGED
@@ -2,6 +2,7 @@ import functools
2
  import json
3
  import math
4
  import os
 
5
  from pathlib import Path
6
  from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
7
 
@@ -33,6 +34,13 @@ logger = logging.get_logger()
33
 
34
 
35
  class SFTTrainer:
 
 
 
 
 
 
 
36
  def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None:
37
  self.args = args
38
  self.state = State()
@@ -72,6 +80,7 @@ class SFTTrainer:
72
  patches.perform_patches_for_training(self.args, self.state.parallel_backend)
73
 
74
  self.model_specification = model_specification
 
75
 
76
  def run(self) -> None:
77
  try:
@@ -254,12 +263,15 @@ class SFTTrainer:
254
  data_root = config.pop("data_root", None)
255
  dataset_file = config.pop("dataset_file", None)
256
  dataset_type = config.pop("dataset_type")
 
257
 
258
  if data_root is not None and dataset_file is not None:
259
  raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.")
260
 
261
  dataset_name_or_root = data_root or dataset_file
262
- dataset = data.initialize_dataset(dataset_name_or_root, dataset_type, streaming=True, infinite=True)
 
 
263
 
264
  if not dataset._precomputable_once and self.args.precomputation_once:
265
  raise ValueError(
@@ -369,9 +381,9 @@ class SFTTrainer:
369
  self.transformer.train()
370
  data_iterator = iter(self.dataloader)
371
 
372
- preprocessor = data.DistributedDataPreprocessor(
373
  rank=parallel_backend.rank,
374
- num_items=self.args.precomputation_items,
375
  processor_fn={
376
  "condition": self.model_specification.prepare_conditions,
377
  "latent": functools.partial(
@@ -379,6 +391,7 @@ class SFTTrainer:
379
  ),
380
  },
381
  save_dir=self.args.precomputation_dir,
 
382
  )
383
  precomputed_condition_iterator: Iterable[Dict[str, Any]] = None
384
  precomputed_latent_iterator: Iterable[Dict[str, Any]] = None
@@ -495,7 +508,6 @@ class SFTTrainer:
495
 
496
  if train_state.step % self.args.gradient_accumulation_steps == 0:
497
  # TODO(aryan): revisit no_sync() for FSDP
498
- # TODO(aryan): average the gradients for accumulation?
499
  self.optimizer.step()
500
  self.lr_scheduler.step()
501
  self.optimizer.zero_grad()
@@ -651,28 +663,29 @@ class SFTTrainer:
651
  # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
652
  for index, (key, artifact) in enumerate(list(artifacts.items())):
653
  assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))
 
 
654
  filename = "validation-" if not final_validation else "final-"
655
- filename += f"{step}-{parallel_backend.rank}-{index}-{prompt_filename}.{artifact.file_extension}"
656
  output_filename = os.path.join(self.args.output_dir, filename)
657
 
658
  if parallel_backend.is_main_process and artifact.file_extension == "mp4":
659
  main_process_prompts_to_filenames[PROMPT] = filename
660
 
661
- caption = f"{PROMPT} | (filename: {output_filename})"
662
  if artifact.type == "image" and artifact.value is not None:
663
  logger.debug(
664
  f"Saving image from rank={parallel_backend.rank} to {output_filename}",
665
  local_main_process_only=False,
666
  )
667
  artifact.value.save(output_filename)
668
- all_processes_artifacts.append(wandb.Image(output_filename, caption=caption))
669
  elif artifact.type == "video" and artifact.value is not None:
670
  logger.debug(
671
  f"Saving video from rank={parallel_backend.rank} to {output_filename}",
672
  local_main_process_only=False,
673
  )
674
  export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
675
- all_processes_artifacts.append(wandb.Video(output_filename, caption=caption))
676
 
677
  # 3. Cleanup & log artifacts
678
  parallel_backend.wait_for_everyone()
@@ -804,24 +817,16 @@ class SFTTrainer:
804
  component.to(device)
805
 
806
  def _set_components(self, components: Dict[str, Any]) -> None:
807
- # fmt: off
808
- component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
809
- # fmt: on
810
-
811
- for component_name in component_names:
812
  existing_component = getattr(self, component_name, None)
813
  new_component = components.get(component_name, existing_component)
814
  setattr(self, component_name, new_component)
815
 
816
  def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
817
  if component_names is None:
818
- # fmt: off
819
- component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
820
- # fmt: on
821
-
822
  for component_name in component_names:
823
  setattr(self, component_name, None)
824
-
825
  utils.free_memory()
826
  utils.synchronize_device()
827
 
@@ -848,7 +853,6 @@ class SFTTrainer:
848
  training=True,
849
  )
850
  else:
851
- # TODO(aryan): this branch does not work yet, needs to be implemented
852
  self._delete_components()
853
 
854
  # Load the transformer weights from the final checkpoint if performing full-finetune
@@ -874,50 +878,101 @@ class SFTTrainer:
874
  self._move_components_to_device(list(components.values()))
875
  return pipeline
876
 
877
- def _prepare_data(self, preprocessor: data.DistributedDataPreprocessor, data_iterator):
878
- logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
879
- if self.args.precomputation_once:
880
- consume_fn = preprocessor.consume_once
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
881
  else:
882
- consume_fn = preprocessor.consume
883
-
884
- condition_components = self.model_specification.load_condition_models()
885
- component_names = list(condition_components.keys())
886
- component_modules = list(condition_components.values())
887
- self._set_components(condition_components)
888
- self._move_components_to_device(component_modules)
889
- precomputed_condition_iterator = consume_fn(
890
- "condition",
891
- components=condition_components,
892
- data_iterator=data_iterator,
893
- generator=self.state.generator,
894
- cache_samples=True,
895
- )
896
- self._delete_components(component_names)
897
- del condition_components, component_names, component_modules
898
-
899
- latent_components = self.model_specification.load_latent_models()
900
- if self.vae is not None:
901
- if self.args.enable_slicing:
902
- self.vae.enable_slicing()
903
- if self.args.enable_tiling:
904
- self.vae.enable_tiling()
905
- component_names = list(latent_components.keys())
906
- component_modules = list(latent_components.values())
907
- self._set_components(latent_components)
908
- self._move_components_to_device(component_modules)
909
- precomputed_latent_iterator = consume_fn(
910
- "latent",
911
- components=latent_components,
912
- data_iterator=data_iterator,
913
- generator=self.state.generator,
914
- use_cached_samples=True,
915
- drop_samples=True,
916
- )
917
- self._delete_components(component_names)
918
- del latent_components, component_names, component_modules
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
919
 
920
- return precomputed_condition_iterator, precomputed_latent_iterator
921
 
922
  def _get_training_info(self) -> Dict[str, Any]:
923
  info = self.args.to_dict()
 
2
  import json
3
  import math
4
  import os
5
+ import time
6
  from pathlib import Path
7
  from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Union
8
 
 
34
 
35
 
36
  class SFTTrainer:
37
+ # fmt: off
38
+ _all_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3", "transformer", "unet", "vae", "scheduler"]
39
+ _condition_component_names = ["tokenizer", "tokenizer_2", "tokenizer_3", "text_encoder", "text_encoder_2", "text_encoder_3"]
40
+ _latent_component_names = ["vae"]
41
+ _diffusion_component_names = ["transformer", "unet", "scheduler"]
42
+ # fmt: on
43
+
44
  def __init__(self, args: "BaseArgs", model_specification: "ModelSpecification") -> None:
45
  self.args = args
46
  self.state = State()
 
80
  patches.perform_patches_for_training(self.args, self.state.parallel_backend)
81
 
82
  self.model_specification = model_specification
83
+ self._are_condition_models_loaded = False
84
 
85
  def run(self) -> None:
86
  try:
 
263
  data_root = config.pop("data_root", None)
264
  dataset_file = config.pop("dataset_file", None)
265
  dataset_type = config.pop("dataset_type")
266
+ caption_options = config.pop("caption_options", {})
267
 
268
  if data_root is not None and dataset_file is not None:
269
  raise ValueError("Both data_root and dataset_file cannot be provided in the same dataset config.")
270
 
271
  dataset_name_or_root = data_root or dataset_file
272
+ dataset = data.initialize_dataset(
273
+ dataset_name_or_root, dataset_type, streaming=True, infinite=True, _caption_options=caption_options
274
+ )
275
 
276
  if not dataset._precomputable_once and self.args.precomputation_once:
277
  raise ValueError(
 
381
  self.transformer.train()
382
  data_iterator = iter(self.dataloader)
383
 
384
+ preprocessor = data.initialize_preprocessor(
385
  rank=parallel_backend.rank,
386
+ num_items=self.args.precomputation_items if self.args.enable_precomputation else 1,
387
  processor_fn={
388
  "condition": self.model_specification.prepare_conditions,
389
  "latent": functools.partial(
 
391
  ),
392
  },
393
  save_dir=self.args.precomputation_dir,
394
+ enable_precomputation=self.args.enable_precomputation,
395
  )
396
  precomputed_condition_iterator: Iterable[Dict[str, Any]] = None
397
  precomputed_latent_iterator: Iterable[Dict[str, Any]] = None
 
508
 
509
  if train_state.step % self.args.gradient_accumulation_steps == 0:
510
  # TODO(aryan): revisit no_sync() for FSDP
 
511
  self.optimizer.step()
512
  self.lr_scheduler.step()
513
  self.optimizer.zero_grad()
 
663
  # TODO(aryan): Currently, we only support WandB so we've hardcoded it here. Needs to be revisited.
664
  for index, (key, artifact) in enumerate(list(artifacts.items())):
665
  assert isinstance(artifact, (data.ImageArtifact, data.VideoArtifact))
666
+
667
+ time_, rank, ext = int(time.time()), parallel_backend.rank, artifact.file_extension
668
  filename = "validation-" if not final_validation else "final-"
669
+ filename += f"{step}-{rank}-{index}-{prompt_filename}-{time_}.{ext}"
670
  output_filename = os.path.join(self.args.output_dir, filename)
671
 
672
  if parallel_backend.is_main_process and artifact.file_extension == "mp4":
673
  main_process_prompts_to_filenames[PROMPT] = filename
674
 
 
675
  if artifact.type == "image" and artifact.value is not None:
676
  logger.debug(
677
  f"Saving image from rank={parallel_backend.rank} to {output_filename}",
678
  local_main_process_only=False,
679
  )
680
  artifact.value.save(output_filename)
681
+ all_processes_artifacts.append(wandb.Image(output_filename, caption=PROMPT))
682
  elif artifact.type == "video" and artifact.value is not None:
683
  logger.debug(
684
  f"Saving video from rank={parallel_backend.rank} to {output_filename}",
685
  local_main_process_only=False,
686
  )
687
  export_to_video(artifact.value, output_filename, fps=EXPORT_FPS)
688
+ all_processes_artifacts.append(wandb.Video(output_filename, caption=PROMPT))
689
 
690
  # 3. Cleanup & log artifacts
691
  parallel_backend.wait_for_everyone()
 
817
  component.to(device)
818
 
819
  def _set_components(self, components: Dict[str, Any]) -> None:
820
+ for component_name in self._all_component_names:
 
 
 
 
821
  existing_component = getattr(self, component_name, None)
822
  new_component = components.get(component_name, existing_component)
823
  setattr(self, component_name, new_component)
824
 
825
  def _delete_components(self, component_names: Optional[List[str]] = None) -> None:
826
  if component_names is None:
827
+ component_names = self._all_component_names
 
 
 
828
  for component_name in component_names:
829
  setattr(self, component_name, None)
 
830
  utils.free_memory()
831
  utils.synchronize_device()
832
 
 
853
  training=True,
854
  )
855
  else:
 
856
  self._delete_components()
857
 
858
  # Load the transformer weights from the final checkpoint if performing full-finetune
 
878
  self._move_components_to_device(list(components.values()))
879
  return pipeline
880
 
881
+ def _prepare_data(
882
+ self,
883
+ preprocessor: Union[data.InMemoryDistributedDataPreprocessor, data.PrecomputedDistributedDataPreprocessor],
884
+ data_iterator,
885
+ ):
886
+ if not self.args.enable_precomputation:
887
+ if not self._are_condition_models_loaded:
888
+ logger.info(
889
+ "Precomputation disabled. Loading in-memory data loaders. All components will be loaded on GPUs."
890
+ )
891
+ condition_components = self.model_specification.load_condition_models()
892
+ latent_components = self.model_specification.load_latent_models()
893
+ all_components = {**condition_components, **latent_components}
894
+ self._set_components(all_components)
895
+ self._move_components_to_device(list(all_components.values()))
896
+ utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
897
+ else:
898
+ condition_components = {k: v for k in self._condition_component_names if (v := getattr(self, k, None))}
899
+ latent_components = {k: v for k in self._latent_component_names if (v := getattr(self, k, None))}
900
+
901
+ condition_iterator = preprocessor.consume(
902
+ "condition",
903
+ components=condition_components,
904
+ data_iterator=data_iterator,
905
+ generator=self.state.generator,
906
+ cache_samples=True,
907
+ )
908
+ latent_iterator = preprocessor.consume(
909
+ "latent",
910
+ components=latent_components,
911
+ data_iterator=data_iterator,
912
+ generator=self.state.generator,
913
+ use_cached_samples=True,
914
+ drop_samples=True,
915
+ )
916
+
917
+ self._are_condition_models_loaded = True
918
  else:
919
+ logger.info("Precomputed condition & latent data exhausted. Loading & preprocessing new data.")
920
+
921
+ # TODO(aryan): This needs to be revisited. For some reason, the tests did not detect that self.transformer
922
+ # had become None after this but should have been loaded back from the checkpoint.
923
+ # parallel_backend = self.state.parallel_backend
924
+ # train_state = self.state.train_state
925
+ # self.checkpointer.save(
926
+ # train_state.step,
927
+ # force=True,
928
+ # _device=parallel_backend.device,
929
+ # _is_main_process=parallel_backend.is_main_process,
930
+ # )
931
+ # self._delete_components(component_names=["transformer", "unet"])
932
+
933
+ if self.args.precomputation_once:
934
+ consume_fn = preprocessor.consume_once
935
+ else:
936
+ consume_fn = preprocessor.consume
937
+
938
+ # Prepare condition iterators
939
+ condition_components = self.model_specification.load_condition_models()
940
+ component_names = list(condition_components.keys())
941
+ component_modules = list(condition_components.values())
942
+ self._set_components(condition_components)
943
+ self._move_components_to_device(component_modules)
944
+ condition_iterator = consume_fn(
945
+ "condition",
946
+ components=condition_components,
947
+ data_iterator=data_iterator,
948
+ generator=self.state.generator,
949
+ cache_samples=True,
950
+ )
951
+ self._delete_components(component_names)
952
+ del condition_components, component_names, component_modules
953
+
954
+ # Prepare latent iterators
955
+ latent_components = self.model_specification.load_latent_models()
956
+ utils._enable_vae_memory_optimizations(self.vae, self.args.enable_slicing, self.args.enable_tiling)
957
+ component_names = list(latent_components.keys())
958
+ component_modules = list(latent_components.values())
959
+ self._set_components(latent_components)
960
+ self._move_components_to_device(component_modules)
961
+ latent_iterator = consume_fn(
962
+ "latent",
963
+ components=latent_components,
964
+ data_iterator=data_iterator,
965
+ generator=self.state.generator,
966
+ use_cached_samples=True,
967
+ drop_samples=True,
968
+ )
969
+ self._delete_components(component_names)
970
+ del latent_components, component_names, component_modules
971
+
972
+ # self.checkpointer.load()
973
+ # self.transformer = self.checkpointer.states["model"].model[0]
974
 
975
+ return condition_iterator, latent_iterator
976
 
977
  def _get_training_info(self) -> Dict[str, Any]:
978
  info = self.args.to_dict()
finetrainers/utils/__init__.py CHANGED
@@ -4,6 +4,7 @@ from typing import Any, Dict, List, Optional, Set, Tuple, Union
4
  from .activation_checkpoint import apply_activation_checkpointing
5
  from .data import determine_batch_size, should_perform_precomputation
6
  from .diffusion import (
 
7
  default_flow_shift,
8
  get_scheduler_alphas,
9
  get_scheduler_sigmas,
 
4
  from .activation_checkpoint import apply_activation_checkpointing
5
  from .data import determine_batch_size, should_perform_precomputation
6
  from .diffusion import (
7
+ _enable_vae_memory_optimizations,
8
  default_flow_shift,
9
  get_scheduler_alphas,
10
  get_scheduler_sigmas,
finetrainers/utils/diffusion.py CHANGED
@@ -143,3 +143,10 @@ def prepare_target(
143
  raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
144
 
145
  return target
 
 
 
 
 
 
 
 
143
  raise ValueError(f"Unsupported scheduler type {type(scheduler)}")
144
 
145
  return target
146
+
147
+
148
+ def _enable_vae_memory_optimizations(vae, enable_slicing: bool = False, enable_tiling: bool = False):
149
+ if hasattr(vae, "enable_slicing") and enable_slicing:
150
+ vae.enable_slicing()
151
+ if hasattr(vae, "enable_tiling") and enable_tiling:
152
+ vae.enable_tiling()
requirements.txt CHANGED
@@ -40,5 +40,5 @@ av==14.1.0
40
  git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
41
 
42
  # for our frontend
43
- gradio==5.15.0
44
  gradio_toggle
 
40
  git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
41
 
42
  # for our frontend
43
+ gradio==5.20.1
44
  gradio_toggle
requirements_without_flash_attention.txt CHANGED
@@ -39,5 +39,5 @@ av==14.1.0
39
  git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
40
 
41
  # for our frontend
42
- gradio==5.15.0
43
  gradio_toggle
 
39
  git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
40
 
41
  # for our frontend
42
+ gradio==5.20.1
43
  gradio_toggle