Spaces:
Running
on
Zero
Running
on
Zero
Upload 80 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- LICENSE.md +173 -0
- README.md +106 -10
- configs/default_config.yaml +20 -0
- configs/self_forcing_dmd.yaml +51 -0
- configs/self_forcing_sid.yaml +53 -0
- demo.py +631 -0
- demo_utils/constant.py +41 -0
- demo_utils/memory.py +135 -0
- demo_utils/taehv.py +313 -0
- demo_utils/utils.py +616 -0
- demo_utils/vae.py +390 -0
- demo_utils/vae_block3.py +291 -0
- demo_utils/vae_torch2trt.py +308 -0
- images/.gitkeep +0 -0
- inference.py +179 -0
- model/__init__.py +14 -0
- model/base.py +222 -0
- model/causvid.py +391 -0
- model/diffusion.py +125 -0
- model/dmd.py +332 -0
- model/gan.py +295 -0
- model/ode_regression.py +138 -0
- model/sid.py +283 -0
- pipeline/__init__.py +13 -0
- pipeline/bidirectional_diffusion_inference.py +110 -0
- pipeline/bidirectional_inference.py +71 -0
- pipeline/causal_diffusion_inference.py +342 -0
- pipeline/causal_inference.py +305 -0
- pipeline/self_forcing_training.py +267 -0
- prompts/MovieGenVideoBench.txt +0 -0
- prompts/MovieGenVideoBench_extended.txt +0 -0
- prompts/vbench/all_dimension.txt +946 -0
- prompts/vbench/all_dimension_extended.txt +0 -0
- requirements.txt +40 -0
- scripts/create_lmdb_14b_shards.py +101 -0
- scripts/create_lmdb_iterative.py +60 -0
- scripts/generate_ode_pairs.py +120 -0
- setup.py +6 -0
- templates/demo.html +615 -0
- train.py +47 -0
- trainer/__init__.py +11 -0
- trainer/diffusion.py +265 -0
- trainer/distillation.py +388 -0
- trainer/gan.py +464 -0
- trainer/ode.py +242 -0
- utils/dataset.py +220 -0
- utils/distributed.py +125 -0
- utils/lmdb.py +72 -0
- utils/loss.py +81 -0
- utils/misc.py +39 -0
LICENSE.md
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
|
4 |
+
|
5 |
+
### Using Creative Commons Public Licenses
|
6 |
+
|
7 |
+
Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
|
8 |
+
|
9 |
+
* __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
|
10 |
+
|
11 |
+
* __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
|
12 |
+
|
13 |
+
## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
14 |
+
|
15 |
+
By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
|
16 |
+
|
17 |
+
### Section 1 – Definitions.
|
18 |
+
|
19 |
+
a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
|
20 |
+
|
21 |
+
b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
|
22 |
+
|
23 |
+
c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
|
24 |
+
|
25 |
+
d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
|
26 |
+
|
27 |
+
e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
|
28 |
+
|
29 |
+
f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
|
30 |
+
|
31 |
+
g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
|
32 |
+
|
33 |
+
h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
|
34 |
+
|
35 |
+
i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
|
36 |
+
|
37 |
+
j. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
|
38 |
+
|
39 |
+
k. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
|
40 |
+
|
41 |
+
l. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
|
42 |
+
|
43 |
+
m. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
|
44 |
+
|
45 |
+
n. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
|
46 |
+
|
47 |
+
### Section 2 – Scope.
|
48 |
+
|
49 |
+
a. ___License grant.___
|
50 |
+
|
51 |
+
1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
|
52 |
+
|
53 |
+
A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
|
54 |
+
|
55 |
+
B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
|
56 |
+
|
57 |
+
2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
|
58 |
+
|
59 |
+
3. __Term.__ The term of this Public License is specified in Section 6(a).
|
60 |
+
|
61 |
+
4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
|
62 |
+
|
63 |
+
5. __Downstream recipients.__
|
64 |
+
|
65 |
+
A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
|
66 |
+
|
67 |
+
B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
|
68 |
+
|
69 |
+
C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
|
70 |
+
|
71 |
+
6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
|
72 |
+
|
73 |
+
b. ___Other rights.___
|
74 |
+
|
75 |
+
1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
|
76 |
+
|
77 |
+
2. Patent and trademark rights are not licensed under this Public License.
|
78 |
+
|
79 |
+
3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
|
80 |
+
|
81 |
+
### Section 3 – License Conditions.
|
82 |
+
|
83 |
+
Your exercise of the Licensed Rights is expressly made subject to the following conditions.
|
84 |
+
|
85 |
+
a. ___Attribution.___
|
86 |
+
|
87 |
+
1. If You Share the Licensed Material (including in modified form), You must:
|
88 |
+
|
89 |
+
A. retain the following if it is supplied by the Licensor with the Licensed Material:
|
90 |
+
|
91 |
+
i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
|
92 |
+
|
93 |
+
ii. a copyright notice;
|
94 |
+
|
95 |
+
iii. a notice that refers to this Public License;
|
96 |
+
|
97 |
+
iv. a notice that refers to the disclaimer of warranties;
|
98 |
+
|
99 |
+
v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
|
100 |
+
|
101 |
+
B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
|
102 |
+
|
103 |
+
C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
|
104 |
+
|
105 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
|
106 |
+
|
107 |
+
3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
|
108 |
+
|
109 |
+
b. ___ShareAlike.___
|
110 |
+
|
111 |
+
In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
|
112 |
+
|
113 |
+
1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
|
114 |
+
|
115 |
+
2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
|
116 |
+
|
117 |
+
3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
|
118 |
+
|
119 |
+
### Section 4 – Sui Generis Database Rights.
|
120 |
+
|
121 |
+
Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
|
122 |
+
|
123 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
|
124 |
+
|
125 |
+
b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
|
126 |
+
|
127 |
+
c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
|
128 |
+
|
129 |
+
For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
|
130 |
+
|
131 |
+
### Section 5 – Disclaimer of Warranties and Limitation of Liability.
|
132 |
+
|
133 |
+
a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
|
134 |
+
|
135 |
+
b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
|
136 |
+
|
137 |
+
c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
|
138 |
+
|
139 |
+
### Section 6 – Term and Termination.
|
140 |
+
|
141 |
+
a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
|
142 |
+
|
143 |
+
b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
|
144 |
+
|
145 |
+
1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
|
146 |
+
|
147 |
+
2. upon express reinstatement by the Licensor.
|
148 |
+
|
149 |
+
For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
|
150 |
+
|
151 |
+
c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
|
152 |
+
|
153 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
|
154 |
+
|
155 |
+
### Section 7 – Other Terms and Conditions.
|
156 |
+
|
157 |
+
a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
|
158 |
+
|
159 |
+
b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
|
160 |
+
|
161 |
+
### Section 8 – Interpretation.
|
162 |
+
|
163 |
+
a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
|
164 |
+
|
165 |
+
b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
|
166 |
+
|
167 |
+
c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
|
168 |
+
|
169 |
+
d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
|
170 |
+
|
171 |
+
> Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
|
172 |
+
>
|
173 |
+
> Creative Commons may be contacted at creativecommons.org
|
README.md
CHANGED
@@ -1,13 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
colorTo: gray
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.34.1
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
short_description: Real-Time video generation with Self-Forcing Wan 1.3B
|
11 |
---
|
12 |
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<p align="center">
|
2 |
+
<h1 align="center">Self Forcing</h1>
|
3 |
+
<h3 align="center">Bridging the Train-Test Gap in Autoregressive Video Diffusion</h3>
|
4 |
+
</p>
|
5 |
+
<p align="center">
|
6 |
+
<p align="center">
|
7 |
+
<a href="https://www.xunhuang.me/">Xun Huang</a><sup>1</sup>
|
8 |
+
·
|
9 |
+
<a href="https://zhengqili.github.io/">Zhengqi Li</a><sup>1</sup>
|
10 |
+
·
|
11 |
+
<a href="https://guandehe.github.io/">Guande He</a><sup>2</sup>
|
12 |
+
·
|
13 |
+
<a href="https://mingyuanzhou.github.io/">Mingyuan Zhou</a><sup>2</sup>
|
14 |
+
·
|
15 |
+
<a href="https://research.adobe.com/person/eli-shechtman/">Eli Shechtman</a><sup>1</sup><br>
|
16 |
+
<sup>1</sup>Adobe Research <sup>2</sup>UT Austin
|
17 |
+
</p>
|
18 |
+
<h3 align="center"><a href="https://arxiv.org/abs/2506.08009">Paper</a> | <a href="https://self-forcing.github.io">Website</a> | <a href="https://huggingface.co/gdhe17/Self-Forcing/tree/main">Models (HuggingFace)</a></h3>
|
19 |
+
</p>
|
20 |
+
|
21 |
---
|
22 |
+
|
23 |
+
Self Forcing trains autoregressive video diffusion models by **simulating the inference process during training**, performing autoregressive rollout with KV caching. It resolves the train-test distribution mismatch and enables **real-time, streaming video generation on a single RTX 4090** while matching the quality of state-of-the-art diffusion models.
|
24 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
---
|
26 |
|
27 |
+
|
28 |
+
https://github.com/user-attachments/assets/7548c2db-fe03-4ba8-8dd3-52d2c6160739
|
29 |
+
|
30 |
+
|
31 |
+
## Requirements
|
32 |
+
We tested this repo on the following setup:
|
33 |
+
* Nvidia GPU with at least 24 GB memory (RTX 4090, A100, and H100 are tested).
|
34 |
+
* Linux operating system.
|
35 |
+
* 64 GB RAM.
|
36 |
+
|
37 |
+
Other hardware setup could also work but hasn't been tested.
|
38 |
+
|
39 |
+
## Installation
|
40 |
+
Create a conda environment and install dependencies:
|
41 |
+
```
|
42 |
+
conda create -n self_forcing python=3.10 -y
|
43 |
+
conda activate self_forcing
|
44 |
+
pip install -r requirements.txt
|
45 |
+
pip install flash-attn --no-build-isolation
|
46 |
+
python setup.py develop
|
47 |
+
```
|
48 |
+
|
49 |
+
## Quick Start
|
50 |
+
### Download checkpoints
|
51 |
+
```
|
52 |
+
huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B
|
53 |
+
huggingface-cli download gdhe17/Self-Forcing checkpoints/self_forcing_dmd.pt --local-dir .
|
54 |
+
```
|
55 |
+
|
56 |
+
### GUI demo
|
57 |
+
```
|
58 |
+
python demo.py
|
59 |
+
```
|
60 |
+
Note:
|
61 |
+
* **Our model works better with long, detailed prompts** since it's trained with such prompts. We will integrate prompt extension into the codebase (similar to [Wan2.1](https://github.com/Wan-Video/Wan2.1/tree/main?tab=readme-ov-file#2-using-prompt-extention)) in the future. For now, it is recommended to use third-party LLMs (such as GPT-4o) to extend your prompt before providing to the model.
|
62 |
+
* You may want to adjust FPS so it plays smoothly on your device.
|
63 |
+
* The speed can be improved by enabling `torch.compile`, [TAEHV-VAE](https://github.com/madebyollin/taehv/), or using FP8 Linear layers, although the latter two options may sacrifice quality. It is recommended to use `torch.compile` if possible and enable TAEHV-VAE if further speedup is needed.
|
64 |
+
|
65 |
+
### CLI Inference
|
66 |
+
Example inference script using the chunk-wise autoregressive checkpoint trained with DMD:
|
67 |
+
```
|
68 |
+
python inference.py \
|
69 |
+
--config_path configs/self_forcing_dmd.yaml \
|
70 |
+
--output_folder videos/self_forcing_dmd \
|
71 |
+
--checkpoint_path checkpoints/self_forcing_dmd.pt \
|
72 |
+
--data_path prompts/MovieGenVideoBench_extended.txt \
|
73 |
+
--use_ema
|
74 |
+
```
|
75 |
+
Other config files and corresponding checkpoints can be found in [configs](configs) folder and our [huggingface repo](https://huggingface.co/gdhe17/Self-Forcing/tree/main/checkpoints).
|
76 |
+
|
77 |
+
## Training
|
78 |
+
### Download text prompts and ODE initialized checkpoint
|
79 |
+
```
|
80 |
+
huggingface-cli download gdhe17/Self-Forcing checkpoints/ode_init.pt --local-dir .
|
81 |
+
huggingface-cli download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir prompts
|
82 |
+
```
|
83 |
+
Note: Our training algorithm (except for the GAN version) is data-free (**no video data is needed**). For now, we directly provide the ODE initialization checkpoint and will add more instructions on how to perform ODE initialization in the future (which is identical to the process described in the [CausVid](https://github.com/tianweiy/CausVid) repo).
|
84 |
+
|
85 |
+
### Self Forcing Training with DMD
|
86 |
+
```
|
87 |
+
torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
|
88 |
+
--rdzv_backend=c10d \
|
89 |
+
--rdzv_endpoint $MASTER_ADDR \
|
90 |
+
train.py \
|
91 |
+
--config_path configs/self_forcing_dmd.yaml \
|
92 |
+
--logdir logs/self_forcing_dmd \
|
93 |
+
--disable-wandb
|
94 |
+
```
|
95 |
+
Our training run uses 600 iterations and completes in under 2 hours using 64 H100 GPUs. By implementing gradient accumulation, it should be possible to reproduce the results in less than 16 hours using 8 H100 GPUs.
|
96 |
+
|
97 |
+
## Acknowledgements
|
98 |
+
This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid) by [Tianwei Yin](https://tianweiy.github.io/) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
|
99 |
+
|
100 |
+
## Citation
|
101 |
+
If you find this codebase useful for your research, please kindly cite our paper:
|
102 |
+
```
|
103 |
+
@article{huang2025selfforcing,
|
104 |
+
title={Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion},
|
105 |
+
author={Huang, Xun and Li, Zhengqi and He, Guande and Zhou, Mingyuan and Shechtman, Eli},
|
106 |
+
journal={arXiv preprint arXiv:2506.08009},
|
107 |
+
year={2025}
|
108 |
+
}
|
109 |
+
```
|
configs/default_config.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
independent_first_frame: false
|
2 |
+
warp_denoising_step: false
|
3 |
+
weight_decay: 0.01
|
4 |
+
same_step_across_blocks: true
|
5 |
+
discriminator_lr_multiplier: 1.0
|
6 |
+
last_step_only: false
|
7 |
+
i2v: false
|
8 |
+
num_training_frames: 21
|
9 |
+
gc_interval: 100
|
10 |
+
context_noise: 0
|
11 |
+
causal: true
|
12 |
+
|
13 |
+
ckpt_step: 0
|
14 |
+
prompt_name: MovieGenVideoBench
|
15 |
+
prompt_path: prompts/MovieGenVideoBench.txt
|
16 |
+
eval_first_n: 64
|
17 |
+
num_samples: 1
|
18 |
+
height: 480
|
19 |
+
width: 832
|
20 |
+
num_frames: 81
|
configs/self_forcing_dmd.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_ckpt: checkpoints/ode_init.pt
|
2 |
+
generator_fsdp_wrap_strategy: size
|
3 |
+
real_score_fsdp_wrap_strategy: size
|
4 |
+
fake_score_fsdp_wrap_strategy: size
|
5 |
+
real_name: Wan2.1-T2V-14B
|
6 |
+
text_encoder_fsdp_wrap_strategy: size
|
7 |
+
denoising_step_list:
|
8 |
+
- 1000
|
9 |
+
- 750
|
10 |
+
- 500
|
11 |
+
- 250
|
12 |
+
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
|
13 |
+
ts_schedule: false
|
14 |
+
num_train_timestep: 1000
|
15 |
+
timestep_shift: 5.0
|
16 |
+
guidance_scale: 3.0
|
17 |
+
denoising_loss_type: flow
|
18 |
+
mixed_precision: true
|
19 |
+
seed: 0
|
20 |
+
wandb_host: WANDB_HOST
|
21 |
+
wandb_key: WANDB_KEY
|
22 |
+
wandb_entity: WANDB_ENTITY
|
23 |
+
wandb_project: WANDB_PROJECT
|
24 |
+
sharding_strategy: hybrid_full
|
25 |
+
lr: 2.0e-06
|
26 |
+
lr_critic: 4.0e-07
|
27 |
+
beta1: 0.0
|
28 |
+
beta2: 0.999
|
29 |
+
beta1_critic: 0.0
|
30 |
+
beta2_critic: 0.999
|
31 |
+
data_path: prompts/vidprom_filtered_extended.txt
|
32 |
+
batch_size: 1
|
33 |
+
ema_weight: 0.99
|
34 |
+
ema_start_step: 200
|
35 |
+
total_batch_size: 64
|
36 |
+
log_iters: 50
|
37 |
+
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
38 |
+
dfake_gen_update_ratio: 5
|
39 |
+
image_or_video_shape:
|
40 |
+
- 1
|
41 |
+
- 21
|
42 |
+
- 16
|
43 |
+
- 60
|
44 |
+
- 104
|
45 |
+
distribution_loss: dmd
|
46 |
+
trainer: score_distillation
|
47 |
+
gradient_checkpointing: true
|
48 |
+
num_frame_per_block: 3
|
49 |
+
load_raw_video: false
|
50 |
+
model_kwargs:
|
51 |
+
timestep_shift: 5.0
|
configs/self_forcing_sid.yaml
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
generator_ckpt: checkpoints/ode_init.pt
|
2 |
+
generator_fsdp_wrap_strategy: size
|
3 |
+
real_score_fsdp_wrap_strategy: size
|
4 |
+
fake_score_fsdp_wrap_strategy: size
|
5 |
+
real_name: Wan2.1-T2V-1.3B
|
6 |
+
text_encoder_fsdp_wrap_strategy: size
|
7 |
+
denoising_step_list:
|
8 |
+
- 1000
|
9 |
+
- 750
|
10 |
+
- 500
|
11 |
+
- 250
|
12 |
+
warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
|
13 |
+
ts_schedule: false
|
14 |
+
num_train_timestep: 1000
|
15 |
+
timestep_shift: 5.0
|
16 |
+
guidance_scale: 3.0
|
17 |
+
denoising_loss_type: flow
|
18 |
+
mixed_precision: true
|
19 |
+
seed: 0
|
20 |
+
wandb_host: WANDB_HOST
|
21 |
+
wandb_key: WANDB_KEY
|
22 |
+
wandb_entity: WANDB_ENTITY
|
23 |
+
wandb_project: WANDB_PROJECT
|
24 |
+
sharding_strategy: hybrid_full
|
25 |
+
lr: 2.0e-06
|
26 |
+
lr_critic: 2.0e-06
|
27 |
+
beta1: 0.0
|
28 |
+
beta2: 0.999
|
29 |
+
beta1_critic: 0.0
|
30 |
+
beta2_critic: 0.999
|
31 |
+
weight_decay: 0.0
|
32 |
+
data_path: prompts/vidprom_filtered_extended.txt
|
33 |
+
batch_size: 1
|
34 |
+
sid_alpha: 1.0
|
35 |
+
ema_weight: 0.99
|
36 |
+
ema_start_step: 200
|
37 |
+
total_batch_size: 64
|
38 |
+
log_iters: 50
|
39 |
+
negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
40 |
+
dfake_gen_update_ratio: 5
|
41 |
+
image_or_video_shape:
|
42 |
+
- 1
|
43 |
+
- 21
|
44 |
+
- 16
|
45 |
+
- 60
|
46 |
+
- 104
|
47 |
+
distribution_loss: dmd
|
48 |
+
trainer: score_distillation
|
49 |
+
gradient_checkpointing: true
|
50 |
+
num_frame_per_block: 3
|
51 |
+
load_raw_video: false
|
52 |
+
model_kwargs:
|
53 |
+
timestep_shift: 5.0
|
demo.py
ADDED
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Demo for Self-Forcing.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import os
|
6 |
+
import re
|
7 |
+
import random
|
8 |
+
import time
|
9 |
+
import base64
|
10 |
+
import argparse
|
11 |
+
import hashlib
|
12 |
+
import subprocess
|
13 |
+
import urllib.request
|
14 |
+
from io import BytesIO
|
15 |
+
from PIL import Image
|
16 |
+
import numpy as np
|
17 |
+
import torch
|
18 |
+
from omegaconf import OmegaConf
|
19 |
+
from flask import Flask, render_template, jsonify
|
20 |
+
from flask_socketio import SocketIO, emit
|
21 |
+
import queue
|
22 |
+
from threading import Thread, Event
|
23 |
+
|
24 |
+
from pipeline import CausalInferencePipeline
|
25 |
+
from demo_utils.constant import ZERO_VAE_CACHE
|
26 |
+
from demo_utils.vae_block3 import VAEDecoderWrapper
|
27 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
|
28 |
+
from demo_utils.utils import generate_timestamp
|
29 |
+
from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
|
30 |
+
|
31 |
+
# Parse arguments
|
32 |
+
parser = argparse.ArgumentParser()
|
33 |
+
parser.add_argument('--port', type=int, default=5001)
|
34 |
+
parser.add_argument('--host', type=str, default='0.0.0.0')
|
35 |
+
parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt')
|
36 |
+
parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml')
|
37 |
+
parser.add_argument('--trt', action='store_true')
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
|
41 |
+
low_memory = get_cuda_free_memory_gb(gpu) < 40
|
42 |
+
|
43 |
+
# Load models
|
44 |
+
config = OmegaConf.load(args.config_path)
|
45 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
46 |
+
config = OmegaConf.merge(default_config, config)
|
47 |
+
|
48 |
+
text_encoder = WanTextEncoder()
|
49 |
+
|
50 |
+
# Global variables for dynamic model switching
|
51 |
+
current_vae_decoder = None
|
52 |
+
current_use_taehv = False
|
53 |
+
fp8_applied = False
|
54 |
+
torch_compile_applied = False
|
55 |
+
global frame_number
|
56 |
+
frame_number = 0
|
57 |
+
anim_name = ""
|
58 |
+
frame_rate = 6
|
59 |
+
|
60 |
+
def initialize_vae_decoder(use_taehv=False, use_trt=False):
|
61 |
+
"""Initialize VAE decoder based on the selected option"""
|
62 |
+
global current_vae_decoder, current_use_taehv
|
63 |
+
|
64 |
+
if use_trt:
|
65 |
+
from demo_utils.vae import VAETRTWrapper
|
66 |
+
current_vae_decoder = VAETRTWrapper()
|
67 |
+
return current_vae_decoder
|
68 |
+
|
69 |
+
if use_taehv:
|
70 |
+
from demo_utils.taehv import TAEHV
|
71 |
+
# Check if taew2_1.pth exists in checkpoints folder, download if missing
|
72 |
+
taehv_checkpoint_path = "checkpoints/taew2_1.pth"
|
73 |
+
if not os.path.exists(taehv_checkpoint_path):
|
74 |
+
print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
|
75 |
+
os.makedirs("checkpoints", exist_ok=True)
|
76 |
+
download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
|
77 |
+
try:
|
78 |
+
urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
|
79 |
+
print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
|
80 |
+
except Exception as e:
|
81 |
+
print(f"Failed to download taew2_1.pth: {e}")
|
82 |
+
raise
|
83 |
+
|
84 |
+
class DotDict(dict):
|
85 |
+
__getattr__ = dict.__getitem__
|
86 |
+
__setattr__ = dict.__setitem__
|
87 |
+
|
88 |
+
class TAEHVDiffusersWrapper(torch.nn.Module):
|
89 |
+
def __init__(self):
|
90 |
+
super().__init__()
|
91 |
+
self.dtype = torch.float16
|
92 |
+
self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
|
93 |
+
self.config = DotDict(scaling_factor=1.0)
|
94 |
+
|
95 |
+
def decode(self, latents, return_dict=None):
|
96 |
+
# n, c, t, h, w = latents.shape
|
97 |
+
# low-memory, set parallel=True for faster + higher memory
|
98 |
+
return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)
|
99 |
+
|
100 |
+
current_vae_decoder = TAEHVDiffusersWrapper()
|
101 |
+
else:
|
102 |
+
current_vae_decoder = VAEDecoderWrapper()
|
103 |
+
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
|
104 |
+
decoder_state_dict = {}
|
105 |
+
for key, value in vae_state_dict.items():
|
106 |
+
if 'decoder.' in key or 'conv2' in key:
|
107 |
+
decoder_state_dict[key] = value
|
108 |
+
current_vae_decoder.load_state_dict(decoder_state_dict)
|
109 |
+
|
110 |
+
current_vae_decoder.eval()
|
111 |
+
current_vae_decoder.to(dtype=torch.float16)
|
112 |
+
current_vae_decoder.requires_grad_(False)
|
113 |
+
current_vae_decoder.to(gpu)
|
114 |
+
current_use_taehv = use_taehv
|
115 |
+
|
116 |
+
print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
|
117 |
+
return current_vae_decoder
|
118 |
+
|
119 |
+
|
120 |
+
# Initialize with default VAE
|
121 |
+
vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
|
122 |
+
|
123 |
+
transformer = WanDiffusionWrapper(is_causal=True)
|
124 |
+
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
125 |
+
transformer.load_state_dict(state_dict['generator_ema'])
|
126 |
+
|
127 |
+
text_encoder.eval()
|
128 |
+
transformer.eval()
|
129 |
+
|
130 |
+
transformer.to(dtype=torch.float16)
|
131 |
+
text_encoder.to(dtype=torch.bfloat16)
|
132 |
+
|
133 |
+
text_encoder.requires_grad_(False)
|
134 |
+
transformer.requires_grad_(False)
|
135 |
+
|
136 |
+
pipeline = CausalInferencePipeline(
|
137 |
+
config,
|
138 |
+
device=gpu,
|
139 |
+
generator=transformer,
|
140 |
+
text_encoder=text_encoder,
|
141 |
+
vae=vae_decoder
|
142 |
+
)
|
143 |
+
|
144 |
+
if low_memory:
|
145 |
+
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
|
146 |
+
else:
|
147 |
+
text_encoder.to(gpu)
|
148 |
+
transformer.to(gpu)
|
149 |
+
|
150 |
+
# Flask and SocketIO setup
|
151 |
+
app = Flask(__name__)
|
152 |
+
app.config['SECRET_KEY'] = 'frontend_buffered_demo'
|
153 |
+
socketio = SocketIO(app, cors_allowed_origins="*")
|
154 |
+
|
155 |
+
generation_active = False
|
156 |
+
stop_event = Event()
|
157 |
+
frame_send_queue = queue.Queue()
|
158 |
+
sender_thread = None
|
159 |
+
models_compiled = False
|
160 |
+
|
161 |
+
|
162 |
+
def tensor_to_base64_frame(frame_tensor):
|
163 |
+
"""Convert a single frame tensor to base64 image string."""
|
164 |
+
global frame_number, anim_name
|
165 |
+
# Clamp and normalize to 0-255
|
166 |
+
frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
|
167 |
+
frame = frame.to(torch.uint8).cpu().numpy()
|
168 |
+
|
169 |
+
# CHW -> HWC
|
170 |
+
if len(frame.shape) == 3:
|
171 |
+
frame = np.transpose(frame, (1, 2, 0))
|
172 |
+
|
173 |
+
# Convert to PIL Image
|
174 |
+
if frame.shape[2] == 3: # RGB
|
175 |
+
image = Image.fromarray(frame, 'RGB')
|
176 |
+
else: # Handle other formats
|
177 |
+
image = Image.fromarray(frame)
|
178 |
+
|
179 |
+
# Convert to base64
|
180 |
+
buffer = BytesIO()
|
181 |
+
image.save(buffer, format='JPEG', quality=100)
|
182 |
+
if not os.path.exists("./images/%s" % anim_name):
|
183 |
+
os.makedirs("./images/%s" % anim_name)
|
184 |
+
frame_number += 1
|
185 |
+
image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number))
|
186 |
+
img_str = base64.b64encode(buffer.getvalue()).decode()
|
187 |
+
return f"data:image/jpeg;base64,{img_str}"
|
188 |
+
|
189 |
+
|
190 |
+
def frame_sender_worker():
|
191 |
+
"""Background thread that processes frame send queue non-blocking."""
|
192 |
+
global frame_send_queue, generation_active, stop_event
|
193 |
+
|
194 |
+
print("📡 Frame sender thread started")
|
195 |
+
|
196 |
+
while True:
|
197 |
+
frame_data = None
|
198 |
+
try:
|
199 |
+
# Get frame data from queue
|
200 |
+
frame_data = frame_send_queue.get(timeout=1.0)
|
201 |
+
|
202 |
+
if frame_data is None: # Shutdown signal
|
203 |
+
frame_send_queue.task_done() # Mark shutdown signal as done
|
204 |
+
break
|
205 |
+
|
206 |
+
frame_tensor, frame_index, block_index, job_id = frame_data
|
207 |
+
|
208 |
+
# Convert tensor to base64
|
209 |
+
base64_frame = tensor_to_base64_frame(frame_tensor)
|
210 |
+
|
211 |
+
# Send via SocketIO
|
212 |
+
try:
|
213 |
+
socketio.emit('frame_ready', {
|
214 |
+
'data': base64_frame,
|
215 |
+
'frame_index': frame_index,
|
216 |
+
'block_index': block_index,
|
217 |
+
'job_id': job_id
|
218 |
+
})
|
219 |
+
except Exception as e:
|
220 |
+
print(f"⚠️ Failed to send frame {frame_index}: {e}")
|
221 |
+
|
222 |
+
frame_send_queue.task_done()
|
223 |
+
|
224 |
+
except queue.Empty:
|
225 |
+
# Check if we should continue running
|
226 |
+
if not generation_active and frame_send_queue.empty():
|
227 |
+
break
|
228 |
+
except Exception as e:
|
229 |
+
print(f"❌ Frame sender error: {e}")
|
230 |
+
# Make sure to mark task as done even if there's an error
|
231 |
+
if frame_data is not None:
|
232 |
+
try:
|
233 |
+
frame_send_queue.task_done()
|
234 |
+
except Exception as e:
|
235 |
+
print(f"❌ Failed to mark frame task as done: {e}")
|
236 |
+
break
|
237 |
+
|
238 |
+
print("📡 Frame sender thread stopped")
|
239 |
+
|
240 |
+
|
241 |
+
@torch.no_grad()
|
242 |
+
def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
|
243 |
+
"""Generate video and push frames immediately to frontend."""
|
244 |
+
global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name
|
245 |
+
|
246 |
+
try:
|
247 |
+
generation_active = True
|
248 |
+
stop_event.clear()
|
249 |
+
job_id = generate_timestamp()
|
250 |
+
|
251 |
+
# Start frame sender thread if not already running
|
252 |
+
if sender_thread is None or not sender_thread.is_alive():
|
253 |
+
sender_thread = Thread(target=frame_sender_worker, daemon=True)
|
254 |
+
sender_thread.start()
|
255 |
+
|
256 |
+
# Emit progress updates
|
257 |
+
def emit_progress(message, progress):
|
258 |
+
try:
|
259 |
+
socketio.emit('progress', {
|
260 |
+
'message': message,
|
261 |
+
'progress': progress,
|
262 |
+
'job_id': job_id
|
263 |
+
})
|
264 |
+
except Exception as e:
|
265 |
+
print(f"❌ Failed to emit progress: {e}")
|
266 |
+
|
267 |
+
emit_progress('Starting generation...', 0)
|
268 |
+
|
269 |
+
# Handle VAE decoder switching
|
270 |
+
if use_taehv != current_use_taehv:
|
271 |
+
emit_progress('Switching VAE decoder...', 2)
|
272 |
+
print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
|
273 |
+
current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
|
274 |
+
# Update pipeline with new VAE decoder
|
275 |
+
pipeline.vae = current_vae_decoder
|
276 |
+
|
277 |
+
# Handle FP8 quantization
|
278 |
+
if enable_fp8 and not fp8_applied:
|
279 |
+
emit_progress('Applying FP8 quantization...', 3)
|
280 |
+
print("🔧 Applying FP8 quantization to transformer")
|
281 |
+
from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
|
282 |
+
quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
|
283 |
+
fp8_applied = True
|
284 |
+
|
285 |
+
# Text encoding
|
286 |
+
emit_progress('Encoding text prompt...', 8)
|
287 |
+
conditional_dict = text_encoder(text_prompts=[prompt])
|
288 |
+
for key, value in conditional_dict.items():
|
289 |
+
conditional_dict[key] = value.to(dtype=torch.float16)
|
290 |
+
if low_memory:
|
291 |
+
gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
|
292 |
+
move_model_to_device_with_memory_preservation(
|
293 |
+
text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
|
294 |
+
|
295 |
+
# Handle torch.compile if enabled
|
296 |
+
torch_compile_applied = enable_torch_compile
|
297 |
+
if enable_torch_compile and not models_compiled:
|
298 |
+
# Compile transformer and decoder
|
299 |
+
transformer.compile(mode="max-autotune-no-cudagraphs")
|
300 |
+
if not current_use_taehv and not low_memory and not args.trt:
|
301 |
+
current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
|
302 |
+
|
303 |
+
# Initialize generation
|
304 |
+
emit_progress('Initializing generation...', 12)
|
305 |
+
|
306 |
+
rnd = torch.Generator(gpu).manual_seed(seed)
|
307 |
+
# all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16)
|
308 |
+
|
309 |
+
pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
|
310 |
+
pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)
|
311 |
+
|
312 |
+
noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
|
313 |
+
|
314 |
+
# Generation parameters
|
315 |
+
num_blocks = 7
|
316 |
+
current_start_frame = 0
|
317 |
+
num_input_frames = 0
|
318 |
+
all_num_frames = [pipeline.num_frame_per_block] * num_blocks
|
319 |
+
if current_use_taehv:
|
320 |
+
vae_cache = None
|
321 |
+
else:
|
322 |
+
vae_cache = ZERO_VAE_CACHE
|
323 |
+
for i in range(len(vae_cache)):
|
324 |
+
vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)
|
325 |
+
|
326 |
+
total_frames_sent = 0
|
327 |
+
generation_start_time = time.time()
|
328 |
+
|
329 |
+
emit_progress('Generating frames... (frontend handles timing)', 15)
|
330 |
+
|
331 |
+
for idx, current_num_frames in enumerate(all_num_frames):
|
332 |
+
if not generation_active or stop_event.is_set():
|
333 |
+
break
|
334 |
+
|
335 |
+
progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15
|
336 |
+
|
337 |
+
# Special message for first block with torch.compile
|
338 |
+
if idx == 0 and torch_compile_applied and not models_compiled:
|
339 |
+
emit_progress(
|
340 |
+
f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
|
341 |
+
print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}")
|
342 |
+
models_compiled = True
|
343 |
+
else:
|
344 |
+
emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
|
345 |
+
print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}")
|
346 |
+
|
347 |
+
block_start_time = time.time()
|
348 |
+
|
349 |
+
noisy_input = noise[:, current_start_frame -
|
350 |
+
num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
351 |
+
|
352 |
+
# Denoising loop
|
353 |
+
denoising_start = time.time()
|
354 |
+
for index, current_timestep in enumerate(pipeline.denoising_step_list):
|
355 |
+
if not generation_active or stop_event.is_set():
|
356 |
+
break
|
357 |
+
|
358 |
+
timestep = torch.ones([1, current_num_frames], device=noise.device,
|
359 |
+
dtype=torch.int64) * current_timestep
|
360 |
+
|
361 |
+
if index < len(pipeline.denoising_step_list) - 1:
|
362 |
+
_, denoised_pred = transformer(
|
363 |
+
noisy_image_or_video=noisy_input,
|
364 |
+
conditional_dict=conditional_dict,
|
365 |
+
timestep=timestep,
|
366 |
+
kv_cache=pipeline.kv_cache1,
|
367 |
+
crossattn_cache=pipeline.crossattn_cache,
|
368 |
+
current_start=current_start_frame * pipeline.frame_seq_length
|
369 |
+
)
|
370 |
+
next_timestep = pipeline.denoising_step_list[index + 1]
|
371 |
+
noisy_input = pipeline.scheduler.add_noise(
|
372 |
+
denoised_pred.flatten(0, 1),
|
373 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
374 |
+
next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
|
375 |
+
).unflatten(0, denoised_pred.shape[:2])
|
376 |
+
else:
|
377 |
+
_, denoised_pred = transformer(
|
378 |
+
noisy_image_or_video=noisy_input,
|
379 |
+
conditional_dict=conditional_dict,
|
380 |
+
timestep=timestep,
|
381 |
+
kv_cache=pipeline.kv_cache1,
|
382 |
+
crossattn_cache=pipeline.crossattn_cache,
|
383 |
+
current_start=current_start_frame * pipeline.frame_seq_length
|
384 |
+
)
|
385 |
+
|
386 |
+
if not generation_active or stop_event.is_set():
|
387 |
+
break
|
388 |
+
|
389 |
+
denoising_time = time.time() - denoising_start
|
390 |
+
print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s")
|
391 |
+
|
392 |
+
# Record output
|
393 |
+
# all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
|
394 |
+
|
395 |
+
# Update KV cache for next block
|
396 |
+
if idx != len(all_num_frames) - 1:
|
397 |
+
transformer(
|
398 |
+
noisy_image_or_video=denoised_pred,
|
399 |
+
conditional_dict=conditional_dict,
|
400 |
+
timestep=torch.zeros_like(timestep),
|
401 |
+
kv_cache=pipeline.kv_cache1,
|
402 |
+
crossattn_cache=pipeline.crossattn_cache,
|
403 |
+
current_start=current_start_frame * pipeline.frame_seq_length,
|
404 |
+
)
|
405 |
+
|
406 |
+
# Decode to pixels and send frames immediately
|
407 |
+
print(f"🎨 Decoding block {idx+1} to pixels...")
|
408 |
+
decode_start = time.time()
|
409 |
+
if args.trt:
|
410 |
+
all_current_pixels = []
|
411 |
+
for i in range(denoised_pred.shape[1]):
|
412 |
+
is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
|
413 |
+
torch.tensor(0.0).cuda().half()
|
414 |
+
outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
|
415 |
+
# outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache)
|
416 |
+
current_pixels, vae_cache = outputs[0], outputs[1:]
|
417 |
+
print(current_pixels.max(), current_pixels.min())
|
418 |
+
all_current_pixels.append(current_pixels.clone())
|
419 |
+
pixels = torch.cat(all_current_pixels, dim=1)
|
420 |
+
if idx == 0:
|
421 |
+
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
|
422 |
+
else:
|
423 |
+
if current_use_taehv:
|
424 |
+
if vae_cache is None:
|
425 |
+
vae_cache = denoised_pred
|
426 |
+
else:
|
427 |
+
denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
|
428 |
+
vae_cache = denoised_pred[:, -3:, :, :, :]
|
429 |
+
pixels = current_vae_decoder.decode(denoised_pred)
|
430 |
+
print(f"denoised_pred shape: {denoised_pred.shape}")
|
431 |
+
print(f"pixels shape: {pixels.shape}")
|
432 |
+
if idx == 0:
|
433 |
+
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
|
434 |
+
else:
|
435 |
+
pixels = pixels[:, 12:, :, :, :]
|
436 |
+
|
437 |
+
else:
|
438 |
+
pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
|
439 |
+
if idx == 0:
|
440 |
+
pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
|
441 |
+
|
442 |
+
decode_time = time.time() - decode_start
|
443 |
+
print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s")
|
444 |
+
|
445 |
+
# Queue frames for non-blocking sending
|
446 |
+
block_frames = pixels.shape[1]
|
447 |
+
print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...")
|
448 |
+
queue_start = time.time()
|
449 |
+
|
450 |
+
for frame_idx in range(block_frames):
|
451 |
+
if not generation_active or stop_event.is_set():
|
452 |
+
break
|
453 |
+
|
454 |
+
frame_tensor = pixels[0, frame_idx].cpu()
|
455 |
+
|
456 |
+
# Queue frame data in non-blocking way
|
457 |
+
frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
|
458 |
+
total_frames_sent += 1
|
459 |
+
|
460 |
+
queue_time = time.time() - queue_start
|
461 |
+
block_time = time.time() - block_start_time
|
462 |
+
print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")
|
463 |
+
|
464 |
+
current_start_frame += current_num_frames
|
465 |
+
|
466 |
+
generation_time = time.time() - generation_start_time
|
467 |
+
print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")
|
468 |
+
|
469 |
+
# Wait for all frames to be sent before completing
|
470 |
+
emit_progress('Waiting for all frames to be sent...', 97)
|
471 |
+
print("⏳ Waiting for all frames to be sent...")
|
472 |
+
frame_send_queue.join() # Wait for all queued frames to be processed
|
473 |
+
print("✅ All frames sent successfully!")
|
474 |
+
|
475 |
+
generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate )
|
476 |
+
# Final progress update
|
477 |
+
emit_progress('Generation complete!', 100)
|
478 |
+
|
479 |
+
try:
|
480 |
+
socketio.emit('generation_complete', {
|
481 |
+
'message': 'Video generation completed!',
|
482 |
+
'total_frames': total_frames_sent,
|
483 |
+
'generation_time': f"{generation_time:.2f}s",
|
484 |
+
'job_id': job_id
|
485 |
+
})
|
486 |
+
except Exception as e:
|
487 |
+
print(f"❌ Failed to emit generation complete: {e}")
|
488 |
+
|
489 |
+
except Exception as e:
|
490 |
+
print(f"❌ Generation failed: {e}")
|
491 |
+
try:
|
492 |
+
socketio.emit('error', {
|
493 |
+
'message': f'Generation failed: {str(e)}',
|
494 |
+
'job_id': job_id
|
495 |
+
})
|
496 |
+
except Exception as e:
|
497 |
+
print(f"❌ Failed to emit error: {e}")
|
498 |
+
finally:
|
499 |
+
generation_active = False
|
500 |
+
stop_event.set()
|
501 |
+
|
502 |
+
# Clean up sender thread
|
503 |
+
try:
|
504 |
+
frame_send_queue.put(None)
|
505 |
+
except Exception as e:
|
506 |
+
print(f"❌ Failed to put None in frame_send_queue: {e}")
|
507 |
+
|
508 |
+
|
509 |
+
def generate_mp4_from_images(image_directory, output_video_path, fps=24):
|
510 |
+
"""
|
511 |
+
Generate an MP4 video from a directory of images ordered alphabetically.
|
512 |
+
|
513 |
+
:param image_directory: Path to the directory containing images.
|
514 |
+
:param output_video_path: Path where the output MP4 will be saved.
|
515 |
+
:param fps: Frames per second for the output video.
|
516 |
+
"""
|
517 |
+
global anim_name
|
518 |
+
# Construct the ffmpeg command
|
519 |
+
cmd = [
|
520 |
+
'ffmpeg',
|
521 |
+
'-framerate', str(fps),
|
522 |
+
'-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'), # Adjust the pattern if necessary
|
523 |
+
'-c:v', 'libx264',
|
524 |
+
'-pix_fmt', 'yuv420p',
|
525 |
+
output_video_path
|
526 |
+
]
|
527 |
+
try:
|
528 |
+
subprocess.run(cmd, check=True)
|
529 |
+
print(f"Video saved to {output_video_path}")
|
530 |
+
except subprocess.CalledProcessError as e:
|
531 |
+
print(f"An error occurred: {e}")
|
532 |
+
|
533 |
+
def calculate_sha256(data):
|
534 |
+
# Convert data to bytes if it's not already
|
535 |
+
if isinstance(data, str):
|
536 |
+
data = data.encode()
|
537 |
+
# Calculate SHA-256 hash
|
538 |
+
sha256_hash = hashlib.sha256(data).hexdigest()
|
539 |
+
return sha256_hash
|
540 |
+
|
541 |
+
# Socket.IO event handlers
|
542 |
+
@socketio.on('connect')
|
543 |
+
def handle_connect():
|
544 |
+
print('Client connected')
|
545 |
+
emit('status', {'message': 'Connected to frontend-buffered demo server'})
|
546 |
+
|
547 |
+
|
548 |
+
@socketio.on('disconnect')
|
549 |
+
def handle_disconnect():
|
550 |
+
print('Client disconnected')
|
551 |
+
|
552 |
+
|
553 |
+
@socketio.on('start_generation')
|
554 |
+
def handle_start_generation(data):
|
555 |
+
global generation_active, frame_number, anim_name, frame_rate
|
556 |
+
|
557 |
+
frame_number = 0
|
558 |
+
if generation_active:
|
559 |
+
emit('error', {'message': 'Generation already in progress'})
|
560 |
+
return
|
561 |
+
|
562 |
+
prompt = data.get('prompt', '')
|
563 |
+
|
564 |
+
seed = data.get('seed', -1)
|
565 |
+
if seed==-1:
|
566 |
+
seed = random.randint(0, 2**32)
|
567 |
+
|
568 |
+
# Extract words up to the first punctuation or newline
|
569 |
+
words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else ''
|
570 |
+
if not words_up_to_punctuation:
|
571 |
+
words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip()
|
572 |
+
|
573 |
+
# Calculate SHA-256 hash of the entire prompt
|
574 |
+
sha256_hash = calculate_sha256(prompt)
|
575 |
+
|
576 |
+
# Create anim_name with the extracted words and first 10 characters of the hash
|
577 |
+
anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}"
|
578 |
+
|
579 |
+
generation_active = True
|
580 |
+
generation_start_time = time.time()
|
581 |
+
enable_torch_compile = data.get('enable_torch_compile', False)
|
582 |
+
enable_fp8 = data.get('enable_fp8', False)
|
583 |
+
use_taehv = data.get('use_taehv', False)
|
584 |
+
frame_rate = data.get('fps', 6)
|
585 |
+
|
586 |
+
if not prompt:
|
587 |
+
emit('error', {'message': 'Prompt is required'})
|
588 |
+
return
|
589 |
+
|
590 |
+
# Start generation in background thread
|
591 |
+
socketio.start_background_task(generate_video_stream, prompt, seed,
|
592 |
+
enable_torch_compile, enable_fp8, use_taehv)
|
593 |
+
emit('status', {'message': 'Generation started - frames will be sent immediately'})
|
594 |
+
|
595 |
+
|
596 |
+
@socketio.on('stop_generation')
|
597 |
+
def handle_stop_generation():
|
598 |
+
global generation_active, stop_event, frame_send_queue
|
599 |
+
generation_active = False
|
600 |
+
stop_event.set()
|
601 |
+
|
602 |
+
# Signal sender thread to stop (will be processed after current frames)
|
603 |
+
try:
|
604 |
+
frame_send_queue.put(None)
|
605 |
+
except Exception as e:
|
606 |
+
print(f"❌ Failed to put None in frame_send_queue: {e}")
|
607 |
+
|
608 |
+
emit('status', {'message': 'Generation stopped'})
|
609 |
+
|
610 |
+
# Web routes
|
611 |
+
|
612 |
+
|
613 |
+
@app.route('/')
|
614 |
+
def index():
|
615 |
+
return render_template('demo.html')
|
616 |
+
|
617 |
+
|
618 |
+
@app.route('/api/status')
|
619 |
+
def api_status():
|
620 |
+
return jsonify({
|
621 |
+
'generation_active': generation_active,
|
622 |
+
'free_vram_gb': get_cuda_free_memory_gb(gpu),
|
623 |
+
'fp8_applied': fp8_applied,
|
624 |
+
'torch_compile_applied': torch_compile_applied,
|
625 |
+
'current_use_taehv': current_use_taehv
|
626 |
+
})
|
627 |
+
|
628 |
+
|
629 |
+
if __name__ == '__main__':
|
630 |
+
print(f"🚀 Starting demo on http://{args.host}:{args.port}")
|
631 |
+
socketio.run(app, host=args.host, port=args.port, debug=False)
|
demo_utils/constant.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
ZERO_VAE_CACHE = [
|
6 |
+
torch.zeros(1, 16, 2, 60, 104),
|
7 |
+
torch.zeros(1, 384, 2, 60, 104),
|
8 |
+
torch.zeros(1, 384, 2, 60, 104),
|
9 |
+
torch.zeros(1, 384, 2, 60, 104),
|
10 |
+
torch.zeros(1, 384, 2, 60, 104),
|
11 |
+
torch.zeros(1, 384, 2, 60, 104),
|
12 |
+
torch.zeros(1, 384, 2, 60, 104),
|
13 |
+
torch.zeros(1, 384, 2, 60, 104),
|
14 |
+
torch.zeros(1, 384, 2, 60, 104),
|
15 |
+
torch.zeros(1, 384, 2, 60, 104),
|
16 |
+
torch.zeros(1, 384, 2, 60, 104),
|
17 |
+
torch.zeros(1, 384, 2, 60, 104),
|
18 |
+
torch.zeros(1, 192, 2, 120, 208),
|
19 |
+
torch.zeros(1, 384, 2, 120, 208),
|
20 |
+
torch.zeros(1, 384, 2, 120, 208),
|
21 |
+
torch.zeros(1, 384, 2, 120, 208),
|
22 |
+
torch.zeros(1, 384, 2, 120, 208),
|
23 |
+
torch.zeros(1, 384, 2, 120, 208),
|
24 |
+
torch.zeros(1, 384, 2, 120, 208),
|
25 |
+
torch.zeros(1, 192, 2, 240, 416),
|
26 |
+
torch.zeros(1, 192, 2, 240, 416),
|
27 |
+
torch.zeros(1, 192, 2, 240, 416),
|
28 |
+
torch.zeros(1, 192, 2, 240, 416),
|
29 |
+
torch.zeros(1, 192, 2, 240, 416),
|
30 |
+
torch.zeros(1, 192, 2, 240, 416),
|
31 |
+
torch.zeros(1, 96, 2, 480, 832),
|
32 |
+
torch.zeros(1, 96, 2, 480, 832),
|
33 |
+
torch.zeros(1, 96, 2, 480, 832),
|
34 |
+
torch.zeros(1, 96, 2, 480, 832),
|
35 |
+
torch.zeros(1, 96, 2, 480, 832),
|
36 |
+
torch.zeros(1, 96, 2, 480, 832),
|
37 |
+
torch.zeros(1, 96, 2, 480, 832)
|
38 |
+
]
|
39 |
+
|
40 |
+
feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
|
41 |
+
ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names
|
demo_utils/memory.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
|
2 |
+
# Apache-2.0 License
|
3 |
+
# By lllyasviel
|
4 |
+
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
cpu = torch.device('cpu')
|
9 |
+
gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
|
10 |
+
gpu_complete_modules = []
|
11 |
+
|
12 |
+
|
13 |
+
class DynamicSwapInstaller:
|
14 |
+
@staticmethod
|
15 |
+
def _install_module(module: torch.nn.Module, **kwargs):
|
16 |
+
original_class = module.__class__
|
17 |
+
module.__dict__['forge_backup_original_class'] = original_class
|
18 |
+
|
19 |
+
def hacked_get_attr(self, name: str):
|
20 |
+
if '_parameters' in self.__dict__:
|
21 |
+
_parameters = self.__dict__['_parameters']
|
22 |
+
if name in _parameters:
|
23 |
+
p = _parameters[name]
|
24 |
+
if p is None:
|
25 |
+
return None
|
26 |
+
if p.__class__ == torch.nn.Parameter:
|
27 |
+
return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
|
28 |
+
else:
|
29 |
+
return p.to(**kwargs)
|
30 |
+
if '_buffers' in self.__dict__:
|
31 |
+
_buffers = self.__dict__['_buffers']
|
32 |
+
if name in _buffers:
|
33 |
+
return _buffers[name].to(**kwargs)
|
34 |
+
return super(original_class, self).__getattr__(name)
|
35 |
+
|
36 |
+
module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
|
37 |
+
'__getattr__': hacked_get_attr,
|
38 |
+
})
|
39 |
+
|
40 |
+
return
|
41 |
+
|
42 |
+
@staticmethod
|
43 |
+
def _uninstall_module(module: torch.nn.Module):
|
44 |
+
if 'forge_backup_original_class' in module.__dict__:
|
45 |
+
module.__class__ = module.__dict__.pop('forge_backup_original_class')
|
46 |
+
return
|
47 |
+
|
48 |
+
@staticmethod
|
49 |
+
def install_model(model: torch.nn.Module, **kwargs):
|
50 |
+
for m in model.modules():
|
51 |
+
DynamicSwapInstaller._install_module(m, **kwargs)
|
52 |
+
return
|
53 |
+
|
54 |
+
@staticmethod
|
55 |
+
def uninstall_model(model: torch.nn.Module):
|
56 |
+
for m in model.modules():
|
57 |
+
DynamicSwapInstaller._uninstall_module(m)
|
58 |
+
return
|
59 |
+
|
60 |
+
|
61 |
+
def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
|
62 |
+
if hasattr(model, 'scale_shift_table'):
|
63 |
+
model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
|
64 |
+
return
|
65 |
+
|
66 |
+
for k, p in model.named_modules():
|
67 |
+
if hasattr(p, 'weight'):
|
68 |
+
p.to(target_device)
|
69 |
+
return
|
70 |
+
|
71 |
+
|
72 |
+
def get_cuda_free_memory_gb(device=None):
|
73 |
+
if device is None:
|
74 |
+
device = gpu
|
75 |
+
|
76 |
+
memory_stats = torch.cuda.memory_stats(device)
|
77 |
+
bytes_active = memory_stats['active_bytes.all.current']
|
78 |
+
bytes_reserved = memory_stats['reserved_bytes.all.current']
|
79 |
+
bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
|
80 |
+
bytes_inactive_reserved = bytes_reserved - bytes_active
|
81 |
+
bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
|
82 |
+
return bytes_total_available / (1024 ** 3)
|
83 |
+
|
84 |
+
|
85 |
+
def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
|
86 |
+
print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
|
87 |
+
|
88 |
+
for m in model.modules():
|
89 |
+
if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
|
90 |
+
torch.cuda.empty_cache()
|
91 |
+
return
|
92 |
+
|
93 |
+
if hasattr(m, 'weight'):
|
94 |
+
m.to(device=target_device)
|
95 |
+
|
96 |
+
model.to(device=target_device)
|
97 |
+
torch.cuda.empty_cache()
|
98 |
+
return
|
99 |
+
|
100 |
+
|
101 |
+
def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
|
102 |
+
print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
|
103 |
+
|
104 |
+
for m in model.modules():
|
105 |
+
if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
|
106 |
+
torch.cuda.empty_cache()
|
107 |
+
return
|
108 |
+
|
109 |
+
if hasattr(m, 'weight'):
|
110 |
+
m.to(device=cpu)
|
111 |
+
|
112 |
+
model.to(device=cpu)
|
113 |
+
torch.cuda.empty_cache()
|
114 |
+
return
|
115 |
+
|
116 |
+
|
117 |
+
def unload_complete_models(*args):
|
118 |
+
for m in gpu_complete_modules + list(args):
|
119 |
+
m.to(device=cpu)
|
120 |
+
print(f'Unloaded {m.__class__.__name__} as complete.')
|
121 |
+
|
122 |
+
gpu_complete_modules.clear()
|
123 |
+
torch.cuda.empty_cache()
|
124 |
+
return
|
125 |
+
|
126 |
+
|
127 |
+
def load_model_as_complete(model, target_device, unload=True):
|
128 |
+
if unload:
|
129 |
+
unload_complete_models()
|
130 |
+
|
131 |
+
model.to(device=target_device)
|
132 |
+
print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
|
133 |
+
|
134 |
+
gpu_complete_modules.append(model)
|
135 |
+
return
|
demo_utils/taehv.py
ADDED
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
"""
|
3 |
+
Tiny AutoEncoder for Hunyuan Video
|
4 |
+
(DNN for encoding / decoding videos to Hunyuan Video's latent space)
|
5 |
+
"""
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from tqdm.auto import tqdm
|
10 |
+
from collections import namedtuple
|
11 |
+
|
12 |
+
DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
|
13 |
+
TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
|
14 |
+
|
15 |
+
|
16 |
+
def conv(n_in, n_out, **kwargs):
|
17 |
+
return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
|
18 |
+
|
19 |
+
|
20 |
+
class Clamp(nn.Module):
|
21 |
+
def forward(self, x):
|
22 |
+
return torch.tanh(x / 3) * 3
|
23 |
+
|
24 |
+
|
25 |
+
class MemBlock(nn.Module):
|
26 |
+
def __init__(self, n_in, n_out):
|
27 |
+
super().__init__()
|
28 |
+
self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True),
|
29 |
+
conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
|
30 |
+
self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
|
31 |
+
self.act = nn.ReLU(inplace=True)
|
32 |
+
|
33 |
+
def forward(self, x, past):
|
34 |
+
return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
|
35 |
+
|
36 |
+
|
37 |
+
class TPool(nn.Module):
|
38 |
+
def __init__(self, n_f, stride):
|
39 |
+
super().__init__()
|
40 |
+
self.stride = stride
|
41 |
+
self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
|
42 |
+
|
43 |
+
def forward(self, x):
|
44 |
+
_NT, C, H, W = x.shape
|
45 |
+
return self.conv(x.reshape(-1, self.stride * C, H, W))
|
46 |
+
|
47 |
+
|
48 |
+
class TGrow(nn.Module):
|
49 |
+
def __init__(self, n_f, stride):
|
50 |
+
super().__init__()
|
51 |
+
self.stride = stride
|
52 |
+
self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
|
53 |
+
|
54 |
+
def forward(self, x):
|
55 |
+
_NT, C, H, W = x.shape
|
56 |
+
x = self.conv(x)
|
57 |
+
return x.reshape(-1, C, H, W)
|
58 |
+
|
59 |
+
|
60 |
+
def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
|
61 |
+
"""
|
62 |
+
Apply a sequential model with memblocks to the given input.
|
63 |
+
Args:
|
64 |
+
- model: nn.Sequential of blocks to apply
|
65 |
+
- x: input data, of dimensions NTCHW
|
66 |
+
- parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
|
67 |
+
if False, each timestep will be processed sequentially (slow but uses O(1) memory)
|
68 |
+
- show_progress_bar: if True, enables tqdm progressbar display
|
69 |
+
|
70 |
+
Returns NTCHW tensor of output data.
|
71 |
+
"""
|
72 |
+
assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
|
73 |
+
N, T, C, H, W = x.shape
|
74 |
+
if parallel:
|
75 |
+
x = x.reshape(N * T, C, H, W)
|
76 |
+
# parallel over input timesteps, iterate over blocks
|
77 |
+
for b in tqdm(model, disable=not show_progress_bar):
|
78 |
+
if isinstance(b, MemBlock):
|
79 |
+
NT, C, H, W = x.shape
|
80 |
+
T = NT // N
|
81 |
+
_x = x.reshape(N, T, C, H, W)
|
82 |
+
mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
|
83 |
+
x = b(x, mem)
|
84 |
+
else:
|
85 |
+
x = b(x)
|
86 |
+
NT, C, H, W = x.shape
|
87 |
+
T = NT // N
|
88 |
+
x = x.view(N, T, C, H, W)
|
89 |
+
else:
|
90 |
+
# TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
|
91 |
+
# need to fix :(
|
92 |
+
out = []
|
93 |
+
# iterate over input timesteps and also iterate over blocks.
|
94 |
+
# because of the cursed TPool/TGrow blocks, this is not a nested loop,
|
95 |
+
# it's actually a ***graph traversal*** problem! so let's make a queue
|
96 |
+
work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
|
97 |
+
# in addition to manually managing our queue, we also need to manually manage our progressbar.
|
98 |
+
# we'll update it for every source node that we consume.
|
99 |
+
progress_bar = tqdm(range(T), disable=not show_progress_bar)
|
100 |
+
# we'll also need a separate addressable memory per node as well
|
101 |
+
mem = [None] * len(model)
|
102 |
+
while work_queue:
|
103 |
+
xt, i = work_queue.pop(0)
|
104 |
+
if i == 0:
|
105 |
+
# new source node consumed
|
106 |
+
progress_bar.update(1)
|
107 |
+
if i == len(model):
|
108 |
+
# reached end of the graph, append result to output list
|
109 |
+
out.append(xt)
|
110 |
+
else:
|
111 |
+
# fetch the block to process
|
112 |
+
b = model[i]
|
113 |
+
if isinstance(b, MemBlock):
|
114 |
+
# mem blocks are simple since we're visiting the graph in causal order
|
115 |
+
if mem[i] is None:
|
116 |
+
xt_new = b(xt, xt * 0)
|
117 |
+
mem[i] = xt
|
118 |
+
else:
|
119 |
+
xt_new = b(xt, mem[i])
|
120 |
+
mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
|
121 |
+
# add successor to work queue
|
122 |
+
work_queue.insert(0, TWorkItem(xt_new, i + 1))
|
123 |
+
elif isinstance(b, TPool):
|
124 |
+
# pool blocks are miserable
|
125 |
+
if mem[i] is None:
|
126 |
+
mem[i] = [] # pool memory is itself a queue of inputs to pool
|
127 |
+
mem[i].append(xt)
|
128 |
+
if len(mem[i]) > b.stride:
|
129 |
+
# pool mem is in invalid state, we should have pooled before this
|
130 |
+
raise ValueError("???")
|
131 |
+
elif len(mem[i]) < b.stride:
|
132 |
+
# pool mem is not yet full, go back to processing the work queue
|
133 |
+
pass
|
134 |
+
else:
|
135 |
+
# pool mem is ready, run the pool block
|
136 |
+
N, C, H, W = xt.shape
|
137 |
+
xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
|
138 |
+
# reset the pool mem
|
139 |
+
mem[i] = []
|
140 |
+
# add successor to work queue
|
141 |
+
work_queue.insert(0, TWorkItem(xt, i + 1))
|
142 |
+
elif isinstance(b, TGrow):
|
143 |
+
xt = b(xt)
|
144 |
+
NT, C, H, W = xt.shape
|
145 |
+
# each tgrow has multiple successor nodes
|
146 |
+
for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
|
147 |
+
# add successor to work queue
|
148 |
+
work_queue.insert(0, TWorkItem(xt_next, i + 1))
|
149 |
+
else:
|
150 |
+
# normal block with no funny business
|
151 |
+
xt = b(xt)
|
152 |
+
# add successor to work queue
|
153 |
+
work_queue.insert(0, TWorkItem(xt, i + 1))
|
154 |
+
progress_bar.close()
|
155 |
+
x = torch.stack(out, 1)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
class TAEHV(nn.Module):
|
160 |
+
latent_channels = 16
|
161 |
+
image_channels = 3
|
162 |
+
|
163 |
+
def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
|
164 |
+
"""Initialize pretrained TAEHV from the given checkpoint.
|
165 |
+
|
166 |
+
Arg:
|
167 |
+
checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
|
168 |
+
decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
169 |
+
decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
|
170 |
+
"""
|
171 |
+
super().__init__()
|
172 |
+
self.encoder = nn.Sequential(
|
173 |
+
conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
|
174 |
+
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
175 |
+
TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
176 |
+
TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
|
177 |
+
conv(64, TAEHV.latent_channels),
|
178 |
+
)
|
179 |
+
n_f = [256, 128, 64, 64]
|
180 |
+
self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
|
181 |
+
self.decoder = nn.Sequential(
|
182 |
+
Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
|
183 |
+
MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(
|
184 |
+
scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
|
185 |
+
MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(
|
186 |
+
scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
|
187 |
+
MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(
|
188 |
+
scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
|
189 |
+
nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
|
190 |
+
)
|
191 |
+
if checkpoint_path is not None:
|
192 |
+
self.load_state_dict(self.patch_tgrow_layers(torch.load(
|
193 |
+
checkpoint_path, map_location="cpu", weights_only=True)))
|
194 |
+
|
195 |
+
def patch_tgrow_layers(self, sd):
|
196 |
+
"""Patch TGrow layers to use a smaller kernel if needed.
|
197 |
+
|
198 |
+
Args:
|
199 |
+
sd: state dict to patch
|
200 |
+
"""
|
201 |
+
new_sd = self.state_dict()
|
202 |
+
for i, layer in enumerate(self.decoder):
|
203 |
+
if isinstance(layer, TGrow):
|
204 |
+
key = f"decoder.{i}.conv.weight"
|
205 |
+
if sd[key].shape[0] > new_sd[key].shape[0]:
|
206 |
+
# take the last-timestep output channels
|
207 |
+
sd[key] = sd[key][-new_sd[key].shape[0]:]
|
208 |
+
return sd
|
209 |
+
|
210 |
+
def encode_video(self, x, parallel=True, show_progress_bar=True):
|
211 |
+
"""Encode a sequence of frames.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
x: input NTCHW RGB (C=3) tensor with values in [0, 1].
|
215 |
+
parallel: if True, all frames will be processed at once.
|
216 |
+
(this is faster but may require more memory).
|
217 |
+
if False, frames will be processed sequentially.
|
218 |
+
Returns NTCHW latent tensor with ~Gaussian values.
|
219 |
+
"""
|
220 |
+
return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
|
221 |
+
|
222 |
+
def decode_video(self, x, parallel=True, show_progress_bar=False):
|
223 |
+
"""Decode a sequence of frames.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
x: input NTCHW latent (C=12) tensor with ~Gaussian values.
|
227 |
+
parallel: if True, all frames will be processed at once.
|
228 |
+
(this is faster but may require more memory).
|
229 |
+
if False, frames will be processed sequentially.
|
230 |
+
Returns NTCHW RGB tensor with ~[0, 1] values.
|
231 |
+
"""
|
232 |
+
x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
|
233 |
+
# return x[:, self.frames_to_trim:]
|
234 |
+
return x
|
235 |
+
|
236 |
+
def forward(self, x):
|
237 |
+
return self.c(x)
|
238 |
+
|
239 |
+
|
240 |
+
@torch.no_grad()
|
241 |
+
def main():
|
242 |
+
"""Run TAEHV roundtrip reconstruction on the given video paths."""
|
243 |
+
import os
|
244 |
+
import sys
|
245 |
+
import cv2 # no highly esteemed deed is commemorated here
|
246 |
+
|
247 |
+
class VideoTensorReader:
|
248 |
+
def __init__(self, video_file_path):
|
249 |
+
self.cap = cv2.VideoCapture(video_file_path)
|
250 |
+
assert self.cap.isOpened(), f"Could not load {video_file_path}"
|
251 |
+
self.fps = self.cap.get(cv2.CAP_PROP_FPS)
|
252 |
+
|
253 |
+
def __iter__(self):
|
254 |
+
return self
|
255 |
+
|
256 |
+
def __next__(self):
|
257 |
+
ret, frame = self.cap.read()
|
258 |
+
if not ret:
|
259 |
+
self.cap.release()
|
260 |
+
raise StopIteration # End of video or error
|
261 |
+
return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
|
262 |
+
|
263 |
+
class VideoTensorWriter:
|
264 |
+
def __init__(self, video_file_path, width_height, fps=30):
|
265 |
+
self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
|
266 |
+
assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
|
267 |
+
|
268 |
+
def write(self, frame_tensor):
|
269 |
+
assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
|
270 |
+
self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(),
|
271 |
+
cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
|
272 |
+
|
273 |
+
def __del__(self):
|
274 |
+
if hasattr(self, 'writer'):
|
275 |
+
self.writer.release()
|
276 |
+
|
277 |
+
dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
|
278 |
+
dtype = torch.float16
|
279 |
+
checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
|
280 |
+
checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
|
281 |
+
print(
|
282 |
+
f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
|
283 |
+
taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
|
284 |
+
for video_path in sys.argv[1:]:
|
285 |
+
print(f"Processing {video_path}...")
|
286 |
+
video_in = VideoTensorReader(video_path)
|
287 |
+
video = torch.stack(list(video_in), 0)[None]
|
288 |
+
vid_dev = video.to(dev, dtype).div_(255.0)
|
289 |
+
# convert to device tensor
|
290 |
+
if video.numel() < 100_000_000:
|
291 |
+
print(f" {video_path} seems small enough, will process all frames in parallel")
|
292 |
+
# convert to device tensor
|
293 |
+
vid_enc = taehv.encode_video(vid_dev)
|
294 |
+
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
295 |
+
vid_dec = taehv.decode_video(vid_enc)
|
296 |
+
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
297 |
+
else:
|
298 |
+
print(f" {video_path} seems large, will process each frame sequentially")
|
299 |
+
# convert to device tensor
|
300 |
+
vid_enc = taehv.encode_video(vid_dev, parallel=False)
|
301 |
+
print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
|
302 |
+
vid_dec = taehv.decode_video(vid_enc, parallel=False)
|
303 |
+
print(f" Decoded {video_path} -> {vid_dec.shape}")
|
304 |
+
video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
|
305 |
+
video_out = VideoTensorWriter(
|
306 |
+
video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
|
307 |
+
for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
|
308 |
+
video_out.write(frame)
|
309 |
+
print(f" Saved to {video_out_path}")
|
310 |
+
|
311 |
+
|
312 |
+
if __name__ == "__main__":
|
313 |
+
main()
|
demo_utils/utils.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
|
2 |
+
# Apache-2.0 License
|
3 |
+
# By lllyasviel
|
4 |
+
|
5 |
+
import os
|
6 |
+
import cv2
|
7 |
+
import json
|
8 |
+
import random
|
9 |
+
import glob
|
10 |
+
import torch
|
11 |
+
import einops
|
12 |
+
import numpy as np
|
13 |
+
import datetime
|
14 |
+
import torchvision
|
15 |
+
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def min_resize(x, m):
|
20 |
+
if x.shape[0] < x.shape[1]:
|
21 |
+
s0 = m
|
22 |
+
s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
|
23 |
+
else:
|
24 |
+
s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
|
25 |
+
s1 = m
|
26 |
+
new_max = max(s1, s0)
|
27 |
+
raw_max = max(x.shape[0], x.shape[1])
|
28 |
+
if new_max < raw_max:
|
29 |
+
interpolation = cv2.INTER_AREA
|
30 |
+
else:
|
31 |
+
interpolation = cv2.INTER_LANCZOS4
|
32 |
+
y = cv2.resize(x, (s1, s0), interpolation=interpolation)
|
33 |
+
return y
|
34 |
+
|
35 |
+
|
36 |
+
def d_resize(x, y):
|
37 |
+
H, W, C = y.shape
|
38 |
+
new_min = min(H, W)
|
39 |
+
raw_min = min(x.shape[0], x.shape[1])
|
40 |
+
if new_min < raw_min:
|
41 |
+
interpolation = cv2.INTER_AREA
|
42 |
+
else:
|
43 |
+
interpolation = cv2.INTER_LANCZOS4
|
44 |
+
y = cv2.resize(x, (W, H), interpolation=interpolation)
|
45 |
+
return y
|
46 |
+
|
47 |
+
|
48 |
+
def resize_and_center_crop(image, target_width, target_height):
|
49 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
50 |
+
return image
|
51 |
+
|
52 |
+
pil_image = Image.fromarray(image)
|
53 |
+
original_width, original_height = pil_image.size
|
54 |
+
scale_factor = max(target_width / original_width, target_height / original_height)
|
55 |
+
resized_width = int(round(original_width * scale_factor))
|
56 |
+
resized_height = int(round(original_height * scale_factor))
|
57 |
+
resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
|
58 |
+
left = (resized_width - target_width) / 2
|
59 |
+
top = (resized_height - target_height) / 2
|
60 |
+
right = (resized_width + target_width) / 2
|
61 |
+
bottom = (resized_height + target_height) / 2
|
62 |
+
cropped_image = resized_image.crop((left, top, right, bottom))
|
63 |
+
return np.array(cropped_image)
|
64 |
+
|
65 |
+
|
66 |
+
def resize_and_center_crop_pytorch(image, target_width, target_height):
|
67 |
+
B, C, H, W = image.shape
|
68 |
+
|
69 |
+
if H == target_height and W == target_width:
|
70 |
+
return image
|
71 |
+
|
72 |
+
scale_factor = max(target_width / W, target_height / H)
|
73 |
+
resized_width = int(round(W * scale_factor))
|
74 |
+
resized_height = int(round(H * scale_factor))
|
75 |
+
|
76 |
+
resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
|
77 |
+
|
78 |
+
top = (resized_height - target_height) // 2
|
79 |
+
left = (resized_width - target_width) // 2
|
80 |
+
cropped = resized[:, :, top:top + target_height, left:left + target_width]
|
81 |
+
|
82 |
+
return cropped
|
83 |
+
|
84 |
+
|
85 |
+
def resize_without_crop(image, target_width, target_height):
|
86 |
+
if target_height == image.shape[0] and target_width == image.shape[1]:
|
87 |
+
return image
|
88 |
+
|
89 |
+
pil_image = Image.fromarray(image)
|
90 |
+
resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
|
91 |
+
return np.array(resized_image)
|
92 |
+
|
93 |
+
|
94 |
+
def just_crop(image, w, h):
|
95 |
+
if h == image.shape[0] and w == image.shape[1]:
|
96 |
+
return image
|
97 |
+
|
98 |
+
original_height, original_width = image.shape[:2]
|
99 |
+
k = min(original_height / h, original_width / w)
|
100 |
+
new_width = int(round(w * k))
|
101 |
+
new_height = int(round(h * k))
|
102 |
+
x_start = (original_width - new_width) // 2
|
103 |
+
y_start = (original_height - new_height) // 2
|
104 |
+
cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
|
105 |
+
return cropped_image
|
106 |
+
|
107 |
+
|
108 |
+
def write_to_json(data, file_path):
|
109 |
+
temp_file_path = file_path + ".tmp"
|
110 |
+
with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
|
111 |
+
json.dump(data, temp_file, indent=4)
|
112 |
+
os.replace(temp_file_path, file_path)
|
113 |
+
return
|
114 |
+
|
115 |
+
|
116 |
+
def read_from_json(file_path):
|
117 |
+
with open(file_path, 'rt', encoding='utf-8') as file:
|
118 |
+
data = json.load(file)
|
119 |
+
return data
|
120 |
+
|
121 |
+
|
122 |
+
def get_active_parameters(m):
|
123 |
+
return {k: v for k, v in m.named_parameters() if v.requires_grad}
|
124 |
+
|
125 |
+
|
126 |
+
def cast_training_params(m, dtype=torch.float32):
|
127 |
+
result = {}
|
128 |
+
for n, param in m.named_parameters():
|
129 |
+
if param.requires_grad:
|
130 |
+
param.data = param.to(dtype)
|
131 |
+
result[n] = param
|
132 |
+
return result
|
133 |
+
|
134 |
+
|
135 |
+
def separate_lora_AB(parameters, B_patterns=None):
|
136 |
+
parameters_normal = {}
|
137 |
+
parameters_B = {}
|
138 |
+
|
139 |
+
if B_patterns is None:
|
140 |
+
B_patterns = ['.lora_B.', '__zero__']
|
141 |
+
|
142 |
+
for k, v in parameters.items():
|
143 |
+
if any(B_pattern in k for B_pattern in B_patterns):
|
144 |
+
parameters_B[k] = v
|
145 |
+
else:
|
146 |
+
parameters_normal[k] = v
|
147 |
+
|
148 |
+
return parameters_normal, parameters_B
|
149 |
+
|
150 |
+
|
151 |
+
def set_attr_recursive(obj, attr, value):
|
152 |
+
attrs = attr.split(".")
|
153 |
+
for name in attrs[:-1]:
|
154 |
+
obj = getattr(obj, name)
|
155 |
+
setattr(obj, attrs[-1], value)
|
156 |
+
return
|
157 |
+
|
158 |
+
|
159 |
+
def print_tensor_list_size(tensors):
|
160 |
+
total_size = 0
|
161 |
+
total_elements = 0
|
162 |
+
|
163 |
+
if isinstance(tensors, dict):
|
164 |
+
tensors = tensors.values()
|
165 |
+
|
166 |
+
for tensor in tensors:
|
167 |
+
total_size += tensor.nelement() * tensor.element_size()
|
168 |
+
total_elements += tensor.nelement()
|
169 |
+
|
170 |
+
total_size_MB = total_size / (1024 ** 2)
|
171 |
+
total_elements_B = total_elements / 1e9
|
172 |
+
|
173 |
+
print(f"Total number of tensors: {len(tensors)}")
|
174 |
+
print(f"Total size of tensors: {total_size_MB:.2f} MB")
|
175 |
+
print(f"Total number of parameters: {total_elements_B:.3f} billion")
|
176 |
+
return
|
177 |
+
|
178 |
+
|
179 |
+
@torch.no_grad()
|
180 |
+
def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
|
181 |
+
batch_size = a.size(0)
|
182 |
+
|
183 |
+
if b is None:
|
184 |
+
b = torch.zeros_like(a)
|
185 |
+
|
186 |
+
if mask_a is None:
|
187 |
+
mask_a = torch.rand(batch_size) < probability_a
|
188 |
+
|
189 |
+
mask_a = mask_a.to(a.device)
|
190 |
+
mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
|
191 |
+
result = torch.where(mask_a, a, b)
|
192 |
+
return result
|
193 |
+
|
194 |
+
|
195 |
+
@torch.no_grad()
|
196 |
+
def zero_module(module):
|
197 |
+
for p in module.parameters():
|
198 |
+
p.detach().zero_()
|
199 |
+
return module
|
200 |
+
|
201 |
+
|
202 |
+
@torch.no_grad()
|
203 |
+
def supress_lower_channels(m, k, alpha=0.01):
|
204 |
+
data = m.weight.data.clone()
|
205 |
+
|
206 |
+
assert int(data.shape[1]) >= k
|
207 |
+
|
208 |
+
data[:, :k] = data[:, :k] * alpha
|
209 |
+
m.weight.data = data.contiguous().clone()
|
210 |
+
return m
|
211 |
+
|
212 |
+
|
213 |
+
def freeze_module(m):
|
214 |
+
if not hasattr(m, '_forward_inside_frozen_module'):
|
215 |
+
m._forward_inside_frozen_module = m.forward
|
216 |
+
m.requires_grad_(False)
|
217 |
+
m.forward = torch.no_grad()(m.forward)
|
218 |
+
return m
|
219 |
+
|
220 |
+
|
221 |
+
def get_latest_safetensors(folder_path):
|
222 |
+
safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
|
223 |
+
|
224 |
+
if not safetensors_files:
|
225 |
+
raise ValueError('No file to resume!')
|
226 |
+
|
227 |
+
latest_file = max(safetensors_files, key=os.path.getmtime)
|
228 |
+
latest_file = os.path.abspath(os.path.realpath(latest_file))
|
229 |
+
return latest_file
|
230 |
+
|
231 |
+
|
232 |
+
def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
|
233 |
+
tags = tags_str.split(', ')
|
234 |
+
tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
|
235 |
+
prompt = ', '.join(tags)
|
236 |
+
return prompt
|
237 |
+
|
238 |
+
|
239 |
+
def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
|
240 |
+
numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
|
241 |
+
if round_to_int:
|
242 |
+
numbers = np.round(numbers).astype(int)
|
243 |
+
return numbers.tolist()
|
244 |
+
|
245 |
+
|
246 |
+
def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
|
247 |
+
edges = np.linspace(0, 1, n + 1)
|
248 |
+
points = np.random.uniform(edges[:-1], edges[1:])
|
249 |
+
numbers = inclusive + (exclusive - inclusive) * points
|
250 |
+
if round_to_int:
|
251 |
+
numbers = np.round(numbers).astype(int)
|
252 |
+
return numbers.tolist()
|
253 |
+
|
254 |
+
|
255 |
+
def soft_append_bcthw(history, current, overlap=0):
|
256 |
+
if overlap <= 0:
|
257 |
+
return torch.cat([history, current], dim=2)
|
258 |
+
|
259 |
+
assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
|
260 |
+
assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
|
261 |
+
|
262 |
+
weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
|
263 |
+
blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
|
264 |
+
output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
|
265 |
+
|
266 |
+
return output.to(history)
|
267 |
+
|
268 |
+
|
269 |
+
def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
|
270 |
+
b, c, t, h, w = x.shape
|
271 |
+
|
272 |
+
per_row = b
|
273 |
+
for p in [6, 5, 4, 3, 2]:
|
274 |
+
if b % p == 0:
|
275 |
+
per_row = p
|
276 |
+
break
|
277 |
+
|
278 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
279 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
280 |
+
x = x.detach().cpu().to(torch.uint8)
|
281 |
+
x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
|
282 |
+
torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
|
283 |
+
return x
|
284 |
+
|
285 |
+
|
286 |
+
def save_bcthw_as_png(x, output_filename):
|
287 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
288 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
289 |
+
x = x.detach().cpu().to(torch.uint8)
|
290 |
+
x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
|
291 |
+
torchvision.io.write_png(x, output_filename)
|
292 |
+
return output_filename
|
293 |
+
|
294 |
+
|
295 |
+
def save_bchw_as_png(x, output_filename):
|
296 |
+
os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
|
297 |
+
x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
|
298 |
+
x = x.detach().cpu().to(torch.uint8)
|
299 |
+
x = einops.rearrange(x, 'b c h w -> c h (b w)')
|
300 |
+
torchvision.io.write_png(x, output_filename)
|
301 |
+
return output_filename
|
302 |
+
|
303 |
+
|
304 |
+
def add_tensors_with_padding(tensor1, tensor2):
|
305 |
+
if tensor1.shape == tensor2.shape:
|
306 |
+
return tensor1 + tensor2
|
307 |
+
|
308 |
+
shape1 = tensor1.shape
|
309 |
+
shape2 = tensor2.shape
|
310 |
+
|
311 |
+
new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
|
312 |
+
|
313 |
+
padded_tensor1 = torch.zeros(new_shape)
|
314 |
+
padded_tensor2 = torch.zeros(new_shape)
|
315 |
+
|
316 |
+
padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
|
317 |
+
padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
|
318 |
+
|
319 |
+
result = padded_tensor1 + padded_tensor2
|
320 |
+
return result
|
321 |
+
|
322 |
+
|
323 |
+
def print_free_mem():
|
324 |
+
torch.cuda.empty_cache()
|
325 |
+
free_mem, total_mem = torch.cuda.mem_get_info(0)
|
326 |
+
free_mem_mb = free_mem / (1024 ** 2)
|
327 |
+
total_mem_mb = total_mem / (1024 ** 2)
|
328 |
+
print(f"Free memory: {free_mem_mb:.2f} MB")
|
329 |
+
print(f"Total memory: {total_mem_mb:.2f} MB")
|
330 |
+
return
|
331 |
+
|
332 |
+
|
333 |
+
def print_gpu_parameters(device, state_dict, log_count=1):
|
334 |
+
summary = {"device": device, "keys_count": len(state_dict)}
|
335 |
+
|
336 |
+
logged_params = {}
|
337 |
+
for i, (key, tensor) in enumerate(state_dict.items()):
|
338 |
+
if i >= log_count:
|
339 |
+
break
|
340 |
+
logged_params[key] = tensor.flatten()[:3].tolist()
|
341 |
+
|
342 |
+
summary["params"] = logged_params
|
343 |
+
|
344 |
+
print(str(summary))
|
345 |
+
return
|
346 |
+
|
347 |
+
|
348 |
+
def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
|
349 |
+
from PIL import Image, ImageDraw, ImageFont
|
350 |
+
|
351 |
+
txt = Image.new("RGB", (width, height), color="white")
|
352 |
+
draw = ImageDraw.Draw(txt)
|
353 |
+
font = ImageFont.truetype(font_path, size=size)
|
354 |
+
|
355 |
+
if text == '':
|
356 |
+
return np.array(txt)
|
357 |
+
|
358 |
+
# Split text into lines that fit within the image width
|
359 |
+
lines = []
|
360 |
+
words = text.split()
|
361 |
+
current_line = words[0]
|
362 |
+
|
363 |
+
for word in words[1:]:
|
364 |
+
line_with_word = f"{current_line} {word}"
|
365 |
+
if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
|
366 |
+
current_line = line_with_word
|
367 |
+
else:
|
368 |
+
lines.append(current_line)
|
369 |
+
current_line = word
|
370 |
+
|
371 |
+
lines.append(current_line)
|
372 |
+
|
373 |
+
# Draw the text line by line
|
374 |
+
y = 0
|
375 |
+
line_height = draw.textbbox((0, 0), "A", font=font)[3]
|
376 |
+
|
377 |
+
for line in lines:
|
378 |
+
if y + line_height > height:
|
379 |
+
break # stop drawing if the next line will be outside the image
|
380 |
+
draw.text((0, y), line, fill="black", font=font)
|
381 |
+
y += line_height
|
382 |
+
|
383 |
+
return np.array(txt)
|
384 |
+
|
385 |
+
|
386 |
+
def blue_mark(x):
|
387 |
+
x = x.copy()
|
388 |
+
c = x[:, :, 2]
|
389 |
+
b = cv2.blur(c, (9, 9))
|
390 |
+
x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
|
391 |
+
return x
|
392 |
+
|
393 |
+
|
394 |
+
def green_mark(x):
|
395 |
+
x = x.copy()
|
396 |
+
x[:, :, 2] = -1
|
397 |
+
x[:, :, 0] = -1
|
398 |
+
return x
|
399 |
+
|
400 |
+
|
401 |
+
def frame_mark(x):
|
402 |
+
x = x.copy()
|
403 |
+
x[:64] = -1
|
404 |
+
x[-64:] = -1
|
405 |
+
x[:, :8] = 1
|
406 |
+
x[:, -8:] = 1
|
407 |
+
return x
|
408 |
+
|
409 |
+
|
410 |
+
@torch.inference_mode()
|
411 |
+
def pytorch2numpy(imgs):
|
412 |
+
results = []
|
413 |
+
for x in imgs:
|
414 |
+
y = x.movedim(0, -1)
|
415 |
+
y = y * 127.5 + 127.5
|
416 |
+
y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
|
417 |
+
results.append(y)
|
418 |
+
return results
|
419 |
+
|
420 |
+
|
421 |
+
@torch.inference_mode()
|
422 |
+
def numpy2pytorch(imgs):
|
423 |
+
h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
|
424 |
+
h = h.movedim(-1, 1)
|
425 |
+
return h
|
426 |
+
|
427 |
+
|
428 |
+
@torch.no_grad()
|
429 |
+
def duplicate_prefix_to_suffix(x, count, zero_out=False):
|
430 |
+
if zero_out:
|
431 |
+
return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
|
432 |
+
else:
|
433 |
+
return torch.cat([x, x[:count]], dim=0)
|
434 |
+
|
435 |
+
|
436 |
+
def weighted_mse(a, b, weight):
|
437 |
+
return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
|
438 |
+
|
439 |
+
|
440 |
+
def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
|
441 |
+
x = (x - x_min) / (x_max - x_min)
|
442 |
+
x = max(0.0, min(x, 1.0))
|
443 |
+
x = x ** sigma
|
444 |
+
return y_min + x * (y_max - y_min)
|
445 |
+
|
446 |
+
|
447 |
+
def expand_to_dims(x, target_dims):
|
448 |
+
return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
|
449 |
+
|
450 |
+
|
451 |
+
def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
|
452 |
+
if tensor is None:
|
453 |
+
return None
|
454 |
+
|
455 |
+
first_dim = tensor.shape[0]
|
456 |
+
|
457 |
+
if first_dim == batch_size:
|
458 |
+
return tensor
|
459 |
+
|
460 |
+
if batch_size % first_dim != 0:
|
461 |
+
raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
|
462 |
+
|
463 |
+
repeat_times = batch_size // first_dim
|
464 |
+
|
465 |
+
return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
|
466 |
+
|
467 |
+
|
468 |
+
def dim5(x):
|
469 |
+
return expand_to_dims(x, 5)
|
470 |
+
|
471 |
+
|
472 |
+
def dim4(x):
|
473 |
+
return expand_to_dims(x, 4)
|
474 |
+
|
475 |
+
|
476 |
+
def dim3(x):
|
477 |
+
return expand_to_dims(x, 3)
|
478 |
+
|
479 |
+
|
480 |
+
def crop_or_pad_yield_mask(x, length):
|
481 |
+
B, F, C = x.shape
|
482 |
+
device = x.device
|
483 |
+
dtype = x.dtype
|
484 |
+
|
485 |
+
if F < length:
|
486 |
+
y = torch.zeros((B, length, C), dtype=dtype, device=device)
|
487 |
+
mask = torch.zeros((B, length), dtype=torch.bool, device=device)
|
488 |
+
y[:, :F, :] = x
|
489 |
+
mask[:, :F] = True
|
490 |
+
return y, mask
|
491 |
+
|
492 |
+
return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
|
493 |
+
|
494 |
+
|
495 |
+
def extend_dim(x, dim, minimal_length, zero_pad=False):
|
496 |
+
original_length = int(x.shape[dim])
|
497 |
+
|
498 |
+
if original_length >= minimal_length:
|
499 |
+
return x
|
500 |
+
|
501 |
+
if zero_pad:
|
502 |
+
padding_shape = list(x.shape)
|
503 |
+
padding_shape[dim] = minimal_length - original_length
|
504 |
+
padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
|
505 |
+
else:
|
506 |
+
idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
|
507 |
+
last_element = x[idx]
|
508 |
+
padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
|
509 |
+
|
510 |
+
return torch.cat([x, padding], dim=dim)
|
511 |
+
|
512 |
+
|
513 |
+
def lazy_positional_encoding(t, repeats=None):
|
514 |
+
if not isinstance(t, list):
|
515 |
+
t = [t]
|
516 |
+
|
517 |
+
from diffusers.models.embeddings import get_timestep_embedding
|
518 |
+
|
519 |
+
te = torch.tensor(t)
|
520 |
+
te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
|
521 |
+
|
522 |
+
if repeats is None:
|
523 |
+
return te
|
524 |
+
|
525 |
+
te = te[:, None, :].expand(-1, repeats, -1)
|
526 |
+
|
527 |
+
return te
|
528 |
+
|
529 |
+
|
530 |
+
def state_dict_offset_merge(A, B, C=None):
|
531 |
+
result = {}
|
532 |
+
keys = A.keys()
|
533 |
+
|
534 |
+
for key in keys:
|
535 |
+
A_value = A[key]
|
536 |
+
B_value = B[key].to(A_value)
|
537 |
+
|
538 |
+
if C is None:
|
539 |
+
result[key] = A_value + B_value
|
540 |
+
else:
|
541 |
+
C_value = C[key].to(A_value)
|
542 |
+
result[key] = A_value + B_value - C_value
|
543 |
+
|
544 |
+
return result
|
545 |
+
|
546 |
+
|
547 |
+
def state_dict_weighted_merge(state_dicts, weights):
|
548 |
+
if len(state_dicts) != len(weights):
|
549 |
+
raise ValueError("Number of state dictionaries must match number of weights")
|
550 |
+
|
551 |
+
if not state_dicts:
|
552 |
+
return {}
|
553 |
+
|
554 |
+
total_weight = sum(weights)
|
555 |
+
|
556 |
+
if total_weight == 0:
|
557 |
+
raise ValueError("Sum of weights cannot be zero")
|
558 |
+
|
559 |
+
normalized_weights = [w / total_weight for w in weights]
|
560 |
+
|
561 |
+
keys = state_dicts[0].keys()
|
562 |
+
result = {}
|
563 |
+
|
564 |
+
for key in keys:
|
565 |
+
result[key] = state_dicts[0][key] * normalized_weights[0]
|
566 |
+
|
567 |
+
for i in range(1, len(state_dicts)):
|
568 |
+
state_dict_value = state_dicts[i][key].to(result[key])
|
569 |
+
result[key] += state_dict_value * normalized_weights[i]
|
570 |
+
|
571 |
+
return result
|
572 |
+
|
573 |
+
|
574 |
+
def group_files_by_folder(all_files):
|
575 |
+
grouped_files = {}
|
576 |
+
|
577 |
+
for file in all_files:
|
578 |
+
folder_name = os.path.basename(os.path.dirname(file))
|
579 |
+
if folder_name not in grouped_files:
|
580 |
+
grouped_files[folder_name] = []
|
581 |
+
grouped_files[folder_name].append(file)
|
582 |
+
|
583 |
+
list_of_lists = list(grouped_files.values())
|
584 |
+
return list_of_lists
|
585 |
+
|
586 |
+
|
587 |
+
def generate_timestamp():
|
588 |
+
now = datetime.datetime.now()
|
589 |
+
timestamp = now.strftime('%y%m%d_%H%M%S')
|
590 |
+
milliseconds = f"{int(now.microsecond / 1000):03d}"
|
591 |
+
random_number = random.randint(0, 9999)
|
592 |
+
return f"{timestamp}_{milliseconds}_{random_number}"
|
593 |
+
|
594 |
+
|
595 |
+
def write_PIL_image_with_png_info(image, metadata, path):
|
596 |
+
from PIL.PngImagePlugin import PngInfo
|
597 |
+
|
598 |
+
png_info = PngInfo()
|
599 |
+
for key, value in metadata.items():
|
600 |
+
png_info.add_text(key, value)
|
601 |
+
|
602 |
+
image.save(path, "PNG", pnginfo=png_info)
|
603 |
+
return image
|
604 |
+
|
605 |
+
|
606 |
+
def torch_safe_save(content, path):
|
607 |
+
torch.save(content, path + '_tmp')
|
608 |
+
os.replace(path + '_tmp', path)
|
609 |
+
return path
|
610 |
+
|
611 |
+
|
612 |
+
def move_optimizer_to_device(optimizer, device):
|
613 |
+
for state in optimizer.state.values():
|
614 |
+
for k, v in state.items():
|
615 |
+
if isinstance(v, torch.Tensor):
|
616 |
+
state[k] = v.to(device)
|
demo_utils/vae.py
ADDED
@@ -0,0 +1,390 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from einops import rearrange
|
3 |
+
import tensorrt as trt
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
from demo_utils.constant import ALL_INPUTS_NAMES, ZERO_VAE_CACHE
|
8 |
+
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, Upsample
|
9 |
+
|
10 |
+
CACHE_T = 2
|
11 |
+
|
12 |
+
|
13 |
+
class ResidualBlock(nn.Module):
|
14 |
+
|
15 |
+
def __init__(self, in_dim, out_dim, dropout=0.0):
|
16 |
+
super().__init__()
|
17 |
+
self.in_dim = in_dim
|
18 |
+
self.out_dim = out_dim
|
19 |
+
|
20 |
+
# layers
|
21 |
+
self.residual = nn.Sequential(
|
22 |
+
RMS_norm(in_dim, images=False), nn.SiLU(),
|
23 |
+
CausalConv3d(in_dim, out_dim, 3, padding=1),
|
24 |
+
RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
|
25 |
+
CausalConv3d(out_dim, out_dim, 3, padding=1))
|
26 |
+
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
27 |
+
if in_dim != out_dim else nn.Identity()
|
28 |
+
|
29 |
+
def forward(self, x, feat_cache_1, feat_cache_2):
|
30 |
+
h = self.shortcut(x)
|
31 |
+
feat_cache = feat_cache_1
|
32 |
+
out_feat_cache = []
|
33 |
+
for layer in self.residual:
|
34 |
+
if isinstance(layer, CausalConv3d):
|
35 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
36 |
+
if cache_x.shape[2] < 2 and feat_cache is not None:
|
37 |
+
# cache last frame of last two chunk
|
38 |
+
cache_x = torch.cat([
|
39 |
+
feat_cache[:, :, -1, :, :].unsqueeze(2).to(
|
40 |
+
cache_x.device), cache_x
|
41 |
+
],
|
42 |
+
dim=2)
|
43 |
+
x = layer(x, feat_cache)
|
44 |
+
out_feat_cache.append(cache_x)
|
45 |
+
feat_cache = feat_cache_2
|
46 |
+
else:
|
47 |
+
x = layer(x)
|
48 |
+
return x + h, *out_feat_cache
|
49 |
+
|
50 |
+
|
51 |
+
class Resample(nn.Module):
|
52 |
+
|
53 |
+
def __init__(self, dim, mode):
|
54 |
+
assert mode in ('none', 'upsample2d', 'upsample3d')
|
55 |
+
super().__init__()
|
56 |
+
self.dim = dim
|
57 |
+
self.mode = mode
|
58 |
+
|
59 |
+
# layers
|
60 |
+
if mode == 'upsample2d':
|
61 |
+
self.resample = nn.Sequential(
|
62 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
63 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
64 |
+
elif mode == 'upsample3d':
|
65 |
+
self.resample = nn.Sequential(
|
66 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
67 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
68 |
+
self.time_conv = CausalConv3d(
|
69 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
70 |
+
else:
|
71 |
+
self.resample = nn.Identity()
|
72 |
+
|
73 |
+
def forward(self, x, is_first_frame, feat_cache):
|
74 |
+
if self.mode == 'upsample3d':
|
75 |
+
b, c, t, h, w = x.size()
|
76 |
+
# x, out_feat_cache = torch.cond(
|
77 |
+
# is_first_frame,
|
78 |
+
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
|
79 |
+
# lambda: self.temporal_conv(x, feat_cache),
|
80 |
+
# )
|
81 |
+
# x, out_feat_cache = torch.cond(
|
82 |
+
# is_first_frame,
|
83 |
+
# lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
|
84 |
+
# lambda: self.temporal_conv(x, feat_cache),
|
85 |
+
# )
|
86 |
+
x, out_feat_cache = self.temporal_conv(x, is_first_frame, feat_cache)
|
87 |
+
out_feat_cache = torch.cond(
|
88 |
+
is_first_frame,
|
89 |
+
lambda: feat_cache.clone().contiguous(),
|
90 |
+
lambda: out_feat_cache.clone().contiguous(),
|
91 |
+
)
|
92 |
+
# if is_first_frame:
|
93 |
+
# x = torch.cat([torch.zeros_like(x), x], dim=2)
|
94 |
+
# out_feat_cache = feat_cache.clone()
|
95 |
+
# else:
|
96 |
+
# x, out_feat_cache = self.temporal_conv(x, feat_cache)
|
97 |
+
else:
|
98 |
+
out_feat_cache = None
|
99 |
+
t = x.shape[2]
|
100 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
101 |
+
x = self.resample(x)
|
102 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
103 |
+
return x, out_feat_cache
|
104 |
+
|
105 |
+
def temporal_conv(self, x, is_first_frame, feat_cache):
|
106 |
+
b, c, t, h, w = x.size()
|
107 |
+
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
108 |
+
if cache_x.shape[2] < 2 and feat_cache is not None:
|
109 |
+
cache_x = torch.cat([
|
110 |
+
torch.zeros_like(cache_x),
|
111 |
+
cache_x
|
112 |
+
], dim=2)
|
113 |
+
x = torch.cond(
|
114 |
+
is_first_frame,
|
115 |
+
lambda: torch.cat([torch.zeros_like(x), x], dim=1).contiguous(),
|
116 |
+
lambda: self.time_conv(x, feat_cache).contiguous(),
|
117 |
+
)
|
118 |
+
# x = self.time_conv(x, feat_cache)
|
119 |
+
out_feat_cache = cache_x
|
120 |
+
|
121 |
+
x = x.reshape(b, 2, c, t, h, w)
|
122 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
123 |
+
3)
|
124 |
+
x = x.reshape(b, c, t * 2, h, w)
|
125 |
+
return x.contiguous(), out_feat_cache.contiguous()
|
126 |
+
|
127 |
+
def init_weight(self, conv):
|
128 |
+
conv_weight = conv.weight
|
129 |
+
nn.init.zeros_(conv_weight)
|
130 |
+
c1, c2, t, h, w = conv_weight.size()
|
131 |
+
one_matrix = torch.eye(c1, c2)
|
132 |
+
init_matrix = one_matrix
|
133 |
+
nn.init.zeros_(conv_weight)
|
134 |
+
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
135 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
136 |
+
conv.weight.data.copy_(conv_weight)
|
137 |
+
nn.init.zeros_(conv.bias.data)
|
138 |
+
|
139 |
+
def init_weight2(self, conv):
|
140 |
+
conv_weight = conv.weight.data
|
141 |
+
nn.init.zeros_(conv_weight)
|
142 |
+
c1, c2, t, h, w = conv_weight.size()
|
143 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
144 |
+
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
145 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
146 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
147 |
+
conv.weight.data.copy_(conv_weight)
|
148 |
+
nn.init.zeros_(conv.bias.data)
|
149 |
+
|
150 |
+
|
151 |
+
class VAEDecoderWrapperSingle(nn.Module):
|
152 |
+
def __init__(self):
|
153 |
+
super().__init__()
|
154 |
+
self.decoder = VAEDecoder3d()
|
155 |
+
mean = [
|
156 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
157 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
158 |
+
]
|
159 |
+
std = [
|
160 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
161 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
162 |
+
]
|
163 |
+
self.mean = torch.tensor(mean, dtype=torch.float32)
|
164 |
+
self.std = torch.tensor(std, dtype=torch.float32)
|
165 |
+
self.z_dim = 16
|
166 |
+
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
|
167 |
+
|
168 |
+
def forward(
|
169 |
+
self,
|
170 |
+
z: torch.Tensor,
|
171 |
+
is_first_frame: torch.Tensor,
|
172 |
+
*feat_cache: List[torch.Tensor]
|
173 |
+
):
|
174 |
+
# from [batch_size, num_frames, num_channels, height, width]
|
175 |
+
# to [batch_size, num_channels, num_frames, height, width]
|
176 |
+
z = z.permute(0, 2, 1, 3, 4)
|
177 |
+
assert z.shape[2] == 1
|
178 |
+
feat_cache = list(feat_cache)
|
179 |
+
is_first_frame = is_first_frame.bool()
|
180 |
+
|
181 |
+
device, dtype = z.device, z.dtype
|
182 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
183 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
184 |
+
|
185 |
+
if isinstance(scale[0], torch.Tensor):
|
186 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
187 |
+
1, self.z_dim, 1, 1, 1)
|
188 |
+
else:
|
189 |
+
z = z / scale[1] + scale[0]
|
190 |
+
x = self.conv2(z)
|
191 |
+
out, feat_cache = self.decoder(x, is_first_frame, feat_cache=feat_cache)
|
192 |
+
out = out.clamp_(-1, 1)
|
193 |
+
# from [batch_size, num_channels, num_frames, height, width]
|
194 |
+
# to [batch_size, num_frames, num_channels, height, width]
|
195 |
+
out = out.permute(0, 2, 1, 3, 4)
|
196 |
+
return out, feat_cache
|
197 |
+
|
198 |
+
|
199 |
+
class VAEDecoder3d(nn.Module):
|
200 |
+
def __init__(self,
|
201 |
+
dim=96,
|
202 |
+
z_dim=16,
|
203 |
+
dim_mult=[1, 2, 4, 4],
|
204 |
+
num_res_blocks=2,
|
205 |
+
attn_scales=[],
|
206 |
+
temperal_upsample=[True, True, False],
|
207 |
+
dropout=0.0):
|
208 |
+
super().__init__()
|
209 |
+
self.dim = dim
|
210 |
+
self.z_dim = z_dim
|
211 |
+
self.dim_mult = dim_mult
|
212 |
+
self.num_res_blocks = num_res_blocks
|
213 |
+
self.attn_scales = attn_scales
|
214 |
+
self.temperal_upsample = temperal_upsample
|
215 |
+
self.cache_t = 2
|
216 |
+
self.decoder_conv_num = 32
|
217 |
+
|
218 |
+
# dimensions
|
219 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
220 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
221 |
+
|
222 |
+
# init block
|
223 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
224 |
+
|
225 |
+
# middle blocks
|
226 |
+
self.middle = nn.Sequential(
|
227 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
228 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
229 |
+
|
230 |
+
# upsample blocks
|
231 |
+
upsamples = []
|
232 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
233 |
+
# residual (+attention) blocks
|
234 |
+
if i == 1 or i == 2 or i == 3:
|
235 |
+
in_dim = in_dim // 2
|
236 |
+
for _ in range(num_res_blocks + 1):
|
237 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
238 |
+
if scale in attn_scales:
|
239 |
+
upsamples.append(AttentionBlock(out_dim))
|
240 |
+
in_dim = out_dim
|
241 |
+
|
242 |
+
# upsample block
|
243 |
+
if i != len(dim_mult) - 1:
|
244 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
245 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
246 |
+
scale *= 2.0
|
247 |
+
self.upsamples = nn.Sequential(*upsamples)
|
248 |
+
|
249 |
+
# output blocks
|
250 |
+
self.head = nn.Sequential(
|
251 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
252 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
253 |
+
|
254 |
+
def forward(
|
255 |
+
self,
|
256 |
+
x: torch.Tensor,
|
257 |
+
is_first_frame: torch.Tensor,
|
258 |
+
feat_cache: List[torch.Tensor]
|
259 |
+
):
|
260 |
+
idx = 0
|
261 |
+
out_feat_cache = []
|
262 |
+
|
263 |
+
# conv1
|
264 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
265 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
266 |
+
# cache last frame of last two chunk
|
267 |
+
cache_x = torch.cat([
|
268 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
269 |
+
cache_x.device), cache_x
|
270 |
+
],
|
271 |
+
dim=2)
|
272 |
+
x = self.conv1(x, feat_cache[idx])
|
273 |
+
out_feat_cache.append(cache_x)
|
274 |
+
idx += 1
|
275 |
+
|
276 |
+
# middle
|
277 |
+
for layer in self.middle:
|
278 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
279 |
+
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
|
280 |
+
idx += 2
|
281 |
+
out_feat_cache.append(out_feat_cache_1)
|
282 |
+
out_feat_cache.append(out_feat_cache_2)
|
283 |
+
else:
|
284 |
+
x = layer(x)
|
285 |
+
|
286 |
+
# upsamples
|
287 |
+
for layer in self.upsamples:
|
288 |
+
if isinstance(layer, Resample):
|
289 |
+
x, cache_x = layer(x, is_first_frame, feat_cache[idx])
|
290 |
+
if cache_x is not None:
|
291 |
+
out_feat_cache.append(cache_x)
|
292 |
+
idx += 1
|
293 |
+
else:
|
294 |
+
x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
|
295 |
+
idx += 2
|
296 |
+
out_feat_cache.append(out_feat_cache_1)
|
297 |
+
out_feat_cache.append(out_feat_cache_2)
|
298 |
+
|
299 |
+
# head
|
300 |
+
for layer in self.head:
|
301 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
302 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
303 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
304 |
+
# cache last frame of last two chunk
|
305 |
+
cache_x = torch.cat([
|
306 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
307 |
+
cache_x.device), cache_x
|
308 |
+
],
|
309 |
+
dim=2)
|
310 |
+
x = layer(x, feat_cache[idx])
|
311 |
+
out_feat_cache.append(cache_x)
|
312 |
+
idx += 1
|
313 |
+
else:
|
314 |
+
x = layer(x)
|
315 |
+
return x, out_feat_cache
|
316 |
+
|
317 |
+
|
318 |
+
class VAETRTWrapper():
|
319 |
+
def __init__(self):
|
320 |
+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
321 |
+
with open("checkpoints/vae_decoder_int8.trt", "rb") as f, trt.Runtime(TRT_LOGGER) as rt:
|
322 |
+
self.engine: trt.ICudaEngine = rt.deserialize_cuda_engine(f.read())
|
323 |
+
|
324 |
+
self.context: trt.IExecutionContext = self.engine.create_execution_context()
|
325 |
+
self.stream = torch.cuda.current_stream().cuda_stream
|
326 |
+
|
327 |
+
# ──────────────────────────────
|
328 |
+
# 2️⃣ Feed the engine with tensors
|
329 |
+
# (name-based API in TRT ≥10)
|
330 |
+
# ──────────────────────────────
|
331 |
+
self.dtype_map = {
|
332 |
+
trt.float32: torch.float32,
|
333 |
+
trt.float16: torch.float16,
|
334 |
+
trt.int8: torch.int8,
|
335 |
+
trt.int32: torch.int32,
|
336 |
+
}
|
337 |
+
test_input = torch.zeros(1, 16, 1, 60, 104).cuda().half()
|
338 |
+
is_first_frame = torch.tensor(1.0).cuda().half()
|
339 |
+
test_cache_inputs = [c.cuda().half() for c in ZERO_VAE_CACHE]
|
340 |
+
test_inputs = [test_input, is_first_frame] + test_cache_inputs
|
341 |
+
|
342 |
+
# keep references so buffers stay alive
|
343 |
+
self.device_buffers, self.outputs = {}, []
|
344 |
+
|
345 |
+
# ---- inputs ----
|
346 |
+
for i, name in enumerate(ALL_INPUTS_NAMES):
|
347 |
+
tensor, scale = test_inputs[i], 1 / 127
|
348 |
+
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
|
349 |
+
|
350 |
+
# dynamic shapes
|
351 |
+
if -1 in self.engine.get_tensor_shape(name):
|
352 |
+
# new API :contentReference[oaicite:0]{index=0}
|
353 |
+
self.context.set_input_shape(name, tuple(tensor.shape))
|
354 |
+
|
355 |
+
# replaces bindings[] :contentReference[oaicite:1]{index=1}
|
356 |
+
self.context.set_tensor_address(name, int(tensor.data_ptr()))
|
357 |
+
self.device_buffers[name] = tensor # keep pointer alive
|
358 |
+
|
359 |
+
# ---- (after all input shapes are known) infer output shapes ----
|
360 |
+
# propagates shapes :contentReference[oaicite:2]{index=2}
|
361 |
+
self.context.infer_shapes()
|
362 |
+
|
363 |
+
for i in range(self.engine.num_io_tensors):
|
364 |
+
name = self.engine.get_tensor_name(i)
|
365 |
+
# replaces binding_is_input :contentReference[oaicite:3]{index=3}
|
366 |
+
if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
367 |
+
shape = tuple(self.context.get_tensor_shape(name))
|
368 |
+
dtype = self.dtype_map[self.engine.get_tensor_dtype(name)]
|
369 |
+
out = torch.empty(shape, dtype=dtype, device="cuda").contiguous()
|
370 |
+
|
371 |
+
self.context.set_tensor_address(name, int(out.data_ptr()))
|
372 |
+
self.outputs.append(out)
|
373 |
+
self.device_buffers[name] = out
|
374 |
+
|
375 |
+
# helper to quant-convert on the fly
|
376 |
+
def quantize_if_needed(self, t, expected_dtype, scale):
|
377 |
+
if expected_dtype == trt.int8 and t.dtype != torch.int8:
|
378 |
+
t = torch.clamp((t / scale).round(), -128, 127).to(torch.int8).contiguous()
|
379 |
+
return t # keep pointer alive
|
380 |
+
|
381 |
+
def forward(self, *test_inputs):
|
382 |
+
for i, name in enumerate(ALL_INPUTS_NAMES):
|
383 |
+
tensor, scale = test_inputs[i], 1 / 127
|
384 |
+
tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
|
385 |
+
self.context.set_tensor_address(name, int(tensor.data_ptr()))
|
386 |
+
self.device_buffers[name] = tensor
|
387 |
+
|
388 |
+
self.context.execute_async_v3(stream_handle=self.stream)
|
389 |
+
torch.cuda.current_stream().synchronize()
|
390 |
+
return self.outputs
|
demo_utils/vae_block3.py
ADDED
@@ -0,0 +1,291 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from einops import rearrange
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, ResidualBlock, Upsample
|
7 |
+
|
8 |
+
|
9 |
+
class Resample(nn.Module):
|
10 |
+
|
11 |
+
def __init__(self, dim, mode):
|
12 |
+
assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
|
13 |
+
'downsample3d')
|
14 |
+
super().__init__()
|
15 |
+
self.dim = dim
|
16 |
+
self.mode = mode
|
17 |
+
self.cache_t = 2
|
18 |
+
|
19 |
+
# layers
|
20 |
+
if mode == 'upsample2d':
|
21 |
+
self.resample = nn.Sequential(
|
22 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
23 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
24 |
+
elif mode == 'upsample3d':
|
25 |
+
self.resample = nn.Sequential(
|
26 |
+
Upsample(scale_factor=(2., 2.), mode='nearest'),
|
27 |
+
nn.Conv2d(dim, dim // 2, 3, padding=1))
|
28 |
+
self.time_conv = CausalConv3d(
|
29 |
+
dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
30 |
+
|
31 |
+
elif mode == 'downsample2d':
|
32 |
+
self.resample = nn.Sequential(
|
33 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
34 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
35 |
+
elif mode == 'downsample3d':
|
36 |
+
self.resample = nn.Sequential(
|
37 |
+
nn.ZeroPad2d((0, 1, 0, 1)),
|
38 |
+
nn.Conv2d(dim, dim, 3, stride=(2, 2)))
|
39 |
+
self.time_conv = CausalConv3d(
|
40 |
+
dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
|
41 |
+
|
42 |
+
else:
|
43 |
+
self.resample = nn.Identity()
|
44 |
+
|
45 |
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
46 |
+
b, c, t, h, w = x.size()
|
47 |
+
if self.mode == 'upsample3d':
|
48 |
+
if feat_cache is not None:
|
49 |
+
idx = feat_idx[0]
|
50 |
+
if feat_cache[idx] is None:
|
51 |
+
feat_cache[idx] = 'Rep'
|
52 |
+
feat_idx[0] += 1
|
53 |
+
else:
|
54 |
+
|
55 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
56 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
57 |
+
idx] is not None and feat_cache[idx] != 'Rep':
|
58 |
+
# cache last frame of last two chunk
|
59 |
+
cache_x = torch.cat([
|
60 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
61 |
+
cache_x.device), cache_x
|
62 |
+
],
|
63 |
+
dim=2)
|
64 |
+
if cache_x.shape[2] < 2 and feat_cache[
|
65 |
+
idx] is not None and feat_cache[idx] == 'Rep':
|
66 |
+
cache_x = torch.cat([
|
67 |
+
torch.zeros_like(cache_x).to(cache_x.device),
|
68 |
+
cache_x
|
69 |
+
],
|
70 |
+
dim=2)
|
71 |
+
if feat_cache[idx] == 'Rep':
|
72 |
+
x = self.time_conv(x)
|
73 |
+
else:
|
74 |
+
x = self.time_conv(x, feat_cache[idx])
|
75 |
+
feat_cache[idx] = cache_x
|
76 |
+
feat_idx[0] += 1
|
77 |
+
|
78 |
+
x = x.reshape(b, 2, c, t, h, w)
|
79 |
+
x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
|
80 |
+
3)
|
81 |
+
x = x.reshape(b, c, t * 2, h, w)
|
82 |
+
t = x.shape[2]
|
83 |
+
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
84 |
+
x = self.resample(x)
|
85 |
+
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
|
86 |
+
|
87 |
+
if self.mode == 'downsample3d':
|
88 |
+
if feat_cache is not None:
|
89 |
+
idx = feat_idx[0]
|
90 |
+
if feat_cache[idx] is None:
|
91 |
+
feat_cache[idx] = x.clone()
|
92 |
+
feat_idx[0] += 1
|
93 |
+
else:
|
94 |
+
|
95 |
+
cache_x = x[:, :, -1:, :, :].clone()
|
96 |
+
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
97 |
+
# # cache last frame of last two chunk
|
98 |
+
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
99 |
+
|
100 |
+
x = self.time_conv(
|
101 |
+
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
102 |
+
feat_cache[idx] = cache_x
|
103 |
+
feat_idx[0] += 1
|
104 |
+
return x
|
105 |
+
|
106 |
+
def init_weight(self, conv):
|
107 |
+
conv_weight = conv.weight
|
108 |
+
nn.init.zeros_(conv_weight)
|
109 |
+
c1, c2, t, h, w = conv_weight.size()
|
110 |
+
one_matrix = torch.eye(c1, c2)
|
111 |
+
init_matrix = one_matrix
|
112 |
+
nn.init.zeros_(conv_weight)
|
113 |
+
# conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
|
114 |
+
conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
|
115 |
+
conv.weight.data.copy_(conv_weight)
|
116 |
+
nn.init.zeros_(conv.bias.data)
|
117 |
+
|
118 |
+
def init_weight2(self, conv):
|
119 |
+
conv_weight = conv.weight.data
|
120 |
+
nn.init.zeros_(conv_weight)
|
121 |
+
c1, c2, t, h, w = conv_weight.size()
|
122 |
+
init_matrix = torch.eye(c1 // 2, c2)
|
123 |
+
# init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
|
124 |
+
conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
|
125 |
+
conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
|
126 |
+
conv.weight.data.copy_(conv_weight)
|
127 |
+
nn.init.zeros_(conv.bias.data)
|
128 |
+
|
129 |
+
|
130 |
+
class VAEDecoderWrapper(nn.Module):
|
131 |
+
def __init__(self):
|
132 |
+
super().__init__()
|
133 |
+
self.decoder = VAEDecoder3d()
|
134 |
+
mean = [
|
135 |
+
-0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
|
136 |
+
0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
|
137 |
+
]
|
138 |
+
std = [
|
139 |
+
2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
|
140 |
+
3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
|
141 |
+
]
|
142 |
+
self.mean = torch.tensor(mean, dtype=torch.float32)
|
143 |
+
self.std = torch.tensor(std, dtype=torch.float32)
|
144 |
+
self.z_dim = 16
|
145 |
+
self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
|
146 |
+
|
147 |
+
def forward(
|
148 |
+
self,
|
149 |
+
z: torch.Tensor,
|
150 |
+
*feat_cache: List[torch.Tensor]
|
151 |
+
):
|
152 |
+
# from [batch_size, num_frames, num_channels, height, width]
|
153 |
+
# to [batch_size, num_channels, num_frames, height, width]
|
154 |
+
z = z.permute(0, 2, 1, 3, 4)
|
155 |
+
feat_cache = list(feat_cache)
|
156 |
+
print("Length of feat_cache: ", len(feat_cache))
|
157 |
+
|
158 |
+
device, dtype = z.device, z.dtype
|
159 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
160 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
161 |
+
|
162 |
+
if isinstance(scale[0], torch.Tensor):
|
163 |
+
z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
|
164 |
+
1, self.z_dim, 1, 1, 1)
|
165 |
+
else:
|
166 |
+
z = z / scale[1] + scale[0]
|
167 |
+
iter_ = z.shape[2]
|
168 |
+
x = self.conv2(z)
|
169 |
+
for i in range(iter_):
|
170 |
+
if i == 0:
|
171 |
+
out, feat_cache = self.decoder(
|
172 |
+
x[:, :, i:i + 1, :, :],
|
173 |
+
feat_cache=feat_cache)
|
174 |
+
else:
|
175 |
+
out_, feat_cache = self.decoder(
|
176 |
+
x[:, :, i:i + 1, :, :],
|
177 |
+
feat_cache=feat_cache)
|
178 |
+
out = torch.cat([out, out_], 2)
|
179 |
+
|
180 |
+
out = out.float().clamp_(-1, 1)
|
181 |
+
# from [batch_size, num_channels, num_frames, height, width]
|
182 |
+
# to [batch_size, num_frames, num_channels, height, width]
|
183 |
+
out = out.permute(0, 2, 1, 3, 4)
|
184 |
+
return out, feat_cache
|
185 |
+
|
186 |
+
|
187 |
+
class VAEDecoder3d(nn.Module):
|
188 |
+
def __init__(self,
|
189 |
+
dim=96,
|
190 |
+
z_dim=16,
|
191 |
+
dim_mult=[1, 2, 4, 4],
|
192 |
+
num_res_blocks=2,
|
193 |
+
attn_scales=[],
|
194 |
+
temperal_upsample=[True, True, False],
|
195 |
+
dropout=0.0):
|
196 |
+
super().__init__()
|
197 |
+
self.dim = dim
|
198 |
+
self.z_dim = z_dim
|
199 |
+
self.dim_mult = dim_mult
|
200 |
+
self.num_res_blocks = num_res_blocks
|
201 |
+
self.attn_scales = attn_scales
|
202 |
+
self.temperal_upsample = temperal_upsample
|
203 |
+
self.cache_t = 2
|
204 |
+
self.decoder_conv_num = 32
|
205 |
+
|
206 |
+
# dimensions
|
207 |
+
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
208 |
+
scale = 1.0 / 2**(len(dim_mult) - 2)
|
209 |
+
|
210 |
+
# init block
|
211 |
+
self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
|
212 |
+
|
213 |
+
# middle blocks
|
214 |
+
self.middle = nn.Sequential(
|
215 |
+
ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
|
216 |
+
ResidualBlock(dims[0], dims[0], dropout))
|
217 |
+
|
218 |
+
# upsample blocks
|
219 |
+
upsamples = []
|
220 |
+
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
221 |
+
# residual (+attention) blocks
|
222 |
+
if i == 1 or i == 2 or i == 3:
|
223 |
+
in_dim = in_dim // 2
|
224 |
+
for _ in range(num_res_blocks + 1):
|
225 |
+
upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
|
226 |
+
if scale in attn_scales:
|
227 |
+
upsamples.append(AttentionBlock(out_dim))
|
228 |
+
in_dim = out_dim
|
229 |
+
|
230 |
+
# upsample block
|
231 |
+
if i != len(dim_mult) - 1:
|
232 |
+
mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
|
233 |
+
upsamples.append(Resample(out_dim, mode=mode))
|
234 |
+
scale *= 2.0
|
235 |
+
self.upsamples = nn.Sequential(*upsamples)
|
236 |
+
|
237 |
+
# output blocks
|
238 |
+
self.head = nn.Sequential(
|
239 |
+
RMS_norm(out_dim, images=False), nn.SiLU(),
|
240 |
+
CausalConv3d(out_dim, 3, 3, padding=1))
|
241 |
+
|
242 |
+
def forward(
|
243 |
+
self,
|
244 |
+
x: torch.Tensor,
|
245 |
+
feat_cache: List[torch.Tensor]
|
246 |
+
):
|
247 |
+
feat_idx = [0]
|
248 |
+
|
249 |
+
# conv1
|
250 |
+
idx = feat_idx[0]
|
251 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
252 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
253 |
+
# cache last frame of last two chunk
|
254 |
+
cache_x = torch.cat([
|
255 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
256 |
+
cache_x.device), cache_x
|
257 |
+
],
|
258 |
+
dim=2)
|
259 |
+
x = self.conv1(x, feat_cache[idx])
|
260 |
+
feat_cache[idx] = cache_x
|
261 |
+
feat_idx[0] += 1
|
262 |
+
|
263 |
+
# middle
|
264 |
+
for layer in self.middle:
|
265 |
+
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
266 |
+
x = layer(x, feat_cache, feat_idx)
|
267 |
+
else:
|
268 |
+
x = layer(x)
|
269 |
+
|
270 |
+
# upsamples
|
271 |
+
for layer in self.upsamples:
|
272 |
+
x = layer(x, feat_cache, feat_idx)
|
273 |
+
|
274 |
+
# head
|
275 |
+
for layer in self.head:
|
276 |
+
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
277 |
+
idx = feat_idx[0]
|
278 |
+
cache_x = x[:, :, -self.cache_t:, :, :].clone()
|
279 |
+
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
280 |
+
# cache last frame of last two chunk
|
281 |
+
cache_x = torch.cat([
|
282 |
+
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
283 |
+
cache_x.device), cache_x
|
284 |
+
],
|
285 |
+
dim=2)
|
286 |
+
x = layer(x, feat_cache[idx])
|
287 |
+
feat_cache[idx] = cache_x
|
288 |
+
feat_idx[0] += 1
|
289 |
+
else:
|
290 |
+
x = layer(x)
|
291 |
+
return x, feat_cache
|
demo_utils/vae_torch2trt.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ---- INT8 (optional) ----
|
2 |
+
from demo_utils.vae import (
|
3 |
+
VAEDecoderWrapperSingle, # main nn.Module
|
4 |
+
ZERO_VAE_CACHE # helper constants shipped with your code base
|
5 |
+
)
|
6 |
+
import pycuda.driver as cuda # ← add
|
7 |
+
import pycuda.autoinit # noqa
|
8 |
+
|
9 |
+
import sys
|
10 |
+
from pathlib import Path
|
11 |
+
|
12 |
+
import torch
|
13 |
+
import tensorrt as trt
|
14 |
+
|
15 |
+
from utils.dataset import ShardingLMDBDataset
|
16 |
+
|
17 |
+
data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
|
18 |
+
dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
|
19 |
+
dataloader = torch.utils.data.DataLoader(
|
20 |
+
dataset,
|
21 |
+
batch_size=1,
|
22 |
+
num_workers=0
|
23 |
+
)
|
24 |
+
|
25 |
+
# ─────────────────────────────────────────────────────────
|
26 |
+
# 1️⃣ Bring the PyTorch model into scope
|
27 |
+
# (all code you pasted lives in `vae_decoder.py`)
|
28 |
+
# ─────────────────────────────────────────────────────────
|
29 |
+
|
30 |
+
# --- dummy tensors (exact shapes you posted) ---
|
31 |
+
dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
|
32 |
+
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
|
33 |
+
dummy_cache_input = [
|
34 |
+
torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
|
35 |
+
for s in ZERO_VAE_CACHE # keep exactly the same ordering
|
36 |
+
]
|
37 |
+
inputs = [dummy_input, is_first_frame, *dummy_cache_input]
|
38 |
+
|
39 |
+
# ─────────────────────────────────────────────────────────
|
40 |
+
# 2️⃣ Export → ONNX
|
41 |
+
# ─────────────────────────────────────────────────────────
|
42 |
+
model = VAEDecoderWrapperSingle().half().cuda().eval()
|
43 |
+
|
44 |
+
vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
|
45 |
+
decoder_state_dict = {}
|
46 |
+
for key, value in vae_state_dict.items():
|
47 |
+
if 'decoder.' in key or 'conv2' in key:
|
48 |
+
decoder_state_dict[key] = value
|
49 |
+
model.load_state_dict(decoder_state_dict)
|
50 |
+
model = model.half().cuda().eval() # only batch dim dynamic
|
51 |
+
|
52 |
+
onnx_path = Path("vae_decoder.onnx")
|
53 |
+
feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
|
54 |
+
all_inputs_names = ["z", "use_cache"] + feat_names
|
55 |
+
|
56 |
+
with torch.inference_mode():
|
57 |
+
torch.onnx.export(
|
58 |
+
model,
|
59 |
+
tuple(inputs), # must be a tuple
|
60 |
+
onnx_path.as_posix(),
|
61 |
+
input_names=all_inputs_names,
|
62 |
+
output_names=["rgb_out", "cache_out"],
|
63 |
+
opset_version=17,
|
64 |
+
do_constant_folding=True,
|
65 |
+
dynamo=True
|
66 |
+
)
|
67 |
+
print(f"✅ ONNX graph saved to {onnx_path.resolve()}")
|
68 |
+
|
69 |
+
# (Optional) quick sanity-check with ONNX-Runtime
|
70 |
+
try:
|
71 |
+
import onnxruntime as ort
|
72 |
+
sess = ort.InferenceSession(onnx_path.as_posix(),
|
73 |
+
providers=["CUDAExecutionProvider"])
|
74 |
+
ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
|
75 |
+
_ = sess.run(None, ort_inputs)
|
76 |
+
print("✅ ONNX graph is executable")
|
77 |
+
except Exception as e:
|
78 |
+
print("⚠️ ONNX check failed:", e)
|
79 |
+
|
80 |
+
# ─────────────────────────────────────────────────────────
|
81 |
+
# 3️⃣ Build the TensorRT engine
|
82 |
+
# ─────────────────────────────────────────────────────────
|
83 |
+
TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
|
84 |
+
builder = trt.Builder(TRT_LOGGER)
|
85 |
+
network = builder.create_network(
|
86 |
+
1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
|
87 |
+
parser = trt.OnnxParser(network, TRT_LOGGER)
|
88 |
+
|
89 |
+
with open(onnx_path, "rb") as f:
|
90 |
+
if not parser.parse(f.read()):
|
91 |
+
for i in range(parser.num_errors):
|
92 |
+
print(parser.get_error(i))
|
93 |
+
sys.exit("❌ ONNX → TRT parsing failed")
|
94 |
+
|
95 |
+
config = builder.create_builder_config()
|
96 |
+
|
97 |
+
|
98 |
+
def set_workspace(config, bytes_):
|
99 |
+
"""Version-agnostic workspace limit."""
|
100 |
+
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
|
101 |
+
config.max_workspace_size = bytes_
|
102 |
+
else: # TRT 10+
|
103 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)
|
104 |
+
|
105 |
+
|
106 |
+
# …
|
107 |
+
config = builder.create_builder_config()
|
108 |
+
set_workspace(config, 4 << 30) # 4 GB
|
109 |
+
# 4 GB
|
110 |
+
|
111 |
+
if builder.platform_has_fast_fp16:
|
112 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
113 |
+
|
114 |
+
# ---- INT8 (optional) ----
|
115 |
+
# provide a calibrator if you need an INT8 engine; comment this
|
116 |
+
# block if you only care about FP16.
|
117 |
+
# ─────────────────────────────────────────────────────────
|
118 |
+
# helper: version-agnostic workspace limit
|
119 |
+
# ─────────────────────────────────────────────────────────
|
120 |
+
|
121 |
+
|
122 |
+
def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
|
123 |
+
"""
|
124 |
+
TRT < 10.x → config.max_workspace_size
|
125 |
+
TRT ≥ 10.x → config.set_memory_pool_limit(...)
|
126 |
+
"""
|
127 |
+
if hasattr(config, "max_workspace_size"): # TRT 8 / 9
|
128 |
+
config.max_workspace_size = bytes_
|
129 |
+
else: # TRT 10+
|
130 |
+
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
|
131 |
+
bytes_)
|
132 |
+
|
133 |
+
# ─────────────────────────────────────────────────────────
|
134 |
+
# (optional) INT-8 calibrator
|
135 |
+
# ─────────────────────────────────────────────────────────
|
136 |
+
# ‼ Only keep this block if you really need INT-8 ‼ # gracefully skip if PyCUDA not present
|
137 |
+
|
138 |
+
|
139 |
+
class VAECalibrator(trt.IInt8EntropyCalibrator2):
|
140 |
+
def __init__(self, loader, cache="calibration.cache", max_batches=10):
|
141 |
+
super().__init__()
|
142 |
+
self.loader = iter(loader)
|
143 |
+
self.batch_size = loader.batch_size or 1
|
144 |
+
self.max_batches = max_batches
|
145 |
+
self.count = 0
|
146 |
+
self.cache_file = cache
|
147 |
+
self.stream = cuda.Stream()
|
148 |
+
self.dev_ptrs = {}
|
149 |
+
|
150 |
+
# --- TRT 10 needs BOTH spellings ---
|
151 |
+
def get_batch_size(self):
|
152 |
+
return self.batch_size
|
153 |
+
|
154 |
+
def getBatchSize(self):
|
155 |
+
return self.batch_size
|
156 |
+
|
157 |
+
def get_batch(self, names):
|
158 |
+
if self.count >= self.max_batches:
|
159 |
+
return None
|
160 |
+
|
161 |
+
# Randomly sample a number from 1 to 10
|
162 |
+
import random
|
163 |
+
vae_idx = random.randint(0, 10)
|
164 |
+
data = next(self.loader)
|
165 |
+
|
166 |
+
latent = data['ode_latent'][0][:, :1]
|
167 |
+
is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
|
168 |
+
feat_cache = ZERO_VAE_CACHE
|
169 |
+
for i in range(vae_idx):
|
170 |
+
inputs = [latent, is_first_frame, *feat_cache]
|
171 |
+
with torch.inference_mode():
|
172 |
+
outputs = model(*inputs)
|
173 |
+
latent = data['ode_latent'][0][:, i + 1:i + 2]
|
174 |
+
is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
|
175 |
+
feat_cache = outputs[1:]
|
176 |
+
|
177 |
+
# -------- ensure context is current --------
|
178 |
+
z_np = latent.cpu().numpy().astype('float32')
|
179 |
+
|
180 |
+
ptrs = [] # list[int] – one entry per name
|
181 |
+
for name in names: # <-- match TRT's binding order
|
182 |
+
if name == "z":
|
183 |
+
arr = z_np
|
184 |
+
elif name == "use_cache":
|
185 |
+
arr = is_first_frame.cpu().numpy().astype('float32')
|
186 |
+
else:
|
187 |
+
idx = int(name.split('_')[-1]) # "vae_cache_17" -> 17
|
188 |
+
arr = feat_cache[idx].cpu().numpy().astype('float32')
|
189 |
+
|
190 |
+
if name not in self.dev_ptrs:
|
191 |
+
self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)
|
192 |
+
|
193 |
+
cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
|
194 |
+
ptrs.append(int(self.dev_ptrs[name])) # ***int() is required***
|
195 |
+
|
196 |
+
self.stream.synchronize()
|
197 |
+
self.count += 1
|
198 |
+
print(f"Calibration batch {self.count}/{self.max_batches}")
|
199 |
+
return ptrs
|
200 |
+
|
201 |
+
# --- calibration-cache helpers (both spellings) ---
|
202 |
+
def read_calibration_cache(self):
|
203 |
+
try:
|
204 |
+
with open(self.cache_file, "rb") as f:
|
205 |
+
return f.read()
|
206 |
+
except FileNotFoundError:
|
207 |
+
return None
|
208 |
+
|
209 |
+
def readCalibrationCache(self):
|
210 |
+
return self.read_calibration_cache()
|
211 |
+
|
212 |
+
def write_calibration_cache(self, cache):
|
213 |
+
with open(self.cache_file, "wb") as f:
|
214 |
+
f.write(cache)
|
215 |
+
|
216 |
+
def writeCalibrationCache(self, cache):
|
217 |
+
self.write_calibration_cache(cache)
|
218 |
+
|
219 |
+
|
220 |
+
# ─────────────────────────────────────────────────────────
|
221 |
+
# Builder-config + optimisation profile
|
222 |
+
# ─────────────────────────────────────────────────────────
|
223 |
+
config = builder.create_builder_config()
|
224 |
+
set_workspace(config, 4 << 30) # 4 GB
|
225 |
+
|
226 |
+
# ► enable FP16 if possible
|
227 |
+
if builder.platform_has_fast_fp16:
|
228 |
+
config.set_flag(trt.BuilderFlag.FP16)
|
229 |
+
|
230 |
+
# ► enable INT-8 (delete this block if you don’t need it)
|
231 |
+
if cuda is not None:
|
232 |
+
config.set_flag(trt.BuilderFlag.INT8)
|
233 |
+
# supply any representative batch you like – here we reuse the latent z
|
234 |
+
calib = VAECalibrator(dataloader)
|
235 |
+
# TRT-10 renamed the setter:
|
236 |
+
if hasattr(config, "set_int8_calibrator"): # TRT 10+
|
237 |
+
config.set_int8_calibrator(calib)
|
238 |
+
else: # TRT ≤ 9
|
239 |
+
config.int8_calibrator = calib
|
240 |
+
|
241 |
+
# ---- optimisation profile ----
|
242 |
+
profile = builder.create_optimization_profile()
|
243 |
+
profile.set_shape(all_inputs_names[0], # latent z
|
244 |
+
min=(1, 1, 16, 60, 104),
|
245 |
+
opt=(1, 1, 16, 60, 104),
|
246 |
+
max=(1, 1, 16, 60, 104))
|
247 |
+
profile.set_shape("use_cache", # scalar flag
|
248 |
+
min=(1,), opt=(1,), max=(1,))
|
249 |
+
for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
|
250 |
+
profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)
|
251 |
+
|
252 |
+
config.add_optimization_profile(profile)
|
253 |
+
|
254 |
+
# ─────────────────────────────────────────────────────────
|
255 |
+
# Build the engine (API changed in TRT-10)
|
256 |
+
# ─────────────────────────────────────────────────────────
|
257 |
+
print("⚙️ Building engine … (can take a minute)")
|
258 |
+
|
259 |
+
if hasattr(builder, "build_serialized_network"): # TRT 10+
|
260 |
+
serialized_engine = builder.build_serialized_network(network, config)
|
261 |
+
assert serialized_engine is not None, "build_serialized_network() failed"
|
262 |
+
plan_path = Path("checkpoints/vae_decoder_int8.trt")
|
263 |
+
plan_path.write_bytes(serialized_engine)
|
264 |
+
engine_bytes = serialized_engine # keep for smoke-test
|
265 |
+
else: # TRT ≤ 9
|
266 |
+
engine = builder.build_engine(network, config)
|
267 |
+
assert engine is not None, "build_engine() returned None"
|
268 |
+
plan_path = Path("checkpoints/vae_decoder_int8.trt")
|
269 |
+
plan_path.write_bytes(engine.serialize())
|
270 |
+
engine_bytes = engine.serialize()
|
271 |
+
|
272 |
+
print(f"✅ TensorRT engine written to {plan_path.resolve()}")
|
273 |
+
|
274 |
+
# ─────────────────────────────────────────────────────────
|
275 |
+
# 4️⃣ Quick smoke-test with the brand-new engine
|
276 |
+
# ─────────────────────────────────────────────────────────
|
277 |
+
with trt.Runtime(TRT_LOGGER) as rt:
|
278 |
+
engine = rt.deserialize_cuda_engine(engine_bytes)
|
279 |
+
context = engine.create_execution_context()
|
280 |
+
stream = torch.cuda.current_stream().cuda_stream
|
281 |
+
|
282 |
+
# pre-allocate device buffers once
|
283 |
+
device_buffers, outputs = {}, []
|
284 |
+
dtype_map = {trt.float32: torch.float32,
|
285 |
+
trt.float16: torch.float16,
|
286 |
+
trt.int8: torch.int8,
|
287 |
+
trt.int32: torch.int32}
|
288 |
+
|
289 |
+
for name, tensor in zip(all_inputs_names, inputs):
|
290 |
+
if -1 in engine.get_tensor_shape(name): # dynamic input
|
291 |
+
context.set_input_shape(name, tensor.shape)
|
292 |
+
context.set_tensor_address(name, int(tensor.data_ptr()))
|
293 |
+
device_buffers[name] = tensor
|
294 |
+
|
295 |
+
context.infer_shapes() # propagate ⇢ outputs
|
296 |
+
for i in range(engine.num_io_tensors):
|
297 |
+
name = engine.get_tensor_name(i)
|
298 |
+
if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
|
299 |
+
shape = tuple(context.get_tensor_shape(name))
|
300 |
+
dtype = dtype_map[engine.get_tensor_dtype(name)]
|
301 |
+
out = torch.empty(shape, dtype=dtype, device="cuda")
|
302 |
+
context.set_tensor_address(name, int(out.data_ptr()))
|
303 |
+
outputs.append(out)
|
304 |
+
print(f"output {name} shape: {shape}")
|
305 |
+
|
306 |
+
context.execute_async_v3(stream_handle=stream)
|
307 |
+
torch.cuda.current_stream().synchronize()
|
308 |
+
print("✅ TRT execution OK – first output shape:", outputs[0].shape)
|
images/.gitkeep
ADDED
File without changes
|
inference.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
from omegaconf import OmegaConf
|
5 |
+
from tqdm import tqdm
|
6 |
+
from torchvision import transforms
|
7 |
+
from torchvision.io import write_video
|
8 |
+
from einops import rearrange
|
9 |
+
import torch.distributed as dist
|
10 |
+
from torch.utils.data import DataLoader, SequentialSampler
|
11 |
+
from torch.utils.data.distributed import DistributedSampler
|
12 |
+
|
13 |
+
from pipeline import (
|
14 |
+
CausalDiffusionInferencePipeline,
|
15 |
+
CausalInferencePipeline
|
16 |
+
)
|
17 |
+
from utils.dataset import TextDataset, TextImagePairDataset
|
18 |
+
from utils.misc import set_seed
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--config_path", type=str, help="Path to the config file")
|
22 |
+
parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
|
23 |
+
parser.add_argument("--data_path", type=str, help="Path to the dataset")
|
24 |
+
parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
|
25 |
+
parser.add_argument("--output_folder", type=str, help="Output folder")
|
26 |
+
parser.add_argument("--num_output_frames", type=int, default=21,
|
27 |
+
help="Number of overlap frames between sliding windows")
|
28 |
+
parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
|
29 |
+
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
|
30 |
+
parser.add_argument("--seed", type=int, default=0, help="Random seed")
|
31 |
+
parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
|
32 |
+
parser.add_argument("--save_with_index", action="store_true",
|
33 |
+
help="Whether to save the video using the index or prompt as the filename")
|
34 |
+
args = parser.parse_args()
|
35 |
+
|
36 |
+
# Initialize distributed inference
|
37 |
+
if "LOCAL_RANK" in os.environ:
|
38 |
+
dist.init_process_group(backend='nccl')
|
39 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
40 |
+
torch.cuda.set_device(local_rank)
|
41 |
+
device = torch.device(f"cuda:{local_rank}")
|
42 |
+
world_size = dist.get_world_size()
|
43 |
+
set_seed(args.seed + local_rank)
|
44 |
+
else:
|
45 |
+
device = torch.device("cuda")
|
46 |
+
local_rank = 0
|
47 |
+
world_size = 1
|
48 |
+
set_seed(args.seed)
|
49 |
+
|
50 |
+
torch.set_grad_enabled(False)
|
51 |
+
|
52 |
+
config = OmegaConf.load(args.config_path)
|
53 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
54 |
+
config = OmegaConf.merge(default_config, config)
|
55 |
+
|
56 |
+
# Initialize pipeline
|
57 |
+
if hasattr(config, 'denoising_step_list'):
|
58 |
+
# Few-step inference
|
59 |
+
pipeline = CausalInferencePipeline(config, device=device)
|
60 |
+
else:
|
61 |
+
# Multi-step diffusion inference
|
62 |
+
pipeline = CausalDiffusionInferencePipeline(config, device=device)
|
63 |
+
|
64 |
+
if args.checkpoint_path:
|
65 |
+
state_dict = torch.load(args.checkpoint_path, map_location="cpu")
|
66 |
+
pipeline.generator.load_state_dict(state_dict['generator' if not args.use_ema else 'generator_ema'])
|
67 |
+
|
68 |
+
pipeline = pipeline.to(device=device, dtype=torch.bfloat16)
|
69 |
+
|
70 |
+
# Create dataset
|
71 |
+
if args.i2v:
|
72 |
+
assert not dist.is_initialized(), "I2V does not support distributed inference yet"
|
73 |
+
transform = transforms.Compose([
|
74 |
+
transforms.Resize((480, 832)),
|
75 |
+
transforms.ToTensor(),
|
76 |
+
transforms.Normalize([0.5], [0.5])
|
77 |
+
])
|
78 |
+
dataset = TextImagePairDataset(args.data_path, transform=transform)
|
79 |
+
else:
|
80 |
+
dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
|
81 |
+
num_prompts = len(dataset)
|
82 |
+
print(f"Number of prompts: {num_prompts}")
|
83 |
+
|
84 |
+
if dist.is_initialized():
|
85 |
+
sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
|
86 |
+
else:
|
87 |
+
sampler = SequentialSampler(dataset)
|
88 |
+
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
|
89 |
+
|
90 |
+
# Create output directory (only on main process to avoid race conditions)
|
91 |
+
if local_rank == 0:
|
92 |
+
os.makedirs(args.output_folder, exist_ok=True)
|
93 |
+
|
94 |
+
if dist.is_initialized():
|
95 |
+
dist.barrier()
|
96 |
+
|
97 |
+
|
98 |
+
def encode(self, videos: torch.Tensor) -> torch.Tensor:
|
99 |
+
device, dtype = videos[0].device, videos[0].dtype
|
100 |
+
scale = [self.mean.to(device=device, dtype=dtype),
|
101 |
+
1.0 / self.std.to(device=device, dtype=dtype)]
|
102 |
+
output = [
|
103 |
+
self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
|
104 |
+
for u in videos
|
105 |
+
]
|
106 |
+
|
107 |
+
output = torch.stack(output, dim=0)
|
108 |
+
return output
|
109 |
+
|
110 |
+
|
111 |
+
for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
|
112 |
+
idx = batch_data['idx'].item()
|
113 |
+
|
114 |
+
# For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
|
115 |
+
# Unpack the batch data for convenience
|
116 |
+
if isinstance(batch_data, dict):
|
117 |
+
batch = batch_data
|
118 |
+
elif isinstance(batch_data, list):
|
119 |
+
batch = batch_data[0] # First (and only) item in the batch
|
120 |
+
|
121 |
+
all_video = []
|
122 |
+
num_generated_frames = 0 # Number of generated (latent) frames
|
123 |
+
|
124 |
+
if args.i2v:
|
125 |
+
# For image-to-video, batch contains image and caption
|
126 |
+
prompt = batch['prompts'][0] # Get caption from batch
|
127 |
+
prompts = [prompt] * args.num_samples
|
128 |
+
|
129 |
+
# Process the image
|
130 |
+
image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)
|
131 |
+
|
132 |
+
# Encode the input image as the first latent
|
133 |
+
initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
|
134 |
+
initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1)
|
135 |
+
|
136 |
+
sampled_noise = torch.randn(
|
137 |
+
[args.num_samples, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
|
138 |
+
)
|
139 |
+
else:
|
140 |
+
# For text-to-video, batch is just the text prompt
|
141 |
+
prompt = batch['prompts'][0]
|
142 |
+
extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
|
143 |
+
if extended_prompt is not None:
|
144 |
+
prompts = [extended_prompt] * args.num_samples
|
145 |
+
else:
|
146 |
+
prompts = [prompt] * args.num_samples
|
147 |
+
initial_latent = None
|
148 |
+
|
149 |
+
sampled_noise = torch.randn(
|
150 |
+
[args.num_samples, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
|
151 |
+
)
|
152 |
+
|
153 |
+
# Generate 81 frames
|
154 |
+
video, latents = pipeline.inference(
|
155 |
+
noise=sampled_noise,
|
156 |
+
text_prompts=prompts,
|
157 |
+
return_latents=True,
|
158 |
+
initial_latent=initial_latent,
|
159 |
+
)
|
160 |
+
current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
|
161 |
+
all_video.append(current_video)
|
162 |
+
num_generated_frames += latents.shape[1]
|
163 |
+
|
164 |
+
# Final output video
|
165 |
+
video = 255.0 * torch.cat(all_video, dim=1)
|
166 |
+
|
167 |
+
# Clear VAE cache
|
168 |
+
pipeline.vae.model.clear_cache()
|
169 |
+
|
170 |
+
# Save the video if the current prompt is not a dummy prompt
|
171 |
+
if idx < num_prompts:
|
172 |
+
model = "regular" if not args.use_ema else "ema"
|
173 |
+
for seed_idx in range(args.num_samples):
|
174 |
+
# All processes save their videos
|
175 |
+
if args.save_with_index:
|
176 |
+
output_path = os.path.join(args.output_folder, f'{idx}-{seed_idx}_{model}.mp4')
|
177 |
+
else:
|
178 |
+
output_path = os.path.join(args.output_folder, f'{prompt[:100]}-{seed_idx}.mp4')
|
179 |
+
write_video(output_path, video[seed_idx], fps=16)
|
model/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .diffusion import CausalDiffusion
|
2 |
+
from .causvid import CausVid
|
3 |
+
from .dmd import DMD
|
4 |
+
from .gan import GAN
|
5 |
+
from .sid import SiD
|
6 |
+
from .ode_regression import ODERegression
|
7 |
+
__all__ = [
|
8 |
+
"CausalDiffusion",
|
9 |
+
"CausVid",
|
10 |
+
"DMD",
|
11 |
+
"GAN",
|
12 |
+
"SiD",
|
13 |
+
"ODERegression"
|
14 |
+
]
|
model/base.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
from einops import rearrange
|
3 |
+
from torch import nn
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from pipeline import SelfForcingTrainingPipeline
|
8 |
+
from utils.loss import get_denoising_loss
|
9 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
10 |
+
|
11 |
+
|
12 |
+
class BaseModel(nn.Module):
|
13 |
+
def __init__(self, args, device):
|
14 |
+
super().__init__()
|
15 |
+
self._initialize_models(args, device)
|
16 |
+
|
17 |
+
self.device = device
|
18 |
+
self.args = args
|
19 |
+
self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32
|
20 |
+
if hasattr(args, "denoising_step_list"):
|
21 |
+
self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long)
|
22 |
+
if args.warp_denoising_step:
|
23 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
24 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
25 |
+
|
26 |
+
def _initialize_models(self, args, device):
|
27 |
+
self.real_model_name = getattr(args, "real_name", "Wan2.1-T2V-1.3B")
|
28 |
+
self.fake_model_name = getattr(args, "fake_name", "Wan2.1-T2V-1.3B")
|
29 |
+
|
30 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
31 |
+
self.generator.model.requires_grad_(True)
|
32 |
+
|
33 |
+
self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, is_causal=False)
|
34 |
+
self.real_score.model.requires_grad_(False)
|
35 |
+
|
36 |
+
self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, is_causal=False)
|
37 |
+
self.fake_score.model.requires_grad_(True)
|
38 |
+
|
39 |
+
self.text_encoder = WanTextEncoder()
|
40 |
+
self.text_encoder.requires_grad_(False)
|
41 |
+
|
42 |
+
self.vae = WanVAEWrapper()
|
43 |
+
self.vae.requires_grad_(False)
|
44 |
+
|
45 |
+
self.scheduler = self.generator.get_scheduler()
|
46 |
+
self.scheduler.timesteps = self.scheduler.timesteps.to(device)
|
47 |
+
|
48 |
+
def _get_timestep(
|
49 |
+
self,
|
50 |
+
min_timestep: int,
|
51 |
+
max_timestep: int,
|
52 |
+
batch_size: int,
|
53 |
+
num_frame: int,
|
54 |
+
num_frame_per_block: int,
|
55 |
+
uniform_timestep: bool = False
|
56 |
+
) -> torch.Tensor:
|
57 |
+
"""
|
58 |
+
Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
|
59 |
+
from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
|
60 |
+
- If uniform_timestep, it will use the same timestep for all frames.
|
61 |
+
- If not uniform_timestep, it will use a different timestep for each block.
|
62 |
+
"""
|
63 |
+
if uniform_timestep:
|
64 |
+
timestep = torch.randint(
|
65 |
+
min_timestep,
|
66 |
+
max_timestep,
|
67 |
+
[batch_size, 1],
|
68 |
+
device=self.device,
|
69 |
+
dtype=torch.long
|
70 |
+
).repeat(1, num_frame)
|
71 |
+
return timestep
|
72 |
+
else:
|
73 |
+
timestep = torch.randint(
|
74 |
+
min_timestep,
|
75 |
+
max_timestep,
|
76 |
+
[batch_size, num_frame],
|
77 |
+
device=self.device,
|
78 |
+
dtype=torch.long
|
79 |
+
)
|
80 |
+
# make the noise level the same within every block
|
81 |
+
if self.independent_first_frame:
|
82 |
+
# the first frame is always kept the same
|
83 |
+
timestep_from_second = timestep[:, 1:]
|
84 |
+
timestep_from_second = timestep_from_second.reshape(
|
85 |
+
timestep_from_second.shape[0], -1, num_frame_per_block)
|
86 |
+
timestep_from_second[:, :, 1:] = timestep_from_second[:, :, 0:1]
|
87 |
+
timestep_from_second = timestep_from_second.reshape(
|
88 |
+
timestep_from_second.shape[0], -1)
|
89 |
+
timestep = torch.cat([timestep[:, 0:1], timestep_from_second], dim=1)
|
90 |
+
else:
|
91 |
+
timestep = timestep.reshape(
|
92 |
+
timestep.shape[0], -1, num_frame_per_block)
|
93 |
+
timestep[:, :, 1:] = timestep[:, :, 0:1]
|
94 |
+
timestep = timestep.reshape(timestep.shape[0], -1)
|
95 |
+
return timestep
|
96 |
+
|
97 |
+
|
98 |
+
class SelfForcingModel(BaseModel):
|
99 |
+
def __init__(self, args, device):
|
100 |
+
super().__init__(args, device)
|
101 |
+
self.denoising_loss_func = get_denoising_loss(args.denoising_loss_type)()
|
102 |
+
|
103 |
+
def _run_generator(
|
104 |
+
self,
|
105 |
+
image_or_video_shape,
|
106 |
+
conditional_dict: dict,
|
107 |
+
initial_latent: torch.tensor = None
|
108 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
109 |
+
"""
|
110 |
+
Optionally simulate the generator's input from noise using backward simulation
|
111 |
+
and then run the generator for one-step.
|
112 |
+
Input:
|
113 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
114 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
115 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
116 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
117 |
+
- initial_latent: a tensor containing the initial latents [B, F, C, H, W].
|
118 |
+
Output:
|
119 |
+
- pred_image: a tensor with shape [B, F, C, H, W].
|
120 |
+
- denoised_timestep: an integer
|
121 |
+
"""
|
122 |
+
# Step 1: Sample noise and backward simulate the generator's input
|
123 |
+
assert getattr(self.args, "backward_simulation", True), "Backward simulation needs to be enabled"
|
124 |
+
if initial_latent is not None:
|
125 |
+
conditional_dict["initial_latent"] = initial_latent
|
126 |
+
if self.args.i2v:
|
127 |
+
noise_shape = [image_or_video_shape[0], image_or_video_shape[1] - 1, *image_or_video_shape[2:]]
|
128 |
+
else:
|
129 |
+
noise_shape = image_or_video_shape.copy()
|
130 |
+
|
131 |
+
# During training, the number of generated frames should be uniformly sampled from
|
132 |
+
# [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block
|
133 |
+
min_num_frames = 20 if self.args.independent_first_frame else 21
|
134 |
+
max_num_frames = self.num_training_frames - 1 if self.args.independent_first_frame else self.num_training_frames
|
135 |
+
assert max_num_frames % self.num_frame_per_block == 0
|
136 |
+
assert min_num_frames % self.num_frame_per_block == 0
|
137 |
+
max_num_blocks = max_num_frames // self.num_frame_per_block
|
138 |
+
min_num_blocks = min_num_frames // self.num_frame_per_block
|
139 |
+
num_generated_blocks = torch.randint(min_num_blocks, max_num_blocks + 1, (1,), device=self.device)
|
140 |
+
dist.broadcast(num_generated_blocks, src=0)
|
141 |
+
num_generated_blocks = num_generated_blocks.item()
|
142 |
+
num_generated_frames = num_generated_blocks * self.num_frame_per_block
|
143 |
+
if self.args.independent_first_frame and initial_latent is None:
|
144 |
+
num_generated_frames += 1
|
145 |
+
min_num_frames += 1
|
146 |
+
# Sync num_generated_frames across all processes
|
147 |
+
noise_shape[1] = num_generated_frames
|
148 |
+
|
149 |
+
pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(
|
150 |
+
noise=torch.randn(noise_shape,
|
151 |
+
device=self.device, dtype=self.dtype),
|
152 |
+
**conditional_dict,
|
153 |
+
)
|
154 |
+
# Slice last 21 frames
|
155 |
+
if pred_image_or_video.shape[1] > 21:
|
156 |
+
with torch.no_grad():
|
157 |
+
# Reencode to get image latent
|
158 |
+
latent_to_decode = pred_image_or_video[:, :-20, ...]
|
159 |
+
# Deccode to video
|
160 |
+
pixels = self.vae.decode_to_pixel(latent_to_decode)
|
161 |
+
frame = pixels[:, -1:, ...].to(self.dtype)
|
162 |
+
frame = rearrange(frame, "b t c h w -> b c t h w")
|
163 |
+
# Encode frame to get image latent
|
164 |
+
image_latent = self.vae.encode_to_latent(frame).to(self.dtype)
|
165 |
+
pred_image_or_video_last_21 = torch.cat([image_latent, pred_image_or_video[:, -20:, ...]], dim=1)
|
166 |
+
else:
|
167 |
+
pred_image_or_video_last_21 = pred_image_or_video
|
168 |
+
|
169 |
+
if num_generated_frames != min_num_frames:
|
170 |
+
# Currently, we do not use gradient for the first chunk, since it contains image latents
|
171 |
+
gradient_mask = torch.ones_like(pred_image_or_video_last_21, dtype=torch.bool)
|
172 |
+
if self.args.independent_first_frame:
|
173 |
+
gradient_mask[:, :1] = False
|
174 |
+
else:
|
175 |
+
gradient_mask[:, :self.num_frame_per_block] = False
|
176 |
+
else:
|
177 |
+
gradient_mask = None
|
178 |
+
|
179 |
+
pred_image_or_video_last_21 = pred_image_or_video_last_21.to(self.dtype)
|
180 |
+
return pred_image_or_video_last_21, gradient_mask, denoised_timestep_from, denoised_timestep_to
|
181 |
+
|
182 |
+
def _consistency_backward_simulation(
|
183 |
+
self,
|
184 |
+
noise: torch.Tensor,
|
185 |
+
**conditional_dict: dict
|
186 |
+
) -> torch.Tensor:
|
187 |
+
"""
|
188 |
+
Simulate the generator's input from noise to avoid training/inference mismatch.
|
189 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
190 |
+
Here we use the consistency sampler (https://arxiv.org/abs/2303.01469)
|
191 |
+
Input:
|
192 |
+
- noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
193 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
194 |
+
Output:
|
195 |
+
- output: a tensor with shape [B, T, F, C, H, W].
|
196 |
+
T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
|
197 |
+
represents the x0 prediction at each timestep.
|
198 |
+
"""
|
199 |
+
if self.inference_pipeline is None:
|
200 |
+
self._initialize_inference_pipeline()
|
201 |
+
|
202 |
+
return self.inference_pipeline.inference_with_trajectory(
|
203 |
+
noise=noise, **conditional_dict
|
204 |
+
)
|
205 |
+
|
206 |
+
def _initialize_inference_pipeline(self):
|
207 |
+
"""
|
208 |
+
Lazy initialize the inference pipeline during the first backward simulation run.
|
209 |
+
Here we encapsulate the inference code with a model-dependent outside function.
|
210 |
+
We pass our FSDP-wrapped modules into the pipeline to save memory.
|
211 |
+
"""
|
212 |
+
self.inference_pipeline = SelfForcingTrainingPipeline(
|
213 |
+
denoising_step_list=self.denoising_step_list,
|
214 |
+
scheduler=self.scheduler,
|
215 |
+
generator=self.generator,
|
216 |
+
num_frame_per_block=self.num_frame_per_block,
|
217 |
+
independent_first_frame=self.args.independent_first_frame,
|
218 |
+
same_step_across_blocks=self.args.same_step_across_blocks,
|
219 |
+
last_step_only=self.args.last_step_only,
|
220 |
+
num_max_frames=self.num_training_frames,
|
221 |
+
context_noise=self.args.context_noise
|
222 |
+
)
|
model/causvid.py
ADDED
@@ -0,0 +1,391 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
from typing import Tuple
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from model.base import BaseModel
|
6 |
+
|
7 |
+
|
8 |
+
class CausVid(BaseModel):
|
9 |
+
def __init__(self, args, device):
|
10 |
+
"""
|
11 |
+
Initialize the DMD (Distribution Matching Distillation) module.
|
12 |
+
This class is self-contained and compute generator and fake score losses
|
13 |
+
in the forward pass.
|
14 |
+
"""
|
15 |
+
super().__init__(args, device)
|
16 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
17 |
+
self.num_training_frames = getattr(args, "num_training_frames", 21)
|
18 |
+
|
19 |
+
if self.num_frame_per_block > 1:
|
20 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
21 |
+
|
22 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
23 |
+
if self.independent_first_frame:
|
24 |
+
self.generator.model.independent_first_frame = True
|
25 |
+
if args.gradient_checkpointing:
|
26 |
+
self.generator.enable_gradient_checkpointing()
|
27 |
+
self.fake_score.enable_gradient_checkpointing()
|
28 |
+
|
29 |
+
# Step 2: Initialize all dmd hyperparameters
|
30 |
+
self.num_train_timestep = args.num_train_timestep
|
31 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
32 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
33 |
+
if hasattr(args, "real_guidance_scale"):
|
34 |
+
self.real_guidance_scale = args.real_guidance_scale
|
35 |
+
self.fake_guidance_scale = args.fake_guidance_scale
|
36 |
+
else:
|
37 |
+
self.real_guidance_scale = args.guidance_scale
|
38 |
+
self.fake_guidance_scale = 0.0
|
39 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
40 |
+
self.teacher_forcing = getattr(args, "teacher_forcing", False)
|
41 |
+
|
42 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
43 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
44 |
+
else:
|
45 |
+
self.scheduler.alphas_cumprod = None
|
46 |
+
|
47 |
+
def _compute_kl_grad(
|
48 |
+
self, noisy_image_or_video: torch.Tensor,
|
49 |
+
estimated_clean_image_or_video: torch.Tensor,
|
50 |
+
timestep: torch.Tensor,
|
51 |
+
conditional_dict: dict, unconditional_dict: dict,
|
52 |
+
normalization: bool = True
|
53 |
+
) -> Tuple[torch.Tensor, dict]:
|
54 |
+
"""
|
55 |
+
Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
|
56 |
+
Input:
|
57 |
+
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
58 |
+
- estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
|
59 |
+
- timestep: a tensor with shape [B, F] containing the randomly generated timestep.
|
60 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
61 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
62 |
+
- normalization: a boolean indicating whether to normalize the gradient.
|
63 |
+
Output:
|
64 |
+
- kl_grad: a tensor representing the KL grad.
|
65 |
+
- kl_log_dict: a dictionary containing the intermediate tensors for logging.
|
66 |
+
"""
|
67 |
+
# Step 1: Compute the fake score
|
68 |
+
_, pred_fake_image_cond = self.fake_score(
|
69 |
+
noisy_image_or_video=noisy_image_or_video,
|
70 |
+
conditional_dict=conditional_dict,
|
71 |
+
timestep=timestep
|
72 |
+
)
|
73 |
+
|
74 |
+
if self.fake_guidance_scale != 0.0:
|
75 |
+
_, pred_fake_image_uncond = self.fake_score(
|
76 |
+
noisy_image_or_video=noisy_image_or_video,
|
77 |
+
conditional_dict=unconditional_dict,
|
78 |
+
timestep=timestep
|
79 |
+
)
|
80 |
+
pred_fake_image = pred_fake_image_cond + (
|
81 |
+
pred_fake_image_cond - pred_fake_image_uncond
|
82 |
+
) * self.fake_guidance_scale
|
83 |
+
else:
|
84 |
+
pred_fake_image = pred_fake_image_cond
|
85 |
+
|
86 |
+
# Step 2: Compute the real score
|
87 |
+
# We compute the conditional and unconditional prediction
|
88 |
+
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
|
89 |
+
_, pred_real_image_cond = self.real_score(
|
90 |
+
noisy_image_or_video=noisy_image_or_video,
|
91 |
+
conditional_dict=conditional_dict,
|
92 |
+
timestep=timestep
|
93 |
+
)
|
94 |
+
|
95 |
+
_, pred_real_image_uncond = self.real_score(
|
96 |
+
noisy_image_or_video=noisy_image_or_video,
|
97 |
+
conditional_dict=unconditional_dict,
|
98 |
+
timestep=timestep
|
99 |
+
)
|
100 |
+
|
101 |
+
pred_real_image = pred_real_image_cond + (
|
102 |
+
pred_real_image_cond - pred_real_image_uncond
|
103 |
+
) * self.real_guidance_scale
|
104 |
+
|
105 |
+
# Step 3: Compute the DMD gradient (DMD paper eq. 7).
|
106 |
+
grad = (pred_fake_image - pred_real_image)
|
107 |
+
|
108 |
+
# TODO: Change the normalizer for causal teacher
|
109 |
+
if normalization:
|
110 |
+
# Step 4: Gradient normalization (DMD paper eq. 8).
|
111 |
+
p_real = (estimated_clean_image_or_video - pred_real_image)
|
112 |
+
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
|
113 |
+
grad = grad / normalizer
|
114 |
+
grad = torch.nan_to_num(grad)
|
115 |
+
|
116 |
+
return grad, {
|
117 |
+
"dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
|
118 |
+
"timestep": timestep.detach()
|
119 |
+
}
|
120 |
+
|
121 |
+
def compute_distribution_matching_loss(
|
122 |
+
self,
|
123 |
+
image_or_video: torch.Tensor,
|
124 |
+
conditional_dict: dict,
|
125 |
+
unconditional_dict: dict,
|
126 |
+
gradient_mask: torch.Tensor = None,
|
127 |
+
) -> Tuple[torch.Tensor, dict]:
|
128 |
+
"""
|
129 |
+
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
|
130 |
+
Input:
|
131 |
+
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
132 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
133 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
134 |
+
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
|
135 |
+
Output:
|
136 |
+
- dmd_loss: a scalar tensor representing the DMD loss.
|
137 |
+
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
|
138 |
+
"""
|
139 |
+
original_latent = image_or_video
|
140 |
+
|
141 |
+
batch_size, num_frame = image_or_video.shape[:2]
|
142 |
+
|
143 |
+
with torch.no_grad():
|
144 |
+
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
|
145 |
+
timestep = self._get_timestep(
|
146 |
+
0,
|
147 |
+
self.num_train_timestep,
|
148 |
+
batch_size,
|
149 |
+
num_frame,
|
150 |
+
self.num_frame_per_block,
|
151 |
+
uniform_timestep=True
|
152 |
+
)
|
153 |
+
|
154 |
+
if self.timestep_shift > 1:
|
155 |
+
timestep = self.timestep_shift * \
|
156 |
+
(timestep / 1000) / \
|
157 |
+
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
|
158 |
+
timestep = timestep.clamp(self.min_step, self.max_step)
|
159 |
+
|
160 |
+
noise = torch.randn_like(image_or_video)
|
161 |
+
noisy_latent = self.scheduler.add_noise(
|
162 |
+
image_or_video.flatten(0, 1),
|
163 |
+
noise.flatten(0, 1),
|
164 |
+
timestep.flatten(0, 1)
|
165 |
+
).detach().unflatten(0, (batch_size, num_frame))
|
166 |
+
|
167 |
+
# Step 2: Compute the KL grad
|
168 |
+
grad, dmd_log_dict = self._compute_kl_grad(
|
169 |
+
noisy_image_or_video=noisy_latent,
|
170 |
+
estimated_clean_image_or_video=original_latent,
|
171 |
+
timestep=timestep,
|
172 |
+
conditional_dict=conditional_dict,
|
173 |
+
unconditional_dict=unconditional_dict
|
174 |
+
)
|
175 |
+
|
176 |
+
if gradient_mask is not None:
|
177 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
178 |
+
)[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
|
179 |
+
else:
|
180 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
181 |
+
), (original_latent.double() - grad.double()).detach(), reduction="mean")
|
182 |
+
return dmd_loss, dmd_log_dict
|
183 |
+
|
184 |
+
def _run_generator(
|
185 |
+
self,
|
186 |
+
image_or_video_shape,
|
187 |
+
conditional_dict: dict,
|
188 |
+
clean_latent: torch.tensor
|
189 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
190 |
+
"""
|
191 |
+
Optionally simulate the generator's input from noise using backward simulation
|
192 |
+
and then run the generator for one-step.
|
193 |
+
Input:
|
194 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
195 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
196 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
197 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
198 |
+
- initial_latent: a tensor containing the initial latents [B, F, C, H, W].
|
199 |
+
Output:
|
200 |
+
- pred_image: a tensor with shape [B, F, C, H, W].
|
201 |
+
"""
|
202 |
+
simulated_noisy_input = []
|
203 |
+
for timestep in self.denoising_step_list:
|
204 |
+
noise = torch.randn(
|
205 |
+
image_or_video_shape, device=self.device, dtype=self.dtype)
|
206 |
+
|
207 |
+
noisy_timestep = timestep * torch.ones(
|
208 |
+
image_or_video_shape[:2], device=self.device, dtype=torch.long)
|
209 |
+
|
210 |
+
if timestep != 0:
|
211 |
+
noisy_image = self.scheduler.add_noise(
|
212 |
+
clean_latent.flatten(0, 1),
|
213 |
+
noise.flatten(0, 1),
|
214 |
+
noisy_timestep.flatten(0, 1)
|
215 |
+
).unflatten(0, image_or_video_shape[:2])
|
216 |
+
else:
|
217 |
+
noisy_image = clean_latent
|
218 |
+
|
219 |
+
simulated_noisy_input.append(noisy_image)
|
220 |
+
|
221 |
+
simulated_noisy_input = torch.stack(simulated_noisy_input, dim=1)
|
222 |
+
|
223 |
+
# Step 2: Randomly sample a timestep and pick the corresponding input
|
224 |
+
index = self._get_timestep(
|
225 |
+
0,
|
226 |
+
len(self.denoising_step_list),
|
227 |
+
image_or_video_shape[0],
|
228 |
+
image_or_video_shape[1],
|
229 |
+
self.num_frame_per_block,
|
230 |
+
uniform_timestep=False
|
231 |
+
)
|
232 |
+
|
233 |
+
# select the corresponding timestep's noisy input from the stacked tensor [B, T, F, C, H, W]
|
234 |
+
noisy_input = torch.gather(
|
235 |
+
simulated_noisy_input, dim=1,
|
236 |
+
index=index.reshape(index.shape[0], 1, index.shape[1], 1, 1, 1).expand(
|
237 |
+
-1, -1, -1, *image_or_video_shape[2:]).to(self.device)
|
238 |
+
).squeeze(1)
|
239 |
+
|
240 |
+
timestep = self.denoising_step_list[index].to(self.device)
|
241 |
+
|
242 |
+
_, pred_image_or_video = self.generator(
|
243 |
+
noisy_image_or_video=noisy_input,
|
244 |
+
conditional_dict=conditional_dict,
|
245 |
+
timestep=timestep,
|
246 |
+
clean_x=clean_latent if self.teacher_forcing else None,
|
247 |
+
)
|
248 |
+
|
249 |
+
gradient_mask = None # timestep != 0
|
250 |
+
|
251 |
+
pred_image_or_video = pred_image_or_video.type_as(noisy_input)
|
252 |
+
|
253 |
+
return pred_image_or_video, gradient_mask
|
254 |
+
|
255 |
+
def generator_loss(
|
256 |
+
self,
|
257 |
+
image_or_video_shape,
|
258 |
+
conditional_dict: dict,
|
259 |
+
unconditional_dict: dict,
|
260 |
+
clean_latent: torch.Tensor,
|
261 |
+
initial_latent: torch.Tensor = None
|
262 |
+
) -> Tuple[torch.Tensor, dict]:
|
263 |
+
"""
|
264 |
+
Generate image/videos from noise and compute the DMD loss.
|
265 |
+
The noisy input to the generator is backward simulated.
|
266 |
+
This removes the need of any datasets during distillation.
|
267 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
268 |
+
Input:
|
269 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
270 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
271 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
272 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
273 |
+
Output:
|
274 |
+
- loss: a scalar tensor representing the generator loss.
|
275 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
276 |
+
"""
|
277 |
+
# Step 1: Run generator on backward simulated noisy input
|
278 |
+
pred_image, gradient_mask = self._run_generator(
|
279 |
+
image_or_video_shape=image_or_video_shape,
|
280 |
+
conditional_dict=conditional_dict,
|
281 |
+
clean_latent=clean_latent
|
282 |
+
)
|
283 |
+
|
284 |
+
# Step 2: Compute the DMD loss
|
285 |
+
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
|
286 |
+
image_or_video=pred_image,
|
287 |
+
conditional_dict=conditional_dict,
|
288 |
+
unconditional_dict=unconditional_dict,
|
289 |
+
gradient_mask=gradient_mask
|
290 |
+
)
|
291 |
+
|
292 |
+
# Step 3: TODO: Implement the GAN loss
|
293 |
+
|
294 |
+
return dmd_loss, dmd_log_dict
|
295 |
+
|
296 |
+
def critic_loss(
|
297 |
+
self,
|
298 |
+
image_or_video_shape,
|
299 |
+
conditional_dict: dict,
|
300 |
+
unconditional_dict: dict,
|
301 |
+
clean_latent: torch.Tensor,
|
302 |
+
initial_latent: torch.Tensor = None
|
303 |
+
) -> Tuple[torch.Tensor, dict]:
|
304 |
+
"""
|
305 |
+
Generate image/videos from noise and train the critic with generated samples.
|
306 |
+
The noisy input to the generator is backward simulated.
|
307 |
+
This removes the need of any datasets during distillation.
|
308 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
309 |
+
Input:
|
310 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
311 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
312 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
313 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
314 |
+
Output:
|
315 |
+
- loss: a scalar tensor representing the generator loss.
|
316 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
317 |
+
"""
|
318 |
+
|
319 |
+
# Step 1: Run generator on backward simulated noisy input
|
320 |
+
with torch.no_grad():
|
321 |
+
generated_image, _ = self._run_generator(
|
322 |
+
image_or_video_shape=image_or_video_shape,
|
323 |
+
conditional_dict=conditional_dict,
|
324 |
+
clean_latent=clean_latent
|
325 |
+
)
|
326 |
+
|
327 |
+
# Step 2: Compute the fake prediction
|
328 |
+
critic_timestep = self._get_timestep(
|
329 |
+
0,
|
330 |
+
self.num_train_timestep,
|
331 |
+
image_or_video_shape[0],
|
332 |
+
image_or_video_shape[1],
|
333 |
+
self.num_frame_per_block,
|
334 |
+
uniform_timestep=True
|
335 |
+
)
|
336 |
+
|
337 |
+
if self.timestep_shift > 1:
|
338 |
+
critic_timestep = self.timestep_shift * \
|
339 |
+
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
340 |
+
|
341 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
342 |
+
|
343 |
+
critic_noise = torch.randn_like(generated_image)
|
344 |
+
noisy_generated_image = self.scheduler.add_noise(
|
345 |
+
generated_image.flatten(0, 1),
|
346 |
+
critic_noise.flatten(0, 1),
|
347 |
+
critic_timestep.flatten(0, 1)
|
348 |
+
).unflatten(0, image_or_video_shape[:2])
|
349 |
+
|
350 |
+
_, pred_fake_image = self.fake_score(
|
351 |
+
noisy_image_or_video=noisy_generated_image,
|
352 |
+
conditional_dict=conditional_dict,
|
353 |
+
timestep=critic_timestep
|
354 |
+
)
|
355 |
+
|
356 |
+
# Step 3: Compute the denoising loss for the fake critic
|
357 |
+
if self.args.denoising_loss_type == "flow":
|
358 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
359 |
+
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
|
360 |
+
scheduler=self.scheduler,
|
361 |
+
x0_pred=pred_fake_image.flatten(0, 1),
|
362 |
+
xt=noisy_generated_image.flatten(0, 1),
|
363 |
+
timestep=critic_timestep.flatten(0, 1)
|
364 |
+
)
|
365 |
+
pred_fake_noise = None
|
366 |
+
else:
|
367 |
+
flow_pred = None
|
368 |
+
pred_fake_noise = self.scheduler.convert_x0_to_noise(
|
369 |
+
x0=pred_fake_image.flatten(0, 1),
|
370 |
+
xt=noisy_generated_image.flatten(0, 1),
|
371 |
+
timestep=critic_timestep.flatten(0, 1)
|
372 |
+
).unflatten(0, image_or_video_shape[:2])
|
373 |
+
|
374 |
+
denoising_loss = self.denoising_loss_func(
|
375 |
+
x=generated_image.flatten(0, 1),
|
376 |
+
x_pred=pred_fake_image.flatten(0, 1),
|
377 |
+
noise=critic_noise.flatten(0, 1),
|
378 |
+
noise_pred=pred_fake_noise,
|
379 |
+
alphas_cumprod=self.scheduler.alphas_cumprod,
|
380 |
+
timestep=critic_timestep.flatten(0, 1),
|
381 |
+
flow_pred=flow_pred
|
382 |
+
)
|
383 |
+
|
384 |
+
# Step 4: TODO: Compute the GAN loss
|
385 |
+
|
386 |
+
# Step 5: Debugging Log
|
387 |
+
critic_log_dict = {
|
388 |
+
"critic_timestep": critic_timestep.detach()
|
389 |
+
}
|
390 |
+
|
391 |
+
return denoising_loss, critic_log_dict
|
model/diffusion.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from model.base import BaseModel
|
5 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
6 |
+
|
7 |
+
|
8 |
+
class CausalDiffusion(BaseModel):
|
9 |
+
def __init__(self, args, device):
|
10 |
+
"""
|
11 |
+
Initialize the Diffusion loss module.
|
12 |
+
"""
|
13 |
+
super().__init__(args, device)
|
14 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
15 |
+
if self.num_frame_per_block > 1:
|
16 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
17 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
18 |
+
if self.independent_first_frame:
|
19 |
+
self.generator.model.independent_first_frame = True
|
20 |
+
|
21 |
+
if args.gradient_checkpointing:
|
22 |
+
self.generator.enable_gradient_checkpointing()
|
23 |
+
|
24 |
+
# Step 2: Initialize all hyperparameters
|
25 |
+
self.num_train_timestep = args.num_train_timestep
|
26 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
27 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
28 |
+
self.guidance_scale = args.guidance_scale
|
29 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
30 |
+
self.teacher_forcing = getattr(args, "teacher_forcing", False)
|
31 |
+
# Noise augmentation in teacher forcing, we add small noise to clean context latents
|
32 |
+
self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0)
|
33 |
+
|
34 |
+
def _initialize_models(self, args):
|
35 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
36 |
+
self.generator.model.requires_grad_(True)
|
37 |
+
|
38 |
+
self.text_encoder = WanTextEncoder()
|
39 |
+
self.text_encoder.requires_grad_(False)
|
40 |
+
|
41 |
+
self.vae = WanVAEWrapper()
|
42 |
+
self.vae.requires_grad_(False)
|
43 |
+
|
44 |
+
def generator_loss(
|
45 |
+
self,
|
46 |
+
image_or_video_shape,
|
47 |
+
conditional_dict: dict,
|
48 |
+
unconditional_dict: dict,
|
49 |
+
clean_latent: torch.Tensor,
|
50 |
+
initial_latent: torch.Tensor = None
|
51 |
+
) -> Tuple[torch.Tensor, dict]:
|
52 |
+
"""
|
53 |
+
Generate image/videos from noise and compute the DMD loss.
|
54 |
+
The noisy input to the generator is backward simulated.
|
55 |
+
This removes the need of any datasets during distillation.
|
56 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
57 |
+
Input:
|
58 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
59 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
60 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
61 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
62 |
+
Output:
|
63 |
+
- loss: a scalar tensor representing the generator loss.
|
64 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
65 |
+
"""
|
66 |
+
noise = torch.randn_like(clean_latent)
|
67 |
+
batch_size, num_frame = image_or_video_shape[:2]
|
68 |
+
|
69 |
+
# Step 2: Randomly sample a timestep and add noise to denoiser inputs
|
70 |
+
index = self._get_timestep(
|
71 |
+
0,
|
72 |
+
self.scheduler.num_train_timesteps,
|
73 |
+
image_or_video_shape[0],
|
74 |
+
image_or_video_shape[1],
|
75 |
+
self.num_frame_per_block,
|
76 |
+
uniform_timestep=False
|
77 |
+
)
|
78 |
+
timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device)
|
79 |
+
noisy_latents = self.scheduler.add_noise(
|
80 |
+
clean_latent.flatten(0, 1),
|
81 |
+
noise.flatten(0, 1),
|
82 |
+
timestep.flatten(0, 1)
|
83 |
+
).unflatten(0, (batch_size, num_frame))
|
84 |
+
training_target = self.scheduler.training_target(clean_latent, noise, timestep)
|
85 |
+
|
86 |
+
# Step 3: Noise augmentation, also add small noise to clean context latents
|
87 |
+
if self.noise_augmentation_max_timestep > 0:
|
88 |
+
index_clean_aug = self._get_timestep(
|
89 |
+
0,
|
90 |
+
self.noise_augmentation_max_timestep,
|
91 |
+
image_or_video_shape[0],
|
92 |
+
image_or_video_shape[1],
|
93 |
+
self.num_frame_per_block,
|
94 |
+
uniform_timestep=False
|
95 |
+
)
|
96 |
+
timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device)
|
97 |
+
clean_latent_aug = self.scheduler.add_noise(
|
98 |
+
clean_latent.flatten(0, 1),
|
99 |
+
noise.flatten(0, 1),
|
100 |
+
timestep_clean_aug.flatten(0, 1)
|
101 |
+
).unflatten(0, (batch_size, num_frame))
|
102 |
+
else:
|
103 |
+
clean_latent_aug = clean_latent
|
104 |
+
timestep_clean_aug = None
|
105 |
+
|
106 |
+
# Compute loss
|
107 |
+
flow_pred, x0_pred = self.generator(
|
108 |
+
noisy_image_or_video=noisy_latents,
|
109 |
+
conditional_dict=conditional_dict,
|
110 |
+
timestep=timestep,
|
111 |
+
clean_x=clean_latent_aug if self.teacher_forcing else None,
|
112 |
+
aug_t=timestep_clean_aug if self.teacher_forcing else None
|
113 |
+
)
|
114 |
+
# loss = torch.nn.functional.mse_loss(flow_pred.float(), training_target.float())
|
115 |
+
loss = torch.nn.functional.mse_loss(
|
116 |
+
flow_pred.float(), training_target.float(), reduction='none'
|
117 |
+
).mean(dim=(2, 3, 4))
|
118 |
+
loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame))
|
119 |
+
loss = loss.mean()
|
120 |
+
|
121 |
+
log_dict = {
|
122 |
+
"x0": clean_latent.detach(),
|
123 |
+
"x0_pred": x0_pred.detach()
|
124 |
+
}
|
125 |
+
return loss, log_dict
|
model/dmd.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pipeline import SelfForcingTrainingPipeline
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from model.base import SelfForcingModel
|
7 |
+
|
8 |
+
|
9 |
+
class DMD(SelfForcingModel):
|
10 |
+
def __init__(self, args, device):
|
11 |
+
"""
|
12 |
+
Initialize the DMD (Distribution Matching Distillation) module.
|
13 |
+
This class is self-contained and compute generator and fake score losses
|
14 |
+
in the forward pass.
|
15 |
+
"""
|
16 |
+
super().__init__(args, device)
|
17 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
18 |
+
self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
|
19 |
+
self.num_training_frames = getattr(args, "num_training_frames", 21)
|
20 |
+
|
21 |
+
if self.num_frame_per_block > 1:
|
22 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
23 |
+
|
24 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
25 |
+
if self.independent_first_frame:
|
26 |
+
self.generator.model.independent_first_frame = True
|
27 |
+
if args.gradient_checkpointing:
|
28 |
+
self.generator.enable_gradient_checkpointing()
|
29 |
+
self.fake_score.enable_gradient_checkpointing()
|
30 |
+
|
31 |
+
# this will be init later with fsdp-wrapped modules
|
32 |
+
self.inference_pipeline: SelfForcingTrainingPipeline = None
|
33 |
+
|
34 |
+
# Step 2: Initialize all dmd hyperparameters
|
35 |
+
self.num_train_timestep = args.num_train_timestep
|
36 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
37 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
38 |
+
if hasattr(args, "real_guidance_scale"):
|
39 |
+
self.real_guidance_scale = args.real_guidance_scale
|
40 |
+
self.fake_guidance_scale = args.fake_guidance_scale
|
41 |
+
else:
|
42 |
+
self.real_guidance_scale = args.guidance_scale
|
43 |
+
self.fake_guidance_scale = 0.0
|
44 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
45 |
+
self.ts_schedule = getattr(args, "ts_schedule", True)
|
46 |
+
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
|
47 |
+
self.min_score_timestep = getattr(args, "min_score_timestep", 0)
|
48 |
+
|
49 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
50 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
51 |
+
else:
|
52 |
+
self.scheduler.alphas_cumprod = None
|
53 |
+
|
54 |
+
def _compute_kl_grad(
|
55 |
+
self, noisy_image_or_video: torch.Tensor,
|
56 |
+
estimated_clean_image_or_video: torch.Tensor,
|
57 |
+
timestep: torch.Tensor,
|
58 |
+
conditional_dict: dict, unconditional_dict: dict,
|
59 |
+
normalization: bool = True
|
60 |
+
) -> Tuple[torch.Tensor, dict]:
|
61 |
+
"""
|
62 |
+
Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
|
63 |
+
Input:
|
64 |
+
- noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
65 |
+
- estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
|
66 |
+
- timestep: a tensor with shape [B, F] containing the randomly generated timestep.
|
67 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
68 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
69 |
+
- normalization: a boolean indicating whether to normalize the gradient.
|
70 |
+
Output:
|
71 |
+
- kl_grad: a tensor representing the KL grad.
|
72 |
+
- kl_log_dict: a dictionary containing the intermediate tensors for logging.
|
73 |
+
"""
|
74 |
+
# Step 1: Compute the fake score
|
75 |
+
_, pred_fake_image_cond = self.fake_score(
|
76 |
+
noisy_image_or_video=noisy_image_or_video,
|
77 |
+
conditional_dict=conditional_dict,
|
78 |
+
timestep=timestep
|
79 |
+
)
|
80 |
+
|
81 |
+
if self.fake_guidance_scale != 0.0:
|
82 |
+
_, pred_fake_image_uncond = self.fake_score(
|
83 |
+
noisy_image_or_video=noisy_image_or_video,
|
84 |
+
conditional_dict=unconditional_dict,
|
85 |
+
timestep=timestep
|
86 |
+
)
|
87 |
+
pred_fake_image = pred_fake_image_cond + (
|
88 |
+
pred_fake_image_cond - pred_fake_image_uncond
|
89 |
+
) * self.fake_guidance_scale
|
90 |
+
else:
|
91 |
+
pred_fake_image = pred_fake_image_cond
|
92 |
+
|
93 |
+
# Step 2: Compute the real score
|
94 |
+
# We compute the conditional and unconditional prediction
|
95 |
+
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
|
96 |
+
_, pred_real_image_cond = self.real_score(
|
97 |
+
noisy_image_or_video=noisy_image_or_video,
|
98 |
+
conditional_dict=conditional_dict,
|
99 |
+
timestep=timestep
|
100 |
+
)
|
101 |
+
|
102 |
+
_, pred_real_image_uncond = self.real_score(
|
103 |
+
noisy_image_or_video=noisy_image_or_video,
|
104 |
+
conditional_dict=unconditional_dict,
|
105 |
+
timestep=timestep
|
106 |
+
)
|
107 |
+
|
108 |
+
pred_real_image = pred_real_image_cond + (
|
109 |
+
pred_real_image_cond - pred_real_image_uncond
|
110 |
+
) * self.real_guidance_scale
|
111 |
+
|
112 |
+
# Step 3: Compute the DMD gradient (DMD paper eq. 7).
|
113 |
+
grad = (pred_fake_image - pred_real_image)
|
114 |
+
|
115 |
+
# TODO: Change the normalizer for causal teacher
|
116 |
+
if normalization:
|
117 |
+
# Step 4: Gradient normalization (DMD paper eq. 8).
|
118 |
+
p_real = (estimated_clean_image_or_video - pred_real_image)
|
119 |
+
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
|
120 |
+
grad = grad / normalizer
|
121 |
+
grad = torch.nan_to_num(grad)
|
122 |
+
|
123 |
+
return grad, {
|
124 |
+
"dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
|
125 |
+
"timestep": timestep.detach()
|
126 |
+
}
|
127 |
+
|
128 |
+
def compute_distribution_matching_loss(
|
129 |
+
self,
|
130 |
+
image_or_video: torch.Tensor,
|
131 |
+
conditional_dict: dict,
|
132 |
+
unconditional_dict: dict,
|
133 |
+
gradient_mask: Optional[torch.Tensor] = None,
|
134 |
+
denoised_timestep_from: int = 0,
|
135 |
+
denoised_timestep_to: int = 0
|
136 |
+
) -> Tuple[torch.Tensor, dict]:
|
137 |
+
"""
|
138 |
+
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
|
139 |
+
Input:
|
140 |
+
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
141 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
142 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
143 |
+
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
|
144 |
+
Output:
|
145 |
+
- dmd_loss: a scalar tensor representing the DMD loss.
|
146 |
+
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
|
147 |
+
"""
|
148 |
+
original_latent = image_or_video
|
149 |
+
|
150 |
+
batch_size, num_frame = image_or_video.shape[:2]
|
151 |
+
|
152 |
+
with torch.no_grad():
|
153 |
+
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
|
154 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
155 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
156 |
+
timestep = self._get_timestep(
|
157 |
+
min_timestep,
|
158 |
+
max_timestep,
|
159 |
+
batch_size,
|
160 |
+
num_frame,
|
161 |
+
self.num_frame_per_block,
|
162 |
+
uniform_timestep=True
|
163 |
+
)
|
164 |
+
|
165 |
+
# TODO:should we change it to `timestep = self.scheduler.timesteps[timestep]`?
|
166 |
+
if self.timestep_shift > 1:
|
167 |
+
timestep = self.timestep_shift * \
|
168 |
+
(timestep / 1000) / \
|
169 |
+
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
|
170 |
+
timestep = timestep.clamp(self.min_step, self.max_step)
|
171 |
+
|
172 |
+
noise = torch.randn_like(image_or_video)
|
173 |
+
noisy_latent = self.scheduler.add_noise(
|
174 |
+
image_or_video.flatten(0, 1),
|
175 |
+
noise.flatten(0, 1),
|
176 |
+
timestep.flatten(0, 1)
|
177 |
+
).detach().unflatten(0, (batch_size, num_frame))
|
178 |
+
|
179 |
+
# Step 2: Compute the KL grad
|
180 |
+
grad, dmd_log_dict = self._compute_kl_grad(
|
181 |
+
noisy_image_or_video=noisy_latent,
|
182 |
+
estimated_clean_image_or_video=original_latent,
|
183 |
+
timestep=timestep,
|
184 |
+
conditional_dict=conditional_dict,
|
185 |
+
unconditional_dict=unconditional_dict
|
186 |
+
)
|
187 |
+
|
188 |
+
if gradient_mask is not None:
|
189 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
190 |
+
)[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
|
191 |
+
else:
|
192 |
+
dmd_loss = 0.5 * F.mse_loss(original_latent.double(
|
193 |
+
), (original_latent.double() - grad.double()).detach(), reduction="mean")
|
194 |
+
return dmd_loss, dmd_log_dict
|
195 |
+
|
196 |
+
def generator_loss(
|
197 |
+
self,
|
198 |
+
image_or_video_shape,
|
199 |
+
conditional_dict: dict,
|
200 |
+
unconditional_dict: dict,
|
201 |
+
clean_latent: torch.Tensor,
|
202 |
+
initial_latent: torch.Tensor = None
|
203 |
+
) -> Tuple[torch.Tensor, dict]:
|
204 |
+
"""
|
205 |
+
Generate image/videos from noise and compute the DMD loss.
|
206 |
+
The noisy input to the generator is backward simulated.
|
207 |
+
This removes the need of any datasets during distillation.
|
208 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
209 |
+
Input:
|
210 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
211 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
212 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
213 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
214 |
+
Output:
|
215 |
+
- loss: a scalar tensor representing the generator loss.
|
216 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
217 |
+
"""
|
218 |
+
# Step 1: Unroll generator to obtain fake videos
|
219 |
+
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
220 |
+
image_or_video_shape=image_or_video_shape,
|
221 |
+
conditional_dict=conditional_dict,
|
222 |
+
initial_latent=initial_latent
|
223 |
+
)
|
224 |
+
|
225 |
+
# Step 2: Compute the DMD loss
|
226 |
+
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
|
227 |
+
image_or_video=pred_image,
|
228 |
+
conditional_dict=conditional_dict,
|
229 |
+
unconditional_dict=unconditional_dict,
|
230 |
+
gradient_mask=gradient_mask,
|
231 |
+
denoised_timestep_from=denoised_timestep_from,
|
232 |
+
denoised_timestep_to=denoised_timestep_to
|
233 |
+
)
|
234 |
+
|
235 |
+
return dmd_loss, dmd_log_dict
|
236 |
+
|
237 |
+
def critic_loss(
|
238 |
+
self,
|
239 |
+
image_or_video_shape,
|
240 |
+
conditional_dict: dict,
|
241 |
+
unconditional_dict: dict,
|
242 |
+
clean_latent: torch.Tensor,
|
243 |
+
initial_latent: torch.Tensor = None
|
244 |
+
) -> Tuple[torch.Tensor, dict]:
|
245 |
+
"""
|
246 |
+
Generate image/videos from noise and train the critic with generated samples.
|
247 |
+
The noisy input to the generator is backward simulated.
|
248 |
+
This removes the need of any datasets during distillation.
|
249 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
250 |
+
Input:
|
251 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
252 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
253 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
254 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
255 |
+
Output:
|
256 |
+
- loss: a scalar tensor representing the generator loss.
|
257 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
258 |
+
"""
|
259 |
+
|
260 |
+
# Step 1: Run generator on backward simulated noisy input
|
261 |
+
with torch.no_grad():
|
262 |
+
generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
263 |
+
image_or_video_shape=image_or_video_shape,
|
264 |
+
conditional_dict=conditional_dict,
|
265 |
+
initial_latent=initial_latent
|
266 |
+
)
|
267 |
+
|
268 |
+
# Step 2: Compute the fake prediction
|
269 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
270 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
271 |
+
critic_timestep = self._get_timestep(
|
272 |
+
min_timestep,
|
273 |
+
max_timestep,
|
274 |
+
image_or_video_shape[0],
|
275 |
+
image_or_video_shape[1],
|
276 |
+
self.num_frame_per_block,
|
277 |
+
uniform_timestep=True
|
278 |
+
)
|
279 |
+
|
280 |
+
if self.timestep_shift > 1:
|
281 |
+
critic_timestep = self.timestep_shift * \
|
282 |
+
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
283 |
+
|
284 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
285 |
+
|
286 |
+
critic_noise = torch.randn_like(generated_image)
|
287 |
+
noisy_generated_image = self.scheduler.add_noise(
|
288 |
+
generated_image.flatten(0, 1),
|
289 |
+
critic_noise.flatten(0, 1),
|
290 |
+
critic_timestep.flatten(0, 1)
|
291 |
+
).unflatten(0, image_or_video_shape[:2])
|
292 |
+
|
293 |
+
_, pred_fake_image = self.fake_score(
|
294 |
+
noisy_image_or_video=noisy_generated_image,
|
295 |
+
conditional_dict=conditional_dict,
|
296 |
+
timestep=critic_timestep
|
297 |
+
)
|
298 |
+
|
299 |
+
# Step 3: Compute the denoising loss for the fake critic
|
300 |
+
if self.args.denoising_loss_type == "flow":
|
301 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
302 |
+
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
|
303 |
+
scheduler=self.scheduler,
|
304 |
+
x0_pred=pred_fake_image.flatten(0, 1),
|
305 |
+
xt=noisy_generated_image.flatten(0, 1),
|
306 |
+
timestep=critic_timestep.flatten(0, 1)
|
307 |
+
)
|
308 |
+
pred_fake_noise = None
|
309 |
+
else:
|
310 |
+
flow_pred = None
|
311 |
+
pred_fake_noise = self.scheduler.convert_x0_to_noise(
|
312 |
+
x0=pred_fake_image.flatten(0, 1),
|
313 |
+
xt=noisy_generated_image.flatten(0, 1),
|
314 |
+
timestep=critic_timestep.flatten(0, 1)
|
315 |
+
).unflatten(0, image_or_video_shape[:2])
|
316 |
+
|
317 |
+
denoising_loss = self.denoising_loss_func(
|
318 |
+
x=generated_image.flatten(0, 1),
|
319 |
+
x_pred=pred_fake_image.flatten(0, 1),
|
320 |
+
noise=critic_noise.flatten(0, 1),
|
321 |
+
noise_pred=pred_fake_noise,
|
322 |
+
alphas_cumprod=self.scheduler.alphas_cumprod,
|
323 |
+
timestep=critic_timestep.flatten(0, 1),
|
324 |
+
flow_pred=flow_pred
|
325 |
+
)
|
326 |
+
|
327 |
+
# Step 5: Debugging Log
|
328 |
+
critic_log_dict = {
|
329 |
+
"critic_timestep": critic_timestep.detach()
|
330 |
+
}
|
331 |
+
|
332 |
+
return denoising_loss, critic_log_dict
|
model/gan.py
ADDED
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from pipeline import SelfForcingTrainingPipeline
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from typing import Tuple
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from model.base import SelfForcingModel
|
8 |
+
|
9 |
+
|
10 |
+
class GAN(SelfForcingModel):
|
11 |
+
def __init__(self, args, device):
|
12 |
+
"""
|
13 |
+
Initialize the GAN module.
|
14 |
+
This class is self-contained and compute generator and fake score losses
|
15 |
+
in the forward pass.
|
16 |
+
"""
|
17 |
+
super().__init__(args, device)
|
18 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
19 |
+
self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
|
20 |
+
self.concat_time_embeddings = getattr(args, "concat_time_embeddings", False)
|
21 |
+
self.num_class = args.num_class
|
22 |
+
self.relativistic_discriminator = getattr(args, "relativistic_discriminator", False)
|
23 |
+
|
24 |
+
if self.num_frame_per_block > 1:
|
25 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
26 |
+
|
27 |
+
self.fake_score.adding_cls_branch(
|
28 |
+
atten_dim=1536, num_class=args.num_class, time_embed_dim=1536 if self.concat_time_embeddings else 0)
|
29 |
+
self.fake_score.model.requires_grad_(True)
|
30 |
+
|
31 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
32 |
+
if self.independent_first_frame:
|
33 |
+
self.generator.model.independent_first_frame = True
|
34 |
+
if args.gradient_checkpointing:
|
35 |
+
self.generator.enable_gradient_checkpointing()
|
36 |
+
self.fake_score.enable_gradient_checkpointing()
|
37 |
+
|
38 |
+
# this will be init later with fsdp-wrapped modules
|
39 |
+
self.inference_pipeline: SelfForcingTrainingPipeline = None
|
40 |
+
|
41 |
+
# Step 2: Initialize all dmd hyperparameters
|
42 |
+
self.num_train_timestep = args.num_train_timestep
|
43 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
44 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
45 |
+
if hasattr(args, "real_guidance_scale"):
|
46 |
+
self.real_guidance_scale = args.real_guidance_scale
|
47 |
+
self.fake_guidance_scale = args.fake_guidance_scale
|
48 |
+
else:
|
49 |
+
self.real_guidance_scale = args.guidance_scale
|
50 |
+
self.fake_guidance_scale = 0.0
|
51 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
52 |
+
self.critic_timestep_shift = getattr(args, "critic_timestep_shift", self.timestep_shift)
|
53 |
+
self.ts_schedule = getattr(args, "ts_schedule", True)
|
54 |
+
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
|
55 |
+
self.min_score_timestep = getattr(args, "min_score_timestep", 0)
|
56 |
+
|
57 |
+
self.gan_g_weight = getattr(args, "gan_g_weight", 1e-2)
|
58 |
+
self.gan_d_weight = getattr(args, "gan_d_weight", 1e-2)
|
59 |
+
self.r1_weight = getattr(args, "r1_weight", 0.0)
|
60 |
+
self.r2_weight = getattr(args, "r2_weight", 0.0)
|
61 |
+
self.r1_sigma = getattr(args, "r1_sigma", 0.01)
|
62 |
+
self.r2_sigma = getattr(args, "r2_sigma", 0.01)
|
63 |
+
|
64 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
65 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
66 |
+
else:
|
67 |
+
self.scheduler.alphas_cumprod = None
|
68 |
+
|
69 |
+
def _run_cls_pred_branch(self,
|
70 |
+
noisy_image_or_video: torch.Tensor,
|
71 |
+
conditional_dict: dict,
|
72 |
+
timestep: torch.Tensor) -> torch.Tensor:
|
73 |
+
"""
|
74 |
+
Run the classifier prediction branch on the generated image or video.
|
75 |
+
Input:
|
76 |
+
- image_or_video: a tensor with shape [B, F, C, H, W].
|
77 |
+
Output:
|
78 |
+
- cls_pred: a tensor with shape [B, 1, 1, 1, 1] representing the feature map for classification.
|
79 |
+
"""
|
80 |
+
_, _, noisy_logit = self.fake_score(
|
81 |
+
noisy_image_or_video=noisy_image_or_video,
|
82 |
+
conditional_dict=conditional_dict,
|
83 |
+
timestep=timestep,
|
84 |
+
classify_mode=True,
|
85 |
+
concat_time_embeddings=self.concat_time_embeddings
|
86 |
+
)
|
87 |
+
|
88 |
+
return noisy_logit
|
89 |
+
|
90 |
+
def generator_loss(
|
91 |
+
self,
|
92 |
+
image_or_video_shape,
|
93 |
+
conditional_dict: dict,
|
94 |
+
unconditional_dict: dict,
|
95 |
+
clean_latent: torch.Tensor,
|
96 |
+
initial_latent: torch.Tensor = None
|
97 |
+
) -> Tuple[torch.Tensor, dict]:
|
98 |
+
"""
|
99 |
+
Generate image/videos from noise and compute the DMD loss.
|
100 |
+
The noisy input to the generator is backward simulated.
|
101 |
+
This removes the need of any datasets during distillation.
|
102 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
103 |
+
Input:
|
104 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
105 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
106 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
107 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
108 |
+
Output:
|
109 |
+
- loss: a scalar tensor representing the generator loss.
|
110 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
111 |
+
"""
|
112 |
+
# Step 1: Unroll generator to obtain fake videos
|
113 |
+
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
114 |
+
image_or_video_shape=image_or_video_shape,
|
115 |
+
conditional_dict=conditional_dict,
|
116 |
+
initial_latent=initial_latent
|
117 |
+
)
|
118 |
+
|
119 |
+
# Step 2: Get timestep and add noise to generated/real latents
|
120 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
121 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
122 |
+
critic_timestep = self._get_timestep(
|
123 |
+
min_timestep,
|
124 |
+
max_timestep,
|
125 |
+
image_or_video_shape[0],
|
126 |
+
image_or_video_shape[1],
|
127 |
+
self.num_frame_per_block,
|
128 |
+
uniform_timestep=True
|
129 |
+
)
|
130 |
+
|
131 |
+
if self.critic_timestep_shift > 1:
|
132 |
+
critic_timestep = self.critic_timestep_shift * \
|
133 |
+
(critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
134 |
+
|
135 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
136 |
+
|
137 |
+
critic_noise = torch.randn_like(pred_image)
|
138 |
+
noisy_fake_latent = self.scheduler.add_noise(
|
139 |
+
pred_image.flatten(0, 1),
|
140 |
+
critic_noise.flatten(0, 1),
|
141 |
+
critic_timestep.flatten(0, 1)
|
142 |
+
).unflatten(0, image_or_video_shape[:2])
|
143 |
+
|
144 |
+
# Step 4: Compute the real GAN discriminator loss
|
145 |
+
real_image_or_video = clean_latent.clone()
|
146 |
+
critic_noise = torch.randn_like(real_image_or_video)
|
147 |
+
noisy_real_latent = self.scheduler.add_noise(
|
148 |
+
real_image_or_video.flatten(0, 1),
|
149 |
+
critic_noise.flatten(0, 1),
|
150 |
+
critic_timestep.flatten(0, 1)
|
151 |
+
).unflatten(0, image_or_video_shape[:2])
|
152 |
+
|
153 |
+
conditional_dict["prompt_embeds"] = torch.concatenate(
|
154 |
+
(conditional_dict["prompt_embeds"], conditional_dict["prompt_embeds"]), dim=0)
|
155 |
+
critic_timestep = torch.concatenate((critic_timestep, critic_timestep), dim=0)
|
156 |
+
noisy_latent = torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0)
|
157 |
+
_, _, noisy_logit = self.fake_score(
|
158 |
+
noisy_image_or_video=noisy_latent,
|
159 |
+
conditional_dict=conditional_dict,
|
160 |
+
timestep=critic_timestep,
|
161 |
+
classify_mode=True,
|
162 |
+
concat_time_embeddings=self.concat_time_embeddings
|
163 |
+
)
|
164 |
+
noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
|
165 |
+
|
166 |
+
if not self.relativistic_discriminator:
|
167 |
+
gan_G_loss = F.softplus(-noisy_fake_logit.float()).mean() * self.gan_g_weight
|
168 |
+
else:
|
169 |
+
relative_fake_logit = noisy_fake_logit - noisy_real_logit
|
170 |
+
gan_G_loss = F.softplus(-relative_fake_logit.float()).mean() * self.gan_g_weight
|
171 |
+
|
172 |
+
return gan_G_loss
|
173 |
+
|
174 |
+
def critic_loss(
|
175 |
+
self,
|
176 |
+
image_or_video_shape,
|
177 |
+
conditional_dict: dict,
|
178 |
+
unconditional_dict: dict,
|
179 |
+
clean_latent: torch.Tensor,
|
180 |
+
real_image_or_video: torch.Tensor,
|
181 |
+
initial_latent: torch.Tensor = None
|
182 |
+
) -> Tuple[torch.Tensor, dict]:
|
183 |
+
"""
|
184 |
+
Generate image/videos from noise and train the critic with generated samples.
|
185 |
+
The noisy input to the generator is backward simulated.
|
186 |
+
This removes the need of any datasets during distillation.
|
187 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
188 |
+
Input:
|
189 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
190 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
191 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
192 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
193 |
+
Output:
|
194 |
+
- loss: a scalar tensor representing the generator loss.
|
195 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
196 |
+
"""
|
197 |
+
|
198 |
+
# Step 1: Run generator on backward simulated noisy input
|
199 |
+
with torch.no_grad():
|
200 |
+
generated_image, _, denoised_timestep_from, denoised_timestep_to, num_sim_steps = self._run_generator(
|
201 |
+
image_or_video_shape=image_or_video_shape,
|
202 |
+
conditional_dict=conditional_dict,
|
203 |
+
initial_latent=initial_latent
|
204 |
+
)
|
205 |
+
|
206 |
+
# Step 2: Get timestep and add noise to generated/real latents
|
207 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
208 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
209 |
+
critic_timestep = self._get_timestep(
|
210 |
+
min_timestep,
|
211 |
+
max_timestep,
|
212 |
+
image_or_video_shape[0],
|
213 |
+
image_or_video_shape[1],
|
214 |
+
self.num_frame_per_block,
|
215 |
+
uniform_timestep=True
|
216 |
+
)
|
217 |
+
|
218 |
+
if self.critic_timestep_shift > 1:
|
219 |
+
critic_timestep = self.critic_timestep_shift * \
|
220 |
+
(critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
221 |
+
|
222 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
223 |
+
|
224 |
+
critic_noise = torch.randn_like(generated_image)
|
225 |
+
noisy_fake_latent = self.scheduler.add_noise(
|
226 |
+
generated_image.flatten(0, 1),
|
227 |
+
critic_noise.flatten(0, 1),
|
228 |
+
critic_timestep.flatten(0, 1)
|
229 |
+
).unflatten(0, image_or_video_shape[:2])
|
230 |
+
|
231 |
+
# Step 4: Compute the real GAN discriminator loss
|
232 |
+
noisy_real_latent = self.scheduler.add_noise(
|
233 |
+
real_image_or_video.flatten(0, 1),
|
234 |
+
critic_noise.flatten(0, 1),
|
235 |
+
critic_timestep.flatten(0, 1)
|
236 |
+
).unflatten(0, image_or_video_shape[:2])
|
237 |
+
|
238 |
+
conditional_dict_cloned = copy.deepcopy(conditional_dict)
|
239 |
+
conditional_dict_cloned["prompt_embeds"] = torch.concatenate(
|
240 |
+
(conditional_dict_cloned["prompt_embeds"], conditional_dict_cloned["prompt_embeds"]), dim=0)
|
241 |
+
_, _, noisy_logit = self.fake_score(
|
242 |
+
noisy_image_or_video=torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0),
|
243 |
+
conditional_dict=conditional_dict_cloned,
|
244 |
+
timestep=torch.concatenate((critic_timestep, critic_timestep), dim=0),
|
245 |
+
classify_mode=True,
|
246 |
+
concat_time_embeddings=self.concat_time_embeddings
|
247 |
+
)
|
248 |
+
noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
|
249 |
+
|
250 |
+
if not self.relativistic_discriminator:
|
251 |
+
gan_D_loss = F.softplus(-noisy_real_logit.float()).mean() + F.softplus(noisy_fake_logit.float()).mean()
|
252 |
+
else:
|
253 |
+
relative_real_logit = noisy_real_logit - noisy_fake_logit
|
254 |
+
gan_D_loss = F.softplus(-relative_real_logit.float()).mean()
|
255 |
+
gan_D_loss = gan_D_loss * self.gan_d_weight
|
256 |
+
|
257 |
+
# R1 regularization
|
258 |
+
if self.r1_weight > 0.:
|
259 |
+
noisy_real_latent_perturbed = noisy_real_latent.clone()
|
260 |
+
epison_real = self.r1_sigma * torch.randn_like(noisy_real_latent_perturbed)
|
261 |
+
noisy_real_latent_perturbed = noisy_real_latent_perturbed + epison_real
|
262 |
+
noisy_real_logit_perturbed = self._run_cls_pred_branch(
|
263 |
+
noisy_image_or_video=noisy_real_latent_perturbed,
|
264 |
+
conditional_dict=conditional_dict,
|
265 |
+
timestep=critic_timestep
|
266 |
+
)
|
267 |
+
|
268 |
+
r1_grad = (noisy_real_logit_perturbed - noisy_real_logit) / self.r1_sigma
|
269 |
+
r1_loss = self.r1_weight * torch.mean((r1_grad)**2)
|
270 |
+
else:
|
271 |
+
r1_loss = torch.zeros_like(gan_D_loss)
|
272 |
+
|
273 |
+
# R2 regularization
|
274 |
+
if self.r2_weight > 0.:
|
275 |
+
noisy_fake_latent_perturbed = noisy_fake_latent.clone()
|
276 |
+
epison_generated = self.r2_sigma * torch.randn_like(noisy_fake_latent_perturbed)
|
277 |
+
noisy_fake_latent_perturbed = noisy_fake_latent_perturbed + epison_generated
|
278 |
+
noisy_fake_logit_perturbed = self._run_cls_pred_branch(
|
279 |
+
noisy_image_or_video=noisy_fake_latent_perturbed,
|
280 |
+
conditional_dict=conditional_dict,
|
281 |
+
timestep=critic_timestep
|
282 |
+
)
|
283 |
+
|
284 |
+
r2_grad = (noisy_fake_logit_perturbed - noisy_fake_logit) / self.r2_sigma
|
285 |
+
r2_loss = self.r2_weight * torch.mean((r2_grad)**2)
|
286 |
+
else:
|
287 |
+
r2_loss = torch.zeros_like(r2_loss)
|
288 |
+
|
289 |
+
critic_log_dict = {
|
290 |
+
"critic_timestep": critic_timestep.detach(),
|
291 |
+
'noisy_real_logit': noisy_real_logit.detach(),
|
292 |
+
'noisy_fake_logit': noisy_fake_logit.detach(),
|
293 |
+
}
|
294 |
+
|
295 |
+
return (gan_D_loss, r1_loss, r2_loss), critic_log_dict
|
model/ode_regression.py
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn.functional as F
|
2 |
+
from typing import Tuple
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from model.base import BaseModel
|
6 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
7 |
+
|
8 |
+
|
9 |
+
class ODERegression(BaseModel):
|
10 |
+
def __init__(self, args, device):
|
11 |
+
"""
|
12 |
+
Initialize the ODERegression module.
|
13 |
+
This class is self-contained and compute generator losses
|
14 |
+
in the forward pass given precomputed ode solution pairs.
|
15 |
+
This class supports the ode regression loss for both causal and bidirectional models.
|
16 |
+
See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details
|
17 |
+
"""
|
18 |
+
super().__init__(args, device)
|
19 |
+
|
20 |
+
# Step 1: Initialize all models
|
21 |
+
|
22 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
23 |
+
self.generator.model.requires_grad_(True)
|
24 |
+
if getattr(args, "generator_ckpt", False):
|
25 |
+
print(f"Loading pretrained generator from {args.generator_ckpt}")
|
26 |
+
state_dict = torch.load(args.generator_ckpt, map_location="cpu")[
|
27 |
+
'generator']
|
28 |
+
self.generator.load_state_dict(
|
29 |
+
state_dict, strict=True
|
30 |
+
)
|
31 |
+
|
32 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
33 |
+
|
34 |
+
if self.num_frame_per_block > 1:
|
35 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
36 |
+
|
37 |
+
self.independent_first_frame = getattr(args, "independent_first_frame", False)
|
38 |
+
if self.independent_first_frame:
|
39 |
+
self.generator.model.independent_first_frame = True
|
40 |
+
if args.gradient_checkpointing:
|
41 |
+
self.generator.enable_gradient_checkpointing()
|
42 |
+
|
43 |
+
# Step 2: Initialize all hyperparameters
|
44 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
45 |
+
|
46 |
+
def _initialize_models(self, args):
|
47 |
+
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
|
48 |
+
self.generator.model.requires_grad_(True)
|
49 |
+
|
50 |
+
self.text_encoder = WanTextEncoder()
|
51 |
+
self.text_encoder.requires_grad_(False)
|
52 |
+
|
53 |
+
self.vae = WanVAEWrapper()
|
54 |
+
self.vae.requires_grad_(False)
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
58 |
+
"""
|
59 |
+
Given a tensor containing the whole ODE sampling trajectories,
|
60 |
+
randomly choose an intermediate timestep and return the latent as well as the corresponding timestep.
|
61 |
+
Input:
|
62 |
+
- ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
|
63 |
+
Output:
|
64 |
+
- noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width].
|
65 |
+
- timestep: a tensor containing the corresponding timestep [batch_size].
|
66 |
+
"""
|
67 |
+
batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape
|
68 |
+
|
69 |
+
# Step 1: Randomly choose a timestep for each frame
|
70 |
+
index = self._get_timestep(
|
71 |
+
0,
|
72 |
+
len(self.denoising_step_list),
|
73 |
+
batch_size,
|
74 |
+
num_frames,
|
75 |
+
self.num_frame_per_block,
|
76 |
+
uniform_timestep=False
|
77 |
+
)
|
78 |
+
if self.args.i2v:
|
79 |
+
index[:, 0] = len(self.denoising_step_list) - 1
|
80 |
+
|
81 |
+
noisy_input = torch.gather(
|
82 |
+
ode_latent, dim=1,
|
83 |
+
index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand(
|
84 |
+
-1, -1, -1, num_channels, height, width).to(self.device)
|
85 |
+
).squeeze(1)
|
86 |
+
|
87 |
+
timestep = self.denoising_step_list[index].to(self.device)
|
88 |
+
|
89 |
+
# if self.extra_noise_step > 0:
|
90 |
+
# random_timestep = torch.randint(0, self.extra_noise_step, [
|
91 |
+
# batch_size, num_frames], device=self.device, dtype=torch.long)
|
92 |
+
# perturbed_noisy_input = self.scheduler.add_noise(
|
93 |
+
# noisy_input.flatten(0, 1),
|
94 |
+
# torch.randn_like(noisy_input.flatten(0, 1)),
|
95 |
+
# random_timestep.flatten(0, 1)
|
96 |
+
# ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input)
|
97 |
+
|
98 |
+
# noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0]
|
99 |
+
|
100 |
+
return noisy_input, timestep
|
101 |
+
|
102 |
+
def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]:
|
103 |
+
"""
|
104 |
+
Generate image/videos from noisy latents and compute the ODE regression loss.
|
105 |
+
Input:
|
106 |
+
- ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
|
107 |
+
They are ordered from most noisy to clean latents.
|
108 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
109 |
+
Output:
|
110 |
+
- loss: a scalar tensor representing the generator loss.
|
111 |
+
- log_dict: a dictionary containing additional information for loss timestep breakdown.
|
112 |
+
"""
|
113 |
+
# Step 1: Run generator on noisy latents
|
114 |
+
target_latent = ode_latent[:, -1]
|
115 |
+
|
116 |
+
noisy_input, timestep = self._prepare_generator_input(
|
117 |
+
ode_latent=ode_latent)
|
118 |
+
|
119 |
+
_, pred_image_or_video = self.generator(
|
120 |
+
noisy_image_or_video=noisy_input,
|
121 |
+
conditional_dict=conditional_dict,
|
122 |
+
timestep=timestep
|
123 |
+
)
|
124 |
+
|
125 |
+
# Step 2: Compute the regression loss
|
126 |
+
mask = timestep != 0
|
127 |
+
|
128 |
+
loss = F.mse_loss(
|
129 |
+
pred_image_or_video[mask], target_latent[mask], reduction="mean")
|
130 |
+
|
131 |
+
log_dict = {
|
132 |
+
"unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(),
|
133 |
+
"timestep": timestep.float().mean(dim=1).detach(),
|
134 |
+
"input": noisy_input.detach(),
|
135 |
+
"output": pred_image_or_video.detach(),
|
136 |
+
}
|
137 |
+
|
138 |
+
return loss, log_dict
|
model/sid.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pipeline import SelfForcingTrainingPipeline
|
2 |
+
from typing import Optional, Tuple
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from model.base import SelfForcingModel
|
6 |
+
|
7 |
+
|
8 |
+
class SiD(SelfForcingModel):
|
9 |
+
def __init__(self, args, device):
|
10 |
+
"""
|
11 |
+
Initialize the DMD (Distribution Matching Distillation) module.
|
12 |
+
This class is self-contained and compute generator and fake score losses
|
13 |
+
in the forward pass.
|
14 |
+
"""
|
15 |
+
super().__init__(args, device)
|
16 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
17 |
+
|
18 |
+
if self.num_frame_per_block > 1:
|
19 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
20 |
+
|
21 |
+
if args.gradient_checkpointing:
|
22 |
+
self.generator.enable_gradient_checkpointing()
|
23 |
+
self.fake_score.enable_gradient_checkpointing()
|
24 |
+
self.real_score.enable_gradient_checkpointing()
|
25 |
+
|
26 |
+
# this will be init later with fsdp-wrapped modules
|
27 |
+
self.inference_pipeline: SelfForcingTrainingPipeline = None
|
28 |
+
|
29 |
+
# Step 2: Initialize all dmd hyperparameters
|
30 |
+
self.num_train_timestep = args.num_train_timestep
|
31 |
+
self.min_step = int(0.02 * self.num_train_timestep)
|
32 |
+
self.max_step = int(0.98 * self.num_train_timestep)
|
33 |
+
if hasattr(args, "real_guidance_scale"):
|
34 |
+
self.real_guidance_scale = args.real_guidance_scale
|
35 |
+
else:
|
36 |
+
self.real_guidance_scale = args.guidance_scale
|
37 |
+
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
|
38 |
+
self.sid_alpha = getattr(args, "sid_alpha", 1.0)
|
39 |
+
self.ts_schedule = getattr(args, "ts_schedule", True)
|
40 |
+
self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
|
41 |
+
|
42 |
+
if getattr(self.scheduler, "alphas_cumprod", None) is not None:
|
43 |
+
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
|
44 |
+
else:
|
45 |
+
self.scheduler.alphas_cumprod = None
|
46 |
+
|
47 |
+
def compute_distribution_matching_loss(
|
48 |
+
self,
|
49 |
+
image_or_video: torch.Tensor,
|
50 |
+
conditional_dict: dict,
|
51 |
+
unconditional_dict: dict,
|
52 |
+
gradient_mask: Optional[torch.Tensor] = None,
|
53 |
+
denoised_timestep_from: int = 0,
|
54 |
+
denoised_timestep_to: int = 0
|
55 |
+
) -> Tuple[torch.Tensor, dict]:
|
56 |
+
"""
|
57 |
+
Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
|
58 |
+
Input:
|
59 |
+
- image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
|
60 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
61 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
62 |
+
- gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
|
63 |
+
Output:
|
64 |
+
- dmd_loss: a scalar tensor representing the DMD loss.
|
65 |
+
- dmd_log_dict: a dictionary containing the intermediate tensors for logging.
|
66 |
+
"""
|
67 |
+
original_latent = image_or_video
|
68 |
+
|
69 |
+
batch_size, num_frame = image_or_video.shape[:2]
|
70 |
+
|
71 |
+
# Step 1: Randomly sample timestep based on the given schedule and corresponding noise
|
72 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
73 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
74 |
+
timestep = self._get_timestep(
|
75 |
+
min_timestep,
|
76 |
+
max_timestep,
|
77 |
+
batch_size,
|
78 |
+
num_frame,
|
79 |
+
self.num_frame_per_block,
|
80 |
+
uniform_timestep=True
|
81 |
+
)
|
82 |
+
|
83 |
+
if self.timestep_shift > 1:
|
84 |
+
timestep = self.timestep_shift * \
|
85 |
+
(timestep / 1000) / \
|
86 |
+
(1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
|
87 |
+
timestep = timestep.clamp(self.min_step, self.max_step)
|
88 |
+
|
89 |
+
noise = torch.randn_like(image_or_video)
|
90 |
+
noisy_latent = self.scheduler.add_noise(
|
91 |
+
image_or_video.flatten(0, 1),
|
92 |
+
noise.flatten(0, 1),
|
93 |
+
timestep.flatten(0, 1)
|
94 |
+
).unflatten(0, (batch_size, num_frame))
|
95 |
+
|
96 |
+
# Step 2: SiD (May be wrap it?)
|
97 |
+
noisy_image_or_video = noisy_latent
|
98 |
+
# Step 2.1: Compute the fake score
|
99 |
+
_, pred_fake_image = self.fake_score(
|
100 |
+
noisy_image_or_video=noisy_image_or_video,
|
101 |
+
conditional_dict=conditional_dict,
|
102 |
+
timestep=timestep
|
103 |
+
)
|
104 |
+
# Step 2.2: Compute the real score
|
105 |
+
# We compute the conditional and unconditional prediction
|
106 |
+
# and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
|
107 |
+
# NOTE: This step may cause OOM issue, which can be addressed by the CFG-free technique
|
108 |
+
|
109 |
+
_, pred_real_image_cond = self.real_score(
|
110 |
+
noisy_image_or_video=noisy_image_or_video,
|
111 |
+
conditional_dict=conditional_dict,
|
112 |
+
timestep=timestep
|
113 |
+
)
|
114 |
+
|
115 |
+
_, pred_real_image_uncond = self.real_score(
|
116 |
+
noisy_image_or_video=noisy_image_or_video,
|
117 |
+
conditional_dict=unconditional_dict,
|
118 |
+
timestep=timestep
|
119 |
+
)
|
120 |
+
|
121 |
+
pred_real_image = pred_real_image_cond + (
|
122 |
+
pred_real_image_cond - pred_real_image_uncond
|
123 |
+
) * self.real_guidance_scale
|
124 |
+
|
125 |
+
# Step 2.3: SiD Loss
|
126 |
+
# TODO: Add alpha
|
127 |
+
# TODO: Double?
|
128 |
+
sid_loss = (pred_real_image.double() - pred_fake_image.double()) * ((pred_real_image.double() - original_latent.double()) - self.sid_alpha * (pred_real_image.double() - pred_fake_image.double()))
|
129 |
+
|
130 |
+
# Step 2.4: Loss normalizer
|
131 |
+
with torch.no_grad():
|
132 |
+
p_real = (original_latent - pred_real_image)
|
133 |
+
normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
|
134 |
+
sid_loss = sid_loss / normalizer
|
135 |
+
|
136 |
+
sid_loss = torch.nan_to_num(sid_loss)
|
137 |
+
num_frame = sid_loss.shape[1]
|
138 |
+
sid_loss = sid_loss.mean()
|
139 |
+
|
140 |
+
sid_log_dict = {
|
141 |
+
"dmdtrain_gradient_norm": torch.zeros_like(sid_loss),
|
142 |
+
"timestep": timestep.detach()
|
143 |
+
}
|
144 |
+
|
145 |
+
return sid_loss, sid_log_dict
|
146 |
+
|
147 |
+
def generator_loss(
|
148 |
+
self,
|
149 |
+
image_or_video_shape,
|
150 |
+
conditional_dict: dict,
|
151 |
+
unconditional_dict: dict,
|
152 |
+
clean_latent: torch.Tensor,
|
153 |
+
initial_latent: torch.Tensor = None
|
154 |
+
) -> Tuple[torch.Tensor, dict]:
|
155 |
+
"""
|
156 |
+
Generate image/videos from noise and compute the DMD loss.
|
157 |
+
The noisy input to the generator is backward simulated.
|
158 |
+
This removes the need of any datasets during distillation.
|
159 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
160 |
+
Input:
|
161 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
162 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
163 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
164 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
165 |
+
Output:
|
166 |
+
- loss: a scalar tensor representing the generator loss.
|
167 |
+
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
|
168 |
+
"""
|
169 |
+
# Step 1: Unroll generator to obtain fake videos
|
170 |
+
pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
171 |
+
image_or_video_shape=image_or_video_shape,
|
172 |
+
conditional_dict=conditional_dict,
|
173 |
+
initial_latent=initial_latent
|
174 |
+
)
|
175 |
+
|
176 |
+
# Step 2: Compute the DMD loss
|
177 |
+
dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
|
178 |
+
image_or_video=pred_image,
|
179 |
+
conditional_dict=conditional_dict,
|
180 |
+
unconditional_dict=unconditional_dict,
|
181 |
+
gradient_mask=gradient_mask,
|
182 |
+
denoised_timestep_from=denoised_timestep_from,
|
183 |
+
denoised_timestep_to=denoised_timestep_to
|
184 |
+
)
|
185 |
+
|
186 |
+
return dmd_loss, dmd_log_dict
|
187 |
+
|
188 |
+
def critic_loss(
|
189 |
+
self,
|
190 |
+
image_or_video_shape,
|
191 |
+
conditional_dict: dict,
|
192 |
+
unconditional_dict: dict,
|
193 |
+
clean_latent: torch.Tensor,
|
194 |
+
initial_latent: torch.Tensor = None
|
195 |
+
) -> Tuple[torch.Tensor, dict]:
|
196 |
+
"""
|
197 |
+
Generate image/videos from noise and train the critic with generated samples.
|
198 |
+
The noisy input to the generator is backward simulated.
|
199 |
+
This removes the need of any datasets during distillation.
|
200 |
+
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
|
201 |
+
Input:
|
202 |
+
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
|
203 |
+
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
|
204 |
+
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
|
205 |
+
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
|
206 |
+
Output:
|
207 |
+
- loss: a scalar tensor representing the generator loss.
|
208 |
+
- critic_log_dict: a dictionary containing the intermediate tensors for logging.
|
209 |
+
"""
|
210 |
+
|
211 |
+
# Step 1: Run generator on backward simulated noisy input
|
212 |
+
with torch.no_grad():
|
213 |
+
generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
|
214 |
+
image_or_video_shape=image_or_video_shape,
|
215 |
+
conditional_dict=conditional_dict,
|
216 |
+
initial_latent=initial_latent
|
217 |
+
)
|
218 |
+
|
219 |
+
# Step 2: Compute the fake prediction
|
220 |
+
min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
|
221 |
+
max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
|
222 |
+
critic_timestep = self._get_timestep(
|
223 |
+
min_timestep,
|
224 |
+
max_timestep,
|
225 |
+
image_or_video_shape[0],
|
226 |
+
image_or_video_shape[1],
|
227 |
+
self.num_frame_per_block,
|
228 |
+
uniform_timestep=True
|
229 |
+
)
|
230 |
+
|
231 |
+
if self.timestep_shift > 1:
|
232 |
+
critic_timestep = self.timestep_shift * \
|
233 |
+
(critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
|
234 |
+
|
235 |
+
critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
|
236 |
+
|
237 |
+
critic_noise = torch.randn_like(generated_image)
|
238 |
+
noisy_generated_image = self.scheduler.add_noise(
|
239 |
+
generated_image.flatten(0, 1),
|
240 |
+
critic_noise.flatten(0, 1),
|
241 |
+
critic_timestep.flatten(0, 1)
|
242 |
+
).unflatten(0, image_or_video_shape[:2])
|
243 |
+
|
244 |
+
_, pred_fake_image = self.fake_score(
|
245 |
+
noisy_image_or_video=noisy_generated_image,
|
246 |
+
conditional_dict=conditional_dict,
|
247 |
+
timestep=critic_timestep
|
248 |
+
)
|
249 |
+
|
250 |
+
# Step 3: Compute the denoising loss for the fake critic
|
251 |
+
if self.args.denoising_loss_type == "flow":
|
252 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
253 |
+
flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
|
254 |
+
scheduler=self.scheduler,
|
255 |
+
x0_pred=pred_fake_image.flatten(0, 1),
|
256 |
+
xt=noisy_generated_image.flatten(0, 1),
|
257 |
+
timestep=critic_timestep.flatten(0, 1)
|
258 |
+
)
|
259 |
+
pred_fake_noise = None
|
260 |
+
else:
|
261 |
+
flow_pred = None
|
262 |
+
pred_fake_noise = self.scheduler.convert_x0_to_noise(
|
263 |
+
x0=pred_fake_image.flatten(0, 1),
|
264 |
+
xt=noisy_generated_image.flatten(0, 1),
|
265 |
+
timestep=critic_timestep.flatten(0, 1)
|
266 |
+
).unflatten(0, image_or_video_shape[:2])
|
267 |
+
|
268 |
+
denoising_loss = self.denoising_loss_func(
|
269 |
+
x=generated_image.flatten(0, 1),
|
270 |
+
x_pred=pred_fake_image.flatten(0, 1),
|
271 |
+
noise=critic_noise.flatten(0, 1),
|
272 |
+
noise_pred=pred_fake_noise,
|
273 |
+
alphas_cumprod=self.scheduler.alphas_cumprod,
|
274 |
+
timestep=critic_timestep.flatten(0, 1),
|
275 |
+
flow_pred=flow_pred
|
276 |
+
)
|
277 |
+
|
278 |
+
# Step 5: Debugging Log
|
279 |
+
critic_log_dict = {
|
280 |
+
"critic_timestep": critic_timestep.detach()
|
281 |
+
}
|
282 |
+
|
283 |
+
return denoising_loss, critic_log_dict
|
pipeline/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline
|
2 |
+
from .bidirectional_inference import BidirectionalInferencePipeline
|
3 |
+
from .causal_diffusion_inference import CausalDiffusionInferencePipeline
|
4 |
+
from .causal_inference import CausalInferencePipeline
|
5 |
+
from .self_forcing_training import SelfForcingTrainingPipeline
|
6 |
+
|
7 |
+
__all__ = [
|
8 |
+
"BidirectionalDiffusionInferencePipeline",
|
9 |
+
"BidirectionalInferencePipeline",
|
10 |
+
"CausalDiffusionInferencePipeline",
|
11 |
+
"CausalInferencePipeline",
|
12 |
+
"SelfForcingTrainingPipeline"
|
13 |
+
]
|
pipeline/bidirectional_diffusion_inference.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from typing import List
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
6 |
+
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
7 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
8 |
+
|
9 |
+
|
10 |
+
class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
args,
|
14 |
+
device,
|
15 |
+
generator=None,
|
16 |
+
text_encoder=None,
|
17 |
+
vae=None
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
# Step 1: Initialize all models
|
21 |
+
self.generator = WanDiffusionWrapper(
|
22 |
+
**getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
|
23 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
24 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
25 |
+
|
26 |
+
# Step 2: Initialize scheduler
|
27 |
+
self.num_train_timesteps = args.num_train_timestep
|
28 |
+
self.sampling_steps = 50
|
29 |
+
self.sample_solver = 'unipc'
|
30 |
+
self.shift = 8.0
|
31 |
+
|
32 |
+
self.args = args
|
33 |
+
|
34 |
+
def inference(
|
35 |
+
self,
|
36 |
+
noise: torch.Tensor,
|
37 |
+
text_prompts: List[str],
|
38 |
+
return_latents=False
|
39 |
+
) -> torch.Tensor:
|
40 |
+
"""
|
41 |
+
Perform inference on the given noise and text prompts.
|
42 |
+
Inputs:
|
43 |
+
noise (torch.Tensor): The input noise tensor of shape
|
44 |
+
(batch_size, num_frames, num_channels, height, width).
|
45 |
+
text_prompts (List[str]): The list of text prompts.
|
46 |
+
Outputs:
|
47 |
+
video (torch.Tensor): The generated video tensor of shape
|
48 |
+
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
|
49 |
+
"""
|
50 |
+
|
51 |
+
conditional_dict = self.text_encoder(
|
52 |
+
text_prompts=text_prompts
|
53 |
+
)
|
54 |
+
unconditional_dict = self.text_encoder(
|
55 |
+
text_prompts=[self.args.negative_prompt] * len(text_prompts)
|
56 |
+
)
|
57 |
+
|
58 |
+
latents = noise
|
59 |
+
|
60 |
+
sample_scheduler = self._initialize_sample_scheduler(noise)
|
61 |
+
for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
|
62 |
+
latent_model_input = latents
|
63 |
+
timestep = t * torch.ones([latents.shape[0], 21], device=noise.device, dtype=torch.float32)
|
64 |
+
|
65 |
+
flow_pred_cond, _ = self.generator(latent_model_input, conditional_dict, timestep)
|
66 |
+
flow_pred_uncond, _ = self.generator(latent_model_input, unconditional_dict, timestep)
|
67 |
+
|
68 |
+
flow_pred = flow_pred_uncond + self.args.guidance_scale * (
|
69 |
+
flow_pred_cond - flow_pred_uncond)
|
70 |
+
|
71 |
+
temp_x0 = sample_scheduler.step(
|
72 |
+
flow_pred.unsqueeze(0),
|
73 |
+
t,
|
74 |
+
latents.unsqueeze(0),
|
75 |
+
return_dict=False)[0]
|
76 |
+
latents = temp_x0.squeeze(0)
|
77 |
+
|
78 |
+
x0 = latents
|
79 |
+
video = self.vae.decode_to_pixel(x0)
|
80 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
81 |
+
|
82 |
+
del sample_scheduler
|
83 |
+
|
84 |
+
if return_latents:
|
85 |
+
return video, latents
|
86 |
+
else:
|
87 |
+
return video
|
88 |
+
|
89 |
+
def _initialize_sample_scheduler(self, noise):
|
90 |
+
if self.sample_solver == 'unipc':
|
91 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
92 |
+
num_train_timesteps=self.num_train_timesteps,
|
93 |
+
shift=1,
|
94 |
+
use_dynamic_shifting=False)
|
95 |
+
sample_scheduler.set_timesteps(
|
96 |
+
self.sampling_steps, device=noise.device, shift=self.shift)
|
97 |
+
self.timesteps = sample_scheduler.timesteps
|
98 |
+
elif self.sample_solver == 'dpm++':
|
99 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
100 |
+
num_train_timesteps=self.num_train_timesteps,
|
101 |
+
shift=1,
|
102 |
+
use_dynamic_shifting=False)
|
103 |
+
sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
|
104 |
+
self.timesteps, _ = retrieve_timesteps(
|
105 |
+
sample_scheduler,
|
106 |
+
device=noise.device,
|
107 |
+
sigmas=sampling_sigmas)
|
108 |
+
else:
|
109 |
+
raise NotImplementedError("Unsupported solver.")
|
110 |
+
return sample_scheduler
|
pipeline/bidirectional_inference.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
5 |
+
|
6 |
+
|
7 |
+
class BidirectionalInferencePipeline(torch.nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
args,
|
11 |
+
device,
|
12 |
+
generator=None,
|
13 |
+
text_encoder=None,
|
14 |
+
vae=None
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
# Step 1: Initialize all models
|
18 |
+
self.generator = WanDiffusionWrapper(
|
19 |
+
**getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
|
20 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
21 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
22 |
+
|
23 |
+
# Step 2: Initialize all bidirectional wan hyperparmeters
|
24 |
+
self.scheduler = self.generator.get_scheduler()
|
25 |
+
self.denoising_step_list = torch.tensor(
|
26 |
+
args.denoising_step_list, dtype=torch.long, device=device)
|
27 |
+
if self.denoising_step_list[-1] == 0:
|
28 |
+
self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
|
29 |
+
if args.warp_denoising_step:
|
30 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
31 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
32 |
+
|
33 |
+
def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor:
|
34 |
+
"""
|
35 |
+
Perform inference on the given noise and text prompts.
|
36 |
+
Inputs:
|
37 |
+
noise (torch.Tensor): The input noise tensor of shape
|
38 |
+
(batch_size, num_frames, num_channels, height, width).
|
39 |
+
text_prompts (List[str]): The list of text prompts.
|
40 |
+
Outputs:
|
41 |
+
video (torch.Tensor): The generated video tensor of shape
|
42 |
+
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
|
43 |
+
"""
|
44 |
+
conditional_dict = self.text_encoder(
|
45 |
+
text_prompts=text_prompts
|
46 |
+
)
|
47 |
+
|
48 |
+
# initial point
|
49 |
+
noisy_image_or_video = noise
|
50 |
+
|
51 |
+
# use the last n-1 timesteps to simulate the generator's input
|
52 |
+
for index, current_timestep in enumerate(self.denoising_step_list[:-1]):
|
53 |
+
_, pred_image_or_video = self.generator(
|
54 |
+
noisy_image_or_video=noisy_image_or_video,
|
55 |
+
conditional_dict=conditional_dict,
|
56 |
+
timestep=torch.ones(
|
57 |
+
noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep
|
58 |
+
) # [B, F, C, H, W]
|
59 |
+
|
60 |
+
next_timestep = self.denoising_step_list[index + 1] * torch.ones(
|
61 |
+
noise.shape[:2], dtype=torch.long, device=noise.device)
|
62 |
+
|
63 |
+
noisy_image_or_video = self.scheduler.add_noise(
|
64 |
+
pred_image_or_video.flatten(0, 1),
|
65 |
+
torch.randn_like(pred_image_or_video.flatten(0, 1)),
|
66 |
+
next_timestep.flatten(0, 1)
|
67 |
+
).unflatten(0, noise.shape[:2])
|
68 |
+
|
69 |
+
video = self.vae.decode_to_pixel(pred_image_or_video)
|
70 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
71 |
+
return video
|
pipeline/causal_diffusion_inference.py
ADDED
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
from typing import List, Optional
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
|
6 |
+
from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
|
7 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
8 |
+
|
9 |
+
|
10 |
+
class CausalDiffusionInferencePipeline(torch.nn.Module):
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
args,
|
14 |
+
device,
|
15 |
+
generator=None,
|
16 |
+
text_encoder=None,
|
17 |
+
vae=None
|
18 |
+
):
|
19 |
+
super().__init__()
|
20 |
+
# Step 1: Initialize all models
|
21 |
+
self.generator = WanDiffusionWrapper(
|
22 |
+
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
|
23 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
24 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
25 |
+
|
26 |
+
# Step 2: Initialize scheduler
|
27 |
+
self.num_train_timesteps = args.num_train_timestep
|
28 |
+
self.sampling_steps = 50
|
29 |
+
self.sample_solver = 'unipc'
|
30 |
+
self.shift = args.timestep_shift
|
31 |
+
|
32 |
+
self.num_transformer_blocks = 30
|
33 |
+
self.frame_seq_length = 1560
|
34 |
+
|
35 |
+
self.kv_cache_pos = None
|
36 |
+
self.kv_cache_neg = None
|
37 |
+
self.crossattn_cache_pos = None
|
38 |
+
self.crossattn_cache_neg = None
|
39 |
+
self.args = args
|
40 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
41 |
+
self.independent_first_frame = args.independent_first_frame
|
42 |
+
self.local_attn_size = self.generator.model.local_attn_size
|
43 |
+
|
44 |
+
print(f"KV inference with {self.num_frame_per_block} frames per block")
|
45 |
+
|
46 |
+
if self.num_frame_per_block > 1:
|
47 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
48 |
+
|
49 |
+
def inference(
|
50 |
+
self,
|
51 |
+
noise: torch.Tensor,
|
52 |
+
text_prompts: List[str],
|
53 |
+
initial_latent: Optional[torch.Tensor] = None,
|
54 |
+
return_latents: bool = False,
|
55 |
+
start_frame_index: Optional[int] = 0
|
56 |
+
) -> torch.Tensor:
|
57 |
+
"""
|
58 |
+
Perform inference on the given noise and text prompts.
|
59 |
+
Inputs:
|
60 |
+
noise (torch.Tensor): The input noise tensor of shape
|
61 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
62 |
+
text_prompts (List[str]): The list of text prompts.
|
63 |
+
initial_latent (torch.Tensor): The initial latent tensor of shape
|
64 |
+
(batch_size, num_input_frames, num_channels, height, width).
|
65 |
+
If num_input_frames is 1, perform image to video.
|
66 |
+
If num_input_frames is greater than 1, perform video extension.
|
67 |
+
return_latents (bool): Whether to return the latents.
|
68 |
+
start_frame_index (int): In long video generation, where does the current window start?
|
69 |
+
Outputs:
|
70 |
+
video (torch.Tensor): The generated video tensor of shape
|
71 |
+
(batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
|
72 |
+
"""
|
73 |
+
batch_size, num_frames, num_channels, height, width = noise.shape
|
74 |
+
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
|
75 |
+
# If the first frame is independent and the first frame is provided, then the number of frames in the
|
76 |
+
# noise should still be a multiple of num_frame_per_block
|
77 |
+
assert num_frames % self.num_frame_per_block == 0
|
78 |
+
num_blocks = num_frames // self.num_frame_per_block
|
79 |
+
elif self.independent_first_frame and initial_latent is None:
|
80 |
+
# Using a [1, 4, 4, 4, 4, 4] model to generate a video without image conditioning
|
81 |
+
assert (num_frames - 1) % self.num_frame_per_block == 0
|
82 |
+
num_blocks = (num_frames - 1) // self.num_frame_per_block
|
83 |
+
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
|
84 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
85 |
+
conditional_dict = self.text_encoder(
|
86 |
+
text_prompts=text_prompts
|
87 |
+
)
|
88 |
+
unconditional_dict = self.text_encoder(
|
89 |
+
text_prompts=[self.args.negative_prompt] * len(text_prompts)
|
90 |
+
)
|
91 |
+
|
92 |
+
output = torch.zeros(
|
93 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
94 |
+
device=noise.device,
|
95 |
+
dtype=noise.dtype
|
96 |
+
)
|
97 |
+
|
98 |
+
# Step 1: Initialize KV cache to all zeros
|
99 |
+
if self.kv_cache_pos is None:
|
100 |
+
self._initialize_kv_cache(
|
101 |
+
batch_size=batch_size,
|
102 |
+
dtype=noise.dtype,
|
103 |
+
device=noise.device
|
104 |
+
)
|
105 |
+
self._initialize_crossattn_cache(
|
106 |
+
batch_size=batch_size,
|
107 |
+
dtype=noise.dtype,
|
108 |
+
device=noise.device
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
# reset cross attn cache
|
112 |
+
for block_index in range(self.num_transformer_blocks):
|
113 |
+
self.crossattn_cache_pos[block_index]["is_init"] = False
|
114 |
+
self.crossattn_cache_neg[block_index]["is_init"] = False
|
115 |
+
# reset kv cache
|
116 |
+
for block_index in range(len(self.kv_cache_pos)):
|
117 |
+
self.kv_cache_pos[block_index]["global_end_index"] = torch.tensor(
|
118 |
+
[0], dtype=torch.long, device=noise.device)
|
119 |
+
self.kv_cache_pos[block_index]["local_end_index"] = torch.tensor(
|
120 |
+
[0], dtype=torch.long, device=noise.device)
|
121 |
+
self.kv_cache_neg[block_index]["global_end_index"] = torch.tensor(
|
122 |
+
[0], dtype=torch.long, device=noise.device)
|
123 |
+
self.kv_cache_neg[block_index]["local_end_index"] = torch.tensor(
|
124 |
+
[0], dtype=torch.long, device=noise.device)
|
125 |
+
|
126 |
+
# Step 2: Cache context feature
|
127 |
+
current_start_frame = start_frame_index
|
128 |
+
cache_start_frame = 0
|
129 |
+
if initial_latent is not None:
|
130 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
131 |
+
if self.independent_first_frame:
|
132 |
+
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
|
133 |
+
assert (num_input_frames - 1) % self.num_frame_per_block == 0
|
134 |
+
num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
|
135 |
+
output[:, :1] = initial_latent[:, :1]
|
136 |
+
self.generator(
|
137 |
+
noisy_image_or_video=initial_latent[:, :1],
|
138 |
+
conditional_dict=conditional_dict,
|
139 |
+
timestep=timestep * 0,
|
140 |
+
kv_cache=self.kv_cache_pos,
|
141 |
+
crossattn_cache=self.crossattn_cache_pos,
|
142 |
+
current_start=current_start_frame * self.frame_seq_length,
|
143 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
144 |
+
)
|
145 |
+
self.generator(
|
146 |
+
noisy_image_or_video=initial_latent[:, :1],
|
147 |
+
conditional_dict=unconditional_dict,
|
148 |
+
timestep=timestep * 0,
|
149 |
+
kv_cache=self.kv_cache_neg,
|
150 |
+
crossattn_cache=self.crossattn_cache_neg,
|
151 |
+
current_start=current_start_frame * self.frame_seq_length,
|
152 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
153 |
+
)
|
154 |
+
current_start_frame += 1
|
155 |
+
cache_start_frame += 1
|
156 |
+
else:
|
157 |
+
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
|
158 |
+
assert num_input_frames % self.num_frame_per_block == 0
|
159 |
+
num_input_blocks = num_input_frames // self.num_frame_per_block
|
160 |
+
|
161 |
+
for block_index in range(num_input_blocks):
|
162 |
+
current_ref_latents = \
|
163 |
+
initial_latent[:, cache_start_frame:cache_start_frame + self.num_frame_per_block]
|
164 |
+
output[:, cache_start_frame:cache_start_frame + self.num_frame_per_block] = current_ref_latents
|
165 |
+
self.generator(
|
166 |
+
noisy_image_or_video=current_ref_latents,
|
167 |
+
conditional_dict=conditional_dict,
|
168 |
+
timestep=timestep * 0,
|
169 |
+
kv_cache=self.kv_cache_pos,
|
170 |
+
crossattn_cache=self.crossattn_cache_pos,
|
171 |
+
current_start=current_start_frame * self.frame_seq_length,
|
172 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
173 |
+
)
|
174 |
+
self.generator(
|
175 |
+
noisy_image_or_video=current_ref_latents,
|
176 |
+
conditional_dict=unconditional_dict,
|
177 |
+
timestep=timestep * 0,
|
178 |
+
kv_cache=self.kv_cache_neg,
|
179 |
+
crossattn_cache=self.crossattn_cache_neg,
|
180 |
+
current_start=current_start_frame * self.frame_seq_length,
|
181 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
182 |
+
)
|
183 |
+
current_start_frame += self.num_frame_per_block
|
184 |
+
cache_start_frame += self.num_frame_per_block
|
185 |
+
|
186 |
+
# Step 3: Temporal denoising loop
|
187 |
+
all_num_frames = [self.num_frame_per_block] * num_blocks
|
188 |
+
if self.independent_first_frame and initial_latent is None:
|
189 |
+
all_num_frames = [1] + all_num_frames
|
190 |
+
for current_num_frames in all_num_frames:
|
191 |
+
noisy_input = noise[
|
192 |
+
:, cache_start_frame - num_input_frames:cache_start_frame + current_num_frames - num_input_frames]
|
193 |
+
latents = noisy_input
|
194 |
+
|
195 |
+
# Step 3.1: Spatial denoising loop
|
196 |
+
sample_scheduler = self._initialize_sample_scheduler(noise)
|
197 |
+
for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
|
198 |
+
latent_model_input = latents
|
199 |
+
timestep = t * torch.ones(
|
200 |
+
[batch_size, current_num_frames], device=noise.device, dtype=torch.float32
|
201 |
+
)
|
202 |
+
|
203 |
+
flow_pred_cond, _ = self.generator(
|
204 |
+
noisy_image_or_video=latent_model_input,
|
205 |
+
conditional_dict=conditional_dict,
|
206 |
+
timestep=timestep,
|
207 |
+
kv_cache=self.kv_cache_pos,
|
208 |
+
crossattn_cache=self.crossattn_cache_pos,
|
209 |
+
current_start=current_start_frame * self.frame_seq_length,
|
210 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
211 |
+
)
|
212 |
+
flow_pred_uncond, _ = self.generator(
|
213 |
+
noisy_image_or_video=latent_model_input,
|
214 |
+
conditional_dict=unconditional_dict,
|
215 |
+
timestep=timestep,
|
216 |
+
kv_cache=self.kv_cache_neg,
|
217 |
+
crossattn_cache=self.crossattn_cache_neg,
|
218 |
+
current_start=current_start_frame * self.frame_seq_length,
|
219 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
220 |
+
)
|
221 |
+
|
222 |
+
flow_pred = flow_pred_uncond + self.args.guidance_scale * (
|
223 |
+
flow_pred_cond - flow_pred_uncond)
|
224 |
+
|
225 |
+
temp_x0 = sample_scheduler.step(
|
226 |
+
flow_pred,
|
227 |
+
t,
|
228 |
+
latents,
|
229 |
+
return_dict=False)[0]
|
230 |
+
latents = temp_x0
|
231 |
+
print(f"kv_cache['local_end_index']: {self.kv_cache_pos[0]['local_end_index']}")
|
232 |
+
print(f"kv_cache['global_end_index']: {self.kv_cache_pos[0]['global_end_index']}")
|
233 |
+
|
234 |
+
# Step 3.2: record the model's output
|
235 |
+
output[:, cache_start_frame:cache_start_frame + current_num_frames] = latents
|
236 |
+
|
237 |
+
# Step 3.3: rerun with timestep zero to update KV cache using clean context
|
238 |
+
self.generator(
|
239 |
+
noisy_image_or_video=latents,
|
240 |
+
conditional_dict=conditional_dict,
|
241 |
+
timestep=timestep * 0,
|
242 |
+
kv_cache=self.kv_cache_pos,
|
243 |
+
crossattn_cache=self.crossattn_cache_pos,
|
244 |
+
current_start=current_start_frame * self.frame_seq_length,
|
245 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
246 |
+
)
|
247 |
+
self.generator(
|
248 |
+
noisy_image_or_video=latents,
|
249 |
+
conditional_dict=unconditional_dict,
|
250 |
+
timestep=timestep * 0,
|
251 |
+
kv_cache=self.kv_cache_neg,
|
252 |
+
crossattn_cache=self.crossattn_cache_neg,
|
253 |
+
current_start=current_start_frame * self.frame_seq_length,
|
254 |
+
cache_start=cache_start_frame * self.frame_seq_length
|
255 |
+
)
|
256 |
+
|
257 |
+
# Step 3.4: update the start and end frame indices
|
258 |
+
current_start_frame += current_num_frames
|
259 |
+
cache_start_frame += current_num_frames
|
260 |
+
|
261 |
+
# Step 4: Decode the output
|
262 |
+
video = self.vae.decode_to_pixel(output)
|
263 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
264 |
+
|
265 |
+
if return_latents:
|
266 |
+
return video, output
|
267 |
+
else:
|
268 |
+
return video
|
269 |
+
|
270 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
271 |
+
"""
|
272 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
273 |
+
"""
|
274 |
+
kv_cache_pos = []
|
275 |
+
kv_cache_neg = []
|
276 |
+
if self.local_attn_size != -1:
|
277 |
+
# Use the local attention size to compute the KV cache size
|
278 |
+
kv_cache_size = self.local_attn_size * self.frame_seq_length
|
279 |
+
else:
|
280 |
+
# Use the default KV cache size
|
281 |
+
kv_cache_size = 32760
|
282 |
+
|
283 |
+
for _ in range(self.num_transformer_blocks):
|
284 |
+
kv_cache_pos.append({
|
285 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
286 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
287 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
288 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
289 |
+
})
|
290 |
+
kv_cache_neg.append({
|
291 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
292 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
293 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
294 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
295 |
+
})
|
296 |
+
|
297 |
+
self.kv_cache_pos = kv_cache_pos # always store the clean cache
|
298 |
+
self.kv_cache_neg = kv_cache_neg # always store the clean cache
|
299 |
+
|
300 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
301 |
+
"""
|
302 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
303 |
+
"""
|
304 |
+
crossattn_cache_pos = []
|
305 |
+
crossattn_cache_neg = []
|
306 |
+
for _ in range(self.num_transformer_blocks):
|
307 |
+
crossattn_cache_pos.append({
|
308 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
309 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
310 |
+
"is_init": False
|
311 |
+
})
|
312 |
+
crossattn_cache_neg.append({
|
313 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
314 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
315 |
+
"is_init": False
|
316 |
+
})
|
317 |
+
|
318 |
+
self.crossattn_cache_pos = crossattn_cache_pos # always store the clean cache
|
319 |
+
self.crossattn_cache_neg = crossattn_cache_neg # always store the clean cache
|
320 |
+
|
321 |
+
def _initialize_sample_scheduler(self, noise):
|
322 |
+
if self.sample_solver == 'unipc':
|
323 |
+
sample_scheduler = FlowUniPCMultistepScheduler(
|
324 |
+
num_train_timesteps=self.num_train_timesteps,
|
325 |
+
shift=1,
|
326 |
+
use_dynamic_shifting=False)
|
327 |
+
sample_scheduler.set_timesteps(
|
328 |
+
self.sampling_steps, device=noise.device, shift=self.shift)
|
329 |
+
self.timesteps = sample_scheduler.timesteps
|
330 |
+
elif self.sample_solver == 'dpm++':
|
331 |
+
sample_scheduler = FlowDPMSolverMultistepScheduler(
|
332 |
+
num_train_timesteps=self.num_train_timesteps,
|
333 |
+
shift=1,
|
334 |
+
use_dynamic_shifting=False)
|
335 |
+
sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
|
336 |
+
self.timesteps, _ = retrieve_timesteps(
|
337 |
+
sample_scheduler,
|
338 |
+
device=noise.device,
|
339 |
+
sigmas=sampling_sigmas)
|
340 |
+
else:
|
341 |
+
raise NotImplementedError("Unsupported solver.")
|
342 |
+
return sample_scheduler
|
pipeline/causal_inference.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
|
5 |
+
|
6 |
+
|
7 |
+
class CausalInferencePipeline(torch.nn.Module):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
args,
|
11 |
+
device,
|
12 |
+
generator=None,
|
13 |
+
text_encoder=None,
|
14 |
+
vae=None
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
# Step 1: Initialize all models
|
18 |
+
self.generator = WanDiffusionWrapper(
|
19 |
+
**getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
|
20 |
+
self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
|
21 |
+
self.vae = WanVAEWrapper() if vae is None else vae
|
22 |
+
|
23 |
+
# Step 2: Initialize all causal hyperparmeters
|
24 |
+
self.scheduler = self.generator.get_scheduler()
|
25 |
+
self.denoising_step_list = torch.tensor(
|
26 |
+
args.denoising_step_list, dtype=torch.long)
|
27 |
+
if args.warp_denoising_step:
|
28 |
+
timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
|
29 |
+
self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
|
30 |
+
|
31 |
+
self.num_transformer_blocks = 30
|
32 |
+
self.frame_seq_length = 1560
|
33 |
+
|
34 |
+
self.kv_cache1 = None
|
35 |
+
self.args = args
|
36 |
+
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
|
37 |
+
self.independent_first_frame = args.independent_first_frame
|
38 |
+
self.local_attn_size = self.generator.model.local_attn_size
|
39 |
+
|
40 |
+
print(f"KV inference with {self.num_frame_per_block} frames per block")
|
41 |
+
|
42 |
+
if self.num_frame_per_block > 1:
|
43 |
+
self.generator.model.num_frame_per_block = self.num_frame_per_block
|
44 |
+
|
45 |
+
def inference(
|
46 |
+
self,
|
47 |
+
noise: torch.Tensor,
|
48 |
+
text_prompts: List[str],
|
49 |
+
initial_latent: Optional[torch.Tensor] = None,
|
50 |
+
return_latents: bool = False,
|
51 |
+
profile: bool = False
|
52 |
+
) -> torch.Tensor:
|
53 |
+
"""
|
54 |
+
Perform inference on the given noise and text prompts.
|
55 |
+
Inputs:
|
56 |
+
noise (torch.Tensor): The input noise tensor of shape
|
57 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
58 |
+
text_prompts (List[str]): The list of text prompts.
|
59 |
+
initial_latent (torch.Tensor): The initial latent tensor of shape
|
60 |
+
(batch_size, num_input_frames, num_channels, height, width).
|
61 |
+
If num_input_frames is 1, perform image to video.
|
62 |
+
If num_input_frames is greater than 1, perform video extension.
|
63 |
+
return_latents (bool): Whether to return the latents.
|
64 |
+
Outputs:
|
65 |
+
video (torch.Tensor): The generated video tensor of shape
|
66 |
+
(batch_size, num_output_frames, num_channels, height, width).
|
67 |
+
It is normalized to be in the range [0, 1].
|
68 |
+
"""
|
69 |
+
batch_size, num_frames, num_channels, height, width = noise.shape
|
70 |
+
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
|
71 |
+
# If the first frame is independent and the first frame is provided, then the number of frames in the
|
72 |
+
# noise should still be a multiple of num_frame_per_block
|
73 |
+
assert num_frames % self.num_frame_per_block == 0
|
74 |
+
num_blocks = num_frames // self.num_frame_per_block
|
75 |
+
else:
|
76 |
+
# Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
|
77 |
+
assert (num_frames - 1) % self.num_frame_per_block == 0
|
78 |
+
num_blocks = (num_frames - 1) // self.num_frame_per_block
|
79 |
+
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
|
80 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
81 |
+
conditional_dict = self.text_encoder(
|
82 |
+
text_prompts=text_prompts
|
83 |
+
)
|
84 |
+
|
85 |
+
output = torch.zeros(
|
86 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
87 |
+
device=noise.device,
|
88 |
+
dtype=noise.dtype
|
89 |
+
)
|
90 |
+
|
91 |
+
# Set up profiling if requested
|
92 |
+
if profile:
|
93 |
+
init_start = torch.cuda.Event(enable_timing=True)
|
94 |
+
init_end = torch.cuda.Event(enable_timing=True)
|
95 |
+
diffusion_start = torch.cuda.Event(enable_timing=True)
|
96 |
+
diffusion_end = torch.cuda.Event(enable_timing=True)
|
97 |
+
vae_start = torch.cuda.Event(enable_timing=True)
|
98 |
+
vae_end = torch.cuda.Event(enable_timing=True)
|
99 |
+
block_times = []
|
100 |
+
block_start = torch.cuda.Event(enable_timing=True)
|
101 |
+
block_end = torch.cuda.Event(enable_timing=True)
|
102 |
+
init_start.record()
|
103 |
+
|
104 |
+
# Step 1: Initialize KV cache to all zeros
|
105 |
+
if self.kv_cache1 is None:
|
106 |
+
self._initialize_kv_cache(
|
107 |
+
batch_size=batch_size,
|
108 |
+
dtype=noise.dtype,
|
109 |
+
device=noise.device
|
110 |
+
)
|
111 |
+
self._initialize_crossattn_cache(
|
112 |
+
batch_size=batch_size,
|
113 |
+
dtype=noise.dtype,
|
114 |
+
device=noise.device
|
115 |
+
)
|
116 |
+
else:
|
117 |
+
# reset cross attn cache
|
118 |
+
for block_index in range(self.num_transformer_blocks):
|
119 |
+
self.crossattn_cache[block_index]["is_init"] = False
|
120 |
+
# reset kv cache
|
121 |
+
for block_index in range(len(self.kv_cache1)):
|
122 |
+
self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
|
123 |
+
[0], dtype=torch.long, device=noise.device)
|
124 |
+
self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
|
125 |
+
[0], dtype=torch.long, device=noise.device)
|
126 |
+
|
127 |
+
# Step 2: Cache context feature
|
128 |
+
current_start_frame = 0
|
129 |
+
if initial_latent is not None:
|
130 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
131 |
+
if self.independent_first_frame:
|
132 |
+
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
|
133 |
+
assert (num_input_frames - 1) % self.num_frame_per_block == 0
|
134 |
+
num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
|
135 |
+
output[:, :1] = initial_latent[:, :1]
|
136 |
+
self.generator(
|
137 |
+
noisy_image_or_video=initial_latent[:, :1],
|
138 |
+
conditional_dict=conditional_dict,
|
139 |
+
timestep=timestep * 0,
|
140 |
+
kv_cache=self.kv_cache1,
|
141 |
+
crossattn_cache=self.crossattn_cache,
|
142 |
+
current_start=current_start_frame * self.frame_seq_length,
|
143 |
+
)
|
144 |
+
current_start_frame += 1
|
145 |
+
else:
|
146 |
+
# Assume num_input_frames is self.num_frame_per_block * num_input_blocks
|
147 |
+
assert num_input_frames % self.num_frame_per_block == 0
|
148 |
+
num_input_blocks = num_input_frames // self.num_frame_per_block
|
149 |
+
|
150 |
+
for _ in range(num_input_blocks):
|
151 |
+
current_ref_latents = \
|
152 |
+
initial_latent[:, current_start_frame:current_start_frame + self.num_frame_per_block]
|
153 |
+
output[:, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
|
154 |
+
self.generator(
|
155 |
+
noisy_image_or_video=current_ref_latents,
|
156 |
+
conditional_dict=conditional_dict,
|
157 |
+
timestep=timestep * 0,
|
158 |
+
kv_cache=self.kv_cache1,
|
159 |
+
crossattn_cache=self.crossattn_cache,
|
160 |
+
current_start=current_start_frame * self.frame_seq_length,
|
161 |
+
)
|
162 |
+
current_start_frame += self.num_frame_per_block
|
163 |
+
|
164 |
+
if profile:
|
165 |
+
init_end.record()
|
166 |
+
torch.cuda.synchronize()
|
167 |
+
diffusion_start.record()
|
168 |
+
|
169 |
+
# Step 3: Temporal denoising loop
|
170 |
+
all_num_frames = [self.num_frame_per_block] * num_blocks
|
171 |
+
if self.independent_first_frame and initial_latent is None:
|
172 |
+
all_num_frames = [1] + all_num_frames
|
173 |
+
for current_num_frames in all_num_frames:
|
174 |
+
if profile:
|
175 |
+
block_start.record()
|
176 |
+
|
177 |
+
noisy_input = noise[
|
178 |
+
:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
179 |
+
|
180 |
+
# Step 3.1: Spatial denoising loop
|
181 |
+
for index, current_timestep in enumerate(self.denoising_step_list):
|
182 |
+
print(f"current_timestep: {current_timestep}")
|
183 |
+
# set current timestep
|
184 |
+
timestep = torch.ones(
|
185 |
+
[batch_size, current_num_frames],
|
186 |
+
device=noise.device,
|
187 |
+
dtype=torch.int64) * current_timestep
|
188 |
+
|
189 |
+
if index < len(self.denoising_step_list) - 1:
|
190 |
+
_, denoised_pred = self.generator(
|
191 |
+
noisy_image_or_video=noisy_input,
|
192 |
+
conditional_dict=conditional_dict,
|
193 |
+
timestep=timestep,
|
194 |
+
kv_cache=self.kv_cache1,
|
195 |
+
crossattn_cache=self.crossattn_cache,
|
196 |
+
current_start=current_start_frame * self.frame_seq_length
|
197 |
+
)
|
198 |
+
next_timestep = self.denoising_step_list[index + 1]
|
199 |
+
noisy_input = self.scheduler.add_noise(
|
200 |
+
denoised_pred.flatten(0, 1),
|
201 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
202 |
+
next_timestep * torch.ones(
|
203 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
204 |
+
).unflatten(0, denoised_pred.shape[:2])
|
205 |
+
else:
|
206 |
+
# for getting real output
|
207 |
+
_, denoised_pred = self.generator(
|
208 |
+
noisy_image_or_video=noisy_input,
|
209 |
+
conditional_dict=conditional_dict,
|
210 |
+
timestep=timestep,
|
211 |
+
kv_cache=self.kv_cache1,
|
212 |
+
crossattn_cache=self.crossattn_cache,
|
213 |
+
current_start=current_start_frame * self.frame_seq_length
|
214 |
+
)
|
215 |
+
|
216 |
+
# Step 3.2: record the model's output
|
217 |
+
output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
|
218 |
+
|
219 |
+
# Step 3.3: rerun with timestep zero to update KV cache using clean context
|
220 |
+
context_timestep = torch.ones_like(timestep) * self.args.context_noise
|
221 |
+
self.generator(
|
222 |
+
noisy_image_or_video=denoised_pred,
|
223 |
+
conditional_dict=conditional_dict,
|
224 |
+
timestep=context_timestep,
|
225 |
+
kv_cache=self.kv_cache1,
|
226 |
+
crossattn_cache=self.crossattn_cache,
|
227 |
+
current_start=current_start_frame * self.frame_seq_length,
|
228 |
+
)
|
229 |
+
|
230 |
+
if profile:
|
231 |
+
block_end.record()
|
232 |
+
torch.cuda.synchronize()
|
233 |
+
block_time = block_start.elapsed_time(block_end)
|
234 |
+
block_times.append(block_time)
|
235 |
+
|
236 |
+
# Step 3.4: update the start and end frame indices
|
237 |
+
current_start_frame += current_num_frames
|
238 |
+
|
239 |
+
if profile:
|
240 |
+
# End diffusion timing and synchronize CUDA
|
241 |
+
diffusion_end.record()
|
242 |
+
torch.cuda.synchronize()
|
243 |
+
diffusion_time = diffusion_start.elapsed_time(diffusion_end)
|
244 |
+
init_time = init_start.elapsed_time(init_end)
|
245 |
+
vae_start.record()
|
246 |
+
|
247 |
+
# Step 4: Decode the output
|
248 |
+
video = self.vae.decode_to_pixel(output, use_cache=False)
|
249 |
+
video = (video * 0.5 + 0.5).clamp(0, 1)
|
250 |
+
|
251 |
+
if profile:
|
252 |
+
# End VAE timing and synchronize CUDA
|
253 |
+
vae_end.record()
|
254 |
+
torch.cuda.synchronize()
|
255 |
+
vae_time = vae_start.elapsed_time(vae_end)
|
256 |
+
total_time = init_time + diffusion_time + vae_time
|
257 |
+
|
258 |
+
print("Profiling results:")
|
259 |
+
print(f" - Initialization/caching time: {init_time:.2f} ms ({100 * init_time / total_time:.2f}%)")
|
260 |
+
print(f" - Diffusion generation time: {diffusion_time:.2f} ms ({100 * diffusion_time / total_time:.2f}%)")
|
261 |
+
for i, block_time in enumerate(block_times):
|
262 |
+
print(f" - Block {i} generation time: {block_time:.2f} ms ({100 * block_time / diffusion_time:.2f}% of diffusion)")
|
263 |
+
print(f" - VAE decoding time: {vae_time:.2f} ms ({100 * vae_time / total_time:.2f}%)")
|
264 |
+
print(f" - Total time: {total_time:.2f} ms")
|
265 |
+
|
266 |
+
if return_latents:
|
267 |
+
return video, output
|
268 |
+
else:
|
269 |
+
return video
|
270 |
+
|
271 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
272 |
+
"""
|
273 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
274 |
+
"""
|
275 |
+
kv_cache1 = []
|
276 |
+
if self.local_attn_size != -1:
|
277 |
+
# Use the local attention size to compute the KV cache size
|
278 |
+
kv_cache_size = self.local_attn_size * self.frame_seq_length
|
279 |
+
else:
|
280 |
+
# Use the default KV cache size
|
281 |
+
kv_cache_size = 32760
|
282 |
+
|
283 |
+
for _ in range(self.num_transformer_blocks):
|
284 |
+
kv_cache1.append({
|
285 |
+
"k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
286 |
+
"v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
|
287 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
288 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
289 |
+
})
|
290 |
+
|
291 |
+
self.kv_cache1 = kv_cache1 # always store the clean cache
|
292 |
+
|
293 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
294 |
+
"""
|
295 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
296 |
+
"""
|
297 |
+
crossattn_cache = []
|
298 |
+
|
299 |
+
for _ in range(self.num_transformer_blocks):
|
300 |
+
crossattn_cache.append({
|
301 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
302 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
303 |
+
"is_init": False
|
304 |
+
})
|
305 |
+
self.crossattn_cache = crossattn_cache
|
pipeline/self_forcing_training.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.wan_wrapper import WanDiffusionWrapper
|
2 |
+
from utils.scheduler import SchedulerInterface
|
3 |
+
from typing import List, Optional
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
|
8 |
+
class SelfForcingTrainingPipeline:
|
9 |
+
def __init__(self,
|
10 |
+
denoising_step_list: List[int],
|
11 |
+
scheduler: SchedulerInterface,
|
12 |
+
generator: WanDiffusionWrapper,
|
13 |
+
num_frame_per_block=3,
|
14 |
+
independent_first_frame: bool = False,
|
15 |
+
same_step_across_blocks: bool = False,
|
16 |
+
last_step_only: bool = False,
|
17 |
+
num_max_frames: int = 21,
|
18 |
+
context_noise: int = 0,
|
19 |
+
**kwargs):
|
20 |
+
super().__init__()
|
21 |
+
self.scheduler = scheduler
|
22 |
+
self.generator = generator
|
23 |
+
self.denoising_step_list = denoising_step_list
|
24 |
+
if self.denoising_step_list[-1] == 0:
|
25 |
+
self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
|
26 |
+
|
27 |
+
# Wan specific hyperparameters
|
28 |
+
self.num_transformer_blocks = 30
|
29 |
+
self.frame_seq_length = 1560
|
30 |
+
self.num_frame_per_block = num_frame_per_block
|
31 |
+
self.context_noise = context_noise
|
32 |
+
self.i2v = False
|
33 |
+
|
34 |
+
self.kv_cache1 = None
|
35 |
+
self.kv_cache2 = None
|
36 |
+
self.independent_first_frame = independent_first_frame
|
37 |
+
self.same_step_across_blocks = same_step_across_blocks
|
38 |
+
self.last_step_only = last_step_only
|
39 |
+
self.kv_cache_size = num_max_frames * self.frame_seq_length
|
40 |
+
|
41 |
+
def generate_and_sync_list(self, num_blocks, num_denoising_steps, device):
|
42 |
+
rank = dist.get_rank() if dist.is_initialized() else 0
|
43 |
+
|
44 |
+
if rank == 0:
|
45 |
+
# Generate random indices
|
46 |
+
indices = torch.randint(
|
47 |
+
low=0,
|
48 |
+
high=num_denoising_steps,
|
49 |
+
size=(num_blocks,),
|
50 |
+
device=device
|
51 |
+
)
|
52 |
+
if self.last_step_only:
|
53 |
+
indices = torch.ones_like(indices) * (num_denoising_steps - 1)
|
54 |
+
else:
|
55 |
+
indices = torch.empty(num_blocks, dtype=torch.long, device=device)
|
56 |
+
|
57 |
+
dist.broadcast(indices, src=0) # Broadcast the random indices to all ranks
|
58 |
+
return indices.tolist()
|
59 |
+
|
60 |
+
def inference_with_trajectory(
|
61 |
+
self,
|
62 |
+
noise: torch.Tensor,
|
63 |
+
initial_latent: Optional[torch.Tensor] = None,
|
64 |
+
return_sim_step: bool = False,
|
65 |
+
**conditional_dict
|
66 |
+
) -> torch.Tensor:
|
67 |
+
batch_size, num_frames, num_channels, height, width = noise.shape
|
68 |
+
if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
|
69 |
+
# If the first frame is independent and the first frame is provided, then the number of frames in the
|
70 |
+
# noise should still be a multiple of num_frame_per_block
|
71 |
+
assert num_frames % self.num_frame_per_block == 0
|
72 |
+
num_blocks = num_frames // self.num_frame_per_block
|
73 |
+
else:
|
74 |
+
# Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
|
75 |
+
assert (num_frames - 1) % self.num_frame_per_block == 0
|
76 |
+
num_blocks = (num_frames - 1) // self.num_frame_per_block
|
77 |
+
num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
|
78 |
+
num_output_frames = num_frames + num_input_frames # add the initial latent frames
|
79 |
+
output = torch.zeros(
|
80 |
+
[batch_size, num_output_frames, num_channels, height, width],
|
81 |
+
device=noise.device,
|
82 |
+
dtype=noise.dtype
|
83 |
+
)
|
84 |
+
|
85 |
+
# Step 1: Initialize KV cache to all zeros
|
86 |
+
self._initialize_kv_cache(
|
87 |
+
batch_size=batch_size, dtype=noise.dtype, device=noise.device
|
88 |
+
)
|
89 |
+
self._initialize_crossattn_cache(
|
90 |
+
batch_size=batch_size, dtype=noise.dtype, device=noise.device
|
91 |
+
)
|
92 |
+
# if self.kv_cache1 is None:
|
93 |
+
# self._initialize_kv_cache(
|
94 |
+
# batch_size=batch_size,
|
95 |
+
# dtype=noise.dtype,
|
96 |
+
# device=noise.device,
|
97 |
+
# )
|
98 |
+
# self._initialize_crossattn_cache(
|
99 |
+
# batch_size=batch_size,
|
100 |
+
# dtype=noise.dtype,
|
101 |
+
# device=noise.device
|
102 |
+
# )
|
103 |
+
# else:
|
104 |
+
# # reset cross attn cache
|
105 |
+
# for block_index in range(self.num_transformer_blocks):
|
106 |
+
# self.crossattn_cache[block_index]["is_init"] = False
|
107 |
+
# # reset kv cache
|
108 |
+
# for block_index in range(len(self.kv_cache1)):
|
109 |
+
# self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
|
110 |
+
# [0], dtype=torch.long, device=noise.device)
|
111 |
+
# self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
|
112 |
+
# [0], dtype=torch.long, device=noise.device)
|
113 |
+
|
114 |
+
# Step 2: Cache context feature
|
115 |
+
current_start_frame = 0
|
116 |
+
if initial_latent is not None:
|
117 |
+
timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
|
118 |
+
# Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
|
119 |
+
output[:, :1] = initial_latent
|
120 |
+
with torch.no_grad():
|
121 |
+
self.generator(
|
122 |
+
noisy_image_or_video=initial_latent,
|
123 |
+
conditional_dict=conditional_dict,
|
124 |
+
timestep=timestep * 0,
|
125 |
+
kv_cache=self.kv_cache1,
|
126 |
+
crossattn_cache=self.crossattn_cache,
|
127 |
+
current_start=current_start_frame * self.frame_seq_length
|
128 |
+
)
|
129 |
+
current_start_frame += 1
|
130 |
+
|
131 |
+
# Step 3: Temporal denoising loop
|
132 |
+
all_num_frames = [self.num_frame_per_block] * num_blocks
|
133 |
+
if self.independent_first_frame and initial_latent is None:
|
134 |
+
all_num_frames = [1] + all_num_frames
|
135 |
+
num_denoising_steps = len(self.denoising_step_list)
|
136 |
+
exit_flags = self.generate_and_sync_list(len(all_num_frames), num_denoising_steps, device=noise.device)
|
137 |
+
start_gradient_frame_index = num_output_frames - 21
|
138 |
+
|
139 |
+
# for block_index in range(num_blocks):
|
140 |
+
for block_index, current_num_frames in enumerate(all_num_frames):
|
141 |
+
noisy_input = noise[
|
142 |
+
:, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
|
143 |
+
|
144 |
+
# Step 3.1: Spatial denoising loop
|
145 |
+
for index, current_timestep in enumerate(self.denoising_step_list):
|
146 |
+
if self.same_step_across_blocks:
|
147 |
+
exit_flag = (index == exit_flags[0])
|
148 |
+
else:
|
149 |
+
exit_flag = (index == exit_flags[block_index]) # Only backprop at the randomly selected timestep (consistent across all ranks)
|
150 |
+
timestep = torch.ones(
|
151 |
+
[batch_size, current_num_frames],
|
152 |
+
device=noise.device,
|
153 |
+
dtype=torch.int64) * current_timestep
|
154 |
+
|
155 |
+
if not exit_flag:
|
156 |
+
with torch.no_grad():
|
157 |
+
_, denoised_pred = self.generator(
|
158 |
+
noisy_image_or_video=noisy_input,
|
159 |
+
conditional_dict=conditional_dict,
|
160 |
+
timestep=timestep,
|
161 |
+
kv_cache=self.kv_cache1,
|
162 |
+
crossattn_cache=self.crossattn_cache,
|
163 |
+
current_start=current_start_frame * self.frame_seq_length
|
164 |
+
)
|
165 |
+
next_timestep = self.denoising_step_list[index + 1]
|
166 |
+
noisy_input = self.scheduler.add_noise(
|
167 |
+
denoised_pred.flatten(0, 1),
|
168 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
169 |
+
next_timestep * torch.ones(
|
170 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
171 |
+
).unflatten(0, denoised_pred.shape[:2])
|
172 |
+
else:
|
173 |
+
# for getting real output
|
174 |
+
# with torch.set_grad_enabled(current_start_frame >= start_gradient_frame_index):
|
175 |
+
if current_start_frame < start_gradient_frame_index:
|
176 |
+
with torch.no_grad():
|
177 |
+
_, denoised_pred = self.generator(
|
178 |
+
noisy_image_or_video=noisy_input,
|
179 |
+
conditional_dict=conditional_dict,
|
180 |
+
timestep=timestep,
|
181 |
+
kv_cache=self.kv_cache1,
|
182 |
+
crossattn_cache=self.crossattn_cache,
|
183 |
+
current_start=current_start_frame * self.frame_seq_length
|
184 |
+
)
|
185 |
+
else:
|
186 |
+
_, denoised_pred = self.generator(
|
187 |
+
noisy_image_or_video=noisy_input,
|
188 |
+
conditional_dict=conditional_dict,
|
189 |
+
timestep=timestep,
|
190 |
+
kv_cache=self.kv_cache1,
|
191 |
+
crossattn_cache=self.crossattn_cache,
|
192 |
+
current_start=current_start_frame * self.frame_seq_length
|
193 |
+
)
|
194 |
+
break
|
195 |
+
|
196 |
+
# Step 3.2: record the model's output
|
197 |
+
output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
|
198 |
+
|
199 |
+
# Step 3.3: rerun with timestep zero to update the cache
|
200 |
+
context_timestep = torch.ones_like(timestep) * self.context_noise
|
201 |
+
# add context noise
|
202 |
+
denoised_pred = self.scheduler.add_noise(
|
203 |
+
denoised_pred.flatten(0, 1),
|
204 |
+
torch.randn_like(denoised_pred.flatten(0, 1)),
|
205 |
+
context_timestep * torch.ones(
|
206 |
+
[batch_size * current_num_frames], device=noise.device, dtype=torch.long)
|
207 |
+
).unflatten(0, denoised_pred.shape[:2])
|
208 |
+
with torch.no_grad():
|
209 |
+
self.generator(
|
210 |
+
noisy_image_or_video=denoised_pred,
|
211 |
+
conditional_dict=conditional_dict,
|
212 |
+
timestep=context_timestep,
|
213 |
+
kv_cache=self.kv_cache1,
|
214 |
+
crossattn_cache=self.crossattn_cache,
|
215 |
+
current_start=current_start_frame * self.frame_seq_length
|
216 |
+
)
|
217 |
+
|
218 |
+
# Step 3.4: update the start and end frame indices
|
219 |
+
current_start_frame += current_num_frames
|
220 |
+
|
221 |
+
# Step 3.5: Return the denoised timestep
|
222 |
+
if not self.same_step_across_blocks:
|
223 |
+
denoised_timestep_from, denoised_timestep_to = None, None
|
224 |
+
elif exit_flags[0] == len(self.denoising_step_list) - 1:
|
225 |
+
denoised_timestep_to = 0
|
226 |
+
denoised_timestep_from = 1000 - torch.argmin(
|
227 |
+
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
|
228 |
+
else:
|
229 |
+
denoised_timestep_to = 1000 - torch.argmin(
|
230 |
+
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0] + 1].cuda()).abs(), dim=0).item()
|
231 |
+
denoised_timestep_from = 1000 - torch.argmin(
|
232 |
+
(self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
|
233 |
+
|
234 |
+
if return_sim_step:
|
235 |
+
return output, denoised_timestep_from, denoised_timestep_to, exit_flags[0] + 1
|
236 |
+
|
237 |
+
return output, denoised_timestep_from, denoised_timestep_to
|
238 |
+
|
239 |
+
def _initialize_kv_cache(self, batch_size, dtype, device):
|
240 |
+
"""
|
241 |
+
Initialize a Per-GPU KV cache for the Wan model.
|
242 |
+
"""
|
243 |
+
kv_cache1 = []
|
244 |
+
|
245 |
+
for _ in range(self.num_transformer_blocks):
|
246 |
+
kv_cache1.append({
|
247 |
+
"k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
|
248 |
+
"v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
|
249 |
+
"global_end_index": torch.tensor([0], dtype=torch.long, device=device),
|
250 |
+
"local_end_index": torch.tensor([0], dtype=torch.long, device=device)
|
251 |
+
})
|
252 |
+
|
253 |
+
self.kv_cache1 = kv_cache1 # always store the clean cache
|
254 |
+
|
255 |
+
def _initialize_crossattn_cache(self, batch_size, dtype, device):
|
256 |
+
"""
|
257 |
+
Initialize a Per-GPU cross-attention cache for the Wan model.
|
258 |
+
"""
|
259 |
+
crossattn_cache = []
|
260 |
+
|
261 |
+
for _ in range(self.num_transformer_blocks):
|
262 |
+
crossattn_cache.append({
|
263 |
+
"k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
264 |
+
"v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
|
265 |
+
"is_init": False
|
266 |
+
})
|
267 |
+
self.crossattn_cache = crossattn_cache
|
prompts/MovieGenVideoBench.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompts/MovieGenVideoBench_extended.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
prompts/vbench/all_dimension.txt
ADDED
@@ -0,0 +1,946 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
In a still frame, a stop sign
|
2 |
+
a toilet, frozen in time
|
3 |
+
a laptop, frozen in time
|
4 |
+
A tranquil tableau of alley
|
5 |
+
A tranquil tableau of bar
|
6 |
+
A tranquil tableau of barn
|
7 |
+
A tranquil tableau of bathroom
|
8 |
+
A tranquil tableau of bedroom
|
9 |
+
A tranquil tableau of cliff
|
10 |
+
In a still frame, courtyard
|
11 |
+
In a still frame, gas station
|
12 |
+
A tranquil tableau of house
|
13 |
+
indoor gymnasium, frozen in time
|
14 |
+
A tranquil tableau of indoor library
|
15 |
+
A tranquil tableau of kitchen
|
16 |
+
A tranquil tableau of palace
|
17 |
+
In a still frame, parking lot
|
18 |
+
In a still frame, phone booth
|
19 |
+
A tranquil tableau of restaurant
|
20 |
+
A tranquil tableau of tower
|
21 |
+
A tranquil tableau of a bowl
|
22 |
+
A tranquil tableau of an apple
|
23 |
+
A tranquil tableau of a bench
|
24 |
+
A tranquil tableau of a bed
|
25 |
+
A tranquil tableau of a chair
|
26 |
+
A tranquil tableau of a cup
|
27 |
+
A tranquil tableau of a dining table
|
28 |
+
In a still frame, a pear
|
29 |
+
A tranquil tableau of a bunch of grapes
|
30 |
+
A tranquil tableau of a bowl on the kitchen counter
|
31 |
+
A tranquil tableau of a beautiful, handcrafted ceramic bowl
|
32 |
+
A tranquil tableau of an antique bowl
|
33 |
+
A tranquil tableau of an exquisite mahogany dining table
|
34 |
+
A tranquil tableau of a wooden bench in the park
|
35 |
+
A tranquil tableau of a beautiful wrought-iron bench surrounded by blooming flowers
|
36 |
+
In a still frame, a park bench with a view of the lake
|
37 |
+
A tranquil tableau of a vintage rocking chair was placed on the porch
|
38 |
+
A tranquil tableau of the jail cell was small and dimly lit, with cold, steel bars
|
39 |
+
A tranquil tableau of the phone booth was tucked away in a quiet alley
|
40 |
+
a dilapidated phone booth stood as a relic of a bygone era on the sidewalk, frozen in time
|
41 |
+
A tranquil tableau of the old red barn stood weathered and iconic against the backdrop of the countryside
|
42 |
+
A tranquil tableau of a picturesque barn was painted a warm shade of red and nestled in a picturesque meadow
|
43 |
+
In a still frame, within the desolate desert, an oasis unfolded, characterized by the stoic presence of palm trees and a motionless, glassy pool of water
|
44 |
+
In a still frame, the Parthenon's majestic Doric columns stand in serene solitude atop the Acropolis, framed by the tranquil Athenian landscape
|
45 |
+
In a still frame, the Temple of Hephaestus, with its timeless Doric grace, stands stoically against the backdrop of a quiet Athens
|
46 |
+
In a still frame, the ornate Victorian streetlamp stands solemnly, adorned with intricate ironwork and stained glass panels
|
47 |
+
A tranquil tableau of the Stonehenge presented itself as an enigmatic puzzle, each colossal stone meticulously placed against the backdrop of tranquility
|
48 |
+
In a still frame, in the vast desert, an oasis nestled among dunes, featuring tall palm trees and an air of serenity
|
49 |
+
static view on a desert scene with an oasis, palm trees, and a clear, calm pool of water
|
50 |
+
A tranquil tableau of an ornate Victorian streetlamp standing on a cobblestone street corner, illuminating the empty night
|
51 |
+
A tranquil tableau of a tranquil lakeside cabin nestled among tall pines, its reflection mirrored perfectly in the calm water
|
52 |
+
In a still frame, a vintage gas lantern, adorned with intricate details, gracing a historic cobblestone square
|
53 |
+
In a still frame, a tranquil Japanese tea ceremony room, with tatami mats, a delicate tea set, and a bonsai tree in the corner
|
54 |
+
A tranquil tableau of the Parthenon stands resolute in its classical elegance, a timeless symbol of Athens' cultural legacy
|
55 |
+
A tranquil tableau of in the heart of Plaka, the neoclassical architecture of the old city harmonizes with the ancient ruins
|
56 |
+
A tranquil tableau of in the desolate beauty of the American Southwest, Chaco Canyon's ancient ruins whispered tales of an enigmatic civilization that once thrived amidst the arid landscapes
|
57 |
+
A tranquil tableau of at the edge of the Arabian Desert, the ancient city of Petra beckoned with its enigmatic rock-carved façades
|
58 |
+
In a still frame, amidst the cobblestone streets, an Art Nouveau lamppost stood tall
|
59 |
+
A tranquil tableau of in the quaint village square, a traditional wrought-iron streetlamp featured delicate filigree patterns and amber-hued glass panels
|
60 |
+
A tranquil tableau of the lampposts were adorned with Art Deco motifs, their geometric shapes and frosted glass creating a sense of vintage glamour
|
61 |
+
In a still frame, in the picturesque square, a Gothic-style lamppost adorned with intricate stone carvings added a touch of medieval charm to the setting
|
62 |
+
In a still frame, in the heart of the old city, a row of ornate lantern-style streetlamps bathed the narrow alleyway in a warm, welcoming light
|
63 |
+
A tranquil tableau of in the heart of the Utah desert, a massive sandstone arch spanned the horizon
|
64 |
+
A tranquil tableau of in the Arizona desert, a massive stone bridge arched across a rugged canyon
|
65 |
+
A tranquil tableau of in the corner of the minimalist tea room, a bonsai tree added a touch of nature's beauty to the otherwise simple and elegant space
|
66 |
+
In a still frame, amidst the hushed ambiance of the traditional tea room, a meticulously arranged tea set awaited, with porcelain cups, a bamboo whisk
|
67 |
+
In a still frame, nestled in the Zen garden, a rustic teahouse featured tatami seating and a traditional charcoal brazier
|
68 |
+
A tranquil tableau of a country estate's library featured elegant wooden shelves
|
69 |
+
A tranquil tableau of beneath the shade of a solitary oak tree, an old wooden park bench sat patiently
|
70 |
+
A tranquil tableau of beside a tranquil pond, a weeping willow tree draped its branches gracefully over the water's surface, creating a serene tableau of reflection and calm
|
71 |
+
A tranquil tableau of in the Zen garden, a perfectly raked gravel path led to a serene rock garden
|
72 |
+
In a still frame, a tranquil pond was fringed by weeping cherry trees, their blossoms drifting lazily onto the glassy surface
|
73 |
+
In a still frame, within the historic library's reading room, rows of antique leather chairs and mahogany tables offered a serene haven for literary contemplation
|
74 |
+
A tranquil tableau of a peaceful orchid garden showcased a variety of delicate blooms
|
75 |
+
A tranquil tableau of in the serene courtyard, a centuries-old stone well stood as a symbol of a bygone era, its mossy stones bearing witness to the passage of time
|
76 |
+
a bird and a cat
|
77 |
+
a cat and a dog
|
78 |
+
a dog and a horse
|
79 |
+
a horse and a sheep
|
80 |
+
a sheep and a cow
|
81 |
+
a cow and an elephant
|
82 |
+
an elephant and a bear
|
83 |
+
a bear and a zebra
|
84 |
+
a zebra and a giraffe
|
85 |
+
a giraffe and a bird
|
86 |
+
a chair and a couch
|
87 |
+
a couch and a potted plant
|
88 |
+
a potted plant and a tv
|
89 |
+
a tv and a laptop
|
90 |
+
a laptop and a remote
|
91 |
+
a remote and a keyboard
|
92 |
+
a keyboard and a cell phone
|
93 |
+
a cell phone and a book
|
94 |
+
a book and a clock
|
95 |
+
a clock and a backpack
|
96 |
+
a backpack and an umbrella
|
97 |
+
an umbrella and a handbag
|
98 |
+
a handbag and a tie
|
99 |
+
a tie and a suitcase
|
100 |
+
a suitcase and a vase
|
101 |
+
a vase and scissors
|
102 |
+
scissors and a teddy bear
|
103 |
+
a teddy bear and a frisbee
|
104 |
+
a frisbee and skis
|
105 |
+
skis and a snowboard
|
106 |
+
a snowboard and a sports ball
|
107 |
+
a sports ball and a kite
|
108 |
+
a kite and a baseball bat
|
109 |
+
a baseball bat and a baseball glove
|
110 |
+
a baseball glove and a skateboard
|
111 |
+
a skateboard and a surfboard
|
112 |
+
a surfboard and a tennis racket
|
113 |
+
a tennis racket and a bottle
|
114 |
+
a bottle and a chair
|
115 |
+
an airplane and a train
|
116 |
+
a train and a boat
|
117 |
+
a boat and an airplane
|
118 |
+
a bicycle and a car
|
119 |
+
a car and a motorcycle
|
120 |
+
a motorcycle and a bus
|
121 |
+
a bus and a traffic light
|
122 |
+
a traffic light and a fire hydrant
|
123 |
+
a fire hydrant and a stop sign
|
124 |
+
a stop sign and a parking meter
|
125 |
+
a parking meter and a truck
|
126 |
+
a truck and a bicycle
|
127 |
+
a toilet and a hair drier
|
128 |
+
a hair drier and a toothbrush
|
129 |
+
a toothbrush and a sink
|
130 |
+
a sink and a toilet
|
131 |
+
a wine glass and a chair
|
132 |
+
a cup and a couch
|
133 |
+
a fork and a potted plant
|
134 |
+
a knife and a tv
|
135 |
+
a spoon and a laptop
|
136 |
+
a bowl and a remote
|
137 |
+
a banana and a keyboard
|
138 |
+
an apple and a cell phone
|
139 |
+
a sandwich and a book
|
140 |
+
an orange and a clock
|
141 |
+
broccoli and a backpack
|
142 |
+
a carrot and an umbrella
|
143 |
+
a hot dog and a handbag
|
144 |
+
a pizza and a tie
|
145 |
+
a donut and a suitcase
|
146 |
+
a cake and a vase
|
147 |
+
an oven and scissors
|
148 |
+
a toaster and a teddy bear
|
149 |
+
a microwave and a frisbee
|
150 |
+
a refrigerator and skis
|
151 |
+
a bicycle and an airplane
|
152 |
+
a car and a train
|
153 |
+
a motorcycle and a boat
|
154 |
+
a person and a toilet
|
155 |
+
a person and a hair drier
|
156 |
+
a person and a toothbrush
|
157 |
+
a person and a sink
|
158 |
+
A person is riding a bike
|
159 |
+
A person is marching
|
160 |
+
A person is roller skating
|
161 |
+
A person is tasting beer
|
162 |
+
A person is clapping
|
163 |
+
A person is drawing
|
164 |
+
A person is petting animal (not cat)
|
165 |
+
A person is eating watermelon
|
166 |
+
A person is playing harp
|
167 |
+
A person is wrestling
|
168 |
+
A person is riding scooter
|
169 |
+
A person is sweeping floor
|
170 |
+
A person is skateboarding
|
171 |
+
A person is dunking basketball
|
172 |
+
A person is playing flute
|
173 |
+
A person is stretching leg
|
174 |
+
A person is tying tie
|
175 |
+
A person is skydiving
|
176 |
+
A person is shooting goal (soccer)
|
177 |
+
A person is playing piano
|
178 |
+
A person is finger snapping
|
179 |
+
A person is canoeing or kayaking
|
180 |
+
A person is laughing
|
181 |
+
A person is digging
|
182 |
+
A person is clay pottery making
|
183 |
+
A person is shooting basketball
|
184 |
+
A person is bending back
|
185 |
+
A person is shaking hands
|
186 |
+
A person is bandaging
|
187 |
+
A person is push up
|
188 |
+
A person is catching or throwing frisbee
|
189 |
+
A person is playing trumpet
|
190 |
+
A person is flying kite
|
191 |
+
A person is filling eyebrows
|
192 |
+
A person is shuffling cards
|
193 |
+
A person is folding clothes
|
194 |
+
A person is smoking
|
195 |
+
A person is tai chi
|
196 |
+
A person is squat
|
197 |
+
A person is playing controller
|
198 |
+
A person is throwing axe
|
199 |
+
A person is giving or receiving award
|
200 |
+
A person is air drumming
|
201 |
+
A person is taking a shower
|
202 |
+
A person is planting trees
|
203 |
+
A person is sharpening knives
|
204 |
+
A person is robot dancing
|
205 |
+
A person is rock climbing
|
206 |
+
A person is hula hooping
|
207 |
+
A person is writing
|
208 |
+
A person is bungee jumping
|
209 |
+
A person is pushing cart
|
210 |
+
A person is cleaning windows
|
211 |
+
A person is cutting watermelon
|
212 |
+
A person is cheerleading
|
213 |
+
A person is washing hands
|
214 |
+
A person is ironing
|
215 |
+
A person is cutting nails
|
216 |
+
A person is hugging
|
217 |
+
A person is trimming or shaving beard
|
218 |
+
A person is jogging
|
219 |
+
A person is making bed
|
220 |
+
A person is washing dishes
|
221 |
+
A person is grooming dog
|
222 |
+
A person is doing laundry
|
223 |
+
A person is knitting
|
224 |
+
A person is reading book
|
225 |
+
A person is baby waking up
|
226 |
+
A person is massaging legs
|
227 |
+
A person is brushing teeth
|
228 |
+
A person is crawling baby
|
229 |
+
A person is motorcycling
|
230 |
+
A person is driving car
|
231 |
+
A person is sticking tongue out
|
232 |
+
A person is shaking head
|
233 |
+
A person is sword fighting
|
234 |
+
A person is doing aerobics
|
235 |
+
A person is strumming guitar
|
236 |
+
A person is riding or walking with horse
|
237 |
+
A person is archery
|
238 |
+
A person is catching or throwing baseball
|
239 |
+
A person is playing chess
|
240 |
+
A person is rock scissors paper
|
241 |
+
A person is using computer
|
242 |
+
A person is arranging flowers
|
243 |
+
A person is bending metal
|
244 |
+
A person is ice skating
|
245 |
+
A person is climbing a rope
|
246 |
+
A person is crying
|
247 |
+
A person is dancing ballet
|
248 |
+
A person is getting a haircut
|
249 |
+
A person is running on treadmill
|
250 |
+
A person is kissing
|
251 |
+
A person is counting money
|
252 |
+
A person is barbequing
|
253 |
+
A person is peeling apples
|
254 |
+
A person is milking cow
|
255 |
+
A person is shining shoes
|
256 |
+
A person is making snowman
|
257 |
+
A person is sailing
|
258 |
+
a person swimming in ocean
|
259 |
+
a person giving a presentation to a room full of colleagues
|
260 |
+
a person washing the dishes
|
261 |
+
a person eating a burger
|
262 |
+
a person walking in the snowstorm
|
263 |
+
a person drinking coffee in a cafe
|
264 |
+
a person playing guitar
|
265 |
+
a bicycle leaning against a tree
|
266 |
+
a bicycle gliding through a snowy field
|
267 |
+
a bicycle slowing down to stop
|
268 |
+
a bicycle accelerating to gain speed
|
269 |
+
a car stuck in traffic during rush hour
|
270 |
+
a car turning a corner
|
271 |
+
a car slowing down to stop
|
272 |
+
a car accelerating to gain speed
|
273 |
+
a motorcycle cruising along a coastal highway
|
274 |
+
a motorcycle turning a corner
|
275 |
+
a motorcycle slowing down to stop
|
276 |
+
a motorcycle gliding through a snowy field
|
277 |
+
a motorcycle accelerating to gain speed
|
278 |
+
an airplane soaring through a clear blue sky
|
279 |
+
an airplane taking off
|
280 |
+
an airplane landing smoothly on a runway
|
281 |
+
an airplane accelerating to gain speed
|
282 |
+
a bus turning a corner
|
283 |
+
a bus stuck in traffic during rush hour
|
284 |
+
a bus accelerating to gain speed
|
285 |
+
a train speeding down the tracks
|
286 |
+
a train crossing over a tall bridge
|
287 |
+
a train accelerating to gain speed
|
288 |
+
a truck turning a corner
|
289 |
+
a truck anchored in a tranquil bay
|
290 |
+
a truck stuck in traffic during rush hour
|
291 |
+
a truck slowing down to stop
|
292 |
+
a truck accelerating to gain speed
|
293 |
+
a boat sailing smoothly on a calm lake
|
294 |
+
a boat slowing down to stop
|
295 |
+
a boat accelerating to gain speed
|
296 |
+
a bird soaring gracefully in the sky
|
297 |
+
a bird building a nest from twigs and leaves
|
298 |
+
a bird flying over a snowy forest
|
299 |
+
a cat grooming itself meticulously with its tongue
|
300 |
+
a cat playing in park
|
301 |
+
a cat drinking water
|
302 |
+
a cat running happily
|
303 |
+
a dog enjoying a peaceful walk
|
304 |
+
a dog playing in park
|
305 |
+
a dog drinking water
|
306 |
+
a dog running happily
|
307 |
+
a horse bending down to drink water from a river
|
308 |
+
a horse galloping across an open field
|
309 |
+
a horse taking a peaceful walk
|
310 |
+
a horse running to join a herd of its kind
|
311 |
+
a sheep bending down to drink water from a river
|
312 |
+
a sheep taking a peaceful walk
|
313 |
+
a sheep running to join a herd of its kind
|
314 |
+
a cow bending down to drink water from a river
|
315 |
+
a cow chewing cud while resting in a tranquil barn
|
316 |
+
a cow running to join a herd of its kind
|
317 |
+
an elephant spraying itself with water using its trunk to cool down
|
318 |
+
an elephant taking a peaceful walk
|
319 |
+
an elephant running to join a herd of its kind
|
320 |
+
a bear catching a salmon in its powerful jaws
|
321 |
+
a bear sniffing the air for scents of food
|
322 |
+
a bear climbing a tree
|
323 |
+
a bear hunting for prey
|
324 |
+
a zebra bending down to drink water from a river
|
325 |
+
a zebra running to join a herd of its kind
|
326 |
+
a zebra taking a peaceful walk
|
327 |
+
a giraffe bending down to drink water from a river
|
328 |
+
a giraffe taking a peaceful walk
|
329 |
+
a giraffe running to join a herd of its kind
|
330 |
+
a person
|
331 |
+
a bicycle
|
332 |
+
a car
|
333 |
+
a motorcycle
|
334 |
+
an airplane
|
335 |
+
a bus
|
336 |
+
a train
|
337 |
+
a truck
|
338 |
+
a boat
|
339 |
+
a traffic light
|
340 |
+
a fire hydrant
|
341 |
+
a stop sign
|
342 |
+
a parking meter
|
343 |
+
a bench
|
344 |
+
a bird
|
345 |
+
a cat
|
346 |
+
a dog
|
347 |
+
a horse
|
348 |
+
a sheep
|
349 |
+
a cow
|
350 |
+
an elephant
|
351 |
+
a bear
|
352 |
+
a zebra
|
353 |
+
a giraffe
|
354 |
+
a backpack
|
355 |
+
an umbrella
|
356 |
+
a handbag
|
357 |
+
a tie
|
358 |
+
a suitcase
|
359 |
+
a frisbee
|
360 |
+
skis
|
361 |
+
a snowboard
|
362 |
+
a sports ball
|
363 |
+
a kite
|
364 |
+
a baseball bat
|
365 |
+
a baseball glove
|
366 |
+
a skateboard
|
367 |
+
a surfboard
|
368 |
+
a tennis racket
|
369 |
+
a bottle
|
370 |
+
a wine glass
|
371 |
+
a cup
|
372 |
+
a fork
|
373 |
+
a knife
|
374 |
+
a spoon
|
375 |
+
a bowl
|
376 |
+
a banana
|
377 |
+
an apple
|
378 |
+
a sandwich
|
379 |
+
an orange
|
380 |
+
broccoli
|
381 |
+
a carrot
|
382 |
+
a hot dog
|
383 |
+
a pizza
|
384 |
+
a donut
|
385 |
+
a cake
|
386 |
+
a chair
|
387 |
+
a couch
|
388 |
+
a potted plant
|
389 |
+
a bed
|
390 |
+
a dining table
|
391 |
+
a toilet
|
392 |
+
a tv
|
393 |
+
a laptop
|
394 |
+
a remote
|
395 |
+
a keyboard
|
396 |
+
a cell phone
|
397 |
+
a microwave
|
398 |
+
an oven
|
399 |
+
a toaster
|
400 |
+
a sink
|
401 |
+
a refrigerator
|
402 |
+
a book
|
403 |
+
a clock
|
404 |
+
a vase
|
405 |
+
scissors
|
406 |
+
a teddy bear
|
407 |
+
a hair drier
|
408 |
+
a toothbrush
|
409 |
+
a red bicycle
|
410 |
+
a green bicycle
|
411 |
+
a blue bicycle
|
412 |
+
a yellow bicycle
|
413 |
+
an orange bicycle
|
414 |
+
a purple bicycle
|
415 |
+
a pink bicycle
|
416 |
+
a black bicycle
|
417 |
+
a white bicycle
|
418 |
+
a red car
|
419 |
+
a green car
|
420 |
+
a blue car
|
421 |
+
a yellow car
|
422 |
+
an orange car
|
423 |
+
a purple car
|
424 |
+
a pink car
|
425 |
+
a black car
|
426 |
+
a white car
|
427 |
+
a red bird
|
428 |
+
a green bird
|
429 |
+
a blue bird
|
430 |
+
a yellow bird
|
431 |
+
an orange bird
|
432 |
+
a purple bird
|
433 |
+
a pink bird
|
434 |
+
a black bird
|
435 |
+
a white bird
|
436 |
+
a black cat
|
437 |
+
a white cat
|
438 |
+
an orange cat
|
439 |
+
a yellow cat
|
440 |
+
a red umbrella
|
441 |
+
a green umbrella
|
442 |
+
a blue umbrella
|
443 |
+
a yellow umbrella
|
444 |
+
an orange umbrella
|
445 |
+
a purple umbrella
|
446 |
+
a pink umbrella
|
447 |
+
a black umbrella
|
448 |
+
a white umbrella
|
449 |
+
a red suitcase
|
450 |
+
a green suitcase
|
451 |
+
a blue suitcase
|
452 |
+
a yellow suitcase
|
453 |
+
an orange suitcase
|
454 |
+
a purple suitcase
|
455 |
+
a pink suitcase
|
456 |
+
a black suitcase
|
457 |
+
a white suitcase
|
458 |
+
a red bowl
|
459 |
+
a green bowl
|
460 |
+
a blue bowl
|
461 |
+
a yellow bowl
|
462 |
+
an orange bowl
|
463 |
+
a purple bowl
|
464 |
+
a pink bowl
|
465 |
+
a black bowl
|
466 |
+
a white bowl
|
467 |
+
a red chair
|
468 |
+
a green chair
|
469 |
+
a blue chair
|
470 |
+
a yellow chair
|
471 |
+
an orange chair
|
472 |
+
a purple chair
|
473 |
+
a pink chair
|
474 |
+
a black chair
|
475 |
+
a white chair
|
476 |
+
a red clock
|
477 |
+
a green clock
|
478 |
+
a blue clock
|
479 |
+
a yellow clock
|
480 |
+
an orange clock
|
481 |
+
a purple clock
|
482 |
+
a pink clock
|
483 |
+
a black clock
|
484 |
+
a white clock
|
485 |
+
a red vase
|
486 |
+
a green vase
|
487 |
+
a blue vase
|
488 |
+
a yellow vase
|
489 |
+
an orange vase
|
490 |
+
a purple vase
|
491 |
+
a pink vase
|
492 |
+
a black vase
|
493 |
+
a white vase
|
494 |
+
A beautiful coastal beach in spring, waves lapping on sand, Van Gogh style
|
495 |
+
A beautiful coastal beach in spring, waves lapping on sand, oil painting
|
496 |
+
A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo
|
497 |
+
A beautiful coastal beach in spring, waves lapping on sand, black and white
|
498 |
+
A beautiful coastal beach in spring, waves lapping on sand, pixel art
|
499 |
+
A beautiful coastal beach in spring, waves lapping on sand, in cyberpunk style
|
500 |
+
A beautiful coastal beach in spring, waves lapping on sand, animated style
|
501 |
+
A beautiful coastal beach in spring, waves lapping on sand, watercolor painting
|
502 |
+
A beautiful coastal beach in spring, waves lapping on sand, surrealism style
|
503 |
+
The bund Shanghai, Van Gogh style
|
504 |
+
The bund Shanghai, oil painting
|
505 |
+
The bund Shanghai by Hokusai, in the style of Ukiyo
|
506 |
+
The bund Shanghai, black and white
|
507 |
+
The bund Shanghai, pixel art
|
508 |
+
The bund Shanghai, in cyberpunk style
|
509 |
+
The bund Shanghai, animated style
|
510 |
+
The bund Shanghai, watercolor painting
|
511 |
+
The bund Shanghai, surrealism style
|
512 |
+
a shark is swimming in the ocean, Van Gogh style
|
513 |
+
a shark is swimming in the ocean, oil painting
|
514 |
+
a shark is swimming in the ocean by Hokusai, in the style of Ukiyo
|
515 |
+
a shark is swimming in the ocean, black and white
|
516 |
+
a shark is swimming in the ocean, pixel art
|
517 |
+
a shark is swimming in the ocean, in cyberpunk style
|
518 |
+
a shark is swimming in the ocean, animated style
|
519 |
+
a shark is swimming in the ocean, watercolor painting
|
520 |
+
a shark is swimming in the ocean, surrealism style
|
521 |
+
A panda drinking coffee in a cafe in Paris, Van Gogh style
|
522 |
+
A panda drinking coffee in a cafe in Paris, oil painting
|
523 |
+
A panda drinking coffee in a cafe in Paris by Hokusai, in the style of Ukiyo
|
524 |
+
A panda drinking coffee in a cafe in Paris, black and white
|
525 |
+
A panda drinking coffee in a cafe in Paris, pixel art
|
526 |
+
A panda drinking coffee in a cafe in Paris, in cyberpunk style
|
527 |
+
A panda drinking coffee in a cafe in Paris, animated style
|
528 |
+
A panda drinking coffee in a cafe in Paris, watercolor painting
|
529 |
+
A panda drinking coffee in a cafe in Paris, surrealism style
|
530 |
+
A cute happy Corgi playing in park, sunset, Van Gogh style
|
531 |
+
A cute happy Corgi playing in park, sunset, oil painting
|
532 |
+
A cute happy Corgi playing in park, sunset by Hokusai, in the style of Ukiyo
|
533 |
+
A cute happy Corgi playing in park, sunset, black and white
|
534 |
+
A cute happy Corgi playing in park, sunset, pixel art
|
535 |
+
A cute happy Corgi playing in park, sunset, in cyberpunk style
|
536 |
+
A cute happy Corgi playing in park, sunset, animated style
|
537 |
+
A cute happy Corgi playing in park, sunset, watercolor painting
|
538 |
+
A cute happy Corgi playing in park, sunset, surrealism style
|
539 |
+
Gwen Stacy reading a book, Van Gogh style
|
540 |
+
Gwen Stacy reading a book, oil painting
|
541 |
+
Gwen Stacy reading a book by Hokusai, in the style of Ukiyo
|
542 |
+
Gwen Stacy reading a book, black and white
|
543 |
+
Gwen Stacy reading a book, pixel art
|
544 |
+
Gwen Stacy reading a book, in cyberpunk style
|
545 |
+
Gwen Stacy reading a book, animated style
|
546 |
+
Gwen Stacy reading a book, watercolor painting
|
547 |
+
Gwen Stacy reading a book, surrealism style
|
548 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, Van Gogh style
|
549 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, oil painting
|
550 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Hokusai, in the style of Ukiyo
|
551 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, black and white
|
552 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pixel art
|
553 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in cyberpunk style
|
554 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, animated style
|
555 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, watercolor painting
|
556 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, surrealism style
|
557 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, Van Gogh style
|
558 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, oil painting
|
559 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas by Hokusai, in the style of Ukiyo
|
560 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, black and white
|
561 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pixel art
|
562 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in cyberpunk style
|
563 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, animated style
|
564 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, watercolor painting
|
565 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, surrealism style
|
566 |
+
An astronaut flying in space, Van Gogh style
|
567 |
+
An astronaut flying in space, oil painting
|
568 |
+
An astronaut flying in space by Hokusai, in the style of Ukiyo
|
569 |
+
An astronaut flying in space, black and white
|
570 |
+
An astronaut flying in space, pixel art
|
571 |
+
An astronaut flying in space, in cyberpunk style
|
572 |
+
An astronaut flying in space, animated style
|
573 |
+
An astronaut flying in space, watercolor painting
|
574 |
+
An astronaut flying in space, surrealism style
|
575 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, Van Gogh style
|
576 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, oil painting
|
577 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks by Hokusai, in the style of Ukiyo
|
578 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, black and white
|
579 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pixel art
|
580 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in cyberpunk style
|
581 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, animated style
|
582 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, watercolor painting
|
583 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, surrealism style
|
584 |
+
A beautiful coastal beach in spring, waves lapping on sand, in super slow motion
|
585 |
+
A beautiful coastal beach in spring, waves lapping on sand, zoom in
|
586 |
+
A beautiful coastal beach in spring, waves lapping on sand, zoom out
|
587 |
+
A beautiful coastal beach in spring, waves lapping on sand, pan left
|
588 |
+
A beautiful coastal beach in spring, waves lapping on sand, pan right
|
589 |
+
A beautiful coastal beach in spring, waves lapping on sand, tilt up
|
590 |
+
A beautiful coastal beach in spring, waves lapping on sand, tilt down
|
591 |
+
A beautiful coastal beach in spring, waves lapping on sand, with an intense shaking effect
|
592 |
+
A beautiful coastal beach in spring, waves lapping on sand, featuring a steady and smooth perspective
|
593 |
+
A beautiful coastal beach in spring, waves lapping on sand, racking focus
|
594 |
+
The bund Shanghai, in super slow motion
|
595 |
+
The bund Shanghai, zoom in
|
596 |
+
The bund Shanghai, zoom out
|
597 |
+
The bund Shanghai, pan left
|
598 |
+
The bund Shanghai, pan right
|
599 |
+
The bund Shanghai, tilt up
|
600 |
+
The bund Shanghai, tilt down
|
601 |
+
The bund Shanghai, with an intense shaking effect
|
602 |
+
The bund Shanghai, featuring a steady and smooth perspective
|
603 |
+
The bund Shanghai, racking focus
|
604 |
+
a shark is swimming in the ocean, in super slow motion
|
605 |
+
a shark is swimming in the ocean, zoom in
|
606 |
+
a shark is swimming in the ocean, zoom out
|
607 |
+
a shark is swimming in the ocean, pan left
|
608 |
+
a shark is swimming in the ocean, pan right
|
609 |
+
a shark is swimming in the ocean, tilt up
|
610 |
+
a shark is swimming in the ocean, tilt down
|
611 |
+
a shark is swimming in the ocean, with an intense shaking effect
|
612 |
+
a shark is swimming in the ocean, featuring a steady and smooth perspective
|
613 |
+
a shark is swimming in the ocean, racking focus
|
614 |
+
A panda drinking coffee in a cafe in Paris, in super slow motion
|
615 |
+
A panda drinking coffee in a cafe in Paris, zoom in
|
616 |
+
A panda drinking coffee in a cafe in Paris, zoom out
|
617 |
+
A panda drinking coffee in a cafe in Paris, pan left
|
618 |
+
A panda drinking coffee in a cafe in Paris, pan right
|
619 |
+
A panda drinking coffee in a cafe in Paris, tilt up
|
620 |
+
A panda drinking coffee in a cafe in Paris, tilt down
|
621 |
+
A panda drinking coffee in a cafe in Paris, with an intense shaking effect
|
622 |
+
A panda drinking coffee in a cafe in Paris, featuring a steady and smooth perspective
|
623 |
+
A panda drinking coffee in a cafe in Paris, racking focus
|
624 |
+
A cute happy Corgi playing in park, sunset, in super slow motion
|
625 |
+
A cute happy Corgi playing in park, sunset, zoom in
|
626 |
+
A cute happy Corgi playing in park, sunset, zoom out
|
627 |
+
A cute happy Corgi playing in park, sunset, pan left
|
628 |
+
A cute happy Corgi playing in park, sunset, pan right
|
629 |
+
A cute happy Corgi playing in park, sunset, tilt up
|
630 |
+
A cute happy Corgi playing in park, sunset, tilt down
|
631 |
+
A cute happy Corgi playing in park, sunset, with an intense shaking effect
|
632 |
+
A cute happy Corgi playing in park, sunset, featuring a steady and smooth perspective
|
633 |
+
A cute happy Corgi playing in park, sunset, racking focus
|
634 |
+
Gwen Stacy reading a book, in super slow motion
|
635 |
+
Gwen Stacy reading a book, zoom in
|
636 |
+
Gwen Stacy reading a book, zoom out
|
637 |
+
Gwen Stacy reading a book, pan left
|
638 |
+
Gwen Stacy reading a book, pan right
|
639 |
+
Gwen Stacy reading a book, tilt up
|
640 |
+
Gwen Stacy reading a book, tilt down
|
641 |
+
Gwen Stacy reading a book, with an intense shaking effect
|
642 |
+
Gwen Stacy reading a book, featuring a steady and smooth perspective
|
643 |
+
Gwen Stacy reading a book, racking focus
|
644 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in super slow motion
|
645 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom in
|
646 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom out
|
647 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan left
|
648 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan right
|
649 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt up
|
650 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt down
|
651 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, with an intense shaking effect
|
652 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, featuring a steady and smooth perspective
|
653 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background, racking focus
|
654 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in super slow motion
|
655 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom in
|
656 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom out
|
657 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan left
|
658 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan right
|
659 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt up
|
660 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt down
|
661 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, with an intense shaking effect
|
662 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, featuring a steady and smooth perspective
|
663 |
+
A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, racking focus
|
664 |
+
An astronaut flying in space, in super slow motion
|
665 |
+
An astronaut flying in space, zoom in
|
666 |
+
An astronaut flying in space, zoom out
|
667 |
+
An astronaut flying in space, pan left
|
668 |
+
An astronaut flying in space, pan right
|
669 |
+
An astronaut flying in space, tilt up
|
670 |
+
An astronaut flying in space, tilt down
|
671 |
+
An astronaut flying in space, with an intense shaking effect
|
672 |
+
An astronaut flying in space, featuring a steady and smooth perspective
|
673 |
+
An astronaut flying in space, racking focus
|
674 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in super slow motion
|
675 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom in
|
676 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom out
|
677 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan left
|
678 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan right
|
679 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt up
|
680 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt down
|
681 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, with an intense shaking effect
|
682 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, featuring a steady and smooth perspective
|
683 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, racking focus
|
684 |
+
Close up of grapes on a rotating table.
|
685 |
+
Turtle swimming in ocean.
|
686 |
+
A storm trooper vacuuming the beach.
|
687 |
+
A panda standing on a surfboard in the ocean in sunset.
|
688 |
+
An astronaut feeding ducks on a sunny afternoon, reflection from the water.
|
689 |
+
Two pandas discussing an academic paper.
|
690 |
+
Sunset time lapse at the beach with moving clouds and colors in the sky.
|
691 |
+
A fat rabbit wearing a purple robe walking through a fantasy landscape.
|
692 |
+
A koala bear playing piano in the forest.
|
693 |
+
An astronaut flying in space.
|
694 |
+
Fireworks.
|
695 |
+
An animated painting of fluffy white clouds moving in sky.
|
696 |
+
Flying through fantasy landscapes.
|
697 |
+
A bigfoot walking in the snowstorm.
|
698 |
+
A squirrel eating a burger.
|
699 |
+
A cat wearing sunglasses and working as a lifeguard at a pool.
|
700 |
+
Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks.
|
701 |
+
Splash of turquoise water in extreme slow motion, alpha channel included.
|
702 |
+
an ice cream is melting on the table.
|
703 |
+
a drone flying over a snowy forest.
|
704 |
+
a shark is swimming in the ocean.
|
705 |
+
Aerial panoramic video from a drone of a fantasy land.
|
706 |
+
a teddy bear is swimming in the ocean.
|
707 |
+
time lapse of sunrise on mars.
|
708 |
+
golden fish swimming in the ocean.
|
709 |
+
An artist brush painting on a canvas close up.
|
710 |
+
A drone view of celebration with Christmas tree and fireworks, starry sky - background.
|
711 |
+
happy dog wearing a yellow turtleneck, studio, portrait, facing camera, dark background
|
712 |
+
Origami dancers in white paper, 3D render, on white background, studio shot, dancing modern dance.
|
713 |
+
Campfire at night in a snowy forest with starry sky in the background.
|
714 |
+
a fantasy landscape
|
715 |
+
A 3D model of a 1800s victorian house.
|
716 |
+
this is how I do makeup in the morning.
|
717 |
+
A raccoon that looks like a turtle, digital art.
|
718 |
+
Robot dancing in Times Square.
|
719 |
+
Busy freeway at night.
|
720 |
+
Balloon full of water exploding in extreme slow motion.
|
721 |
+
An astronaut is riding a horse in the space in a photorealistic style.
|
722 |
+
Macro slo-mo. Slow motion cropped closeup of roasted coffee beans falling into an empty bowl.
|
723 |
+
Sewing machine, old sewing machine working.
|
724 |
+
Motion colour drop in water, ink swirling in water, colourful ink in water, abstraction fancy dream cloud of ink.
|
725 |
+
Few big purple plums rotating on the turntable. water drops appear on the skin during rotation. isolated on the white background. close-up. macro.
|
726 |
+
Vampire makeup face of beautiful girl, red contact lenses.
|
727 |
+
Ashtray full of butts on table, smoke flowing on black background, close-up
|
728 |
+
Pacific coast, carmel by the sea ocean and waves.
|
729 |
+
A teddy bear is playing drum kit in NYC Times Square.
|
730 |
+
A corgi is playing drum kit.
|
731 |
+
An Iron man is playing the electronic guitar, high electronic guitar.
|
732 |
+
A raccoon is playing the electronic guitar.
|
733 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Vincent van Gogh
|
734 |
+
A corgi's head depicted as an explosion of a nebula
|
735 |
+
A fantasy landscape
|
736 |
+
A future where humans have achieved teleportation technology
|
737 |
+
A jellyfish floating through the ocean, with bioluminescent tentacles
|
738 |
+
A Mars rover moving on Mars
|
739 |
+
A panda drinking coffee in a cafe in Paris
|
740 |
+
A space shuttle launching into orbit, with flames and smoke billowing out from the engines
|
741 |
+
A steam train moving on a mountainside
|
742 |
+
A super cool giant robot in Cyberpunk Beijing
|
743 |
+
A tropical beach at sunrise, with palm trees and crystal-clear water in the foreground
|
744 |
+
Cinematic shot of Van Gogh's selfie, Van Gogh style
|
745 |
+
Gwen Stacy reading a book
|
746 |
+
Iron Man flying in the sky
|
747 |
+
The bund Shanghai, oil painting
|
748 |
+
Yoda playing guitar on the stage
|
749 |
+
A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo
|
750 |
+
A beautiful coastal beach in spring, waves lapping on sand by Vincent van Gogh
|
751 |
+
A boat sailing leisurely along the Seine River with the Eiffel Tower in background
|
752 |
+
A car moving slowly on an empty street, rainy evening
|
753 |
+
A cat eating food out of a bowl
|
754 |
+
A cat wearing sunglasses at a pool
|
755 |
+
A confused panda in calculus class
|
756 |
+
A cute fluffy panda eating Chinese food in a restaurant
|
757 |
+
A cute happy Corgi playing in park, sunset
|
758 |
+
A cute raccoon playing guitar in a boat on the ocean
|
759 |
+
A happy fuzzy panda playing guitar nearby a campfire, snow mountain in the background
|
760 |
+
A lightning striking atop of eiffel tower, dark clouds in the sky
|
761 |
+
A modern art museum, with colorful paintings
|
762 |
+
A panda cooking in the kitchen
|
763 |
+
A panda playing on a swing set
|
764 |
+
A polar bear is playing guitar
|
765 |
+
A raccoon dressed in suit playing the trumpet, stage background
|
766 |
+
A robot DJ is playing the turntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-fi, fantasy
|
767 |
+
A shark swimming in clear Caribbean ocean
|
768 |
+
A super robot protecting city
|
769 |
+
A teddy bear washing the dishes
|
770 |
+
An epic tornado attacking above a glowing city at night, the tornado is made of smoke
|
771 |
+
An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with umbrellas
|
772 |
+
Clown fish swimming through the coral reef
|
773 |
+
Hyper-realistic spaceship landing on Mars
|
774 |
+
The bund Shanghai, vibrant color
|
775 |
+
Vincent van Gogh is painting in the room
|
776 |
+
Yellow flowers swing in the wind
|
777 |
+
alley
|
778 |
+
amusement park
|
779 |
+
aquarium
|
780 |
+
arch
|
781 |
+
art gallery
|
782 |
+
bathroom
|
783 |
+
bakery shop
|
784 |
+
ballroom
|
785 |
+
bar
|
786 |
+
barn
|
787 |
+
basement
|
788 |
+
beach
|
789 |
+
bedroom
|
790 |
+
bridge
|
791 |
+
botanical garden
|
792 |
+
cafeteria
|
793 |
+
campsite
|
794 |
+
campus
|
795 |
+
carrousel
|
796 |
+
castle
|
797 |
+
cemetery
|
798 |
+
classroom
|
799 |
+
cliff
|
800 |
+
crosswalk
|
801 |
+
construction site
|
802 |
+
corridor
|
803 |
+
courtyard
|
804 |
+
desert
|
805 |
+
downtown
|
806 |
+
driveway
|
807 |
+
farm
|
808 |
+
food court
|
809 |
+
football field
|
810 |
+
forest road
|
811 |
+
fountain
|
812 |
+
gas station
|
813 |
+
glacier
|
814 |
+
golf course
|
815 |
+
indoor gymnasium
|
816 |
+
harbor
|
817 |
+
highway
|
818 |
+
hospital
|
819 |
+
house
|
820 |
+
iceberg
|
821 |
+
industrial area
|
822 |
+
jail cell
|
823 |
+
junkyard
|
824 |
+
kitchen
|
825 |
+
indoor library
|
826 |
+
lighthouse
|
827 |
+
laboratory
|
828 |
+
mansion
|
829 |
+
marsh
|
830 |
+
mountain
|
831 |
+
indoor movie theater
|
832 |
+
indoor museum
|
833 |
+
music studio
|
834 |
+
nursery
|
835 |
+
ocean
|
836 |
+
office
|
837 |
+
palace
|
838 |
+
parking lot
|
839 |
+
pharmacy
|
840 |
+
phone booth
|
841 |
+
raceway
|
842 |
+
restaurant
|
843 |
+
river
|
844 |
+
science museum
|
845 |
+
shower
|
846 |
+
ski slope
|
847 |
+
sky
|
848 |
+
skyscraper
|
849 |
+
baseball stadium
|
850 |
+
staircase
|
851 |
+
street
|
852 |
+
supermarket
|
853 |
+
indoor swimming pool
|
854 |
+
tower
|
855 |
+
outdoor track
|
856 |
+
train railway
|
857 |
+
train station platform
|
858 |
+
underwater coral reef
|
859 |
+
valley
|
860 |
+
volcano
|
861 |
+
waterfall
|
862 |
+
windmill
|
863 |
+
a bicycle on the left of a car, front view
|
864 |
+
a car on the right of a motorcycle, front view
|
865 |
+
a motorcycle on the left of a bus, front view
|
866 |
+
a bus on the right of a traffic light, front view
|
867 |
+
a traffic light on the left of a fire hydrant, front view
|
868 |
+
a fire hydrant on the right of a stop sign, front view
|
869 |
+
a stop sign on the left of a parking meter, front view
|
870 |
+
a parking meter on the right of a bench, front view
|
871 |
+
a bench on the left of a truck, front view
|
872 |
+
a truck on the right of a bicycle, front view
|
873 |
+
a bird on the left of a cat, front view
|
874 |
+
a cat on the right of a dog, front view
|
875 |
+
a dog on the left of a horse, front view
|
876 |
+
a horse on the right of a sheep, front view
|
877 |
+
a sheep on the left of a cow, front view
|
878 |
+
a cow on the right of an elephant, front view
|
879 |
+
an elephant on the left of a bear, front view
|
880 |
+
a bear on the right of a zebra, front view
|
881 |
+
a zebra on the left of a giraffe, front view
|
882 |
+
a giraffe on the right of a bird, front view
|
883 |
+
a bottle on the left of a wine glass, front view
|
884 |
+
a wine glass on the right of a cup, front view
|
885 |
+
a cup on the left of a fork, front view
|
886 |
+
a fork on the right of a knife, front view
|
887 |
+
a knife on the left of a spoon, front view
|
888 |
+
a spoon on the right of a bowl, front view
|
889 |
+
a bowl on the left of a bottle, front view
|
890 |
+
a potted plant on the left of a remote, front view
|
891 |
+
a remote on the right of a clock, front view
|
892 |
+
a clock on the left of a vase, front view
|
893 |
+
a vase on the right of scissors, front view
|
894 |
+
scissors on the left of a teddy bear, front view
|
895 |
+
a teddy bear on the right of a potted plant, front view
|
896 |
+
a frisbee on the left of a sports ball, front view
|
897 |
+
a sports ball on the right of a baseball bat, front view
|
898 |
+
a baseball bat on the left of a baseball glove, front view
|
899 |
+
a baseball glove on the right of a tennis racket, front view
|
900 |
+
a tennis racket on the left of a frisbee, front view
|
901 |
+
a toilet on the left of a hair drier, front view
|
902 |
+
a hair drier on the right of a toothbrush, front view
|
903 |
+
a toothbrush on the left of a sink, front view
|
904 |
+
a sink on the right of a toilet, front view
|
905 |
+
a chair on the left of a couch, front view
|
906 |
+
a couch on the right of a bed, front view
|
907 |
+
a bed on the left of a tv, front view
|
908 |
+
a tv on the right of a dining table, front view
|
909 |
+
a dining table on the left of a chair, front view
|
910 |
+
an airplane on the left of a train, front view
|
911 |
+
a train on the right of a boat, front view
|
912 |
+
a boat on the left of an airplane, front view
|
913 |
+
an oven on the top of a toaster, front view
|
914 |
+
an oven on the bottom of a toaster, front view
|
915 |
+
a toaster on the top of a microwave, front view
|
916 |
+
a toaster on the bottom of a microwave, front view
|
917 |
+
a microwave on the top of an oven, front view
|
918 |
+
a microwave on the bottom of an oven, front view
|
919 |
+
a banana on the top of an apple, front view
|
920 |
+
a banana on the bottom of an apple, front view
|
921 |
+
an apple on the top of a sandwich, front view
|
922 |
+
an apple on the bottom of a sandwich, front view
|
923 |
+
a sandwich on the top of an orange, front view
|
924 |
+
a sandwich on the bottom of an orange, front view
|
925 |
+
an orange on the top of a carrot, front view
|
926 |
+
an orange on the bottom of a carrot, front view
|
927 |
+
a carrot on the top of a hot dog, front view
|
928 |
+
a carrot on the bottom of a hot dog, front view
|
929 |
+
a hot dog on the top of a pizza, front view
|
930 |
+
a hot dog on the bottom of a pizza, front view
|
931 |
+
a pizza on the top of a donut, front view
|
932 |
+
a pizza on the bottom of a donut, front view
|
933 |
+
a donut on the top of broccoli, front view
|
934 |
+
a donut on the bottom of broccoli, front view
|
935 |
+
broccoli on the top of a banana, front view
|
936 |
+
broccoli on the bottom of a banana, front view
|
937 |
+
skis on the top of a snowboard, front view
|
938 |
+
skis on the bottom of a snowboard, front view
|
939 |
+
a snowboard on the top of a kite, front view
|
940 |
+
a snowboard on the bottom of a kite, front view
|
941 |
+
a kite on the top of a skateboard, front view
|
942 |
+
a kite on the bottom of a skateboard, front view
|
943 |
+
a skateboard on the top of a surfboard, front view
|
944 |
+
a skateboard on the bottom of a surfboard, front view
|
945 |
+
a surfboard on the top of skis, front view
|
946 |
+
a surfboard on the bottom of skis, front view
|
prompts/vbench/all_dimension_extended.txt
ADDED
The diff for this file is too large to render.
See raw diff
|
|
requirements.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch>=2.4.0
|
2 |
+
torchvision>=0.19.0
|
3 |
+
opencv-python>=4.9.0.80
|
4 |
+
diffusers==0.31.0
|
5 |
+
transformers>=4.49.0
|
6 |
+
tokenizers>=0.20.3
|
7 |
+
accelerate>=1.1.1
|
8 |
+
tqdm
|
9 |
+
imageio
|
10 |
+
easydict
|
11 |
+
ftfy
|
12 |
+
dashscope
|
13 |
+
imageio-ffmpeg
|
14 |
+
numpy==1.24.4
|
15 |
+
wandb
|
16 |
+
omegaconf
|
17 |
+
einops
|
18 |
+
av==13.1.0
|
19 |
+
opencv-python
|
20 |
+
git+https://github.com/openai/CLIP.git
|
21 |
+
open_clip_torch
|
22 |
+
starlette
|
23 |
+
pycocotools
|
24 |
+
lmdb
|
25 |
+
matplotlib
|
26 |
+
sentencepiece
|
27 |
+
pydantic==2.10.6
|
28 |
+
scikit-image
|
29 |
+
huggingface_hub[cli]
|
30 |
+
dominate
|
31 |
+
nvidia-pyindex
|
32 |
+
nvidia-tensorrt
|
33 |
+
pycuda
|
34 |
+
onnx
|
35 |
+
onnxruntime
|
36 |
+
onnxscript
|
37 |
+
onnxconverter_common
|
38 |
+
flask
|
39 |
+
flask-socketio
|
40 |
+
torchao
|
scripts/create_lmdb_14b_shards.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
python create_lmdb_14b_shards.py \
|
3 |
+
--data_path /mnt/localssd/wanx_14b_data \
|
4 |
+
--lmdb_path /mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb
|
5 |
+
"""
|
6 |
+
from tqdm import tqdm
|
7 |
+
import numpy as np
|
8 |
+
import argparse
|
9 |
+
import torch
|
10 |
+
import lmdb
|
11 |
+
import glob
|
12 |
+
import os
|
13 |
+
|
14 |
+
from utils.lmdb import store_arrays_to_lmdb, process_data_dict
|
15 |
+
|
16 |
+
|
17 |
+
def main():
|
18 |
+
"""
|
19 |
+
Aggregate all ode pairs inside a folder into a lmdb dataset.
|
20 |
+
Each pt file should contain a (key, value) pair representing a
|
21 |
+
video's ODE trajectories.
|
22 |
+
"""
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument("--data_path", type=str,
|
25 |
+
required=True, help="path to ode pairs")
|
26 |
+
parser.add_argument("--lmdb_path", type=str,
|
27 |
+
required=True, help="path to lmdb")
|
28 |
+
parser.add_argument("--num_shards", type=int,
|
29 |
+
default=16, help="num_shards")
|
30 |
+
|
31 |
+
args = parser.parse_args()
|
32 |
+
|
33 |
+
all_dirs = sorted(os.listdir(args.data_path))
|
34 |
+
|
35 |
+
# figure out the maximum map size needed
|
36 |
+
map_size = int(1e12) # adapt to your need, set to 1TB by default
|
37 |
+
os.makedirs(args.lmdb_path, exist_ok=True)
|
38 |
+
# 1) Open one LMDB env per shard
|
39 |
+
envs = []
|
40 |
+
num_shards = args.num_shards
|
41 |
+
for shard_id in range(num_shards):
|
42 |
+
print("shard_id ", shard_id)
|
43 |
+
path = os.path.join(args.lmdb_path, f"shard_{shard_id}")
|
44 |
+
env = lmdb.open(path,
|
45 |
+
map_size=map_size,
|
46 |
+
subdir=True, # set to True if you want a directory per env
|
47 |
+
readonly=False,
|
48 |
+
metasync=True,
|
49 |
+
sync=True,
|
50 |
+
lock=True,
|
51 |
+
readahead=False,
|
52 |
+
meminit=False)
|
53 |
+
envs.append(env)
|
54 |
+
|
55 |
+
counters = [0] * num_shards
|
56 |
+
seen_prompts = set() # for deduplication
|
57 |
+
total_samples = 0
|
58 |
+
all_files = []
|
59 |
+
|
60 |
+
for part_dir in all_dirs:
|
61 |
+
all_files += sorted(glob.glob(os.path.join(args.data_path, part_dir, "*.pt")))
|
62 |
+
|
63 |
+
# 2) Prepare a write transaction for each shard
|
64 |
+
for idx, file in tqdm(enumerate(all_files)):
|
65 |
+
try:
|
66 |
+
data_dict = torch.load(file)
|
67 |
+
data_dict = process_data_dict(data_dict, seen_prompts)
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error processing {file}: {e}")
|
70 |
+
continue
|
71 |
+
|
72 |
+
if data_dict["latents"].shape != (1, 21, 16, 60, 104):
|
73 |
+
continue
|
74 |
+
|
75 |
+
shard_id = idx % num_shards
|
76 |
+
# write to lmdb file
|
77 |
+
store_arrays_to_lmdb(envs[shard_id], data_dict, start_index=counters[shard_id])
|
78 |
+
counters[shard_id] += len(data_dict['prompts'])
|
79 |
+
data_shape = data_dict["latents"].shape
|
80 |
+
|
81 |
+
total_samples += len(all_files)
|
82 |
+
|
83 |
+
print(len(seen_prompts))
|
84 |
+
|
85 |
+
# save each entry's shape to lmdb
|
86 |
+
for shard_id, env in enumerate(envs):
|
87 |
+
with env.begin(write=True) as txn:
|
88 |
+
for key, val in (data_dict.items()):
|
89 |
+
assert len(data_shape) == 5
|
90 |
+
array_shape = np.array(data_shape) # val.shape)
|
91 |
+
array_shape[0] = counters[shard_id]
|
92 |
+
shape_key = f"{key}_shape".encode()
|
93 |
+
print(shape_key, array_shape)
|
94 |
+
shape_str = " ".join(map(str, array_shape))
|
95 |
+
txn.put(shape_key, shape_str.encode())
|
96 |
+
|
97 |
+
print(f"Finished writing {total_samples} examples into {num_shards} shards under {args.lmdb_path}")
|
98 |
+
|
99 |
+
|
100 |
+
if __name__ == "__main__":
|
101 |
+
main()
|
scripts/create_lmdb_iterative.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import numpy as np
|
3 |
+
import argparse
|
4 |
+
import torch
|
5 |
+
import lmdb
|
6 |
+
import glob
|
7 |
+
import os
|
8 |
+
|
9 |
+
from utils.lmdb import store_arrays_to_lmdb, process_data_dict
|
10 |
+
|
11 |
+
|
12 |
+
def main():
|
13 |
+
"""
|
14 |
+
Aggregate all ode pairs inside a folder into a lmdb dataset.
|
15 |
+
Each pt file should contain a (key, value) pair representing a
|
16 |
+
video's ODE trajectories.
|
17 |
+
"""
|
18 |
+
parser = argparse.ArgumentParser()
|
19 |
+
parser.add_argument("--data_path", type=str,
|
20 |
+
required=True, help="path to ode pairs")
|
21 |
+
parser.add_argument("--lmdb_path", type=str,
|
22 |
+
required=True, help="path to lmdb")
|
23 |
+
|
24 |
+
args = parser.parse_args()
|
25 |
+
|
26 |
+
all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt")))
|
27 |
+
|
28 |
+
# figure out the maximum map size needed
|
29 |
+
total_array_size = 5000000000000 # adapt to your need, set to 5TB by default
|
30 |
+
|
31 |
+
env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2)
|
32 |
+
|
33 |
+
counter = 0
|
34 |
+
|
35 |
+
seen_prompts = set() # for deduplication
|
36 |
+
|
37 |
+
for index, file in tqdm(enumerate(all_files)):
|
38 |
+
# read from disk
|
39 |
+
data_dict = torch.load(file)
|
40 |
+
|
41 |
+
data_dict = process_data_dict(data_dict, seen_prompts)
|
42 |
+
|
43 |
+
# write to lmdb file
|
44 |
+
store_arrays_to_lmdb(env, data_dict, start_index=counter)
|
45 |
+
counter += len(data_dict['prompts'])
|
46 |
+
|
47 |
+
# save each entry's shape to lmdb
|
48 |
+
with env.begin(write=True) as txn:
|
49 |
+
for key, val in data_dict.items():
|
50 |
+
print(key, val)
|
51 |
+
array_shape = np.array(val.shape)
|
52 |
+
array_shape[0] = counter
|
53 |
+
|
54 |
+
shape_key = f"{key}_shape".encode()
|
55 |
+
shape_str = " ".join(map(str, array_shape))
|
56 |
+
txn.put(shape_key, shape_str.encode())
|
57 |
+
|
58 |
+
|
59 |
+
if __name__ == "__main__":
|
60 |
+
main()
|
scripts/generate_ode_pairs.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.distributed import launch_distributed_job
|
2 |
+
from utils.scheduler import FlowMatchScheduler
|
3 |
+
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
|
4 |
+
from utils.dataset import TextDataset
|
5 |
+
import torch.distributed as dist
|
6 |
+
from tqdm import tqdm
|
7 |
+
import argparse
|
8 |
+
import torch
|
9 |
+
import math
|
10 |
+
import os
|
11 |
+
|
12 |
+
|
13 |
+
def init_model(device):
|
14 |
+
model = WanDiffusionWrapper().to(device).to(torch.float32)
|
15 |
+
encoder = WanTextEncoder().to(device).to(torch.float32)
|
16 |
+
model.model.requires_grad_(False)
|
17 |
+
|
18 |
+
scheduler = FlowMatchScheduler(
|
19 |
+
shift=8.0, sigma_min=0.0, extra_one_step=True)
|
20 |
+
scheduler.set_timesteps(num_inference_steps=48, denoising_strength=1.0)
|
21 |
+
scheduler.sigmas = scheduler.sigmas.to(device)
|
22 |
+
|
23 |
+
sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
|
24 |
+
|
25 |
+
unconditional_dict = encoder(
|
26 |
+
text_prompts=[sample_neg_prompt]
|
27 |
+
)
|
28 |
+
|
29 |
+
return model, encoder, scheduler, unconditional_dict
|
30 |
+
|
31 |
+
|
32 |
+
def main():
|
33 |
+
parser = argparse.ArgumentParser()
|
34 |
+
parser.add_argument("--local_rank", type=int, default=-1)
|
35 |
+
parser.add_argument("--output_folder", type=str)
|
36 |
+
parser.add_argument("--caption_path", type=str)
|
37 |
+
parser.add_argument("--guidance_scale", type=float, default=6.0)
|
38 |
+
|
39 |
+
args = parser.parse_args()
|
40 |
+
|
41 |
+
# launch_distributed_job()
|
42 |
+
launch_distributed_job()
|
43 |
+
|
44 |
+
device = torch.cuda.current_device()
|
45 |
+
|
46 |
+
torch.set_grad_enabled(False)
|
47 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
48 |
+
torch.backends.cudnn.allow_tf32 = True
|
49 |
+
|
50 |
+
model, encoder, scheduler, unconditional_dict = init_model(device=device)
|
51 |
+
|
52 |
+
dataset = TextDataset(args.caption_path)
|
53 |
+
|
54 |
+
# if global_rank == 0:
|
55 |
+
os.makedirs(args.output_folder, exist_ok=True)
|
56 |
+
|
57 |
+
for index in tqdm(range(int(math.ceil(len(dataset) / dist.get_world_size()))), disable=dist.get_rank() != 0):
|
58 |
+
prompt_index = index * dist.get_world_size() + dist.get_rank()
|
59 |
+
if prompt_index >= len(dataset):
|
60 |
+
continue
|
61 |
+
prompt = dataset[prompt_index]
|
62 |
+
|
63 |
+
conditional_dict = encoder(text_prompts=prompt)
|
64 |
+
|
65 |
+
latents = torch.randn(
|
66 |
+
[1, 21, 16, 60, 104], dtype=torch.float32, device=device
|
67 |
+
)
|
68 |
+
|
69 |
+
noisy_input = []
|
70 |
+
|
71 |
+
for progress_id, t in enumerate(tqdm(scheduler.timesteps)):
|
72 |
+
timestep = t * \
|
73 |
+
torch.ones([1, 21], device=device, dtype=torch.float32)
|
74 |
+
|
75 |
+
noisy_input.append(latents)
|
76 |
+
|
77 |
+
_, x0_pred_cond = model(
|
78 |
+
latents, conditional_dict, timestep
|
79 |
+
)
|
80 |
+
|
81 |
+
_, x0_pred_uncond = model(
|
82 |
+
latents, unconditional_dict, timestep
|
83 |
+
)
|
84 |
+
|
85 |
+
x0_pred = x0_pred_uncond + args.guidance_scale * (
|
86 |
+
x0_pred_cond - x0_pred_uncond
|
87 |
+
)
|
88 |
+
|
89 |
+
flow_pred = model._convert_x0_to_flow_pred(
|
90 |
+
scheduler=scheduler,
|
91 |
+
x0_pred=x0_pred.flatten(0, 1),
|
92 |
+
xt=latents.flatten(0, 1),
|
93 |
+
timestep=timestep.flatten(0, 1)
|
94 |
+
).unflatten(0, x0_pred.shape[:2])
|
95 |
+
|
96 |
+
latents = scheduler.step(
|
97 |
+
flow_pred.flatten(0, 1),
|
98 |
+
scheduler.timesteps[progress_id] * torch.ones(
|
99 |
+
[1, 21], device=device, dtype=torch.long).flatten(0, 1),
|
100 |
+
latents.flatten(0, 1)
|
101 |
+
).unflatten(dim=0, sizes=flow_pred.shape[:2])
|
102 |
+
|
103 |
+
noisy_input.append(latents)
|
104 |
+
|
105 |
+
noisy_inputs = torch.stack(noisy_input, dim=1)
|
106 |
+
|
107 |
+
noisy_inputs = noisy_inputs[:, [0, 12, 24, 36, -1]]
|
108 |
+
|
109 |
+
stored_data = noisy_inputs
|
110 |
+
|
111 |
+
torch.save(
|
112 |
+
{prompt: stored_data.cpu().detach()},
|
113 |
+
os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
|
114 |
+
)
|
115 |
+
|
116 |
+
dist.barrier()
|
117 |
+
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
main()
|
setup.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from setuptools import setup, find_packages
|
2 |
+
setup(
|
3 |
+
name="self_forcing",
|
4 |
+
version="0.0.1",
|
5 |
+
packages=find_packages(),
|
6 |
+
)
|
templates/demo.html
ADDED
@@ -0,0 +1,615 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<!DOCTYPE html>
|
2 |
+
<html lang="en">
|
3 |
+
<head>
|
4 |
+
<meta charset="UTF-8">
|
5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
6 |
+
<title>Self Forcing</title>
|
7 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.0/socket.io.js"></script>
|
8 |
+
<style>
|
9 |
+
body {
|
10 |
+
font-family: Arial, sans-serif;
|
11 |
+
max-width: 1400px;
|
12 |
+
margin: 0 auto;
|
13 |
+
padding: 20px;
|
14 |
+
background-color: #f5f5f5;
|
15 |
+
}
|
16 |
+
.container {
|
17 |
+
background: white;
|
18 |
+
padding: 20px;
|
19 |
+
border-radius: 10px;
|
20 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
21 |
+
}
|
22 |
+
.main-layout {
|
23 |
+
display: grid;
|
24 |
+
grid-template-columns: 1fr 1fr;
|
25 |
+
gap: 30px;
|
26 |
+
margin-top: 20px;
|
27 |
+
}
|
28 |
+
.left-column {
|
29 |
+
padding-right: 15px;
|
30 |
+
}
|
31 |
+
.right-column {
|
32 |
+
padding-left: 15px;
|
33 |
+
}
|
34 |
+
@media (max-width: 768px) {
|
35 |
+
.main-layout {
|
36 |
+
grid-template-columns: 1fr;
|
37 |
+
gap: 20px;
|
38 |
+
}
|
39 |
+
.left-column, .right-column {
|
40 |
+
padding: 0;
|
41 |
+
}
|
42 |
+
}
|
43 |
+
.controls {
|
44 |
+
margin-bottom: 20px;
|
45 |
+
}
|
46 |
+
.control-group {
|
47 |
+
margin-bottom: 15px;
|
48 |
+
}
|
49 |
+
label {
|
50 |
+
display: block;
|
51 |
+
margin-bottom: 5px;
|
52 |
+
font-weight: bold;
|
53 |
+
}
|
54 |
+
input, textarea, button, select {
|
55 |
+
padding: 8px;
|
56 |
+
border: 1px solid #ddd;
|
57 |
+
border-radius: 4px;
|
58 |
+
}
|
59 |
+
textarea {
|
60 |
+
width: 100%;
|
61 |
+
height: 90px;
|
62 |
+
resize: vertical;
|
63 |
+
}
|
64 |
+
input[type="range"] {
|
65 |
+
width: 200px;
|
66 |
+
}
|
67 |
+
button {
|
68 |
+
background-color: #007bff;
|
69 |
+
color: white;
|
70 |
+
border: none;
|
71 |
+
padding: 10px 20px;
|
72 |
+
cursor: pointer;
|
73 |
+
margin-right: 10px;
|
74 |
+
}
|
75 |
+
button:hover {
|
76 |
+
background-color: #0056b3;
|
77 |
+
}
|
78 |
+
button:disabled {
|
79 |
+
background-color: #6c757d;
|
80 |
+
cursor: not-allowed;
|
81 |
+
}
|
82 |
+
.stop-btn {
|
83 |
+
background-color: #dc3545;
|
84 |
+
}
|
85 |
+
.stop-btn:hover {
|
86 |
+
background-color: #c82333;
|
87 |
+
}
|
88 |
+
.video-container {
|
89 |
+
text-align: center;
|
90 |
+
background: #000;
|
91 |
+
border-radius: 8px;
|
92 |
+
padding: 20px;
|
93 |
+
margin: 20px auto;
|
94 |
+
display: flex;
|
95 |
+
flex-direction: column;
|
96 |
+
align-items: center;
|
97 |
+
justify-content: center;
|
98 |
+
}
|
99 |
+
#videoFrame {
|
100 |
+
max-width: 100%;
|
101 |
+
height: auto;
|
102 |
+
border-radius: 4px;
|
103 |
+
}
|
104 |
+
.progress-container {
|
105 |
+
margin: 20px 0;
|
106 |
+
}
|
107 |
+
.progress-bar {
|
108 |
+
width: 100%;
|
109 |
+
height: 20px;
|
110 |
+
background-color: #e9ecef;
|
111 |
+
border-radius: 10px;
|
112 |
+
overflow: hidden;
|
113 |
+
}
|
114 |
+
.progress-fill {
|
115 |
+
height: 100%;
|
116 |
+
background-color: #007bff;
|
117 |
+
transition: width 0.3s ease;
|
118 |
+
}
|
119 |
+
.status {
|
120 |
+
margin: 10px 0;
|
121 |
+
padding: 10px;
|
122 |
+
border-radius: 4px;
|
123 |
+
}
|
124 |
+
.status.info {
|
125 |
+
background-color: #d1ecf1;
|
126 |
+
color: #0c5460;
|
127 |
+
}
|
128 |
+
.status.error {
|
129 |
+
background-color: #f8d7da;
|
130 |
+
color: #721c24;
|
131 |
+
}
|
132 |
+
.status.success {
|
133 |
+
background-color: #d4edda;
|
134 |
+
color: #155724;
|
135 |
+
}
|
136 |
+
.frame-info {
|
137 |
+
color: #666;
|
138 |
+
font-size: 0.9em;
|
139 |
+
margin-top: 10px;
|
140 |
+
}
|
141 |
+
.buffer-info {
|
142 |
+
background-color: #e3f2fd;
|
143 |
+
padding: 15px;
|
144 |
+
border-radius: 4px;
|
145 |
+
margin: 15px 0;
|
146 |
+
color: #1976d2;
|
147 |
+
}
|
148 |
+
.playback-controls {
|
149 |
+
margin: 15px 0;
|
150 |
+
display: flex;
|
151 |
+
align-items: center;
|
152 |
+
justify-content: center;
|
153 |
+
gap: 10px;
|
154 |
+
}
|
155 |
+
.playback-controls button {
|
156 |
+
margin: 0 5px;
|
157 |
+
padding: 8px 15px;
|
158 |
+
}
|
159 |
+
#playbackSpeed {
|
160 |
+
width: 80px;
|
161 |
+
}
|
162 |
+
.torch-compile-toggle {
|
163 |
+
background-color: #f8f9fa;
|
164 |
+
border: 1px solid #dee2e6;
|
165 |
+
border-radius: 6px;
|
166 |
+
padding: 10px;
|
167 |
+
margin: 0;
|
168 |
+
flex: 1;
|
169 |
+
min-width: 120px;
|
170 |
+
}
|
171 |
+
.torch-compile-toggle label {
|
172 |
+
display: flex;
|
173 |
+
align-items: center;
|
174 |
+
font-weight: bold;
|
175 |
+
color: #495057;
|
176 |
+
margin-bottom: 0;
|
177 |
+
font-size: 0.9em;
|
178 |
+
}
|
179 |
+
.torch-compile-toggle input[type="checkbox"] {
|
180 |
+
transform: scale(1.1);
|
181 |
+
margin-right: 8px;
|
182 |
+
}
|
183 |
+
</style>
|
184 |
+
</head>
|
185 |
+
<body>
|
186 |
+
<div class="container">
|
187 |
+
<h1>🚀 Self Forcing</h1>
|
188 |
+
|
189 |
+
<div class="main-layout">
|
190 |
+
<div class="left-column">
|
191 |
+
<div class="controls">
|
192 |
+
<div class="control-group">
|
193 |
+
<label for="prompt">Prompt (long, detailed prompts work better):</label>
|
194 |
+
<textarea id="prompt" placeholder="Describe the video you want to generate..."></textarea>
|
195 |
+
|
196 |
+
<div style="margin-top: 10px;">
|
197 |
+
<label>Quick Prompts:</label>
|
198 |
+
<div style="display: flex; flex-direction: column; gap: 8px; margin-top: 5px;">
|
199 |
+
<button type="button" onclick="setQuickPrompt('quick-demo-1')" style="background-color: #28a745; font-size: 11px; padding: 8px; width: 100%; text-align: left; white-space: pre-wrap; line-height: 1.3; min-height: 60px; border-radius: 4px; color: white; border: none; cursor: pointer;">A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.</button>
|
200 |
+
<button type="button" onclick="setQuickPrompt('quick-demo-2')" style="background-color: #17a2b8; font-size: 11px; padding: 8px; width: 100%; text-align: left; white-space: pre-wrap; line-height: 1.3; min-height: 60px; border-radius: 4px; color: white; border: none; cursor: pointer;">A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates a warm contrast, accentuating the cat’s orange fur. The shot is clear and sharp, with a shallow depth of field.</button>
|
201 |
+
</div>
|
202 |
+
</div>
|
203 |
+
</div>
|
204 |
+
|
205 |
+
<div style="display: flex; gap: 20px;">
|
206 |
+
<div class="control-group">
|
207 |
+
<label for="seed">Seed:</label>
|
208 |
+
<input type="number" id="seed" value="-1" min="0" max="999999">
|
209 |
+
</div>
|
210 |
+
|
211 |
+
<div class="control-group">
|
212 |
+
<label for="fps">Target FPS: <span id="fpsValue">6</span></label>
|
213 |
+
<input type="range" id="fps" min="2" max="16" value="6" step="0.5">
|
214 |
+
</div>
|
215 |
+
|
216 |
+
<!-- <div class="control-group">
|
217 |
+
<label for="blocks">Total Blocks: <span id="blocksValue">7</span></label>
|
218 |
+
<input type="range" id="blocks" min="3" max="10" value="7" step="1">
|
219 |
+
</div> -->
|
220 |
+
</div>
|
221 |
+
|
222 |
+
<div class="control-group">
|
223 |
+
<div style="display: flex; gap: 15px; align-items: flex-start; flex-wrap: wrap;">
|
224 |
+
<div class="torch-compile-toggle">
|
225 |
+
<label>
|
226 |
+
<input type="checkbox" id="torchCompile">
|
227 |
+
🔥 torch.compile
|
228 |
+
</label>
|
229 |
+
</div>
|
230 |
+
<div class="torch-compile-toggle">
|
231 |
+
<label>
|
232 |
+
<input type="checkbox" id="fp8Toggle">
|
233 |
+
⚡ FP8 Quantization
|
234 |
+
</label>
|
235 |
+
</div>
|
236 |
+
<div class="torch-compile-toggle">
|
237 |
+
<label>
|
238 |
+
<input type="checkbox" id="taehvToggle">
|
239 |
+
⚡ TAEHV VAE
|
240 |
+
</label>
|
241 |
+
</div>
|
242 |
+
</div>
|
243 |
+
<!-- <div style="font-size: 0.85em; color: #666; margin-top: 5px;">
|
244 |
+
<strong>Note:</strong> torch.compile and FP8 are one-time toggles (cannot be changed once applied)
|
245 |
+
</div> -->
|
246 |
+
</div>
|
247 |
+
|
248 |
+
<div class="control-group">
|
249 |
+
<button id="startBtn" onclick="startGeneration()">🚀 Start Generation</button>
|
250 |
+
<button id="stopBtn" onclick="stopGeneration()" disabled class="stop-btn">⏹️ Stop</button>
|
251 |
+
</div>
|
252 |
+
</div>
|
253 |
+
|
254 |
+
<div class="progress-container">
|
255 |
+
<div class="progress-bar">
|
256 |
+
<div id="progressFill" class="progress-fill" style="width: 0%"></div>
|
257 |
+
</div>
|
258 |
+
<div id="progressText">Ready to generate</div>
|
259 |
+
</div>
|
260 |
+
</div>
|
261 |
+
|
262 |
+
<div class="right-column">
|
263 |
+
<div class="buffer-info">
|
264 |
+
<strong>📦 Frame Buffer:</strong> <span id="bufferCount">0</span> frames ready |
|
265 |
+
<strong>📺 Displayed:</strong> <span id="displayedCount">0</span> frames
|
266 |
+
<!-- <strong>⚡ Receive Rate:</strong> <span id="receiveRate">0</span> fps -->
|
267 |
+
</div>
|
268 |
+
|
269 |
+
<div class="playback-controls">
|
270 |
+
<button id="playBtn" onclick="togglePlayback()" disabled>▶️ Play</button>
|
271 |
+
<button id="resetBtn" onclick="resetPlayback()" disabled>⏮️ Reset</button>
|
272 |
+
<label for="playbackSpeed">Speed:</label>
|
273 |
+
<select id="playbackSpeed" onchange="updatePlaybackSpeed()">
|
274 |
+
<option value="0.25">0.25x</option>
|
275 |
+
<option value="0.5">0.5x</option>
|
276 |
+
<option value="0.75">0.75x</option>
|
277 |
+
<option value="1" selected>1x</option>
|
278 |
+
<option value="1.25">1.25x</option>
|
279 |
+
<option value="1.5">1.5x</option>
|
280 |
+
<option value="2">2x</option>
|
281 |
+
</select>
|
282 |
+
</div>
|
283 |
+
|
284 |
+
<div id="statusContainer"></div>
|
285 |
+
|
286 |
+
<div class="video-container">
|
287 |
+
<img id="videoFrame" src="" alt="Video frames will appear here" style="display: none;">
|
288 |
+
<div id="placeholderText">Click "Start Generation" to begin</div>
|
289 |
+
<div id="frameInfo" class="frame-info"></div>
|
290 |
+
</div>
|
291 |
+
</div>
|
292 |
+
</div>
|
293 |
+
</div>
|
294 |
+
|
295 |
+
<script>
|
296 |
+
const socket = io();
|
297 |
+
let frameBuffer = []; // Store all received frames
|
298 |
+
let currentFrameIndex = 0;
|
299 |
+
let isPlaying = false;
|
300 |
+
let playbackInterval = null;
|
301 |
+
let targetFps = 6;
|
302 |
+
let playbackSpeed = 1.0;
|
303 |
+
let startTime = null;
|
304 |
+
let lastReceiveTime = null;
|
305 |
+
let receiveCount = 0;
|
306 |
+
let receiveRate = 0;
|
307 |
+
|
308 |
+
// State tracking for one-time toggles
|
309 |
+
let torchCompileApplied = false;
|
310 |
+
let fp8Applied = false;
|
311 |
+
|
312 |
+
// Update slider values
|
313 |
+
document.getElementById('fps').oninput = function() {
|
314 |
+
targetFps = parseFloat(this.value);
|
315 |
+
document.getElementById('fpsValue').textContent = this.value;
|
316 |
+
updatePlaybackTiming();
|
317 |
+
};
|
318 |
+
|
319 |
+
// document.getElementById('blocks').oninput = function() {
|
320 |
+
// document.getElementById('blocksValue').textContent = this.value;
|
321 |
+
// };
|
322 |
+
|
323 |
+
// Handle toggle behavior and fetch current status
|
324 |
+
function updateToggleStates() {
|
325 |
+
fetch('/api/status')
|
326 |
+
.then(response => response.json())
|
327 |
+
.then(data => {
|
328 |
+
torchCompileApplied = data.torch_compile_applied;
|
329 |
+
fp8Applied = data.fp8_applied;
|
330 |
+
|
331 |
+
// Update UI based on current state
|
332 |
+
const torchToggle = document.getElementById('torchCompile');
|
333 |
+
const fp8Toggle = document.getElementById('fp8Toggle');
|
334 |
+
const taehvToggle = document.getElementById('taehvToggle');
|
335 |
+
|
336 |
+
// Disable one-time toggles if already applied
|
337 |
+
if (torchCompileApplied) {
|
338 |
+
torchToggle.checked = true;
|
339 |
+
torchToggle.disabled = true;
|
340 |
+
torchToggle.parentElement.style.opacity = '0.6';
|
341 |
+
}
|
342 |
+
|
343 |
+
if (fp8Applied) {
|
344 |
+
fp8Toggle.checked = true;
|
345 |
+
fp8Toggle.disabled = true;
|
346 |
+
fp8Toggle.parentElement.style.opacity = '0.6';
|
347 |
+
}
|
348 |
+
|
349 |
+
// Set TAEHV toggle based on current state
|
350 |
+
taehvToggle.checked = data.current_use_taehv;
|
351 |
+
})
|
352 |
+
.catch(err => console.log('Status check failed:', err));
|
353 |
+
}
|
354 |
+
|
355 |
+
// Handle torch.compile toggle
|
356 |
+
document.getElementById('torchCompile').onchange = function() {
|
357 |
+
if (torchCompileApplied && !this.checked) {
|
358 |
+
this.checked = true; // Prevent unchecking
|
359 |
+
alert('torch.compile cannot be disabled once applied');
|
360 |
+
}
|
361 |
+
};
|
362 |
+
|
363 |
+
// Handle FP8 toggle
|
364 |
+
document.getElementById('fp8Toggle').onchange = function() {
|
365 |
+
if (fp8Applied && !this.checked) {
|
366 |
+
this.checked = true; // Prevent unchecking
|
367 |
+
alert('FP8 quantization cannot be disabled once applied');
|
368 |
+
}
|
369 |
+
};
|
370 |
+
|
371 |
+
// Update toggle states on page load
|
372 |
+
updateToggleStates();
|
373 |
+
|
374 |
+
// Socket event handlers
|
375 |
+
socket.on('connect', function() {
|
376 |
+
// showStatus('Connected to frontend-buffered server', 'info');
|
377 |
+
});
|
378 |
+
|
379 |
+
socket.on('status', function(data) {
|
380 |
+
// showStatus(data.message, 'info');
|
381 |
+
});
|
382 |
+
|
383 |
+
socket.on('progress', function(data) {
|
384 |
+
updateProgress(data.progress, data.message);
|
385 |
+
});
|
386 |
+
|
387 |
+
socket.on('frame_ready', function(data) {
|
388 |
+
// Add frame to buffer immediately
|
389 |
+
frameBuffer.push(data);
|
390 |
+
receiveCount++;
|
391 |
+
|
392 |
+
// Calculate receive rate
|
393 |
+
const now = Date.now();
|
394 |
+
if (lastReceiveTime) {
|
395 |
+
const interval = (now - lastReceiveTime) / 1000;
|
396 |
+
receiveRate = (1 / interval).toFixed(1);
|
397 |
+
}
|
398 |
+
lastReceiveTime = now;
|
399 |
+
|
400 |
+
updateBufferInfo();
|
401 |
+
|
402 |
+
// Auto-start playback when we have some frames
|
403 |
+
if (frameBuffer.length === 5 && !isPlaying) {
|
404 |
+
// showStatus('Auto-starting playback with buffer of 5 frames', 'info');
|
405 |
+
startPlayback();
|
406 |
+
}
|
407 |
+
});
|
408 |
+
|
409 |
+
socket.on('generation_complete', function(data) {
|
410 |
+
// showStatus(data.message + ` (Generated in ${data.generation_time})`, 'success');
|
411 |
+
enableControls(true);
|
412 |
+
const duration = startTime ? ((Date.now() - startTime) / 1000).toFixed(1) : 'unknown';
|
413 |
+
updateFrameInfo(`Generation complete! ${data.total_frames} frames in ${duration}s`);
|
414 |
+
|
415 |
+
// Update toggle states after generation
|
416 |
+
updateToggleStates();
|
417 |
+
});
|
418 |
+
|
419 |
+
socket.on('error', function(data) {
|
420 |
+
// showStatus(`Error: ${data.message}`, 'error');
|
421 |
+
enableControls(true);
|
422 |
+
});
|
423 |
+
|
424 |
+
function startGeneration() {
|
425 |
+
const prompt = document.getElementById('prompt').value.trim();
|
426 |
+
if (!prompt) {
|
427 |
+
alert('Please enter a prompt');
|
428 |
+
return;
|
429 |
+
}
|
430 |
+
|
431 |
+
const seed = parseInt(document.getElementById('seed').value) || 31337;
|
432 |
+
// const totalBlocks = parseInt(document.getElementById('blocks').value) || 7;
|
433 |
+
const enableTorchCompile = document.getElementById('torchCompile').checked;
|
434 |
+
const enableFp8 = document.getElementById('fp8Toggle').checked;
|
435 |
+
const useTaehv = document.getElementById('taehvToggle').checked;
|
436 |
+
|
437 |
+
// Reset state
|
438 |
+
frameBuffer = [];
|
439 |
+
currentFrameIndex = 0;
|
440 |
+
receiveCount = 0;
|
441 |
+
receiveRate = 0;
|
442 |
+
stopPlayback();
|
443 |
+
|
444 |
+
enableControls(false);
|
445 |
+
startTime = Date.now();
|
446 |
+
|
447 |
+
socket.emit('start_generation', {
|
448 |
+
prompt: prompt,
|
449 |
+
seed: seed,
|
450 |
+
enable_torch_compile: enableTorchCompile,
|
451 |
+
enable_fp8: enableFp8,
|
452 |
+
use_taehv: useTaehv
|
453 |
+
});
|
454 |
+
}
|
455 |
+
|
456 |
+
function stopGeneration() {
|
457 |
+
socket.emit('stop_generation');
|
458 |
+
enableControls(true);
|
459 |
+
}
|
460 |
+
|
461 |
+
function togglePlayback() {
|
462 |
+
if (isPlaying) {
|
463 |
+
stopPlayback();
|
464 |
+
} else {
|
465 |
+
startPlayback();
|
466 |
+
}
|
467 |
+
}
|
468 |
+
|
469 |
+
function startPlayback() {
|
470 |
+
if (frameBuffer.length === 0) return;
|
471 |
+
|
472 |
+
isPlaying = true;
|
473 |
+
document.getElementById('playBtn').textContent = '⏸️ Pause';
|
474 |
+
document.getElementById('playBtn').disabled = false;
|
475 |
+
document.getElementById('resetBtn').disabled = false;
|
476 |
+
|
477 |
+
updatePlaybackTiming();
|
478 |
+
// showStatus('Playback started', 'info');
|
479 |
+
}
|
480 |
+
|
481 |
+
function stopPlayback() {
|
482 |
+
isPlaying = false;
|
483 |
+
if (playbackInterval) {
|
484 |
+
clearInterval(playbackInterval);
|
485 |
+
playbackInterval = null;
|
486 |
+
}
|
487 |
+
document.getElementById('playBtn').textContent = '▶️ Play';
|
488 |
+
}
|
489 |
+
|
490 |
+
function resetPlayback() {
|
491 |
+
stopPlayback();
|
492 |
+
|
493 |
+
// Clear the entire frame buffer
|
494 |
+
frameBuffer = [];
|
495 |
+
currentFrameIndex = 0;
|
496 |
+
receiveCount = 0;
|
497 |
+
receiveRate = 0;
|
498 |
+
|
499 |
+
// Reset video display to initial state
|
500 |
+
const img = document.getElementById('videoFrame');
|
501 |
+
const placeholder = document.getElementById('placeholderText');
|
502 |
+
|
503 |
+
img.src = '';
|
504 |
+
img.style.display = 'none';
|
505 |
+
placeholder.style.display = 'block';
|
506 |
+
|
507 |
+
// Update UI
|
508 |
+
updateBufferInfo();
|
509 |
+
updateFrameInfo('Reset - buffer cleared');
|
510 |
+
|
511 |
+
// Disable playback controls since there's no content
|
512 |
+
document.getElementById('playBtn').disabled = true;
|
513 |
+
document.getElementById('resetBtn').disabled = true;
|
514 |
+
}
|
515 |
+
|
516 |
+
function updatePlaybackSpeed() {
|
517 |
+
playbackSpeed = parseFloat(document.getElementById('playbackSpeed').value);
|
518 |
+
if (isPlaying) {
|
519 |
+
updatePlaybackTiming();
|
520 |
+
}
|
521 |
+
}
|
522 |
+
|
523 |
+
function updatePlaybackTiming() {
|
524 |
+
if (playbackInterval) {
|
525 |
+
clearInterval(playbackInterval);
|
526 |
+
}
|
527 |
+
|
528 |
+
if (isPlaying) {
|
529 |
+
const interval = (1000 / targetFps) / playbackSpeed;
|
530 |
+
playbackInterval = setInterval(displayNextFrame, interval);
|
531 |
+
}
|
532 |
+
}
|
533 |
+
|
534 |
+
function displayNextFrame() {
|
535 |
+
if (currentFrameIndex >= frameBuffer.length) {
|
536 |
+
// Reached end of buffer
|
537 |
+
if (document.querySelector('#progressFill').style.width === '100%') {
|
538 |
+
// Generation complete, stop playback
|
539 |
+
stopPlayback();
|
540 |
+
// showStatus('Playback complete', 'success');
|
541 |
+
}
|
542 |
+
return;
|
543 |
+
}
|
544 |
+
|
545 |
+
const frameData = frameBuffer[currentFrameIndex];
|
546 |
+
displayFrame(frameData);
|
547 |
+
currentFrameIndex++;
|
548 |
+
|
549 |
+
updateBufferInfo();
|
550 |
+
}
|
551 |
+
|
552 |
+
function displayFrame(frameData) {
|
553 |
+
const img = document.getElementById('videoFrame');
|
554 |
+
const placeholder = document.getElementById('placeholderText');
|
555 |
+
|
556 |
+
img.src = frameData.data;
|
557 |
+
img.style.display = 'block';
|
558 |
+
placeholder.style.display = 'none';
|
559 |
+
|
560 |
+
const elapsed = startTime ? ((Date.now() - startTime) / 1000).toFixed(1) : '0';
|
561 |
+
updateFrameInfo(`Frame ${frameData.frame_index + 1} | Block ${frameData.block_index + 1} | ${elapsed}s elapsed | ${targetFps} FPS @ ${playbackSpeed}x speed`);
|
562 |
+
}
|
563 |
+
|
564 |
+
function updateBufferInfo() {
|
565 |
+
document.getElementById('bufferCount').textContent = frameBuffer.length;
|
566 |
+
document.getElementById('displayedCount').textContent = currentFrameIndex;
|
567 |
+
// document.getElementById('receiveRate').textContent = receiveRate;
|
568 |
+
}
|
569 |
+
|
570 |
+
function setQuickPrompt(type) {
|
571 |
+
const promptBox = document.getElementById('prompt');
|
572 |
+
if (type === 'quick-demo-1') {
|
573 |
+
promptBox.value = 'A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.';
|
574 |
+
} else if (type === 'quick-demo-2') {
|
575 |
+
promptBox.value = 'A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates a warm contrast, accentuating the cat’s orange fur. The shot is clear and sharp, with a shallow depth of field.';
|
576 |
+
}
|
577 |
+
}
|
578 |
+
|
579 |
+
function enableControls(enabled) {
|
580 |
+
document.getElementById('startBtn').disabled = !enabled;
|
581 |
+
document.getElementById('stopBtn').disabled = enabled;
|
582 |
+
}
|
583 |
+
|
584 |
+
function updateProgress(progress, message) {
|
585 |
+
document.getElementById('progressFill').style.width = progress + '%';
|
586 |
+
document.getElementById('progressText').textContent = message;
|
587 |
+
}
|
588 |
+
|
589 |
+
function updateFrameInfo(text) {
|
590 |
+
document.getElementById('frameInfo').textContent = text;
|
591 |
+
}
|
592 |
+
|
593 |
+
function showStatus(message, type) {
|
594 |
+
const container = document.getElementById('statusContainer');
|
595 |
+
const statusDiv = document.createElement('div');
|
596 |
+
statusDiv.className = `status ${type}`;
|
597 |
+
statusDiv.textContent = message;
|
598 |
+
|
599 |
+
container.insertBefore(statusDiv, container.firstChild);
|
600 |
+
|
601 |
+
// Remove old status messages (keep only last 3)
|
602 |
+
while (container.children.length > 3) {
|
603 |
+
container.removeChild(container.lastChild);
|
604 |
+
}
|
605 |
+
|
606 |
+
// Auto-remove after 5 seconds
|
607 |
+
setTimeout(() => {
|
608 |
+
if (statusDiv.parentNode) {
|
609 |
+
statusDiv.parentNode.removeChild(statusDiv);
|
610 |
+
}
|
611 |
+
}, 5000);
|
612 |
+
}
|
613 |
+
</script>
|
614 |
+
</body>
|
615 |
+
</html>
|
train.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
from omegaconf import OmegaConf
|
4 |
+
import wandb
|
5 |
+
|
6 |
+
from trainer import DiffusionTrainer, GANTrainer, ODETrainer, ScoreDistillationTrainer
|
7 |
+
|
8 |
+
|
9 |
+
def main():
|
10 |
+
parser = argparse.ArgumentParser()
|
11 |
+
parser.add_argument("--config_path", type=str, required=True)
|
12 |
+
parser.add_argument("--no_save", action="store_true")
|
13 |
+
parser.add_argument("--no_visualize", action="store_true")
|
14 |
+
parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs")
|
15 |
+
parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs")
|
16 |
+
parser.add_argument("--disable-wandb", action="store_true")
|
17 |
+
|
18 |
+
args = parser.parse_args()
|
19 |
+
|
20 |
+
config = OmegaConf.load(args.config_path)
|
21 |
+
default_config = OmegaConf.load("configs/default_config.yaml")
|
22 |
+
config = OmegaConf.merge(default_config, config)
|
23 |
+
config.no_save = args.no_save
|
24 |
+
config.no_visualize = args.no_visualize
|
25 |
+
|
26 |
+
# get the filename of config_path
|
27 |
+
config_name = os.path.basename(args.config_path).split(".")[0]
|
28 |
+
config.config_name = config_name
|
29 |
+
config.logdir = args.logdir
|
30 |
+
config.wandb_save_dir = args.wandb_save_dir
|
31 |
+
config.disable_wandb = args.disable_wandb
|
32 |
+
|
33 |
+
if config.trainer == "diffusion":
|
34 |
+
trainer = DiffusionTrainer(config)
|
35 |
+
elif config.trainer == "gan":
|
36 |
+
trainer = GANTrainer(config)
|
37 |
+
elif config.trainer == "ode":
|
38 |
+
trainer = ODETrainer(config)
|
39 |
+
elif config.trainer == "score_distillation":
|
40 |
+
trainer = ScoreDistillationTrainer(config)
|
41 |
+
trainer.train()
|
42 |
+
|
43 |
+
wandb.finish()
|
44 |
+
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
main()
|
trainer/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .diffusion import Trainer as DiffusionTrainer
|
2 |
+
from .gan import Trainer as GANTrainer
|
3 |
+
from .ode import Trainer as ODETrainer
|
4 |
+
from .distillation import Trainer as ScoreDistillationTrainer
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
"DiffusionTrainer",
|
8 |
+
"GANTrainer",
|
9 |
+
"ODETrainer",
|
10 |
+
"ScoreDistillationTrainer"
|
11 |
+
]
|
trainer/diffusion.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from model import CausalDiffusion
|
5 |
+
from utils.dataset import ShardingLMDBDataset, cycle
|
6 |
+
from utils.misc import set_seed
|
7 |
+
import torch.distributed as dist
|
8 |
+
from omegaconf import OmegaConf
|
9 |
+
import torch
|
10 |
+
import wandb
|
11 |
+
import time
|
12 |
+
import os
|
13 |
+
|
14 |
+
from utils.distributed import EMA_FSDP, barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
15 |
+
|
16 |
+
|
17 |
+
class Trainer:
|
18 |
+
def __init__(self, config):
|
19 |
+
self.config = config
|
20 |
+
self.step = 0
|
21 |
+
|
22 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
23 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
24 |
+
torch.backends.cudnn.allow_tf32 = True
|
25 |
+
|
26 |
+
launch_distributed_job()
|
27 |
+
global_rank = dist.get_rank()
|
28 |
+
|
29 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
30 |
+
self.device = torch.cuda.current_device()
|
31 |
+
self.is_main_process = global_rank == 0
|
32 |
+
self.causal = config.causal
|
33 |
+
self.disable_wandb = config.disable_wandb
|
34 |
+
|
35 |
+
# use a random seed for the training
|
36 |
+
if config.seed == 0:
|
37 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
38 |
+
dist.broadcast(random_seed, src=0)
|
39 |
+
config.seed = random_seed.item()
|
40 |
+
|
41 |
+
set_seed(config.seed + global_rank)
|
42 |
+
|
43 |
+
if self.is_main_process and not self.disable_wandb:
|
44 |
+
wandb.login(host=config.wandb_host, key=config.wandb_key)
|
45 |
+
wandb.init(
|
46 |
+
config=OmegaConf.to_container(config, resolve=True),
|
47 |
+
name=config.config_name,
|
48 |
+
mode="online",
|
49 |
+
entity=config.wandb_entity,
|
50 |
+
project=config.wandb_project,
|
51 |
+
dir=config.wandb_save_dir
|
52 |
+
)
|
53 |
+
|
54 |
+
self.output_path = config.logdir
|
55 |
+
|
56 |
+
# Step 2: Initialize the model and optimizer
|
57 |
+
self.model = CausalDiffusion(config, device=self.device)
|
58 |
+
self.model.generator = fsdp_wrap(
|
59 |
+
self.model.generator,
|
60 |
+
sharding_strategy=config.sharding_strategy,
|
61 |
+
mixed_precision=config.mixed_precision,
|
62 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
63 |
+
)
|
64 |
+
|
65 |
+
self.model.text_encoder = fsdp_wrap(
|
66 |
+
self.model.text_encoder,
|
67 |
+
sharding_strategy=config.sharding_strategy,
|
68 |
+
mixed_precision=config.mixed_precision,
|
69 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy
|
70 |
+
)
|
71 |
+
|
72 |
+
if not config.no_visualize or config.load_raw_video:
|
73 |
+
self.model.vae = self.model.vae.to(
|
74 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
75 |
+
|
76 |
+
self.generator_optimizer = torch.optim.AdamW(
|
77 |
+
[param for param in self.model.generator.parameters()
|
78 |
+
if param.requires_grad],
|
79 |
+
lr=config.lr,
|
80 |
+
betas=(config.beta1, config.beta2),
|
81 |
+
weight_decay=config.weight_decay
|
82 |
+
)
|
83 |
+
|
84 |
+
# Step 3: Initialize the dataloader
|
85 |
+
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
|
86 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
87 |
+
dataset, shuffle=True, drop_last=True)
|
88 |
+
dataloader = torch.utils.data.DataLoader(
|
89 |
+
dataset,
|
90 |
+
batch_size=config.batch_size,
|
91 |
+
sampler=sampler,
|
92 |
+
num_workers=8)
|
93 |
+
|
94 |
+
if dist.get_rank() == 0:
|
95 |
+
print("DATASET SIZE %d" % len(dataset))
|
96 |
+
self.dataloader = cycle(dataloader)
|
97 |
+
|
98 |
+
##############################################################################################################
|
99 |
+
# 6. Set up EMA parameter containers
|
100 |
+
rename_param = (
|
101 |
+
lambda name: name.replace("_fsdp_wrapped_module.", "")
|
102 |
+
.replace("_checkpoint_wrapped_module.", "")
|
103 |
+
.replace("_orig_mod.", "")
|
104 |
+
)
|
105 |
+
self.name_to_trainable_params = {}
|
106 |
+
for n, p in self.model.generator.named_parameters():
|
107 |
+
if not p.requires_grad:
|
108 |
+
continue
|
109 |
+
|
110 |
+
renamed_n = rename_param(n)
|
111 |
+
self.name_to_trainable_params[renamed_n] = p
|
112 |
+
ema_weight = config.ema_weight
|
113 |
+
self.generator_ema = None
|
114 |
+
if (ema_weight is not None) and (ema_weight > 0.0):
|
115 |
+
print(f"Setting up EMA with weight {ema_weight}")
|
116 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
|
117 |
+
|
118 |
+
##############################################################################################################
|
119 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
120 |
+
if getattr(config, "generator_ckpt", False):
|
121 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
122 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
|
123 |
+
if "generator" in state_dict:
|
124 |
+
state_dict = state_dict["generator"]
|
125 |
+
elif "model" in state_dict:
|
126 |
+
state_dict = state_dict["model"]
|
127 |
+
self.model.generator.load_state_dict(
|
128 |
+
state_dict, strict=True
|
129 |
+
)
|
130 |
+
|
131 |
+
##############################################################################################################
|
132 |
+
|
133 |
+
# Let's delete EMA params for early steps to save some computes at training and inference
|
134 |
+
if self.step < config.ema_start_step:
|
135 |
+
self.generator_ema = None
|
136 |
+
|
137 |
+
self.max_grad_norm = 10.0
|
138 |
+
self.previous_time = None
|
139 |
+
|
140 |
+
def save(self):
|
141 |
+
print("Start gathering distributed model states...")
|
142 |
+
generator_state_dict = fsdp_state_dict(
|
143 |
+
self.model.generator)
|
144 |
+
|
145 |
+
if self.config.ema_start_step < self.step:
|
146 |
+
state_dict = {
|
147 |
+
"generator": generator_state_dict,
|
148 |
+
"generator_ema": self.generator_ema.state_dict(),
|
149 |
+
}
|
150 |
+
else:
|
151 |
+
state_dict = {
|
152 |
+
"generator": generator_state_dict,
|
153 |
+
}
|
154 |
+
|
155 |
+
if self.is_main_process:
|
156 |
+
os.makedirs(os.path.join(self.output_path,
|
157 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
158 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
159 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
160 |
+
print("Model saved to", os.path.join(self.output_path,
|
161 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
162 |
+
|
163 |
+
def train_one_step(self, batch):
|
164 |
+
self.log_iters = 1
|
165 |
+
|
166 |
+
if self.step % 20 == 0:
|
167 |
+
torch.cuda.empty_cache()
|
168 |
+
|
169 |
+
# Step 1: Get the next batch of text prompts
|
170 |
+
text_prompts = batch["prompts"]
|
171 |
+
if not self.config.load_raw_video: # precomputed latent
|
172 |
+
clean_latent = batch["ode_latent"][:, -1].to(
|
173 |
+
device=self.device, dtype=self.dtype)
|
174 |
+
else: # encode raw video to latent
|
175 |
+
frames = batch["frames"].to(
|
176 |
+
device=self.device, dtype=self.dtype)
|
177 |
+
with torch.no_grad():
|
178 |
+
clean_latent = self.model.vae.encode_to_latent(
|
179 |
+
frames).to(device=self.device, dtype=self.dtype)
|
180 |
+
image_latent = clean_latent[:, 0:1, ]
|
181 |
+
|
182 |
+
batch_size = len(text_prompts)
|
183 |
+
image_or_video_shape = list(self.config.image_or_video_shape)
|
184 |
+
image_or_video_shape[0] = batch_size
|
185 |
+
|
186 |
+
# Step 2: Extract the conditional infos
|
187 |
+
with torch.no_grad():
|
188 |
+
conditional_dict = self.model.text_encoder(
|
189 |
+
text_prompts=text_prompts)
|
190 |
+
|
191 |
+
if not getattr(self, "unconditional_dict", None):
|
192 |
+
unconditional_dict = self.model.text_encoder(
|
193 |
+
text_prompts=[self.config.negative_prompt] * batch_size)
|
194 |
+
unconditional_dict = {k: v.detach()
|
195 |
+
for k, v in unconditional_dict.items()}
|
196 |
+
self.unconditional_dict = unconditional_dict # cache the unconditional_dict
|
197 |
+
else:
|
198 |
+
unconditional_dict = self.unconditional_dict
|
199 |
+
|
200 |
+
# Step 3: Train the generator
|
201 |
+
generator_loss, log_dict = self.model.generator_loss(
|
202 |
+
image_or_video_shape=image_or_video_shape,
|
203 |
+
conditional_dict=conditional_dict,
|
204 |
+
unconditional_dict=unconditional_dict,
|
205 |
+
clean_latent=clean_latent,
|
206 |
+
initial_latent=image_latent
|
207 |
+
)
|
208 |
+
self.generator_optimizer.zero_grad()
|
209 |
+
generator_loss.backward()
|
210 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
211 |
+
self.max_grad_norm)
|
212 |
+
self.generator_optimizer.step()
|
213 |
+
|
214 |
+
# Increment the step since we finished gradient update
|
215 |
+
self.step += 1
|
216 |
+
|
217 |
+
wandb_loss_dict = {
|
218 |
+
"generator_loss": generator_loss.item(),
|
219 |
+
"generator_grad_norm": generator_grad_norm.item(),
|
220 |
+
}
|
221 |
+
|
222 |
+
# Step 4: Logging
|
223 |
+
if self.is_main_process:
|
224 |
+
if not self.disable_wandb:
|
225 |
+
wandb.log(wandb_loss_dict, step=self.step)
|
226 |
+
|
227 |
+
if self.step % self.config.gc_interval == 0:
|
228 |
+
if dist.get_rank() == 0:
|
229 |
+
logging.info("DistGarbageCollector: Running GC.")
|
230 |
+
gc.collect()
|
231 |
+
|
232 |
+
# Step 5. Create EMA params
|
233 |
+
# TODO: Implement EMA
|
234 |
+
|
235 |
+
def generate_video(self, pipeline, prompts, image=None):
|
236 |
+
batch_size = len(prompts)
|
237 |
+
sampled_noise = torch.randn(
|
238 |
+
[batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
|
239 |
+
)
|
240 |
+
video, _ = pipeline.inference(
|
241 |
+
noise=sampled_noise,
|
242 |
+
text_prompts=prompts,
|
243 |
+
return_latents=True
|
244 |
+
)
|
245 |
+
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
|
246 |
+
return current_video
|
247 |
+
|
248 |
+
def train(self):
|
249 |
+
while True:
|
250 |
+
batch = next(self.dataloader)
|
251 |
+
self.train_one_step(batch)
|
252 |
+
if (not self.config.no_save) and self.step % self.config.log_iters == 0:
|
253 |
+
torch.cuda.empty_cache()
|
254 |
+
self.save()
|
255 |
+
torch.cuda.empty_cache()
|
256 |
+
|
257 |
+
barrier()
|
258 |
+
if self.is_main_process:
|
259 |
+
current_time = time.time()
|
260 |
+
if self.previous_time is None:
|
261 |
+
self.previous_time = current_time
|
262 |
+
else:
|
263 |
+
if not self.disable_wandb:
|
264 |
+
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
|
265 |
+
self.previous_time = current_time
|
trainer/distillation.py
ADDED
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from utils.dataset import ShardingLMDBDataset, cycle
|
5 |
+
from utils.dataset import TextDataset
|
6 |
+
from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
7 |
+
from utils.misc import (
|
8 |
+
set_seed,
|
9 |
+
merge_dict_list
|
10 |
+
)
|
11 |
+
import torch.distributed as dist
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from model import CausVid, DMD, SiD
|
14 |
+
import torch
|
15 |
+
import wandb
|
16 |
+
import time
|
17 |
+
import os
|
18 |
+
|
19 |
+
|
20 |
+
class Trainer:
|
21 |
+
def __init__(self, config):
|
22 |
+
self.config = config
|
23 |
+
self.step = 0
|
24 |
+
|
25 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
26 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
27 |
+
torch.backends.cudnn.allow_tf32 = True
|
28 |
+
|
29 |
+
launch_distributed_job()
|
30 |
+
global_rank = dist.get_rank()
|
31 |
+
self.world_size = dist.get_world_size()
|
32 |
+
|
33 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
34 |
+
self.device = torch.cuda.current_device()
|
35 |
+
self.is_main_process = global_rank == 0
|
36 |
+
self.causal = config.causal
|
37 |
+
self.disable_wandb = config.disable_wandb
|
38 |
+
|
39 |
+
# use a random seed for the training
|
40 |
+
if config.seed == 0:
|
41 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
42 |
+
dist.broadcast(random_seed, src=0)
|
43 |
+
config.seed = random_seed.item()
|
44 |
+
|
45 |
+
set_seed(config.seed + global_rank)
|
46 |
+
|
47 |
+
if self.is_main_process and not self.disable_wandb:
|
48 |
+
wandb.login(host=config.wandb_host, key=config.wandb_key)
|
49 |
+
wandb.init(
|
50 |
+
config=OmegaConf.to_container(config, resolve=True),
|
51 |
+
name=config.config_name,
|
52 |
+
mode="online",
|
53 |
+
entity=config.wandb_entity,
|
54 |
+
project=config.wandb_project,
|
55 |
+
dir=config.wandb_save_dir
|
56 |
+
)
|
57 |
+
|
58 |
+
self.output_path = config.logdir
|
59 |
+
|
60 |
+
# Step 2: Initialize the model and optimizer
|
61 |
+
if config.distribution_loss == "causvid":
|
62 |
+
self.model = CausVid(config, device=self.device)
|
63 |
+
elif config.distribution_loss == "dmd":
|
64 |
+
self.model = DMD(config, device=self.device)
|
65 |
+
elif config.distribution_loss == "sid":
|
66 |
+
self.model = SiD(config, device=self.device)
|
67 |
+
else:
|
68 |
+
raise ValueError("Invalid distribution matching loss")
|
69 |
+
|
70 |
+
# Save pretrained model state_dicts to CPU
|
71 |
+
self.fake_score_state_dict_cpu = self.model.fake_score.state_dict()
|
72 |
+
|
73 |
+
self.model.generator = fsdp_wrap(
|
74 |
+
self.model.generator,
|
75 |
+
sharding_strategy=config.sharding_strategy,
|
76 |
+
mixed_precision=config.mixed_precision,
|
77 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
78 |
+
)
|
79 |
+
|
80 |
+
self.model.real_score = fsdp_wrap(
|
81 |
+
self.model.real_score,
|
82 |
+
sharding_strategy=config.sharding_strategy,
|
83 |
+
mixed_precision=config.mixed_precision,
|
84 |
+
wrap_strategy=config.real_score_fsdp_wrap_strategy
|
85 |
+
)
|
86 |
+
|
87 |
+
self.model.fake_score = fsdp_wrap(
|
88 |
+
self.model.fake_score,
|
89 |
+
sharding_strategy=config.sharding_strategy,
|
90 |
+
mixed_precision=config.mixed_precision,
|
91 |
+
wrap_strategy=config.fake_score_fsdp_wrap_strategy
|
92 |
+
)
|
93 |
+
|
94 |
+
self.model.text_encoder = fsdp_wrap(
|
95 |
+
self.model.text_encoder,
|
96 |
+
sharding_strategy=config.sharding_strategy,
|
97 |
+
mixed_precision=config.mixed_precision,
|
98 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
|
99 |
+
cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
|
100 |
+
)
|
101 |
+
|
102 |
+
if not config.no_visualize or config.load_raw_video:
|
103 |
+
self.model.vae = self.model.vae.to(
|
104 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
105 |
+
|
106 |
+
self.generator_optimizer = torch.optim.AdamW(
|
107 |
+
[param for param in self.model.generator.parameters()
|
108 |
+
if param.requires_grad],
|
109 |
+
lr=config.lr,
|
110 |
+
betas=(config.beta1, config.beta2),
|
111 |
+
weight_decay=config.weight_decay
|
112 |
+
)
|
113 |
+
|
114 |
+
self.critic_optimizer = torch.optim.AdamW(
|
115 |
+
[param for param in self.model.fake_score.parameters()
|
116 |
+
if param.requires_grad],
|
117 |
+
lr=config.lr_critic if hasattr(config, "lr_critic") else config.lr,
|
118 |
+
betas=(config.beta1_critic, config.beta2_critic),
|
119 |
+
weight_decay=config.weight_decay
|
120 |
+
)
|
121 |
+
|
122 |
+
# Step 3: Initialize the dataloader
|
123 |
+
if self.config.i2v:
|
124 |
+
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
|
125 |
+
else:
|
126 |
+
dataset = TextDataset(config.data_path)
|
127 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
128 |
+
dataset, shuffle=True, drop_last=True)
|
129 |
+
dataloader = torch.utils.data.DataLoader(
|
130 |
+
dataset,
|
131 |
+
batch_size=config.batch_size,
|
132 |
+
sampler=sampler,
|
133 |
+
num_workers=8)
|
134 |
+
|
135 |
+
if dist.get_rank() == 0:
|
136 |
+
print("DATASET SIZE %d" % len(dataset))
|
137 |
+
self.dataloader = cycle(dataloader)
|
138 |
+
|
139 |
+
##############################################################################################################
|
140 |
+
# 6. Set up EMA parameter containers
|
141 |
+
rename_param = (
|
142 |
+
lambda name: name.replace("_fsdp_wrapped_module.", "")
|
143 |
+
.replace("_checkpoint_wrapped_module.", "")
|
144 |
+
.replace("_orig_mod.", "")
|
145 |
+
)
|
146 |
+
self.name_to_trainable_params = {}
|
147 |
+
for n, p in self.model.generator.named_parameters():
|
148 |
+
if not p.requires_grad:
|
149 |
+
continue
|
150 |
+
|
151 |
+
renamed_n = rename_param(n)
|
152 |
+
self.name_to_trainable_params[renamed_n] = p
|
153 |
+
ema_weight = config.ema_weight
|
154 |
+
self.generator_ema = None
|
155 |
+
if (ema_weight is not None) and (ema_weight > 0.0):
|
156 |
+
print(f"Setting up EMA with weight {ema_weight}")
|
157 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
|
158 |
+
|
159 |
+
##############################################################################################################
|
160 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
161 |
+
if getattr(config, "generator_ckpt", False):
|
162 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
163 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
|
164 |
+
if "generator" in state_dict:
|
165 |
+
state_dict = state_dict["generator"]
|
166 |
+
elif "model" in state_dict:
|
167 |
+
state_dict = state_dict["model"]
|
168 |
+
self.model.generator.load_state_dict(
|
169 |
+
state_dict, strict=True
|
170 |
+
)
|
171 |
+
|
172 |
+
##############################################################################################################
|
173 |
+
|
174 |
+
# Let's delete EMA params for early steps to save some computes at training and inference
|
175 |
+
if self.step < config.ema_start_step:
|
176 |
+
self.generator_ema = None
|
177 |
+
|
178 |
+
self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
|
179 |
+
self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
|
180 |
+
self.previous_time = None
|
181 |
+
|
182 |
+
def save(self):
|
183 |
+
print("Start gathering distributed model states...")
|
184 |
+
generator_state_dict = fsdp_state_dict(
|
185 |
+
self.model.generator)
|
186 |
+
critic_state_dict = fsdp_state_dict(
|
187 |
+
self.model.fake_score)
|
188 |
+
|
189 |
+
if self.config.ema_start_step < self.step:
|
190 |
+
state_dict = {
|
191 |
+
"generator": generator_state_dict,
|
192 |
+
"critic": critic_state_dict,
|
193 |
+
"generator_ema": self.generator_ema.state_dict(),
|
194 |
+
}
|
195 |
+
else:
|
196 |
+
state_dict = {
|
197 |
+
"generator": generator_state_dict,
|
198 |
+
"critic": critic_state_dict,
|
199 |
+
}
|
200 |
+
|
201 |
+
if self.is_main_process:
|
202 |
+
os.makedirs(os.path.join(self.output_path,
|
203 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
204 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
205 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
206 |
+
print("Model saved to", os.path.join(self.output_path,
|
207 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
208 |
+
|
209 |
+
def fwdbwd_one_step(self, batch, train_generator):
|
210 |
+
self.model.eval() # prevent any randomness (e.g. dropout)
|
211 |
+
|
212 |
+
if self.step % 20 == 0:
|
213 |
+
torch.cuda.empty_cache()
|
214 |
+
|
215 |
+
# Step 1: Get the next batch of text prompts
|
216 |
+
text_prompts = batch["prompts"]
|
217 |
+
if self.config.i2v:
|
218 |
+
clean_latent = None
|
219 |
+
image_latent = batch["ode_latent"][:, -1][:, 0:1, ].to(
|
220 |
+
device=self.device, dtype=self.dtype)
|
221 |
+
else:
|
222 |
+
clean_latent = None
|
223 |
+
image_latent = None
|
224 |
+
|
225 |
+
batch_size = len(text_prompts)
|
226 |
+
image_or_video_shape = list(self.config.image_or_video_shape)
|
227 |
+
image_or_video_shape[0] = batch_size
|
228 |
+
|
229 |
+
# Step 2: Extract the conditional infos
|
230 |
+
with torch.no_grad():
|
231 |
+
conditional_dict = self.model.text_encoder(
|
232 |
+
text_prompts=text_prompts)
|
233 |
+
|
234 |
+
if not getattr(self, "unconditional_dict", None):
|
235 |
+
unconditional_dict = self.model.text_encoder(
|
236 |
+
text_prompts=[self.config.negative_prompt] * batch_size)
|
237 |
+
unconditional_dict = {k: v.detach()
|
238 |
+
for k, v in unconditional_dict.items()}
|
239 |
+
self.unconditional_dict = unconditional_dict # cache the unconditional_dict
|
240 |
+
else:
|
241 |
+
unconditional_dict = self.unconditional_dict
|
242 |
+
|
243 |
+
# Step 3: Store gradients for the generator (if training the generator)
|
244 |
+
if train_generator:
|
245 |
+
generator_loss, generator_log_dict = self.model.generator_loss(
|
246 |
+
image_or_video_shape=image_or_video_shape,
|
247 |
+
conditional_dict=conditional_dict,
|
248 |
+
unconditional_dict=unconditional_dict,
|
249 |
+
clean_latent=clean_latent,
|
250 |
+
initial_latent=image_latent if self.config.i2v else None
|
251 |
+
)
|
252 |
+
|
253 |
+
generator_loss.backward()
|
254 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
255 |
+
self.max_grad_norm_generator)
|
256 |
+
|
257 |
+
generator_log_dict.update({"generator_loss": generator_loss,
|
258 |
+
"generator_grad_norm": generator_grad_norm})
|
259 |
+
|
260 |
+
return generator_log_dict
|
261 |
+
else:
|
262 |
+
generator_log_dict = {}
|
263 |
+
|
264 |
+
# Step 4: Store gradients for the critic (if training the critic)
|
265 |
+
critic_loss, critic_log_dict = self.model.critic_loss(
|
266 |
+
image_or_video_shape=image_or_video_shape,
|
267 |
+
conditional_dict=conditional_dict,
|
268 |
+
unconditional_dict=unconditional_dict,
|
269 |
+
clean_latent=clean_latent,
|
270 |
+
initial_latent=image_latent if self.config.i2v else None
|
271 |
+
)
|
272 |
+
|
273 |
+
critic_loss.backward()
|
274 |
+
critic_grad_norm = self.model.fake_score.clip_grad_norm_(
|
275 |
+
self.max_grad_norm_critic)
|
276 |
+
|
277 |
+
critic_log_dict.update({"critic_loss": critic_loss,
|
278 |
+
"critic_grad_norm": critic_grad_norm})
|
279 |
+
|
280 |
+
return critic_log_dict
|
281 |
+
|
282 |
+
def generate_video(self, pipeline, prompts, image=None):
|
283 |
+
batch_size = len(prompts)
|
284 |
+
if image is not None:
|
285 |
+
image = image.squeeze(0).unsqueeze(0).unsqueeze(2).to(device="cuda", dtype=torch.bfloat16)
|
286 |
+
|
287 |
+
# Encode the input image as the first latent
|
288 |
+
initial_latent = pipeline.vae.encode_to_latent(image).to(device="cuda", dtype=torch.bfloat16)
|
289 |
+
initial_latent = initial_latent.repeat(batch_size, 1, 1, 1, 1)
|
290 |
+
sampled_noise = torch.randn(
|
291 |
+
[batch_size, self.model.num_training_frames - 1, 16, 60, 104],
|
292 |
+
device="cuda",
|
293 |
+
dtype=self.dtype
|
294 |
+
)
|
295 |
+
else:
|
296 |
+
initial_latent = None
|
297 |
+
sampled_noise = torch.randn(
|
298 |
+
[batch_size, self.model.num_training_frames, 16, 60, 104],
|
299 |
+
device="cuda",
|
300 |
+
dtype=self.dtype
|
301 |
+
)
|
302 |
+
|
303 |
+
video, _ = pipeline.inference(
|
304 |
+
noise=sampled_noise,
|
305 |
+
text_prompts=prompts,
|
306 |
+
return_latents=True,
|
307 |
+
initial_latent=initial_latent
|
308 |
+
)
|
309 |
+
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
|
310 |
+
return current_video
|
311 |
+
|
312 |
+
def train(self):
|
313 |
+
start_step = self.step
|
314 |
+
|
315 |
+
while True:
|
316 |
+
TRAIN_GENERATOR = self.step % self.config.dfake_gen_update_ratio == 0
|
317 |
+
|
318 |
+
# Train the generator
|
319 |
+
if TRAIN_GENERATOR:
|
320 |
+
self.generator_optimizer.zero_grad(set_to_none=True)
|
321 |
+
extras_list = []
|
322 |
+
batch = next(self.dataloader)
|
323 |
+
extra = self.fwdbwd_one_step(batch, True)
|
324 |
+
extras_list.append(extra)
|
325 |
+
generator_log_dict = merge_dict_list(extras_list)
|
326 |
+
self.generator_optimizer.step()
|
327 |
+
if self.generator_ema is not None:
|
328 |
+
self.generator_ema.update(self.model.generator)
|
329 |
+
|
330 |
+
# Train the critic
|
331 |
+
self.critic_optimizer.zero_grad(set_to_none=True)
|
332 |
+
extras_list = []
|
333 |
+
batch = next(self.dataloader)
|
334 |
+
extra = self.fwdbwd_one_step(batch, False)
|
335 |
+
extras_list.append(extra)
|
336 |
+
critic_log_dict = merge_dict_list(extras_list)
|
337 |
+
self.critic_optimizer.step()
|
338 |
+
|
339 |
+
# Increment the step since we finished gradient update
|
340 |
+
self.step += 1
|
341 |
+
|
342 |
+
# Create EMA params (if not already created)
|
343 |
+
if (self.step >= self.config.ema_start_step) and \
|
344 |
+
(self.generator_ema is None) and (self.config.ema_weight > 0):
|
345 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
|
346 |
+
|
347 |
+
# Save the model
|
348 |
+
if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
|
349 |
+
torch.cuda.empty_cache()
|
350 |
+
self.save()
|
351 |
+
torch.cuda.empty_cache()
|
352 |
+
|
353 |
+
# Logging
|
354 |
+
if self.is_main_process:
|
355 |
+
wandb_loss_dict = {}
|
356 |
+
if TRAIN_GENERATOR:
|
357 |
+
wandb_loss_dict.update(
|
358 |
+
{
|
359 |
+
"generator_loss": generator_log_dict["generator_loss"].mean().item(),
|
360 |
+
"generator_grad_norm": generator_log_dict["generator_grad_norm"].mean().item(),
|
361 |
+
"dmdtrain_gradient_norm": generator_log_dict["dmdtrain_gradient_norm"].mean().item()
|
362 |
+
}
|
363 |
+
)
|
364 |
+
|
365 |
+
wandb_loss_dict.update(
|
366 |
+
{
|
367 |
+
"critic_loss": critic_log_dict["critic_loss"].mean().item(),
|
368 |
+
"critic_grad_norm": critic_log_dict["critic_grad_norm"].mean().item()
|
369 |
+
}
|
370 |
+
)
|
371 |
+
|
372 |
+
if not self.disable_wandb:
|
373 |
+
wandb.log(wandb_loss_dict, step=self.step)
|
374 |
+
|
375 |
+
if self.step % self.config.gc_interval == 0:
|
376 |
+
if dist.get_rank() == 0:
|
377 |
+
logging.info("DistGarbageCollector: Running GC.")
|
378 |
+
gc.collect()
|
379 |
+
torch.cuda.empty_cache()
|
380 |
+
|
381 |
+
if self.is_main_process:
|
382 |
+
current_time = time.time()
|
383 |
+
if self.previous_time is None:
|
384 |
+
self.previous_time = current_time
|
385 |
+
else:
|
386 |
+
if not self.disable_wandb:
|
387 |
+
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
|
388 |
+
self.previous_time = current_time
|
trainer/gan.py
ADDED
@@ -0,0 +1,464 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
|
4 |
+
from utils.dataset import ShardingLMDBDataset, cycle
|
5 |
+
from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
6 |
+
from utils.misc import (
|
7 |
+
set_seed,
|
8 |
+
merge_dict_list
|
9 |
+
)
|
10 |
+
import torch.distributed as dist
|
11 |
+
from omegaconf import OmegaConf
|
12 |
+
from model import GAN
|
13 |
+
import torch
|
14 |
+
import wandb
|
15 |
+
import time
|
16 |
+
import os
|
17 |
+
|
18 |
+
|
19 |
+
class Trainer:
|
20 |
+
def __init__(self, config):
|
21 |
+
self.config = config
|
22 |
+
self.step = 0
|
23 |
+
|
24 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
25 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
26 |
+
torch.backends.cudnn.allow_tf32 = True
|
27 |
+
|
28 |
+
launch_distributed_job()
|
29 |
+
global_rank = dist.get_rank()
|
30 |
+
self.world_size = dist.get_world_size()
|
31 |
+
|
32 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
33 |
+
self.device = torch.cuda.current_device()
|
34 |
+
self.is_main_process = global_rank == 0
|
35 |
+
self.causal = config.causal
|
36 |
+
self.disable_wandb = config.disable_wandb
|
37 |
+
|
38 |
+
# Configuration for discriminator warmup
|
39 |
+
self.discriminator_warmup_steps = getattr(config, "discriminator_warmup_steps", 0)
|
40 |
+
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
|
41 |
+
if self.in_discriminator_warmup and self.is_main_process:
|
42 |
+
print(f"Starting with discriminator warmup for {self.discriminator_warmup_steps} steps")
|
43 |
+
self.loss_scale = getattr(config, "loss_scale", 1.0)
|
44 |
+
|
45 |
+
# use a random seed for the training
|
46 |
+
if config.seed == 0:
|
47 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
48 |
+
dist.broadcast(random_seed, src=0)
|
49 |
+
config.seed = random_seed.item()
|
50 |
+
|
51 |
+
set_seed(config.seed + global_rank)
|
52 |
+
|
53 |
+
if self.is_main_process and not self.disable_wandb:
|
54 |
+
wandb.login(host=config.wandb_host, key=config.wandb_key)
|
55 |
+
wandb.init(
|
56 |
+
config=OmegaConf.to_container(config, resolve=True),
|
57 |
+
name=config.config_name,
|
58 |
+
mode="online",
|
59 |
+
entity=config.wandb_entity,
|
60 |
+
project=config.wandb_project,
|
61 |
+
dir=config.wandb_save_dir
|
62 |
+
)
|
63 |
+
|
64 |
+
self.output_path = config.logdir
|
65 |
+
|
66 |
+
# Step 2: Initialize the model and optimizer
|
67 |
+
self.model = GAN(config, device=self.device)
|
68 |
+
|
69 |
+
self.model.generator = fsdp_wrap(
|
70 |
+
self.model.generator,
|
71 |
+
sharding_strategy=config.sharding_strategy,
|
72 |
+
mixed_precision=config.mixed_precision,
|
73 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
74 |
+
)
|
75 |
+
|
76 |
+
self.model.fake_score = fsdp_wrap(
|
77 |
+
self.model.fake_score,
|
78 |
+
sharding_strategy=config.sharding_strategy,
|
79 |
+
mixed_precision=config.mixed_precision,
|
80 |
+
wrap_strategy=config.fake_score_fsdp_wrap_strategy
|
81 |
+
)
|
82 |
+
|
83 |
+
self.model.text_encoder = fsdp_wrap(
|
84 |
+
self.model.text_encoder,
|
85 |
+
sharding_strategy=config.sharding_strategy,
|
86 |
+
mixed_precision=config.mixed_precision,
|
87 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
|
88 |
+
cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
|
89 |
+
)
|
90 |
+
|
91 |
+
if not config.no_visualize or config.load_raw_video:
|
92 |
+
self.model.vae = self.model.vae.to(
|
93 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
94 |
+
|
95 |
+
self.generator_optimizer = torch.optim.AdamW(
|
96 |
+
[param for param in self.model.generator.parameters()
|
97 |
+
if param.requires_grad],
|
98 |
+
lr=config.gen_lr,
|
99 |
+
betas=(config.beta1, config.beta2)
|
100 |
+
)
|
101 |
+
|
102 |
+
# Create separate parameter groups for the fake_score network
|
103 |
+
# One group for parameters with "_cls_pred_branch" or "_gan_ca_blocks" in the name
|
104 |
+
# and another group for all other parameters
|
105 |
+
fake_score_params = []
|
106 |
+
discriminator_params = []
|
107 |
+
|
108 |
+
for name, param in self.model.fake_score.named_parameters():
|
109 |
+
if param.requires_grad:
|
110 |
+
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
|
111 |
+
discriminator_params.append(param)
|
112 |
+
else:
|
113 |
+
fake_score_params.append(param)
|
114 |
+
|
115 |
+
# Use the special learning rate for the special parameter group
|
116 |
+
# and the default critic learning rate for other parameters
|
117 |
+
self.critic_param_groups = [
|
118 |
+
{'params': fake_score_params, 'lr': config.critic_lr},
|
119 |
+
{'params': discriminator_params, 'lr': config.critic_lr * config.discriminator_lr_multiplier}
|
120 |
+
]
|
121 |
+
if self.in_discriminator_warmup:
|
122 |
+
self.critic_optimizer = torch.optim.AdamW(
|
123 |
+
self.critic_param_groups,
|
124 |
+
betas=(0.9, config.beta2_critic)
|
125 |
+
)
|
126 |
+
else:
|
127 |
+
self.critic_optimizer = torch.optim.AdamW(
|
128 |
+
self.critic_param_groups,
|
129 |
+
betas=(config.beta1_critic, config.beta2_critic)
|
130 |
+
)
|
131 |
+
|
132 |
+
# Step 3: Initialize the dataloader
|
133 |
+
self.data_path = config.data_path
|
134 |
+
dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
|
135 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
136 |
+
dataset, shuffle=True, drop_last=True)
|
137 |
+
dataloader = torch.utils.data.DataLoader(
|
138 |
+
dataset,
|
139 |
+
batch_size=config.batch_size,
|
140 |
+
sampler=sampler,
|
141 |
+
num_workers=8)
|
142 |
+
|
143 |
+
if dist.get_rank() == 0:
|
144 |
+
print("DATASET SIZE %d" % len(dataset))
|
145 |
+
|
146 |
+
self.dataloader = cycle(dataloader)
|
147 |
+
|
148 |
+
##############################################################################################################
|
149 |
+
# 6. Set up EMA parameter containers
|
150 |
+
rename_param = (
|
151 |
+
lambda name: name.replace("_fsdp_wrapped_module.", "")
|
152 |
+
.replace("_checkpoint_wrapped_module.", "")
|
153 |
+
.replace("_orig_mod.", "")
|
154 |
+
)
|
155 |
+
self.name_to_trainable_params = {}
|
156 |
+
for n, p in self.model.generator.named_parameters():
|
157 |
+
if not p.requires_grad:
|
158 |
+
continue
|
159 |
+
|
160 |
+
renamed_n = rename_param(n)
|
161 |
+
self.name_to_trainable_params[renamed_n] = p
|
162 |
+
ema_weight = config.ema_weight
|
163 |
+
self.generator_ema = None
|
164 |
+
if (ema_weight is not None) and (ema_weight > 0.0):
|
165 |
+
print(f"Setting up EMA with weight {ema_weight}")
|
166 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
|
167 |
+
|
168 |
+
##############################################################################################################
|
169 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
170 |
+
if getattr(config, "generator_ckpt", False):
|
171 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
172 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")
|
173 |
+
if "generator" in state_dict:
|
174 |
+
state_dict = state_dict["generator"]
|
175 |
+
elif "model" in state_dict:
|
176 |
+
state_dict = state_dict["model"]
|
177 |
+
self.model.generator.load_state_dict(
|
178 |
+
state_dict, strict=True
|
179 |
+
)
|
180 |
+
if hasattr(config, "load"):
|
181 |
+
resume_ckpt_path_critic = os.path.join(config.load, "critic")
|
182 |
+
resume_ckpt_path_generator = os.path.join(config.load, "generator")
|
183 |
+
else:
|
184 |
+
resume_ckpt_path_critic = "none"
|
185 |
+
resume_ckpt_path_generator = "none"
|
186 |
+
|
187 |
+
_, _ = self.checkpointer_critic.try_best_load(
|
188 |
+
resume_ckpt_path=resume_ckpt_path_critic,
|
189 |
+
)
|
190 |
+
self.step, _ = self.checkpointer_generator.try_best_load(
|
191 |
+
resume_ckpt_path=resume_ckpt_path_generator,
|
192 |
+
force_start_w_ema=config.force_start_w_ema,
|
193 |
+
force_reset_zero_step=config.force_reset_zero_step,
|
194 |
+
force_reinit_ema=config.force_reinit_ema,
|
195 |
+
skip_optimizer_scheduler=config.skip_optimizer_scheduler,
|
196 |
+
)
|
197 |
+
|
198 |
+
##############################################################################################################
|
199 |
+
|
200 |
+
# Let's delete EMA params for early steps to save some computes at training and inference
|
201 |
+
if self.step < config.ema_start_step:
|
202 |
+
self.generator_ema = None
|
203 |
+
|
204 |
+
self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
|
205 |
+
self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
|
206 |
+
self.previous_time = None
|
207 |
+
|
208 |
+
def save(self):
|
209 |
+
print("Start gathering distributed model states...")
|
210 |
+
generator_state_dict = fsdp_state_dict(
|
211 |
+
self.model.generator)
|
212 |
+
critic_state_dict = fsdp_state_dict(
|
213 |
+
self.model.fake_score)
|
214 |
+
|
215 |
+
if self.config.ema_start_step < self.step:
|
216 |
+
state_dict = {
|
217 |
+
"generator": generator_state_dict,
|
218 |
+
"critic": critic_state_dict,
|
219 |
+
"generator_ema": self.generator_ema.state_dict(),
|
220 |
+
}
|
221 |
+
else:
|
222 |
+
state_dict = {
|
223 |
+
"generator": generator_state_dict,
|
224 |
+
"critic": critic_state_dict,
|
225 |
+
}
|
226 |
+
|
227 |
+
if self.is_main_process:
|
228 |
+
os.makedirs(os.path.join(self.output_path,
|
229 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
230 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
231 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
232 |
+
print("Model saved to", os.path.join(self.output_path,
|
233 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
234 |
+
|
235 |
+
def fwdbwd_one_step(self, batch, train_generator):
|
236 |
+
self.model.eval() # prevent any randomness (e.g. dropout)
|
237 |
+
|
238 |
+
if self.step % 20 == 0:
|
239 |
+
torch.cuda.empty_cache()
|
240 |
+
|
241 |
+
# Step 1: Get the next batch of text prompts
|
242 |
+
text_prompts = batch["prompts"] # next(self.dataloader)
|
243 |
+
if "ode_latent" in batch:
|
244 |
+
clean_latent = batch["ode_latent"][:, -1].to(device=self.device, dtype=self.dtype)
|
245 |
+
else:
|
246 |
+
frames = batch["frames"].to(device=self.device, dtype=self.dtype)
|
247 |
+
with torch.no_grad():
|
248 |
+
clean_latent = self.model.vae.encode_to_latent(
|
249 |
+
frames).to(device=self.device, dtype=self.dtype)
|
250 |
+
|
251 |
+
image_latent = clean_latent[:, 0:1, ]
|
252 |
+
|
253 |
+
batch_size = len(text_prompts)
|
254 |
+
image_or_video_shape = list(self.config.image_or_video_shape)
|
255 |
+
image_or_video_shape[0] = batch_size
|
256 |
+
|
257 |
+
# Step 2: Extract the conditional infos
|
258 |
+
with torch.no_grad():
|
259 |
+
conditional_dict = self.model.text_encoder(
|
260 |
+
text_prompts=text_prompts)
|
261 |
+
|
262 |
+
if not getattr(self, "unconditional_dict", None):
|
263 |
+
unconditional_dict = self.model.text_encoder(
|
264 |
+
text_prompts=[self.config.negative_prompt] * batch_size)
|
265 |
+
unconditional_dict = {k: v.detach()
|
266 |
+
for k, v in unconditional_dict.items()}
|
267 |
+
self.unconditional_dict = unconditional_dict # cache the unconditional_dict
|
268 |
+
else:
|
269 |
+
unconditional_dict = self.unconditional_dict
|
270 |
+
|
271 |
+
mini_bs, full_bs = (
|
272 |
+
batch["mini_bs"],
|
273 |
+
batch["full_bs"],
|
274 |
+
)
|
275 |
+
|
276 |
+
# Step 3: Store gradients for the generator (if training the generator)
|
277 |
+
if train_generator:
|
278 |
+
gan_G_loss = self.model.generator_loss(
|
279 |
+
image_or_video_shape=image_or_video_shape,
|
280 |
+
conditional_dict=conditional_dict,
|
281 |
+
unconditional_dict=unconditional_dict,
|
282 |
+
clean_latent=clean_latent,
|
283 |
+
initial_latent=image_latent if self.config.i2v else None
|
284 |
+
)
|
285 |
+
|
286 |
+
loss_ratio = mini_bs * self.world_size / full_bs
|
287 |
+
total_loss = gan_G_loss * loss_ratio * self.loss_scale
|
288 |
+
|
289 |
+
total_loss.backward()
|
290 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
291 |
+
self.max_grad_norm_generator)
|
292 |
+
|
293 |
+
generator_log_dict = {"generator_grad_norm": generator_grad_norm,
|
294 |
+
"gan_G_loss": gan_G_loss}
|
295 |
+
|
296 |
+
return generator_log_dict
|
297 |
+
else:
|
298 |
+
generator_log_dict = {}
|
299 |
+
|
300 |
+
# Step 4: Store gradients for the critic (if training the critic)
|
301 |
+
(gan_D_loss, r1_loss, r2_loss), critic_log_dict = self.model.critic_loss(
|
302 |
+
image_or_video_shape=image_or_video_shape,
|
303 |
+
conditional_dict=conditional_dict,
|
304 |
+
unconditional_dict=unconditional_dict,
|
305 |
+
clean_latent=clean_latent,
|
306 |
+
real_image_or_video=clean_latent,
|
307 |
+
initial_latent=image_latent if self.config.i2v else None
|
308 |
+
)
|
309 |
+
|
310 |
+
loss_ratio = mini_bs * dist.get_world_size() / full_bs
|
311 |
+
total_loss = (gan_D_loss + 0.5 * (r1_loss + r2_loss)) * loss_ratio * self.loss_scale
|
312 |
+
|
313 |
+
total_loss.backward()
|
314 |
+
critic_grad_norm = self.model.fake_score.clip_grad_norm_(
|
315 |
+
self.max_grad_norm_critic)
|
316 |
+
|
317 |
+
critic_log_dict.update({"critic_grad_norm": critic_grad_norm,
|
318 |
+
"gan_D_loss": gan_D_loss,
|
319 |
+
"r1_loss": r1_loss,
|
320 |
+
"r2_loss": r2_loss})
|
321 |
+
|
322 |
+
return critic_log_dict
|
323 |
+
|
324 |
+
def generate_video(self, pipeline, prompts, image=None):
|
325 |
+
batch_size = len(prompts)
|
326 |
+
sampled_noise = torch.randn(
|
327 |
+
[batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
|
328 |
+
)
|
329 |
+
video, _ = pipeline.inference(
|
330 |
+
noise=sampled_noise,
|
331 |
+
text_prompts=prompts,
|
332 |
+
return_latents=True
|
333 |
+
)
|
334 |
+
current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
|
335 |
+
return current_video
|
336 |
+
|
337 |
+
def train(self):
|
338 |
+
start_step = self.step
|
339 |
+
|
340 |
+
while True:
|
341 |
+
if self.step == self.discriminator_warmup_steps and self.discriminator_warmup_steps != 0:
|
342 |
+
print("Resetting critic optimizer")
|
343 |
+
del self.critic_optimizer
|
344 |
+
torch.cuda.empty_cache()
|
345 |
+
# Create new optimizers
|
346 |
+
self.critic_optimizer = torch.optim.AdamW(
|
347 |
+
self.critic_param_groups,
|
348 |
+
betas=(self.config.beta1_critic, self.config.beta2_critic)
|
349 |
+
)
|
350 |
+
# Update checkpointer references
|
351 |
+
self.checkpointer_critic.optimizer = self.critic_optimizer
|
352 |
+
# Check if we're in the discriminator warmup phase
|
353 |
+
self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
|
354 |
+
|
355 |
+
# Only update generator and critic outside the warmup phase
|
356 |
+
TRAIN_GENERATOR = not self.in_discriminator_warmup and self.step % self.config.dfake_gen_update_ratio == 0
|
357 |
+
|
358 |
+
# Train the generator (only outside warmup phase)
|
359 |
+
if TRAIN_GENERATOR:
|
360 |
+
self.model.fake_score.requires_grad_(False)
|
361 |
+
self.model.generator.requires_grad_(True)
|
362 |
+
self.generator_optimizer.zero_grad(set_to_none=True)
|
363 |
+
extras_list = []
|
364 |
+
for ii, mini_batch in enumerate(self.dataloader.next()):
|
365 |
+
extra = self.fwdbwd_one_step(mini_batch, True)
|
366 |
+
extras_list.append(extra)
|
367 |
+
generator_log_dict = merge_dict_list(extras_list)
|
368 |
+
self.generator_optimizer.step()
|
369 |
+
if self.generator_ema is not None:
|
370 |
+
self.generator_ema.update(self.model.generator)
|
371 |
+
else:
|
372 |
+
generator_log_dict = {}
|
373 |
+
|
374 |
+
# Train the critic/discriminator
|
375 |
+
if self.in_discriminator_warmup:
|
376 |
+
# During warmup, only allow gradient for discriminator params
|
377 |
+
self.model.generator.requires_grad_(False)
|
378 |
+
self.model.fake_score.requires_grad_(False)
|
379 |
+
|
380 |
+
# Enable gradient only for discriminator params
|
381 |
+
for name, param in self.model.fake_score.named_parameters():
|
382 |
+
if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
|
383 |
+
param.requires_grad_(True)
|
384 |
+
else:
|
385 |
+
# Normal training mode
|
386 |
+
self.model.generator.requires_grad_(False)
|
387 |
+
self.model.fake_score.requires_grad_(True)
|
388 |
+
|
389 |
+
self.critic_optimizer.zero_grad(set_to_none=True)
|
390 |
+
extras_list = []
|
391 |
+
batch = next(self.dataloader)
|
392 |
+
extra = self.fwdbwd_one_step(batch, False)
|
393 |
+
extras_list.append(extra)
|
394 |
+
critic_log_dict = merge_dict_list(extras_list)
|
395 |
+
self.critic_optimizer.step()
|
396 |
+
|
397 |
+
# Increment the step since we finished gradient update
|
398 |
+
self.step += 1
|
399 |
+
|
400 |
+
# If we just finished warmup, print a message
|
401 |
+
if self.is_main_process and self.step == self.discriminator_warmup_steps:
|
402 |
+
print(f"Finished discriminator warmup after {self.discriminator_warmup_steps} steps")
|
403 |
+
|
404 |
+
# Create EMA params (if not already created)
|
405 |
+
if (self.step >= self.config.ema_start_step) and \
|
406 |
+
(self.generator_ema is None) and (self.config.ema_weight > 0):
|
407 |
+
self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
|
408 |
+
|
409 |
+
# Save the model
|
410 |
+
if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
|
411 |
+
torch.cuda.empty_cache()
|
412 |
+
self.save()
|
413 |
+
torch.cuda.empty_cache()
|
414 |
+
|
415 |
+
# Logging
|
416 |
+
wandb_loss_dict = {
|
417 |
+
"generator_grad_norm": generator_log_dict["generator_grad_norm"],
|
418 |
+
"critic_grad_norm": critic_log_dict["critic_grad_norm"],
|
419 |
+
"real_logit": critic_log_dict["noisy_real_logit"],
|
420 |
+
"fake_logit": critic_log_dict["noisy_fake_logit"],
|
421 |
+
"r1_loss": critic_log_dict["r1_loss"],
|
422 |
+
"r2_loss": critic_log_dict["r2_loss"],
|
423 |
+
}
|
424 |
+
if TRAIN_GENERATOR:
|
425 |
+
wandb_loss_dict.update({
|
426 |
+
"generator_grad_norm": generator_log_dict["generator_grad_norm"],
|
427 |
+
})
|
428 |
+
self.all_gather_dict(wandb_loss_dict)
|
429 |
+
wandb_loss_dict["diff_logit"] = wandb_loss_dict["real_logit"] - wandb_loss_dict["fake_logit"]
|
430 |
+
wandb_loss_dict["reg_loss"] = 0.5 * (wandb_loss_dict["r1_loss"] + wandb_loss_dict["r2_loss"])
|
431 |
+
|
432 |
+
if self.is_main_process:
|
433 |
+
if self.in_discriminator_warmup:
|
434 |
+
warmup_status = f"[WARMUP {self.step}/{self.discriminator_warmup_steps}] Training only discriminator params"
|
435 |
+
print(warmup_status)
|
436 |
+
if not self.disable_wandb:
|
437 |
+
wandb_loss_dict.update({"warmup_status": 1.0})
|
438 |
+
|
439 |
+
if not self.disable_wandb:
|
440 |
+
wandb.log(wandb_loss_dict, step=self.step)
|
441 |
+
|
442 |
+
if self.step % self.config.gc_interval == 0:
|
443 |
+
if dist.get_rank() == 0:
|
444 |
+
logging.info("DistGarbageCollector: Running GC.")
|
445 |
+
gc.collect()
|
446 |
+
torch.cuda.empty_cache()
|
447 |
+
|
448 |
+
if self.is_main_process:
|
449 |
+
current_time = time.time()
|
450 |
+
if self.previous_time is None:
|
451 |
+
self.previous_time = current_time
|
452 |
+
else:
|
453 |
+
if not self.disable_wandb:
|
454 |
+
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
|
455 |
+
self.previous_time = current_time
|
456 |
+
|
457 |
+
def all_gather_dict(self, target_dict):
|
458 |
+
for key, value in target_dict.items():
|
459 |
+
gathered_value = torch.zeros(
|
460 |
+
[self.world_size, *value.shape],
|
461 |
+
dtype=value.dtype, device=self.device)
|
462 |
+
dist.all_gather_into_tensor(gathered_value, value)
|
463 |
+
avg_value = gathered_value.mean().item()
|
464 |
+
target_dict[key] = avg_value
|
trainer/ode.py
ADDED
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gc
|
2 |
+
import logging
|
3 |
+
from utils.dataset import ODERegressionLMDBDataset, cycle
|
4 |
+
from model import ODERegression
|
5 |
+
from collections import defaultdict
|
6 |
+
from utils.misc import (
|
7 |
+
set_seed
|
8 |
+
)
|
9 |
+
import torch.distributed as dist
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
import torch
|
12 |
+
import wandb
|
13 |
+
import time
|
14 |
+
import os
|
15 |
+
|
16 |
+
from utils.distributed import barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
|
17 |
+
|
18 |
+
|
19 |
+
class Trainer:
|
20 |
+
def __init__(self, config):
|
21 |
+
self.config = config
|
22 |
+
self.step = 0
|
23 |
+
|
24 |
+
# Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
|
25 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
26 |
+
torch.backends.cudnn.allow_tf32 = True
|
27 |
+
|
28 |
+
launch_distributed_job()
|
29 |
+
global_rank = dist.get_rank()
|
30 |
+
self.world_size = dist.get_world_size()
|
31 |
+
|
32 |
+
self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
|
33 |
+
self.device = torch.cuda.current_device()
|
34 |
+
self.is_main_process = global_rank == 0
|
35 |
+
self.disable_wandb = config.disable_wandb
|
36 |
+
|
37 |
+
# use a random seed for the training
|
38 |
+
if config.seed == 0:
|
39 |
+
random_seed = torch.randint(0, 10000000, (1,), device=self.device)
|
40 |
+
dist.broadcast(random_seed, src=0)
|
41 |
+
config.seed = random_seed.item()
|
42 |
+
|
43 |
+
set_seed(config.seed + global_rank)
|
44 |
+
|
45 |
+
if self.is_main_process and not self.disable_wandb:
|
46 |
+
wandb.login(host=config.wandb_host, key=config.wandb_key)
|
47 |
+
wandb.init(
|
48 |
+
config=OmegaConf.to_container(config, resolve=True),
|
49 |
+
name=config.config_name,
|
50 |
+
mode="online",
|
51 |
+
entity=config.wandb_entity,
|
52 |
+
project=config.wandb_project,
|
53 |
+
dir=config.wandb_save_dir
|
54 |
+
)
|
55 |
+
|
56 |
+
self.output_path = config.logdir
|
57 |
+
|
58 |
+
# Step 2: Initialize the model and optimizer
|
59 |
+
|
60 |
+
assert config.distribution_loss == "ode", "Only ODE loss is supported for ODE training"
|
61 |
+
self.model = ODERegression(config, device=self.device)
|
62 |
+
|
63 |
+
self.model.generator = fsdp_wrap(
|
64 |
+
self.model.generator,
|
65 |
+
sharding_strategy=config.sharding_strategy,
|
66 |
+
mixed_precision=config.mixed_precision,
|
67 |
+
wrap_strategy=config.generator_fsdp_wrap_strategy
|
68 |
+
)
|
69 |
+
self.model.text_encoder = fsdp_wrap(
|
70 |
+
self.model.text_encoder,
|
71 |
+
sharding_strategy=config.sharding_strategy,
|
72 |
+
mixed_precision=config.mixed_precision,
|
73 |
+
wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
|
74 |
+
cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
|
75 |
+
)
|
76 |
+
|
77 |
+
if not config.no_visualize or config.load_raw_video:
|
78 |
+
self.model.vae = self.model.vae.to(
|
79 |
+
device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
|
80 |
+
|
81 |
+
self.generator_optimizer = torch.optim.AdamW(
|
82 |
+
[param for param in self.model.generator.parameters()
|
83 |
+
if param.requires_grad],
|
84 |
+
lr=config.lr,
|
85 |
+
betas=(config.beta1, config.beta2),
|
86 |
+
weight_decay=config.weight_decay
|
87 |
+
)
|
88 |
+
|
89 |
+
# Step 3: Initialize the dataloader
|
90 |
+
dataset = ODERegressionLMDBDataset(
|
91 |
+
config.data_path, max_pair=getattr(config, "max_pair", int(1e8)))
|
92 |
+
sampler = torch.utils.data.distributed.DistributedSampler(
|
93 |
+
dataset, shuffle=True, drop_last=True)
|
94 |
+
dataloader = torch.utils.data.DataLoader(
|
95 |
+
dataset, batch_size=config.batch_size, sampler=sampler, num_workers=8)
|
96 |
+
total_batch_size = getattr(config, "total_batch_size", None)
|
97 |
+
if total_batch_size is not None:
|
98 |
+
assert total_batch_size == config.batch_size * self.world_size, "Gradient accumulation is not supported for ODE training"
|
99 |
+
self.dataloader = cycle(dataloader)
|
100 |
+
|
101 |
+
self.step = 0
|
102 |
+
|
103 |
+
##############################################################################################################
|
104 |
+
# 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
|
105 |
+
if getattr(config, "generator_ckpt", False):
|
106 |
+
print(f"Loading pretrained generator from {config.generator_ckpt}")
|
107 |
+
state_dict = torch.load(config.generator_ckpt, map_location="cpu")[
|
108 |
+
'generator']
|
109 |
+
self.model.generator.load_state_dict(
|
110 |
+
state_dict, strict=True
|
111 |
+
)
|
112 |
+
|
113 |
+
##############################################################################################################
|
114 |
+
|
115 |
+
self.max_grad_norm = 10.0
|
116 |
+
self.previous_time = None
|
117 |
+
|
118 |
+
def save(self):
|
119 |
+
print("Start gathering distributed model states...")
|
120 |
+
generator_state_dict = fsdp_state_dict(
|
121 |
+
self.model.generator)
|
122 |
+
state_dict = {
|
123 |
+
"generator": generator_state_dict
|
124 |
+
}
|
125 |
+
|
126 |
+
if self.is_main_process:
|
127 |
+
os.makedirs(os.path.join(self.output_path,
|
128 |
+
f"checkpoint_model_{self.step:06d}"), exist_ok=True)
|
129 |
+
torch.save(state_dict, os.path.join(self.output_path,
|
130 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
131 |
+
print("Model saved to", os.path.join(self.output_path,
|
132 |
+
f"checkpoint_model_{self.step:06d}", "model.pt"))
|
133 |
+
|
134 |
+
def train_one_step(self):
|
135 |
+
VISUALIZE = self.step % 100 == 0
|
136 |
+
self.model.eval() # prevent any randomness (e.g. dropout)
|
137 |
+
|
138 |
+
# Step 1: Get the next batch of text prompts
|
139 |
+
batch = next(self.dataloader)
|
140 |
+
text_prompts = batch["prompts"]
|
141 |
+
ode_latent = batch["ode_latent"].to(
|
142 |
+
device=self.device, dtype=self.dtype)
|
143 |
+
|
144 |
+
# Step 2: Extract the conditional infos
|
145 |
+
with torch.no_grad():
|
146 |
+
conditional_dict = self.model.text_encoder(
|
147 |
+
text_prompts=text_prompts)
|
148 |
+
|
149 |
+
# Step 3: Train the generator
|
150 |
+
generator_loss, log_dict = self.model.generator_loss(
|
151 |
+
ode_latent=ode_latent,
|
152 |
+
conditional_dict=conditional_dict
|
153 |
+
)
|
154 |
+
|
155 |
+
unnormalized_loss = log_dict["unnormalized_loss"]
|
156 |
+
timestep = log_dict["timestep"]
|
157 |
+
|
158 |
+
if self.world_size > 1:
|
159 |
+
gathered_unnormalized_loss = torch.zeros(
|
160 |
+
[self.world_size, *unnormalized_loss.shape],
|
161 |
+
dtype=unnormalized_loss.dtype, device=self.device)
|
162 |
+
gathered_timestep = torch.zeros(
|
163 |
+
[self.world_size, *timestep.shape],
|
164 |
+
dtype=timestep.dtype, device=self.device)
|
165 |
+
|
166 |
+
dist.all_gather_into_tensor(
|
167 |
+
gathered_unnormalized_loss, unnormalized_loss)
|
168 |
+
dist.all_gather_into_tensor(gathered_timestep, timestep)
|
169 |
+
else:
|
170 |
+
gathered_unnormalized_loss = unnormalized_loss
|
171 |
+
gathered_timestep = timestep
|
172 |
+
|
173 |
+
loss_breakdown = defaultdict(list)
|
174 |
+
stats = {}
|
175 |
+
|
176 |
+
for index, t in enumerate(timestep):
|
177 |
+
loss_breakdown[str(int(t.item()) // 250 * 250)].append(
|
178 |
+
unnormalized_loss[index].item())
|
179 |
+
|
180 |
+
for key_t in loss_breakdown.keys():
|
181 |
+
stats["loss_at_time_" + key_t] = sum(loss_breakdown[key_t]) / \
|
182 |
+
len(loss_breakdown[key_t])
|
183 |
+
|
184 |
+
self.generator_optimizer.zero_grad()
|
185 |
+
generator_loss.backward()
|
186 |
+
generator_grad_norm = self.model.generator.clip_grad_norm_(
|
187 |
+
self.max_grad_norm)
|
188 |
+
self.generator_optimizer.step()
|
189 |
+
|
190 |
+
# Step 4: Visualization
|
191 |
+
if VISUALIZE and not self.config.no_visualize and not self.config.disable_wandb and self.is_main_process:
|
192 |
+
# Visualize the input, output, and ground truth
|
193 |
+
input = log_dict["input"]
|
194 |
+
output = log_dict["output"]
|
195 |
+
ground_truth = ode_latent[:, -1]
|
196 |
+
|
197 |
+
input_video = self.model.vae.decode_to_pixel(input)
|
198 |
+
output_video = self.model.vae.decode_to_pixel(output)
|
199 |
+
ground_truth_video = self.model.vae.decode_to_pixel(ground_truth)
|
200 |
+
input_video = 255.0 * (input_video.cpu().numpy() * 0.5 + 0.5)
|
201 |
+
output_video = 255.0 * (output_video.cpu().numpy() * 0.5 + 0.5)
|
202 |
+
ground_truth_video = 255.0 * (ground_truth_video.cpu().numpy() * 0.5 + 0.5)
|
203 |
+
|
204 |
+
# Visualize the input, output, and ground truth
|
205 |
+
wandb.log({
|
206 |
+
"input": wandb.Video(input_video, caption="Input", fps=16, format="mp4"),
|
207 |
+
"output": wandb.Video(output_video, caption="Output", fps=16, format="mp4"),
|
208 |
+
"ground_truth": wandb.Video(ground_truth_video, caption="Ground Truth", fps=16, format="mp4"),
|
209 |
+
}, step=self.step)
|
210 |
+
|
211 |
+
# Step 5: Logging
|
212 |
+
if self.is_main_process and not self.disable_wandb:
|
213 |
+
wandb_loss_dict = {
|
214 |
+
"generator_loss": generator_loss.item(),
|
215 |
+
"generator_grad_norm": generator_grad_norm.item(),
|
216 |
+
**stats
|
217 |
+
}
|
218 |
+
wandb.log(wandb_loss_dict, step=self.step)
|
219 |
+
|
220 |
+
if self.step % self.config.gc_interval == 0:
|
221 |
+
if dist.get_rank() == 0:
|
222 |
+
logging.info("DistGarbageCollector: Running GC.")
|
223 |
+
gc.collect()
|
224 |
+
|
225 |
+
def train(self):
|
226 |
+
while True:
|
227 |
+
self.train_one_step()
|
228 |
+
if (not self.config.no_save) and self.step % self.config.log_iters == 0:
|
229 |
+
self.save()
|
230 |
+
torch.cuda.empty_cache()
|
231 |
+
|
232 |
+
barrier()
|
233 |
+
if self.is_main_process:
|
234 |
+
current_time = time.time()
|
235 |
+
if self.previous_time is None:
|
236 |
+
self.previous_time = current_time
|
237 |
+
else:
|
238 |
+
if not self.disable_wandb:
|
239 |
+
wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
|
240 |
+
self.previous_time = current_time
|
241 |
+
|
242 |
+
self.step += 1
|
utils/dataset.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import lmdb
|
6 |
+
import json
|
7 |
+
from pathlib import Path
|
8 |
+
from PIL import Image
|
9 |
+
import os
|
10 |
+
|
11 |
+
|
12 |
+
class TextDataset(Dataset):
|
13 |
+
def __init__(self, prompt_path, extended_prompt_path=None):
|
14 |
+
with open(prompt_path, encoding="utf-8") as f:
|
15 |
+
self.prompt_list = [line.rstrip() for line in f]
|
16 |
+
|
17 |
+
if extended_prompt_path is not None:
|
18 |
+
with open(extended_prompt_path, encoding="utf-8") as f:
|
19 |
+
self.extended_prompt_list = [line.rstrip() for line in f]
|
20 |
+
assert len(self.extended_prompt_list) == len(self.prompt_list)
|
21 |
+
else:
|
22 |
+
self.extended_prompt_list = None
|
23 |
+
|
24 |
+
def __len__(self):
|
25 |
+
return len(self.prompt_list)
|
26 |
+
|
27 |
+
def __getitem__(self, idx):
|
28 |
+
batch = {
|
29 |
+
"prompts": self.prompt_list[idx],
|
30 |
+
"idx": idx,
|
31 |
+
}
|
32 |
+
if self.extended_prompt_list is not None:
|
33 |
+
batch["extended_prompts"] = self.extended_prompt_list[idx]
|
34 |
+
return batch
|
35 |
+
|
36 |
+
|
37 |
+
class ODERegressionLMDBDataset(Dataset):
|
38 |
+
def __init__(self, data_path: str, max_pair: int = int(1e8)):
|
39 |
+
self.env = lmdb.open(data_path, readonly=True,
|
40 |
+
lock=False, readahead=False, meminit=False)
|
41 |
+
|
42 |
+
self.latents_shape = get_array_shape_from_lmdb(self.env, 'latents')
|
43 |
+
self.max_pair = max_pair
|
44 |
+
|
45 |
+
def __len__(self):
|
46 |
+
return min(self.latents_shape[0], self.max_pair)
|
47 |
+
|
48 |
+
def __getitem__(self, idx):
|
49 |
+
"""
|
50 |
+
Outputs:
|
51 |
+
- prompts: List of Strings
|
52 |
+
- latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
|
53 |
+
"""
|
54 |
+
latents = retrieve_row_from_lmdb(
|
55 |
+
self.env,
|
56 |
+
"latents", np.float16, idx, shape=self.latents_shape[1:]
|
57 |
+
)
|
58 |
+
|
59 |
+
if len(latents.shape) == 4:
|
60 |
+
latents = latents[None, ...]
|
61 |
+
|
62 |
+
prompts = retrieve_row_from_lmdb(
|
63 |
+
self.env,
|
64 |
+
"prompts", str, idx
|
65 |
+
)
|
66 |
+
return {
|
67 |
+
"prompts": prompts,
|
68 |
+
"ode_latent": torch.tensor(latents, dtype=torch.float32)
|
69 |
+
}
|
70 |
+
|
71 |
+
|
72 |
+
class ShardingLMDBDataset(Dataset):
|
73 |
+
def __init__(self, data_path: str, max_pair: int = int(1e8)):
|
74 |
+
self.envs = []
|
75 |
+
self.index = []
|
76 |
+
|
77 |
+
for fname in sorted(os.listdir(data_path)):
|
78 |
+
path = os.path.join(data_path, fname)
|
79 |
+
env = lmdb.open(path,
|
80 |
+
readonly=True,
|
81 |
+
lock=False,
|
82 |
+
readahead=False,
|
83 |
+
meminit=False)
|
84 |
+
self.envs.append(env)
|
85 |
+
|
86 |
+
self.latents_shape = [None] * len(self.envs)
|
87 |
+
for shard_id, env in enumerate(self.envs):
|
88 |
+
self.latents_shape[shard_id] = get_array_shape_from_lmdb(env, 'latents')
|
89 |
+
for local_i in range(self.latents_shape[shard_id][0]):
|
90 |
+
self.index.append((shard_id, local_i))
|
91 |
+
|
92 |
+
# print("shard_id ", shard_id, " local_i ", local_i)
|
93 |
+
|
94 |
+
self.max_pair = max_pair
|
95 |
+
|
96 |
+
def __len__(self):
|
97 |
+
return len(self.index)
|
98 |
+
|
99 |
+
def __getitem__(self, idx):
|
100 |
+
"""
|
101 |
+
Outputs:
|
102 |
+
- prompts: List of Strings
|
103 |
+
- latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
|
104 |
+
"""
|
105 |
+
shard_id, local_idx = self.index[idx]
|
106 |
+
|
107 |
+
latents = retrieve_row_from_lmdb(
|
108 |
+
self.envs[shard_id],
|
109 |
+
"latents", np.float16, local_idx,
|
110 |
+
shape=self.latents_shape[shard_id][1:]
|
111 |
+
)
|
112 |
+
|
113 |
+
if len(latents.shape) == 4:
|
114 |
+
latents = latents[None, ...]
|
115 |
+
|
116 |
+
prompts = retrieve_row_from_lmdb(
|
117 |
+
self.envs[shard_id],
|
118 |
+
"prompts", str, local_idx
|
119 |
+
)
|
120 |
+
|
121 |
+
return {
|
122 |
+
"prompts": prompts,
|
123 |
+
"ode_latent": torch.tensor(latents, dtype=torch.float32)
|
124 |
+
}
|
125 |
+
|
126 |
+
|
127 |
+
class TextImagePairDataset(Dataset):
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
data_dir,
|
131 |
+
transform=None,
|
132 |
+
eval_first_n=-1,
|
133 |
+
pad_to_multiple_of=None
|
134 |
+
):
|
135 |
+
"""
|
136 |
+
Args:
|
137 |
+
data_dir (str): Path to the directory containing:
|
138 |
+
- target_crop_info_*.json (metadata file)
|
139 |
+
- */ (subdirectory containing images with matching aspect ratio)
|
140 |
+
transform (callable, optional): Optional transform to be applied on the image
|
141 |
+
"""
|
142 |
+
self.transform = transform
|
143 |
+
data_dir = Path(data_dir)
|
144 |
+
|
145 |
+
# Find the metadata JSON file
|
146 |
+
metadata_files = list(data_dir.glob('target_crop_info_*.json'))
|
147 |
+
if not metadata_files:
|
148 |
+
raise FileNotFoundError(f"No metadata file found in {data_dir}")
|
149 |
+
if len(metadata_files) > 1:
|
150 |
+
raise ValueError(f"Multiple metadata files found in {data_dir}")
|
151 |
+
|
152 |
+
metadata_path = metadata_files[0]
|
153 |
+
# Extract aspect ratio from metadata filename (e.g. target_crop_info_26-15.json -> 26-15)
|
154 |
+
aspect_ratio = metadata_path.stem.split('_')[-1]
|
155 |
+
|
156 |
+
# Use aspect ratio subfolder for images
|
157 |
+
self.image_dir = data_dir / aspect_ratio
|
158 |
+
if not self.image_dir.exists():
|
159 |
+
raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
|
160 |
+
|
161 |
+
# Load metadata
|
162 |
+
with open(metadata_path, 'r') as f:
|
163 |
+
self.metadata = json.load(f)
|
164 |
+
|
165 |
+
eval_first_n = eval_first_n if eval_first_n != -1 else len(self.metadata)
|
166 |
+
self.metadata = self.metadata[:eval_first_n]
|
167 |
+
|
168 |
+
# Verify all images exist
|
169 |
+
for item in self.metadata:
|
170 |
+
image_path = self.image_dir / item['file_name']
|
171 |
+
if not image_path.exists():
|
172 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
173 |
+
|
174 |
+
self.dummy_prompt = "DUMMY PROMPT"
|
175 |
+
self.pre_pad_len = len(self.metadata)
|
176 |
+
if pad_to_multiple_of is not None and len(self.metadata) % pad_to_multiple_of != 0:
|
177 |
+
# Duplicate the last entry
|
178 |
+
self.metadata += [self.metadata[-1]] * (
|
179 |
+
pad_to_multiple_of - len(self.metadata) % pad_to_multiple_of
|
180 |
+
)
|
181 |
+
|
182 |
+
def __len__(self):
|
183 |
+
return len(self.metadata)
|
184 |
+
|
185 |
+
def __getitem__(self, idx):
|
186 |
+
"""
|
187 |
+
Returns:
|
188 |
+
dict: A dictionary containing:
|
189 |
+
- image: PIL Image
|
190 |
+
- caption: str
|
191 |
+
- target_bbox: list of int [x1, y1, x2, y2]
|
192 |
+
- target_ratio: str
|
193 |
+
- type: str
|
194 |
+
- origin_size: tuple of int (width, height)
|
195 |
+
"""
|
196 |
+
item = self.metadata[idx]
|
197 |
+
|
198 |
+
# Load image
|
199 |
+
image_path = self.image_dir / item['file_name']
|
200 |
+
image = Image.open(image_path).convert('RGB')
|
201 |
+
|
202 |
+
# Apply transform if specified
|
203 |
+
if self.transform:
|
204 |
+
image = self.transform(image)
|
205 |
+
|
206 |
+
return {
|
207 |
+
'image': image,
|
208 |
+
'prompts': item['caption'],
|
209 |
+
'target_bbox': item['target_crop']['target_bbox'],
|
210 |
+
'target_ratio': item['target_crop']['target_ratio'],
|
211 |
+
'type': item['type'],
|
212 |
+
'origin_size': (item['origin_width'], item['origin_height']),
|
213 |
+
'idx': idx
|
214 |
+
}
|
215 |
+
|
216 |
+
|
217 |
+
def cycle(dl):
|
218 |
+
while True:
|
219 |
+
for data in dl:
|
220 |
+
yield data
|
utils/distributed.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datetime import timedelta
|
2 |
+
from functools import partial
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, StateDictType
|
7 |
+
from torch.distributed.fsdp.api import CPUOffload
|
8 |
+
from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
|
9 |
+
|
10 |
+
|
11 |
+
def fsdp_state_dict(model):
|
12 |
+
fsdp_fullstate_save_policy = FullStateDictConfig(
|
13 |
+
offload_to_cpu=True, rank0_only=True
|
14 |
+
)
|
15 |
+
with FSDP.state_dict_type(
|
16 |
+
model, StateDictType.FULL_STATE_DICT, fsdp_fullstate_save_policy
|
17 |
+
):
|
18 |
+
checkpoint = model.state_dict()
|
19 |
+
|
20 |
+
return checkpoint
|
21 |
+
|
22 |
+
|
23 |
+
def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, wrap_strategy="size", min_num_params=int(5e7), transformer_module=None, cpu_offload=False):
|
24 |
+
if mixed_precision:
|
25 |
+
mixed_precision_policy = MixedPrecision(
|
26 |
+
param_dtype=torch.bfloat16,
|
27 |
+
reduce_dtype=torch.float32,
|
28 |
+
buffer_dtype=torch.float32,
|
29 |
+
cast_forward_inputs=False
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
mixed_precision_policy = None
|
33 |
+
|
34 |
+
if wrap_strategy == "transformer":
|
35 |
+
auto_wrap_policy = partial(
|
36 |
+
transformer_auto_wrap_policy,
|
37 |
+
transformer_layer_cls=transformer_module
|
38 |
+
)
|
39 |
+
elif wrap_strategy == "size":
|
40 |
+
auto_wrap_policy = partial(
|
41 |
+
size_based_auto_wrap_policy,
|
42 |
+
min_num_params=min_num_params
|
43 |
+
)
|
44 |
+
else:
|
45 |
+
raise ValueError(f"Invalid wrap strategy: {wrap_strategy}")
|
46 |
+
|
47 |
+
os.environ["NCCL_CROSS_NIC"] = "1"
|
48 |
+
|
49 |
+
sharding_strategy = {
|
50 |
+
"full": ShardingStrategy.FULL_SHARD,
|
51 |
+
"hybrid_full": ShardingStrategy.HYBRID_SHARD,
|
52 |
+
"hybrid_zero2": ShardingStrategy._HYBRID_SHARD_ZERO2,
|
53 |
+
"no_shard": ShardingStrategy.NO_SHARD,
|
54 |
+
}[sharding_strategy]
|
55 |
+
|
56 |
+
module = FSDP(
|
57 |
+
module,
|
58 |
+
auto_wrap_policy=auto_wrap_policy,
|
59 |
+
sharding_strategy=sharding_strategy,
|
60 |
+
mixed_precision=mixed_precision_policy,
|
61 |
+
device_id=torch.cuda.current_device(),
|
62 |
+
limit_all_gathers=True,
|
63 |
+
use_orig_params=True,
|
64 |
+
cpu_offload=CPUOffload(offload_params=cpu_offload),
|
65 |
+
sync_module_states=False # Load ckpt on rank 0 and sync to other ranks
|
66 |
+
)
|
67 |
+
return module
|
68 |
+
|
69 |
+
|
70 |
+
def barrier():
|
71 |
+
if dist.is_initialized():
|
72 |
+
dist.barrier()
|
73 |
+
|
74 |
+
|
75 |
+
def launch_distributed_job(backend: str = "nccl"):
|
76 |
+
rank = int(os.environ["RANK"])
|
77 |
+
local_rank = int(os.environ["LOCAL_RANK"])
|
78 |
+
world_size = int(os.environ["WORLD_SIZE"])
|
79 |
+
host = os.environ["MASTER_ADDR"]
|
80 |
+
port = int(os.environ["MASTER_PORT"])
|
81 |
+
|
82 |
+
if ":" in host: # IPv6
|
83 |
+
init_method = f"tcp://[{host}]:{port}"
|
84 |
+
else: # IPv4
|
85 |
+
init_method = f"tcp://{host}:{port}"
|
86 |
+
dist.init_process_group(rank=rank, world_size=world_size, backend=backend,
|
87 |
+
init_method=init_method, timeout=timedelta(minutes=30))
|
88 |
+
torch.cuda.set_device(local_rank)
|
89 |
+
|
90 |
+
|
91 |
+
class EMA_FSDP:
|
92 |
+
def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
|
93 |
+
self.decay = decay
|
94 |
+
self.shadow = {}
|
95 |
+
self._init_shadow(fsdp_module)
|
96 |
+
|
97 |
+
@torch.no_grad()
|
98 |
+
def _init_shadow(self, fsdp_module):
|
99 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
100 |
+
with FSDP.summon_full_params(fsdp_module, writeback=False):
|
101 |
+
for n, p in fsdp_module.module.named_parameters():
|
102 |
+
self.shadow[n] = p.detach().clone().float().cpu()
|
103 |
+
|
104 |
+
@torch.no_grad()
|
105 |
+
def update(self, fsdp_module):
|
106 |
+
d = self.decay
|
107 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
108 |
+
with FSDP.summon_full_params(fsdp_module, writeback=False):
|
109 |
+
for n, p in fsdp_module.module.named_parameters():
|
110 |
+
self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
|
111 |
+
|
112 |
+
# Optional helpers ---------------------------------------------------
|
113 |
+
def state_dict(self):
|
114 |
+
return self.shadow # picklable
|
115 |
+
|
116 |
+
def load_state_dict(self, sd):
|
117 |
+
self.shadow = {k: v.clone() for k, v in sd.items()}
|
118 |
+
|
119 |
+
def copy_to(self, fsdp_module):
|
120 |
+
# load EMA weights into an (unwrapped) copy of the generator
|
121 |
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
122 |
+
with FSDP.summon_full_params(fsdp_module, writeback=True):
|
123 |
+
for n, p in fsdp_module.module.named_parameters():
|
124 |
+
if n in self.shadow:
|
125 |
+
p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
|
utils/lmdb.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
def get_array_shape_from_lmdb(env, array_name):
|
5 |
+
with env.begin() as txn:
|
6 |
+
image_shape = txn.get(f"{array_name}_shape".encode()).decode()
|
7 |
+
image_shape = tuple(map(int, image_shape.split()))
|
8 |
+
return image_shape
|
9 |
+
|
10 |
+
|
11 |
+
def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
|
12 |
+
"""
|
13 |
+
Store rows of multiple numpy arrays in a single LMDB.
|
14 |
+
Each row is stored separately with a naming convention.
|
15 |
+
"""
|
16 |
+
with env.begin(write=True) as txn:
|
17 |
+
for array_name, array in arrays_dict.items():
|
18 |
+
for i, row in enumerate(array):
|
19 |
+
# Convert row to bytes
|
20 |
+
if isinstance(row, str):
|
21 |
+
row_bytes = row.encode()
|
22 |
+
else:
|
23 |
+
row_bytes = row.tobytes()
|
24 |
+
|
25 |
+
data_key = f'{array_name}_{start_index + i}_data'.encode()
|
26 |
+
|
27 |
+
txn.put(data_key, row_bytes)
|
28 |
+
|
29 |
+
|
30 |
+
def process_data_dict(data_dict, seen_prompts):
|
31 |
+
output_dict = {}
|
32 |
+
|
33 |
+
all_videos = []
|
34 |
+
all_prompts = []
|
35 |
+
for prompt, video in data_dict.items():
|
36 |
+
if prompt in seen_prompts:
|
37 |
+
continue
|
38 |
+
else:
|
39 |
+
seen_prompts.add(prompt)
|
40 |
+
|
41 |
+
video = video.half().numpy()
|
42 |
+
all_videos.append(video)
|
43 |
+
all_prompts.append(prompt)
|
44 |
+
|
45 |
+
if len(all_videos) == 0:
|
46 |
+
return {"latents": np.array([]), "prompts": np.array([])}
|
47 |
+
|
48 |
+
all_videos = np.concatenate(all_videos, axis=0)
|
49 |
+
|
50 |
+
output_dict['latents'] = all_videos
|
51 |
+
output_dict['prompts'] = np.array(all_prompts)
|
52 |
+
|
53 |
+
return output_dict
|
54 |
+
|
55 |
+
|
56 |
+
def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape=None):
|
57 |
+
"""
|
58 |
+
Retrieve a specific row from a specific array in the LMDB.
|
59 |
+
"""
|
60 |
+
data_key = f'{array_name}_{row_index}_data'.encode()
|
61 |
+
|
62 |
+
with lmdb_env.begin() as txn:
|
63 |
+
row_bytes = txn.get(data_key)
|
64 |
+
|
65 |
+
if dtype == str:
|
66 |
+
array = row_bytes.decode()
|
67 |
+
else:
|
68 |
+
array = np.frombuffer(row_bytes, dtype=dtype)
|
69 |
+
|
70 |
+
if shape is not None and len(shape) > 0:
|
71 |
+
array = array.reshape(shape)
|
72 |
+
return array
|
utils/loss.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class DenoisingLoss(ABC):
|
6 |
+
@abstractmethod
|
7 |
+
def __call__(
|
8 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
9 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
10 |
+
alphas_cumprod: torch.Tensor,
|
11 |
+
timestep: torch.Tensor,
|
12 |
+
**kwargs
|
13 |
+
) -> torch.Tensor:
|
14 |
+
"""
|
15 |
+
Base class for denoising loss.
|
16 |
+
Input:
|
17 |
+
- x: the clean data with shape [B, F, C, H, W]
|
18 |
+
- x_pred: the predicted clean data with shape [B, F, C, H, W]
|
19 |
+
- noise: the noise with shape [B, F, C, H, W]
|
20 |
+
- noise_pred: the predicted noise with shape [B, F, C, H, W]
|
21 |
+
- alphas_cumprod: the cumulative product of alphas (defining the noise schedule) with shape [T]
|
22 |
+
- timestep: the current timestep with shape [B, F]
|
23 |
+
"""
|
24 |
+
pass
|
25 |
+
|
26 |
+
|
27 |
+
class X0PredLoss(DenoisingLoss):
|
28 |
+
def __call__(
|
29 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
30 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
31 |
+
alphas_cumprod: torch.Tensor,
|
32 |
+
timestep: torch.Tensor,
|
33 |
+
**kwargs
|
34 |
+
) -> torch.Tensor:
|
35 |
+
return torch.mean((x - x_pred) ** 2)
|
36 |
+
|
37 |
+
|
38 |
+
class VPredLoss(DenoisingLoss):
|
39 |
+
def __call__(
|
40 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
41 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
42 |
+
alphas_cumprod: torch.Tensor,
|
43 |
+
timestep: torch.Tensor,
|
44 |
+
**kwargs
|
45 |
+
) -> torch.Tensor:
|
46 |
+
weights = 1 / (1 - alphas_cumprod[timestep].reshape(*timestep.shape, 1, 1, 1))
|
47 |
+
return torch.mean(weights * (x - x_pred) ** 2)
|
48 |
+
|
49 |
+
|
50 |
+
class NoisePredLoss(DenoisingLoss):
|
51 |
+
def __call__(
|
52 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
53 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
54 |
+
alphas_cumprod: torch.Tensor,
|
55 |
+
timestep: torch.Tensor,
|
56 |
+
**kwargs
|
57 |
+
) -> torch.Tensor:
|
58 |
+
return torch.mean((noise - noise_pred) ** 2)
|
59 |
+
|
60 |
+
|
61 |
+
class FlowPredLoss(DenoisingLoss):
|
62 |
+
def __call__(
|
63 |
+
self, x: torch.Tensor, x_pred: torch.Tensor,
|
64 |
+
noise: torch.Tensor, noise_pred: torch.Tensor,
|
65 |
+
alphas_cumprod: torch.Tensor,
|
66 |
+
timestep: torch.Tensor,
|
67 |
+
**kwargs
|
68 |
+
) -> torch.Tensor:
|
69 |
+
return torch.mean((kwargs["flow_pred"] - (noise - x)) ** 2)
|
70 |
+
|
71 |
+
|
72 |
+
NAME_TO_CLASS = {
|
73 |
+
"x0": X0PredLoss,
|
74 |
+
"v": VPredLoss,
|
75 |
+
"noise": NoisePredLoss,
|
76 |
+
"flow": FlowPredLoss
|
77 |
+
}
|
78 |
+
|
79 |
+
|
80 |
+
def get_denoising_loss(loss_type: str) -> DenoisingLoss:
|
81 |
+
return NAME_TO_CLASS[loss_type]
|
utils/misc.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def set_seed(seed: int, deterministic: bool = False):
|
7 |
+
"""
|
8 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
seed (`int`):
|
12 |
+
The seed to set.
|
13 |
+
deterministic (`bool`, *optional*, defaults to `False`):
|
14 |
+
Whether to use deterministic algorithms where available. Can slow down training.
|
15 |
+
"""
|
16 |
+
random.seed(seed)
|
17 |
+
np.random.seed(seed)
|
18 |
+
torch.manual_seed(seed)
|
19 |
+
torch.cuda.manual_seed_all(seed)
|
20 |
+
|
21 |
+
if deterministic:
|
22 |
+
torch.use_deterministic_algorithms(True)
|
23 |
+
|
24 |
+
|
25 |
+
def merge_dict_list(dict_list):
|
26 |
+
if len(dict_list) == 1:
|
27 |
+
return dict_list[0]
|
28 |
+
|
29 |
+
merged_dict = {}
|
30 |
+
for k, v in dict_list[0].items():
|
31 |
+
if isinstance(v, torch.Tensor):
|
32 |
+
if v.ndim == 0:
|
33 |
+
merged_dict[k] = torch.stack([d[k] for d in dict_list], dim=0)
|
34 |
+
else:
|
35 |
+
merged_dict[k] = torch.cat([d[k] for d in dict_list], dim=0)
|
36 |
+
else:
|
37 |
+
# for non-tensor values, we just copy the value from the first item
|
38 |
+
merged_dict[k] = v
|
39 |
+
return merged_dict
|