multimodalart HF Staff commited on
Commit
0fd2f06
·
verified ·
1 Parent(s): 029803b

Upload 80 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE.md +173 -0
  2. README.md +106 -10
  3. configs/default_config.yaml +20 -0
  4. configs/self_forcing_dmd.yaml +51 -0
  5. configs/self_forcing_sid.yaml +53 -0
  6. demo.py +631 -0
  7. demo_utils/constant.py +41 -0
  8. demo_utils/memory.py +135 -0
  9. demo_utils/taehv.py +313 -0
  10. demo_utils/utils.py +616 -0
  11. demo_utils/vae.py +390 -0
  12. demo_utils/vae_block3.py +291 -0
  13. demo_utils/vae_torch2trt.py +308 -0
  14. images/.gitkeep +0 -0
  15. inference.py +179 -0
  16. model/__init__.py +14 -0
  17. model/base.py +222 -0
  18. model/causvid.py +391 -0
  19. model/diffusion.py +125 -0
  20. model/dmd.py +332 -0
  21. model/gan.py +295 -0
  22. model/ode_regression.py +138 -0
  23. model/sid.py +283 -0
  24. pipeline/__init__.py +13 -0
  25. pipeline/bidirectional_diffusion_inference.py +110 -0
  26. pipeline/bidirectional_inference.py +71 -0
  27. pipeline/causal_diffusion_inference.py +342 -0
  28. pipeline/causal_inference.py +305 -0
  29. pipeline/self_forcing_training.py +267 -0
  30. prompts/MovieGenVideoBench.txt +0 -0
  31. prompts/MovieGenVideoBench_extended.txt +0 -0
  32. prompts/vbench/all_dimension.txt +946 -0
  33. prompts/vbench/all_dimension_extended.txt +0 -0
  34. requirements.txt +40 -0
  35. scripts/create_lmdb_14b_shards.py +101 -0
  36. scripts/create_lmdb_iterative.py +60 -0
  37. scripts/generate_ode_pairs.py +120 -0
  38. setup.py +6 -0
  39. templates/demo.html +615 -0
  40. train.py +47 -0
  41. trainer/__init__.py +11 -0
  42. trainer/diffusion.py +265 -0
  43. trainer/distillation.py +388 -0
  44. trainer/gan.py +464 -0
  45. trainer/ode.py +242 -0
  46. utils/dataset.py +220 -0
  47. utils/distributed.py +125 -0
  48. utils/lmdb.py +72 -0
  49. utils/loss.py +81 -0
  50. utils/misc.py +39 -0
LICENSE.md ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ Creative Commons Corporation (“Creative Commons”) is not a law firm and does not provide legal services or legal advice. Distribution of Creative Commons public licenses does not create a lawyer-client or other relationship. Creative Commons makes its licenses and related information available on an “as-is” basis. Creative Commons gives no warranties regarding its licenses, any material licensed under their terms and conditions, or any related information. Creative Commons disclaims all liability for damages resulting from their use to the fullest extent possible.
4
+
5
+ ### Using Creative Commons Public Licenses
6
+
7
+ Creative Commons public licenses provide a standard set of terms and conditions that creators and other rights holders may use to share original works of authorship and other material subject to copyright and certain other rights specified in the public license below. The following considerations are for informational purposes only, are not exhaustive, and do not form part of our licenses.
8
+
9
+ * __Considerations for licensors:__ Our public licenses are intended for use by those authorized to give the public permission to use material in ways otherwise restricted by copyright and certain other rights. Our licenses are irrevocable. Licensors should read and understand the terms and conditions of the license they choose before applying it. Licensors should also secure all rights necessary before applying our licenses so that the public can reuse the material as expected. Licensors should clearly mark any material not subject to the license. This includes other CC-licensed material, or material used under an exception or limitation to copyright. [More considerations for licensors](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensors).
10
+
11
+ * __Considerations for the public:__ By using one of our public licenses, a licensor grants the public permission to use the licensed material under specified terms and conditions. If the licensor’s permission is not necessary for any reason–for example, because of any applicable exception or limitation to copyright–then that use is not regulated by the license. Our licenses grant only permissions under copyright and certain other rights that a licensor has authority to grant. Use of the licensed material may still be restricted for other reasons, including because others have copyright or other rights in the material. A licensor may make special requests, such as asking that all changes be marked or described. Although not required by our licenses, you are encouraged to respect those requests where reasonable. [More considerations for the public](http://wiki.creativecommons.org/Considerations_for_licensors_and_licensees#Considerations_for_licensees).
12
+
13
+ ## Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License
14
+
15
+ By exercising the Licensed Rights (defined below), You accept and agree to be bound by the terms and conditions of this Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International Public License ("Public License"). To the extent this Public License may be interpreted as a contract, You are granted the Licensed Rights in consideration of Your acceptance of these terms and conditions, and the Licensor grants You such rights in consideration of benefits the Licensor receives from making the Licensed Material available under these terms and conditions.
16
+
17
+ ### Section 1 – Definitions.
18
+
19
+ a. __Adapted Material__ means material subject to Copyright and Similar Rights that is derived from or based upon the Licensed Material and in which the Licensed Material is translated, altered, arranged, transformed, or otherwise modified in a manner requiring permission under the Copyright and Similar Rights held by the Licensor. For purposes of this Public License, where the Licensed Material is a musical work, performance, or sound recording, Adapted Material is always produced where the Licensed Material is synched in timed relation with a moving image.
20
+
21
+ b. __Adapter's License__ means the license You apply to Your Copyright and Similar Rights in Your contributions to Adapted Material in accordance with the terms and conditions of this Public License.
22
+
23
+ c. __BY-NC-SA Compatible License__ means a license listed at [creativecommons.org/compatiblelicenses](http://creativecommons.org/compatiblelicenses), approved by Creative Commons as essentially the equivalent of this Public License.
24
+
25
+ d. __Copyright and Similar Rights__ means copyright and/or similar rights closely related to copyright including, without limitation, performance, broadcast, sound recording, and Sui Generis Database Rights, without regard to how the rights are labeled or categorized. For purposes of this Public License, the rights specified in Section 2(b)(1)-(2) are not Copyright and Similar Rights.
26
+
27
+ e. __Effective Technological Measures__ means those measures that, in the absence of proper authority, may not be circumvented under laws fulfilling obligations under Article 11 of the WIPO Copyright Treaty adopted on December 20, 1996, and/or similar international agreements.
28
+
29
+ f. __Exceptions and Limitations__ means fair use, fair dealing, and/or any other exception or limitation to Copyright and Similar Rights that applies to Your use of the Licensed Material.
30
+
31
+ g. __License Elements__ means the license attributes listed in the name of a Creative Commons Public License. The License Elements of this Public License are Attribution, NonCommercial, and ShareAlike.
32
+
33
+ h. __Licensed Material__ means the artistic or literary work, database, or other material to which the Licensor applied this Public License.
34
+
35
+ i. __Licensed Rights__ means the rights granted to You subject to the terms and conditions of this Public License, which are limited to all Copyright and Similar Rights that apply to Your use of the Licensed Material and that the Licensor has authority to license.
36
+
37
+ j. __Licensor__ means the individual(s) or entity(ies) granting rights under this Public License.
38
+
39
+ k. __NonCommercial__ means not primarily intended for or directed towards commercial advantage or monetary compensation. For purposes of this Public License, the exchange of the Licensed Material for other material subject to Copyright and Similar Rights by digital file-sharing or similar means is NonCommercial provided there is no payment of monetary compensation in connection with the exchange.
40
+
41
+ l. __Share__ means to provide material to the public by any means or process that requires permission under the Licensed Rights, such as reproduction, public display, public performance, distribution, dissemination, communication, or importation, and to make material available to the public including in ways that members of the public may access the material from a place and at a time individually chosen by them.
42
+
43
+ m. __Sui Generis Database Rights__ means rights other than copyright resulting from Directive 96/9/EC of the European Parliament and of the Council of 11 March 1996 on the legal protection of databases, as amended and/or succeeded, as well as other essentially equivalent rights anywhere in the world.
44
+
45
+ n. __You__ means the individual or entity exercising the Licensed Rights under this Public License. Your has a corresponding meaning.
46
+
47
+ ### Section 2 – Scope.
48
+
49
+ a. ___License grant.___
50
+
51
+ 1. Subject to the terms and conditions of this Public License, the Licensor hereby grants You a worldwide, royalty-free, non-sublicensable, non-exclusive, irrevocable license to exercise the Licensed Rights in the Licensed Material to:
52
+
53
+ A. reproduce and Share the Licensed Material, in whole or in part, for NonCommercial purposes only; and
54
+
55
+ B. produce, reproduce, and Share Adapted Material for NonCommercial purposes only.
56
+
57
+ 2. __Exceptions and Limitations.__ For the avoidance of doubt, where Exceptions and Limitations apply to Your use, this Public License does not apply, and You do not need to comply with its terms and conditions.
58
+
59
+ 3. __Term.__ The term of this Public License is specified in Section 6(a).
60
+
61
+ 4. __Media and formats; technical modifications allowed.__ The Licensor authorizes You to exercise the Licensed Rights in all media and formats whether now known or hereafter created, and to make technical modifications necessary to do so. The Licensor waives and/or agrees not to assert any right or authority to forbid You from making technical modifications necessary to exercise the Licensed Rights, including technical modifications necessary to circumvent Effective Technological Measures. For purposes of this Public License, simply making modifications authorized by this Section 2(a)(4) never produces Adapted Material.
62
+
63
+ 5. __Downstream recipients.__
64
+
65
+ A. __Offer from the Licensor – Licensed Material.__ Every recipient of the Licensed Material automatically receives an offer from the Licensor to exercise the Licensed Rights under the terms and conditions of this Public License.
66
+
67
+ B. __Additional offer from the Licensor – Adapted Material.__ Every recipient of Adapted Material from You automatically receives an offer from the Licensor to exercise the Licensed Rights in the Adapted Material under the conditions of the Adapter’s License You apply.
68
+
69
+ C. __No downstream restrictions.__ You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, the Licensed Material if doing so restricts exercise of the Licensed Rights by any recipient of the Licensed Material.
70
+
71
+ 6. __No endorsement.__ Nothing in this Public License constitutes or may be construed as permission to assert or imply that You are, or that Your use of the Licensed Material is, connected with, or sponsored, endorsed, or granted official status by, the Licensor or others designated to receive attribution as provided in Section 3(a)(1)(A)(i).
72
+
73
+ b. ___Other rights.___
74
+
75
+ 1. Moral rights, such as the right of integrity, are not licensed under this Public License, nor are publicity, privacy, and/or other similar personality rights; however, to the extent possible, the Licensor waives and/or agrees not to assert any such rights held by the Licensor to the limited extent necessary to allow You to exercise the Licensed Rights, but not otherwise.
76
+
77
+ 2. Patent and trademark rights are not licensed under this Public License.
78
+
79
+ 3. To the extent possible, the Licensor waives any right to collect royalties from You for the exercise of the Licensed Rights, whether directly or through a collecting society under any voluntary or waivable statutory or compulsory licensing scheme. In all other cases the Licensor expressly reserves any right to collect such royalties, including when the Licensed Material is used other than for NonCommercial purposes.
80
+
81
+ ### Section 3 – License Conditions.
82
+
83
+ Your exercise of the Licensed Rights is expressly made subject to the following conditions.
84
+
85
+ a. ___Attribution.___
86
+
87
+ 1. If You Share the Licensed Material (including in modified form), You must:
88
+
89
+ A. retain the following if it is supplied by the Licensor with the Licensed Material:
90
+
91
+ i. identification of the creator(s) of the Licensed Material and any others designated to receive attribution, in any reasonable manner requested by the Licensor (including by pseudonym if designated);
92
+
93
+ ii. a copyright notice;
94
+
95
+ iii. a notice that refers to this Public License;
96
+
97
+ iv. a notice that refers to the disclaimer of warranties;
98
+
99
+ v. a URI or hyperlink to the Licensed Material to the extent reasonably practicable;
100
+
101
+ B. indicate if You modified the Licensed Material and retain an indication of any previous modifications; and
102
+
103
+ C. indicate the Licensed Material is licensed under this Public License, and include the text of, or the URI or hyperlink to, this Public License.
104
+
105
+ 2. You may satisfy the conditions in Section 3(a)(1) in any reasonable manner based on the medium, means, and context in which You Share the Licensed Material. For example, it may be reasonable to satisfy the conditions by providing a URI or hyperlink to a resource that includes the required information.
106
+
107
+ 3. If requested by the Licensor, You must remove any of the information required by Section 3(a)(1)(A) to the extent reasonably practicable.
108
+
109
+ b. ___ShareAlike.___
110
+
111
+ In addition to the conditions in Section 3(a), if You Share Adapted Material You produce, the following conditions also apply.
112
+
113
+ 1. The Adapter’s License You apply must be a Creative Commons license with the same License Elements, this version or later, or a BY-NC-SA Compatible License.
114
+
115
+ 2. You must include the text of, or the URI or hyperlink to, the Adapter's License You apply. You may satisfy this condition in any reasonable manner based on the medium, means, and context in which You Share Adapted Material.
116
+
117
+ 3. You may not offer or impose any additional or different terms or conditions on, or apply any Effective Technological Measures to, Adapted Material that restrict exercise of the rights granted under the Adapter's License You apply.
118
+
119
+ ### Section 4 – Sui Generis Database Rights.
120
+
121
+ Where the Licensed Rights include Sui Generis Database Rights that apply to Your use of the Licensed Material:
122
+
123
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right to extract, reuse, reproduce, and Share all or a substantial portion of the contents of the database for NonCommercial purposes only;
124
+
125
+ b. if You include all or a substantial portion of the database contents in a database in which You have Sui Generis Database Rights, then the database in which You have Sui Generis Database Rights (but not its individual contents) is Adapted Material, including for purposes of Section 3(b); and
126
+
127
+ c. You must comply with the conditions in Section 3(a) if You Share all or a substantial portion of the contents of the database.
128
+
129
+ For the avoidance of doubt, this Section 4 supplements and does not replace Your obligations under this Public License where the Licensed Rights include other Copyright and Similar Rights.
130
+
131
+ ### Section 5 – Disclaimer of Warranties and Limitation of Liability.
132
+
133
+ a. __Unless otherwise separately undertaken by the Licensor, to the extent possible, the Licensor offers the Licensed Material as-is and as-available, and makes no representations or warranties of any kind concerning the Licensed Material, whether express, implied, statutory, or other. This includes, without limitation, warranties of title, merchantability, fitness for a particular purpose, non-infringement, absence of latent or other defects, accuracy, or the presence or absence of errors, whether or not known or discoverable. Where disclaimers of warranties are not allowed in full or in part, this disclaimer may not apply to You.__
134
+
135
+ b. __To the extent possible, in no event will the Licensor be liable to You on any legal theory (including, without limitation, negligence) or otherwise for any direct, special, indirect, incidental, consequential, punitive, exemplary, or other losses, costs, expenses, or damages arising out of this Public License or use of the Licensed Material, even if the Licensor has been advised of the possibility of such losses, costs, expenses, or damages. Where a limitation of liability is not allowed in full or in part, this limitation may not apply to You.__
136
+
137
+ c. The disclaimer of warranties and limitation of liability provided above shall be interpreted in a manner that, to the extent possible, most closely approximates an absolute disclaimer and waiver of all liability.
138
+
139
+ ### Section 6 – Term and Termination.
140
+
141
+ a. This Public License applies for the term of the Copyright and Similar Rights licensed here. However, if You fail to comply with this Public License, then Your rights under this Public License terminate automatically.
142
+
143
+ b. Where Your right to use the Licensed Material has terminated under Section 6(a), it reinstates:
144
+
145
+ 1. automatically as of the date the violation is cured, provided it is cured within 30 days of Your discovery of the violation; or
146
+
147
+ 2. upon express reinstatement by the Licensor.
148
+
149
+ For the avoidance of doubt, this Section 6(b) does not affect any right the Licensor may have to seek remedies for Your violations of this Public License.
150
+
151
+ c. For the avoidance of doubt, the Licensor may also offer the Licensed Material under separate terms or conditions or stop distributing the Licensed Material at any time; however, doing so will not terminate this Public License.
152
+
153
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public License.
154
+
155
+ ### Section 7 – Other Terms and Conditions.
156
+
157
+ a. The Licensor shall not be bound by any additional or different terms or conditions communicated by You unless expressly agreed.
158
+
159
+ b. Any arrangements, understandings, or agreements regarding the Licensed Material not stated herein are separate from and independent of the terms and conditions of this Public License.
160
+
161
+ ### Section 8 – Interpretation.
162
+
163
+ a. For the avoidance of doubt, this Public License does not, and shall not be interpreted to, reduce, limit, restrict, or impose conditions on any use of the Licensed Material that could lawfully be made without permission under this Public License.
164
+
165
+ b. To the extent possible, if any provision of this Public License is deemed unenforceable, it shall be automatically reformed to the minimum extent necessary to make it enforceable. If the provision cannot be reformed, it shall be severed from this Public License without affecting the enforceability of the remaining terms and conditions.
166
+
167
+ c. No term or condition of this Public License will be waived and no failure to comply consented to unless expressly agreed to by the Licensor.
168
+
169
+ d. Nothing in this Public License constitutes or may be interpreted as a limitation upon, or waiver of, any privileges and immunities that apply to the Licensor or You, including from the legal processes of any jurisdiction or authority.
170
+
171
+ > Creative Commons is not a party to its public licenses. Notwithstanding, Creative Commons may elect to apply one of its public licenses to material it publishes and in those instances will be considered the “Licensor.” The text of the Creative Commons public licenses is dedicated to the public domain under the CC0 Public Domain Dedication. Except for the limited purpose of indicating that material is shared under a Creative Commons public license or as otherwise permitted by the Creative Commons policies published at creativecommons.org/policies, Creative Commons does not authorize the use of the trademark “Creative Commons” or any other trademark or logo of Creative Commons without its prior written consent including, without limitation, in connection with any unauthorized modifications to any of its public licenses or any other arrangements, understandings, or agreements concerning use of licensed material. For the avoidance of doubt, this paragraph does not form part of the public licenses.
172
+ >
173
+ > Creative Commons may be contacted at creativecommons.org
README.md CHANGED
@@ -1,13 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
- title: Self Forcing
3
- emoji: 🦀
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.34.1
8
- app_file: app.py
9
- pinned: false
10
- short_description: Real-Time video generation with Self-Forcing Wan 1.3B
11
  ---
12
 
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <p align="center">
2
+ <h1 align="center">Self Forcing</h1>
3
+ <h3 align="center">Bridging the Train-Test Gap in Autoregressive Video Diffusion</h3>
4
+ </p>
5
+ <p align="center">
6
+ <p align="center">
7
+ <a href="https://www.xunhuang.me/">Xun Huang</a><sup>1</sup>
8
+ ·
9
+ <a href="https://zhengqili.github.io/">Zhengqi Li</a><sup>1</sup>
10
+ ·
11
+ <a href="https://guandehe.github.io/">Guande He</a><sup>2</sup>
12
+ ·
13
+ <a href="https://mingyuanzhou.github.io/">Mingyuan Zhou</a><sup>2</sup>
14
+ ·
15
+ <a href="https://research.adobe.com/person/eli-shechtman/">Eli Shechtman</a><sup>1</sup><br>
16
+ <sup>1</sup>Adobe Research <sup>2</sup>UT Austin
17
+ </p>
18
+ <h3 align="center"><a href="https://arxiv.org/abs/2506.08009">Paper</a> | <a href="https://self-forcing.github.io">Website</a> | <a href="https://huggingface.co/gdhe17/Self-Forcing/tree/main">Models (HuggingFace)</a></h3>
19
+ </p>
20
+
21
  ---
22
+
23
+ Self Forcing trains autoregressive video diffusion models by **simulating the inference process during training**, performing autoregressive rollout with KV caching. It resolves the train-test distribution mismatch and enables **real-time, streaming video generation on a single RTX 4090** while matching the quality of state-of-the-art diffusion models.
24
+
 
 
 
 
 
 
25
  ---
26
 
27
+
28
+ https://github.com/user-attachments/assets/7548c2db-fe03-4ba8-8dd3-52d2c6160739
29
+
30
+
31
+ ## Requirements
32
+ We tested this repo on the following setup:
33
+ * Nvidia GPU with at least 24 GB memory (RTX 4090, A100, and H100 are tested).
34
+ * Linux operating system.
35
+ * 64 GB RAM.
36
+
37
+ Other hardware setup could also work but hasn't been tested.
38
+
39
+ ## Installation
40
+ Create a conda environment and install dependencies:
41
+ ```
42
+ conda create -n self_forcing python=3.10 -y
43
+ conda activate self_forcing
44
+ pip install -r requirements.txt
45
+ pip install flash-attn --no-build-isolation
46
+ python setup.py develop
47
+ ```
48
+
49
+ ## Quick Start
50
+ ### Download checkpoints
51
+ ```
52
+ huggingface-cli download Wan-AI/Wan2.1-T2V-1.3B --local-dir-use-symlinks False --local-dir wan_models/Wan2.1-T2V-1.3B
53
+ huggingface-cli download gdhe17/Self-Forcing checkpoints/self_forcing_dmd.pt --local-dir .
54
+ ```
55
+
56
+ ### GUI demo
57
+ ```
58
+ python demo.py
59
+ ```
60
+ Note:
61
+ * **Our model works better with long, detailed prompts** since it's trained with such prompts. We will integrate prompt extension into the codebase (similar to [Wan2.1](https://github.com/Wan-Video/Wan2.1/tree/main?tab=readme-ov-file#2-using-prompt-extention)) in the future. For now, it is recommended to use third-party LLMs (such as GPT-4o) to extend your prompt before providing to the model.
62
+ * You may want to adjust FPS so it plays smoothly on your device.
63
+ * The speed can be improved by enabling `torch.compile`, [TAEHV-VAE](https://github.com/madebyollin/taehv/), or using FP8 Linear layers, although the latter two options may sacrifice quality. It is recommended to use `torch.compile` if possible and enable TAEHV-VAE if further speedup is needed.
64
+
65
+ ### CLI Inference
66
+ Example inference script using the chunk-wise autoregressive checkpoint trained with DMD:
67
+ ```
68
+ python inference.py \
69
+ --config_path configs/self_forcing_dmd.yaml \
70
+ --output_folder videos/self_forcing_dmd \
71
+ --checkpoint_path checkpoints/self_forcing_dmd.pt \
72
+ --data_path prompts/MovieGenVideoBench_extended.txt \
73
+ --use_ema
74
+ ```
75
+ Other config files and corresponding checkpoints can be found in [configs](configs) folder and our [huggingface repo](https://huggingface.co/gdhe17/Self-Forcing/tree/main/checkpoints).
76
+
77
+ ## Training
78
+ ### Download text prompts and ODE initialized checkpoint
79
+ ```
80
+ huggingface-cli download gdhe17/Self-Forcing checkpoints/ode_init.pt --local-dir .
81
+ huggingface-cli download gdhe17/Self-Forcing vidprom_filtered_extended.txt --local-dir prompts
82
+ ```
83
+ Note: Our training algorithm (except for the GAN version) is data-free (**no video data is needed**). For now, we directly provide the ODE initialization checkpoint and will add more instructions on how to perform ODE initialization in the future (which is identical to the process described in the [CausVid](https://github.com/tianweiy/CausVid) repo).
84
+
85
+ ### Self Forcing Training with DMD
86
+ ```
87
+ torchrun --nnodes=8 --nproc_per_node=8 --rdzv_id=5235 \
88
+ --rdzv_backend=c10d \
89
+ --rdzv_endpoint $MASTER_ADDR \
90
+ train.py \
91
+ --config_path configs/self_forcing_dmd.yaml \
92
+ --logdir logs/self_forcing_dmd \
93
+ --disable-wandb
94
+ ```
95
+ Our training run uses 600 iterations and completes in under 2 hours using 64 H100 GPUs. By implementing gradient accumulation, it should be possible to reproduce the results in less than 16 hours using 8 H100 GPUs.
96
+
97
+ ## Acknowledgements
98
+ This codebase is built on top of the open-source implementation of [CausVid](https://github.com/tianweiy/CausVid) by [Tianwei Yin](https://tianweiy.github.io/) and the [Wan2.1](https://github.com/Wan-Video/Wan2.1) repo.
99
+
100
+ ## Citation
101
+ If you find this codebase useful for your research, please kindly cite our paper:
102
+ ```
103
+ @article{huang2025selfforcing,
104
+ title={Self Forcing: Bridging the Train-Test Gap in Autoregressive Video Diffusion},
105
+ author={Huang, Xun and Li, Zhengqi and He, Guande and Zhou, Mingyuan and Shechtman, Eli},
106
+ journal={arXiv preprint arXiv:2506.08009},
107
+ year={2025}
108
+ }
109
+ ```
configs/default_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ independent_first_frame: false
2
+ warp_denoising_step: false
3
+ weight_decay: 0.01
4
+ same_step_across_blocks: true
5
+ discriminator_lr_multiplier: 1.0
6
+ last_step_only: false
7
+ i2v: false
8
+ num_training_frames: 21
9
+ gc_interval: 100
10
+ context_noise: 0
11
+ causal: true
12
+
13
+ ckpt_step: 0
14
+ prompt_name: MovieGenVideoBench
15
+ prompt_path: prompts/MovieGenVideoBench.txt
16
+ eval_first_n: 64
17
+ num_samples: 1
18
+ height: 480
19
+ width: 832
20
+ num_frames: 81
configs/self_forcing_dmd.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_ckpt: checkpoints/ode_init.pt
2
+ generator_fsdp_wrap_strategy: size
3
+ real_score_fsdp_wrap_strategy: size
4
+ fake_score_fsdp_wrap_strategy: size
5
+ real_name: Wan2.1-T2V-14B
6
+ text_encoder_fsdp_wrap_strategy: size
7
+ denoising_step_list:
8
+ - 1000
9
+ - 750
10
+ - 500
11
+ - 250
12
+ warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
13
+ ts_schedule: false
14
+ num_train_timestep: 1000
15
+ timestep_shift: 5.0
16
+ guidance_scale: 3.0
17
+ denoising_loss_type: flow
18
+ mixed_precision: true
19
+ seed: 0
20
+ wandb_host: WANDB_HOST
21
+ wandb_key: WANDB_KEY
22
+ wandb_entity: WANDB_ENTITY
23
+ wandb_project: WANDB_PROJECT
24
+ sharding_strategy: hybrid_full
25
+ lr: 2.0e-06
26
+ lr_critic: 4.0e-07
27
+ beta1: 0.0
28
+ beta2: 0.999
29
+ beta1_critic: 0.0
30
+ beta2_critic: 0.999
31
+ data_path: prompts/vidprom_filtered_extended.txt
32
+ batch_size: 1
33
+ ema_weight: 0.99
34
+ ema_start_step: 200
35
+ total_batch_size: 64
36
+ log_iters: 50
37
+ negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
38
+ dfake_gen_update_ratio: 5
39
+ image_or_video_shape:
40
+ - 1
41
+ - 21
42
+ - 16
43
+ - 60
44
+ - 104
45
+ distribution_loss: dmd
46
+ trainer: score_distillation
47
+ gradient_checkpointing: true
48
+ num_frame_per_block: 3
49
+ load_raw_video: false
50
+ model_kwargs:
51
+ timestep_shift: 5.0
configs/self_forcing_sid.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ generator_ckpt: checkpoints/ode_init.pt
2
+ generator_fsdp_wrap_strategy: size
3
+ real_score_fsdp_wrap_strategy: size
4
+ fake_score_fsdp_wrap_strategy: size
5
+ real_name: Wan2.1-T2V-1.3B
6
+ text_encoder_fsdp_wrap_strategy: size
7
+ denoising_step_list:
8
+ - 1000
9
+ - 750
10
+ - 500
11
+ - 250
12
+ warp_denoising_step: true # need to remove - 0 in denoising_step_list if warp_denoising_step is true
13
+ ts_schedule: false
14
+ num_train_timestep: 1000
15
+ timestep_shift: 5.0
16
+ guidance_scale: 3.0
17
+ denoising_loss_type: flow
18
+ mixed_precision: true
19
+ seed: 0
20
+ wandb_host: WANDB_HOST
21
+ wandb_key: WANDB_KEY
22
+ wandb_entity: WANDB_ENTITY
23
+ wandb_project: WANDB_PROJECT
24
+ sharding_strategy: hybrid_full
25
+ lr: 2.0e-06
26
+ lr_critic: 2.0e-06
27
+ beta1: 0.0
28
+ beta2: 0.999
29
+ beta1_critic: 0.0
30
+ beta2_critic: 0.999
31
+ weight_decay: 0.0
32
+ data_path: prompts/vidprom_filtered_extended.txt
33
+ batch_size: 1
34
+ sid_alpha: 1.0
35
+ ema_weight: 0.99
36
+ ema_start_step: 200
37
+ total_batch_size: 64
38
+ log_iters: 50
39
+ negative_prompt: '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
40
+ dfake_gen_update_ratio: 5
41
+ image_or_video_shape:
42
+ - 1
43
+ - 21
44
+ - 16
45
+ - 60
46
+ - 104
47
+ distribution_loss: dmd
48
+ trainer: score_distillation
49
+ gradient_checkpointing: true
50
+ num_frame_per_block: 3
51
+ load_raw_video: false
52
+ model_kwargs:
53
+ timestep_shift: 5.0
demo.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Demo for Self-Forcing.
3
+ """
4
+
5
+ import os
6
+ import re
7
+ import random
8
+ import time
9
+ import base64
10
+ import argparse
11
+ import hashlib
12
+ import subprocess
13
+ import urllib.request
14
+ from io import BytesIO
15
+ from PIL import Image
16
+ import numpy as np
17
+ import torch
18
+ from omegaconf import OmegaConf
19
+ from flask import Flask, render_template, jsonify
20
+ from flask_socketio import SocketIO, emit
21
+ import queue
22
+ from threading import Thread, Event
23
+
24
+ from pipeline import CausalInferencePipeline
25
+ from demo_utils.constant import ZERO_VAE_CACHE
26
+ from demo_utils.vae_block3 import VAEDecoderWrapper
27
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
28
+ from demo_utils.utils import generate_timestamp
29
+ from demo_utils.memory import gpu, get_cuda_free_memory_gb, DynamicSwapInstaller, move_model_to_device_with_memory_preservation
30
+
31
+ # Parse arguments
32
+ parser = argparse.ArgumentParser()
33
+ parser.add_argument('--port', type=int, default=5001)
34
+ parser.add_argument('--host', type=str, default='0.0.0.0')
35
+ parser.add_argument("--checkpoint_path", type=str, default='./checkpoints/self_forcing_dmd.pt')
36
+ parser.add_argument("--config_path", type=str, default='./configs/self_forcing_dmd.yaml')
37
+ parser.add_argument('--trt', action='store_true')
38
+ args = parser.parse_args()
39
+
40
+ print(f'Free VRAM {get_cuda_free_memory_gb(gpu)} GB')
41
+ low_memory = get_cuda_free_memory_gb(gpu) < 40
42
+
43
+ # Load models
44
+ config = OmegaConf.load(args.config_path)
45
+ default_config = OmegaConf.load("configs/default_config.yaml")
46
+ config = OmegaConf.merge(default_config, config)
47
+
48
+ text_encoder = WanTextEncoder()
49
+
50
+ # Global variables for dynamic model switching
51
+ current_vae_decoder = None
52
+ current_use_taehv = False
53
+ fp8_applied = False
54
+ torch_compile_applied = False
55
+ global frame_number
56
+ frame_number = 0
57
+ anim_name = ""
58
+ frame_rate = 6
59
+
60
+ def initialize_vae_decoder(use_taehv=False, use_trt=False):
61
+ """Initialize VAE decoder based on the selected option"""
62
+ global current_vae_decoder, current_use_taehv
63
+
64
+ if use_trt:
65
+ from demo_utils.vae import VAETRTWrapper
66
+ current_vae_decoder = VAETRTWrapper()
67
+ return current_vae_decoder
68
+
69
+ if use_taehv:
70
+ from demo_utils.taehv import TAEHV
71
+ # Check if taew2_1.pth exists in checkpoints folder, download if missing
72
+ taehv_checkpoint_path = "checkpoints/taew2_1.pth"
73
+ if not os.path.exists(taehv_checkpoint_path):
74
+ print(f"taew2_1.pth not found in checkpoints folder {taehv_checkpoint_path}. Downloading...")
75
+ os.makedirs("checkpoints", exist_ok=True)
76
+ download_url = "https://github.com/madebyollin/taehv/raw/main/taew2_1.pth"
77
+ try:
78
+ urllib.request.urlretrieve(download_url, taehv_checkpoint_path)
79
+ print(f"Successfully downloaded taew2_1.pth to {taehv_checkpoint_path}")
80
+ except Exception as e:
81
+ print(f"Failed to download taew2_1.pth: {e}")
82
+ raise
83
+
84
+ class DotDict(dict):
85
+ __getattr__ = dict.__getitem__
86
+ __setattr__ = dict.__setitem__
87
+
88
+ class TAEHVDiffusersWrapper(torch.nn.Module):
89
+ def __init__(self):
90
+ super().__init__()
91
+ self.dtype = torch.float16
92
+ self.taehv = TAEHV(checkpoint_path=taehv_checkpoint_path).to(self.dtype)
93
+ self.config = DotDict(scaling_factor=1.0)
94
+
95
+ def decode(self, latents, return_dict=None):
96
+ # n, c, t, h, w = latents.shape
97
+ # low-memory, set parallel=True for faster + higher memory
98
+ return self.taehv.decode_video(latents, parallel=False).mul_(2).sub_(1)
99
+
100
+ current_vae_decoder = TAEHVDiffusersWrapper()
101
+ else:
102
+ current_vae_decoder = VAEDecoderWrapper()
103
+ vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
104
+ decoder_state_dict = {}
105
+ for key, value in vae_state_dict.items():
106
+ if 'decoder.' in key or 'conv2' in key:
107
+ decoder_state_dict[key] = value
108
+ current_vae_decoder.load_state_dict(decoder_state_dict)
109
+
110
+ current_vae_decoder.eval()
111
+ current_vae_decoder.to(dtype=torch.float16)
112
+ current_vae_decoder.requires_grad_(False)
113
+ current_vae_decoder.to(gpu)
114
+ current_use_taehv = use_taehv
115
+
116
+ print(f"✅ VAE decoder initialized with {'TAEHV' if use_taehv else 'default VAE'}")
117
+ return current_vae_decoder
118
+
119
+
120
+ # Initialize with default VAE
121
+ vae_decoder = initialize_vae_decoder(use_taehv=False, use_trt=args.trt)
122
+
123
+ transformer = WanDiffusionWrapper(is_causal=True)
124
+ state_dict = torch.load(args.checkpoint_path, map_location="cpu")
125
+ transformer.load_state_dict(state_dict['generator_ema'])
126
+
127
+ text_encoder.eval()
128
+ transformer.eval()
129
+
130
+ transformer.to(dtype=torch.float16)
131
+ text_encoder.to(dtype=torch.bfloat16)
132
+
133
+ text_encoder.requires_grad_(False)
134
+ transformer.requires_grad_(False)
135
+
136
+ pipeline = CausalInferencePipeline(
137
+ config,
138
+ device=gpu,
139
+ generator=transformer,
140
+ text_encoder=text_encoder,
141
+ vae=vae_decoder
142
+ )
143
+
144
+ if low_memory:
145
+ DynamicSwapInstaller.install_model(text_encoder, device=gpu)
146
+ else:
147
+ text_encoder.to(gpu)
148
+ transformer.to(gpu)
149
+
150
+ # Flask and SocketIO setup
151
+ app = Flask(__name__)
152
+ app.config['SECRET_KEY'] = 'frontend_buffered_demo'
153
+ socketio = SocketIO(app, cors_allowed_origins="*")
154
+
155
+ generation_active = False
156
+ stop_event = Event()
157
+ frame_send_queue = queue.Queue()
158
+ sender_thread = None
159
+ models_compiled = False
160
+
161
+
162
+ def tensor_to_base64_frame(frame_tensor):
163
+ """Convert a single frame tensor to base64 image string."""
164
+ global frame_number, anim_name
165
+ # Clamp and normalize to 0-255
166
+ frame = torch.clamp(frame_tensor.float(), -1., 1.) * 127.5 + 127.5
167
+ frame = frame.to(torch.uint8).cpu().numpy()
168
+
169
+ # CHW -> HWC
170
+ if len(frame.shape) == 3:
171
+ frame = np.transpose(frame, (1, 2, 0))
172
+
173
+ # Convert to PIL Image
174
+ if frame.shape[2] == 3: # RGB
175
+ image = Image.fromarray(frame, 'RGB')
176
+ else: # Handle other formats
177
+ image = Image.fromarray(frame)
178
+
179
+ # Convert to base64
180
+ buffer = BytesIO()
181
+ image.save(buffer, format='JPEG', quality=100)
182
+ if not os.path.exists("./images/%s" % anim_name):
183
+ os.makedirs("./images/%s" % anim_name)
184
+ frame_number += 1
185
+ image.save("./images/%s/%s_%03d.jpg" % (anim_name, anim_name, frame_number))
186
+ img_str = base64.b64encode(buffer.getvalue()).decode()
187
+ return f"data:image/jpeg;base64,{img_str}"
188
+
189
+
190
+ def frame_sender_worker():
191
+ """Background thread that processes frame send queue non-blocking."""
192
+ global frame_send_queue, generation_active, stop_event
193
+
194
+ print("📡 Frame sender thread started")
195
+
196
+ while True:
197
+ frame_data = None
198
+ try:
199
+ # Get frame data from queue
200
+ frame_data = frame_send_queue.get(timeout=1.0)
201
+
202
+ if frame_data is None: # Shutdown signal
203
+ frame_send_queue.task_done() # Mark shutdown signal as done
204
+ break
205
+
206
+ frame_tensor, frame_index, block_index, job_id = frame_data
207
+
208
+ # Convert tensor to base64
209
+ base64_frame = tensor_to_base64_frame(frame_tensor)
210
+
211
+ # Send via SocketIO
212
+ try:
213
+ socketio.emit('frame_ready', {
214
+ 'data': base64_frame,
215
+ 'frame_index': frame_index,
216
+ 'block_index': block_index,
217
+ 'job_id': job_id
218
+ })
219
+ except Exception as e:
220
+ print(f"⚠️ Failed to send frame {frame_index}: {e}")
221
+
222
+ frame_send_queue.task_done()
223
+
224
+ except queue.Empty:
225
+ # Check if we should continue running
226
+ if not generation_active and frame_send_queue.empty():
227
+ break
228
+ except Exception as e:
229
+ print(f"❌ Frame sender error: {e}")
230
+ # Make sure to mark task as done even if there's an error
231
+ if frame_data is not None:
232
+ try:
233
+ frame_send_queue.task_done()
234
+ except Exception as e:
235
+ print(f"❌ Failed to mark frame task as done: {e}")
236
+ break
237
+
238
+ print("📡 Frame sender thread stopped")
239
+
240
+
241
+ @torch.no_grad()
242
+ def generate_video_stream(prompt, seed, enable_torch_compile=False, enable_fp8=False, use_taehv=False):
243
+ """Generate video and push frames immediately to frontend."""
244
+ global generation_active, stop_event, frame_send_queue, sender_thread, models_compiled, torch_compile_applied, fp8_applied, current_vae_decoder, current_use_taehv, frame_rate, anim_name
245
+
246
+ try:
247
+ generation_active = True
248
+ stop_event.clear()
249
+ job_id = generate_timestamp()
250
+
251
+ # Start frame sender thread if not already running
252
+ if sender_thread is None or not sender_thread.is_alive():
253
+ sender_thread = Thread(target=frame_sender_worker, daemon=True)
254
+ sender_thread.start()
255
+
256
+ # Emit progress updates
257
+ def emit_progress(message, progress):
258
+ try:
259
+ socketio.emit('progress', {
260
+ 'message': message,
261
+ 'progress': progress,
262
+ 'job_id': job_id
263
+ })
264
+ except Exception as e:
265
+ print(f"❌ Failed to emit progress: {e}")
266
+
267
+ emit_progress('Starting generation...', 0)
268
+
269
+ # Handle VAE decoder switching
270
+ if use_taehv != current_use_taehv:
271
+ emit_progress('Switching VAE decoder...', 2)
272
+ print(f"🔄 Switching VAE decoder to {'TAEHV' if use_taehv else 'default VAE'}")
273
+ current_vae_decoder = initialize_vae_decoder(use_taehv=use_taehv)
274
+ # Update pipeline with new VAE decoder
275
+ pipeline.vae = current_vae_decoder
276
+
277
+ # Handle FP8 quantization
278
+ if enable_fp8 and not fp8_applied:
279
+ emit_progress('Applying FP8 quantization...', 3)
280
+ print("🔧 Applying FP8 quantization to transformer")
281
+ from torchao.quantization.quant_api import quantize_, Float8DynamicActivationFloat8WeightConfig, PerTensor
282
+ quantize_(transformer, Float8DynamicActivationFloat8WeightConfig(granularity=PerTensor()))
283
+ fp8_applied = True
284
+
285
+ # Text encoding
286
+ emit_progress('Encoding text prompt...', 8)
287
+ conditional_dict = text_encoder(text_prompts=[prompt])
288
+ for key, value in conditional_dict.items():
289
+ conditional_dict[key] = value.to(dtype=torch.float16)
290
+ if low_memory:
291
+ gpu_memory_preservation = get_cuda_free_memory_gb(gpu) + 5
292
+ move_model_to_device_with_memory_preservation(
293
+ text_encoder,target_device=gpu, preserved_memory_gb=gpu_memory_preservation)
294
+
295
+ # Handle torch.compile if enabled
296
+ torch_compile_applied = enable_torch_compile
297
+ if enable_torch_compile and not models_compiled:
298
+ # Compile transformer and decoder
299
+ transformer.compile(mode="max-autotune-no-cudagraphs")
300
+ if not current_use_taehv and not low_memory and not args.trt:
301
+ current_vae_decoder.compile(mode="max-autotune-no-cudagraphs")
302
+
303
+ # Initialize generation
304
+ emit_progress('Initializing generation...', 12)
305
+
306
+ rnd = torch.Generator(gpu).manual_seed(seed)
307
+ # all_latents = torch.zeros([1, 21, 16, 60, 104], device=gpu, dtype=torch.bfloat16)
308
+
309
+ pipeline._initialize_kv_cache(batch_size=1, dtype=torch.float16, device=gpu)
310
+ pipeline._initialize_crossattn_cache(batch_size=1, dtype=torch.float16, device=gpu)
311
+
312
+ noise = torch.randn([1, 21, 16, 60, 104], device=gpu, dtype=torch.float16, generator=rnd)
313
+
314
+ # Generation parameters
315
+ num_blocks = 7
316
+ current_start_frame = 0
317
+ num_input_frames = 0
318
+ all_num_frames = [pipeline.num_frame_per_block] * num_blocks
319
+ if current_use_taehv:
320
+ vae_cache = None
321
+ else:
322
+ vae_cache = ZERO_VAE_CACHE
323
+ for i in range(len(vae_cache)):
324
+ vae_cache[i] = vae_cache[i].to(device=gpu, dtype=torch.float16)
325
+
326
+ total_frames_sent = 0
327
+ generation_start_time = time.time()
328
+
329
+ emit_progress('Generating frames... (frontend handles timing)', 15)
330
+
331
+ for idx, current_num_frames in enumerate(all_num_frames):
332
+ if not generation_active or stop_event.is_set():
333
+ break
334
+
335
+ progress = int(((idx + 1) / len(all_num_frames)) * 80) + 15
336
+
337
+ # Special message for first block with torch.compile
338
+ if idx == 0 and torch_compile_applied and not models_compiled:
339
+ emit_progress(
340
+ f'Processing block 1/{len(all_num_frames)} - Compiling models (may take 5-10 minutes)...', progress)
341
+ print(f"🔥 Processing block {idx+1}/{len(all_num_frames)}")
342
+ models_compiled = True
343
+ else:
344
+ emit_progress(f'Processing block {idx+1}/{len(all_num_frames)}...', progress)
345
+ print(f"🔄 Processing block {idx+1}/{len(all_num_frames)}")
346
+
347
+ block_start_time = time.time()
348
+
349
+ noisy_input = noise[:, current_start_frame -
350
+ num_input_frames:current_start_frame + current_num_frames - num_input_frames]
351
+
352
+ # Denoising loop
353
+ denoising_start = time.time()
354
+ for index, current_timestep in enumerate(pipeline.denoising_step_list):
355
+ if not generation_active or stop_event.is_set():
356
+ break
357
+
358
+ timestep = torch.ones([1, current_num_frames], device=noise.device,
359
+ dtype=torch.int64) * current_timestep
360
+
361
+ if index < len(pipeline.denoising_step_list) - 1:
362
+ _, denoised_pred = transformer(
363
+ noisy_image_or_video=noisy_input,
364
+ conditional_dict=conditional_dict,
365
+ timestep=timestep,
366
+ kv_cache=pipeline.kv_cache1,
367
+ crossattn_cache=pipeline.crossattn_cache,
368
+ current_start=current_start_frame * pipeline.frame_seq_length
369
+ )
370
+ next_timestep = pipeline.denoising_step_list[index + 1]
371
+ noisy_input = pipeline.scheduler.add_noise(
372
+ denoised_pred.flatten(0, 1),
373
+ torch.randn_like(denoised_pred.flatten(0, 1)),
374
+ next_timestep * torch.ones([1 * current_num_frames], device=noise.device, dtype=torch.long)
375
+ ).unflatten(0, denoised_pred.shape[:2])
376
+ else:
377
+ _, denoised_pred = transformer(
378
+ noisy_image_or_video=noisy_input,
379
+ conditional_dict=conditional_dict,
380
+ timestep=timestep,
381
+ kv_cache=pipeline.kv_cache1,
382
+ crossattn_cache=pipeline.crossattn_cache,
383
+ current_start=current_start_frame * pipeline.frame_seq_length
384
+ )
385
+
386
+ if not generation_active or stop_event.is_set():
387
+ break
388
+
389
+ denoising_time = time.time() - denoising_start
390
+ print(f"⚡ Block {idx+1} denoising completed in {denoising_time:.2f}s")
391
+
392
+ # Record output
393
+ # all_latents[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
394
+
395
+ # Update KV cache for next block
396
+ if idx != len(all_num_frames) - 1:
397
+ transformer(
398
+ noisy_image_or_video=denoised_pred,
399
+ conditional_dict=conditional_dict,
400
+ timestep=torch.zeros_like(timestep),
401
+ kv_cache=pipeline.kv_cache1,
402
+ crossattn_cache=pipeline.crossattn_cache,
403
+ current_start=current_start_frame * pipeline.frame_seq_length,
404
+ )
405
+
406
+ # Decode to pixels and send frames immediately
407
+ print(f"🎨 Decoding block {idx+1} to pixels...")
408
+ decode_start = time.time()
409
+ if args.trt:
410
+ all_current_pixels = []
411
+ for i in range(denoised_pred.shape[1]):
412
+ is_first_frame = torch.tensor(1.0).cuda().half() if idx == 0 and i == 0 else \
413
+ torch.tensor(0.0).cuda().half()
414
+ outputs = vae_decoder.forward(denoised_pred[:, i:i + 1, :, :, :].half(), is_first_frame, *vae_cache)
415
+ # outputs = vae_decoder.forward(denoised_pred.float(), *vae_cache)
416
+ current_pixels, vae_cache = outputs[0], outputs[1:]
417
+ print(current_pixels.max(), current_pixels.min())
418
+ all_current_pixels.append(current_pixels.clone())
419
+ pixels = torch.cat(all_current_pixels, dim=1)
420
+ if idx == 0:
421
+ pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
422
+ else:
423
+ if current_use_taehv:
424
+ if vae_cache is None:
425
+ vae_cache = denoised_pred
426
+ else:
427
+ denoised_pred = torch.cat([vae_cache, denoised_pred], dim=1)
428
+ vae_cache = denoised_pred[:, -3:, :, :, :]
429
+ pixels = current_vae_decoder.decode(denoised_pred)
430
+ print(f"denoised_pred shape: {denoised_pred.shape}")
431
+ print(f"pixels shape: {pixels.shape}")
432
+ if idx == 0:
433
+ pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
434
+ else:
435
+ pixels = pixels[:, 12:, :, :, :]
436
+
437
+ else:
438
+ pixels, vae_cache = current_vae_decoder(denoised_pred.half(), *vae_cache)
439
+ if idx == 0:
440
+ pixels = pixels[:, 3:, :, :, :] # Skip first 3 frames of first block
441
+
442
+ decode_time = time.time() - decode_start
443
+ print(f"🎨 Block {idx+1} VAE decoding completed in {decode_time:.2f}s")
444
+
445
+ # Queue frames for non-blocking sending
446
+ block_frames = pixels.shape[1]
447
+ print(f"📡 Queueing {block_frames} frames from block {idx+1} for sending...")
448
+ queue_start = time.time()
449
+
450
+ for frame_idx in range(block_frames):
451
+ if not generation_active or stop_event.is_set():
452
+ break
453
+
454
+ frame_tensor = pixels[0, frame_idx].cpu()
455
+
456
+ # Queue frame data in non-blocking way
457
+ frame_send_queue.put((frame_tensor, total_frames_sent, idx, job_id))
458
+ total_frames_sent += 1
459
+
460
+ queue_time = time.time() - queue_start
461
+ block_time = time.time() - block_start_time
462
+ print(f"✅ Block {idx+1} completed in {block_time:.2f}s ({block_frames} frames queued in {queue_time:.3f}s)")
463
+
464
+ current_start_frame += current_num_frames
465
+
466
+ generation_time = time.time() - generation_start_time
467
+ print(f"🎉 Generation completed in {generation_time:.2f}s! {total_frames_sent} frames queued for sending")
468
+
469
+ # Wait for all frames to be sent before completing
470
+ emit_progress('Waiting for all frames to be sent...', 97)
471
+ print("⏳ Waiting for all frames to be sent...")
472
+ frame_send_queue.join() # Wait for all queued frames to be processed
473
+ print("✅ All frames sent successfully!")
474
+
475
+ generate_mp4_from_images("./images","./videos/"+anim_name+".mp4", frame_rate )
476
+ # Final progress update
477
+ emit_progress('Generation complete!', 100)
478
+
479
+ try:
480
+ socketio.emit('generation_complete', {
481
+ 'message': 'Video generation completed!',
482
+ 'total_frames': total_frames_sent,
483
+ 'generation_time': f"{generation_time:.2f}s",
484
+ 'job_id': job_id
485
+ })
486
+ except Exception as e:
487
+ print(f"❌ Failed to emit generation complete: {e}")
488
+
489
+ except Exception as e:
490
+ print(f"❌ Generation failed: {e}")
491
+ try:
492
+ socketio.emit('error', {
493
+ 'message': f'Generation failed: {str(e)}',
494
+ 'job_id': job_id
495
+ })
496
+ except Exception as e:
497
+ print(f"❌ Failed to emit error: {e}")
498
+ finally:
499
+ generation_active = False
500
+ stop_event.set()
501
+
502
+ # Clean up sender thread
503
+ try:
504
+ frame_send_queue.put(None)
505
+ except Exception as e:
506
+ print(f"❌ Failed to put None in frame_send_queue: {e}")
507
+
508
+
509
+ def generate_mp4_from_images(image_directory, output_video_path, fps=24):
510
+ """
511
+ Generate an MP4 video from a directory of images ordered alphabetically.
512
+
513
+ :param image_directory: Path to the directory containing images.
514
+ :param output_video_path: Path where the output MP4 will be saved.
515
+ :param fps: Frames per second for the output video.
516
+ """
517
+ global anim_name
518
+ # Construct the ffmpeg command
519
+ cmd = [
520
+ 'ffmpeg',
521
+ '-framerate', str(fps),
522
+ '-i', os.path.join(image_directory, anim_name+'/'+anim_name+'_%03d.jpg'), # Adjust the pattern if necessary
523
+ '-c:v', 'libx264',
524
+ '-pix_fmt', 'yuv420p',
525
+ output_video_path
526
+ ]
527
+ try:
528
+ subprocess.run(cmd, check=True)
529
+ print(f"Video saved to {output_video_path}")
530
+ except subprocess.CalledProcessError as e:
531
+ print(f"An error occurred: {e}")
532
+
533
+ def calculate_sha256(data):
534
+ # Convert data to bytes if it's not already
535
+ if isinstance(data, str):
536
+ data = data.encode()
537
+ # Calculate SHA-256 hash
538
+ sha256_hash = hashlib.sha256(data).hexdigest()
539
+ return sha256_hash
540
+
541
+ # Socket.IO event handlers
542
+ @socketio.on('connect')
543
+ def handle_connect():
544
+ print('Client connected')
545
+ emit('status', {'message': 'Connected to frontend-buffered demo server'})
546
+
547
+
548
+ @socketio.on('disconnect')
549
+ def handle_disconnect():
550
+ print('Client disconnected')
551
+
552
+
553
+ @socketio.on('start_generation')
554
+ def handle_start_generation(data):
555
+ global generation_active, frame_number, anim_name, frame_rate
556
+
557
+ frame_number = 0
558
+ if generation_active:
559
+ emit('error', {'message': 'Generation already in progress'})
560
+ return
561
+
562
+ prompt = data.get('prompt', '')
563
+
564
+ seed = data.get('seed', -1)
565
+ if seed==-1:
566
+ seed = random.randint(0, 2**32)
567
+
568
+ # Extract words up to the first punctuation or newline
569
+ words_up_to_punctuation = re.split(r'[^\w\s]', prompt)[0].strip() if prompt else ''
570
+ if not words_up_to_punctuation:
571
+ words_up_to_punctuation = re.split(r'[\n\r]', prompt)[0].strip()
572
+
573
+ # Calculate SHA-256 hash of the entire prompt
574
+ sha256_hash = calculate_sha256(prompt)
575
+
576
+ # Create anim_name with the extracted words and first 10 characters of the hash
577
+ anim_name = f"{words_up_to_punctuation[:20]}_{str(seed)}_{sha256_hash[:10]}"
578
+
579
+ generation_active = True
580
+ generation_start_time = time.time()
581
+ enable_torch_compile = data.get('enable_torch_compile', False)
582
+ enable_fp8 = data.get('enable_fp8', False)
583
+ use_taehv = data.get('use_taehv', False)
584
+ frame_rate = data.get('fps', 6)
585
+
586
+ if not prompt:
587
+ emit('error', {'message': 'Prompt is required'})
588
+ return
589
+
590
+ # Start generation in background thread
591
+ socketio.start_background_task(generate_video_stream, prompt, seed,
592
+ enable_torch_compile, enable_fp8, use_taehv)
593
+ emit('status', {'message': 'Generation started - frames will be sent immediately'})
594
+
595
+
596
+ @socketio.on('stop_generation')
597
+ def handle_stop_generation():
598
+ global generation_active, stop_event, frame_send_queue
599
+ generation_active = False
600
+ stop_event.set()
601
+
602
+ # Signal sender thread to stop (will be processed after current frames)
603
+ try:
604
+ frame_send_queue.put(None)
605
+ except Exception as e:
606
+ print(f"❌ Failed to put None in frame_send_queue: {e}")
607
+
608
+ emit('status', {'message': 'Generation stopped'})
609
+
610
+ # Web routes
611
+
612
+
613
+ @app.route('/')
614
+ def index():
615
+ return render_template('demo.html')
616
+
617
+
618
+ @app.route('/api/status')
619
+ def api_status():
620
+ return jsonify({
621
+ 'generation_active': generation_active,
622
+ 'free_vram_gb': get_cuda_free_memory_gb(gpu),
623
+ 'fp8_applied': fp8_applied,
624
+ 'torch_compile_applied': torch_compile_applied,
625
+ 'current_use_taehv': current_use_taehv
626
+ })
627
+
628
+
629
+ if __name__ == '__main__':
630
+ print(f"🚀 Starting demo on http://{args.host}:{args.port}")
631
+ socketio.run(app, host=args.host, port=args.port, debug=False)
demo_utils/constant.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+
4
+
5
+ ZERO_VAE_CACHE = [
6
+ torch.zeros(1, 16, 2, 60, 104),
7
+ torch.zeros(1, 384, 2, 60, 104),
8
+ torch.zeros(1, 384, 2, 60, 104),
9
+ torch.zeros(1, 384, 2, 60, 104),
10
+ torch.zeros(1, 384, 2, 60, 104),
11
+ torch.zeros(1, 384, 2, 60, 104),
12
+ torch.zeros(1, 384, 2, 60, 104),
13
+ torch.zeros(1, 384, 2, 60, 104),
14
+ torch.zeros(1, 384, 2, 60, 104),
15
+ torch.zeros(1, 384, 2, 60, 104),
16
+ torch.zeros(1, 384, 2, 60, 104),
17
+ torch.zeros(1, 384, 2, 60, 104),
18
+ torch.zeros(1, 192, 2, 120, 208),
19
+ torch.zeros(1, 384, 2, 120, 208),
20
+ torch.zeros(1, 384, 2, 120, 208),
21
+ torch.zeros(1, 384, 2, 120, 208),
22
+ torch.zeros(1, 384, 2, 120, 208),
23
+ torch.zeros(1, 384, 2, 120, 208),
24
+ torch.zeros(1, 384, 2, 120, 208),
25
+ torch.zeros(1, 192, 2, 240, 416),
26
+ torch.zeros(1, 192, 2, 240, 416),
27
+ torch.zeros(1, 192, 2, 240, 416),
28
+ torch.zeros(1, 192, 2, 240, 416),
29
+ torch.zeros(1, 192, 2, 240, 416),
30
+ torch.zeros(1, 192, 2, 240, 416),
31
+ torch.zeros(1, 96, 2, 480, 832),
32
+ torch.zeros(1, 96, 2, 480, 832),
33
+ torch.zeros(1, 96, 2, 480, 832),
34
+ torch.zeros(1, 96, 2, 480, 832),
35
+ torch.zeros(1, 96, 2, 480, 832),
36
+ torch.zeros(1, 96, 2, 480, 832),
37
+ torch.zeros(1, 96, 2, 480, 832)
38
+ ]
39
+
40
+ feat_names = [f"vae_cache_{i}" for i in range(len(ZERO_VAE_CACHE))]
41
+ ALL_INPUTS_NAMES = ["z", "use_cache"] + feat_names
demo_utils/memory.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
2
+ # Apache-2.0 License
3
+ # By lllyasviel
4
+
5
+ import torch
6
+
7
+
8
+ cpu = torch.device('cpu')
9
+ gpu = torch.device(f'cuda:{torch.cuda.current_device()}')
10
+ gpu_complete_modules = []
11
+
12
+
13
+ class DynamicSwapInstaller:
14
+ @staticmethod
15
+ def _install_module(module: torch.nn.Module, **kwargs):
16
+ original_class = module.__class__
17
+ module.__dict__['forge_backup_original_class'] = original_class
18
+
19
+ def hacked_get_attr(self, name: str):
20
+ if '_parameters' in self.__dict__:
21
+ _parameters = self.__dict__['_parameters']
22
+ if name in _parameters:
23
+ p = _parameters[name]
24
+ if p is None:
25
+ return None
26
+ if p.__class__ == torch.nn.Parameter:
27
+ return torch.nn.Parameter(p.to(**kwargs), requires_grad=p.requires_grad)
28
+ else:
29
+ return p.to(**kwargs)
30
+ if '_buffers' in self.__dict__:
31
+ _buffers = self.__dict__['_buffers']
32
+ if name in _buffers:
33
+ return _buffers[name].to(**kwargs)
34
+ return super(original_class, self).__getattr__(name)
35
+
36
+ module.__class__ = type('DynamicSwap_' + original_class.__name__, (original_class,), {
37
+ '__getattr__': hacked_get_attr,
38
+ })
39
+
40
+ return
41
+
42
+ @staticmethod
43
+ def _uninstall_module(module: torch.nn.Module):
44
+ if 'forge_backup_original_class' in module.__dict__:
45
+ module.__class__ = module.__dict__.pop('forge_backup_original_class')
46
+ return
47
+
48
+ @staticmethod
49
+ def install_model(model: torch.nn.Module, **kwargs):
50
+ for m in model.modules():
51
+ DynamicSwapInstaller._install_module(m, **kwargs)
52
+ return
53
+
54
+ @staticmethod
55
+ def uninstall_model(model: torch.nn.Module):
56
+ for m in model.modules():
57
+ DynamicSwapInstaller._uninstall_module(m)
58
+ return
59
+
60
+
61
+ def fake_diffusers_current_device(model: torch.nn.Module, target_device: torch.device):
62
+ if hasattr(model, 'scale_shift_table'):
63
+ model.scale_shift_table.data = model.scale_shift_table.data.to(target_device)
64
+ return
65
+
66
+ for k, p in model.named_modules():
67
+ if hasattr(p, 'weight'):
68
+ p.to(target_device)
69
+ return
70
+
71
+
72
+ def get_cuda_free_memory_gb(device=None):
73
+ if device is None:
74
+ device = gpu
75
+
76
+ memory_stats = torch.cuda.memory_stats(device)
77
+ bytes_active = memory_stats['active_bytes.all.current']
78
+ bytes_reserved = memory_stats['reserved_bytes.all.current']
79
+ bytes_free_cuda, _ = torch.cuda.mem_get_info(device)
80
+ bytes_inactive_reserved = bytes_reserved - bytes_active
81
+ bytes_total_available = bytes_free_cuda + bytes_inactive_reserved
82
+ return bytes_total_available / (1024 ** 3)
83
+
84
+
85
+ def move_model_to_device_with_memory_preservation(model, target_device, preserved_memory_gb=0):
86
+ print(f'Moving {model.__class__.__name__} to {target_device} with preserved memory: {preserved_memory_gb} GB')
87
+
88
+ for m in model.modules():
89
+ if get_cuda_free_memory_gb(target_device) <= preserved_memory_gb:
90
+ torch.cuda.empty_cache()
91
+ return
92
+
93
+ if hasattr(m, 'weight'):
94
+ m.to(device=target_device)
95
+
96
+ model.to(device=target_device)
97
+ torch.cuda.empty_cache()
98
+ return
99
+
100
+
101
+ def offload_model_from_device_for_memory_preservation(model, target_device, preserved_memory_gb=0):
102
+ print(f'Offloading {model.__class__.__name__} from {target_device} to preserve memory: {preserved_memory_gb} GB')
103
+
104
+ for m in model.modules():
105
+ if get_cuda_free_memory_gb(target_device) >= preserved_memory_gb:
106
+ torch.cuda.empty_cache()
107
+ return
108
+
109
+ if hasattr(m, 'weight'):
110
+ m.to(device=cpu)
111
+
112
+ model.to(device=cpu)
113
+ torch.cuda.empty_cache()
114
+ return
115
+
116
+
117
+ def unload_complete_models(*args):
118
+ for m in gpu_complete_modules + list(args):
119
+ m.to(device=cpu)
120
+ print(f'Unloaded {m.__class__.__name__} as complete.')
121
+
122
+ gpu_complete_modules.clear()
123
+ torch.cuda.empty_cache()
124
+ return
125
+
126
+
127
+ def load_model_as_complete(model, target_device, unload=True):
128
+ if unload:
129
+ unload_complete_models()
130
+
131
+ model.to(device=target_device)
132
+ print(f'Loaded {model.__class__.__name__} to {target_device} as complete.')
133
+
134
+ gpu_complete_modules.append(model)
135
+ return
demo_utils/taehv.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Tiny AutoEncoder for Hunyuan Video
4
+ (DNN for encoding / decoding videos to Hunyuan Video's latent space)
5
+ """
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from tqdm.auto import tqdm
10
+ from collections import namedtuple
11
+
12
+ DecoderResult = namedtuple("DecoderResult", ("frame", "memory"))
13
+ TWorkItem = namedtuple("TWorkItem", ("input_tensor", "block_index"))
14
+
15
+
16
+ def conv(n_in, n_out, **kwargs):
17
+ return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)
18
+
19
+
20
+ class Clamp(nn.Module):
21
+ def forward(self, x):
22
+ return torch.tanh(x / 3) * 3
23
+
24
+
25
+ class MemBlock(nn.Module):
26
+ def __init__(self, n_in, n_out):
27
+ super().__init__()
28
+ self.conv = nn.Sequential(conv(n_in * 2, n_out), nn.ReLU(inplace=True),
29
+ conv(n_out, n_out), nn.ReLU(inplace=True), conv(n_out, n_out))
30
+ self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
31
+ self.act = nn.ReLU(inplace=True)
32
+
33
+ def forward(self, x, past):
34
+ return self.act(self.conv(torch.cat([x, past], 1)) + self.skip(x))
35
+
36
+
37
+ class TPool(nn.Module):
38
+ def __init__(self, n_f, stride):
39
+ super().__init__()
40
+ self.stride = stride
41
+ self.conv = nn.Conv2d(n_f * stride, n_f, 1, bias=False)
42
+
43
+ def forward(self, x):
44
+ _NT, C, H, W = x.shape
45
+ return self.conv(x.reshape(-1, self.stride * C, H, W))
46
+
47
+
48
+ class TGrow(nn.Module):
49
+ def __init__(self, n_f, stride):
50
+ super().__init__()
51
+ self.stride = stride
52
+ self.conv = nn.Conv2d(n_f, n_f * stride, 1, bias=False)
53
+
54
+ def forward(self, x):
55
+ _NT, C, H, W = x.shape
56
+ x = self.conv(x)
57
+ return x.reshape(-1, C, H, W)
58
+
59
+
60
+ def apply_model_with_memblocks(model, x, parallel, show_progress_bar):
61
+ """
62
+ Apply a sequential model with memblocks to the given input.
63
+ Args:
64
+ - model: nn.Sequential of blocks to apply
65
+ - x: input data, of dimensions NTCHW
66
+ - parallel: if True, parallelize over timesteps (fast but uses O(T) memory)
67
+ if False, each timestep will be processed sequentially (slow but uses O(1) memory)
68
+ - show_progress_bar: if True, enables tqdm progressbar display
69
+
70
+ Returns NTCHW tensor of output data.
71
+ """
72
+ assert x.ndim == 5, f"TAEHV operates on NTCHW tensors, but got {x.ndim}-dim tensor"
73
+ N, T, C, H, W = x.shape
74
+ if parallel:
75
+ x = x.reshape(N * T, C, H, W)
76
+ # parallel over input timesteps, iterate over blocks
77
+ for b in tqdm(model, disable=not show_progress_bar):
78
+ if isinstance(b, MemBlock):
79
+ NT, C, H, W = x.shape
80
+ T = NT // N
81
+ _x = x.reshape(N, T, C, H, W)
82
+ mem = F.pad(_x, (0, 0, 0, 0, 0, 0, 1, 0), value=0)[:, :T].reshape(x.shape)
83
+ x = b(x, mem)
84
+ else:
85
+ x = b(x)
86
+ NT, C, H, W = x.shape
87
+ T = NT // N
88
+ x = x.view(N, T, C, H, W)
89
+ else:
90
+ # TODO(oboerbohan): at least on macos this still gradually uses more memory during decode...
91
+ # need to fix :(
92
+ out = []
93
+ # iterate over input timesteps and also iterate over blocks.
94
+ # because of the cursed TPool/TGrow blocks, this is not a nested loop,
95
+ # it's actually a ***graph traversal*** problem! so let's make a queue
96
+ work_queue = [TWorkItem(xt, 0) for t, xt in enumerate(x.reshape(N, T * C, H, W).chunk(T, dim=1))]
97
+ # in addition to manually managing our queue, we also need to manually manage our progressbar.
98
+ # we'll update it for every source node that we consume.
99
+ progress_bar = tqdm(range(T), disable=not show_progress_bar)
100
+ # we'll also need a separate addressable memory per node as well
101
+ mem = [None] * len(model)
102
+ while work_queue:
103
+ xt, i = work_queue.pop(0)
104
+ if i == 0:
105
+ # new source node consumed
106
+ progress_bar.update(1)
107
+ if i == len(model):
108
+ # reached end of the graph, append result to output list
109
+ out.append(xt)
110
+ else:
111
+ # fetch the block to process
112
+ b = model[i]
113
+ if isinstance(b, MemBlock):
114
+ # mem blocks are simple since we're visiting the graph in causal order
115
+ if mem[i] is None:
116
+ xt_new = b(xt, xt * 0)
117
+ mem[i] = xt
118
+ else:
119
+ xt_new = b(xt, mem[i])
120
+ mem[i].copy_(xt) # inplace might reduce mysterious pytorch memory allocations? doesn't help though
121
+ # add successor to work queue
122
+ work_queue.insert(0, TWorkItem(xt_new, i + 1))
123
+ elif isinstance(b, TPool):
124
+ # pool blocks are miserable
125
+ if mem[i] is None:
126
+ mem[i] = [] # pool memory is itself a queue of inputs to pool
127
+ mem[i].append(xt)
128
+ if len(mem[i]) > b.stride:
129
+ # pool mem is in invalid state, we should have pooled before this
130
+ raise ValueError("???")
131
+ elif len(mem[i]) < b.stride:
132
+ # pool mem is not yet full, go back to processing the work queue
133
+ pass
134
+ else:
135
+ # pool mem is ready, run the pool block
136
+ N, C, H, W = xt.shape
137
+ xt = b(torch.cat(mem[i], 1).view(N * b.stride, C, H, W))
138
+ # reset the pool mem
139
+ mem[i] = []
140
+ # add successor to work queue
141
+ work_queue.insert(0, TWorkItem(xt, i + 1))
142
+ elif isinstance(b, TGrow):
143
+ xt = b(xt)
144
+ NT, C, H, W = xt.shape
145
+ # each tgrow has multiple successor nodes
146
+ for xt_next in reversed(xt.view(N, b.stride * C, H, W).chunk(b.stride, 1)):
147
+ # add successor to work queue
148
+ work_queue.insert(0, TWorkItem(xt_next, i + 1))
149
+ else:
150
+ # normal block with no funny business
151
+ xt = b(xt)
152
+ # add successor to work queue
153
+ work_queue.insert(0, TWorkItem(xt, i + 1))
154
+ progress_bar.close()
155
+ x = torch.stack(out, 1)
156
+ return x
157
+
158
+
159
+ class TAEHV(nn.Module):
160
+ latent_channels = 16
161
+ image_channels = 3
162
+
163
+ def __init__(self, checkpoint_path="taehv.pth", decoder_time_upscale=(True, True), decoder_space_upscale=(True, True, True)):
164
+ """Initialize pretrained TAEHV from the given checkpoint.
165
+
166
+ Arg:
167
+ checkpoint_path: path to weight file to load. taehv.pth for Hunyuan, taew2_1.pth for Wan 2.1.
168
+ decoder_time_upscale: whether temporal upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
169
+ decoder_space_upscale: whether spatial upsampling is enabled for each block. upsampling can be disabled for a cheaper preview.
170
+ """
171
+ super().__init__()
172
+ self.encoder = nn.Sequential(
173
+ conv(TAEHV.image_channels, 64), nn.ReLU(inplace=True),
174
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
175
+ TPool(64, 2), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
176
+ TPool(64, 1), conv(64, 64, stride=2, bias=False), MemBlock(64, 64), MemBlock(64, 64), MemBlock(64, 64),
177
+ conv(64, TAEHV.latent_channels),
178
+ )
179
+ n_f = [256, 128, 64, 64]
180
+ self.frames_to_trim = 2**sum(decoder_time_upscale) - 1
181
+ self.decoder = nn.Sequential(
182
+ Clamp(), conv(TAEHV.latent_channels, n_f[0]), nn.ReLU(inplace=True),
183
+ MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), MemBlock(n_f[0], n_f[0]), nn.Upsample(
184
+ scale_factor=2 if decoder_space_upscale[0] else 1), TGrow(n_f[0], 1), conv(n_f[0], n_f[1], bias=False),
185
+ MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), MemBlock(n_f[1], n_f[1]), nn.Upsample(
186
+ scale_factor=2 if decoder_space_upscale[1] else 1), TGrow(n_f[1], 2 if decoder_time_upscale[0] else 1), conv(n_f[1], n_f[2], bias=False),
187
+ MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), MemBlock(n_f[2], n_f[2]), nn.Upsample(
188
+ scale_factor=2 if decoder_space_upscale[2] else 1), TGrow(n_f[2], 2 if decoder_time_upscale[1] else 1), conv(n_f[2], n_f[3], bias=False),
189
+ nn.ReLU(inplace=True), conv(n_f[3], TAEHV.image_channels),
190
+ )
191
+ if checkpoint_path is not None:
192
+ self.load_state_dict(self.patch_tgrow_layers(torch.load(
193
+ checkpoint_path, map_location="cpu", weights_only=True)))
194
+
195
+ def patch_tgrow_layers(self, sd):
196
+ """Patch TGrow layers to use a smaller kernel if needed.
197
+
198
+ Args:
199
+ sd: state dict to patch
200
+ """
201
+ new_sd = self.state_dict()
202
+ for i, layer in enumerate(self.decoder):
203
+ if isinstance(layer, TGrow):
204
+ key = f"decoder.{i}.conv.weight"
205
+ if sd[key].shape[0] > new_sd[key].shape[0]:
206
+ # take the last-timestep output channels
207
+ sd[key] = sd[key][-new_sd[key].shape[0]:]
208
+ return sd
209
+
210
+ def encode_video(self, x, parallel=True, show_progress_bar=True):
211
+ """Encode a sequence of frames.
212
+
213
+ Args:
214
+ x: input NTCHW RGB (C=3) tensor with values in [0, 1].
215
+ parallel: if True, all frames will be processed at once.
216
+ (this is faster but may require more memory).
217
+ if False, frames will be processed sequentially.
218
+ Returns NTCHW latent tensor with ~Gaussian values.
219
+ """
220
+ return apply_model_with_memblocks(self.encoder, x, parallel, show_progress_bar)
221
+
222
+ def decode_video(self, x, parallel=True, show_progress_bar=False):
223
+ """Decode a sequence of frames.
224
+
225
+ Args:
226
+ x: input NTCHW latent (C=12) tensor with ~Gaussian values.
227
+ parallel: if True, all frames will be processed at once.
228
+ (this is faster but may require more memory).
229
+ if False, frames will be processed sequentially.
230
+ Returns NTCHW RGB tensor with ~[0, 1] values.
231
+ """
232
+ x = apply_model_with_memblocks(self.decoder, x, parallel, show_progress_bar)
233
+ # return x[:, self.frames_to_trim:]
234
+ return x
235
+
236
+ def forward(self, x):
237
+ return self.c(x)
238
+
239
+
240
+ @torch.no_grad()
241
+ def main():
242
+ """Run TAEHV roundtrip reconstruction on the given video paths."""
243
+ import os
244
+ import sys
245
+ import cv2 # no highly esteemed deed is commemorated here
246
+
247
+ class VideoTensorReader:
248
+ def __init__(self, video_file_path):
249
+ self.cap = cv2.VideoCapture(video_file_path)
250
+ assert self.cap.isOpened(), f"Could not load {video_file_path}"
251
+ self.fps = self.cap.get(cv2.CAP_PROP_FPS)
252
+
253
+ def __iter__(self):
254
+ return self
255
+
256
+ def __next__(self):
257
+ ret, frame = self.cap.read()
258
+ if not ret:
259
+ self.cap.release()
260
+ raise StopIteration # End of video or error
261
+ return torch.from_numpy(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)).permute(2, 0, 1) # BGR HWC -> RGB CHW
262
+
263
+ class VideoTensorWriter:
264
+ def __init__(self, video_file_path, width_height, fps=30):
265
+ self.writer = cv2.VideoWriter(video_file_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, width_height)
266
+ assert self.writer.isOpened(), f"Could not create writer for {video_file_path}"
267
+
268
+ def write(self, frame_tensor):
269
+ assert frame_tensor.ndim == 3 and frame_tensor.shape[0] == 3, f"{frame_tensor.shape}??"
270
+ self.writer.write(cv2.cvtColor(frame_tensor.permute(1, 2, 0).numpy(),
271
+ cv2.COLOR_RGB2BGR)) # RGB CHW -> BGR HWC
272
+
273
+ def __del__(self):
274
+ if hasattr(self, 'writer'):
275
+ self.writer.release()
276
+
277
+ dev = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
278
+ dtype = torch.float16
279
+ checkpoint_path = os.getenv("TAEHV_CHECKPOINT_PATH", "taehv.pth")
280
+ checkpoint_name = os.path.splitext(os.path.basename(checkpoint_path))[0]
281
+ print(
282
+ f"Using device \033[31m{dev}\033[0m, dtype \033[32m{dtype}\033[0m, checkpoint \033[34m{checkpoint_name}\033[0m ({checkpoint_path})")
283
+ taehv = TAEHV(checkpoint_path=checkpoint_path).to(dev, dtype)
284
+ for video_path in sys.argv[1:]:
285
+ print(f"Processing {video_path}...")
286
+ video_in = VideoTensorReader(video_path)
287
+ video = torch.stack(list(video_in), 0)[None]
288
+ vid_dev = video.to(dev, dtype).div_(255.0)
289
+ # convert to device tensor
290
+ if video.numel() < 100_000_000:
291
+ print(f" {video_path} seems small enough, will process all frames in parallel")
292
+ # convert to device tensor
293
+ vid_enc = taehv.encode_video(vid_dev)
294
+ print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
295
+ vid_dec = taehv.decode_video(vid_enc)
296
+ print(f" Decoded {video_path} -> {vid_dec.shape}")
297
+ else:
298
+ print(f" {video_path} seems large, will process each frame sequentially")
299
+ # convert to device tensor
300
+ vid_enc = taehv.encode_video(vid_dev, parallel=False)
301
+ print(f" Encoded {video_path} -> {vid_enc.shape}. Decoding...")
302
+ vid_dec = taehv.decode_video(vid_enc, parallel=False)
303
+ print(f" Decoded {video_path} -> {vid_dec.shape}")
304
+ video_out_path = video_path + f".reconstructed_by_{checkpoint_name}.mp4"
305
+ video_out = VideoTensorWriter(
306
+ video_out_path, (vid_dec.shape[-1], vid_dec.shape[-2]), fps=int(round(video_in.fps)))
307
+ for frame in vid_dec.clamp_(0, 1).mul_(255).round_().byte().cpu()[0]:
308
+ video_out.write(frame)
309
+ print(f" Saved to {video_out_path}")
310
+
311
+
312
+ if __name__ == "__main__":
313
+ main()
demo_utils/utils.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from https://github.com/lllyasviel/FramePack/tree/main/demo_utils
2
+ # Apache-2.0 License
3
+ # By lllyasviel
4
+
5
+ import os
6
+ import cv2
7
+ import json
8
+ import random
9
+ import glob
10
+ import torch
11
+ import einops
12
+ import numpy as np
13
+ import datetime
14
+ import torchvision
15
+
16
+ from PIL import Image
17
+
18
+
19
+ def min_resize(x, m):
20
+ if x.shape[0] < x.shape[1]:
21
+ s0 = m
22
+ s1 = int(float(m) / float(x.shape[0]) * float(x.shape[1]))
23
+ else:
24
+ s0 = int(float(m) / float(x.shape[1]) * float(x.shape[0]))
25
+ s1 = m
26
+ new_max = max(s1, s0)
27
+ raw_max = max(x.shape[0], x.shape[1])
28
+ if new_max < raw_max:
29
+ interpolation = cv2.INTER_AREA
30
+ else:
31
+ interpolation = cv2.INTER_LANCZOS4
32
+ y = cv2.resize(x, (s1, s0), interpolation=interpolation)
33
+ return y
34
+
35
+
36
+ def d_resize(x, y):
37
+ H, W, C = y.shape
38
+ new_min = min(H, W)
39
+ raw_min = min(x.shape[0], x.shape[1])
40
+ if new_min < raw_min:
41
+ interpolation = cv2.INTER_AREA
42
+ else:
43
+ interpolation = cv2.INTER_LANCZOS4
44
+ y = cv2.resize(x, (W, H), interpolation=interpolation)
45
+ return y
46
+
47
+
48
+ def resize_and_center_crop(image, target_width, target_height):
49
+ if target_height == image.shape[0] and target_width == image.shape[1]:
50
+ return image
51
+
52
+ pil_image = Image.fromarray(image)
53
+ original_width, original_height = pil_image.size
54
+ scale_factor = max(target_width / original_width, target_height / original_height)
55
+ resized_width = int(round(original_width * scale_factor))
56
+ resized_height = int(round(original_height * scale_factor))
57
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
58
+ left = (resized_width - target_width) / 2
59
+ top = (resized_height - target_height) / 2
60
+ right = (resized_width + target_width) / 2
61
+ bottom = (resized_height + target_height) / 2
62
+ cropped_image = resized_image.crop((left, top, right, bottom))
63
+ return np.array(cropped_image)
64
+
65
+
66
+ def resize_and_center_crop_pytorch(image, target_width, target_height):
67
+ B, C, H, W = image.shape
68
+
69
+ if H == target_height and W == target_width:
70
+ return image
71
+
72
+ scale_factor = max(target_width / W, target_height / H)
73
+ resized_width = int(round(W * scale_factor))
74
+ resized_height = int(round(H * scale_factor))
75
+
76
+ resized = torch.nn.functional.interpolate(image, size=(resized_height, resized_width), mode='bilinear', align_corners=False)
77
+
78
+ top = (resized_height - target_height) // 2
79
+ left = (resized_width - target_width) // 2
80
+ cropped = resized[:, :, top:top + target_height, left:left + target_width]
81
+
82
+ return cropped
83
+
84
+
85
+ def resize_without_crop(image, target_width, target_height):
86
+ if target_height == image.shape[0] and target_width == image.shape[1]:
87
+ return image
88
+
89
+ pil_image = Image.fromarray(image)
90
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
91
+ return np.array(resized_image)
92
+
93
+
94
+ def just_crop(image, w, h):
95
+ if h == image.shape[0] and w == image.shape[1]:
96
+ return image
97
+
98
+ original_height, original_width = image.shape[:2]
99
+ k = min(original_height / h, original_width / w)
100
+ new_width = int(round(w * k))
101
+ new_height = int(round(h * k))
102
+ x_start = (original_width - new_width) // 2
103
+ y_start = (original_height - new_height) // 2
104
+ cropped_image = image[y_start:y_start + new_height, x_start:x_start + new_width]
105
+ return cropped_image
106
+
107
+
108
+ def write_to_json(data, file_path):
109
+ temp_file_path = file_path + ".tmp"
110
+ with open(temp_file_path, 'wt', encoding='utf-8') as temp_file:
111
+ json.dump(data, temp_file, indent=4)
112
+ os.replace(temp_file_path, file_path)
113
+ return
114
+
115
+
116
+ def read_from_json(file_path):
117
+ with open(file_path, 'rt', encoding='utf-8') as file:
118
+ data = json.load(file)
119
+ return data
120
+
121
+
122
+ def get_active_parameters(m):
123
+ return {k: v for k, v in m.named_parameters() if v.requires_grad}
124
+
125
+
126
+ def cast_training_params(m, dtype=torch.float32):
127
+ result = {}
128
+ for n, param in m.named_parameters():
129
+ if param.requires_grad:
130
+ param.data = param.to(dtype)
131
+ result[n] = param
132
+ return result
133
+
134
+
135
+ def separate_lora_AB(parameters, B_patterns=None):
136
+ parameters_normal = {}
137
+ parameters_B = {}
138
+
139
+ if B_patterns is None:
140
+ B_patterns = ['.lora_B.', '__zero__']
141
+
142
+ for k, v in parameters.items():
143
+ if any(B_pattern in k for B_pattern in B_patterns):
144
+ parameters_B[k] = v
145
+ else:
146
+ parameters_normal[k] = v
147
+
148
+ return parameters_normal, parameters_B
149
+
150
+
151
+ def set_attr_recursive(obj, attr, value):
152
+ attrs = attr.split(".")
153
+ for name in attrs[:-1]:
154
+ obj = getattr(obj, name)
155
+ setattr(obj, attrs[-1], value)
156
+ return
157
+
158
+
159
+ def print_tensor_list_size(tensors):
160
+ total_size = 0
161
+ total_elements = 0
162
+
163
+ if isinstance(tensors, dict):
164
+ tensors = tensors.values()
165
+
166
+ for tensor in tensors:
167
+ total_size += tensor.nelement() * tensor.element_size()
168
+ total_elements += tensor.nelement()
169
+
170
+ total_size_MB = total_size / (1024 ** 2)
171
+ total_elements_B = total_elements / 1e9
172
+
173
+ print(f"Total number of tensors: {len(tensors)}")
174
+ print(f"Total size of tensors: {total_size_MB:.2f} MB")
175
+ print(f"Total number of parameters: {total_elements_B:.3f} billion")
176
+ return
177
+
178
+
179
+ @torch.no_grad()
180
+ def batch_mixture(a, b=None, probability_a=0.5, mask_a=None):
181
+ batch_size = a.size(0)
182
+
183
+ if b is None:
184
+ b = torch.zeros_like(a)
185
+
186
+ if mask_a is None:
187
+ mask_a = torch.rand(batch_size) < probability_a
188
+
189
+ mask_a = mask_a.to(a.device)
190
+ mask_a = mask_a.reshape((batch_size,) + (1,) * (a.dim() - 1))
191
+ result = torch.where(mask_a, a, b)
192
+ return result
193
+
194
+
195
+ @torch.no_grad()
196
+ def zero_module(module):
197
+ for p in module.parameters():
198
+ p.detach().zero_()
199
+ return module
200
+
201
+
202
+ @torch.no_grad()
203
+ def supress_lower_channels(m, k, alpha=0.01):
204
+ data = m.weight.data.clone()
205
+
206
+ assert int(data.shape[1]) >= k
207
+
208
+ data[:, :k] = data[:, :k] * alpha
209
+ m.weight.data = data.contiguous().clone()
210
+ return m
211
+
212
+
213
+ def freeze_module(m):
214
+ if not hasattr(m, '_forward_inside_frozen_module'):
215
+ m._forward_inside_frozen_module = m.forward
216
+ m.requires_grad_(False)
217
+ m.forward = torch.no_grad()(m.forward)
218
+ return m
219
+
220
+
221
+ def get_latest_safetensors(folder_path):
222
+ safetensors_files = glob.glob(os.path.join(folder_path, '*.safetensors'))
223
+
224
+ if not safetensors_files:
225
+ raise ValueError('No file to resume!')
226
+
227
+ latest_file = max(safetensors_files, key=os.path.getmtime)
228
+ latest_file = os.path.abspath(os.path.realpath(latest_file))
229
+ return latest_file
230
+
231
+
232
+ def generate_random_prompt_from_tags(tags_str, min_length=3, max_length=32):
233
+ tags = tags_str.split(', ')
234
+ tags = random.sample(tags, k=min(random.randint(min_length, max_length), len(tags)))
235
+ prompt = ', '.join(tags)
236
+ return prompt
237
+
238
+
239
+ def interpolate_numbers(a, b, n, round_to_int=False, gamma=1.0):
240
+ numbers = a + (b - a) * (np.linspace(0, 1, n) ** gamma)
241
+ if round_to_int:
242
+ numbers = np.round(numbers).astype(int)
243
+ return numbers.tolist()
244
+
245
+
246
+ def uniform_random_by_intervals(inclusive, exclusive, n, round_to_int=False):
247
+ edges = np.linspace(0, 1, n + 1)
248
+ points = np.random.uniform(edges[:-1], edges[1:])
249
+ numbers = inclusive + (exclusive - inclusive) * points
250
+ if round_to_int:
251
+ numbers = np.round(numbers).astype(int)
252
+ return numbers.tolist()
253
+
254
+
255
+ def soft_append_bcthw(history, current, overlap=0):
256
+ if overlap <= 0:
257
+ return torch.cat([history, current], dim=2)
258
+
259
+ assert history.shape[2] >= overlap, f"History length ({history.shape[2]}) must be >= overlap ({overlap})"
260
+ assert current.shape[2] >= overlap, f"Current length ({current.shape[2]}) must be >= overlap ({overlap})"
261
+
262
+ weights = torch.linspace(1, 0, overlap, dtype=history.dtype, device=history.device).view(1, 1, -1, 1, 1)
263
+ blended = weights * history[:, :, -overlap:] + (1 - weights) * current[:, :, :overlap]
264
+ output = torch.cat([history[:, :, :-overlap], blended, current[:, :, overlap:]], dim=2)
265
+
266
+ return output.to(history)
267
+
268
+
269
+ def save_bcthw_as_mp4(x, output_filename, fps=10, crf=0):
270
+ b, c, t, h, w = x.shape
271
+
272
+ per_row = b
273
+ for p in [6, 5, 4, 3, 2]:
274
+ if b % p == 0:
275
+ per_row = p
276
+ break
277
+
278
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
279
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
280
+ x = x.detach().cpu().to(torch.uint8)
281
+ x = einops.rearrange(x, '(m n) c t h w -> t (m h) (n w) c', n=per_row)
282
+ torchvision.io.write_video(output_filename, x, fps=fps, video_codec='libx264', options={'crf': str(int(crf))})
283
+ return x
284
+
285
+
286
+ def save_bcthw_as_png(x, output_filename):
287
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
288
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
289
+ x = x.detach().cpu().to(torch.uint8)
290
+ x = einops.rearrange(x, 'b c t h w -> c (b h) (t w)')
291
+ torchvision.io.write_png(x, output_filename)
292
+ return output_filename
293
+
294
+
295
+ def save_bchw_as_png(x, output_filename):
296
+ os.makedirs(os.path.dirname(os.path.abspath(os.path.realpath(output_filename))), exist_ok=True)
297
+ x = torch.clamp(x.float(), -1., 1.) * 127.5 + 127.5
298
+ x = x.detach().cpu().to(torch.uint8)
299
+ x = einops.rearrange(x, 'b c h w -> c h (b w)')
300
+ torchvision.io.write_png(x, output_filename)
301
+ return output_filename
302
+
303
+
304
+ def add_tensors_with_padding(tensor1, tensor2):
305
+ if tensor1.shape == tensor2.shape:
306
+ return tensor1 + tensor2
307
+
308
+ shape1 = tensor1.shape
309
+ shape2 = tensor2.shape
310
+
311
+ new_shape = tuple(max(s1, s2) for s1, s2 in zip(shape1, shape2))
312
+
313
+ padded_tensor1 = torch.zeros(new_shape)
314
+ padded_tensor2 = torch.zeros(new_shape)
315
+
316
+ padded_tensor1[tuple(slice(0, s) for s in shape1)] = tensor1
317
+ padded_tensor2[tuple(slice(0, s) for s in shape2)] = tensor2
318
+
319
+ result = padded_tensor1 + padded_tensor2
320
+ return result
321
+
322
+
323
+ def print_free_mem():
324
+ torch.cuda.empty_cache()
325
+ free_mem, total_mem = torch.cuda.mem_get_info(0)
326
+ free_mem_mb = free_mem / (1024 ** 2)
327
+ total_mem_mb = total_mem / (1024 ** 2)
328
+ print(f"Free memory: {free_mem_mb:.2f} MB")
329
+ print(f"Total memory: {total_mem_mb:.2f} MB")
330
+ return
331
+
332
+
333
+ def print_gpu_parameters(device, state_dict, log_count=1):
334
+ summary = {"device": device, "keys_count": len(state_dict)}
335
+
336
+ logged_params = {}
337
+ for i, (key, tensor) in enumerate(state_dict.items()):
338
+ if i >= log_count:
339
+ break
340
+ logged_params[key] = tensor.flatten()[:3].tolist()
341
+
342
+ summary["params"] = logged_params
343
+
344
+ print(str(summary))
345
+ return
346
+
347
+
348
+ def visualize_txt_as_img(width, height, text, font_path='font/DejaVuSans.ttf', size=18):
349
+ from PIL import Image, ImageDraw, ImageFont
350
+
351
+ txt = Image.new("RGB", (width, height), color="white")
352
+ draw = ImageDraw.Draw(txt)
353
+ font = ImageFont.truetype(font_path, size=size)
354
+
355
+ if text == '':
356
+ return np.array(txt)
357
+
358
+ # Split text into lines that fit within the image width
359
+ lines = []
360
+ words = text.split()
361
+ current_line = words[0]
362
+
363
+ for word in words[1:]:
364
+ line_with_word = f"{current_line} {word}"
365
+ if draw.textbbox((0, 0), line_with_word, font=font)[2] <= width:
366
+ current_line = line_with_word
367
+ else:
368
+ lines.append(current_line)
369
+ current_line = word
370
+
371
+ lines.append(current_line)
372
+
373
+ # Draw the text line by line
374
+ y = 0
375
+ line_height = draw.textbbox((0, 0), "A", font=font)[3]
376
+
377
+ for line in lines:
378
+ if y + line_height > height:
379
+ break # stop drawing if the next line will be outside the image
380
+ draw.text((0, y), line, fill="black", font=font)
381
+ y += line_height
382
+
383
+ return np.array(txt)
384
+
385
+
386
+ def blue_mark(x):
387
+ x = x.copy()
388
+ c = x[:, :, 2]
389
+ b = cv2.blur(c, (9, 9))
390
+ x[:, :, 2] = ((c - b) * 16.0 + b).clip(-1, 1)
391
+ return x
392
+
393
+
394
+ def green_mark(x):
395
+ x = x.copy()
396
+ x[:, :, 2] = -1
397
+ x[:, :, 0] = -1
398
+ return x
399
+
400
+
401
+ def frame_mark(x):
402
+ x = x.copy()
403
+ x[:64] = -1
404
+ x[-64:] = -1
405
+ x[:, :8] = 1
406
+ x[:, -8:] = 1
407
+ return x
408
+
409
+
410
+ @torch.inference_mode()
411
+ def pytorch2numpy(imgs):
412
+ results = []
413
+ for x in imgs:
414
+ y = x.movedim(0, -1)
415
+ y = y * 127.5 + 127.5
416
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
417
+ results.append(y)
418
+ return results
419
+
420
+
421
+ @torch.inference_mode()
422
+ def numpy2pytorch(imgs):
423
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.5 - 1.0
424
+ h = h.movedim(-1, 1)
425
+ return h
426
+
427
+
428
+ @torch.no_grad()
429
+ def duplicate_prefix_to_suffix(x, count, zero_out=False):
430
+ if zero_out:
431
+ return torch.cat([x, torch.zeros_like(x[:count])], dim=0)
432
+ else:
433
+ return torch.cat([x, x[:count]], dim=0)
434
+
435
+
436
+ def weighted_mse(a, b, weight):
437
+ return torch.mean(weight.float() * (a.float() - b.float()) ** 2)
438
+
439
+
440
+ def clamped_linear_interpolation(x, x_min, y_min, x_max, y_max, sigma=1.0):
441
+ x = (x - x_min) / (x_max - x_min)
442
+ x = max(0.0, min(x, 1.0))
443
+ x = x ** sigma
444
+ return y_min + x * (y_max - y_min)
445
+
446
+
447
+ def expand_to_dims(x, target_dims):
448
+ return x.view(*x.shape, *([1] * max(0, target_dims - x.dim())))
449
+
450
+
451
+ def repeat_to_batch_size(tensor: torch.Tensor, batch_size: int):
452
+ if tensor is None:
453
+ return None
454
+
455
+ first_dim = tensor.shape[0]
456
+
457
+ if first_dim == batch_size:
458
+ return tensor
459
+
460
+ if batch_size % first_dim != 0:
461
+ raise ValueError(f"Cannot evenly repeat first dim {first_dim} to match batch_size {batch_size}.")
462
+
463
+ repeat_times = batch_size // first_dim
464
+
465
+ return tensor.repeat(repeat_times, *[1] * (tensor.dim() - 1))
466
+
467
+
468
+ def dim5(x):
469
+ return expand_to_dims(x, 5)
470
+
471
+
472
+ def dim4(x):
473
+ return expand_to_dims(x, 4)
474
+
475
+
476
+ def dim3(x):
477
+ return expand_to_dims(x, 3)
478
+
479
+
480
+ def crop_or_pad_yield_mask(x, length):
481
+ B, F, C = x.shape
482
+ device = x.device
483
+ dtype = x.dtype
484
+
485
+ if F < length:
486
+ y = torch.zeros((B, length, C), dtype=dtype, device=device)
487
+ mask = torch.zeros((B, length), dtype=torch.bool, device=device)
488
+ y[:, :F, :] = x
489
+ mask[:, :F] = True
490
+ return y, mask
491
+
492
+ return x[:, :length, :], torch.ones((B, length), dtype=torch.bool, device=device)
493
+
494
+
495
+ def extend_dim(x, dim, minimal_length, zero_pad=False):
496
+ original_length = int(x.shape[dim])
497
+
498
+ if original_length >= minimal_length:
499
+ return x
500
+
501
+ if zero_pad:
502
+ padding_shape = list(x.shape)
503
+ padding_shape[dim] = minimal_length - original_length
504
+ padding = torch.zeros(padding_shape, dtype=x.dtype, device=x.device)
505
+ else:
506
+ idx = (slice(None),) * dim + (slice(-1, None),) + (slice(None),) * (len(x.shape) - dim - 1)
507
+ last_element = x[idx]
508
+ padding = last_element.repeat_interleave(minimal_length - original_length, dim=dim)
509
+
510
+ return torch.cat([x, padding], dim=dim)
511
+
512
+
513
+ def lazy_positional_encoding(t, repeats=None):
514
+ if not isinstance(t, list):
515
+ t = [t]
516
+
517
+ from diffusers.models.embeddings import get_timestep_embedding
518
+
519
+ te = torch.tensor(t)
520
+ te = get_timestep_embedding(timesteps=te, embedding_dim=256, flip_sin_to_cos=True, downscale_freq_shift=0.0, scale=1.0)
521
+
522
+ if repeats is None:
523
+ return te
524
+
525
+ te = te[:, None, :].expand(-1, repeats, -1)
526
+
527
+ return te
528
+
529
+
530
+ def state_dict_offset_merge(A, B, C=None):
531
+ result = {}
532
+ keys = A.keys()
533
+
534
+ for key in keys:
535
+ A_value = A[key]
536
+ B_value = B[key].to(A_value)
537
+
538
+ if C is None:
539
+ result[key] = A_value + B_value
540
+ else:
541
+ C_value = C[key].to(A_value)
542
+ result[key] = A_value + B_value - C_value
543
+
544
+ return result
545
+
546
+
547
+ def state_dict_weighted_merge(state_dicts, weights):
548
+ if len(state_dicts) != len(weights):
549
+ raise ValueError("Number of state dictionaries must match number of weights")
550
+
551
+ if not state_dicts:
552
+ return {}
553
+
554
+ total_weight = sum(weights)
555
+
556
+ if total_weight == 0:
557
+ raise ValueError("Sum of weights cannot be zero")
558
+
559
+ normalized_weights = [w / total_weight for w in weights]
560
+
561
+ keys = state_dicts[0].keys()
562
+ result = {}
563
+
564
+ for key in keys:
565
+ result[key] = state_dicts[0][key] * normalized_weights[0]
566
+
567
+ for i in range(1, len(state_dicts)):
568
+ state_dict_value = state_dicts[i][key].to(result[key])
569
+ result[key] += state_dict_value * normalized_weights[i]
570
+
571
+ return result
572
+
573
+
574
+ def group_files_by_folder(all_files):
575
+ grouped_files = {}
576
+
577
+ for file in all_files:
578
+ folder_name = os.path.basename(os.path.dirname(file))
579
+ if folder_name not in grouped_files:
580
+ grouped_files[folder_name] = []
581
+ grouped_files[folder_name].append(file)
582
+
583
+ list_of_lists = list(grouped_files.values())
584
+ return list_of_lists
585
+
586
+
587
+ def generate_timestamp():
588
+ now = datetime.datetime.now()
589
+ timestamp = now.strftime('%y%m%d_%H%M%S')
590
+ milliseconds = f"{int(now.microsecond / 1000):03d}"
591
+ random_number = random.randint(0, 9999)
592
+ return f"{timestamp}_{milliseconds}_{random_number}"
593
+
594
+
595
+ def write_PIL_image_with_png_info(image, metadata, path):
596
+ from PIL.PngImagePlugin import PngInfo
597
+
598
+ png_info = PngInfo()
599
+ for key, value in metadata.items():
600
+ png_info.add_text(key, value)
601
+
602
+ image.save(path, "PNG", pnginfo=png_info)
603
+ return image
604
+
605
+
606
+ def torch_safe_save(content, path):
607
+ torch.save(content, path + '_tmp')
608
+ os.replace(path + '_tmp', path)
609
+ return path
610
+
611
+
612
+ def move_optimizer_to_device(optimizer, device):
613
+ for state in optimizer.state.values():
614
+ for k, v in state.items():
615
+ if isinstance(v, torch.Tensor):
616
+ state[k] = v.to(device)
demo_utils/vae.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from einops import rearrange
3
+ import tensorrt as trt
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from demo_utils.constant import ALL_INPUTS_NAMES, ZERO_VAE_CACHE
8
+ from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, Upsample
9
+
10
+ CACHE_T = 2
11
+
12
+
13
+ class ResidualBlock(nn.Module):
14
+
15
+ def __init__(self, in_dim, out_dim, dropout=0.0):
16
+ super().__init__()
17
+ self.in_dim = in_dim
18
+ self.out_dim = out_dim
19
+
20
+ # layers
21
+ self.residual = nn.Sequential(
22
+ RMS_norm(in_dim, images=False), nn.SiLU(),
23
+ CausalConv3d(in_dim, out_dim, 3, padding=1),
24
+ RMS_norm(out_dim, images=False), nn.SiLU(), nn.Dropout(dropout),
25
+ CausalConv3d(out_dim, out_dim, 3, padding=1))
26
+ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
27
+ if in_dim != out_dim else nn.Identity()
28
+
29
+ def forward(self, x, feat_cache_1, feat_cache_2):
30
+ h = self.shortcut(x)
31
+ feat_cache = feat_cache_1
32
+ out_feat_cache = []
33
+ for layer in self.residual:
34
+ if isinstance(layer, CausalConv3d):
35
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
36
+ if cache_x.shape[2] < 2 and feat_cache is not None:
37
+ # cache last frame of last two chunk
38
+ cache_x = torch.cat([
39
+ feat_cache[:, :, -1, :, :].unsqueeze(2).to(
40
+ cache_x.device), cache_x
41
+ ],
42
+ dim=2)
43
+ x = layer(x, feat_cache)
44
+ out_feat_cache.append(cache_x)
45
+ feat_cache = feat_cache_2
46
+ else:
47
+ x = layer(x)
48
+ return x + h, *out_feat_cache
49
+
50
+
51
+ class Resample(nn.Module):
52
+
53
+ def __init__(self, dim, mode):
54
+ assert mode in ('none', 'upsample2d', 'upsample3d')
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.mode = mode
58
+
59
+ # layers
60
+ if mode == 'upsample2d':
61
+ self.resample = nn.Sequential(
62
+ Upsample(scale_factor=(2., 2.), mode='nearest'),
63
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
64
+ elif mode == 'upsample3d':
65
+ self.resample = nn.Sequential(
66
+ Upsample(scale_factor=(2., 2.), mode='nearest'),
67
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
68
+ self.time_conv = CausalConv3d(
69
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
70
+ else:
71
+ self.resample = nn.Identity()
72
+
73
+ def forward(self, x, is_first_frame, feat_cache):
74
+ if self.mode == 'upsample3d':
75
+ b, c, t, h, w = x.size()
76
+ # x, out_feat_cache = torch.cond(
77
+ # is_first_frame,
78
+ # lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
79
+ # lambda: self.temporal_conv(x, feat_cache),
80
+ # )
81
+ # x, out_feat_cache = torch.cond(
82
+ # is_first_frame,
83
+ # lambda: (torch.cat([torch.zeros_like(x), x], dim=2), feat_cache.clone()),
84
+ # lambda: self.temporal_conv(x, feat_cache),
85
+ # )
86
+ x, out_feat_cache = self.temporal_conv(x, is_first_frame, feat_cache)
87
+ out_feat_cache = torch.cond(
88
+ is_first_frame,
89
+ lambda: feat_cache.clone().contiguous(),
90
+ lambda: out_feat_cache.clone().contiguous(),
91
+ )
92
+ # if is_first_frame:
93
+ # x = torch.cat([torch.zeros_like(x), x], dim=2)
94
+ # out_feat_cache = feat_cache.clone()
95
+ # else:
96
+ # x, out_feat_cache = self.temporal_conv(x, feat_cache)
97
+ else:
98
+ out_feat_cache = None
99
+ t = x.shape[2]
100
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
101
+ x = self.resample(x)
102
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
103
+ return x, out_feat_cache
104
+
105
+ def temporal_conv(self, x, is_first_frame, feat_cache):
106
+ b, c, t, h, w = x.size()
107
+ cache_x = x[:, :, -CACHE_T:, :, :].clone()
108
+ if cache_x.shape[2] < 2 and feat_cache is not None:
109
+ cache_x = torch.cat([
110
+ torch.zeros_like(cache_x),
111
+ cache_x
112
+ ], dim=2)
113
+ x = torch.cond(
114
+ is_first_frame,
115
+ lambda: torch.cat([torch.zeros_like(x), x], dim=1).contiguous(),
116
+ lambda: self.time_conv(x, feat_cache).contiguous(),
117
+ )
118
+ # x = self.time_conv(x, feat_cache)
119
+ out_feat_cache = cache_x
120
+
121
+ x = x.reshape(b, 2, c, t, h, w)
122
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
123
+ 3)
124
+ x = x.reshape(b, c, t * 2, h, w)
125
+ return x.contiguous(), out_feat_cache.contiguous()
126
+
127
+ def init_weight(self, conv):
128
+ conv_weight = conv.weight
129
+ nn.init.zeros_(conv_weight)
130
+ c1, c2, t, h, w = conv_weight.size()
131
+ one_matrix = torch.eye(c1, c2)
132
+ init_matrix = one_matrix
133
+ nn.init.zeros_(conv_weight)
134
+ # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
135
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
136
+ conv.weight.data.copy_(conv_weight)
137
+ nn.init.zeros_(conv.bias.data)
138
+
139
+ def init_weight2(self, conv):
140
+ conv_weight = conv.weight.data
141
+ nn.init.zeros_(conv_weight)
142
+ c1, c2, t, h, w = conv_weight.size()
143
+ init_matrix = torch.eye(c1 // 2, c2)
144
+ # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
145
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
146
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
147
+ conv.weight.data.copy_(conv_weight)
148
+ nn.init.zeros_(conv.bias.data)
149
+
150
+
151
+ class VAEDecoderWrapperSingle(nn.Module):
152
+ def __init__(self):
153
+ super().__init__()
154
+ self.decoder = VAEDecoder3d()
155
+ mean = [
156
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
157
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
158
+ ]
159
+ std = [
160
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
161
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
162
+ ]
163
+ self.mean = torch.tensor(mean, dtype=torch.float32)
164
+ self.std = torch.tensor(std, dtype=torch.float32)
165
+ self.z_dim = 16
166
+ self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
167
+
168
+ def forward(
169
+ self,
170
+ z: torch.Tensor,
171
+ is_first_frame: torch.Tensor,
172
+ *feat_cache: List[torch.Tensor]
173
+ ):
174
+ # from [batch_size, num_frames, num_channels, height, width]
175
+ # to [batch_size, num_channels, num_frames, height, width]
176
+ z = z.permute(0, 2, 1, 3, 4)
177
+ assert z.shape[2] == 1
178
+ feat_cache = list(feat_cache)
179
+ is_first_frame = is_first_frame.bool()
180
+
181
+ device, dtype = z.device, z.dtype
182
+ scale = [self.mean.to(device=device, dtype=dtype),
183
+ 1.0 / self.std.to(device=device, dtype=dtype)]
184
+
185
+ if isinstance(scale[0], torch.Tensor):
186
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
187
+ 1, self.z_dim, 1, 1, 1)
188
+ else:
189
+ z = z / scale[1] + scale[0]
190
+ x = self.conv2(z)
191
+ out, feat_cache = self.decoder(x, is_first_frame, feat_cache=feat_cache)
192
+ out = out.clamp_(-1, 1)
193
+ # from [batch_size, num_channels, num_frames, height, width]
194
+ # to [batch_size, num_frames, num_channels, height, width]
195
+ out = out.permute(0, 2, 1, 3, 4)
196
+ return out, feat_cache
197
+
198
+
199
+ class VAEDecoder3d(nn.Module):
200
+ def __init__(self,
201
+ dim=96,
202
+ z_dim=16,
203
+ dim_mult=[1, 2, 4, 4],
204
+ num_res_blocks=2,
205
+ attn_scales=[],
206
+ temperal_upsample=[True, True, False],
207
+ dropout=0.0):
208
+ super().__init__()
209
+ self.dim = dim
210
+ self.z_dim = z_dim
211
+ self.dim_mult = dim_mult
212
+ self.num_res_blocks = num_res_blocks
213
+ self.attn_scales = attn_scales
214
+ self.temperal_upsample = temperal_upsample
215
+ self.cache_t = 2
216
+ self.decoder_conv_num = 32
217
+
218
+ # dimensions
219
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
220
+ scale = 1.0 / 2**(len(dim_mult) - 2)
221
+
222
+ # init block
223
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
224
+
225
+ # middle blocks
226
+ self.middle = nn.Sequential(
227
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
228
+ ResidualBlock(dims[0], dims[0], dropout))
229
+
230
+ # upsample blocks
231
+ upsamples = []
232
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
233
+ # residual (+attention) blocks
234
+ if i == 1 or i == 2 or i == 3:
235
+ in_dim = in_dim // 2
236
+ for _ in range(num_res_blocks + 1):
237
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
238
+ if scale in attn_scales:
239
+ upsamples.append(AttentionBlock(out_dim))
240
+ in_dim = out_dim
241
+
242
+ # upsample block
243
+ if i != len(dim_mult) - 1:
244
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
245
+ upsamples.append(Resample(out_dim, mode=mode))
246
+ scale *= 2.0
247
+ self.upsamples = nn.Sequential(*upsamples)
248
+
249
+ # output blocks
250
+ self.head = nn.Sequential(
251
+ RMS_norm(out_dim, images=False), nn.SiLU(),
252
+ CausalConv3d(out_dim, 3, 3, padding=1))
253
+
254
+ def forward(
255
+ self,
256
+ x: torch.Tensor,
257
+ is_first_frame: torch.Tensor,
258
+ feat_cache: List[torch.Tensor]
259
+ ):
260
+ idx = 0
261
+ out_feat_cache = []
262
+
263
+ # conv1
264
+ cache_x = x[:, :, -self.cache_t:, :, :].clone()
265
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
266
+ # cache last frame of last two chunk
267
+ cache_x = torch.cat([
268
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
269
+ cache_x.device), cache_x
270
+ ],
271
+ dim=2)
272
+ x = self.conv1(x, feat_cache[idx])
273
+ out_feat_cache.append(cache_x)
274
+ idx += 1
275
+
276
+ # middle
277
+ for layer in self.middle:
278
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
279
+ x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
280
+ idx += 2
281
+ out_feat_cache.append(out_feat_cache_1)
282
+ out_feat_cache.append(out_feat_cache_2)
283
+ else:
284
+ x = layer(x)
285
+
286
+ # upsamples
287
+ for layer in self.upsamples:
288
+ if isinstance(layer, Resample):
289
+ x, cache_x = layer(x, is_first_frame, feat_cache[idx])
290
+ if cache_x is not None:
291
+ out_feat_cache.append(cache_x)
292
+ idx += 1
293
+ else:
294
+ x, out_feat_cache_1, out_feat_cache_2 = layer(x, feat_cache[idx], feat_cache[idx + 1])
295
+ idx += 2
296
+ out_feat_cache.append(out_feat_cache_1)
297
+ out_feat_cache.append(out_feat_cache_2)
298
+
299
+ # head
300
+ for layer in self.head:
301
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
302
+ cache_x = x[:, :, -self.cache_t:, :, :].clone()
303
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
304
+ # cache last frame of last two chunk
305
+ cache_x = torch.cat([
306
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
307
+ cache_x.device), cache_x
308
+ ],
309
+ dim=2)
310
+ x = layer(x, feat_cache[idx])
311
+ out_feat_cache.append(cache_x)
312
+ idx += 1
313
+ else:
314
+ x = layer(x)
315
+ return x, out_feat_cache
316
+
317
+
318
+ class VAETRTWrapper():
319
+ def __init__(self):
320
+ TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
321
+ with open("checkpoints/vae_decoder_int8.trt", "rb") as f, trt.Runtime(TRT_LOGGER) as rt:
322
+ self.engine: trt.ICudaEngine = rt.deserialize_cuda_engine(f.read())
323
+
324
+ self.context: trt.IExecutionContext = self.engine.create_execution_context()
325
+ self.stream = torch.cuda.current_stream().cuda_stream
326
+
327
+ # ──────────────────────────────
328
+ # 2️⃣ Feed the engine with tensors
329
+ # (name-based API in TRT ≥10)
330
+ # ──────────────────────────────
331
+ self.dtype_map = {
332
+ trt.float32: torch.float32,
333
+ trt.float16: torch.float16,
334
+ trt.int8: torch.int8,
335
+ trt.int32: torch.int32,
336
+ }
337
+ test_input = torch.zeros(1, 16, 1, 60, 104).cuda().half()
338
+ is_first_frame = torch.tensor(1.0).cuda().half()
339
+ test_cache_inputs = [c.cuda().half() for c in ZERO_VAE_CACHE]
340
+ test_inputs = [test_input, is_first_frame] + test_cache_inputs
341
+
342
+ # keep references so buffers stay alive
343
+ self.device_buffers, self.outputs = {}, []
344
+
345
+ # ---- inputs ----
346
+ for i, name in enumerate(ALL_INPUTS_NAMES):
347
+ tensor, scale = test_inputs[i], 1 / 127
348
+ tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
349
+
350
+ # dynamic shapes
351
+ if -1 in self.engine.get_tensor_shape(name):
352
+ # new API :contentReference[oaicite:0]{index=0}
353
+ self.context.set_input_shape(name, tuple(tensor.shape))
354
+
355
+ # replaces bindings[] :contentReference[oaicite:1]{index=1}
356
+ self.context.set_tensor_address(name, int(tensor.data_ptr()))
357
+ self.device_buffers[name] = tensor # keep pointer alive
358
+
359
+ # ---- (after all input shapes are known) infer output shapes ----
360
+ # propagates shapes :contentReference[oaicite:2]{index=2}
361
+ self.context.infer_shapes()
362
+
363
+ for i in range(self.engine.num_io_tensors):
364
+ name = self.engine.get_tensor_name(i)
365
+ # replaces binding_is_input :contentReference[oaicite:3]{index=3}
366
+ if self.engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
367
+ shape = tuple(self.context.get_tensor_shape(name))
368
+ dtype = self.dtype_map[self.engine.get_tensor_dtype(name)]
369
+ out = torch.empty(shape, dtype=dtype, device="cuda").contiguous()
370
+
371
+ self.context.set_tensor_address(name, int(out.data_ptr()))
372
+ self.outputs.append(out)
373
+ self.device_buffers[name] = out
374
+
375
+ # helper to quant-convert on the fly
376
+ def quantize_if_needed(self, t, expected_dtype, scale):
377
+ if expected_dtype == trt.int8 and t.dtype != torch.int8:
378
+ t = torch.clamp((t / scale).round(), -128, 127).to(torch.int8).contiguous()
379
+ return t # keep pointer alive
380
+
381
+ def forward(self, *test_inputs):
382
+ for i, name in enumerate(ALL_INPUTS_NAMES):
383
+ tensor, scale = test_inputs[i], 1 / 127
384
+ tensor = self.quantize_if_needed(tensor, self.engine.get_tensor_dtype(name), scale)
385
+ self.context.set_tensor_address(name, int(tensor.data_ptr()))
386
+ self.device_buffers[name] = tensor
387
+
388
+ self.context.execute_async_v3(stream_handle=self.stream)
389
+ torch.cuda.current_stream().synchronize()
390
+ return self.outputs
demo_utils/vae_block3.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from einops import rearrange
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ from wan.modules.vae import AttentionBlock, CausalConv3d, RMS_norm, ResidualBlock, Upsample
7
+
8
+
9
+ class Resample(nn.Module):
10
+
11
+ def __init__(self, dim, mode):
12
+ assert mode in ('none', 'upsample2d', 'upsample3d', 'downsample2d',
13
+ 'downsample3d')
14
+ super().__init__()
15
+ self.dim = dim
16
+ self.mode = mode
17
+ self.cache_t = 2
18
+
19
+ # layers
20
+ if mode == 'upsample2d':
21
+ self.resample = nn.Sequential(
22
+ Upsample(scale_factor=(2., 2.), mode='nearest'),
23
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
24
+ elif mode == 'upsample3d':
25
+ self.resample = nn.Sequential(
26
+ Upsample(scale_factor=(2., 2.), mode='nearest'),
27
+ nn.Conv2d(dim, dim // 2, 3, padding=1))
28
+ self.time_conv = CausalConv3d(
29
+ dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
30
+
31
+ elif mode == 'downsample2d':
32
+ self.resample = nn.Sequential(
33
+ nn.ZeroPad2d((0, 1, 0, 1)),
34
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
35
+ elif mode == 'downsample3d':
36
+ self.resample = nn.Sequential(
37
+ nn.ZeroPad2d((0, 1, 0, 1)),
38
+ nn.Conv2d(dim, dim, 3, stride=(2, 2)))
39
+ self.time_conv = CausalConv3d(
40
+ dim, dim, (3, 1, 1), stride=(2, 1, 1), padding=(0, 0, 0))
41
+
42
+ else:
43
+ self.resample = nn.Identity()
44
+
45
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
46
+ b, c, t, h, w = x.size()
47
+ if self.mode == 'upsample3d':
48
+ if feat_cache is not None:
49
+ idx = feat_idx[0]
50
+ if feat_cache[idx] is None:
51
+ feat_cache[idx] = 'Rep'
52
+ feat_idx[0] += 1
53
+ else:
54
+
55
+ cache_x = x[:, :, -self.cache_t:, :, :].clone()
56
+ if cache_x.shape[2] < 2 and feat_cache[
57
+ idx] is not None and feat_cache[idx] != 'Rep':
58
+ # cache last frame of last two chunk
59
+ cache_x = torch.cat([
60
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
61
+ cache_x.device), cache_x
62
+ ],
63
+ dim=2)
64
+ if cache_x.shape[2] < 2 and feat_cache[
65
+ idx] is not None and feat_cache[idx] == 'Rep':
66
+ cache_x = torch.cat([
67
+ torch.zeros_like(cache_x).to(cache_x.device),
68
+ cache_x
69
+ ],
70
+ dim=2)
71
+ if feat_cache[idx] == 'Rep':
72
+ x = self.time_conv(x)
73
+ else:
74
+ x = self.time_conv(x, feat_cache[idx])
75
+ feat_cache[idx] = cache_x
76
+ feat_idx[0] += 1
77
+
78
+ x = x.reshape(b, 2, c, t, h, w)
79
+ x = torch.stack((x[:, 0, :, :, :, :], x[:, 1, :, :, :, :]),
80
+ 3)
81
+ x = x.reshape(b, c, t * 2, h, w)
82
+ t = x.shape[2]
83
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
84
+ x = self.resample(x)
85
+ x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
86
+
87
+ if self.mode == 'downsample3d':
88
+ if feat_cache is not None:
89
+ idx = feat_idx[0]
90
+ if feat_cache[idx] is None:
91
+ feat_cache[idx] = x.clone()
92
+ feat_idx[0] += 1
93
+ else:
94
+
95
+ cache_x = x[:, :, -1:, :, :].clone()
96
+ # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
97
+ # # cache last frame of last two chunk
98
+ # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
99
+
100
+ x = self.time_conv(
101
+ torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
102
+ feat_cache[idx] = cache_x
103
+ feat_idx[0] += 1
104
+ return x
105
+
106
+ def init_weight(self, conv):
107
+ conv_weight = conv.weight
108
+ nn.init.zeros_(conv_weight)
109
+ c1, c2, t, h, w = conv_weight.size()
110
+ one_matrix = torch.eye(c1, c2)
111
+ init_matrix = one_matrix
112
+ nn.init.zeros_(conv_weight)
113
+ # conv_weight.data[:,:,-1,1,1] = init_matrix * 0.5
114
+ conv_weight.data[:, :, 1, 0, 0] = init_matrix # * 0.5
115
+ conv.weight.data.copy_(conv_weight)
116
+ nn.init.zeros_(conv.bias.data)
117
+
118
+ def init_weight2(self, conv):
119
+ conv_weight = conv.weight.data
120
+ nn.init.zeros_(conv_weight)
121
+ c1, c2, t, h, w = conv_weight.size()
122
+ init_matrix = torch.eye(c1 // 2, c2)
123
+ # init_matrix = repeat(init_matrix, 'o ... -> (o 2) ...').permute(1,0,2).contiguous().reshape(c1,c2)
124
+ conv_weight[:c1 // 2, :, -1, 0, 0] = init_matrix
125
+ conv_weight[c1 // 2:, :, -1, 0, 0] = init_matrix
126
+ conv.weight.data.copy_(conv_weight)
127
+ nn.init.zeros_(conv.bias.data)
128
+
129
+
130
+ class VAEDecoderWrapper(nn.Module):
131
+ def __init__(self):
132
+ super().__init__()
133
+ self.decoder = VAEDecoder3d()
134
+ mean = [
135
+ -0.7571, -0.7089, -0.9113, 0.1075, -0.1745, 0.9653, -0.1517, 1.5508,
136
+ 0.4134, -0.0715, 0.5517, -0.3632, -0.1922, -0.9497, 0.2503, -0.2921
137
+ ]
138
+ std = [
139
+ 2.8184, 1.4541, 2.3275, 2.6558, 1.2196, 1.7708, 2.6052, 2.0743,
140
+ 3.2687, 2.1526, 2.8652, 1.5579, 1.6382, 1.1253, 2.8251, 1.9160
141
+ ]
142
+ self.mean = torch.tensor(mean, dtype=torch.float32)
143
+ self.std = torch.tensor(std, dtype=torch.float32)
144
+ self.z_dim = 16
145
+ self.conv2 = CausalConv3d(self.z_dim, self.z_dim, 1)
146
+
147
+ def forward(
148
+ self,
149
+ z: torch.Tensor,
150
+ *feat_cache: List[torch.Tensor]
151
+ ):
152
+ # from [batch_size, num_frames, num_channels, height, width]
153
+ # to [batch_size, num_channels, num_frames, height, width]
154
+ z = z.permute(0, 2, 1, 3, 4)
155
+ feat_cache = list(feat_cache)
156
+ print("Length of feat_cache: ", len(feat_cache))
157
+
158
+ device, dtype = z.device, z.dtype
159
+ scale = [self.mean.to(device=device, dtype=dtype),
160
+ 1.0 / self.std.to(device=device, dtype=dtype)]
161
+
162
+ if isinstance(scale[0], torch.Tensor):
163
+ z = z / scale[1].view(1, self.z_dim, 1, 1, 1) + scale[0].view(
164
+ 1, self.z_dim, 1, 1, 1)
165
+ else:
166
+ z = z / scale[1] + scale[0]
167
+ iter_ = z.shape[2]
168
+ x = self.conv2(z)
169
+ for i in range(iter_):
170
+ if i == 0:
171
+ out, feat_cache = self.decoder(
172
+ x[:, :, i:i + 1, :, :],
173
+ feat_cache=feat_cache)
174
+ else:
175
+ out_, feat_cache = self.decoder(
176
+ x[:, :, i:i + 1, :, :],
177
+ feat_cache=feat_cache)
178
+ out = torch.cat([out, out_], 2)
179
+
180
+ out = out.float().clamp_(-1, 1)
181
+ # from [batch_size, num_channels, num_frames, height, width]
182
+ # to [batch_size, num_frames, num_channels, height, width]
183
+ out = out.permute(0, 2, 1, 3, 4)
184
+ return out, feat_cache
185
+
186
+
187
+ class VAEDecoder3d(nn.Module):
188
+ def __init__(self,
189
+ dim=96,
190
+ z_dim=16,
191
+ dim_mult=[1, 2, 4, 4],
192
+ num_res_blocks=2,
193
+ attn_scales=[],
194
+ temperal_upsample=[True, True, False],
195
+ dropout=0.0):
196
+ super().__init__()
197
+ self.dim = dim
198
+ self.z_dim = z_dim
199
+ self.dim_mult = dim_mult
200
+ self.num_res_blocks = num_res_blocks
201
+ self.attn_scales = attn_scales
202
+ self.temperal_upsample = temperal_upsample
203
+ self.cache_t = 2
204
+ self.decoder_conv_num = 32
205
+
206
+ # dimensions
207
+ dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
208
+ scale = 1.0 / 2**(len(dim_mult) - 2)
209
+
210
+ # init block
211
+ self.conv1 = CausalConv3d(z_dim, dims[0], 3, padding=1)
212
+
213
+ # middle blocks
214
+ self.middle = nn.Sequential(
215
+ ResidualBlock(dims[0], dims[0], dropout), AttentionBlock(dims[0]),
216
+ ResidualBlock(dims[0], dims[0], dropout))
217
+
218
+ # upsample blocks
219
+ upsamples = []
220
+ for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
221
+ # residual (+attention) blocks
222
+ if i == 1 or i == 2 or i == 3:
223
+ in_dim = in_dim // 2
224
+ for _ in range(num_res_blocks + 1):
225
+ upsamples.append(ResidualBlock(in_dim, out_dim, dropout))
226
+ if scale in attn_scales:
227
+ upsamples.append(AttentionBlock(out_dim))
228
+ in_dim = out_dim
229
+
230
+ # upsample block
231
+ if i != len(dim_mult) - 1:
232
+ mode = 'upsample3d' if temperal_upsample[i] else 'upsample2d'
233
+ upsamples.append(Resample(out_dim, mode=mode))
234
+ scale *= 2.0
235
+ self.upsamples = nn.Sequential(*upsamples)
236
+
237
+ # output blocks
238
+ self.head = nn.Sequential(
239
+ RMS_norm(out_dim, images=False), nn.SiLU(),
240
+ CausalConv3d(out_dim, 3, 3, padding=1))
241
+
242
+ def forward(
243
+ self,
244
+ x: torch.Tensor,
245
+ feat_cache: List[torch.Tensor]
246
+ ):
247
+ feat_idx = [0]
248
+
249
+ # conv1
250
+ idx = feat_idx[0]
251
+ cache_x = x[:, :, -self.cache_t:, :, :].clone()
252
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
253
+ # cache last frame of last two chunk
254
+ cache_x = torch.cat([
255
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
256
+ cache_x.device), cache_x
257
+ ],
258
+ dim=2)
259
+ x = self.conv1(x, feat_cache[idx])
260
+ feat_cache[idx] = cache_x
261
+ feat_idx[0] += 1
262
+
263
+ # middle
264
+ for layer in self.middle:
265
+ if isinstance(layer, ResidualBlock) and feat_cache is not None:
266
+ x = layer(x, feat_cache, feat_idx)
267
+ else:
268
+ x = layer(x)
269
+
270
+ # upsamples
271
+ for layer in self.upsamples:
272
+ x = layer(x, feat_cache, feat_idx)
273
+
274
+ # head
275
+ for layer in self.head:
276
+ if isinstance(layer, CausalConv3d) and feat_cache is not None:
277
+ idx = feat_idx[0]
278
+ cache_x = x[:, :, -self.cache_t:, :, :].clone()
279
+ if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
280
+ # cache last frame of last two chunk
281
+ cache_x = torch.cat([
282
+ feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
283
+ cache_x.device), cache_x
284
+ ],
285
+ dim=2)
286
+ x = layer(x, feat_cache[idx])
287
+ feat_cache[idx] = cache_x
288
+ feat_idx[0] += 1
289
+ else:
290
+ x = layer(x)
291
+ return x, feat_cache
demo_utils/vae_torch2trt.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---- INT8 (optional) ----
2
+ from demo_utils.vae import (
3
+ VAEDecoderWrapperSingle, # main nn.Module
4
+ ZERO_VAE_CACHE # helper constants shipped with your code base
5
+ )
6
+ import pycuda.driver as cuda # ← add
7
+ import pycuda.autoinit # noqa
8
+
9
+ import sys
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ import tensorrt as trt
14
+
15
+ from utils.dataset import ShardingLMDBDataset
16
+
17
+ data_path = "/mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb_oneshard"
18
+ dataset = ShardingLMDBDataset(data_path, max_pair=int(1e8))
19
+ dataloader = torch.utils.data.DataLoader(
20
+ dataset,
21
+ batch_size=1,
22
+ num_workers=0
23
+ )
24
+
25
+ # ─────────────────────────────────────────────────────────
26
+ # 1️⃣ Bring the PyTorch model into scope
27
+ # (all code you pasted lives in `vae_decoder.py`)
28
+ # ─────────────────────────────────────────────────────────
29
+
30
+ # --- dummy tensors (exact shapes you posted) ---
31
+ dummy_input = torch.randn(1, 1, 16, 60, 104).half().cuda()
32
+ is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
33
+ dummy_cache_input = [
34
+ torch.randn(*s.shape).half().cuda() if isinstance(s, torch.Tensor) else s
35
+ for s in ZERO_VAE_CACHE # keep exactly the same ordering
36
+ ]
37
+ inputs = [dummy_input, is_first_frame, *dummy_cache_input]
38
+
39
+ # ─────────────────────────────────────────────────────────
40
+ # 2️⃣ Export → ONNX
41
+ # ─────────────────────────────────────────────────────────
42
+ model = VAEDecoderWrapperSingle().half().cuda().eval()
43
+
44
+ vae_state_dict = torch.load('wan_models/Wan2.1-T2V-1.3B/Wan2.1_VAE.pth', map_location="cpu")
45
+ decoder_state_dict = {}
46
+ for key, value in vae_state_dict.items():
47
+ if 'decoder.' in key or 'conv2' in key:
48
+ decoder_state_dict[key] = value
49
+ model.load_state_dict(decoder_state_dict)
50
+ model = model.half().cuda().eval() # only batch dim dynamic
51
+
52
+ onnx_path = Path("vae_decoder.onnx")
53
+ feat_names = [f"vae_cache_{i}" for i in range(len(dummy_cache_input))]
54
+ all_inputs_names = ["z", "use_cache"] + feat_names
55
+
56
+ with torch.inference_mode():
57
+ torch.onnx.export(
58
+ model,
59
+ tuple(inputs), # must be a tuple
60
+ onnx_path.as_posix(),
61
+ input_names=all_inputs_names,
62
+ output_names=["rgb_out", "cache_out"],
63
+ opset_version=17,
64
+ do_constant_folding=True,
65
+ dynamo=True
66
+ )
67
+ print(f"✅ ONNX graph saved to {onnx_path.resolve()}")
68
+
69
+ # (Optional) quick sanity-check with ONNX-Runtime
70
+ try:
71
+ import onnxruntime as ort
72
+ sess = ort.InferenceSession(onnx_path.as_posix(),
73
+ providers=["CUDAExecutionProvider"])
74
+ ort_inputs = {n: t.cpu().numpy() for n, t in zip(all_inputs_names, inputs)}
75
+ _ = sess.run(None, ort_inputs)
76
+ print("✅ ONNX graph is executable")
77
+ except Exception as e:
78
+ print("⚠️ ONNX check failed:", e)
79
+
80
+ # ─────────────────────────────────────────────────────────
81
+ # 3️⃣ Build the TensorRT engine
82
+ # ─────────────────────────────────────────────────────────
83
+ TRT_LOGGER = trt.Logger(trt.Logger.WARNING)
84
+ builder = trt.Builder(TRT_LOGGER)
85
+ network = builder.create_network(
86
+ 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))
87
+ parser = trt.OnnxParser(network, TRT_LOGGER)
88
+
89
+ with open(onnx_path, "rb") as f:
90
+ if not parser.parse(f.read()):
91
+ for i in range(parser.num_errors):
92
+ print(parser.get_error(i))
93
+ sys.exit("❌ ONNX → TRT parsing failed")
94
+
95
+ config = builder.create_builder_config()
96
+
97
+
98
+ def set_workspace(config, bytes_):
99
+ """Version-agnostic workspace limit."""
100
+ if hasattr(config, "max_workspace_size"): # TRT 8 / 9
101
+ config.max_workspace_size = bytes_
102
+ else: # TRT 10+
103
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, bytes_)
104
+
105
+
106
+ # …
107
+ config = builder.create_builder_config()
108
+ set_workspace(config, 4 << 30) # 4 GB
109
+ # 4 GB
110
+
111
+ if builder.platform_has_fast_fp16:
112
+ config.set_flag(trt.BuilderFlag.FP16)
113
+
114
+ # ---- INT8 (optional) ----
115
+ # provide a calibrator if you need an INT8 engine; comment this
116
+ # block if you only care about FP16.
117
+ # ─────────────────────────────────────────────────────────
118
+ # helper: version-agnostic workspace limit
119
+ # ─────────────────────────────────────────────────────────
120
+
121
+
122
+ def set_workspace(config: trt.IBuilderConfig, bytes_: int = 4 << 30):
123
+ """
124
+ TRT < 10.x → config.max_workspace_size
125
+ TRT ≥ 10.x → config.set_memory_pool_limit(...)
126
+ """
127
+ if hasattr(config, "max_workspace_size"): # TRT 8 / 9
128
+ config.max_workspace_size = bytes_
129
+ else: # TRT 10+
130
+ config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE,
131
+ bytes_)
132
+
133
+ # ─────────────────────────────────────────────────────────
134
+ # (optional) INT-8 calibrator
135
+ # ─────────────────────────────────────────────────────────
136
+ # ‼ Only keep this block if you really need INT-8 ‼ # gracefully skip if PyCUDA not present
137
+
138
+
139
+ class VAECalibrator(trt.IInt8EntropyCalibrator2):
140
+ def __init__(self, loader, cache="calibration.cache", max_batches=10):
141
+ super().__init__()
142
+ self.loader = iter(loader)
143
+ self.batch_size = loader.batch_size or 1
144
+ self.max_batches = max_batches
145
+ self.count = 0
146
+ self.cache_file = cache
147
+ self.stream = cuda.Stream()
148
+ self.dev_ptrs = {}
149
+
150
+ # --- TRT 10 needs BOTH spellings ---
151
+ def get_batch_size(self):
152
+ return self.batch_size
153
+
154
+ def getBatchSize(self):
155
+ return self.batch_size
156
+
157
+ def get_batch(self, names):
158
+ if self.count >= self.max_batches:
159
+ return None
160
+
161
+ # Randomly sample a number from 1 to 10
162
+ import random
163
+ vae_idx = random.randint(0, 10)
164
+ data = next(self.loader)
165
+
166
+ latent = data['ode_latent'][0][:, :1]
167
+ is_first_frame = torch.tensor([1.0], device="cuda", dtype=torch.float16)
168
+ feat_cache = ZERO_VAE_CACHE
169
+ for i in range(vae_idx):
170
+ inputs = [latent, is_first_frame, *feat_cache]
171
+ with torch.inference_mode():
172
+ outputs = model(*inputs)
173
+ latent = data['ode_latent'][0][:, i + 1:i + 2]
174
+ is_first_frame = torch.tensor([0.0], device="cuda", dtype=torch.float16)
175
+ feat_cache = outputs[1:]
176
+
177
+ # -------- ensure context is current --------
178
+ z_np = latent.cpu().numpy().astype('float32')
179
+
180
+ ptrs = [] # list[int] – one entry per name
181
+ for name in names: # <-- match TRT's binding order
182
+ if name == "z":
183
+ arr = z_np
184
+ elif name == "use_cache":
185
+ arr = is_first_frame.cpu().numpy().astype('float32')
186
+ else:
187
+ idx = int(name.split('_')[-1]) # "vae_cache_17" -> 17
188
+ arr = feat_cache[idx].cpu().numpy().astype('float32')
189
+
190
+ if name not in self.dev_ptrs:
191
+ self.dev_ptrs[name] = cuda.mem_alloc(arr.nbytes)
192
+
193
+ cuda.memcpy_htod_async(self.dev_ptrs[name], arr, self.stream)
194
+ ptrs.append(int(self.dev_ptrs[name])) # ***int() is required***
195
+
196
+ self.stream.synchronize()
197
+ self.count += 1
198
+ print(f"Calibration batch {self.count}/{self.max_batches}")
199
+ return ptrs
200
+
201
+ # --- calibration-cache helpers (both spellings) ---
202
+ def read_calibration_cache(self):
203
+ try:
204
+ with open(self.cache_file, "rb") as f:
205
+ return f.read()
206
+ except FileNotFoundError:
207
+ return None
208
+
209
+ def readCalibrationCache(self):
210
+ return self.read_calibration_cache()
211
+
212
+ def write_calibration_cache(self, cache):
213
+ with open(self.cache_file, "wb") as f:
214
+ f.write(cache)
215
+
216
+ def writeCalibrationCache(self, cache):
217
+ self.write_calibration_cache(cache)
218
+
219
+
220
+ # ─────────────────────────────────────────────────────────
221
+ # Builder-config + optimisation profile
222
+ # ─────────────────────────────────────────────────────────
223
+ config = builder.create_builder_config()
224
+ set_workspace(config, 4 << 30) # 4 GB
225
+
226
+ # ► enable FP16 if possible
227
+ if builder.platform_has_fast_fp16:
228
+ config.set_flag(trt.BuilderFlag.FP16)
229
+
230
+ # ► enable INT-8 (delete this block if you don’t need it)
231
+ if cuda is not None:
232
+ config.set_flag(trt.BuilderFlag.INT8)
233
+ # supply any representative batch you like – here we reuse the latent z
234
+ calib = VAECalibrator(dataloader)
235
+ # TRT-10 renamed the setter:
236
+ if hasattr(config, "set_int8_calibrator"): # TRT 10+
237
+ config.set_int8_calibrator(calib)
238
+ else: # TRT ≤ 9
239
+ config.int8_calibrator = calib
240
+
241
+ # ---- optimisation profile ----
242
+ profile = builder.create_optimization_profile()
243
+ profile.set_shape(all_inputs_names[0], # latent z
244
+ min=(1, 1, 16, 60, 104),
245
+ opt=(1, 1, 16, 60, 104),
246
+ max=(1, 1, 16, 60, 104))
247
+ profile.set_shape("use_cache", # scalar flag
248
+ min=(1,), opt=(1,), max=(1,))
249
+ for name, tensor in zip(all_inputs_names[2:], dummy_cache_input):
250
+ profile.set_shape(name, tensor.shape, tensor.shape, tensor.shape)
251
+
252
+ config.add_optimization_profile(profile)
253
+
254
+ # ─────────────────────────────────────────────────────────
255
+ # Build the engine (API changed in TRT-10)
256
+ # ─────────────────────────────────────────────────────────
257
+ print("⚙️ Building engine … (can take a minute)")
258
+
259
+ if hasattr(builder, "build_serialized_network"): # TRT 10+
260
+ serialized_engine = builder.build_serialized_network(network, config)
261
+ assert serialized_engine is not None, "build_serialized_network() failed"
262
+ plan_path = Path("checkpoints/vae_decoder_int8.trt")
263
+ plan_path.write_bytes(serialized_engine)
264
+ engine_bytes = serialized_engine # keep for smoke-test
265
+ else: # TRT ≤ 9
266
+ engine = builder.build_engine(network, config)
267
+ assert engine is not None, "build_engine() returned None"
268
+ plan_path = Path("checkpoints/vae_decoder_int8.trt")
269
+ plan_path.write_bytes(engine.serialize())
270
+ engine_bytes = engine.serialize()
271
+
272
+ print(f"✅ TensorRT engine written to {plan_path.resolve()}")
273
+
274
+ # ─────────────────────────────────────────────────────────
275
+ # 4️⃣ Quick smoke-test with the brand-new engine
276
+ # ─────────────────────────────────────────────────────────
277
+ with trt.Runtime(TRT_LOGGER) as rt:
278
+ engine = rt.deserialize_cuda_engine(engine_bytes)
279
+ context = engine.create_execution_context()
280
+ stream = torch.cuda.current_stream().cuda_stream
281
+
282
+ # pre-allocate device buffers once
283
+ device_buffers, outputs = {}, []
284
+ dtype_map = {trt.float32: torch.float32,
285
+ trt.float16: torch.float16,
286
+ trt.int8: torch.int8,
287
+ trt.int32: torch.int32}
288
+
289
+ for name, tensor in zip(all_inputs_names, inputs):
290
+ if -1 in engine.get_tensor_shape(name): # dynamic input
291
+ context.set_input_shape(name, tensor.shape)
292
+ context.set_tensor_address(name, int(tensor.data_ptr()))
293
+ device_buffers[name] = tensor
294
+
295
+ context.infer_shapes() # propagate ⇢ outputs
296
+ for i in range(engine.num_io_tensors):
297
+ name = engine.get_tensor_name(i)
298
+ if engine.get_tensor_mode(name) == trt.TensorIOMode.OUTPUT:
299
+ shape = tuple(context.get_tensor_shape(name))
300
+ dtype = dtype_map[engine.get_tensor_dtype(name)]
301
+ out = torch.empty(shape, dtype=dtype, device="cuda")
302
+ context.set_tensor_address(name, int(out.data_ptr()))
303
+ outputs.append(out)
304
+ print(f"output {name} shape: {shape}")
305
+
306
+ context.execute_async_v3(stream_handle=stream)
307
+ torch.cuda.current_stream().synchronize()
308
+ print("✅ TRT execution OK – first output shape:", outputs[0].shape)
images/.gitkeep ADDED
File without changes
inference.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import torch
3
+ import os
4
+ from omegaconf import OmegaConf
5
+ from tqdm import tqdm
6
+ from torchvision import transforms
7
+ from torchvision.io import write_video
8
+ from einops import rearrange
9
+ import torch.distributed as dist
10
+ from torch.utils.data import DataLoader, SequentialSampler
11
+ from torch.utils.data.distributed import DistributedSampler
12
+
13
+ from pipeline import (
14
+ CausalDiffusionInferencePipeline,
15
+ CausalInferencePipeline
16
+ )
17
+ from utils.dataset import TextDataset, TextImagePairDataset
18
+ from utils.misc import set_seed
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--config_path", type=str, help="Path to the config file")
22
+ parser.add_argument("--checkpoint_path", type=str, help="Path to the checkpoint folder")
23
+ parser.add_argument("--data_path", type=str, help="Path to the dataset")
24
+ parser.add_argument("--extended_prompt_path", type=str, help="Path to the extended prompt")
25
+ parser.add_argument("--output_folder", type=str, help="Output folder")
26
+ parser.add_argument("--num_output_frames", type=int, default=21,
27
+ help="Number of overlap frames between sliding windows")
28
+ parser.add_argument("--i2v", action="store_true", help="Whether to perform I2V (or T2V by default)")
29
+ parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA parameters")
30
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
31
+ parser.add_argument("--num_samples", type=int, default=1, help="Number of samples to generate per prompt")
32
+ parser.add_argument("--save_with_index", action="store_true",
33
+ help="Whether to save the video using the index or prompt as the filename")
34
+ args = parser.parse_args()
35
+
36
+ # Initialize distributed inference
37
+ if "LOCAL_RANK" in os.environ:
38
+ dist.init_process_group(backend='nccl')
39
+ local_rank = int(os.environ["LOCAL_RANK"])
40
+ torch.cuda.set_device(local_rank)
41
+ device = torch.device(f"cuda:{local_rank}")
42
+ world_size = dist.get_world_size()
43
+ set_seed(args.seed + local_rank)
44
+ else:
45
+ device = torch.device("cuda")
46
+ local_rank = 0
47
+ world_size = 1
48
+ set_seed(args.seed)
49
+
50
+ torch.set_grad_enabled(False)
51
+
52
+ config = OmegaConf.load(args.config_path)
53
+ default_config = OmegaConf.load("configs/default_config.yaml")
54
+ config = OmegaConf.merge(default_config, config)
55
+
56
+ # Initialize pipeline
57
+ if hasattr(config, 'denoising_step_list'):
58
+ # Few-step inference
59
+ pipeline = CausalInferencePipeline(config, device=device)
60
+ else:
61
+ # Multi-step diffusion inference
62
+ pipeline = CausalDiffusionInferencePipeline(config, device=device)
63
+
64
+ if args.checkpoint_path:
65
+ state_dict = torch.load(args.checkpoint_path, map_location="cpu")
66
+ pipeline.generator.load_state_dict(state_dict['generator' if not args.use_ema else 'generator_ema'])
67
+
68
+ pipeline = pipeline.to(device=device, dtype=torch.bfloat16)
69
+
70
+ # Create dataset
71
+ if args.i2v:
72
+ assert not dist.is_initialized(), "I2V does not support distributed inference yet"
73
+ transform = transforms.Compose([
74
+ transforms.Resize((480, 832)),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize([0.5], [0.5])
77
+ ])
78
+ dataset = TextImagePairDataset(args.data_path, transform=transform)
79
+ else:
80
+ dataset = TextDataset(prompt_path=args.data_path, extended_prompt_path=args.extended_prompt_path)
81
+ num_prompts = len(dataset)
82
+ print(f"Number of prompts: {num_prompts}")
83
+
84
+ if dist.is_initialized():
85
+ sampler = DistributedSampler(dataset, shuffle=False, drop_last=True)
86
+ else:
87
+ sampler = SequentialSampler(dataset)
88
+ dataloader = DataLoader(dataset, batch_size=1, sampler=sampler, num_workers=0, drop_last=False)
89
+
90
+ # Create output directory (only on main process to avoid race conditions)
91
+ if local_rank == 0:
92
+ os.makedirs(args.output_folder, exist_ok=True)
93
+
94
+ if dist.is_initialized():
95
+ dist.barrier()
96
+
97
+
98
+ def encode(self, videos: torch.Tensor) -> torch.Tensor:
99
+ device, dtype = videos[0].device, videos[0].dtype
100
+ scale = [self.mean.to(device=device, dtype=dtype),
101
+ 1.0 / self.std.to(device=device, dtype=dtype)]
102
+ output = [
103
+ self.model.encode(u.unsqueeze(0), scale).float().squeeze(0)
104
+ for u in videos
105
+ ]
106
+
107
+ output = torch.stack(output, dim=0)
108
+ return output
109
+
110
+
111
+ for i, batch_data in tqdm(enumerate(dataloader), disable=(local_rank != 0)):
112
+ idx = batch_data['idx'].item()
113
+
114
+ # For DataLoader batch_size=1, the batch_data is already a single item, but in a batch container
115
+ # Unpack the batch data for convenience
116
+ if isinstance(batch_data, dict):
117
+ batch = batch_data
118
+ elif isinstance(batch_data, list):
119
+ batch = batch_data[0] # First (and only) item in the batch
120
+
121
+ all_video = []
122
+ num_generated_frames = 0 # Number of generated (latent) frames
123
+
124
+ if args.i2v:
125
+ # For image-to-video, batch contains image and caption
126
+ prompt = batch['prompts'][0] # Get caption from batch
127
+ prompts = [prompt] * args.num_samples
128
+
129
+ # Process the image
130
+ image = batch['image'].squeeze(0).unsqueeze(0).unsqueeze(2).to(device=device, dtype=torch.bfloat16)
131
+
132
+ # Encode the input image as the first latent
133
+ initial_latent = pipeline.vae.encode_to_latent(image).to(device=device, dtype=torch.bfloat16)
134
+ initial_latent = initial_latent.repeat(args.num_samples, 1, 1, 1, 1)
135
+
136
+ sampled_noise = torch.randn(
137
+ [args.num_samples, args.num_output_frames - 1, 16, 60, 104], device=device, dtype=torch.bfloat16
138
+ )
139
+ else:
140
+ # For text-to-video, batch is just the text prompt
141
+ prompt = batch['prompts'][0]
142
+ extended_prompt = batch['extended_prompts'][0] if 'extended_prompts' in batch else None
143
+ if extended_prompt is not None:
144
+ prompts = [extended_prompt] * args.num_samples
145
+ else:
146
+ prompts = [prompt] * args.num_samples
147
+ initial_latent = None
148
+
149
+ sampled_noise = torch.randn(
150
+ [args.num_samples, args.num_output_frames, 16, 60, 104], device=device, dtype=torch.bfloat16
151
+ )
152
+
153
+ # Generate 81 frames
154
+ video, latents = pipeline.inference(
155
+ noise=sampled_noise,
156
+ text_prompts=prompts,
157
+ return_latents=True,
158
+ initial_latent=initial_latent,
159
+ )
160
+ current_video = rearrange(video, 'b t c h w -> b t h w c').cpu()
161
+ all_video.append(current_video)
162
+ num_generated_frames += latents.shape[1]
163
+
164
+ # Final output video
165
+ video = 255.0 * torch.cat(all_video, dim=1)
166
+
167
+ # Clear VAE cache
168
+ pipeline.vae.model.clear_cache()
169
+
170
+ # Save the video if the current prompt is not a dummy prompt
171
+ if idx < num_prompts:
172
+ model = "regular" if not args.use_ema else "ema"
173
+ for seed_idx in range(args.num_samples):
174
+ # All processes save their videos
175
+ if args.save_with_index:
176
+ output_path = os.path.join(args.output_folder, f'{idx}-{seed_idx}_{model}.mp4')
177
+ else:
178
+ output_path = os.path.join(args.output_folder, f'{prompt[:100]}-{seed_idx}.mp4')
179
+ write_video(output_path, video[seed_idx], fps=16)
model/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffusion import CausalDiffusion
2
+ from .causvid import CausVid
3
+ from .dmd import DMD
4
+ from .gan import GAN
5
+ from .sid import SiD
6
+ from .ode_regression import ODERegression
7
+ __all__ = [
8
+ "CausalDiffusion",
9
+ "CausVid",
10
+ "DMD",
11
+ "GAN",
12
+ "SiD",
13
+ "ODERegression"
14
+ ]
model/base.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from einops import rearrange
3
+ from torch import nn
4
+ import torch.distributed as dist
5
+ import torch
6
+
7
+ from pipeline import SelfForcingTrainingPipeline
8
+ from utils.loss import get_denoising_loss
9
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
10
+
11
+
12
+ class BaseModel(nn.Module):
13
+ def __init__(self, args, device):
14
+ super().__init__()
15
+ self._initialize_models(args, device)
16
+
17
+ self.device = device
18
+ self.args = args
19
+ self.dtype = torch.bfloat16 if args.mixed_precision else torch.float32
20
+ if hasattr(args, "denoising_step_list"):
21
+ self.denoising_step_list = torch.tensor(args.denoising_step_list, dtype=torch.long)
22
+ if args.warp_denoising_step:
23
+ timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
24
+ self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
25
+
26
+ def _initialize_models(self, args, device):
27
+ self.real_model_name = getattr(args, "real_name", "Wan2.1-T2V-1.3B")
28
+ self.fake_model_name = getattr(args, "fake_name", "Wan2.1-T2V-1.3B")
29
+
30
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
31
+ self.generator.model.requires_grad_(True)
32
+
33
+ self.real_score = WanDiffusionWrapper(model_name=self.real_model_name, is_causal=False)
34
+ self.real_score.model.requires_grad_(False)
35
+
36
+ self.fake_score = WanDiffusionWrapper(model_name=self.fake_model_name, is_causal=False)
37
+ self.fake_score.model.requires_grad_(True)
38
+
39
+ self.text_encoder = WanTextEncoder()
40
+ self.text_encoder.requires_grad_(False)
41
+
42
+ self.vae = WanVAEWrapper()
43
+ self.vae.requires_grad_(False)
44
+
45
+ self.scheduler = self.generator.get_scheduler()
46
+ self.scheduler.timesteps = self.scheduler.timesteps.to(device)
47
+
48
+ def _get_timestep(
49
+ self,
50
+ min_timestep: int,
51
+ max_timestep: int,
52
+ batch_size: int,
53
+ num_frame: int,
54
+ num_frame_per_block: int,
55
+ uniform_timestep: bool = False
56
+ ) -> torch.Tensor:
57
+ """
58
+ Randomly generate a timestep tensor based on the generator's task type. It uniformly samples a timestep
59
+ from the range [min_timestep, max_timestep], and returns a tensor of shape [batch_size, num_frame].
60
+ - If uniform_timestep, it will use the same timestep for all frames.
61
+ - If not uniform_timestep, it will use a different timestep for each block.
62
+ """
63
+ if uniform_timestep:
64
+ timestep = torch.randint(
65
+ min_timestep,
66
+ max_timestep,
67
+ [batch_size, 1],
68
+ device=self.device,
69
+ dtype=torch.long
70
+ ).repeat(1, num_frame)
71
+ return timestep
72
+ else:
73
+ timestep = torch.randint(
74
+ min_timestep,
75
+ max_timestep,
76
+ [batch_size, num_frame],
77
+ device=self.device,
78
+ dtype=torch.long
79
+ )
80
+ # make the noise level the same within every block
81
+ if self.independent_first_frame:
82
+ # the first frame is always kept the same
83
+ timestep_from_second = timestep[:, 1:]
84
+ timestep_from_second = timestep_from_second.reshape(
85
+ timestep_from_second.shape[0], -1, num_frame_per_block)
86
+ timestep_from_second[:, :, 1:] = timestep_from_second[:, :, 0:1]
87
+ timestep_from_second = timestep_from_second.reshape(
88
+ timestep_from_second.shape[0], -1)
89
+ timestep = torch.cat([timestep[:, 0:1], timestep_from_second], dim=1)
90
+ else:
91
+ timestep = timestep.reshape(
92
+ timestep.shape[0], -1, num_frame_per_block)
93
+ timestep[:, :, 1:] = timestep[:, :, 0:1]
94
+ timestep = timestep.reshape(timestep.shape[0], -1)
95
+ return timestep
96
+
97
+
98
+ class SelfForcingModel(BaseModel):
99
+ def __init__(self, args, device):
100
+ super().__init__(args, device)
101
+ self.denoising_loss_func = get_denoising_loss(args.denoising_loss_type)()
102
+
103
+ def _run_generator(
104
+ self,
105
+ image_or_video_shape,
106
+ conditional_dict: dict,
107
+ initial_latent: torch.tensor = None
108
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
109
+ """
110
+ Optionally simulate the generator's input from noise using backward simulation
111
+ and then run the generator for one-step.
112
+ Input:
113
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
114
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
115
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
116
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
117
+ - initial_latent: a tensor containing the initial latents [B, F, C, H, W].
118
+ Output:
119
+ - pred_image: a tensor with shape [B, F, C, H, W].
120
+ - denoised_timestep: an integer
121
+ """
122
+ # Step 1: Sample noise and backward simulate the generator's input
123
+ assert getattr(self.args, "backward_simulation", True), "Backward simulation needs to be enabled"
124
+ if initial_latent is not None:
125
+ conditional_dict["initial_latent"] = initial_latent
126
+ if self.args.i2v:
127
+ noise_shape = [image_or_video_shape[0], image_or_video_shape[1] - 1, *image_or_video_shape[2:]]
128
+ else:
129
+ noise_shape = image_or_video_shape.copy()
130
+
131
+ # During training, the number of generated frames should be uniformly sampled from
132
+ # [21, self.num_training_frames], but still being a multiple of self.num_frame_per_block
133
+ min_num_frames = 20 if self.args.independent_first_frame else 21
134
+ max_num_frames = self.num_training_frames - 1 if self.args.independent_first_frame else self.num_training_frames
135
+ assert max_num_frames % self.num_frame_per_block == 0
136
+ assert min_num_frames % self.num_frame_per_block == 0
137
+ max_num_blocks = max_num_frames // self.num_frame_per_block
138
+ min_num_blocks = min_num_frames // self.num_frame_per_block
139
+ num_generated_blocks = torch.randint(min_num_blocks, max_num_blocks + 1, (1,), device=self.device)
140
+ dist.broadcast(num_generated_blocks, src=0)
141
+ num_generated_blocks = num_generated_blocks.item()
142
+ num_generated_frames = num_generated_blocks * self.num_frame_per_block
143
+ if self.args.independent_first_frame and initial_latent is None:
144
+ num_generated_frames += 1
145
+ min_num_frames += 1
146
+ # Sync num_generated_frames across all processes
147
+ noise_shape[1] = num_generated_frames
148
+
149
+ pred_image_or_video, denoised_timestep_from, denoised_timestep_to = self._consistency_backward_simulation(
150
+ noise=torch.randn(noise_shape,
151
+ device=self.device, dtype=self.dtype),
152
+ **conditional_dict,
153
+ )
154
+ # Slice last 21 frames
155
+ if pred_image_or_video.shape[1] > 21:
156
+ with torch.no_grad():
157
+ # Reencode to get image latent
158
+ latent_to_decode = pred_image_or_video[:, :-20, ...]
159
+ # Deccode to video
160
+ pixels = self.vae.decode_to_pixel(latent_to_decode)
161
+ frame = pixels[:, -1:, ...].to(self.dtype)
162
+ frame = rearrange(frame, "b t c h w -> b c t h w")
163
+ # Encode frame to get image latent
164
+ image_latent = self.vae.encode_to_latent(frame).to(self.dtype)
165
+ pred_image_or_video_last_21 = torch.cat([image_latent, pred_image_or_video[:, -20:, ...]], dim=1)
166
+ else:
167
+ pred_image_or_video_last_21 = pred_image_or_video
168
+
169
+ if num_generated_frames != min_num_frames:
170
+ # Currently, we do not use gradient for the first chunk, since it contains image latents
171
+ gradient_mask = torch.ones_like(pred_image_or_video_last_21, dtype=torch.bool)
172
+ if self.args.independent_first_frame:
173
+ gradient_mask[:, :1] = False
174
+ else:
175
+ gradient_mask[:, :self.num_frame_per_block] = False
176
+ else:
177
+ gradient_mask = None
178
+
179
+ pred_image_or_video_last_21 = pred_image_or_video_last_21.to(self.dtype)
180
+ return pred_image_or_video_last_21, gradient_mask, denoised_timestep_from, denoised_timestep_to
181
+
182
+ def _consistency_backward_simulation(
183
+ self,
184
+ noise: torch.Tensor,
185
+ **conditional_dict: dict
186
+ ) -> torch.Tensor:
187
+ """
188
+ Simulate the generator's input from noise to avoid training/inference mismatch.
189
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
190
+ Here we use the consistency sampler (https://arxiv.org/abs/2303.01469)
191
+ Input:
192
+ - noise: a tensor sampled from N(0, 1) with shape [B, F, C, H, W] where the number of frame is 1 for images.
193
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
194
+ Output:
195
+ - output: a tensor with shape [B, T, F, C, H, W].
196
+ T is the total number of timesteps. output[0] is a pure noise and output[i] and i>0
197
+ represents the x0 prediction at each timestep.
198
+ """
199
+ if self.inference_pipeline is None:
200
+ self._initialize_inference_pipeline()
201
+
202
+ return self.inference_pipeline.inference_with_trajectory(
203
+ noise=noise, **conditional_dict
204
+ )
205
+
206
+ def _initialize_inference_pipeline(self):
207
+ """
208
+ Lazy initialize the inference pipeline during the first backward simulation run.
209
+ Here we encapsulate the inference code with a model-dependent outside function.
210
+ We pass our FSDP-wrapped modules into the pipeline to save memory.
211
+ """
212
+ self.inference_pipeline = SelfForcingTrainingPipeline(
213
+ denoising_step_list=self.denoising_step_list,
214
+ scheduler=self.scheduler,
215
+ generator=self.generator,
216
+ num_frame_per_block=self.num_frame_per_block,
217
+ independent_first_frame=self.args.independent_first_frame,
218
+ same_step_across_blocks=self.args.same_step_across_blocks,
219
+ last_step_only=self.args.last_step_only,
220
+ num_max_frames=self.num_training_frames,
221
+ context_noise=self.args.context_noise
222
+ )
model/causvid.py ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from typing import Tuple
3
+ import torch
4
+
5
+ from model.base import BaseModel
6
+
7
+
8
+ class CausVid(BaseModel):
9
+ def __init__(self, args, device):
10
+ """
11
+ Initialize the DMD (Distribution Matching Distillation) module.
12
+ This class is self-contained and compute generator and fake score losses
13
+ in the forward pass.
14
+ """
15
+ super().__init__(args, device)
16
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
17
+ self.num_training_frames = getattr(args, "num_training_frames", 21)
18
+
19
+ if self.num_frame_per_block > 1:
20
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
21
+
22
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
23
+ if self.independent_first_frame:
24
+ self.generator.model.independent_first_frame = True
25
+ if args.gradient_checkpointing:
26
+ self.generator.enable_gradient_checkpointing()
27
+ self.fake_score.enable_gradient_checkpointing()
28
+
29
+ # Step 2: Initialize all dmd hyperparameters
30
+ self.num_train_timestep = args.num_train_timestep
31
+ self.min_step = int(0.02 * self.num_train_timestep)
32
+ self.max_step = int(0.98 * self.num_train_timestep)
33
+ if hasattr(args, "real_guidance_scale"):
34
+ self.real_guidance_scale = args.real_guidance_scale
35
+ self.fake_guidance_scale = args.fake_guidance_scale
36
+ else:
37
+ self.real_guidance_scale = args.guidance_scale
38
+ self.fake_guidance_scale = 0.0
39
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
40
+ self.teacher_forcing = getattr(args, "teacher_forcing", False)
41
+
42
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
43
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
44
+ else:
45
+ self.scheduler.alphas_cumprod = None
46
+
47
+ def _compute_kl_grad(
48
+ self, noisy_image_or_video: torch.Tensor,
49
+ estimated_clean_image_or_video: torch.Tensor,
50
+ timestep: torch.Tensor,
51
+ conditional_dict: dict, unconditional_dict: dict,
52
+ normalization: bool = True
53
+ ) -> Tuple[torch.Tensor, dict]:
54
+ """
55
+ Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
56
+ Input:
57
+ - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
58
+ - estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
59
+ - timestep: a tensor with shape [B, F] containing the randomly generated timestep.
60
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
61
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
62
+ - normalization: a boolean indicating whether to normalize the gradient.
63
+ Output:
64
+ - kl_grad: a tensor representing the KL grad.
65
+ - kl_log_dict: a dictionary containing the intermediate tensors for logging.
66
+ """
67
+ # Step 1: Compute the fake score
68
+ _, pred_fake_image_cond = self.fake_score(
69
+ noisy_image_or_video=noisy_image_or_video,
70
+ conditional_dict=conditional_dict,
71
+ timestep=timestep
72
+ )
73
+
74
+ if self.fake_guidance_scale != 0.0:
75
+ _, pred_fake_image_uncond = self.fake_score(
76
+ noisy_image_or_video=noisy_image_or_video,
77
+ conditional_dict=unconditional_dict,
78
+ timestep=timestep
79
+ )
80
+ pred_fake_image = pred_fake_image_cond + (
81
+ pred_fake_image_cond - pred_fake_image_uncond
82
+ ) * self.fake_guidance_scale
83
+ else:
84
+ pred_fake_image = pred_fake_image_cond
85
+
86
+ # Step 2: Compute the real score
87
+ # We compute the conditional and unconditional prediction
88
+ # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
89
+ _, pred_real_image_cond = self.real_score(
90
+ noisy_image_or_video=noisy_image_or_video,
91
+ conditional_dict=conditional_dict,
92
+ timestep=timestep
93
+ )
94
+
95
+ _, pred_real_image_uncond = self.real_score(
96
+ noisy_image_or_video=noisy_image_or_video,
97
+ conditional_dict=unconditional_dict,
98
+ timestep=timestep
99
+ )
100
+
101
+ pred_real_image = pred_real_image_cond + (
102
+ pred_real_image_cond - pred_real_image_uncond
103
+ ) * self.real_guidance_scale
104
+
105
+ # Step 3: Compute the DMD gradient (DMD paper eq. 7).
106
+ grad = (pred_fake_image - pred_real_image)
107
+
108
+ # TODO: Change the normalizer for causal teacher
109
+ if normalization:
110
+ # Step 4: Gradient normalization (DMD paper eq. 8).
111
+ p_real = (estimated_clean_image_or_video - pred_real_image)
112
+ normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
113
+ grad = grad / normalizer
114
+ grad = torch.nan_to_num(grad)
115
+
116
+ return grad, {
117
+ "dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
118
+ "timestep": timestep.detach()
119
+ }
120
+
121
+ def compute_distribution_matching_loss(
122
+ self,
123
+ image_or_video: torch.Tensor,
124
+ conditional_dict: dict,
125
+ unconditional_dict: dict,
126
+ gradient_mask: torch.Tensor = None,
127
+ ) -> Tuple[torch.Tensor, dict]:
128
+ """
129
+ Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
130
+ Input:
131
+ - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
132
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
133
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
134
+ - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
135
+ Output:
136
+ - dmd_loss: a scalar tensor representing the DMD loss.
137
+ - dmd_log_dict: a dictionary containing the intermediate tensors for logging.
138
+ """
139
+ original_latent = image_or_video
140
+
141
+ batch_size, num_frame = image_or_video.shape[:2]
142
+
143
+ with torch.no_grad():
144
+ # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
145
+ timestep = self._get_timestep(
146
+ 0,
147
+ self.num_train_timestep,
148
+ batch_size,
149
+ num_frame,
150
+ self.num_frame_per_block,
151
+ uniform_timestep=True
152
+ )
153
+
154
+ if self.timestep_shift > 1:
155
+ timestep = self.timestep_shift * \
156
+ (timestep / 1000) / \
157
+ (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
158
+ timestep = timestep.clamp(self.min_step, self.max_step)
159
+
160
+ noise = torch.randn_like(image_or_video)
161
+ noisy_latent = self.scheduler.add_noise(
162
+ image_or_video.flatten(0, 1),
163
+ noise.flatten(0, 1),
164
+ timestep.flatten(0, 1)
165
+ ).detach().unflatten(0, (batch_size, num_frame))
166
+
167
+ # Step 2: Compute the KL grad
168
+ grad, dmd_log_dict = self._compute_kl_grad(
169
+ noisy_image_or_video=noisy_latent,
170
+ estimated_clean_image_or_video=original_latent,
171
+ timestep=timestep,
172
+ conditional_dict=conditional_dict,
173
+ unconditional_dict=unconditional_dict
174
+ )
175
+
176
+ if gradient_mask is not None:
177
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
178
+ )[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
179
+ else:
180
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
181
+ ), (original_latent.double() - grad.double()).detach(), reduction="mean")
182
+ return dmd_loss, dmd_log_dict
183
+
184
+ def _run_generator(
185
+ self,
186
+ image_or_video_shape,
187
+ conditional_dict: dict,
188
+ clean_latent: torch.tensor
189
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
190
+ """
191
+ Optionally simulate the generator's input from noise using backward simulation
192
+ and then run the generator for one-step.
193
+ Input:
194
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
195
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
196
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
197
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
198
+ - initial_latent: a tensor containing the initial latents [B, F, C, H, W].
199
+ Output:
200
+ - pred_image: a tensor with shape [B, F, C, H, W].
201
+ """
202
+ simulated_noisy_input = []
203
+ for timestep in self.denoising_step_list:
204
+ noise = torch.randn(
205
+ image_or_video_shape, device=self.device, dtype=self.dtype)
206
+
207
+ noisy_timestep = timestep * torch.ones(
208
+ image_or_video_shape[:2], device=self.device, dtype=torch.long)
209
+
210
+ if timestep != 0:
211
+ noisy_image = self.scheduler.add_noise(
212
+ clean_latent.flatten(0, 1),
213
+ noise.flatten(0, 1),
214
+ noisy_timestep.flatten(0, 1)
215
+ ).unflatten(0, image_or_video_shape[:2])
216
+ else:
217
+ noisy_image = clean_latent
218
+
219
+ simulated_noisy_input.append(noisy_image)
220
+
221
+ simulated_noisy_input = torch.stack(simulated_noisy_input, dim=1)
222
+
223
+ # Step 2: Randomly sample a timestep and pick the corresponding input
224
+ index = self._get_timestep(
225
+ 0,
226
+ len(self.denoising_step_list),
227
+ image_or_video_shape[0],
228
+ image_or_video_shape[1],
229
+ self.num_frame_per_block,
230
+ uniform_timestep=False
231
+ )
232
+
233
+ # select the corresponding timestep's noisy input from the stacked tensor [B, T, F, C, H, W]
234
+ noisy_input = torch.gather(
235
+ simulated_noisy_input, dim=1,
236
+ index=index.reshape(index.shape[0], 1, index.shape[1], 1, 1, 1).expand(
237
+ -1, -1, -1, *image_or_video_shape[2:]).to(self.device)
238
+ ).squeeze(1)
239
+
240
+ timestep = self.denoising_step_list[index].to(self.device)
241
+
242
+ _, pred_image_or_video = self.generator(
243
+ noisy_image_or_video=noisy_input,
244
+ conditional_dict=conditional_dict,
245
+ timestep=timestep,
246
+ clean_x=clean_latent if self.teacher_forcing else None,
247
+ )
248
+
249
+ gradient_mask = None # timestep != 0
250
+
251
+ pred_image_or_video = pred_image_or_video.type_as(noisy_input)
252
+
253
+ return pred_image_or_video, gradient_mask
254
+
255
+ def generator_loss(
256
+ self,
257
+ image_or_video_shape,
258
+ conditional_dict: dict,
259
+ unconditional_dict: dict,
260
+ clean_latent: torch.Tensor,
261
+ initial_latent: torch.Tensor = None
262
+ ) -> Tuple[torch.Tensor, dict]:
263
+ """
264
+ Generate image/videos from noise and compute the DMD loss.
265
+ The noisy input to the generator is backward simulated.
266
+ This removes the need of any datasets during distillation.
267
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
268
+ Input:
269
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
270
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
271
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
272
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
273
+ Output:
274
+ - loss: a scalar tensor representing the generator loss.
275
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
276
+ """
277
+ # Step 1: Run generator on backward simulated noisy input
278
+ pred_image, gradient_mask = self._run_generator(
279
+ image_or_video_shape=image_or_video_shape,
280
+ conditional_dict=conditional_dict,
281
+ clean_latent=clean_latent
282
+ )
283
+
284
+ # Step 2: Compute the DMD loss
285
+ dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
286
+ image_or_video=pred_image,
287
+ conditional_dict=conditional_dict,
288
+ unconditional_dict=unconditional_dict,
289
+ gradient_mask=gradient_mask
290
+ )
291
+
292
+ # Step 3: TODO: Implement the GAN loss
293
+
294
+ return dmd_loss, dmd_log_dict
295
+
296
+ def critic_loss(
297
+ self,
298
+ image_or_video_shape,
299
+ conditional_dict: dict,
300
+ unconditional_dict: dict,
301
+ clean_latent: torch.Tensor,
302
+ initial_latent: torch.Tensor = None
303
+ ) -> Tuple[torch.Tensor, dict]:
304
+ """
305
+ Generate image/videos from noise and train the critic with generated samples.
306
+ The noisy input to the generator is backward simulated.
307
+ This removes the need of any datasets during distillation.
308
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
309
+ Input:
310
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
311
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
312
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
313
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
314
+ Output:
315
+ - loss: a scalar tensor representing the generator loss.
316
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
317
+ """
318
+
319
+ # Step 1: Run generator on backward simulated noisy input
320
+ with torch.no_grad():
321
+ generated_image, _ = self._run_generator(
322
+ image_or_video_shape=image_or_video_shape,
323
+ conditional_dict=conditional_dict,
324
+ clean_latent=clean_latent
325
+ )
326
+
327
+ # Step 2: Compute the fake prediction
328
+ critic_timestep = self._get_timestep(
329
+ 0,
330
+ self.num_train_timestep,
331
+ image_or_video_shape[0],
332
+ image_or_video_shape[1],
333
+ self.num_frame_per_block,
334
+ uniform_timestep=True
335
+ )
336
+
337
+ if self.timestep_shift > 1:
338
+ critic_timestep = self.timestep_shift * \
339
+ (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
340
+
341
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
342
+
343
+ critic_noise = torch.randn_like(generated_image)
344
+ noisy_generated_image = self.scheduler.add_noise(
345
+ generated_image.flatten(0, 1),
346
+ critic_noise.flatten(0, 1),
347
+ critic_timestep.flatten(0, 1)
348
+ ).unflatten(0, image_or_video_shape[:2])
349
+
350
+ _, pred_fake_image = self.fake_score(
351
+ noisy_image_or_video=noisy_generated_image,
352
+ conditional_dict=conditional_dict,
353
+ timestep=critic_timestep
354
+ )
355
+
356
+ # Step 3: Compute the denoising loss for the fake critic
357
+ if self.args.denoising_loss_type == "flow":
358
+ from utils.wan_wrapper import WanDiffusionWrapper
359
+ flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
360
+ scheduler=self.scheduler,
361
+ x0_pred=pred_fake_image.flatten(0, 1),
362
+ xt=noisy_generated_image.flatten(0, 1),
363
+ timestep=critic_timestep.flatten(0, 1)
364
+ )
365
+ pred_fake_noise = None
366
+ else:
367
+ flow_pred = None
368
+ pred_fake_noise = self.scheduler.convert_x0_to_noise(
369
+ x0=pred_fake_image.flatten(0, 1),
370
+ xt=noisy_generated_image.flatten(0, 1),
371
+ timestep=critic_timestep.flatten(0, 1)
372
+ ).unflatten(0, image_or_video_shape[:2])
373
+
374
+ denoising_loss = self.denoising_loss_func(
375
+ x=generated_image.flatten(0, 1),
376
+ x_pred=pred_fake_image.flatten(0, 1),
377
+ noise=critic_noise.flatten(0, 1),
378
+ noise_pred=pred_fake_noise,
379
+ alphas_cumprod=self.scheduler.alphas_cumprod,
380
+ timestep=critic_timestep.flatten(0, 1),
381
+ flow_pred=flow_pred
382
+ )
383
+
384
+ # Step 4: TODO: Compute the GAN loss
385
+
386
+ # Step 5: Debugging Log
387
+ critic_log_dict = {
388
+ "critic_timestep": critic_timestep.detach()
389
+ }
390
+
391
+ return denoising_loss, critic_log_dict
model/diffusion.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ from model.base import BaseModel
5
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
6
+
7
+
8
+ class CausalDiffusion(BaseModel):
9
+ def __init__(self, args, device):
10
+ """
11
+ Initialize the Diffusion loss module.
12
+ """
13
+ super().__init__(args, device)
14
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
15
+ if self.num_frame_per_block > 1:
16
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
17
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
18
+ if self.independent_first_frame:
19
+ self.generator.model.independent_first_frame = True
20
+
21
+ if args.gradient_checkpointing:
22
+ self.generator.enable_gradient_checkpointing()
23
+
24
+ # Step 2: Initialize all hyperparameters
25
+ self.num_train_timestep = args.num_train_timestep
26
+ self.min_step = int(0.02 * self.num_train_timestep)
27
+ self.max_step = int(0.98 * self.num_train_timestep)
28
+ self.guidance_scale = args.guidance_scale
29
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
30
+ self.teacher_forcing = getattr(args, "teacher_forcing", False)
31
+ # Noise augmentation in teacher forcing, we add small noise to clean context latents
32
+ self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0)
33
+
34
+ def _initialize_models(self, args):
35
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
36
+ self.generator.model.requires_grad_(True)
37
+
38
+ self.text_encoder = WanTextEncoder()
39
+ self.text_encoder.requires_grad_(False)
40
+
41
+ self.vae = WanVAEWrapper()
42
+ self.vae.requires_grad_(False)
43
+
44
+ def generator_loss(
45
+ self,
46
+ image_or_video_shape,
47
+ conditional_dict: dict,
48
+ unconditional_dict: dict,
49
+ clean_latent: torch.Tensor,
50
+ initial_latent: torch.Tensor = None
51
+ ) -> Tuple[torch.Tensor, dict]:
52
+ """
53
+ Generate image/videos from noise and compute the DMD loss.
54
+ The noisy input to the generator is backward simulated.
55
+ This removes the need of any datasets during distillation.
56
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
57
+ Input:
58
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
59
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
60
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
61
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
62
+ Output:
63
+ - loss: a scalar tensor representing the generator loss.
64
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
65
+ """
66
+ noise = torch.randn_like(clean_latent)
67
+ batch_size, num_frame = image_or_video_shape[:2]
68
+
69
+ # Step 2: Randomly sample a timestep and add noise to denoiser inputs
70
+ index = self._get_timestep(
71
+ 0,
72
+ self.scheduler.num_train_timesteps,
73
+ image_or_video_shape[0],
74
+ image_or_video_shape[1],
75
+ self.num_frame_per_block,
76
+ uniform_timestep=False
77
+ )
78
+ timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device)
79
+ noisy_latents = self.scheduler.add_noise(
80
+ clean_latent.flatten(0, 1),
81
+ noise.flatten(0, 1),
82
+ timestep.flatten(0, 1)
83
+ ).unflatten(0, (batch_size, num_frame))
84
+ training_target = self.scheduler.training_target(clean_latent, noise, timestep)
85
+
86
+ # Step 3: Noise augmentation, also add small noise to clean context latents
87
+ if self.noise_augmentation_max_timestep > 0:
88
+ index_clean_aug = self._get_timestep(
89
+ 0,
90
+ self.noise_augmentation_max_timestep,
91
+ image_or_video_shape[0],
92
+ image_or_video_shape[1],
93
+ self.num_frame_per_block,
94
+ uniform_timestep=False
95
+ )
96
+ timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device)
97
+ clean_latent_aug = self.scheduler.add_noise(
98
+ clean_latent.flatten(0, 1),
99
+ noise.flatten(0, 1),
100
+ timestep_clean_aug.flatten(0, 1)
101
+ ).unflatten(0, (batch_size, num_frame))
102
+ else:
103
+ clean_latent_aug = clean_latent
104
+ timestep_clean_aug = None
105
+
106
+ # Compute loss
107
+ flow_pred, x0_pred = self.generator(
108
+ noisy_image_or_video=noisy_latents,
109
+ conditional_dict=conditional_dict,
110
+ timestep=timestep,
111
+ clean_x=clean_latent_aug if self.teacher_forcing else None,
112
+ aug_t=timestep_clean_aug if self.teacher_forcing else None
113
+ )
114
+ # loss = torch.nn.functional.mse_loss(flow_pred.float(), training_target.float())
115
+ loss = torch.nn.functional.mse_loss(
116
+ flow_pred.float(), training_target.float(), reduction='none'
117
+ ).mean(dim=(2, 3, 4))
118
+ loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame))
119
+ loss = loss.mean()
120
+
121
+ log_dict = {
122
+ "x0": clean_latent.detach(),
123
+ "x0_pred": x0_pred.detach()
124
+ }
125
+ return loss, log_dict
model/dmd.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline import SelfForcingTrainingPipeline
2
+ import torch.nn.functional as F
3
+ from typing import Optional, Tuple
4
+ import torch
5
+
6
+ from model.base import SelfForcingModel
7
+
8
+
9
+ class DMD(SelfForcingModel):
10
+ def __init__(self, args, device):
11
+ """
12
+ Initialize the DMD (Distribution Matching Distillation) module.
13
+ This class is self-contained and compute generator and fake score losses
14
+ in the forward pass.
15
+ """
16
+ super().__init__(args, device)
17
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
18
+ self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
19
+ self.num_training_frames = getattr(args, "num_training_frames", 21)
20
+
21
+ if self.num_frame_per_block > 1:
22
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
23
+
24
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
25
+ if self.independent_first_frame:
26
+ self.generator.model.independent_first_frame = True
27
+ if args.gradient_checkpointing:
28
+ self.generator.enable_gradient_checkpointing()
29
+ self.fake_score.enable_gradient_checkpointing()
30
+
31
+ # this will be init later with fsdp-wrapped modules
32
+ self.inference_pipeline: SelfForcingTrainingPipeline = None
33
+
34
+ # Step 2: Initialize all dmd hyperparameters
35
+ self.num_train_timestep = args.num_train_timestep
36
+ self.min_step = int(0.02 * self.num_train_timestep)
37
+ self.max_step = int(0.98 * self.num_train_timestep)
38
+ if hasattr(args, "real_guidance_scale"):
39
+ self.real_guidance_scale = args.real_guidance_scale
40
+ self.fake_guidance_scale = args.fake_guidance_scale
41
+ else:
42
+ self.real_guidance_scale = args.guidance_scale
43
+ self.fake_guidance_scale = 0.0
44
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
45
+ self.ts_schedule = getattr(args, "ts_schedule", True)
46
+ self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
47
+ self.min_score_timestep = getattr(args, "min_score_timestep", 0)
48
+
49
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
50
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
51
+ else:
52
+ self.scheduler.alphas_cumprod = None
53
+
54
+ def _compute_kl_grad(
55
+ self, noisy_image_or_video: torch.Tensor,
56
+ estimated_clean_image_or_video: torch.Tensor,
57
+ timestep: torch.Tensor,
58
+ conditional_dict: dict, unconditional_dict: dict,
59
+ normalization: bool = True
60
+ ) -> Tuple[torch.Tensor, dict]:
61
+ """
62
+ Compute the KL grad (eq 7 in https://arxiv.org/abs/2311.18828).
63
+ Input:
64
+ - noisy_image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
65
+ - estimated_clean_image_or_video: a tensor with shape [B, F, C, H, W] representing the estimated clean image or video.
66
+ - timestep: a tensor with shape [B, F] containing the randomly generated timestep.
67
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
68
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
69
+ - normalization: a boolean indicating whether to normalize the gradient.
70
+ Output:
71
+ - kl_grad: a tensor representing the KL grad.
72
+ - kl_log_dict: a dictionary containing the intermediate tensors for logging.
73
+ """
74
+ # Step 1: Compute the fake score
75
+ _, pred_fake_image_cond = self.fake_score(
76
+ noisy_image_or_video=noisy_image_or_video,
77
+ conditional_dict=conditional_dict,
78
+ timestep=timestep
79
+ )
80
+
81
+ if self.fake_guidance_scale != 0.0:
82
+ _, pred_fake_image_uncond = self.fake_score(
83
+ noisy_image_or_video=noisy_image_or_video,
84
+ conditional_dict=unconditional_dict,
85
+ timestep=timestep
86
+ )
87
+ pred_fake_image = pred_fake_image_cond + (
88
+ pred_fake_image_cond - pred_fake_image_uncond
89
+ ) * self.fake_guidance_scale
90
+ else:
91
+ pred_fake_image = pred_fake_image_cond
92
+
93
+ # Step 2: Compute the real score
94
+ # We compute the conditional and unconditional prediction
95
+ # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
96
+ _, pred_real_image_cond = self.real_score(
97
+ noisy_image_or_video=noisy_image_or_video,
98
+ conditional_dict=conditional_dict,
99
+ timestep=timestep
100
+ )
101
+
102
+ _, pred_real_image_uncond = self.real_score(
103
+ noisy_image_or_video=noisy_image_or_video,
104
+ conditional_dict=unconditional_dict,
105
+ timestep=timestep
106
+ )
107
+
108
+ pred_real_image = pred_real_image_cond + (
109
+ pred_real_image_cond - pred_real_image_uncond
110
+ ) * self.real_guidance_scale
111
+
112
+ # Step 3: Compute the DMD gradient (DMD paper eq. 7).
113
+ grad = (pred_fake_image - pred_real_image)
114
+
115
+ # TODO: Change the normalizer for causal teacher
116
+ if normalization:
117
+ # Step 4: Gradient normalization (DMD paper eq. 8).
118
+ p_real = (estimated_clean_image_or_video - pred_real_image)
119
+ normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
120
+ grad = grad / normalizer
121
+ grad = torch.nan_to_num(grad)
122
+
123
+ return grad, {
124
+ "dmdtrain_gradient_norm": torch.mean(torch.abs(grad)).detach(),
125
+ "timestep": timestep.detach()
126
+ }
127
+
128
+ def compute_distribution_matching_loss(
129
+ self,
130
+ image_or_video: torch.Tensor,
131
+ conditional_dict: dict,
132
+ unconditional_dict: dict,
133
+ gradient_mask: Optional[torch.Tensor] = None,
134
+ denoised_timestep_from: int = 0,
135
+ denoised_timestep_to: int = 0
136
+ ) -> Tuple[torch.Tensor, dict]:
137
+ """
138
+ Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
139
+ Input:
140
+ - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
141
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
142
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
143
+ - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
144
+ Output:
145
+ - dmd_loss: a scalar tensor representing the DMD loss.
146
+ - dmd_log_dict: a dictionary containing the intermediate tensors for logging.
147
+ """
148
+ original_latent = image_or_video
149
+
150
+ batch_size, num_frame = image_or_video.shape[:2]
151
+
152
+ with torch.no_grad():
153
+ # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
154
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
155
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
156
+ timestep = self._get_timestep(
157
+ min_timestep,
158
+ max_timestep,
159
+ batch_size,
160
+ num_frame,
161
+ self.num_frame_per_block,
162
+ uniform_timestep=True
163
+ )
164
+
165
+ # TODO:should we change it to `timestep = self.scheduler.timesteps[timestep]`?
166
+ if self.timestep_shift > 1:
167
+ timestep = self.timestep_shift * \
168
+ (timestep / 1000) / \
169
+ (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
170
+ timestep = timestep.clamp(self.min_step, self.max_step)
171
+
172
+ noise = torch.randn_like(image_or_video)
173
+ noisy_latent = self.scheduler.add_noise(
174
+ image_or_video.flatten(0, 1),
175
+ noise.flatten(0, 1),
176
+ timestep.flatten(0, 1)
177
+ ).detach().unflatten(0, (batch_size, num_frame))
178
+
179
+ # Step 2: Compute the KL grad
180
+ grad, dmd_log_dict = self._compute_kl_grad(
181
+ noisy_image_or_video=noisy_latent,
182
+ estimated_clean_image_or_video=original_latent,
183
+ timestep=timestep,
184
+ conditional_dict=conditional_dict,
185
+ unconditional_dict=unconditional_dict
186
+ )
187
+
188
+ if gradient_mask is not None:
189
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
190
+ )[gradient_mask], (original_latent.double() - grad.double()).detach()[gradient_mask], reduction="mean")
191
+ else:
192
+ dmd_loss = 0.5 * F.mse_loss(original_latent.double(
193
+ ), (original_latent.double() - grad.double()).detach(), reduction="mean")
194
+ return dmd_loss, dmd_log_dict
195
+
196
+ def generator_loss(
197
+ self,
198
+ image_or_video_shape,
199
+ conditional_dict: dict,
200
+ unconditional_dict: dict,
201
+ clean_latent: torch.Tensor,
202
+ initial_latent: torch.Tensor = None
203
+ ) -> Tuple[torch.Tensor, dict]:
204
+ """
205
+ Generate image/videos from noise and compute the DMD loss.
206
+ The noisy input to the generator is backward simulated.
207
+ This removes the need of any datasets during distillation.
208
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
209
+ Input:
210
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
211
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
212
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
213
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
214
+ Output:
215
+ - loss: a scalar tensor representing the generator loss.
216
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
217
+ """
218
+ # Step 1: Unroll generator to obtain fake videos
219
+ pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
220
+ image_or_video_shape=image_or_video_shape,
221
+ conditional_dict=conditional_dict,
222
+ initial_latent=initial_latent
223
+ )
224
+
225
+ # Step 2: Compute the DMD loss
226
+ dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
227
+ image_or_video=pred_image,
228
+ conditional_dict=conditional_dict,
229
+ unconditional_dict=unconditional_dict,
230
+ gradient_mask=gradient_mask,
231
+ denoised_timestep_from=denoised_timestep_from,
232
+ denoised_timestep_to=denoised_timestep_to
233
+ )
234
+
235
+ return dmd_loss, dmd_log_dict
236
+
237
+ def critic_loss(
238
+ self,
239
+ image_or_video_shape,
240
+ conditional_dict: dict,
241
+ unconditional_dict: dict,
242
+ clean_latent: torch.Tensor,
243
+ initial_latent: torch.Tensor = None
244
+ ) -> Tuple[torch.Tensor, dict]:
245
+ """
246
+ Generate image/videos from noise and train the critic with generated samples.
247
+ The noisy input to the generator is backward simulated.
248
+ This removes the need of any datasets during distillation.
249
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
250
+ Input:
251
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
252
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
253
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
254
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
255
+ Output:
256
+ - loss: a scalar tensor representing the generator loss.
257
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
258
+ """
259
+
260
+ # Step 1: Run generator on backward simulated noisy input
261
+ with torch.no_grad():
262
+ generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
263
+ image_or_video_shape=image_or_video_shape,
264
+ conditional_dict=conditional_dict,
265
+ initial_latent=initial_latent
266
+ )
267
+
268
+ # Step 2: Compute the fake prediction
269
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
270
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
271
+ critic_timestep = self._get_timestep(
272
+ min_timestep,
273
+ max_timestep,
274
+ image_or_video_shape[0],
275
+ image_or_video_shape[1],
276
+ self.num_frame_per_block,
277
+ uniform_timestep=True
278
+ )
279
+
280
+ if self.timestep_shift > 1:
281
+ critic_timestep = self.timestep_shift * \
282
+ (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
283
+
284
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
285
+
286
+ critic_noise = torch.randn_like(generated_image)
287
+ noisy_generated_image = self.scheduler.add_noise(
288
+ generated_image.flatten(0, 1),
289
+ critic_noise.flatten(0, 1),
290
+ critic_timestep.flatten(0, 1)
291
+ ).unflatten(0, image_or_video_shape[:2])
292
+
293
+ _, pred_fake_image = self.fake_score(
294
+ noisy_image_or_video=noisy_generated_image,
295
+ conditional_dict=conditional_dict,
296
+ timestep=critic_timestep
297
+ )
298
+
299
+ # Step 3: Compute the denoising loss for the fake critic
300
+ if self.args.denoising_loss_type == "flow":
301
+ from utils.wan_wrapper import WanDiffusionWrapper
302
+ flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
303
+ scheduler=self.scheduler,
304
+ x0_pred=pred_fake_image.flatten(0, 1),
305
+ xt=noisy_generated_image.flatten(0, 1),
306
+ timestep=critic_timestep.flatten(0, 1)
307
+ )
308
+ pred_fake_noise = None
309
+ else:
310
+ flow_pred = None
311
+ pred_fake_noise = self.scheduler.convert_x0_to_noise(
312
+ x0=pred_fake_image.flatten(0, 1),
313
+ xt=noisy_generated_image.flatten(0, 1),
314
+ timestep=critic_timestep.flatten(0, 1)
315
+ ).unflatten(0, image_or_video_shape[:2])
316
+
317
+ denoising_loss = self.denoising_loss_func(
318
+ x=generated_image.flatten(0, 1),
319
+ x_pred=pred_fake_image.flatten(0, 1),
320
+ noise=critic_noise.flatten(0, 1),
321
+ noise_pred=pred_fake_noise,
322
+ alphas_cumprod=self.scheduler.alphas_cumprod,
323
+ timestep=critic_timestep.flatten(0, 1),
324
+ flow_pred=flow_pred
325
+ )
326
+
327
+ # Step 5: Debugging Log
328
+ critic_log_dict = {
329
+ "critic_timestep": critic_timestep.detach()
330
+ }
331
+
332
+ return denoising_loss, critic_log_dict
model/gan.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from pipeline import SelfForcingTrainingPipeline
3
+ import torch.nn.functional as F
4
+ from typing import Tuple
5
+ import torch
6
+
7
+ from model.base import SelfForcingModel
8
+
9
+
10
+ class GAN(SelfForcingModel):
11
+ def __init__(self, args, device):
12
+ """
13
+ Initialize the GAN module.
14
+ This class is self-contained and compute generator and fake score losses
15
+ in the forward pass.
16
+ """
17
+ super().__init__(args, device)
18
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
19
+ self.same_step_across_blocks = getattr(args, "same_step_across_blocks", True)
20
+ self.concat_time_embeddings = getattr(args, "concat_time_embeddings", False)
21
+ self.num_class = args.num_class
22
+ self.relativistic_discriminator = getattr(args, "relativistic_discriminator", False)
23
+
24
+ if self.num_frame_per_block > 1:
25
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
26
+
27
+ self.fake_score.adding_cls_branch(
28
+ atten_dim=1536, num_class=args.num_class, time_embed_dim=1536 if self.concat_time_embeddings else 0)
29
+ self.fake_score.model.requires_grad_(True)
30
+
31
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
32
+ if self.independent_first_frame:
33
+ self.generator.model.independent_first_frame = True
34
+ if args.gradient_checkpointing:
35
+ self.generator.enable_gradient_checkpointing()
36
+ self.fake_score.enable_gradient_checkpointing()
37
+
38
+ # this will be init later with fsdp-wrapped modules
39
+ self.inference_pipeline: SelfForcingTrainingPipeline = None
40
+
41
+ # Step 2: Initialize all dmd hyperparameters
42
+ self.num_train_timestep = args.num_train_timestep
43
+ self.min_step = int(0.02 * self.num_train_timestep)
44
+ self.max_step = int(0.98 * self.num_train_timestep)
45
+ if hasattr(args, "real_guidance_scale"):
46
+ self.real_guidance_scale = args.real_guidance_scale
47
+ self.fake_guidance_scale = args.fake_guidance_scale
48
+ else:
49
+ self.real_guidance_scale = args.guidance_scale
50
+ self.fake_guidance_scale = 0.0
51
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
52
+ self.critic_timestep_shift = getattr(args, "critic_timestep_shift", self.timestep_shift)
53
+ self.ts_schedule = getattr(args, "ts_schedule", True)
54
+ self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
55
+ self.min_score_timestep = getattr(args, "min_score_timestep", 0)
56
+
57
+ self.gan_g_weight = getattr(args, "gan_g_weight", 1e-2)
58
+ self.gan_d_weight = getattr(args, "gan_d_weight", 1e-2)
59
+ self.r1_weight = getattr(args, "r1_weight", 0.0)
60
+ self.r2_weight = getattr(args, "r2_weight", 0.0)
61
+ self.r1_sigma = getattr(args, "r1_sigma", 0.01)
62
+ self.r2_sigma = getattr(args, "r2_sigma", 0.01)
63
+
64
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
65
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
66
+ else:
67
+ self.scheduler.alphas_cumprod = None
68
+
69
+ def _run_cls_pred_branch(self,
70
+ noisy_image_or_video: torch.Tensor,
71
+ conditional_dict: dict,
72
+ timestep: torch.Tensor) -> torch.Tensor:
73
+ """
74
+ Run the classifier prediction branch on the generated image or video.
75
+ Input:
76
+ - image_or_video: a tensor with shape [B, F, C, H, W].
77
+ Output:
78
+ - cls_pred: a tensor with shape [B, 1, 1, 1, 1] representing the feature map for classification.
79
+ """
80
+ _, _, noisy_logit = self.fake_score(
81
+ noisy_image_or_video=noisy_image_or_video,
82
+ conditional_dict=conditional_dict,
83
+ timestep=timestep,
84
+ classify_mode=True,
85
+ concat_time_embeddings=self.concat_time_embeddings
86
+ )
87
+
88
+ return noisy_logit
89
+
90
+ def generator_loss(
91
+ self,
92
+ image_or_video_shape,
93
+ conditional_dict: dict,
94
+ unconditional_dict: dict,
95
+ clean_latent: torch.Tensor,
96
+ initial_latent: torch.Tensor = None
97
+ ) -> Tuple[torch.Tensor, dict]:
98
+ """
99
+ Generate image/videos from noise and compute the DMD loss.
100
+ The noisy input to the generator is backward simulated.
101
+ This removes the need of any datasets during distillation.
102
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
103
+ Input:
104
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
105
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
106
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
107
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
108
+ Output:
109
+ - loss: a scalar tensor representing the generator loss.
110
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
111
+ """
112
+ # Step 1: Unroll generator to obtain fake videos
113
+ pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
114
+ image_or_video_shape=image_or_video_shape,
115
+ conditional_dict=conditional_dict,
116
+ initial_latent=initial_latent
117
+ )
118
+
119
+ # Step 2: Get timestep and add noise to generated/real latents
120
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
121
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
122
+ critic_timestep = self._get_timestep(
123
+ min_timestep,
124
+ max_timestep,
125
+ image_or_video_shape[0],
126
+ image_or_video_shape[1],
127
+ self.num_frame_per_block,
128
+ uniform_timestep=True
129
+ )
130
+
131
+ if self.critic_timestep_shift > 1:
132
+ critic_timestep = self.critic_timestep_shift * \
133
+ (critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
134
+
135
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
136
+
137
+ critic_noise = torch.randn_like(pred_image)
138
+ noisy_fake_latent = self.scheduler.add_noise(
139
+ pred_image.flatten(0, 1),
140
+ critic_noise.flatten(0, 1),
141
+ critic_timestep.flatten(0, 1)
142
+ ).unflatten(0, image_or_video_shape[:2])
143
+
144
+ # Step 4: Compute the real GAN discriminator loss
145
+ real_image_or_video = clean_latent.clone()
146
+ critic_noise = torch.randn_like(real_image_or_video)
147
+ noisy_real_latent = self.scheduler.add_noise(
148
+ real_image_or_video.flatten(0, 1),
149
+ critic_noise.flatten(0, 1),
150
+ critic_timestep.flatten(0, 1)
151
+ ).unflatten(0, image_or_video_shape[:2])
152
+
153
+ conditional_dict["prompt_embeds"] = torch.concatenate(
154
+ (conditional_dict["prompt_embeds"], conditional_dict["prompt_embeds"]), dim=0)
155
+ critic_timestep = torch.concatenate((critic_timestep, critic_timestep), dim=0)
156
+ noisy_latent = torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0)
157
+ _, _, noisy_logit = self.fake_score(
158
+ noisy_image_or_video=noisy_latent,
159
+ conditional_dict=conditional_dict,
160
+ timestep=critic_timestep,
161
+ classify_mode=True,
162
+ concat_time_embeddings=self.concat_time_embeddings
163
+ )
164
+ noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
165
+
166
+ if not self.relativistic_discriminator:
167
+ gan_G_loss = F.softplus(-noisy_fake_logit.float()).mean() * self.gan_g_weight
168
+ else:
169
+ relative_fake_logit = noisy_fake_logit - noisy_real_logit
170
+ gan_G_loss = F.softplus(-relative_fake_logit.float()).mean() * self.gan_g_weight
171
+
172
+ return gan_G_loss
173
+
174
+ def critic_loss(
175
+ self,
176
+ image_or_video_shape,
177
+ conditional_dict: dict,
178
+ unconditional_dict: dict,
179
+ clean_latent: torch.Tensor,
180
+ real_image_or_video: torch.Tensor,
181
+ initial_latent: torch.Tensor = None
182
+ ) -> Tuple[torch.Tensor, dict]:
183
+ """
184
+ Generate image/videos from noise and train the critic with generated samples.
185
+ The noisy input to the generator is backward simulated.
186
+ This removes the need of any datasets during distillation.
187
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
188
+ Input:
189
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
190
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
191
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
192
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
193
+ Output:
194
+ - loss: a scalar tensor representing the generator loss.
195
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
196
+ """
197
+
198
+ # Step 1: Run generator on backward simulated noisy input
199
+ with torch.no_grad():
200
+ generated_image, _, denoised_timestep_from, denoised_timestep_to, num_sim_steps = self._run_generator(
201
+ image_or_video_shape=image_or_video_shape,
202
+ conditional_dict=conditional_dict,
203
+ initial_latent=initial_latent
204
+ )
205
+
206
+ # Step 2: Get timestep and add noise to generated/real latents
207
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
208
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
209
+ critic_timestep = self._get_timestep(
210
+ min_timestep,
211
+ max_timestep,
212
+ image_or_video_shape[0],
213
+ image_or_video_shape[1],
214
+ self.num_frame_per_block,
215
+ uniform_timestep=True
216
+ )
217
+
218
+ if self.critic_timestep_shift > 1:
219
+ critic_timestep = self.critic_timestep_shift * \
220
+ (critic_timestep / 1000) / (1 + (self.critic_timestep_shift - 1) * (critic_timestep / 1000)) * 1000
221
+
222
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
223
+
224
+ critic_noise = torch.randn_like(generated_image)
225
+ noisy_fake_latent = self.scheduler.add_noise(
226
+ generated_image.flatten(0, 1),
227
+ critic_noise.flatten(0, 1),
228
+ critic_timestep.flatten(0, 1)
229
+ ).unflatten(0, image_or_video_shape[:2])
230
+
231
+ # Step 4: Compute the real GAN discriminator loss
232
+ noisy_real_latent = self.scheduler.add_noise(
233
+ real_image_or_video.flatten(0, 1),
234
+ critic_noise.flatten(0, 1),
235
+ critic_timestep.flatten(0, 1)
236
+ ).unflatten(0, image_or_video_shape[:2])
237
+
238
+ conditional_dict_cloned = copy.deepcopy(conditional_dict)
239
+ conditional_dict_cloned["prompt_embeds"] = torch.concatenate(
240
+ (conditional_dict_cloned["prompt_embeds"], conditional_dict_cloned["prompt_embeds"]), dim=0)
241
+ _, _, noisy_logit = self.fake_score(
242
+ noisy_image_or_video=torch.concatenate((noisy_fake_latent, noisy_real_latent), dim=0),
243
+ conditional_dict=conditional_dict_cloned,
244
+ timestep=torch.concatenate((critic_timestep, critic_timestep), dim=0),
245
+ classify_mode=True,
246
+ concat_time_embeddings=self.concat_time_embeddings
247
+ )
248
+ noisy_fake_logit, noisy_real_logit = noisy_logit.chunk(2, dim=0)
249
+
250
+ if not self.relativistic_discriminator:
251
+ gan_D_loss = F.softplus(-noisy_real_logit.float()).mean() + F.softplus(noisy_fake_logit.float()).mean()
252
+ else:
253
+ relative_real_logit = noisy_real_logit - noisy_fake_logit
254
+ gan_D_loss = F.softplus(-relative_real_logit.float()).mean()
255
+ gan_D_loss = gan_D_loss * self.gan_d_weight
256
+
257
+ # R1 regularization
258
+ if self.r1_weight > 0.:
259
+ noisy_real_latent_perturbed = noisy_real_latent.clone()
260
+ epison_real = self.r1_sigma * torch.randn_like(noisy_real_latent_perturbed)
261
+ noisy_real_latent_perturbed = noisy_real_latent_perturbed + epison_real
262
+ noisy_real_logit_perturbed = self._run_cls_pred_branch(
263
+ noisy_image_or_video=noisy_real_latent_perturbed,
264
+ conditional_dict=conditional_dict,
265
+ timestep=critic_timestep
266
+ )
267
+
268
+ r1_grad = (noisy_real_logit_perturbed - noisy_real_logit) / self.r1_sigma
269
+ r1_loss = self.r1_weight * torch.mean((r1_grad)**2)
270
+ else:
271
+ r1_loss = torch.zeros_like(gan_D_loss)
272
+
273
+ # R2 regularization
274
+ if self.r2_weight > 0.:
275
+ noisy_fake_latent_perturbed = noisy_fake_latent.clone()
276
+ epison_generated = self.r2_sigma * torch.randn_like(noisy_fake_latent_perturbed)
277
+ noisy_fake_latent_perturbed = noisy_fake_latent_perturbed + epison_generated
278
+ noisy_fake_logit_perturbed = self._run_cls_pred_branch(
279
+ noisy_image_or_video=noisy_fake_latent_perturbed,
280
+ conditional_dict=conditional_dict,
281
+ timestep=critic_timestep
282
+ )
283
+
284
+ r2_grad = (noisy_fake_logit_perturbed - noisy_fake_logit) / self.r2_sigma
285
+ r2_loss = self.r2_weight * torch.mean((r2_grad)**2)
286
+ else:
287
+ r2_loss = torch.zeros_like(r2_loss)
288
+
289
+ critic_log_dict = {
290
+ "critic_timestep": critic_timestep.detach(),
291
+ 'noisy_real_logit': noisy_real_logit.detach(),
292
+ 'noisy_fake_logit': noisy_fake_logit.detach(),
293
+ }
294
+
295
+ return (gan_D_loss, r1_loss, r2_loss), critic_log_dict
model/ode_regression.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+ from typing import Tuple
3
+ import torch
4
+
5
+ from model.base import BaseModel
6
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
7
+
8
+
9
+ class ODERegression(BaseModel):
10
+ def __init__(self, args, device):
11
+ """
12
+ Initialize the ODERegression module.
13
+ This class is self-contained and compute generator losses
14
+ in the forward pass given precomputed ode solution pairs.
15
+ This class supports the ode regression loss for both causal and bidirectional models.
16
+ See Sec 4.3 of CausVid https://arxiv.org/abs/2412.07772 for details
17
+ """
18
+ super().__init__(args, device)
19
+
20
+ # Step 1: Initialize all models
21
+
22
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
23
+ self.generator.model.requires_grad_(True)
24
+ if getattr(args, "generator_ckpt", False):
25
+ print(f"Loading pretrained generator from {args.generator_ckpt}")
26
+ state_dict = torch.load(args.generator_ckpt, map_location="cpu")[
27
+ 'generator']
28
+ self.generator.load_state_dict(
29
+ state_dict, strict=True
30
+ )
31
+
32
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
33
+
34
+ if self.num_frame_per_block > 1:
35
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
36
+
37
+ self.independent_first_frame = getattr(args, "independent_first_frame", False)
38
+ if self.independent_first_frame:
39
+ self.generator.model.independent_first_frame = True
40
+ if args.gradient_checkpointing:
41
+ self.generator.enable_gradient_checkpointing()
42
+
43
+ # Step 2: Initialize all hyperparameters
44
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
45
+
46
+ def _initialize_models(self, args):
47
+ self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
48
+ self.generator.model.requires_grad_(True)
49
+
50
+ self.text_encoder = WanTextEncoder()
51
+ self.text_encoder.requires_grad_(False)
52
+
53
+ self.vae = WanVAEWrapper()
54
+ self.vae.requires_grad_(False)
55
+
56
+ @torch.no_grad()
57
+ def _prepare_generator_input(self, ode_latent: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
58
+ """
59
+ Given a tensor containing the whole ODE sampling trajectories,
60
+ randomly choose an intermediate timestep and return the latent as well as the corresponding timestep.
61
+ Input:
62
+ - ode_latent: a tensor containing the whole ODE sampling trajectories [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
63
+ Output:
64
+ - noisy_input: a tensor containing the selected latent [batch_size, num_frames, num_channels, height, width].
65
+ - timestep: a tensor containing the corresponding timestep [batch_size].
66
+ """
67
+ batch_size, num_denoising_steps, num_frames, num_channels, height, width = ode_latent.shape
68
+
69
+ # Step 1: Randomly choose a timestep for each frame
70
+ index = self._get_timestep(
71
+ 0,
72
+ len(self.denoising_step_list),
73
+ batch_size,
74
+ num_frames,
75
+ self.num_frame_per_block,
76
+ uniform_timestep=False
77
+ )
78
+ if self.args.i2v:
79
+ index[:, 0] = len(self.denoising_step_list) - 1
80
+
81
+ noisy_input = torch.gather(
82
+ ode_latent, dim=1,
83
+ index=index.reshape(batch_size, 1, num_frames, 1, 1, 1).expand(
84
+ -1, -1, -1, num_channels, height, width).to(self.device)
85
+ ).squeeze(1)
86
+
87
+ timestep = self.denoising_step_list[index].to(self.device)
88
+
89
+ # if self.extra_noise_step > 0:
90
+ # random_timestep = torch.randint(0, self.extra_noise_step, [
91
+ # batch_size, num_frames], device=self.device, dtype=torch.long)
92
+ # perturbed_noisy_input = self.scheduler.add_noise(
93
+ # noisy_input.flatten(0, 1),
94
+ # torch.randn_like(noisy_input.flatten(0, 1)),
95
+ # random_timestep.flatten(0, 1)
96
+ # ).detach().unflatten(0, (batch_size, num_frames)).type_as(noisy_input)
97
+
98
+ # noisy_input[timestep == 0] = perturbed_noisy_input[timestep == 0]
99
+
100
+ return noisy_input, timestep
101
+
102
+ def generator_loss(self, ode_latent: torch.Tensor, conditional_dict: dict) -> Tuple[torch.Tensor, dict]:
103
+ """
104
+ Generate image/videos from noisy latents and compute the ODE regression loss.
105
+ Input:
106
+ - ode_latent: a tensor containing the ODE latents [batch_size, num_denoising_steps, num_frames, num_channels, height, width].
107
+ They are ordered from most noisy to clean latents.
108
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
109
+ Output:
110
+ - loss: a scalar tensor representing the generator loss.
111
+ - log_dict: a dictionary containing additional information for loss timestep breakdown.
112
+ """
113
+ # Step 1: Run generator on noisy latents
114
+ target_latent = ode_latent[:, -1]
115
+
116
+ noisy_input, timestep = self._prepare_generator_input(
117
+ ode_latent=ode_latent)
118
+
119
+ _, pred_image_or_video = self.generator(
120
+ noisy_image_or_video=noisy_input,
121
+ conditional_dict=conditional_dict,
122
+ timestep=timestep
123
+ )
124
+
125
+ # Step 2: Compute the regression loss
126
+ mask = timestep != 0
127
+
128
+ loss = F.mse_loss(
129
+ pred_image_or_video[mask], target_latent[mask], reduction="mean")
130
+
131
+ log_dict = {
132
+ "unnormalized_loss": F.mse_loss(pred_image_or_video, target_latent, reduction='none').mean(dim=[1, 2, 3, 4]).detach(),
133
+ "timestep": timestep.float().mean(dim=1).detach(),
134
+ "input": noisy_input.detach(),
135
+ "output": pred_image_or_video.detach(),
136
+ }
137
+
138
+ return loss, log_dict
model/sid.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pipeline import SelfForcingTrainingPipeline
2
+ from typing import Optional, Tuple
3
+ import torch
4
+
5
+ from model.base import SelfForcingModel
6
+
7
+
8
+ class SiD(SelfForcingModel):
9
+ def __init__(self, args, device):
10
+ """
11
+ Initialize the DMD (Distribution Matching Distillation) module.
12
+ This class is self-contained and compute generator and fake score losses
13
+ in the forward pass.
14
+ """
15
+ super().__init__(args, device)
16
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
17
+
18
+ if self.num_frame_per_block > 1:
19
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
20
+
21
+ if args.gradient_checkpointing:
22
+ self.generator.enable_gradient_checkpointing()
23
+ self.fake_score.enable_gradient_checkpointing()
24
+ self.real_score.enable_gradient_checkpointing()
25
+
26
+ # this will be init later with fsdp-wrapped modules
27
+ self.inference_pipeline: SelfForcingTrainingPipeline = None
28
+
29
+ # Step 2: Initialize all dmd hyperparameters
30
+ self.num_train_timestep = args.num_train_timestep
31
+ self.min_step = int(0.02 * self.num_train_timestep)
32
+ self.max_step = int(0.98 * self.num_train_timestep)
33
+ if hasattr(args, "real_guidance_scale"):
34
+ self.real_guidance_scale = args.real_guidance_scale
35
+ else:
36
+ self.real_guidance_scale = args.guidance_scale
37
+ self.timestep_shift = getattr(args, "timestep_shift", 1.0)
38
+ self.sid_alpha = getattr(args, "sid_alpha", 1.0)
39
+ self.ts_schedule = getattr(args, "ts_schedule", True)
40
+ self.ts_schedule_max = getattr(args, "ts_schedule_max", False)
41
+
42
+ if getattr(self.scheduler, "alphas_cumprod", None) is not None:
43
+ self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(device)
44
+ else:
45
+ self.scheduler.alphas_cumprod = None
46
+
47
+ def compute_distribution_matching_loss(
48
+ self,
49
+ image_or_video: torch.Tensor,
50
+ conditional_dict: dict,
51
+ unconditional_dict: dict,
52
+ gradient_mask: Optional[torch.Tensor] = None,
53
+ denoised_timestep_from: int = 0,
54
+ denoised_timestep_to: int = 0
55
+ ) -> Tuple[torch.Tensor, dict]:
56
+ """
57
+ Compute the DMD loss (eq 7 in https://arxiv.org/abs/2311.18828).
58
+ Input:
59
+ - image_or_video: a tensor with shape [B, F, C, H, W] where the number of frame is 1 for images.
60
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
61
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
62
+ - gradient_mask: a boolean tensor with the same shape as image_or_video indicating which pixels to compute loss .
63
+ Output:
64
+ - dmd_loss: a scalar tensor representing the DMD loss.
65
+ - dmd_log_dict: a dictionary containing the intermediate tensors for logging.
66
+ """
67
+ original_latent = image_or_video
68
+
69
+ batch_size, num_frame = image_or_video.shape[:2]
70
+
71
+ # Step 1: Randomly sample timestep based on the given schedule and corresponding noise
72
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
73
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
74
+ timestep = self._get_timestep(
75
+ min_timestep,
76
+ max_timestep,
77
+ batch_size,
78
+ num_frame,
79
+ self.num_frame_per_block,
80
+ uniform_timestep=True
81
+ )
82
+
83
+ if self.timestep_shift > 1:
84
+ timestep = self.timestep_shift * \
85
+ (timestep / 1000) / \
86
+ (1 + (self.timestep_shift - 1) * (timestep / 1000)) * 1000
87
+ timestep = timestep.clamp(self.min_step, self.max_step)
88
+
89
+ noise = torch.randn_like(image_or_video)
90
+ noisy_latent = self.scheduler.add_noise(
91
+ image_or_video.flatten(0, 1),
92
+ noise.flatten(0, 1),
93
+ timestep.flatten(0, 1)
94
+ ).unflatten(0, (batch_size, num_frame))
95
+
96
+ # Step 2: SiD (May be wrap it?)
97
+ noisy_image_or_video = noisy_latent
98
+ # Step 2.1: Compute the fake score
99
+ _, pred_fake_image = self.fake_score(
100
+ noisy_image_or_video=noisy_image_or_video,
101
+ conditional_dict=conditional_dict,
102
+ timestep=timestep
103
+ )
104
+ # Step 2.2: Compute the real score
105
+ # We compute the conditional and unconditional prediction
106
+ # and add them together to achieve cfg (https://arxiv.org/abs/2207.12598)
107
+ # NOTE: This step may cause OOM issue, which can be addressed by the CFG-free technique
108
+
109
+ _, pred_real_image_cond = self.real_score(
110
+ noisy_image_or_video=noisy_image_or_video,
111
+ conditional_dict=conditional_dict,
112
+ timestep=timestep
113
+ )
114
+
115
+ _, pred_real_image_uncond = self.real_score(
116
+ noisy_image_or_video=noisy_image_or_video,
117
+ conditional_dict=unconditional_dict,
118
+ timestep=timestep
119
+ )
120
+
121
+ pred_real_image = pred_real_image_cond + (
122
+ pred_real_image_cond - pred_real_image_uncond
123
+ ) * self.real_guidance_scale
124
+
125
+ # Step 2.3: SiD Loss
126
+ # TODO: Add alpha
127
+ # TODO: Double?
128
+ sid_loss = (pred_real_image.double() - pred_fake_image.double()) * ((pred_real_image.double() - original_latent.double()) - self.sid_alpha * (pred_real_image.double() - pred_fake_image.double()))
129
+
130
+ # Step 2.4: Loss normalizer
131
+ with torch.no_grad():
132
+ p_real = (original_latent - pred_real_image)
133
+ normalizer = torch.abs(p_real).mean(dim=[1, 2, 3, 4], keepdim=True)
134
+ sid_loss = sid_loss / normalizer
135
+
136
+ sid_loss = torch.nan_to_num(sid_loss)
137
+ num_frame = sid_loss.shape[1]
138
+ sid_loss = sid_loss.mean()
139
+
140
+ sid_log_dict = {
141
+ "dmdtrain_gradient_norm": torch.zeros_like(sid_loss),
142
+ "timestep": timestep.detach()
143
+ }
144
+
145
+ return sid_loss, sid_log_dict
146
+
147
+ def generator_loss(
148
+ self,
149
+ image_or_video_shape,
150
+ conditional_dict: dict,
151
+ unconditional_dict: dict,
152
+ clean_latent: torch.Tensor,
153
+ initial_latent: torch.Tensor = None
154
+ ) -> Tuple[torch.Tensor, dict]:
155
+ """
156
+ Generate image/videos from noise and compute the DMD loss.
157
+ The noisy input to the generator is backward simulated.
158
+ This removes the need of any datasets during distillation.
159
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
160
+ Input:
161
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
162
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
163
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
164
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
165
+ Output:
166
+ - loss: a scalar tensor representing the generator loss.
167
+ - generator_log_dict: a dictionary containing the intermediate tensors for logging.
168
+ """
169
+ # Step 1: Unroll generator to obtain fake videos
170
+ pred_image, gradient_mask, denoised_timestep_from, denoised_timestep_to = self._run_generator(
171
+ image_or_video_shape=image_or_video_shape,
172
+ conditional_dict=conditional_dict,
173
+ initial_latent=initial_latent
174
+ )
175
+
176
+ # Step 2: Compute the DMD loss
177
+ dmd_loss, dmd_log_dict = self.compute_distribution_matching_loss(
178
+ image_or_video=pred_image,
179
+ conditional_dict=conditional_dict,
180
+ unconditional_dict=unconditional_dict,
181
+ gradient_mask=gradient_mask,
182
+ denoised_timestep_from=denoised_timestep_from,
183
+ denoised_timestep_to=denoised_timestep_to
184
+ )
185
+
186
+ return dmd_loss, dmd_log_dict
187
+
188
+ def critic_loss(
189
+ self,
190
+ image_or_video_shape,
191
+ conditional_dict: dict,
192
+ unconditional_dict: dict,
193
+ clean_latent: torch.Tensor,
194
+ initial_latent: torch.Tensor = None
195
+ ) -> Tuple[torch.Tensor, dict]:
196
+ """
197
+ Generate image/videos from noise and train the critic with generated samples.
198
+ The noisy input to the generator is backward simulated.
199
+ This removes the need of any datasets during distillation.
200
+ See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
201
+ Input:
202
+ - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
203
+ - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
204
+ - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
205
+ - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
206
+ Output:
207
+ - loss: a scalar tensor representing the generator loss.
208
+ - critic_log_dict: a dictionary containing the intermediate tensors for logging.
209
+ """
210
+
211
+ # Step 1: Run generator on backward simulated noisy input
212
+ with torch.no_grad():
213
+ generated_image, _, denoised_timestep_from, denoised_timestep_to = self._run_generator(
214
+ image_or_video_shape=image_or_video_shape,
215
+ conditional_dict=conditional_dict,
216
+ initial_latent=initial_latent
217
+ )
218
+
219
+ # Step 2: Compute the fake prediction
220
+ min_timestep = denoised_timestep_to if self.ts_schedule and denoised_timestep_to is not None else self.min_score_timestep
221
+ max_timestep = denoised_timestep_from if self.ts_schedule_max and denoised_timestep_from is not None else self.num_train_timestep
222
+ critic_timestep = self._get_timestep(
223
+ min_timestep,
224
+ max_timestep,
225
+ image_or_video_shape[0],
226
+ image_or_video_shape[1],
227
+ self.num_frame_per_block,
228
+ uniform_timestep=True
229
+ )
230
+
231
+ if self.timestep_shift > 1:
232
+ critic_timestep = self.timestep_shift * \
233
+ (critic_timestep / 1000) / (1 + (self.timestep_shift - 1) * (critic_timestep / 1000)) * 1000
234
+
235
+ critic_timestep = critic_timestep.clamp(self.min_step, self.max_step)
236
+
237
+ critic_noise = torch.randn_like(generated_image)
238
+ noisy_generated_image = self.scheduler.add_noise(
239
+ generated_image.flatten(0, 1),
240
+ critic_noise.flatten(0, 1),
241
+ critic_timestep.flatten(0, 1)
242
+ ).unflatten(0, image_or_video_shape[:2])
243
+
244
+ _, pred_fake_image = self.fake_score(
245
+ noisy_image_or_video=noisy_generated_image,
246
+ conditional_dict=conditional_dict,
247
+ timestep=critic_timestep
248
+ )
249
+
250
+ # Step 3: Compute the denoising loss for the fake critic
251
+ if self.args.denoising_loss_type == "flow":
252
+ from utils.wan_wrapper import WanDiffusionWrapper
253
+ flow_pred = WanDiffusionWrapper._convert_x0_to_flow_pred(
254
+ scheduler=self.scheduler,
255
+ x0_pred=pred_fake_image.flatten(0, 1),
256
+ xt=noisy_generated_image.flatten(0, 1),
257
+ timestep=critic_timestep.flatten(0, 1)
258
+ )
259
+ pred_fake_noise = None
260
+ else:
261
+ flow_pred = None
262
+ pred_fake_noise = self.scheduler.convert_x0_to_noise(
263
+ x0=pred_fake_image.flatten(0, 1),
264
+ xt=noisy_generated_image.flatten(0, 1),
265
+ timestep=critic_timestep.flatten(0, 1)
266
+ ).unflatten(0, image_or_video_shape[:2])
267
+
268
+ denoising_loss = self.denoising_loss_func(
269
+ x=generated_image.flatten(0, 1),
270
+ x_pred=pred_fake_image.flatten(0, 1),
271
+ noise=critic_noise.flatten(0, 1),
272
+ noise_pred=pred_fake_noise,
273
+ alphas_cumprod=self.scheduler.alphas_cumprod,
274
+ timestep=critic_timestep.flatten(0, 1),
275
+ flow_pred=flow_pred
276
+ )
277
+
278
+ # Step 5: Debugging Log
279
+ critic_log_dict = {
280
+ "critic_timestep": critic_timestep.detach()
281
+ }
282
+
283
+ return denoising_loss, critic_log_dict
pipeline/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .bidirectional_diffusion_inference import BidirectionalDiffusionInferencePipeline
2
+ from .bidirectional_inference import BidirectionalInferencePipeline
3
+ from .causal_diffusion_inference import CausalDiffusionInferencePipeline
4
+ from .causal_inference import CausalInferencePipeline
5
+ from .self_forcing_training import SelfForcingTrainingPipeline
6
+
7
+ __all__ = [
8
+ "BidirectionalDiffusionInferencePipeline",
9
+ "BidirectionalInferencePipeline",
10
+ "CausalDiffusionInferencePipeline",
11
+ "CausalInferencePipeline",
12
+ "SelfForcingTrainingPipeline"
13
+ ]
pipeline/bidirectional_diffusion_inference.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import List
3
+ import torch
4
+
5
+ from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
6
+ from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
7
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
8
+
9
+
10
+ class BidirectionalDiffusionInferencePipeline(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ args,
14
+ device,
15
+ generator=None,
16
+ text_encoder=None,
17
+ vae=None
18
+ ):
19
+ super().__init__()
20
+ # Step 1: Initialize all models
21
+ self.generator = WanDiffusionWrapper(
22
+ **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
23
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
24
+ self.vae = WanVAEWrapper() if vae is None else vae
25
+
26
+ # Step 2: Initialize scheduler
27
+ self.num_train_timesteps = args.num_train_timestep
28
+ self.sampling_steps = 50
29
+ self.sample_solver = 'unipc'
30
+ self.shift = 8.0
31
+
32
+ self.args = args
33
+
34
+ def inference(
35
+ self,
36
+ noise: torch.Tensor,
37
+ text_prompts: List[str],
38
+ return_latents=False
39
+ ) -> torch.Tensor:
40
+ """
41
+ Perform inference on the given noise and text prompts.
42
+ Inputs:
43
+ noise (torch.Tensor): The input noise tensor of shape
44
+ (batch_size, num_frames, num_channels, height, width).
45
+ text_prompts (List[str]): The list of text prompts.
46
+ Outputs:
47
+ video (torch.Tensor): The generated video tensor of shape
48
+ (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
49
+ """
50
+
51
+ conditional_dict = self.text_encoder(
52
+ text_prompts=text_prompts
53
+ )
54
+ unconditional_dict = self.text_encoder(
55
+ text_prompts=[self.args.negative_prompt] * len(text_prompts)
56
+ )
57
+
58
+ latents = noise
59
+
60
+ sample_scheduler = self._initialize_sample_scheduler(noise)
61
+ for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
62
+ latent_model_input = latents
63
+ timestep = t * torch.ones([latents.shape[0], 21], device=noise.device, dtype=torch.float32)
64
+
65
+ flow_pred_cond, _ = self.generator(latent_model_input, conditional_dict, timestep)
66
+ flow_pred_uncond, _ = self.generator(latent_model_input, unconditional_dict, timestep)
67
+
68
+ flow_pred = flow_pred_uncond + self.args.guidance_scale * (
69
+ flow_pred_cond - flow_pred_uncond)
70
+
71
+ temp_x0 = sample_scheduler.step(
72
+ flow_pred.unsqueeze(0),
73
+ t,
74
+ latents.unsqueeze(0),
75
+ return_dict=False)[0]
76
+ latents = temp_x0.squeeze(0)
77
+
78
+ x0 = latents
79
+ video = self.vae.decode_to_pixel(x0)
80
+ video = (video * 0.5 + 0.5).clamp(0, 1)
81
+
82
+ del sample_scheduler
83
+
84
+ if return_latents:
85
+ return video, latents
86
+ else:
87
+ return video
88
+
89
+ def _initialize_sample_scheduler(self, noise):
90
+ if self.sample_solver == 'unipc':
91
+ sample_scheduler = FlowUniPCMultistepScheduler(
92
+ num_train_timesteps=self.num_train_timesteps,
93
+ shift=1,
94
+ use_dynamic_shifting=False)
95
+ sample_scheduler.set_timesteps(
96
+ self.sampling_steps, device=noise.device, shift=self.shift)
97
+ self.timesteps = sample_scheduler.timesteps
98
+ elif self.sample_solver == 'dpm++':
99
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
100
+ num_train_timesteps=self.num_train_timesteps,
101
+ shift=1,
102
+ use_dynamic_shifting=False)
103
+ sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
104
+ self.timesteps, _ = retrieve_timesteps(
105
+ sample_scheduler,
106
+ device=noise.device,
107
+ sigmas=sampling_sigmas)
108
+ else:
109
+ raise NotImplementedError("Unsupported solver.")
110
+ return sample_scheduler
pipeline/bidirectional_inference.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ import torch
3
+
4
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
5
+
6
+
7
+ class BidirectionalInferencePipeline(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ args,
11
+ device,
12
+ generator=None,
13
+ text_encoder=None,
14
+ vae=None
15
+ ):
16
+ super().__init__()
17
+ # Step 1: Initialize all models
18
+ self.generator = WanDiffusionWrapper(
19
+ **getattr(args, "model_kwargs", {}), is_causal=False) if generator is None else generator
20
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
21
+ self.vae = WanVAEWrapper() if vae is None else vae
22
+
23
+ # Step 2: Initialize all bidirectional wan hyperparmeters
24
+ self.scheduler = self.generator.get_scheduler()
25
+ self.denoising_step_list = torch.tensor(
26
+ args.denoising_step_list, dtype=torch.long, device=device)
27
+ if self.denoising_step_list[-1] == 0:
28
+ self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
29
+ if args.warp_denoising_step:
30
+ timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
31
+ self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
32
+
33
+ def inference(self, noise: torch.Tensor, text_prompts: List[str]) -> torch.Tensor:
34
+ """
35
+ Perform inference on the given noise and text prompts.
36
+ Inputs:
37
+ noise (torch.Tensor): The input noise tensor of shape
38
+ (batch_size, num_frames, num_channels, height, width).
39
+ text_prompts (List[str]): The list of text prompts.
40
+ Outputs:
41
+ video (torch.Tensor): The generated video tensor of shape
42
+ (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
43
+ """
44
+ conditional_dict = self.text_encoder(
45
+ text_prompts=text_prompts
46
+ )
47
+
48
+ # initial point
49
+ noisy_image_or_video = noise
50
+
51
+ # use the last n-1 timesteps to simulate the generator's input
52
+ for index, current_timestep in enumerate(self.denoising_step_list[:-1]):
53
+ _, pred_image_or_video = self.generator(
54
+ noisy_image_or_video=noisy_image_or_video,
55
+ conditional_dict=conditional_dict,
56
+ timestep=torch.ones(
57
+ noise.shape[:2], dtype=torch.long, device=noise.device) * current_timestep
58
+ ) # [B, F, C, H, W]
59
+
60
+ next_timestep = self.denoising_step_list[index + 1] * torch.ones(
61
+ noise.shape[:2], dtype=torch.long, device=noise.device)
62
+
63
+ noisy_image_or_video = self.scheduler.add_noise(
64
+ pred_image_or_video.flatten(0, 1),
65
+ torch.randn_like(pred_image_or_video.flatten(0, 1)),
66
+ next_timestep.flatten(0, 1)
67
+ ).unflatten(0, noise.shape[:2])
68
+
69
+ video = self.vae.decode_to_pixel(pred_image_or_video)
70
+ video = (video * 0.5 + 0.5).clamp(0, 1)
71
+ return video
pipeline/causal_diffusion_inference.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ from typing import List, Optional
3
+ import torch
4
+
5
+ from wan.utils.fm_solvers import FlowDPMSolverMultistepScheduler, get_sampling_sigmas, retrieve_timesteps
6
+ from wan.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
7
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
8
+
9
+
10
+ class CausalDiffusionInferencePipeline(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ args,
14
+ device,
15
+ generator=None,
16
+ text_encoder=None,
17
+ vae=None
18
+ ):
19
+ super().__init__()
20
+ # Step 1: Initialize all models
21
+ self.generator = WanDiffusionWrapper(
22
+ **getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
23
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
24
+ self.vae = WanVAEWrapper() if vae is None else vae
25
+
26
+ # Step 2: Initialize scheduler
27
+ self.num_train_timesteps = args.num_train_timestep
28
+ self.sampling_steps = 50
29
+ self.sample_solver = 'unipc'
30
+ self.shift = args.timestep_shift
31
+
32
+ self.num_transformer_blocks = 30
33
+ self.frame_seq_length = 1560
34
+
35
+ self.kv_cache_pos = None
36
+ self.kv_cache_neg = None
37
+ self.crossattn_cache_pos = None
38
+ self.crossattn_cache_neg = None
39
+ self.args = args
40
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
41
+ self.independent_first_frame = args.independent_first_frame
42
+ self.local_attn_size = self.generator.model.local_attn_size
43
+
44
+ print(f"KV inference with {self.num_frame_per_block} frames per block")
45
+
46
+ if self.num_frame_per_block > 1:
47
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
48
+
49
+ def inference(
50
+ self,
51
+ noise: torch.Tensor,
52
+ text_prompts: List[str],
53
+ initial_latent: Optional[torch.Tensor] = None,
54
+ return_latents: bool = False,
55
+ start_frame_index: Optional[int] = 0
56
+ ) -> torch.Tensor:
57
+ """
58
+ Perform inference on the given noise and text prompts.
59
+ Inputs:
60
+ noise (torch.Tensor): The input noise tensor of shape
61
+ (batch_size, num_output_frames, num_channels, height, width).
62
+ text_prompts (List[str]): The list of text prompts.
63
+ initial_latent (torch.Tensor): The initial latent tensor of shape
64
+ (batch_size, num_input_frames, num_channels, height, width).
65
+ If num_input_frames is 1, perform image to video.
66
+ If num_input_frames is greater than 1, perform video extension.
67
+ return_latents (bool): Whether to return the latents.
68
+ start_frame_index (int): In long video generation, where does the current window start?
69
+ Outputs:
70
+ video (torch.Tensor): The generated video tensor of shape
71
+ (batch_size, num_frames, num_channels, height, width). It is normalized to be in the range [0, 1].
72
+ """
73
+ batch_size, num_frames, num_channels, height, width = noise.shape
74
+ if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
75
+ # If the first frame is independent and the first frame is provided, then the number of frames in the
76
+ # noise should still be a multiple of num_frame_per_block
77
+ assert num_frames % self.num_frame_per_block == 0
78
+ num_blocks = num_frames // self.num_frame_per_block
79
+ elif self.independent_first_frame and initial_latent is None:
80
+ # Using a [1, 4, 4, 4, 4, 4] model to generate a video without image conditioning
81
+ assert (num_frames - 1) % self.num_frame_per_block == 0
82
+ num_blocks = (num_frames - 1) // self.num_frame_per_block
83
+ num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
84
+ num_output_frames = num_frames + num_input_frames # add the initial latent frames
85
+ conditional_dict = self.text_encoder(
86
+ text_prompts=text_prompts
87
+ )
88
+ unconditional_dict = self.text_encoder(
89
+ text_prompts=[self.args.negative_prompt] * len(text_prompts)
90
+ )
91
+
92
+ output = torch.zeros(
93
+ [batch_size, num_output_frames, num_channels, height, width],
94
+ device=noise.device,
95
+ dtype=noise.dtype
96
+ )
97
+
98
+ # Step 1: Initialize KV cache to all zeros
99
+ if self.kv_cache_pos is None:
100
+ self._initialize_kv_cache(
101
+ batch_size=batch_size,
102
+ dtype=noise.dtype,
103
+ device=noise.device
104
+ )
105
+ self._initialize_crossattn_cache(
106
+ batch_size=batch_size,
107
+ dtype=noise.dtype,
108
+ device=noise.device
109
+ )
110
+ else:
111
+ # reset cross attn cache
112
+ for block_index in range(self.num_transformer_blocks):
113
+ self.crossattn_cache_pos[block_index]["is_init"] = False
114
+ self.crossattn_cache_neg[block_index]["is_init"] = False
115
+ # reset kv cache
116
+ for block_index in range(len(self.kv_cache_pos)):
117
+ self.kv_cache_pos[block_index]["global_end_index"] = torch.tensor(
118
+ [0], dtype=torch.long, device=noise.device)
119
+ self.kv_cache_pos[block_index]["local_end_index"] = torch.tensor(
120
+ [0], dtype=torch.long, device=noise.device)
121
+ self.kv_cache_neg[block_index]["global_end_index"] = torch.tensor(
122
+ [0], dtype=torch.long, device=noise.device)
123
+ self.kv_cache_neg[block_index]["local_end_index"] = torch.tensor(
124
+ [0], dtype=torch.long, device=noise.device)
125
+
126
+ # Step 2: Cache context feature
127
+ current_start_frame = start_frame_index
128
+ cache_start_frame = 0
129
+ if initial_latent is not None:
130
+ timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
131
+ if self.independent_first_frame:
132
+ # Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
133
+ assert (num_input_frames - 1) % self.num_frame_per_block == 0
134
+ num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
135
+ output[:, :1] = initial_latent[:, :1]
136
+ self.generator(
137
+ noisy_image_or_video=initial_latent[:, :1],
138
+ conditional_dict=conditional_dict,
139
+ timestep=timestep * 0,
140
+ kv_cache=self.kv_cache_pos,
141
+ crossattn_cache=self.crossattn_cache_pos,
142
+ current_start=current_start_frame * self.frame_seq_length,
143
+ cache_start=cache_start_frame * self.frame_seq_length
144
+ )
145
+ self.generator(
146
+ noisy_image_or_video=initial_latent[:, :1],
147
+ conditional_dict=unconditional_dict,
148
+ timestep=timestep * 0,
149
+ kv_cache=self.kv_cache_neg,
150
+ crossattn_cache=self.crossattn_cache_neg,
151
+ current_start=current_start_frame * self.frame_seq_length,
152
+ cache_start=cache_start_frame * self.frame_seq_length
153
+ )
154
+ current_start_frame += 1
155
+ cache_start_frame += 1
156
+ else:
157
+ # Assume num_input_frames is self.num_frame_per_block * num_input_blocks
158
+ assert num_input_frames % self.num_frame_per_block == 0
159
+ num_input_blocks = num_input_frames // self.num_frame_per_block
160
+
161
+ for block_index in range(num_input_blocks):
162
+ current_ref_latents = \
163
+ initial_latent[:, cache_start_frame:cache_start_frame + self.num_frame_per_block]
164
+ output[:, cache_start_frame:cache_start_frame + self.num_frame_per_block] = current_ref_latents
165
+ self.generator(
166
+ noisy_image_or_video=current_ref_latents,
167
+ conditional_dict=conditional_dict,
168
+ timestep=timestep * 0,
169
+ kv_cache=self.kv_cache_pos,
170
+ crossattn_cache=self.crossattn_cache_pos,
171
+ current_start=current_start_frame * self.frame_seq_length,
172
+ cache_start=cache_start_frame * self.frame_seq_length
173
+ )
174
+ self.generator(
175
+ noisy_image_or_video=current_ref_latents,
176
+ conditional_dict=unconditional_dict,
177
+ timestep=timestep * 0,
178
+ kv_cache=self.kv_cache_neg,
179
+ crossattn_cache=self.crossattn_cache_neg,
180
+ current_start=current_start_frame * self.frame_seq_length,
181
+ cache_start=cache_start_frame * self.frame_seq_length
182
+ )
183
+ current_start_frame += self.num_frame_per_block
184
+ cache_start_frame += self.num_frame_per_block
185
+
186
+ # Step 3: Temporal denoising loop
187
+ all_num_frames = [self.num_frame_per_block] * num_blocks
188
+ if self.independent_first_frame and initial_latent is None:
189
+ all_num_frames = [1] + all_num_frames
190
+ for current_num_frames in all_num_frames:
191
+ noisy_input = noise[
192
+ :, cache_start_frame - num_input_frames:cache_start_frame + current_num_frames - num_input_frames]
193
+ latents = noisy_input
194
+
195
+ # Step 3.1: Spatial denoising loop
196
+ sample_scheduler = self._initialize_sample_scheduler(noise)
197
+ for _, t in enumerate(tqdm(sample_scheduler.timesteps)):
198
+ latent_model_input = latents
199
+ timestep = t * torch.ones(
200
+ [batch_size, current_num_frames], device=noise.device, dtype=torch.float32
201
+ )
202
+
203
+ flow_pred_cond, _ = self.generator(
204
+ noisy_image_or_video=latent_model_input,
205
+ conditional_dict=conditional_dict,
206
+ timestep=timestep,
207
+ kv_cache=self.kv_cache_pos,
208
+ crossattn_cache=self.crossattn_cache_pos,
209
+ current_start=current_start_frame * self.frame_seq_length,
210
+ cache_start=cache_start_frame * self.frame_seq_length
211
+ )
212
+ flow_pred_uncond, _ = self.generator(
213
+ noisy_image_or_video=latent_model_input,
214
+ conditional_dict=unconditional_dict,
215
+ timestep=timestep,
216
+ kv_cache=self.kv_cache_neg,
217
+ crossattn_cache=self.crossattn_cache_neg,
218
+ current_start=current_start_frame * self.frame_seq_length,
219
+ cache_start=cache_start_frame * self.frame_seq_length
220
+ )
221
+
222
+ flow_pred = flow_pred_uncond + self.args.guidance_scale * (
223
+ flow_pred_cond - flow_pred_uncond)
224
+
225
+ temp_x0 = sample_scheduler.step(
226
+ flow_pred,
227
+ t,
228
+ latents,
229
+ return_dict=False)[0]
230
+ latents = temp_x0
231
+ print(f"kv_cache['local_end_index']: {self.kv_cache_pos[0]['local_end_index']}")
232
+ print(f"kv_cache['global_end_index']: {self.kv_cache_pos[0]['global_end_index']}")
233
+
234
+ # Step 3.2: record the model's output
235
+ output[:, cache_start_frame:cache_start_frame + current_num_frames] = latents
236
+
237
+ # Step 3.3: rerun with timestep zero to update KV cache using clean context
238
+ self.generator(
239
+ noisy_image_or_video=latents,
240
+ conditional_dict=conditional_dict,
241
+ timestep=timestep * 0,
242
+ kv_cache=self.kv_cache_pos,
243
+ crossattn_cache=self.crossattn_cache_pos,
244
+ current_start=current_start_frame * self.frame_seq_length,
245
+ cache_start=cache_start_frame * self.frame_seq_length
246
+ )
247
+ self.generator(
248
+ noisy_image_or_video=latents,
249
+ conditional_dict=unconditional_dict,
250
+ timestep=timestep * 0,
251
+ kv_cache=self.kv_cache_neg,
252
+ crossattn_cache=self.crossattn_cache_neg,
253
+ current_start=current_start_frame * self.frame_seq_length,
254
+ cache_start=cache_start_frame * self.frame_seq_length
255
+ )
256
+
257
+ # Step 3.4: update the start and end frame indices
258
+ current_start_frame += current_num_frames
259
+ cache_start_frame += current_num_frames
260
+
261
+ # Step 4: Decode the output
262
+ video = self.vae.decode_to_pixel(output)
263
+ video = (video * 0.5 + 0.5).clamp(0, 1)
264
+
265
+ if return_latents:
266
+ return video, output
267
+ else:
268
+ return video
269
+
270
+ def _initialize_kv_cache(self, batch_size, dtype, device):
271
+ """
272
+ Initialize a Per-GPU KV cache for the Wan model.
273
+ """
274
+ kv_cache_pos = []
275
+ kv_cache_neg = []
276
+ if self.local_attn_size != -1:
277
+ # Use the local attention size to compute the KV cache size
278
+ kv_cache_size = self.local_attn_size * self.frame_seq_length
279
+ else:
280
+ # Use the default KV cache size
281
+ kv_cache_size = 32760
282
+
283
+ for _ in range(self.num_transformer_blocks):
284
+ kv_cache_pos.append({
285
+ "k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
286
+ "v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
287
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
288
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
289
+ })
290
+ kv_cache_neg.append({
291
+ "k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
292
+ "v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
293
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
294
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
295
+ })
296
+
297
+ self.kv_cache_pos = kv_cache_pos # always store the clean cache
298
+ self.kv_cache_neg = kv_cache_neg # always store the clean cache
299
+
300
+ def _initialize_crossattn_cache(self, batch_size, dtype, device):
301
+ """
302
+ Initialize a Per-GPU cross-attention cache for the Wan model.
303
+ """
304
+ crossattn_cache_pos = []
305
+ crossattn_cache_neg = []
306
+ for _ in range(self.num_transformer_blocks):
307
+ crossattn_cache_pos.append({
308
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
309
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
310
+ "is_init": False
311
+ })
312
+ crossattn_cache_neg.append({
313
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
314
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
315
+ "is_init": False
316
+ })
317
+
318
+ self.crossattn_cache_pos = crossattn_cache_pos # always store the clean cache
319
+ self.crossattn_cache_neg = crossattn_cache_neg # always store the clean cache
320
+
321
+ def _initialize_sample_scheduler(self, noise):
322
+ if self.sample_solver == 'unipc':
323
+ sample_scheduler = FlowUniPCMultistepScheduler(
324
+ num_train_timesteps=self.num_train_timesteps,
325
+ shift=1,
326
+ use_dynamic_shifting=False)
327
+ sample_scheduler.set_timesteps(
328
+ self.sampling_steps, device=noise.device, shift=self.shift)
329
+ self.timesteps = sample_scheduler.timesteps
330
+ elif self.sample_solver == 'dpm++':
331
+ sample_scheduler = FlowDPMSolverMultistepScheduler(
332
+ num_train_timesteps=self.num_train_timesteps,
333
+ shift=1,
334
+ use_dynamic_shifting=False)
335
+ sampling_sigmas = get_sampling_sigmas(self.sampling_steps, self.shift)
336
+ self.timesteps, _ = retrieve_timesteps(
337
+ sample_scheduler,
338
+ device=noise.device,
339
+ sigmas=sampling_sigmas)
340
+ else:
341
+ raise NotImplementedError("Unsupported solver.")
342
+ return sample_scheduler
pipeline/causal_inference.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ import torch
3
+
4
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
5
+
6
+
7
+ class CausalInferencePipeline(torch.nn.Module):
8
+ def __init__(
9
+ self,
10
+ args,
11
+ device,
12
+ generator=None,
13
+ text_encoder=None,
14
+ vae=None
15
+ ):
16
+ super().__init__()
17
+ # Step 1: Initialize all models
18
+ self.generator = WanDiffusionWrapper(
19
+ **getattr(args, "model_kwargs", {}), is_causal=True) if generator is None else generator
20
+ self.text_encoder = WanTextEncoder() if text_encoder is None else text_encoder
21
+ self.vae = WanVAEWrapper() if vae is None else vae
22
+
23
+ # Step 2: Initialize all causal hyperparmeters
24
+ self.scheduler = self.generator.get_scheduler()
25
+ self.denoising_step_list = torch.tensor(
26
+ args.denoising_step_list, dtype=torch.long)
27
+ if args.warp_denoising_step:
28
+ timesteps = torch.cat((self.scheduler.timesteps.cpu(), torch.tensor([0], dtype=torch.float32)))
29
+ self.denoising_step_list = timesteps[1000 - self.denoising_step_list]
30
+
31
+ self.num_transformer_blocks = 30
32
+ self.frame_seq_length = 1560
33
+
34
+ self.kv_cache1 = None
35
+ self.args = args
36
+ self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
37
+ self.independent_first_frame = args.independent_first_frame
38
+ self.local_attn_size = self.generator.model.local_attn_size
39
+
40
+ print(f"KV inference with {self.num_frame_per_block} frames per block")
41
+
42
+ if self.num_frame_per_block > 1:
43
+ self.generator.model.num_frame_per_block = self.num_frame_per_block
44
+
45
+ def inference(
46
+ self,
47
+ noise: torch.Tensor,
48
+ text_prompts: List[str],
49
+ initial_latent: Optional[torch.Tensor] = None,
50
+ return_latents: bool = False,
51
+ profile: bool = False
52
+ ) -> torch.Tensor:
53
+ """
54
+ Perform inference on the given noise and text prompts.
55
+ Inputs:
56
+ noise (torch.Tensor): The input noise tensor of shape
57
+ (batch_size, num_output_frames, num_channels, height, width).
58
+ text_prompts (List[str]): The list of text prompts.
59
+ initial_latent (torch.Tensor): The initial latent tensor of shape
60
+ (batch_size, num_input_frames, num_channels, height, width).
61
+ If num_input_frames is 1, perform image to video.
62
+ If num_input_frames is greater than 1, perform video extension.
63
+ return_latents (bool): Whether to return the latents.
64
+ Outputs:
65
+ video (torch.Tensor): The generated video tensor of shape
66
+ (batch_size, num_output_frames, num_channels, height, width).
67
+ It is normalized to be in the range [0, 1].
68
+ """
69
+ batch_size, num_frames, num_channels, height, width = noise.shape
70
+ if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
71
+ # If the first frame is independent and the first frame is provided, then the number of frames in the
72
+ # noise should still be a multiple of num_frame_per_block
73
+ assert num_frames % self.num_frame_per_block == 0
74
+ num_blocks = num_frames // self.num_frame_per_block
75
+ else:
76
+ # Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
77
+ assert (num_frames - 1) % self.num_frame_per_block == 0
78
+ num_blocks = (num_frames - 1) // self.num_frame_per_block
79
+ num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
80
+ num_output_frames = num_frames + num_input_frames # add the initial latent frames
81
+ conditional_dict = self.text_encoder(
82
+ text_prompts=text_prompts
83
+ )
84
+
85
+ output = torch.zeros(
86
+ [batch_size, num_output_frames, num_channels, height, width],
87
+ device=noise.device,
88
+ dtype=noise.dtype
89
+ )
90
+
91
+ # Set up profiling if requested
92
+ if profile:
93
+ init_start = torch.cuda.Event(enable_timing=True)
94
+ init_end = torch.cuda.Event(enable_timing=True)
95
+ diffusion_start = torch.cuda.Event(enable_timing=True)
96
+ diffusion_end = torch.cuda.Event(enable_timing=True)
97
+ vae_start = torch.cuda.Event(enable_timing=True)
98
+ vae_end = torch.cuda.Event(enable_timing=True)
99
+ block_times = []
100
+ block_start = torch.cuda.Event(enable_timing=True)
101
+ block_end = torch.cuda.Event(enable_timing=True)
102
+ init_start.record()
103
+
104
+ # Step 1: Initialize KV cache to all zeros
105
+ if self.kv_cache1 is None:
106
+ self._initialize_kv_cache(
107
+ batch_size=batch_size,
108
+ dtype=noise.dtype,
109
+ device=noise.device
110
+ )
111
+ self._initialize_crossattn_cache(
112
+ batch_size=batch_size,
113
+ dtype=noise.dtype,
114
+ device=noise.device
115
+ )
116
+ else:
117
+ # reset cross attn cache
118
+ for block_index in range(self.num_transformer_blocks):
119
+ self.crossattn_cache[block_index]["is_init"] = False
120
+ # reset kv cache
121
+ for block_index in range(len(self.kv_cache1)):
122
+ self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
123
+ [0], dtype=torch.long, device=noise.device)
124
+ self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
125
+ [0], dtype=torch.long, device=noise.device)
126
+
127
+ # Step 2: Cache context feature
128
+ current_start_frame = 0
129
+ if initial_latent is not None:
130
+ timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
131
+ if self.independent_first_frame:
132
+ # Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
133
+ assert (num_input_frames - 1) % self.num_frame_per_block == 0
134
+ num_input_blocks = (num_input_frames - 1) // self.num_frame_per_block
135
+ output[:, :1] = initial_latent[:, :1]
136
+ self.generator(
137
+ noisy_image_or_video=initial_latent[:, :1],
138
+ conditional_dict=conditional_dict,
139
+ timestep=timestep * 0,
140
+ kv_cache=self.kv_cache1,
141
+ crossattn_cache=self.crossattn_cache,
142
+ current_start=current_start_frame * self.frame_seq_length,
143
+ )
144
+ current_start_frame += 1
145
+ else:
146
+ # Assume num_input_frames is self.num_frame_per_block * num_input_blocks
147
+ assert num_input_frames % self.num_frame_per_block == 0
148
+ num_input_blocks = num_input_frames // self.num_frame_per_block
149
+
150
+ for _ in range(num_input_blocks):
151
+ current_ref_latents = \
152
+ initial_latent[:, current_start_frame:current_start_frame + self.num_frame_per_block]
153
+ output[:, current_start_frame:current_start_frame + self.num_frame_per_block] = current_ref_latents
154
+ self.generator(
155
+ noisy_image_or_video=current_ref_latents,
156
+ conditional_dict=conditional_dict,
157
+ timestep=timestep * 0,
158
+ kv_cache=self.kv_cache1,
159
+ crossattn_cache=self.crossattn_cache,
160
+ current_start=current_start_frame * self.frame_seq_length,
161
+ )
162
+ current_start_frame += self.num_frame_per_block
163
+
164
+ if profile:
165
+ init_end.record()
166
+ torch.cuda.synchronize()
167
+ diffusion_start.record()
168
+
169
+ # Step 3: Temporal denoising loop
170
+ all_num_frames = [self.num_frame_per_block] * num_blocks
171
+ if self.independent_first_frame and initial_latent is None:
172
+ all_num_frames = [1] + all_num_frames
173
+ for current_num_frames in all_num_frames:
174
+ if profile:
175
+ block_start.record()
176
+
177
+ noisy_input = noise[
178
+ :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
179
+
180
+ # Step 3.1: Spatial denoising loop
181
+ for index, current_timestep in enumerate(self.denoising_step_list):
182
+ print(f"current_timestep: {current_timestep}")
183
+ # set current timestep
184
+ timestep = torch.ones(
185
+ [batch_size, current_num_frames],
186
+ device=noise.device,
187
+ dtype=torch.int64) * current_timestep
188
+
189
+ if index < len(self.denoising_step_list) - 1:
190
+ _, denoised_pred = self.generator(
191
+ noisy_image_or_video=noisy_input,
192
+ conditional_dict=conditional_dict,
193
+ timestep=timestep,
194
+ kv_cache=self.kv_cache1,
195
+ crossattn_cache=self.crossattn_cache,
196
+ current_start=current_start_frame * self.frame_seq_length
197
+ )
198
+ next_timestep = self.denoising_step_list[index + 1]
199
+ noisy_input = self.scheduler.add_noise(
200
+ denoised_pred.flatten(0, 1),
201
+ torch.randn_like(denoised_pred.flatten(0, 1)),
202
+ next_timestep * torch.ones(
203
+ [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
204
+ ).unflatten(0, denoised_pred.shape[:2])
205
+ else:
206
+ # for getting real output
207
+ _, denoised_pred = self.generator(
208
+ noisy_image_or_video=noisy_input,
209
+ conditional_dict=conditional_dict,
210
+ timestep=timestep,
211
+ kv_cache=self.kv_cache1,
212
+ crossattn_cache=self.crossattn_cache,
213
+ current_start=current_start_frame * self.frame_seq_length
214
+ )
215
+
216
+ # Step 3.2: record the model's output
217
+ output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
218
+
219
+ # Step 3.3: rerun with timestep zero to update KV cache using clean context
220
+ context_timestep = torch.ones_like(timestep) * self.args.context_noise
221
+ self.generator(
222
+ noisy_image_or_video=denoised_pred,
223
+ conditional_dict=conditional_dict,
224
+ timestep=context_timestep,
225
+ kv_cache=self.kv_cache1,
226
+ crossattn_cache=self.crossattn_cache,
227
+ current_start=current_start_frame * self.frame_seq_length,
228
+ )
229
+
230
+ if profile:
231
+ block_end.record()
232
+ torch.cuda.synchronize()
233
+ block_time = block_start.elapsed_time(block_end)
234
+ block_times.append(block_time)
235
+
236
+ # Step 3.4: update the start and end frame indices
237
+ current_start_frame += current_num_frames
238
+
239
+ if profile:
240
+ # End diffusion timing and synchronize CUDA
241
+ diffusion_end.record()
242
+ torch.cuda.synchronize()
243
+ diffusion_time = diffusion_start.elapsed_time(diffusion_end)
244
+ init_time = init_start.elapsed_time(init_end)
245
+ vae_start.record()
246
+
247
+ # Step 4: Decode the output
248
+ video = self.vae.decode_to_pixel(output, use_cache=False)
249
+ video = (video * 0.5 + 0.5).clamp(0, 1)
250
+
251
+ if profile:
252
+ # End VAE timing and synchronize CUDA
253
+ vae_end.record()
254
+ torch.cuda.synchronize()
255
+ vae_time = vae_start.elapsed_time(vae_end)
256
+ total_time = init_time + diffusion_time + vae_time
257
+
258
+ print("Profiling results:")
259
+ print(f" - Initialization/caching time: {init_time:.2f} ms ({100 * init_time / total_time:.2f}%)")
260
+ print(f" - Diffusion generation time: {diffusion_time:.2f} ms ({100 * diffusion_time / total_time:.2f}%)")
261
+ for i, block_time in enumerate(block_times):
262
+ print(f" - Block {i} generation time: {block_time:.2f} ms ({100 * block_time / diffusion_time:.2f}% of diffusion)")
263
+ print(f" - VAE decoding time: {vae_time:.2f} ms ({100 * vae_time / total_time:.2f}%)")
264
+ print(f" - Total time: {total_time:.2f} ms")
265
+
266
+ if return_latents:
267
+ return video, output
268
+ else:
269
+ return video
270
+
271
+ def _initialize_kv_cache(self, batch_size, dtype, device):
272
+ """
273
+ Initialize a Per-GPU KV cache for the Wan model.
274
+ """
275
+ kv_cache1 = []
276
+ if self.local_attn_size != -1:
277
+ # Use the local attention size to compute the KV cache size
278
+ kv_cache_size = self.local_attn_size * self.frame_seq_length
279
+ else:
280
+ # Use the default KV cache size
281
+ kv_cache_size = 32760
282
+
283
+ for _ in range(self.num_transformer_blocks):
284
+ kv_cache1.append({
285
+ "k": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
286
+ "v": torch.zeros([batch_size, kv_cache_size, 12, 128], dtype=dtype, device=device),
287
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
288
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
289
+ })
290
+
291
+ self.kv_cache1 = kv_cache1 # always store the clean cache
292
+
293
+ def _initialize_crossattn_cache(self, batch_size, dtype, device):
294
+ """
295
+ Initialize a Per-GPU cross-attention cache for the Wan model.
296
+ """
297
+ crossattn_cache = []
298
+
299
+ for _ in range(self.num_transformer_blocks):
300
+ crossattn_cache.append({
301
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
302
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
303
+ "is_init": False
304
+ })
305
+ self.crossattn_cache = crossattn_cache
pipeline/self_forcing_training.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.wan_wrapper import WanDiffusionWrapper
2
+ from utils.scheduler import SchedulerInterface
3
+ from typing import List, Optional
4
+ import torch
5
+ import torch.distributed as dist
6
+
7
+
8
+ class SelfForcingTrainingPipeline:
9
+ def __init__(self,
10
+ denoising_step_list: List[int],
11
+ scheduler: SchedulerInterface,
12
+ generator: WanDiffusionWrapper,
13
+ num_frame_per_block=3,
14
+ independent_first_frame: bool = False,
15
+ same_step_across_blocks: bool = False,
16
+ last_step_only: bool = False,
17
+ num_max_frames: int = 21,
18
+ context_noise: int = 0,
19
+ **kwargs):
20
+ super().__init__()
21
+ self.scheduler = scheduler
22
+ self.generator = generator
23
+ self.denoising_step_list = denoising_step_list
24
+ if self.denoising_step_list[-1] == 0:
25
+ self.denoising_step_list = self.denoising_step_list[:-1] # remove the zero timestep for inference
26
+
27
+ # Wan specific hyperparameters
28
+ self.num_transformer_blocks = 30
29
+ self.frame_seq_length = 1560
30
+ self.num_frame_per_block = num_frame_per_block
31
+ self.context_noise = context_noise
32
+ self.i2v = False
33
+
34
+ self.kv_cache1 = None
35
+ self.kv_cache2 = None
36
+ self.independent_first_frame = independent_first_frame
37
+ self.same_step_across_blocks = same_step_across_blocks
38
+ self.last_step_only = last_step_only
39
+ self.kv_cache_size = num_max_frames * self.frame_seq_length
40
+
41
+ def generate_and_sync_list(self, num_blocks, num_denoising_steps, device):
42
+ rank = dist.get_rank() if dist.is_initialized() else 0
43
+
44
+ if rank == 0:
45
+ # Generate random indices
46
+ indices = torch.randint(
47
+ low=0,
48
+ high=num_denoising_steps,
49
+ size=(num_blocks,),
50
+ device=device
51
+ )
52
+ if self.last_step_only:
53
+ indices = torch.ones_like(indices) * (num_denoising_steps - 1)
54
+ else:
55
+ indices = torch.empty(num_blocks, dtype=torch.long, device=device)
56
+
57
+ dist.broadcast(indices, src=0) # Broadcast the random indices to all ranks
58
+ return indices.tolist()
59
+
60
+ def inference_with_trajectory(
61
+ self,
62
+ noise: torch.Tensor,
63
+ initial_latent: Optional[torch.Tensor] = None,
64
+ return_sim_step: bool = False,
65
+ **conditional_dict
66
+ ) -> torch.Tensor:
67
+ batch_size, num_frames, num_channels, height, width = noise.shape
68
+ if not self.independent_first_frame or (self.independent_first_frame and initial_latent is not None):
69
+ # If the first frame is independent and the first frame is provided, then the number of frames in the
70
+ # noise should still be a multiple of num_frame_per_block
71
+ assert num_frames % self.num_frame_per_block == 0
72
+ num_blocks = num_frames // self.num_frame_per_block
73
+ else:
74
+ # Using a [1, 4, 4, 4, 4, 4, ...] model to generate a video without image conditioning
75
+ assert (num_frames - 1) % self.num_frame_per_block == 0
76
+ num_blocks = (num_frames - 1) // self.num_frame_per_block
77
+ num_input_frames = initial_latent.shape[1] if initial_latent is not None else 0
78
+ num_output_frames = num_frames + num_input_frames # add the initial latent frames
79
+ output = torch.zeros(
80
+ [batch_size, num_output_frames, num_channels, height, width],
81
+ device=noise.device,
82
+ dtype=noise.dtype
83
+ )
84
+
85
+ # Step 1: Initialize KV cache to all zeros
86
+ self._initialize_kv_cache(
87
+ batch_size=batch_size, dtype=noise.dtype, device=noise.device
88
+ )
89
+ self._initialize_crossattn_cache(
90
+ batch_size=batch_size, dtype=noise.dtype, device=noise.device
91
+ )
92
+ # if self.kv_cache1 is None:
93
+ # self._initialize_kv_cache(
94
+ # batch_size=batch_size,
95
+ # dtype=noise.dtype,
96
+ # device=noise.device,
97
+ # )
98
+ # self._initialize_crossattn_cache(
99
+ # batch_size=batch_size,
100
+ # dtype=noise.dtype,
101
+ # device=noise.device
102
+ # )
103
+ # else:
104
+ # # reset cross attn cache
105
+ # for block_index in range(self.num_transformer_blocks):
106
+ # self.crossattn_cache[block_index]["is_init"] = False
107
+ # # reset kv cache
108
+ # for block_index in range(len(self.kv_cache1)):
109
+ # self.kv_cache1[block_index]["global_end_index"] = torch.tensor(
110
+ # [0], dtype=torch.long, device=noise.device)
111
+ # self.kv_cache1[block_index]["local_end_index"] = torch.tensor(
112
+ # [0], dtype=torch.long, device=noise.device)
113
+
114
+ # Step 2: Cache context feature
115
+ current_start_frame = 0
116
+ if initial_latent is not None:
117
+ timestep = torch.ones([batch_size, 1], device=noise.device, dtype=torch.int64) * 0
118
+ # Assume num_input_frames is 1 + self.num_frame_per_block * num_input_blocks
119
+ output[:, :1] = initial_latent
120
+ with torch.no_grad():
121
+ self.generator(
122
+ noisy_image_or_video=initial_latent,
123
+ conditional_dict=conditional_dict,
124
+ timestep=timestep * 0,
125
+ kv_cache=self.kv_cache1,
126
+ crossattn_cache=self.crossattn_cache,
127
+ current_start=current_start_frame * self.frame_seq_length
128
+ )
129
+ current_start_frame += 1
130
+
131
+ # Step 3: Temporal denoising loop
132
+ all_num_frames = [self.num_frame_per_block] * num_blocks
133
+ if self.independent_first_frame and initial_latent is None:
134
+ all_num_frames = [1] + all_num_frames
135
+ num_denoising_steps = len(self.denoising_step_list)
136
+ exit_flags = self.generate_and_sync_list(len(all_num_frames), num_denoising_steps, device=noise.device)
137
+ start_gradient_frame_index = num_output_frames - 21
138
+
139
+ # for block_index in range(num_blocks):
140
+ for block_index, current_num_frames in enumerate(all_num_frames):
141
+ noisy_input = noise[
142
+ :, current_start_frame - num_input_frames:current_start_frame + current_num_frames - num_input_frames]
143
+
144
+ # Step 3.1: Spatial denoising loop
145
+ for index, current_timestep in enumerate(self.denoising_step_list):
146
+ if self.same_step_across_blocks:
147
+ exit_flag = (index == exit_flags[0])
148
+ else:
149
+ exit_flag = (index == exit_flags[block_index]) # Only backprop at the randomly selected timestep (consistent across all ranks)
150
+ timestep = torch.ones(
151
+ [batch_size, current_num_frames],
152
+ device=noise.device,
153
+ dtype=torch.int64) * current_timestep
154
+
155
+ if not exit_flag:
156
+ with torch.no_grad():
157
+ _, denoised_pred = self.generator(
158
+ noisy_image_or_video=noisy_input,
159
+ conditional_dict=conditional_dict,
160
+ timestep=timestep,
161
+ kv_cache=self.kv_cache1,
162
+ crossattn_cache=self.crossattn_cache,
163
+ current_start=current_start_frame * self.frame_seq_length
164
+ )
165
+ next_timestep = self.denoising_step_list[index + 1]
166
+ noisy_input = self.scheduler.add_noise(
167
+ denoised_pred.flatten(0, 1),
168
+ torch.randn_like(denoised_pred.flatten(0, 1)),
169
+ next_timestep * torch.ones(
170
+ [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
171
+ ).unflatten(0, denoised_pred.shape[:2])
172
+ else:
173
+ # for getting real output
174
+ # with torch.set_grad_enabled(current_start_frame >= start_gradient_frame_index):
175
+ if current_start_frame < start_gradient_frame_index:
176
+ with torch.no_grad():
177
+ _, denoised_pred = self.generator(
178
+ noisy_image_or_video=noisy_input,
179
+ conditional_dict=conditional_dict,
180
+ timestep=timestep,
181
+ kv_cache=self.kv_cache1,
182
+ crossattn_cache=self.crossattn_cache,
183
+ current_start=current_start_frame * self.frame_seq_length
184
+ )
185
+ else:
186
+ _, denoised_pred = self.generator(
187
+ noisy_image_or_video=noisy_input,
188
+ conditional_dict=conditional_dict,
189
+ timestep=timestep,
190
+ kv_cache=self.kv_cache1,
191
+ crossattn_cache=self.crossattn_cache,
192
+ current_start=current_start_frame * self.frame_seq_length
193
+ )
194
+ break
195
+
196
+ # Step 3.2: record the model's output
197
+ output[:, current_start_frame:current_start_frame + current_num_frames] = denoised_pred
198
+
199
+ # Step 3.3: rerun with timestep zero to update the cache
200
+ context_timestep = torch.ones_like(timestep) * self.context_noise
201
+ # add context noise
202
+ denoised_pred = self.scheduler.add_noise(
203
+ denoised_pred.flatten(0, 1),
204
+ torch.randn_like(denoised_pred.flatten(0, 1)),
205
+ context_timestep * torch.ones(
206
+ [batch_size * current_num_frames], device=noise.device, dtype=torch.long)
207
+ ).unflatten(0, denoised_pred.shape[:2])
208
+ with torch.no_grad():
209
+ self.generator(
210
+ noisy_image_or_video=denoised_pred,
211
+ conditional_dict=conditional_dict,
212
+ timestep=context_timestep,
213
+ kv_cache=self.kv_cache1,
214
+ crossattn_cache=self.crossattn_cache,
215
+ current_start=current_start_frame * self.frame_seq_length
216
+ )
217
+
218
+ # Step 3.4: update the start and end frame indices
219
+ current_start_frame += current_num_frames
220
+
221
+ # Step 3.5: Return the denoised timestep
222
+ if not self.same_step_across_blocks:
223
+ denoised_timestep_from, denoised_timestep_to = None, None
224
+ elif exit_flags[0] == len(self.denoising_step_list) - 1:
225
+ denoised_timestep_to = 0
226
+ denoised_timestep_from = 1000 - torch.argmin(
227
+ (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
228
+ else:
229
+ denoised_timestep_to = 1000 - torch.argmin(
230
+ (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0] + 1].cuda()).abs(), dim=0).item()
231
+ denoised_timestep_from = 1000 - torch.argmin(
232
+ (self.scheduler.timesteps.cuda() - self.denoising_step_list[exit_flags[0]].cuda()).abs(), dim=0).item()
233
+
234
+ if return_sim_step:
235
+ return output, denoised_timestep_from, denoised_timestep_to, exit_flags[0] + 1
236
+
237
+ return output, denoised_timestep_from, denoised_timestep_to
238
+
239
+ def _initialize_kv_cache(self, batch_size, dtype, device):
240
+ """
241
+ Initialize a Per-GPU KV cache for the Wan model.
242
+ """
243
+ kv_cache1 = []
244
+
245
+ for _ in range(self.num_transformer_blocks):
246
+ kv_cache1.append({
247
+ "k": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
248
+ "v": torch.zeros([batch_size, self.kv_cache_size, 12, 128], dtype=dtype, device=device),
249
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
250
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
251
+ })
252
+
253
+ self.kv_cache1 = kv_cache1 # always store the clean cache
254
+
255
+ def _initialize_crossattn_cache(self, batch_size, dtype, device):
256
+ """
257
+ Initialize a Per-GPU cross-attention cache for the Wan model.
258
+ """
259
+ crossattn_cache = []
260
+
261
+ for _ in range(self.num_transformer_blocks):
262
+ crossattn_cache.append({
263
+ "k": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
264
+ "v": torch.zeros([batch_size, 512, 12, 128], dtype=dtype, device=device),
265
+ "is_init": False
266
+ })
267
+ self.crossattn_cache = crossattn_cache
prompts/MovieGenVideoBench.txt ADDED
The diff for this file is too large to render. See raw diff
 
prompts/MovieGenVideoBench_extended.txt ADDED
The diff for this file is too large to render. See raw diff
 
prompts/vbench/all_dimension.txt ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ In a still frame, a stop sign
2
+ a toilet, frozen in time
3
+ a laptop, frozen in time
4
+ A tranquil tableau of alley
5
+ A tranquil tableau of bar
6
+ A tranquil tableau of barn
7
+ A tranquil tableau of bathroom
8
+ A tranquil tableau of bedroom
9
+ A tranquil tableau of cliff
10
+ In a still frame, courtyard
11
+ In a still frame, gas station
12
+ A tranquil tableau of house
13
+ indoor gymnasium, frozen in time
14
+ A tranquil tableau of indoor library
15
+ A tranquil tableau of kitchen
16
+ A tranquil tableau of palace
17
+ In a still frame, parking lot
18
+ In a still frame, phone booth
19
+ A tranquil tableau of restaurant
20
+ A tranquil tableau of tower
21
+ A tranquil tableau of a bowl
22
+ A tranquil tableau of an apple
23
+ A tranquil tableau of a bench
24
+ A tranquil tableau of a bed
25
+ A tranquil tableau of a chair
26
+ A tranquil tableau of a cup
27
+ A tranquil tableau of a dining table
28
+ In a still frame, a pear
29
+ A tranquil tableau of a bunch of grapes
30
+ A tranquil tableau of a bowl on the kitchen counter
31
+ A tranquil tableau of a beautiful, handcrafted ceramic bowl
32
+ A tranquil tableau of an antique bowl
33
+ A tranquil tableau of an exquisite mahogany dining table
34
+ A tranquil tableau of a wooden bench in the park
35
+ A tranquil tableau of a beautiful wrought-iron bench surrounded by blooming flowers
36
+ In a still frame, a park bench with a view of the lake
37
+ A tranquil tableau of a vintage rocking chair was placed on the porch
38
+ A tranquil tableau of the jail cell was small and dimly lit, with cold, steel bars
39
+ A tranquil tableau of the phone booth was tucked away in a quiet alley
40
+ a dilapidated phone booth stood as a relic of a bygone era on the sidewalk, frozen in time
41
+ A tranquil tableau of the old red barn stood weathered and iconic against the backdrop of the countryside
42
+ A tranquil tableau of a picturesque barn was painted a warm shade of red and nestled in a picturesque meadow
43
+ In a still frame, within the desolate desert, an oasis unfolded, characterized by the stoic presence of palm trees and a motionless, glassy pool of water
44
+ In a still frame, the Parthenon's majestic Doric columns stand in serene solitude atop the Acropolis, framed by the tranquil Athenian landscape
45
+ In a still frame, the Temple of Hephaestus, with its timeless Doric grace, stands stoically against the backdrop of a quiet Athens
46
+ In a still frame, the ornate Victorian streetlamp stands solemnly, adorned with intricate ironwork and stained glass panels
47
+ A tranquil tableau of the Stonehenge presented itself as an enigmatic puzzle, each colossal stone meticulously placed against the backdrop of tranquility
48
+ In a still frame, in the vast desert, an oasis nestled among dunes, featuring tall palm trees and an air of serenity
49
+ static view on a desert scene with an oasis, palm trees, and a clear, calm pool of water
50
+ A tranquil tableau of an ornate Victorian streetlamp standing on a cobblestone street corner, illuminating the empty night
51
+ A tranquil tableau of a tranquil lakeside cabin nestled among tall pines, its reflection mirrored perfectly in the calm water
52
+ In a still frame, a vintage gas lantern, adorned with intricate details, gracing a historic cobblestone square
53
+ In a still frame, a tranquil Japanese tea ceremony room, with tatami mats, a delicate tea set, and a bonsai tree in the corner
54
+ A tranquil tableau of the Parthenon stands resolute in its classical elegance, a timeless symbol of Athens' cultural legacy
55
+ A tranquil tableau of in the heart of Plaka, the neoclassical architecture of the old city harmonizes with the ancient ruins
56
+ A tranquil tableau of in the desolate beauty of the American Southwest, Chaco Canyon's ancient ruins whispered tales of an enigmatic civilization that once thrived amidst the arid landscapes
57
+ A tranquil tableau of at the edge of the Arabian Desert, the ancient city of Petra beckoned with its enigmatic rock-carved façades
58
+ In a still frame, amidst the cobblestone streets, an Art Nouveau lamppost stood tall
59
+ A tranquil tableau of in the quaint village square, a traditional wrought-iron streetlamp featured delicate filigree patterns and amber-hued glass panels
60
+ A tranquil tableau of the lampposts were adorned with Art Deco motifs, their geometric shapes and frosted glass creating a sense of vintage glamour
61
+ In a still frame, in the picturesque square, a Gothic-style lamppost adorned with intricate stone carvings added a touch of medieval charm to the setting
62
+ In a still frame, in the heart of the old city, a row of ornate lantern-style streetlamps bathed the narrow alleyway in a warm, welcoming light
63
+ A tranquil tableau of in the heart of the Utah desert, a massive sandstone arch spanned the horizon
64
+ A tranquil tableau of in the Arizona desert, a massive stone bridge arched across a rugged canyon
65
+ A tranquil tableau of in the corner of the minimalist tea room, a bonsai tree added a touch of nature's beauty to the otherwise simple and elegant space
66
+ In a still frame, amidst the hushed ambiance of the traditional tea room, a meticulously arranged tea set awaited, with porcelain cups, a bamboo whisk
67
+ In a still frame, nestled in the Zen garden, a rustic teahouse featured tatami seating and a traditional charcoal brazier
68
+ A tranquil tableau of a country estate's library featured elegant wooden shelves
69
+ A tranquil tableau of beneath the shade of a solitary oak tree, an old wooden park bench sat patiently
70
+ A tranquil tableau of beside a tranquil pond, a weeping willow tree draped its branches gracefully over the water's surface, creating a serene tableau of reflection and calm
71
+ A tranquil tableau of in the Zen garden, a perfectly raked gravel path led to a serene rock garden
72
+ In a still frame, a tranquil pond was fringed by weeping cherry trees, their blossoms drifting lazily onto the glassy surface
73
+ In a still frame, within the historic library's reading room, rows of antique leather chairs and mahogany tables offered a serene haven for literary contemplation
74
+ A tranquil tableau of a peaceful orchid garden showcased a variety of delicate blooms
75
+ A tranquil tableau of in the serene courtyard, a centuries-old stone well stood as a symbol of a bygone era, its mossy stones bearing witness to the passage of time
76
+ a bird and a cat
77
+ a cat and a dog
78
+ a dog and a horse
79
+ a horse and a sheep
80
+ a sheep and a cow
81
+ a cow and an elephant
82
+ an elephant and a bear
83
+ a bear and a zebra
84
+ a zebra and a giraffe
85
+ a giraffe and a bird
86
+ a chair and a couch
87
+ a couch and a potted plant
88
+ a potted plant and a tv
89
+ a tv and a laptop
90
+ a laptop and a remote
91
+ a remote and a keyboard
92
+ a keyboard and a cell phone
93
+ a cell phone and a book
94
+ a book and a clock
95
+ a clock and a backpack
96
+ a backpack and an umbrella
97
+ an umbrella and a handbag
98
+ a handbag and a tie
99
+ a tie and a suitcase
100
+ a suitcase and a vase
101
+ a vase and scissors
102
+ scissors and a teddy bear
103
+ a teddy bear and a frisbee
104
+ a frisbee and skis
105
+ skis and a snowboard
106
+ a snowboard and a sports ball
107
+ a sports ball and a kite
108
+ a kite and a baseball bat
109
+ a baseball bat and a baseball glove
110
+ a baseball glove and a skateboard
111
+ a skateboard and a surfboard
112
+ a surfboard and a tennis racket
113
+ a tennis racket and a bottle
114
+ a bottle and a chair
115
+ an airplane and a train
116
+ a train and a boat
117
+ a boat and an airplane
118
+ a bicycle and a car
119
+ a car and a motorcycle
120
+ a motorcycle and a bus
121
+ a bus and a traffic light
122
+ a traffic light and a fire hydrant
123
+ a fire hydrant and a stop sign
124
+ a stop sign and a parking meter
125
+ a parking meter and a truck
126
+ a truck and a bicycle
127
+ a toilet and a hair drier
128
+ a hair drier and a toothbrush
129
+ a toothbrush and a sink
130
+ a sink and a toilet
131
+ a wine glass and a chair
132
+ a cup and a couch
133
+ a fork and a potted plant
134
+ a knife and a tv
135
+ a spoon and a laptop
136
+ a bowl and a remote
137
+ a banana and a keyboard
138
+ an apple and a cell phone
139
+ a sandwich and a book
140
+ an orange and a clock
141
+ broccoli and a backpack
142
+ a carrot and an umbrella
143
+ a hot dog and a handbag
144
+ a pizza and a tie
145
+ a donut and a suitcase
146
+ a cake and a vase
147
+ an oven and scissors
148
+ a toaster and a teddy bear
149
+ a microwave and a frisbee
150
+ a refrigerator and skis
151
+ a bicycle and an airplane
152
+ a car and a train
153
+ a motorcycle and a boat
154
+ a person and a toilet
155
+ a person and a hair drier
156
+ a person and a toothbrush
157
+ a person and a sink
158
+ A person is riding a bike
159
+ A person is marching
160
+ A person is roller skating
161
+ A person is tasting beer
162
+ A person is clapping
163
+ A person is drawing
164
+ A person is petting animal (not cat)
165
+ A person is eating watermelon
166
+ A person is playing harp
167
+ A person is wrestling
168
+ A person is riding scooter
169
+ A person is sweeping floor
170
+ A person is skateboarding
171
+ A person is dunking basketball
172
+ A person is playing flute
173
+ A person is stretching leg
174
+ A person is tying tie
175
+ A person is skydiving
176
+ A person is shooting goal (soccer)
177
+ A person is playing piano
178
+ A person is finger snapping
179
+ A person is canoeing or kayaking
180
+ A person is laughing
181
+ A person is digging
182
+ A person is clay pottery making
183
+ A person is shooting basketball
184
+ A person is bending back
185
+ A person is shaking hands
186
+ A person is bandaging
187
+ A person is push up
188
+ A person is catching or throwing frisbee
189
+ A person is playing trumpet
190
+ A person is flying kite
191
+ A person is filling eyebrows
192
+ A person is shuffling cards
193
+ A person is folding clothes
194
+ A person is smoking
195
+ A person is tai chi
196
+ A person is squat
197
+ A person is playing controller
198
+ A person is throwing axe
199
+ A person is giving or receiving award
200
+ A person is air drumming
201
+ A person is taking a shower
202
+ A person is planting trees
203
+ A person is sharpening knives
204
+ A person is robot dancing
205
+ A person is rock climbing
206
+ A person is hula hooping
207
+ A person is writing
208
+ A person is bungee jumping
209
+ A person is pushing cart
210
+ A person is cleaning windows
211
+ A person is cutting watermelon
212
+ A person is cheerleading
213
+ A person is washing hands
214
+ A person is ironing
215
+ A person is cutting nails
216
+ A person is hugging
217
+ A person is trimming or shaving beard
218
+ A person is jogging
219
+ A person is making bed
220
+ A person is washing dishes
221
+ A person is grooming dog
222
+ A person is doing laundry
223
+ A person is knitting
224
+ A person is reading book
225
+ A person is baby waking up
226
+ A person is massaging legs
227
+ A person is brushing teeth
228
+ A person is crawling baby
229
+ A person is motorcycling
230
+ A person is driving car
231
+ A person is sticking tongue out
232
+ A person is shaking head
233
+ A person is sword fighting
234
+ A person is doing aerobics
235
+ A person is strumming guitar
236
+ A person is riding or walking with horse
237
+ A person is archery
238
+ A person is catching or throwing baseball
239
+ A person is playing chess
240
+ A person is rock scissors paper
241
+ A person is using computer
242
+ A person is arranging flowers
243
+ A person is bending metal
244
+ A person is ice skating
245
+ A person is climbing a rope
246
+ A person is crying
247
+ A person is dancing ballet
248
+ A person is getting a haircut
249
+ A person is running on treadmill
250
+ A person is kissing
251
+ A person is counting money
252
+ A person is barbequing
253
+ A person is peeling apples
254
+ A person is milking cow
255
+ A person is shining shoes
256
+ A person is making snowman
257
+ A person is sailing
258
+ a person swimming in ocean
259
+ a person giving a presentation to a room full of colleagues
260
+ a person washing the dishes
261
+ a person eating a burger
262
+ a person walking in the snowstorm
263
+ a person drinking coffee in a cafe
264
+ a person playing guitar
265
+ a bicycle leaning against a tree
266
+ a bicycle gliding through a snowy field
267
+ a bicycle slowing down to stop
268
+ a bicycle accelerating to gain speed
269
+ a car stuck in traffic during rush hour
270
+ a car turning a corner
271
+ a car slowing down to stop
272
+ a car accelerating to gain speed
273
+ a motorcycle cruising along a coastal highway
274
+ a motorcycle turning a corner
275
+ a motorcycle slowing down to stop
276
+ a motorcycle gliding through a snowy field
277
+ a motorcycle accelerating to gain speed
278
+ an airplane soaring through a clear blue sky
279
+ an airplane taking off
280
+ an airplane landing smoothly on a runway
281
+ an airplane accelerating to gain speed
282
+ a bus turning a corner
283
+ a bus stuck in traffic during rush hour
284
+ a bus accelerating to gain speed
285
+ a train speeding down the tracks
286
+ a train crossing over a tall bridge
287
+ a train accelerating to gain speed
288
+ a truck turning a corner
289
+ a truck anchored in a tranquil bay
290
+ a truck stuck in traffic during rush hour
291
+ a truck slowing down to stop
292
+ a truck accelerating to gain speed
293
+ a boat sailing smoothly on a calm lake
294
+ a boat slowing down to stop
295
+ a boat accelerating to gain speed
296
+ a bird soaring gracefully in the sky
297
+ a bird building a nest from twigs and leaves
298
+ a bird flying over a snowy forest
299
+ a cat grooming itself meticulously with its tongue
300
+ a cat playing in park
301
+ a cat drinking water
302
+ a cat running happily
303
+ a dog enjoying a peaceful walk
304
+ a dog playing in park
305
+ a dog drinking water
306
+ a dog running happily
307
+ a horse bending down to drink water from a river
308
+ a horse galloping across an open field
309
+ a horse taking a peaceful walk
310
+ a horse running to join a herd of its kind
311
+ a sheep bending down to drink water from a river
312
+ a sheep taking a peaceful walk
313
+ a sheep running to join a herd of its kind
314
+ a cow bending down to drink water from a river
315
+ a cow chewing cud while resting in a tranquil barn
316
+ a cow running to join a herd of its kind
317
+ an elephant spraying itself with water using its trunk to cool down
318
+ an elephant taking a peaceful walk
319
+ an elephant running to join a herd of its kind
320
+ a bear catching a salmon in its powerful jaws
321
+ a bear sniffing the air for scents of food
322
+ a bear climbing a tree
323
+ a bear hunting for prey
324
+ a zebra bending down to drink water from a river
325
+ a zebra running to join a herd of its kind
326
+ a zebra taking a peaceful walk
327
+ a giraffe bending down to drink water from a river
328
+ a giraffe taking a peaceful walk
329
+ a giraffe running to join a herd of its kind
330
+ a person
331
+ a bicycle
332
+ a car
333
+ a motorcycle
334
+ an airplane
335
+ a bus
336
+ a train
337
+ a truck
338
+ a boat
339
+ a traffic light
340
+ a fire hydrant
341
+ a stop sign
342
+ a parking meter
343
+ a bench
344
+ a bird
345
+ a cat
346
+ a dog
347
+ a horse
348
+ a sheep
349
+ a cow
350
+ an elephant
351
+ a bear
352
+ a zebra
353
+ a giraffe
354
+ a backpack
355
+ an umbrella
356
+ a handbag
357
+ a tie
358
+ a suitcase
359
+ a frisbee
360
+ skis
361
+ a snowboard
362
+ a sports ball
363
+ a kite
364
+ a baseball bat
365
+ a baseball glove
366
+ a skateboard
367
+ a surfboard
368
+ a tennis racket
369
+ a bottle
370
+ a wine glass
371
+ a cup
372
+ a fork
373
+ a knife
374
+ a spoon
375
+ a bowl
376
+ a banana
377
+ an apple
378
+ a sandwich
379
+ an orange
380
+ broccoli
381
+ a carrot
382
+ a hot dog
383
+ a pizza
384
+ a donut
385
+ a cake
386
+ a chair
387
+ a couch
388
+ a potted plant
389
+ a bed
390
+ a dining table
391
+ a toilet
392
+ a tv
393
+ a laptop
394
+ a remote
395
+ a keyboard
396
+ a cell phone
397
+ a microwave
398
+ an oven
399
+ a toaster
400
+ a sink
401
+ a refrigerator
402
+ a book
403
+ a clock
404
+ a vase
405
+ scissors
406
+ a teddy bear
407
+ a hair drier
408
+ a toothbrush
409
+ a red bicycle
410
+ a green bicycle
411
+ a blue bicycle
412
+ a yellow bicycle
413
+ an orange bicycle
414
+ a purple bicycle
415
+ a pink bicycle
416
+ a black bicycle
417
+ a white bicycle
418
+ a red car
419
+ a green car
420
+ a blue car
421
+ a yellow car
422
+ an orange car
423
+ a purple car
424
+ a pink car
425
+ a black car
426
+ a white car
427
+ a red bird
428
+ a green bird
429
+ a blue bird
430
+ a yellow bird
431
+ an orange bird
432
+ a purple bird
433
+ a pink bird
434
+ a black bird
435
+ a white bird
436
+ a black cat
437
+ a white cat
438
+ an orange cat
439
+ a yellow cat
440
+ a red umbrella
441
+ a green umbrella
442
+ a blue umbrella
443
+ a yellow umbrella
444
+ an orange umbrella
445
+ a purple umbrella
446
+ a pink umbrella
447
+ a black umbrella
448
+ a white umbrella
449
+ a red suitcase
450
+ a green suitcase
451
+ a blue suitcase
452
+ a yellow suitcase
453
+ an orange suitcase
454
+ a purple suitcase
455
+ a pink suitcase
456
+ a black suitcase
457
+ a white suitcase
458
+ a red bowl
459
+ a green bowl
460
+ a blue bowl
461
+ a yellow bowl
462
+ an orange bowl
463
+ a purple bowl
464
+ a pink bowl
465
+ a black bowl
466
+ a white bowl
467
+ a red chair
468
+ a green chair
469
+ a blue chair
470
+ a yellow chair
471
+ an orange chair
472
+ a purple chair
473
+ a pink chair
474
+ a black chair
475
+ a white chair
476
+ a red clock
477
+ a green clock
478
+ a blue clock
479
+ a yellow clock
480
+ an orange clock
481
+ a purple clock
482
+ a pink clock
483
+ a black clock
484
+ a white clock
485
+ a red vase
486
+ a green vase
487
+ a blue vase
488
+ a yellow vase
489
+ an orange vase
490
+ a purple vase
491
+ a pink vase
492
+ a black vase
493
+ a white vase
494
+ A beautiful coastal beach in spring, waves lapping on sand, Van Gogh style
495
+ A beautiful coastal beach in spring, waves lapping on sand, oil painting
496
+ A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo
497
+ A beautiful coastal beach in spring, waves lapping on sand, black and white
498
+ A beautiful coastal beach in spring, waves lapping on sand, pixel art
499
+ A beautiful coastal beach in spring, waves lapping on sand, in cyberpunk style
500
+ A beautiful coastal beach in spring, waves lapping on sand, animated style
501
+ A beautiful coastal beach in spring, waves lapping on sand, watercolor painting
502
+ A beautiful coastal beach in spring, waves lapping on sand, surrealism style
503
+ The bund Shanghai, Van Gogh style
504
+ The bund Shanghai, oil painting
505
+ The bund Shanghai by Hokusai, in the style of Ukiyo
506
+ The bund Shanghai, black and white
507
+ The bund Shanghai, pixel art
508
+ The bund Shanghai, in cyberpunk style
509
+ The bund Shanghai, animated style
510
+ The bund Shanghai, watercolor painting
511
+ The bund Shanghai, surrealism style
512
+ a shark is swimming in the ocean, Van Gogh style
513
+ a shark is swimming in the ocean, oil painting
514
+ a shark is swimming in the ocean by Hokusai, in the style of Ukiyo
515
+ a shark is swimming in the ocean, black and white
516
+ a shark is swimming in the ocean, pixel art
517
+ a shark is swimming in the ocean, in cyberpunk style
518
+ a shark is swimming in the ocean, animated style
519
+ a shark is swimming in the ocean, watercolor painting
520
+ a shark is swimming in the ocean, surrealism style
521
+ A panda drinking coffee in a cafe in Paris, Van Gogh style
522
+ A panda drinking coffee in a cafe in Paris, oil painting
523
+ A panda drinking coffee in a cafe in Paris by Hokusai, in the style of Ukiyo
524
+ A panda drinking coffee in a cafe in Paris, black and white
525
+ A panda drinking coffee in a cafe in Paris, pixel art
526
+ A panda drinking coffee in a cafe in Paris, in cyberpunk style
527
+ A panda drinking coffee in a cafe in Paris, animated style
528
+ A panda drinking coffee in a cafe in Paris, watercolor painting
529
+ A panda drinking coffee in a cafe in Paris, surrealism style
530
+ A cute happy Corgi playing in park, sunset, Van Gogh style
531
+ A cute happy Corgi playing in park, sunset, oil painting
532
+ A cute happy Corgi playing in park, sunset by Hokusai, in the style of Ukiyo
533
+ A cute happy Corgi playing in park, sunset, black and white
534
+ A cute happy Corgi playing in park, sunset, pixel art
535
+ A cute happy Corgi playing in park, sunset, in cyberpunk style
536
+ A cute happy Corgi playing in park, sunset, animated style
537
+ A cute happy Corgi playing in park, sunset, watercolor painting
538
+ A cute happy Corgi playing in park, sunset, surrealism style
539
+ Gwen Stacy reading a book, Van Gogh style
540
+ Gwen Stacy reading a book, oil painting
541
+ Gwen Stacy reading a book by Hokusai, in the style of Ukiyo
542
+ Gwen Stacy reading a book, black and white
543
+ Gwen Stacy reading a book, pixel art
544
+ Gwen Stacy reading a book, in cyberpunk style
545
+ Gwen Stacy reading a book, animated style
546
+ Gwen Stacy reading a book, watercolor painting
547
+ Gwen Stacy reading a book, surrealism style
548
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, Van Gogh style
549
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, oil painting
550
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Hokusai, in the style of Ukiyo
551
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, black and white
552
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pixel art
553
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in cyberpunk style
554
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, animated style
555
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, watercolor painting
556
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, surrealism style
557
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, Van Gogh style
558
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, oil painting
559
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas by Hokusai, in the style of Ukiyo
560
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, black and white
561
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pixel art
562
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in cyberpunk style
563
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, animated style
564
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, watercolor painting
565
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, surrealism style
566
+ An astronaut flying in space, Van Gogh style
567
+ An astronaut flying in space, oil painting
568
+ An astronaut flying in space by Hokusai, in the style of Ukiyo
569
+ An astronaut flying in space, black and white
570
+ An astronaut flying in space, pixel art
571
+ An astronaut flying in space, in cyberpunk style
572
+ An astronaut flying in space, animated style
573
+ An astronaut flying in space, watercolor painting
574
+ An astronaut flying in space, surrealism style
575
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, Van Gogh style
576
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, oil painting
577
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks by Hokusai, in the style of Ukiyo
578
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, black and white
579
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pixel art
580
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in cyberpunk style
581
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, animated style
582
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, watercolor painting
583
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, surrealism style
584
+ A beautiful coastal beach in spring, waves lapping on sand, in super slow motion
585
+ A beautiful coastal beach in spring, waves lapping on sand, zoom in
586
+ A beautiful coastal beach in spring, waves lapping on sand, zoom out
587
+ A beautiful coastal beach in spring, waves lapping on sand, pan left
588
+ A beautiful coastal beach in spring, waves lapping on sand, pan right
589
+ A beautiful coastal beach in spring, waves lapping on sand, tilt up
590
+ A beautiful coastal beach in spring, waves lapping on sand, tilt down
591
+ A beautiful coastal beach in spring, waves lapping on sand, with an intense shaking effect
592
+ A beautiful coastal beach in spring, waves lapping on sand, featuring a steady and smooth perspective
593
+ A beautiful coastal beach in spring, waves lapping on sand, racking focus
594
+ The bund Shanghai, in super slow motion
595
+ The bund Shanghai, zoom in
596
+ The bund Shanghai, zoom out
597
+ The bund Shanghai, pan left
598
+ The bund Shanghai, pan right
599
+ The bund Shanghai, tilt up
600
+ The bund Shanghai, tilt down
601
+ The bund Shanghai, with an intense shaking effect
602
+ The bund Shanghai, featuring a steady and smooth perspective
603
+ The bund Shanghai, racking focus
604
+ a shark is swimming in the ocean, in super slow motion
605
+ a shark is swimming in the ocean, zoom in
606
+ a shark is swimming in the ocean, zoom out
607
+ a shark is swimming in the ocean, pan left
608
+ a shark is swimming in the ocean, pan right
609
+ a shark is swimming in the ocean, tilt up
610
+ a shark is swimming in the ocean, tilt down
611
+ a shark is swimming in the ocean, with an intense shaking effect
612
+ a shark is swimming in the ocean, featuring a steady and smooth perspective
613
+ a shark is swimming in the ocean, racking focus
614
+ A panda drinking coffee in a cafe in Paris, in super slow motion
615
+ A panda drinking coffee in a cafe in Paris, zoom in
616
+ A panda drinking coffee in a cafe in Paris, zoom out
617
+ A panda drinking coffee in a cafe in Paris, pan left
618
+ A panda drinking coffee in a cafe in Paris, pan right
619
+ A panda drinking coffee in a cafe in Paris, tilt up
620
+ A panda drinking coffee in a cafe in Paris, tilt down
621
+ A panda drinking coffee in a cafe in Paris, with an intense shaking effect
622
+ A panda drinking coffee in a cafe in Paris, featuring a steady and smooth perspective
623
+ A panda drinking coffee in a cafe in Paris, racking focus
624
+ A cute happy Corgi playing in park, sunset, in super slow motion
625
+ A cute happy Corgi playing in park, sunset, zoom in
626
+ A cute happy Corgi playing in park, sunset, zoom out
627
+ A cute happy Corgi playing in park, sunset, pan left
628
+ A cute happy Corgi playing in park, sunset, pan right
629
+ A cute happy Corgi playing in park, sunset, tilt up
630
+ A cute happy Corgi playing in park, sunset, tilt down
631
+ A cute happy Corgi playing in park, sunset, with an intense shaking effect
632
+ A cute happy Corgi playing in park, sunset, featuring a steady and smooth perspective
633
+ A cute happy Corgi playing in park, sunset, racking focus
634
+ Gwen Stacy reading a book, in super slow motion
635
+ Gwen Stacy reading a book, zoom in
636
+ Gwen Stacy reading a book, zoom out
637
+ Gwen Stacy reading a book, pan left
638
+ Gwen Stacy reading a book, pan right
639
+ Gwen Stacy reading a book, tilt up
640
+ Gwen Stacy reading a book, tilt down
641
+ Gwen Stacy reading a book, with an intense shaking effect
642
+ Gwen Stacy reading a book, featuring a steady and smooth perspective
643
+ Gwen Stacy reading a book, racking focus
644
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, in super slow motion
645
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom in
646
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, zoom out
647
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan left
648
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, pan right
649
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt up
650
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, tilt down
651
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, with an intense shaking effect
652
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, featuring a steady and smooth perspective
653
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background, racking focus
654
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, in super slow motion
655
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom in
656
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, zoom out
657
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan left
658
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, pan right
659
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt up
660
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, tilt down
661
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, with an intense shaking effect
662
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, featuring a steady and smooth perspective
663
+ A couple in formal evening wear going home get caught in a heavy downpour with umbrellas, racking focus
664
+ An astronaut flying in space, in super slow motion
665
+ An astronaut flying in space, zoom in
666
+ An astronaut flying in space, zoom out
667
+ An astronaut flying in space, pan left
668
+ An astronaut flying in space, pan right
669
+ An astronaut flying in space, tilt up
670
+ An astronaut flying in space, tilt down
671
+ An astronaut flying in space, with an intense shaking effect
672
+ An astronaut flying in space, featuring a steady and smooth perspective
673
+ An astronaut flying in space, racking focus
674
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, in super slow motion
675
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom in
676
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, zoom out
677
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan left
678
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, pan right
679
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt up
680
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, tilt down
681
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, with an intense shaking effect
682
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, featuring a steady and smooth perspective
683
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks, racking focus
684
+ Close up of grapes on a rotating table.
685
+ Turtle swimming in ocean.
686
+ A storm trooper vacuuming the beach.
687
+ A panda standing on a surfboard in the ocean in sunset.
688
+ An astronaut feeding ducks on a sunny afternoon, reflection from the water.
689
+ Two pandas discussing an academic paper.
690
+ Sunset time lapse at the beach with moving clouds and colors in the sky.
691
+ A fat rabbit wearing a purple robe walking through a fantasy landscape.
692
+ A koala bear playing piano in the forest.
693
+ An astronaut flying in space.
694
+ Fireworks.
695
+ An animated painting of fluffy white clouds moving in sky.
696
+ Flying through fantasy landscapes.
697
+ A bigfoot walking in the snowstorm.
698
+ A squirrel eating a burger.
699
+ A cat wearing sunglasses and working as a lifeguard at a pool.
700
+ Snow rocky mountains peaks canyon. snow blanketed rocky mountains surround and shadow deep canyons. the canyons twist and bend through the high elevated mountain peaks.
701
+ Splash of turquoise water in extreme slow motion, alpha channel included.
702
+ an ice cream is melting on the table.
703
+ a drone flying over a snowy forest.
704
+ a shark is swimming in the ocean.
705
+ Aerial panoramic video from a drone of a fantasy land.
706
+ a teddy bear is swimming in the ocean.
707
+ time lapse of sunrise on mars.
708
+ golden fish swimming in the ocean.
709
+ An artist brush painting on a canvas close up.
710
+ A drone view of celebration with Christmas tree and fireworks, starry sky - background.
711
+ happy dog wearing a yellow turtleneck, studio, portrait, facing camera, dark background
712
+ Origami dancers in white paper, 3D render, on white background, studio shot, dancing modern dance.
713
+ Campfire at night in a snowy forest with starry sky in the background.
714
+ a fantasy landscape
715
+ A 3D model of a 1800s victorian house.
716
+ this is how I do makeup in the morning.
717
+ A raccoon that looks like a turtle, digital art.
718
+ Robot dancing in Times Square.
719
+ Busy freeway at night.
720
+ Balloon full of water exploding in extreme slow motion.
721
+ An astronaut is riding a horse in the space in a photorealistic style.
722
+ Macro slo-mo. Slow motion cropped closeup of roasted coffee beans falling into an empty bowl.
723
+ Sewing machine, old sewing machine working.
724
+ Motion colour drop in water, ink swirling in water, colourful ink in water, abstraction fancy dream cloud of ink.
725
+ Few big purple plums rotating on the turntable. water drops appear on the skin during rotation. isolated on the white background. close-up. macro.
726
+ Vampire makeup face of beautiful girl, red contact lenses.
727
+ Ashtray full of butts on table, smoke flowing on black background, close-up
728
+ Pacific coast, carmel by the sea ocean and waves.
729
+ A teddy bear is playing drum kit in NYC Times Square.
730
+ A corgi is playing drum kit.
731
+ An Iron man is playing the electronic guitar, high electronic guitar.
732
+ A raccoon is playing the electronic guitar.
733
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background by Vincent van Gogh
734
+ A corgi's head depicted as an explosion of a nebula
735
+ A fantasy landscape
736
+ A future where humans have achieved teleportation technology
737
+ A jellyfish floating through the ocean, with bioluminescent tentacles
738
+ A Mars rover moving on Mars
739
+ A panda drinking coffee in a cafe in Paris
740
+ A space shuttle launching into orbit, with flames and smoke billowing out from the engines
741
+ A steam train moving on a mountainside
742
+ A super cool giant robot in Cyberpunk Beijing
743
+ A tropical beach at sunrise, with palm trees and crystal-clear water in the foreground
744
+ Cinematic shot of Van Gogh's selfie, Van Gogh style
745
+ Gwen Stacy reading a book
746
+ Iron Man flying in the sky
747
+ The bund Shanghai, oil painting
748
+ Yoda playing guitar on the stage
749
+ A beautiful coastal beach in spring, waves lapping on sand by Hokusai, in the style of Ukiyo
750
+ A beautiful coastal beach in spring, waves lapping on sand by Vincent van Gogh
751
+ A boat sailing leisurely along the Seine River with the Eiffel Tower in background
752
+ A car moving slowly on an empty street, rainy evening
753
+ A cat eating food out of a bowl
754
+ A cat wearing sunglasses at a pool
755
+ A confused panda in calculus class
756
+ A cute fluffy panda eating Chinese food in a restaurant
757
+ A cute happy Corgi playing in park, sunset
758
+ A cute raccoon playing guitar in a boat on the ocean
759
+ A happy fuzzy panda playing guitar nearby a campfire, snow mountain in the background
760
+ A lightning striking atop of eiffel tower, dark clouds in the sky
761
+ A modern art museum, with colorful paintings
762
+ A panda cooking in the kitchen
763
+ A panda playing on a swing set
764
+ A polar bear is playing guitar
765
+ A raccoon dressed in suit playing the trumpet, stage background
766
+ A robot DJ is playing the turntable, in heavy raining futuristic tokyo rooftop cyberpunk night, sci-fi, fantasy
767
+ A shark swimming in clear Caribbean ocean
768
+ A super robot protecting city
769
+ A teddy bear washing the dishes
770
+ An epic tornado attacking above a glowing city at night, the tornado is made of smoke
771
+ An oil painting of a couple in formal evening wear going home get caught in a heavy downpour with umbrellas
772
+ Clown fish swimming through the coral reef
773
+ Hyper-realistic spaceship landing on Mars
774
+ The bund Shanghai, vibrant color
775
+ Vincent van Gogh is painting in the room
776
+ Yellow flowers swing in the wind
777
+ alley
778
+ amusement park
779
+ aquarium
780
+ arch
781
+ art gallery
782
+ bathroom
783
+ bakery shop
784
+ ballroom
785
+ bar
786
+ barn
787
+ basement
788
+ beach
789
+ bedroom
790
+ bridge
791
+ botanical garden
792
+ cafeteria
793
+ campsite
794
+ campus
795
+ carrousel
796
+ castle
797
+ cemetery
798
+ classroom
799
+ cliff
800
+ crosswalk
801
+ construction site
802
+ corridor
803
+ courtyard
804
+ desert
805
+ downtown
806
+ driveway
807
+ farm
808
+ food court
809
+ football field
810
+ forest road
811
+ fountain
812
+ gas station
813
+ glacier
814
+ golf course
815
+ indoor gymnasium
816
+ harbor
817
+ highway
818
+ hospital
819
+ house
820
+ iceberg
821
+ industrial area
822
+ jail cell
823
+ junkyard
824
+ kitchen
825
+ indoor library
826
+ lighthouse
827
+ laboratory
828
+ mansion
829
+ marsh
830
+ mountain
831
+ indoor movie theater
832
+ indoor museum
833
+ music studio
834
+ nursery
835
+ ocean
836
+ office
837
+ palace
838
+ parking lot
839
+ pharmacy
840
+ phone booth
841
+ raceway
842
+ restaurant
843
+ river
844
+ science museum
845
+ shower
846
+ ski slope
847
+ sky
848
+ skyscraper
849
+ baseball stadium
850
+ staircase
851
+ street
852
+ supermarket
853
+ indoor swimming pool
854
+ tower
855
+ outdoor track
856
+ train railway
857
+ train station platform
858
+ underwater coral reef
859
+ valley
860
+ volcano
861
+ waterfall
862
+ windmill
863
+ a bicycle on the left of a car, front view
864
+ a car on the right of a motorcycle, front view
865
+ a motorcycle on the left of a bus, front view
866
+ a bus on the right of a traffic light, front view
867
+ a traffic light on the left of a fire hydrant, front view
868
+ a fire hydrant on the right of a stop sign, front view
869
+ a stop sign on the left of a parking meter, front view
870
+ a parking meter on the right of a bench, front view
871
+ a bench on the left of a truck, front view
872
+ a truck on the right of a bicycle, front view
873
+ a bird on the left of a cat, front view
874
+ a cat on the right of a dog, front view
875
+ a dog on the left of a horse, front view
876
+ a horse on the right of a sheep, front view
877
+ a sheep on the left of a cow, front view
878
+ a cow on the right of an elephant, front view
879
+ an elephant on the left of a bear, front view
880
+ a bear on the right of a zebra, front view
881
+ a zebra on the left of a giraffe, front view
882
+ a giraffe on the right of a bird, front view
883
+ a bottle on the left of a wine glass, front view
884
+ a wine glass on the right of a cup, front view
885
+ a cup on the left of a fork, front view
886
+ a fork on the right of a knife, front view
887
+ a knife on the left of a spoon, front view
888
+ a spoon on the right of a bowl, front view
889
+ a bowl on the left of a bottle, front view
890
+ a potted plant on the left of a remote, front view
891
+ a remote on the right of a clock, front view
892
+ a clock on the left of a vase, front view
893
+ a vase on the right of scissors, front view
894
+ scissors on the left of a teddy bear, front view
895
+ a teddy bear on the right of a potted plant, front view
896
+ a frisbee on the left of a sports ball, front view
897
+ a sports ball on the right of a baseball bat, front view
898
+ a baseball bat on the left of a baseball glove, front view
899
+ a baseball glove on the right of a tennis racket, front view
900
+ a tennis racket on the left of a frisbee, front view
901
+ a toilet on the left of a hair drier, front view
902
+ a hair drier on the right of a toothbrush, front view
903
+ a toothbrush on the left of a sink, front view
904
+ a sink on the right of a toilet, front view
905
+ a chair on the left of a couch, front view
906
+ a couch on the right of a bed, front view
907
+ a bed on the left of a tv, front view
908
+ a tv on the right of a dining table, front view
909
+ a dining table on the left of a chair, front view
910
+ an airplane on the left of a train, front view
911
+ a train on the right of a boat, front view
912
+ a boat on the left of an airplane, front view
913
+ an oven on the top of a toaster, front view
914
+ an oven on the bottom of a toaster, front view
915
+ a toaster on the top of a microwave, front view
916
+ a toaster on the bottom of a microwave, front view
917
+ a microwave on the top of an oven, front view
918
+ a microwave on the bottom of an oven, front view
919
+ a banana on the top of an apple, front view
920
+ a banana on the bottom of an apple, front view
921
+ an apple on the top of a sandwich, front view
922
+ an apple on the bottom of a sandwich, front view
923
+ a sandwich on the top of an orange, front view
924
+ a sandwich on the bottom of an orange, front view
925
+ an orange on the top of a carrot, front view
926
+ an orange on the bottom of a carrot, front view
927
+ a carrot on the top of a hot dog, front view
928
+ a carrot on the bottom of a hot dog, front view
929
+ a hot dog on the top of a pizza, front view
930
+ a hot dog on the bottom of a pizza, front view
931
+ a pizza on the top of a donut, front view
932
+ a pizza on the bottom of a donut, front view
933
+ a donut on the top of broccoli, front view
934
+ a donut on the bottom of broccoli, front view
935
+ broccoli on the top of a banana, front view
936
+ broccoli on the bottom of a banana, front view
937
+ skis on the top of a snowboard, front view
938
+ skis on the bottom of a snowboard, front view
939
+ a snowboard on the top of a kite, front view
940
+ a snowboard on the bottom of a kite, front view
941
+ a kite on the top of a skateboard, front view
942
+ a kite on the bottom of a skateboard, front view
943
+ a skateboard on the top of a surfboard, front view
944
+ a skateboard on the bottom of a surfboard, front view
945
+ a surfboard on the top of skis, front view
946
+ a surfboard on the bottom of skis, front view
prompts/vbench/all_dimension_extended.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.4.0
2
+ torchvision>=0.19.0
3
+ opencv-python>=4.9.0.80
4
+ diffusers==0.31.0
5
+ transformers>=4.49.0
6
+ tokenizers>=0.20.3
7
+ accelerate>=1.1.1
8
+ tqdm
9
+ imageio
10
+ easydict
11
+ ftfy
12
+ dashscope
13
+ imageio-ffmpeg
14
+ numpy==1.24.4
15
+ wandb
16
+ omegaconf
17
+ einops
18
+ av==13.1.0
19
+ opencv-python
20
+ git+https://github.com/openai/CLIP.git
21
+ open_clip_torch
22
+ starlette
23
+ pycocotools
24
+ lmdb
25
+ matplotlib
26
+ sentencepiece
27
+ pydantic==2.10.6
28
+ scikit-image
29
+ huggingface_hub[cli]
30
+ dominate
31
+ nvidia-pyindex
32
+ nvidia-tensorrt
33
+ pycuda
34
+ onnx
35
+ onnxruntime
36
+ onnxscript
37
+ onnxconverter_common
38
+ flask
39
+ flask-socketio
40
+ torchao
scripts/create_lmdb_14b_shards.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ python create_lmdb_14b_shards.py \
3
+ --data_path /mnt/localssd/wanx_14b_data \
4
+ --lmdb_path /mnt/localssd/wanx_14B_shift-3.0_cfg-5.0_lmdb
5
+ """
6
+ from tqdm import tqdm
7
+ import numpy as np
8
+ import argparse
9
+ import torch
10
+ import lmdb
11
+ import glob
12
+ import os
13
+
14
+ from utils.lmdb import store_arrays_to_lmdb, process_data_dict
15
+
16
+
17
+ def main():
18
+ """
19
+ Aggregate all ode pairs inside a folder into a lmdb dataset.
20
+ Each pt file should contain a (key, value) pair representing a
21
+ video's ODE trajectories.
22
+ """
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument("--data_path", type=str,
25
+ required=True, help="path to ode pairs")
26
+ parser.add_argument("--lmdb_path", type=str,
27
+ required=True, help="path to lmdb")
28
+ parser.add_argument("--num_shards", type=int,
29
+ default=16, help="num_shards")
30
+
31
+ args = parser.parse_args()
32
+
33
+ all_dirs = sorted(os.listdir(args.data_path))
34
+
35
+ # figure out the maximum map size needed
36
+ map_size = int(1e12) # adapt to your need, set to 1TB by default
37
+ os.makedirs(args.lmdb_path, exist_ok=True)
38
+ # 1) Open one LMDB env per shard
39
+ envs = []
40
+ num_shards = args.num_shards
41
+ for shard_id in range(num_shards):
42
+ print("shard_id ", shard_id)
43
+ path = os.path.join(args.lmdb_path, f"shard_{shard_id}")
44
+ env = lmdb.open(path,
45
+ map_size=map_size,
46
+ subdir=True, # set to True if you want a directory per env
47
+ readonly=False,
48
+ metasync=True,
49
+ sync=True,
50
+ lock=True,
51
+ readahead=False,
52
+ meminit=False)
53
+ envs.append(env)
54
+
55
+ counters = [0] * num_shards
56
+ seen_prompts = set() # for deduplication
57
+ total_samples = 0
58
+ all_files = []
59
+
60
+ for part_dir in all_dirs:
61
+ all_files += sorted(glob.glob(os.path.join(args.data_path, part_dir, "*.pt")))
62
+
63
+ # 2) Prepare a write transaction for each shard
64
+ for idx, file in tqdm(enumerate(all_files)):
65
+ try:
66
+ data_dict = torch.load(file)
67
+ data_dict = process_data_dict(data_dict, seen_prompts)
68
+ except Exception as e:
69
+ print(f"Error processing {file}: {e}")
70
+ continue
71
+
72
+ if data_dict["latents"].shape != (1, 21, 16, 60, 104):
73
+ continue
74
+
75
+ shard_id = idx % num_shards
76
+ # write to lmdb file
77
+ store_arrays_to_lmdb(envs[shard_id], data_dict, start_index=counters[shard_id])
78
+ counters[shard_id] += len(data_dict['prompts'])
79
+ data_shape = data_dict["latents"].shape
80
+
81
+ total_samples += len(all_files)
82
+
83
+ print(len(seen_prompts))
84
+
85
+ # save each entry's shape to lmdb
86
+ for shard_id, env in enumerate(envs):
87
+ with env.begin(write=True) as txn:
88
+ for key, val in (data_dict.items()):
89
+ assert len(data_shape) == 5
90
+ array_shape = np.array(data_shape) # val.shape)
91
+ array_shape[0] = counters[shard_id]
92
+ shape_key = f"{key}_shape".encode()
93
+ print(shape_key, array_shape)
94
+ shape_str = " ".join(map(str, array_shape))
95
+ txn.put(shape_key, shape_str.encode())
96
+
97
+ print(f"Finished writing {total_samples} examples into {num_shards} shards under {args.lmdb_path}")
98
+
99
+
100
+ if __name__ == "__main__":
101
+ main()
scripts/create_lmdb_iterative.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import numpy as np
3
+ import argparse
4
+ import torch
5
+ import lmdb
6
+ import glob
7
+ import os
8
+
9
+ from utils.lmdb import store_arrays_to_lmdb, process_data_dict
10
+
11
+
12
+ def main():
13
+ """
14
+ Aggregate all ode pairs inside a folder into a lmdb dataset.
15
+ Each pt file should contain a (key, value) pair representing a
16
+ video's ODE trajectories.
17
+ """
18
+ parser = argparse.ArgumentParser()
19
+ parser.add_argument("--data_path", type=str,
20
+ required=True, help="path to ode pairs")
21
+ parser.add_argument("--lmdb_path", type=str,
22
+ required=True, help="path to lmdb")
23
+
24
+ args = parser.parse_args()
25
+
26
+ all_files = sorted(glob.glob(os.path.join(args.data_path, "*.pt")))
27
+
28
+ # figure out the maximum map size needed
29
+ total_array_size = 5000000000000 # adapt to your need, set to 5TB by default
30
+
31
+ env = lmdb.open(args.lmdb_path, map_size=total_array_size * 2)
32
+
33
+ counter = 0
34
+
35
+ seen_prompts = set() # for deduplication
36
+
37
+ for index, file in tqdm(enumerate(all_files)):
38
+ # read from disk
39
+ data_dict = torch.load(file)
40
+
41
+ data_dict = process_data_dict(data_dict, seen_prompts)
42
+
43
+ # write to lmdb file
44
+ store_arrays_to_lmdb(env, data_dict, start_index=counter)
45
+ counter += len(data_dict['prompts'])
46
+
47
+ # save each entry's shape to lmdb
48
+ with env.begin(write=True) as txn:
49
+ for key, val in data_dict.items():
50
+ print(key, val)
51
+ array_shape = np.array(val.shape)
52
+ array_shape[0] = counter
53
+
54
+ shape_key = f"{key}_shape".encode()
55
+ shape_str = " ".join(map(str, array_shape))
56
+ txn.put(shape_key, shape_str.encode())
57
+
58
+
59
+ if __name__ == "__main__":
60
+ main()
scripts/generate_ode_pairs.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.distributed import launch_distributed_job
2
+ from utils.scheduler import FlowMatchScheduler
3
+ from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder
4
+ from utils.dataset import TextDataset
5
+ import torch.distributed as dist
6
+ from tqdm import tqdm
7
+ import argparse
8
+ import torch
9
+ import math
10
+ import os
11
+
12
+
13
+ def init_model(device):
14
+ model = WanDiffusionWrapper().to(device).to(torch.float32)
15
+ encoder = WanTextEncoder().to(device).to(torch.float32)
16
+ model.model.requires_grad_(False)
17
+
18
+ scheduler = FlowMatchScheduler(
19
+ shift=8.0, sigma_min=0.0, extra_one_step=True)
20
+ scheduler.set_timesteps(num_inference_steps=48, denoising_strength=1.0)
21
+ scheduler.sigmas = scheduler.sigmas.to(device)
22
+
23
+ sample_neg_prompt = '色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走'
24
+
25
+ unconditional_dict = encoder(
26
+ text_prompts=[sample_neg_prompt]
27
+ )
28
+
29
+ return model, encoder, scheduler, unconditional_dict
30
+
31
+
32
+ def main():
33
+ parser = argparse.ArgumentParser()
34
+ parser.add_argument("--local_rank", type=int, default=-1)
35
+ parser.add_argument("--output_folder", type=str)
36
+ parser.add_argument("--caption_path", type=str)
37
+ parser.add_argument("--guidance_scale", type=float, default=6.0)
38
+
39
+ args = parser.parse_args()
40
+
41
+ # launch_distributed_job()
42
+ launch_distributed_job()
43
+
44
+ device = torch.cuda.current_device()
45
+
46
+ torch.set_grad_enabled(False)
47
+ torch.backends.cuda.matmul.allow_tf32 = True
48
+ torch.backends.cudnn.allow_tf32 = True
49
+
50
+ model, encoder, scheduler, unconditional_dict = init_model(device=device)
51
+
52
+ dataset = TextDataset(args.caption_path)
53
+
54
+ # if global_rank == 0:
55
+ os.makedirs(args.output_folder, exist_ok=True)
56
+
57
+ for index in tqdm(range(int(math.ceil(len(dataset) / dist.get_world_size()))), disable=dist.get_rank() != 0):
58
+ prompt_index = index * dist.get_world_size() + dist.get_rank()
59
+ if prompt_index >= len(dataset):
60
+ continue
61
+ prompt = dataset[prompt_index]
62
+
63
+ conditional_dict = encoder(text_prompts=prompt)
64
+
65
+ latents = torch.randn(
66
+ [1, 21, 16, 60, 104], dtype=torch.float32, device=device
67
+ )
68
+
69
+ noisy_input = []
70
+
71
+ for progress_id, t in enumerate(tqdm(scheduler.timesteps)):
72
+ timestep = t * \
73
+ torch.ones([1, 21], device=device, dtype=torch.float32)
74
+
75
+ noisy_input.append(latents)
76
+
77
+ _, x0_pred_cond = model(
78
+ latents, conditional_dict, timestep
79
+ )
80
+
81
+ _, x0_pred_uncond = model(
82
+ latents, unconditional_dict, timestep
83
+ )
84
+
85
+ x0_pred = x0_pred_uncond + args.guidance_scale * (
86
+ x0_pred_cond - x0_pred_uncond
87
+ )
88
+
89
+ flow_pred = model._convert_x0_to_flow_pred(
90
+ scheduler=scheduler,
91
+ x0_pred=x0_pred.flatten(0, 1),
92
+ xt=latents.flatten(0, 1),
93
+ timestep=timestep.flatten(0, 1)
94
+ ).unflatten(0, x0_pred.shape[:2])
95
+
96
+ latents = scheduler.step(
97
+ flow_pred.flatten(0, 1),
98
+ scheduler.timesteps[progress_id] * torch.ones(
99
+ [1, 21], device=device, dtype=torch.long).flatten(0, 1),
100
+ latents.flatten(0, 1)
101
+ ).unflatten(dim=0, sizes=flow_pred.shape[:2])
102
+
103
+ noisy_input.append(latents)
104
+
105
+ noisy_inputs = torch.stack(noisy_input, dim=1)
106
+
107
+ noisy_inputs = noisy_inputs[:, [0, 12, 24, 36, -1]]
108
+
109
+ stored_data = noisy_inputs
110
+
111
+ torch.save(
112
+ {prompt: stored_data.cpu().detach()},
113
+ os.path.join(args.output_folder, f"{prompt_index:05d}.pt")
114
+ )
115
+
116
+ dist.barrier()
117
+
118
+
119
+ if __name__ == "__main__":
120
+ main()
setup.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from setuptools import setup, find_packages
2
+ setup(
3
+ name="self_forcing",
4
+ version="0.0.1",
5
+ packages=find_packages(),
6
+ )
templates/demo.html ADDED
@@ -0,0 +1,615 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>Self Forcing</title>
7
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/socket.io/4.0.0/socket.io.js"></script>
8
+ <style>
9
+ body {
10
+ font-family: Arial, sans-serif;
11
+ max-width: 1400px;
12
+ margin: 0 auto;
13
+ padding: 20px;
14
+ background-color: #f5f5f5;
15
+ }
16
+ .container {
17
+ background: white;
18
+ padding: 20px;
19
+ border-radius: 10px;
20
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
21
+ }
22
+ .main-layout {
23
+ display: grid;
24
+ grid-template-columns: 1fr 1fr;
25
+ gap: 30px;
26
+ margin-top: 20px;
27
+ }
28
+ .left-column {
29
+ padding-right: 15px;
30
+ }
31
+ .right-column {
32
+ padding-left: 15px;
33
+ }
34
+ @media (max-width: 768px) {
35
+ .main-layout {
36
+ grid-template-columns: 1fr;
37
+ gap: 20px;
38
+ }
39
+ .left-column, .right-column {
40
+ padding: 0;
41
+ }
42
+ }
43
+ .controls {
44
+ margin-bottom: 20px;
45
+ }
46
+ .control-group {
47
+ margin-bottom: 15px;
48
+ }
49
+ label {
50
+ display: block;
51
+ margin-bottom: 5px;
52
+ font-weight: bold;
53
+ }
54
+ input, textarea, button, select {
55
+ padding: 8px;
56
+ border: 1px solid #ddd;
57
+ border-radius: 4px;
58
+ }
59
+ textarea {
60
+ width: 100%;
61
+ height: 90px;
62
+ resize: vertical;
63
+ }
64
+ input[type="range"] {
65
+ width: 200px;
66
+ }
67
+ button {
68
+ background-color: #007bff;
69
+ color: white;
70
+ border: none;
71
+ padding: 10px 20px;
72
+ cursor: pointer;
73
+ margin-right: 10px;
74
+ }
75
+ button:hover {
76
+ background-color: #0056b3;
77
+ }
78
+ button:disabled {
79
+ background-color: #6c757d;
80
+ cursor: not-allowed;
81
+ }
82
+ .stop-btn {
83
+ background-color: #dc3545;
84
+ }
85
+ .stop-btn:hover {
86
+ background-color: #c82333;
87
+ }
88
+ .video-container {
89
+ text-align: center;
90
+ background: #000;
91
+ border-radius: 8px;
92
+ padding: 20px;
93
+ margin: 20px auto;
94
+ display: flex;
95
+ flex-direction: column;
96
+ align-items: center;
97
+ justify-content: center;
98
+ }
99
+ #videoFrame {
100
+ max-width: 100%;
101
+ height: auto;
102
+ border-radius: 4px;
103
+ }
104
+ .progress-container {
105
+ margin: 20px 0;
106
+ }
107
+ .progress-bar {
108
+ width: 100%;
109
+ height: 20px;
110
+ background-color: #e9ecef;
111
+ border-radius: 10px;
112
+ overflow: hidden;
113
+ }
114
+ .progress-fill {
115
+ height: 100%;
116
+ background-color: #007bff;
117
+ transition: width 0.3s ease;
118
+ }
119
+ .status {
120
+ margin: 10px 0;
121
+ padding: 10px;
122
+ border-radius: 4px;
123
+ }
124
+ .status.info {
125
+ background-color: #d1ecf1;
126
+ color: #0c5460;
127
+ }
128
+ .status.error {
129
+ background-color: #f8d7da;
130
+ color: #721c24;
131
+ }
132
+ .status.success {
133
+ background-color: #d4edda;
134
+ color: #155724;
135
+ }
136
+ .frame-info {
137
+ color: #666;
138
+ font-size: 0.9em;
139
+ margin-top: 10px;
140
+ }
141
+ .buffer-info {
142
+ background-color: #e3f2fd;
143
+ padding: 15px;
144
+ border-radius: 4px;
145
+ margin: 15px 0;
146
+ color: #1976d2;
147
+ }
148
+ .playback-controls {
149
+ margin: 15px 0;
150
+ display: flex;
151
+ align-items: center;
152
+ justify-content: center;
153
+ gap: 10px;
154
+ }
155
+ .playback-controls button {
156
+ margin: 0 5px;
157
+ padding: 8px 15px;
158
+ }
159
+ #playbackSpeed {
160
+ width: 80px;
161
+ }
162
+ .torch-compile-toggle {
163
+ background-color: #f8f9fa;
164
+ border: 1px solid #dee2e6;
165
+ border-radius: 6px;
166
+ padding: 10px;
167
+ margin: 0;
168
+ flex: 1;
169
+ min-width: 120px;
170
+ }
171
+ .torch-compile-toggle label {
172
+ display: flex;
173
+ align-items: center;
174
+ font-weight: bold;
175
+ color: #495057;
176
+ margin-bottom: 0;
177
+ font-size: 0.9em;
178
+ }
179
+ .torch-compile-toggle input[type="checkbox"] {
180
+ transform: scale(1.1);
181
+ margin-right: 8px;
182
+ }
183
+ </style>
184
+ </head>
185
+ <body>
186
+ <div class="container">
187
+ <h1>🚀 Self Forcing</h1>
188
+
189
+ <div class="main-layout">
190
+ <div class="left-column">
191
+ <div class="controls">
192
+ <div class="control-group">
193
+ <label for="prompt">Prompt (long, detailed prompts work better):</label>
194
+ <textarea id="prompt" placeholder="Describe the video you want to generate..."></textarea>
195
+
196
+ <div style="margin-top: 10px;">
197
+ <label>Quick Prompts:</label>
198
+ <div style="display: flex; flex-direction: column; gap: 8px; margin-top: 5px;">
199
+ <button type="button" onclick="setQuickPrompt('quick-demo-1')" style="background-color: #28a745; font-size: 11px; padding: 8px; width: 100%; text-align: left; white-space: pre-wrap; line-height: 1.3; min-height: 60px; border-radius: 4px; color: white; border: none; cursor: pointer;">A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.</button>
200
+ <button type="button" onclick="setQuickPrompt('quick-demo-2')" style="background-color: #17a2b8; font-size: 11px; padding: 8px; width: 100%; text-align: left; white-space: pre-wrap; line-height: 1.3; min-height: 60px; border-radius: 4px; color: white; border: none; cursor: pointer;">A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates a warm contrast, accentuating the cat’s orange fur. The shot is clear and sharp, with a shallow depth of field.</button>
201
+ </div>
202
+ </div>
203
+ </div>
204
+
205
+ <div style="display: flex; gap: 20px;">
206
+ <div class="control-group">
207
+ <label for="seed">Seed:</label>
208
+ <input type="number" id="seed" value="-1" min="0" max="999999">
209
+ </div>
210
+
211
+ <div class="control-group">
212
+ <label for="fps">Target FPS: <span id="fpsValue">6</span></label>
213
+ <input type="range" id="fps" min="2" max="16" value="6" step="0.5">
214
+ </div>
215
+
216
+ <!-- <div class="control-group">
217
+ <label for="blocks">Total Blocks: <span id="blocksValue">7</span></label>
218
+ <input type="range" id="blocks" min="3" max="10" value="7" step="1">
219
+ </div> -->
220
+ </div>
221
+
222
+ <div class="control-group">
223
+ <div style="display: flex; gap: 15px; align-items: flex-start; flex-wrap: wrap;">
224
+ <div class="torch-compile-toggle">
225
+ <label>
226
+ <input type="checkbox" id="torchCompile">
227
+ 🔥 torch.compile
228
+ </label>
229
+ </div>
230
+ <div class="torch-compile-toggle">
231
+ <label>
232
+ <input type="checkbox" id="fp8Toggle">
233
+ ⚡ FP8 Quantization
234
+ </label>
235
+ </div>
236
+ <div class="torch-compile-toggle">
237
+ <label>
238
+ <input type="checkbox" id="taehvToggle">
239
+ ⚡ TAEHV VAE
240
+ </label>
241
+ </div>
242
+ </div>
243
+ <!-- <div style="font-size: 0.85em; color: #666; margin-top: 5px;">
244
+ <strong>Note:</strong> torch.compile and FP8 are one-time toggles (cannot be changed once applied)
245
+ </div> -->
246
+ </div>
247
+
248
+ <div class="control-group">
249
+ <button id="startBtn" onclick="startGeneration()">🚀 Start Generation</button>
250
+ <button id="stopBtn" onclick="stopGeneration()" disabled class="stop-btn">⏹️ Stop</button>
251
+ </div>
252
+ </div>
253
+
254
+ <div class="progress-container">
255
+ <div class="progress-bar">
256
+ <div id="progressFill" class="progress-fill" style="width: 0%"></div>
257
+ </div>
258
+ <div id="progressText">Ready to generate</div>
259
+ </div>
260
+ </div>
261
+
262
+ <div class="right-column">
263
+ <div class="buffer-info">
264
+ <strong>📦 Frame Buffer:</strong> <span id="bufferCount">0</span> frames ready |
265
+ <strong>📺 Displayed:</strong> <span id="displayedCount">0</span> frames
266
+ <!-- <strong>⚡ Receive Rate:</strong> <span id="receiveRate">0</span> fps -->
267
+ </div>
268
+
269
+ <div class="playback-controls">
270
+ <button id="playBtn" onclick="togglePlayback()" disabled>▶️ Play</button>
271
+ <button id="resetBtn" onclick="resetPlayback()" disabled>⏮️ Reset</button>
272
+ <label for="playbackSpeed">Speed:</label>
273
+ <select id="playbackSpeed" onchange="updatePlaybackSpeed()">
274
+ <option value="0.25">0.25x</option>
275
+ <option value="0.5">0.5x</option>
276
+ <option value="0.75">0.75x</option>
277
+ <option value="1" selected>1x</option>
278
+ <option value="1.25">1.25x</option>
279
+ <option value="1.5">1.5x</option>
280
+ <option value="2">2x</option>
281
+ </select>
282
+ </div>
283
+
284
+ <div id="statusContainer"></div>
285
+
286
+ <div class="video-container">
287
+ <img id="videoFrame" src="" alt="Video frames will appear here" style="display: none;">
288
+ <div id="placeholderText">Click "Start Generation" to begin</div>
289
+ <div id="frameInfo" class="frame-info"></div>
290
+ </div>
291
+ </div>
292
+ </div>
293
+ </div>
294
+
295
+ <script>
296
+ const socket = io();
297
+ let frameBuffer = []; // Store all received frames
298
+ let currentFrameIndex = 0;
299
+ let isPlaying = false;
300
+ let playbackInterval = null;
301
+ let targetFps = 6;
302
+ let playbackSpeed = 1.0;
303
+ let startTime = null;
304
+ let lastReceiveTime = null;
305
+ let receiveCount = 0;
306
+ let receiveRate = 0;
307
+
308
+ // State tracking for one-time toggles
309
+ let torchCompileApplied = false;
310
+ let fp8Applied = false;
311
+
312
+ // Update slider values
313
+ document.getElementById('fps').oninput = function() {
314
+ targetFps = parseFloat(this.value);
315
+ document.getElementById('fpsValue').textContent = this.value;
316
+ updatePlaybackTiming();
317
+ };
318
+
319
+ // document.getElementById('blocks').oninput = function() {
320
+ // document.getElementById('blocksValue').textContent = this.value;
321
+ // };
322
+
323
+ // Handle toggle behavior and fetch current status
324
+ function updateToggleStates() {
325
+ fetch('/api/status')
326
+ .then(response => response.json())
327
+ .then(data => {
328
+ torchCompileApplied = data.torch_compile_applied;
329
+ fp8Applied = data.fp8_applied;
330
+
331
+ // Update UI based on current state
332
+ const torchToggle = document.getElementById('torchCompile');
333
+ const fp8Toggle = document.getElementById('fp8Toggle');
334
+ const taehvToggle = document.getElementById('taehvToggle');
335
+
336
+ // Disable one-time toggles if already applied
337
+ if (torchCompileApplied) {
338
+ torchToggle.checked = true;
339
+ torchToggle.disabled = true;
340
+ torchToggle.parentElement.style.opacity = '0.6';
341
+ }
342
+
343
+ if (fp8Applied) {
344
+ fp8Toggle.checked = true;
345
+ fp8Toggle.disabled = true;
346
+ fp8Toggle.parentElement.style.opacity = '0.6';
347
+ }
348
+
349
+ // Set TAEHV toggle based on current state
350
+ taehvToggle.checked = data.current_use_taehv;
351
+ })
352
+ .catch(err => console.log('Status check failed:', err));
353
+ }
354
+
355
+ // Handle torch.compile toggle
356
+ document.getElementById('torchCompile').onchange = function() {
357
+ if (torchCompileApplied && !this.checked) {
358
+ this.checked = true; // Prevent unchecking
359
+ alert('torch.compile cannot be disabled once applied');
360
+ }
361
+ };
362
+
363
+ // Handle FP8 toggle
364
+ document.getElementById('fp8Toggle').onchange = function() {
365
+ if (fp8Applied && !this.checked) {
366
+ this.checked = true; // Prevent unchecking
367
+ alert('FP8 quantization cannot be disabled once applied');
368
+ }
369
+ };
370
+
371
+ // Update toggle states on page load
372
+ updateToggleStates();
373
+
374
+ // Socket event handlers
375
+ socket.on('connect', function() {
376
+ // showStatus('Connected to frontend-buffered server', 'info');
377
+ });
378
+
379
+ socket.on('status', function(data) {
380
+ // showStatus(data.message, 'info');
381
+ });
382
+
383
+ socket.on('progress', function(data) {
384
+ updateProgress(data.progress, data.message);
385
+ });
386
+
387
+ socket.on('frame_ready', function(data) {
388
+ // Add frame to buffer immediately
389
+ frameBuffer.push(data);
390
+ receiveCount++;
391
+
392
+ // Calculate receive rate
393
+ const now = Date.now();
394
+ if (lastReceiveTime) {
395
+ const interval = (now - lastReceiveTime) / 1000;
396
+ receiveRate = (1 / interval).toFixed(1);
397
+ }
398
+ lastReceiveTime = now;
399
+
400
+ updateBufferInfo();
401
+
402
+ // Auto-start playback when we have some frames
403
+ if (frameBuffer.length === 5 && !isPlaying) {
404
+ // showStatus('Auto-starting playback with buffer of 5 frames', 'info');
405
+ startPlayback();
406
+ }
407
+ });
408
+
409
+ socket.on('generation_complete', function(data) {
410
+ // showStatus(data.message + ` (Generated in ${data.generation_time})`, 'success');
411
+ enableControls(true);
412
+ const duration = startTime ? ((Date.now() - startTime) / 1000).toFixed(1) : 'unknown';
413
+ updateFrameInfo(`Generation complete! ${data.total_frames} frames in ${duration}s`);
414
+
415
+ // Update toggle states after generation
416
+ updateToggleStates();
417
+ });
418
+
419
+ socket.on('error', function(data) {
420
+ // showStatus(`Error: ${data.message}`, 'error');
421
+ enableControls(true);
422
+ });
423
+
424
+ function startGeneration() {
425
+ const prompt = document.getElementById('prompt').value.trim();
426
+ if (!prompt) {
427
+ alert('Please enter a prompt');
428
+ return;
429
+ }
430
+
431
+ const seed = parseInt(document.getElementById('seed').value) || 31337;
432
+ // const totalBlocks = parseInt(document.getElementById('blocks').value) || 7;
433
+ const enableTorchCompile = document.getElementById('torchCompile').checked;
434
+ const enableFp8 = document.getElementById('fp8Toggle').checked;
435
+ const useTaehv = document.getElementById('taehvToggle').checked;
436
+
437
+ // Reset state
438
+ frameBuffer = [];
439
+ currentFrameIndex = 0;
440
+ receiveCount = 0;
441
+ receiveRate = 0;
442
+ stopPlayback();
443
+
444
+ enableControls(false);
445
+ startTime = Date.now();
446
+
447
+ socket.emit('start_generation', {
448
+ prompt: prompt,
449
+ seed: seed,
450
+ enable_torch_compile: enableTorchCompile,
451
+ enable_fp8: enableFp8,
452
+ use_taehv: useTaehv
453
+ });
454
+ }
455
+
456
+ function stopGeneration() {
457
+ socket.emit('stop_generation');
458
+ enableControls(true);
459
+ }
460
+
461
+ function togglePlayback() {
462
+ if (isPlaying) {
463
+ stopPlayback();
464
+ } else {
465
+ startPlayback();
466
+ }
467
+ }
468
+
469
+ function startPlayback() {
470
+ if (frameBuffer.length === 0) return;
471
+
472
+ isPlaying = true;
473
+ document.getElementById('playBtn').textContent = '⏸️ Pause';
474
+ document.getElementById('playBtn').disabled = false;
475
+ document.getElementById('resetBtn').disabled = false;
476
+
477
+ updatePlaybackTiming();
478
+ // showStatus('Playback started', 'info');
479
+ }
480
+
481
+ function stopPlayback() {
482
+ isPlaying = false;
483
+ if (playbackInterval) {
484
+ clearInterval(playbackInterval);
485
+ playbackInterval = null;
486
+ }
487
+ document.getElementById('playBtn').textContent = '▶️ Play';
488
+ }
489
+
490
+ function resetPlayback() {
491
+ stopPlayback();
492
+
493
+ // Clear the entire frame buffer
494
+ frameBuffer = [];
495
+ currentFrameIndex = 0;
496
+ receiveCount = 0;
497
+ receiveRate = 0;
498
+
499
+ // Reset video display to initial state
500
+ const img = document.getElementById('videoFrame');
501
+ const placeholder = document.getElementById('placeholderText');
502
+
503
+ img.src = '';
504
+ img.style.display = 'none';
505
+ placeholder.style.display = 'block';
506
+
507
+ // Update UI
508
+ updateBufferInfo();
509
+ updateFrameInfo('Reset - buffer cleared');
510
+
511
+ // Disable playback controls since there's no content
512
+ document.getElementById('playBtn').disabled = true;
513
+ document.getElementById('resetBtn').disabled = true;
514
+ }
515
+
516
+ function updatePlaybackSpeed() {
517
+ playbackSpeed = parseFloat(document.getElementById('playbackSpeed').value);
518
+ if (isPlaying) {
519
+ updatePlaybackTiming();
520
+ }
521
+ }
522
+
523
+ function updatePlaybackTiming() {
524
+ if (playbackInterval) {
525
+ clearInterval(playbackInterval);
526
+ }
527
+
528
+ if (isPlaying) {
529
+ const interval = (1000 / targetFps) / playbackSpeed;
530
+ playbackInterval = setInterval(displayNextFrame, interval);
531
+ }
532
+ }
533
+
534
+ function displayNextFrame() {
535
+ if (currentFrameIndex >= frameBuffer.length) {
536
+ // Reached end of buffer
537
+ if (document.querySelector('#progressFill').style.width === '100%') {
538
+ // Generation complete, stop playback
539
+ stopPlayback();
540
+ // showStatus('Playback complete', 'success');
541
+ }
542
+ return;
543
+ }
544
+
545
+ const frameData = frameBuffer[currentFrameIndex];
546
+ displayFrame(frameData);
547
+ currentFrameIndex++;
548
+
549
+ updateBufferInfo();
550
+ }
551
+
552
+ function displayFrame(frameData) {
553
+ const img = document.getElementById('videoFrame');
554
+ const placeholder = document.getElementById('placeholderText');
555
+
556
+ img.src = frameData.data;
557
+ img.style.display = 'block';
558
+ placeholder.style.display = 'none';
559
+
560
+ const elapsed = startTime ? ((Date.now() - startTime) / 1000).toFixed(1) : '0';
561
+ updateFrameInfo(`Frame ${frameData.frame_index + 1} | Block ${frameData.block_index + 1} | ${elapsed}s elapsed | ${targetFps} FPS @ ${playbackSpeed}x speed`);
562
+ }
563
+
564
+ function updateBufferInfo() {
565
+ document.getElementById('bufferCount').textContent = frameBuffer.length;
566
+ document.getElementById('displayedCount').textContent = currentFrameIndex;
567
+ // document.getElementById('receiveRate').textContent = receiveRate;
568
+ }
569
+
570
+ function setQuickPrompt(type) {
571
+ const promptBox = document.getElementById('prompt');
572
+ if (type === 'quick-demo-1') {
573
+ promptBox.value = 'A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.';
574
+ } else if (type === 'quick-demo-2') {
575
+ promptBox.value = 'A white and orange tabby cat is seen happily darting through a dense garden, as if chasing something. Its eyes are wide and happy as it jogs forward, scanning the branches, flowers, and leaves as it walks. The path is narrow as it makes its way between all the plants. the scene is captured from a ground-level angle, following the cat closely, giving a low and intimate perspective. The image is cinematic with warm tones and a grainy texture. The scattered daylight between the leaves and plants above creates a warm contrast, accentuating the cat’s orange fur. The shot is clear and sharp, with a shallow depth of field.';
576
+ }
577
+ }
578
+
579
+ function enableControls(enabled) {
580
+ document.getElementById('startBtn').disabled = !enabled;
581
+ document.getElementById('stopBtn').disabled = enabled;
582
+ }
583
+
584
+ function updateProgress(progress, message) {
585
+ document.getElementById('progressFill').style.width = progress + '%';
586
+ document.getElementById('progressText').textContent = message;
587
+ }
588
+
589
+ function updateFrameInfo(text) {
590
+ document.getElementById('frameInfo').textContent = text;
591
+ }
592
+
593
+ function showStatus(message, type) {
594
+ const container = document.getElementById('statusContainer');
595
+ const statusDiv = document.createElement('div');
596
+ statusDiv.className = `status ${type}`;
597
+ statusDiv.textContent = message;
598
+
599
+ container.insertBefore(statusDiv, container.firstChild);
600
+
601
+ // Remove old status messages (keep only last 3)
602
+ while (container.children.length > 3) {
603
+ container.removeChild(container.lastChild);
604
+ }
605
+
606
+ // Auto-remove after 5 seconds
607
+ setTimeout(() => {
608
+ if (statusDiv.parentNode) {
609
+ statusDiv.parentNode.removeChild(statusDiv);
610
+ }
611
+ }, 5000);
612
+ }
613
+ </script>
614
+ </body>
615
+ </html>
train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ from omegaconf import OmegaConf
4
+ import wandb
5
+
6
+ from trainer import DiffusionTrainer, GANTrainer, ODETrainer, ScoreDistillationTrainer
7
+
8
+
9
+ def main():
10
+ parser = argparse.ArgumentParser()
11
+ parser.add_argument("--config_path", type=str, required=True)
12
+ parser.add_argument("--no_save", action="store_true")
13
+ parser.add_argument("--no_visualize", action="store_true")
14
+ parser.add_argument("--logdir", type=str, default="", help="Path to the directory to save logs")
15
+ parser.add_argument("--wandb-save-dir", type=str, default="", help="Path to the directory to save wandb logs")
16
+ parser.add_argument("--disable-wandb", action="store_true")
17
+
18
+ args = parser.parse_args()
19
+
20
+ config = OmegaConf.load(args.config_path)
21
+ default_config = OmegaConf.load("configs/default_config.yaml")
22
+ config = OmegaConf.merge(default_config, config)
23
+ config.no_save = args.no_save
24
+ config.no_visualize = args.no_visualize
25
+
26
+ # get the filename of config_path
27
+ config_name = os.path.basename(args.config_path).split(".")[0]
28
+ config.config_name = config_name
29
+ config.logdir = args.logdir
30
+ config.wandb_save_dir = args.wandb_save_dir
31
+ config.disable_wandb = args.disable_wandb
32
+
33
+ if config.trainer == "diffusion":
34
+ trainer = DiffusionTrainer(config)
35
+ elif config.trainer == "gan":
36
+ trainer = GANTrainer(config)
37
+ elif config.trainer == "ode":
38
+ trainer = ODETrainer(config)
39
+ elif config.trainer == "score_distillation":
40
+ trainer = ScoreDistillationTrainer(config)
41
+ trainer.train()
42
+
43
+ wandb.finish()
44
+
45
+
46
+ if __name__ == "__main__":
47
+ main()
trainer/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffusion import Trainer as DiffusionTrainer
2
+ from .gan import Trainer as GANTrainer
3
+ from .ode import Trainer as ODETrainer
4
+ from .distillation import Trainer as ScoreDistillationTrainer
5
+
6
+ __all__ = [
7
+ "DiffusionTrainer",
8
+ "GANTrainer",
9
+ "ODETrainer",
10
+ "ScoreDistillationTrainer"
11
+ ]
trainer/diffusion.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+
4
+ from model import CausalDiffusion
5
+ from utils.dataset import ShardingLMDBDataset, cycle
6
+ from utils.misc import set_seed
7
+ import torch.distributed as dist
8
+ from omegaconf import OmegaConf
9
+ import torch
10
+ import wandb
11
+ import time
12
+ import os
13
+
14
+ from utils.distributed import EMA_FSDP, barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
15
+
16
+
17
+ class Trainer:
18
+ def __init__(self, config):
19
+ self.config = config
20
+ self.step = 0
21
+
22
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
23
+ torch.backends.cuda.matmul.allow_tf32 = True
24
+ torch.backends.cudnn.allow_tf32 = True
25
+
26
+ launch_distributed_job()
27
+ global_rank = dist.get_rank()
28
+
29
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
30
+ self.device = torch.cuda.current_device()
31
+ self.is_main_process = global_rank == 0
32
+ self.causal = config.causal
33
+ self.disable_wandb = config.disable_wandb
34
+
35
+ # use a random seed for the training
36
+ if config.seed == 0:
37
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
38
+ dist.broadcast(random_seed, src=0)
39
+ config.seed = random_seed.item()
40
+
41
+ set_seed(config.seed + global_rank)
42
+
43
+ if self.is_main_process and not self.disable_wandb:
44
+ wandb.login(host=config.wandb_host, key=config.wandb_key)
45
+ wandb.init(
46
+ config=OmegaConf.to_container(config, resolve=True),
47
+ name=config.config_name,
48
+ mode="online",
49
+ entity=config.wandb_entity,
50
+ project=config.wandb_project,
51
+ dir=config.wandb_save_dir
52
+ )
53
+
54
+ self.output_path = config.logdir
55
+
56
+ # Step 2: Initialize the model and optimizer
57
+ self.model = CausalDiffusion(config, device=self.device)
58
+ self.model.generator = fsdp_wrap(
59
+ self.model.generator,
60
+ sharding_strategy=config.sharding_strategy,
61
+ mixed_precision=config.mixed_precision,
62
+ wrap_strategy=config.generator_fsdp_wrap_strategy
63
+ )
64
+
65
+ self.model.text_encoder = fsdp_wrap(
66
+ self.model.text_encoder,
67
+ sharding_strategy=config.sharding_strategy,
68
+ mixed_precision=config.mixed_precision,
69
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy
70
+ )
71
+
72
+ if not config.no_visualize or config.load_raw_video:
73
+ self.model.vae = self.model.vae.to(
74
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
75
+
76
+ self.generator_optimizer = torch.optim.AdamW(
77
+ [param for param in self.model.generator.parameters()
78
+ if param.requires_grad],
79
+ lr=config.lr,
80
+ betas=(config.beta1, config.beta2),
81
+ weight_decay=config.weight_decay
82
+ )
83
+
84
+ # Step 3: Initialize the dataloader
85
+ dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
86
+ sampler = torch.utils.data.distributed.DistributedSampler(
87
+ dataset, shuffle=True, drop_last=True)
88
+ dataloader = torch.utils.data.DataLoader(
89
+ dataset,
90
+ batch_size=config.batch_size,
91
+ sampler=sampler,
92
+ num_workers=8)
93
+
94
+ if dist.get_rank() == 0:
95
+ print("DATASET SIZE %d" % len(dataset))
96
+ self.dataloader = cycle(dataloader)
97
+
98
+ ##############################################################################################################
99
+ # 6. Set up EMA parameter containers
100
+ rename_param = (
101
+ lambda name: name.replace("_fsdp_wrapped_module.", "")
102
+ .replace("_checkpoint_wrapped_module.", "")
103
+ .replace("_orig_mod.", "")
104
+ )
105
+ self.name_to_trainable_params = {}
106
+ for n, p in self.model.generator.named_parameters():
107
+ if not p.requires_grad:
108
+ continue
109
+
110
+ renamed_n = rename_param(n)
111
+ self.name_to_trainable_params[renamed_n] = p
112
+ ema_weight = config.ema_weight
113
+ self.generator_ema = None
114
+ if (ema_weight is not None) and (ema_weight > 0.0):
115
+ print(f"Setting up EMA with weight {ema_weight}")
116
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
117
+
118
+ ##############################################################################################################
119
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
120
+ if getattr(config, "generator_ckpt", False):
121
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
122
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")
123
+ if "generator" in state_dict:
124
+ state_dict = state_dict["generator"]
125
+ elif "model" in state_dict:
126
+ state_dict = state_dict["model"]
127
+ self.model.generator.load_state_dict(
128
+ state_dict, strict=True
129
+ )
130
+
131
+ ##############################################################################################################
132
+
133
+ # Let's delete EMA params for early steps to save some computes at training and inference
134
+ if self.step < config.ema_start_step:
135
+ self.generator_ema = None
136
+
137
+ self.max_grad_norm = 10.0
138
+ self.previous_time = None
139
+
140
+ def save(self):
141
+ print("Start gathering distributed model states...")
142
+ generator_state_dict = fsdp_state_dict(
143
+ self.model.generator)
144
+
145
+ if self.config.ema_start_step < self.step:
146
+ state_dict = {
147
+ "generator": generator_state_dict,
148
+ "generator_ema": self.generator_ema.state_dict(),
149
+ }
150
+ else:
151
+ state_dict = {
152
+ "generator": generator_state_dict,
153
+ }
154
+
155
+ if self.is_main_process:
156
+ os.makedirs(os.path.join(self.output_path,
157
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
158
+ torch.save(state_dict, os.path.join(self.output_path,
159
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
160
+ print("Model saved to", os.path.join(self.output_path,
161
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
162
+
163
+ def train_one_step(self, batch):
164
+ self.log_iters = 1
165
+
166
+ if self.step % 20 == 0:
167
+ torch.cuda.empty_cache()
168
+
169
+ # Step 1: Get the next batch of text prompts
170
+ text_prompts = batch["prompts"]
171
+ if not self.config.load_raw_video: # precomputed latent
172
+ clean_latent = batch["ode_latent"][:, -1].to(
173
+ device=self.device, dtype=self.dtype)
174
+ else: # encode raw video to latent
175
+ frames = batch["frames"].to(
176
+ device=self.device, dtype=self.dtype)
177
+ with torch.no_grad():
178
+ clean_latent = self.model.vae.encode_to_latent(
179
+ frames).to(device=self.device, dtype=self.dtype)
180
+ image_latent = clean_latent[:, 0:1, ]
181
+
182
+ batch_size = len(text_prompts)
183
+ image_or_video_shape = list(self.config.image_or_video_shape)
184
+ image_or_video_shape[0] = batch_size
185
+
186
+ # Step 2: Extract the conditional infos
187
+ with torch.no_grad():
188
+ conditional_dict = self.model.text_encoder(
189
+ text_prompts=text_prompts)
190
+
191
+ if not getattr(self, "unconditional_dict", None):
192
+ unconditional_dict = self.model.text_encoder(
193
+ text_prompts=[self.config.negative_prompt] * batch_size)
194
+ unconditional_dict = {k: v.detach()
195
+ for k, v in unconditional_dict.items()}
196
+ self.unconditional_dict = unconditional_dict # cache the unconditional_dict
197
+ else:
198
+ unconditional_dict = self.unconditional_dict
199
+
200
+ # Step 3: Train the generator
201
+ generator_loss, log_dict = self.model.generator_loss(
202
+ image_or_video_shape=image_or_video_shape,
203
+ conditional_dict=conditional_dict,
204
+ unconditional_dict=unconditional_dict,
205
+ clean_latent=clean_latent,
206
+ initial_latent=image_latent
207
+ )
208
+ self.generator_optimizer.zero_grad()
209
+ generator_loss.backward()
210
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
211
+ self.max_grad_norm)
212
+ self.generator_optimizer.step()
213
+
214
+ # Increment the step since we finished gradient update
215
+ self.step += 1
216
+
217
+ wandb_loss_dict = {
218
+ "generator_loss": generator_loss.item(),
219
+ "generator_grad_norm": generator_grad_norm.item(),
220
+ }
221
+
222
+ # Step 4: Logging
223
+ if self.is_main_process:
224
+ if not self.disable_wandb:
225
+ wandb.log(wandb_loss_dict, step=self.step)
226
+
227
+ if self.step % self.config.gc_interval == 0:
228
+ if dist.get_rank() == 0:
229
+ logging.info("DistGarbageCollector: Running GC.")
230
+ gc.collect()
231
+
232
+ # Step 5. Create EMA params
233
+ # TODO: Implement EMA
234
+
235
+ def generate_video(self, pipeline, prompts, image=None):
236
+ batch_size = len(prompts)
237
+ sampled_noise = torch.randn(
238
+ [batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
239
+ )
240
+ video, _ = pipeline.inference(
241
+ noise=sampled_noise,
242
+ text_prompts=prompts,
243
+ return_latents=True
244
+ )
245
+ current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
246
+ return current_video
247
+
248
+ def train(self):
249
+ while True:
250
+ batch = next(self.dataloader)
251
+ self.train_one_step(batch)
252
+ if (not self.config.no_save) and self.step % self.config.log_iters == 0:
253
+ torch.cuda.empty_cache()
254
+ self.save()
255
+ torch.cuda.empty_cache()
256
+
257
+ barrier()
258
+ if self.is_main_process:
259
+ current_time = time.time()
260
+ if self.previous_time is None:
261
+ self.previous_time = current_time
262
+ else:
263
+ if not self.disable_wandb:
264
+ wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
265
+ self.previous_time = current_time
trainer/distillation.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+
4
+ from utils.dataset import ShardingLMDBDataset, cycle
5
+ from utils.dataset import TextDataset
6
+ from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
7
+ from utils.misc import (
8
+ set_seed,
9
+ merge_dict_list
10
+ )
11
+ import torch.distributed as dist
12
+ from omegaconf import OmegaConf
13
+ from model import CausVid, DMD, SiD
14
+ import torch
15
+ import wandb
16
+ import time
17
+ import os
18
+
19
+
20
+ class Trainer:
21
+ def __init__(self, config):
22
+ self.config = config
23
+ self.step = 0
24
+
25
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
26
+ torch.backends.cuda.matmul.allow_tf32 = True
27
+ torch.backends.cudnn.allow_tf32 = True
28
+
29
+ launch_distributed_job()
30
+ global_rank = dist.get_rank()
31
+ self.world_size = dist.get_world_size()
32
+
33
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
34
+ self.device = torch.cuda.current_device()
35
+ self.is_main_process = global_rank == 0
36
+ self.causal = config.causal
37
+ self.disable_wandb = config.disable_wandb
38
+
39
+ # use a random seed for the training
40
+ if config.seed == 0:
41
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
42
+ dist.broadcast(random_seed, src=0)
43
+ config.seed = random_seed.item()
44
+
45
+ set_seed(config.seed + global_rank)
46
+
47
+ if self.is_main_process and not self.disable_wandb:
48
+ wandb.login(host=config.wandb_host, key=config.wandb_key)
49
+ wandb.init(
50
+ config=OmegaConf.to_container(config, resolve=True),
51
+ name=config.config_name,
52
+ mode="online",
53
+ entity=config.wandb_entity,
54
+ project=config.wandb_project,
55
+ dir=config.wandb_save_dir
56
+ )
57
+
58
+ self.output_path = config.logdir
59
+
60
+ # Step 2: Initialize the model and optimizer
61
+ if config.distribution_loss == "causvid":
62
+ self.model = CausVid(config, device=self.device)
63
+ elif config.distribution_loss == "dmd":
64
+ self.model = DMD(config, device=self.device)
65
+ elif config.distribution_loss == "sid":
66
+ self.model = SiD(config, device=self.device)
67
+ else:
68
+ raise ValueError("Invalid distribution matching loss")
69
+
70
+ # Save pretrained model state_dicts to CPU
71
+ self.fake_score_state_dict_cpu = self.model.fake_score.state_dict()
72
+
73
+ self.model.generator = fsdp_wrap(
74
+ self.model.generator,
75
+ sharding_strategy=config.sharding_strategy,
76
+ mixed_precision=config.mixed_precision,
77
+ wrap_strategy=config.generator_fsdp_wrap_strategy
78
+ )
79
+
80
+ self.model.real_score = fsdp_wrap(
81
+ self.model.real_score,
82
+ sharding_strategy=config.sharding_strategy,
83
+ mixed_precision=config.mixed_precision,
84
+ wrap_strategy=config.real_score_fsdp_wrap_strategy
85
+ )
86
+
87
+ self.model.fake_score = fsdp_wrap(
88
+ self.model.fake_score,
89
+ sharding_strategy=config.sharding_strategy,
90
+ mixed_precision=config.mixed_precision,
91
+ wrap_strategy=config.fake_score_fsdp_wrap_strategy
92
+ )
93
+
94
+ self.model.text_encoder = fsdp_wrap(
95
+ self.model.text_encoder,
96
+ sharding_strategy=config.sharding_strategy,
97
+ mixed_precision=config.mixed_precision,
98
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
99
+ cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
100
+ )
101
+
102
+ if not config.no_visualize or config.load_raw_video:
103
+ self.model.vae = self.model.vae.to(
104
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
105
+
106
+ self.generator_optimizer = torch.optim.AdamW(
107
+ [param for param in self.model.generator.parameters()
108
+ if param.requires_grad],
109
+ lr=config.lr,
110
+ betas=(config.beta1, config.beta2),
111
+ weight_decay=config.weight_decay
112
+ )
113
+
114
+ self.critic_optimizer = torch.optim.AdamW(
115
+ [param for param in self.model.fake_score.parameters()
116
+ if param.requires_grad],
117
+ lr=config.lr_critic if hasattr(config, "lr_critic") else config.lr,
118
+ betas=(config.beta1_critic, config.beta2_critic),
119
+ weight_decay=config.weight_decay
120
+ )
121
+
122
+ # Step 3: Initialize the dataloader
123
+ if self.config.i2v:
124
+ dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
125
+ else:
126
+ dataset = TextDataset(config.data_path)
127
+ sampler = torch.utils.data.distributed.DistributedSampler(
128
+ dataset, shuffle=True, drop_last=True)
129
+ dataloader = torch.utils.data.DataLoader(
130
+ dataset,
131
+ batch_size=config.batch_size,
132
+ sampler=sampler,
133
+ num_workers=8)
134
+
135
+ if dist.get_rank() == 0:
136
+ print("DATASET SIZE %d" % len(dataset))
137
+ self.dataloader = cycle(dataloader)
138
+
139
+ ##############################################################################################################
140
+ # 6. Set up EMA parameter containers
141
+ rename_param = (
142
+ lambda name: name.replace("_fsdp_wrapped_module.", "")
143
+ .replace("_checkpoint_wrapped_module.", "")
144
+ .replace("_orig_mod.", "")
145
+ )
146
+ self.name_to_trainable_params = {}
147
+ for n, p in self.model.generator.named_parameters():
148
+ if not p.requires_grad:
149
+ continue
150
+
151
+ renamed_n = rename_param(n)
152
+ self.name_to_trainable_params[renamed_n] = p
153
+ ema_weight = config.ema_weight
154
+ self.generator_ema = None
155
+ if (ema_weight is not None) and (ema_weight > 0.0):
156
+ print(f"Setting up EMA with weight {ema_weight}")
157
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
158
+
159
+ ##############################################################################################################
160
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
161
+ if getattr(config, "generator_ckpt", False):
162
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
163
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")
164
+ if "generator" in state_dict:
165
+ state_dict = state_dict["generator"]
166
+ elif "model" in state_dict:
167
+ state_dict = state_dict["model"]
168
+ self.model.generator.load_state_dict(
169
+ state_dict, strict=True
170
+ )
171
+
172
+ ##############################################################################################################
173
+
174
+ # Let's delete EMA params for early steps to save some computes at training and inference
175
+ if self.step < config.ema_start_step:
176
+ self.generator_ema = None
177
+
178
+ self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
179
+ self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
180
+ self.previous_time = None
181
+
182
+ def save(self):
183
+ print("Start gathering distributed model states...")
184
+ generator_state_dict = fsdp_state_dict(
185
+ self.model.generator)
186
+ critic_state_dict = fsdp_state_dict(
187
+ self.model.fake_score)
188
+
189
+ if self.config.ema_start_step < self.step:
190
+ state_dict = {
191
+ "generator": generator_state_dict,
192
+ "critic": critic_state_dict,
193
+ "generator_ema": self.generator_ema.state_dict(),
194
+ }
195
+ else:
196
+ state_dict = {
197
+ "generator": generator_state_dict,
198
+ "critic": critic_state_dict,
199
+ }
200
+
201
+ if self.is_main_process:
202
+ os.makedirs(os.path.join(self.output_path,
203
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
204
+ torch.save(state_dict, os.path.join(self.output_path,
205
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
206
+ print("Model saved to", os.path.join(self.output_path,
207
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
208
+
209
+ def fwdbwd_one_step(self, batch, train_generator):
210
+ self.model.eval() # prevent any randomness (e.g. dropout)
211
+
212
+ if self.step % 20 == 0:
213
+ torch.cuda.empty_cache()
214
+
215
+ # Step 1: Get the next batch of text prompts
216
+ text_prompts = batch["prompts"]
217
+ if self.config.i2v:
218
+ clean_latent = None
219
+ image_latent = batch["ode_latent"][:, -1][:, 0:1, ].to(
220
+ device=self.device, dtype=self.dtype)
221
+ else:
222
+ clean_latent = None
223
+ image_latent = None
224
+
225
+ batch_size = len(text_prompts)
226
+ image_or_video_shape = list(self.config.image_or_video_shape)
227
+ image_or_video_shape[0] = batch_size
228
+
229
+ # Step 2: Extract the conditional infos
230
+ with torch.no_grad():
231
+ conditional_dict = self.model.text_encoder(
232
+ text_prompts=text_prompts)
233
+
234
+ if not getattr(self, "unconditional_dict", None):
235
+ unconditional_dict = self.model.text_encoder(
236
+ text_prompts=[self.config.negative_prompt] * batch_size)
237
+ unconditional_dict = {k: v.detach()
238
+ for k, v in unconditional_dict.items()}
239
+ self.unconditional_dict = unconditional_dict # cache the unconditional_dict
240
+ else:
241
+ unconditional_dict = self.unconditional_dict
242
+
243
+ # Step 3: Store gradients for the generator (if training the generator)
244
+ if train_generator:
245
+ generator_loss, generator_log_dict = self.model.generator_loss(
246
+ image_or_video_shape=image_or_video_shape,
247
+ conditional_dict=conditional_dict,
248
+ unconditional_dict=unconditional_dict,
249
+ clean_latent=clean_latent,
250
+ initial_latent=image_latent if self.config.i2v else None
251
+ )
252
+
253
+ generator_loss.backward()
254
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
255
+ self.max_grad_norm_generator)
256
+
257
+ generator_log_dict.update({"generator_loss": generator_loss,
258
+ "generator_grad_norm": generator_grad_norm})
259
+
260
+ return generator_log_dict
261
+ else:
262
+ generator_log_dict = {}
263
+
264
+ # Step 4: Store gradients for the critic (if training the critic)
265
+ critic_loss, critic_log_dict = self.model.critic_loss(
266
+ image_or_video_shape=image_or_video_shape,
267
+ conditional_dict=conditional_dict,
268
+ unconditional_dict=unconditional_dict,
269
+ clean_latent=clean_latent,
270
+ initial_latent=image_latent if self.config.i2v else None
271
+ )
272
+
273
+ critic_loss.backward()
274
+ critic_grad_norm = self.model.fake_score.clip_grad_norm_(
275
+ self.max_grad_norm_critic)
276
+
277
+ critic_log_dict.update({"critic_loss": critic_loss,
278
+ "critic_grad_norm": critic_grad_norm})
279
+
280
+ return critic_log_dict
281
+
282
+ def generate_video(self, pipeline, prompts, image=None):
283
+ batch_size = len(prompts)
284
+ if image is not None:
285
+ image = image.squeeze(0).unsqueeze(0).unsqueeze(2).to(device="cuda", dtype=torch.bfloat16)
286
+
287
+ # Encode the input image as the first latent
288
+ initial_latent = pipeline.vae.encode_to_latent(image).to(device="cuda", dtype=torch.bfloat16)
289
+ initial_latent = initial_latent.repeat(batch_size, 1, 1, 1, 1)
290
+ sampled_noise = torch.randn(
291
+ [batch_size, self.model.num_training_frames - 1, 16, 60, 104],
292
+ device="cuda",
293
+ dtype=self.dtype
294
+ )
295
+ else:
296
+ initial_latent = None
297
+ sampled_noise = torch.randn(
298
+ [batch_size, self.model.num_training_frames, 16, 60, 104],
299
+ device="cuda",
300
+ dtype=self.dtype
301
+ )
302
+
303
+ video, _ = pipeline.inference(
304
+ noise=sampled_noise,
305
+ text_prompts=prompts,
306
+ return_latents=True,
307
+ initial_latent=initial_latent
308
+ )
309
+ current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
310
+ return current_video
311
+
312
+ def train(self):
313
+ start_step = self.step
314
+
315
+ while True:
316
+ TRAIN_GENERATOR = self.step % self.config.dfake_gen_update_ratio == 0
317
+
318
+ # Train the generator
319
+ if TRAIN_GENERATOR:
320
+ self.generator_optimizer.zero_grad(set_to_none=True)
321
+ extras_list = []
322
+ batch = next(self.dataloader)
323
+ extra = self.fwdbwd_one_step(batch, True)
324
+ extras_list.append(extra)
325
+ generator_log_dict = merge_dict_list(extras_list)
326
+ self.generator_optimizer.step()
327
+ if self.generator_ema is not None:
328
+ self.generator_ema.update(self.model.generator)
329
+
330
+ # Train the critic
331
+ self.critic_optimizer.zero_grad(set_to_none=True)
332
+ extras_list = []
333
+ batch = next(self.dataloader)
334
+ extra = self.fwdbwd_one_step(batch, False)
335
+ extras_list.append(extra)
336
+ critic_log_dict = merge_dict_list(extras_list)
337
+ self.critic_optimizer.step()
338
+
339
+ # Increment the step since we finished gradient update
340
+ self.step += 1
341
+
342
+ # Create EMA params (if not already created)
343
+ if (self.step >= self.config.ema_start_step) and \
344
+ (self.generator_ema is None) and (self.config.ema_weight > 0):
345
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
346
+
347
+ # Save the model
348
+ if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
349
+ torch.cuda.empty_cache()
350
+ self.save()
351
+ torch.cuda.empty_cache()
352
+
353
+ # Logging
354
+ if self.is_main_process:
355
+ wandb_loss_dict = {}
356
+ if TRAIN_GENERATOR:
357
+ wandb_loss_dict.update(
358
+ {
359
+ "generator_loss": generator_log_dict["generator_loss"].mean().item(),
360
+ "generator_grad_norm": generator_log_dict["generator_grad_norm"].mean().item(),
361
+ "dmdtrain_gradient_norm": generator_log_dict["dmdtrain_gradient_norm"].mean().item()
362
+ }
363
+ )
364
+
365
+ wandb_loss_dict.update(
366
+ {
367
+ "critic_loss": critic_log_dict["critic_loss"].mean().item(),
368
+ "critic_grad_norm": critic_log_dict["critic_grad_norm"].mean().item()
369
+ }
370
+ )
371
+
372
+ if not self.disable_wandb:
373
+ wandb.log(wandb_loss_dict, step=self.step)
374
+
375
+ if self.step % self.config.gc_interval == 0:
376
+ if dist.get_rank() == 0:
377
+ logging.info("DistGarbageCollector: Running GC.")
378
+ gc.collect()
379
+ torch.cuda.empty_cache()
380
+
381
+ if self.is_main_process:
382
+ current_time = time.time()
383
+ if self.previous_time is None:
384
+ self.previous_time = current_time
385
+ else:
386
+ if not self.disable_wandb:
387
+ wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
388
+ self.previous_time = current_time
trainer/gan.py ADDED
@@ -0,0 +1,464 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+
4
+ from utils.dataset import ShardingLMDBDataset, cycle
5
+ from utils.distributed import EMA_FSDP, fsdp_wrap, fsdp_state_dict, launch_distributed_job
6
+ from utils.misc import (
7
+ set_seed,
8
+ merge_dict_list
9
+ )
10
+ import torch.distributed as dist
11
+ from omegaconf import OmegaConf
12
+ from model import GAN
13
+ import torch
14
+ import wandb
15
+ import time
16
+ import os
17
+
18
+
19
+ class Trainer:
20
+ def __init__(self, config):
21
+ self.config = config
22
+ self.step = 0
23
+
24
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ torch.backends.cudnn.allow_tf32 = True
27
+
28
+ launch_distributed_job()
29
+ global_rank = dist.get_rank()
30
+ self.world_size = dist.get_world_size()
31
+
32
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
33
+ self.device = torch.cuda.current_device()
34
+ self.is_main_process = global_rank == 0
35
+ self.causal = config.causal
36
+ self.disable_wandb = config.disable_wandb
37
+
38
+ # Configuration for discriminator warmup
39
+ self.discriminator_warmup_steps = getattr(config, "discriminator_warmup_steps", 0)
40
+ self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
41
+ if self.in_discriminator_warmup and self.is_main_process:
42
+ print(f"Starting with discriminator warmup for {self.discriminator_warmup_steps} steps")
43
+ self.loss_scale = getattr(config, "loss_scale", 1.0)
44
+
45
+ # use a random seed for the training
46
+ if config.seed == 0:
47
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
48
+ dist.broadcast(random_seed, src=0)
49
+ config.seed = random_seed.item()
50
+
51
+ set_seed(config.seed + global_rank)
52
+
53
+ if self.is_main_process and not self.disable_wandb:
54
+ wandb.login(host=config.wandb_host, key=config.wandb_key)
55
+ wandb.init(
56
+ config=OmegaConf.to_container(config, resolve=True),
57
+ name=config.config_name,
58
+ mode="online",
59
+ entity=config.wandb_entity,
60
+ project=config.wandb_project,
61
+ dir=config.wandb_save_dir
62
+ )
63
+
64
+ self.output_path = config.logdir
65
+
66
+ # Step 2: Initialize the model and optimizer
67
+ self.model = GAN(config, device=self.device)
68
+
69
+ self.model.generator = fsdp_wrap(
70
+ self.model.generator,
71
+ sharding_strategy=config.sharding_strategy,
72
+ mixed_precision=config.mixed_precision,
73
+ wrap_strategy=config.generator_fsdp_wrap_strategy
74
+ )
75
+
76
+ self.model.fake_score = fsdp_wrap(
77
+ self.model.fake_score,
78
+ sharding_strategy=config.sharding_strategy,
79
+ mixed_precision=config.mixed_precision,
80
+ wrap_strategy=config.fake_score_fsdp_wrap_strategy
81
+ )
82
+
83
+ self.model.text_encoder = fsdp_wrap(
84
+ self.model.text_encoder,
85
+ sharding_strategy=config.sharding_strategy,
86
+ mixed_precision=config.mixed_precision,
87
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
88
+ cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
89
+ )
90
+
91
+ if not config.no_visualize or config.load_raw_video:
92
+ self.model.vae = self.model.vae.to(
93
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
94
+
95
+ self.generator_optimizer = torch.optim.AdamW(
96
+ [param for param in self.model.generator.parameters()
97
+ if param.requires_grad],
98
+ lr=config.gen_lr,
99
+ betas=(config.beta1, config.beta2)
100
+ )
101
+
102
+ # Create separate parameter groups for the fake_score network
103
+ # One group for parameters with "_cls_pred_branch" or "_gan_ca_blocks" in the name
104
+ # and another group for all other parameters
105
+ fake_score_params = []
106
+ discriminator_params = []
107
+
108
+ for name, param in self.model.fake_score.named_parameters():
109
+ if param.requires_grad:
110
+ if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
111
+ discriminator_params.append(param)
112
+ else:
113
+ fake_score_params.append(param)
114
+
115
+ # Use the special learning rate for the special parameter group
116
+ # and the default critic learning rate for other parameters
117
+ self.critic_param_groups = [
118
+ {'params': fake_score_params, 'lr': config.critic_lr},
119
+ {'params': discriminator_params, 'lr': config.critic_lr * config.discriminator_lr_multiplier}
120
+ ]
121
+ if self.in_discriminator_warmup:
122
+ self.critic_optimizer = torch.optim.AdamW(
123
+ self.critic_param_groups,
124
+ betas=(0.9, config.beta2_critic)
125
+ )
126
+ else:
127
+ self.critic_optimizer = torch.optim.AdamW(
128
+ self.critic_param_groups,
129
+ betas=(config.beta1_critic, config.beta2_critic)
130
+ )
131
+
132
+ # Step 3: Initialize the dataloader
133
+ self.data_path = config.data_path
134
+ dataset = ShardingLMDBDataset(config.data_path, max_pair=int(1e8))
135
+ sampler = torch.utils.data.distributed.DistributedSampler(
136
+ dataset, shuffle=True, drop_last=True)
137
+ dataloader = torch.utils.data.DataLoader(
138
+ dataset,
139
+ batch_size=config.batch_size,
140
+ sampler=sampler,
141
+ num_workers=8)
142
+
143
+ if dist.get_rank() == 0:
144
+ print("DATASET SIZE %d" % len(dataset))
145
+
146
+ self.dataloader = cycle(dataloader)
147
+
148
+ ##############################################################################################################
149
+ # 6. Set up EMA parameter containers
150
+ rename_param = (
151
+ lambda name: name.replace("_fsdp_wrapped_module.", "")
152
+ .replace("_checkpoint_wrapped_module.", "")
153
+ .replace("_orig_mod.", "")
154
+ )
155
+ self.name_to_trainable_params = {}
156
+ for n, p in self.model.generator.named_parameters():
157
+ if not p.requires_grad:
158
+ continue
159
+
160
+ renamed_n = rename_param(n)
161
+ self.name_to_trainable_params[renamed_n] = p
162
+ ema_weight = config.ema_weight
163
+ self.generator_ema = None
164
+ if (ema_weight is not None) and (ema_weight > 0.0):
165
+ print(f"Setting up EMA with weight {ema_weight}")
166
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=ema_weight)
167
+
168
+ ##############################################################################################################
169
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
170
+ if getattr(config, "generator_ckpt", False):
171
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
172
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")
173
+ if "generator" in state_dict:
174
+ state_dict = state_dict["generator"]
175
+ elif "model" in state_dict:
176
+ state_dict = state_dict["model"]
177
+ self.model.generator.load_state_dict(
178
+ state_dict, strict=True
179
+ )
180
+ if hasattr(config, "load"):
181
+ resume_ckpt_path_critic = os.path.join(config.load, "critic")
182
+ resume_ckpt_path_generator = os.path.join(config.load, "generator")
183
+ else:
184
+ resume_ckpt_path_critic = "none"
185
+ resume_ckpt_path_generator = "none"
186
+
187
+ _, _ = self.checkpointer_critic.try_best_load(
188
+ resume_ckpt_path=resume_ckpt_path_critic,
189
+ )
190
+ self.step, _ = self.checkpointer_generator.try_best_load(
191
+ resume_ckpt_path=resume_ckpt_path_generator,
192
+ force_start_w_ema=config.force_start_w_ema,
193
+ force_reset_zero_step=config.force_reset_zero_step,
194
+ force_reinit_ema=config.force_reinit_ema,
195
+ skip_optimizer_scheduler=config.skip_optimizer_scheduler,
196
+ )
197
+
198
+ ##############################################################################################################
199
+
200
+ # Let's delete EMA params for early steps to save some computes at training and inference
201
+ if self.step < config.ema_start_step:
202
+ self.generator_ema = None
203
+
204
+ self.max_grad_norm_generator = getattr(config, "max_grad_norm_generator", 10.0)
205
+ self.max_grad_norm_critic = getattr(config, "max_grad_norm_critic", 10.0)
206
+ self.previous_time = None
207
+
208
+ def save(self):
209
+ print("Start gathering distributed model states...")
210
+ generator_state_dict = fsdp_state_dict(
211
+ self.model.generator)
212
+ critic_state_dict = fsdp_state_dict(
213
+ self.model.fake_score)
214
+
215
+ if self.config.ema_start_step < self.step:
216
+ state_dict = {
217
+ "generator": generator_state_dict,
218
+ "critic": critic_state_dict,
219
+ "generator_ema": self.generator_ema.state_dict(),
220
+ }
221
+ else:
222
+ state_dict = {
223
+ "generator": generator_state_dict,
224
+ "critic": critic_state_dict,
225
+ }
226
+
227
+ if self.is_main_process:
228
+ os.makedirs(os.path.join(self.output_path,
229
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
230
+ torch.save(state_dict, os.path.join(self.output_path,
231
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
232
+ print("Model saved to", os.path.join(self.output_path,
233
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
234
+
235
+ def fwdbwd_one_step(self, batch, train_generator):
236
+ self.model.eval() # prevent any randomness (e.g. dropout)
237
+
238
+ if self.step % 20 == 0:
239
+ torch.cuda.empty_cache()
240
+
241
+ # Step 1: Get the next batch of text prompts
242
+ text_prompts = batch["prompts"] # next(self.dataloader)
243
+ if "ode_latent" in batch:
244
+ clean_latent = batch["ode_latent"][:, -1].to(device=self.device, dtype=self.dtype)
245
+ else:
246
+ frames = batch["frames"].to(device=self.device, dtype=self.dtype)
247
+ with torch.no_grad():
248
+ clean_latent = self.model.vae.encode_to_latent(
249
+ frames).to(device=self.device, dtype=self.dtype)
250
+
251
+ image_latent = clean_latent[:, 0:1, ]
252
+
253
+ batch_size = len(text_prompts)
254
+ image_or_video_shape = list(self.config.image_or_video_shape)
255
+ image_or_video_shape[0] = batch_size
256
+
257
+ # Step 2: Extract the conditional infos
258
+ with torch.no_grad():
259
+ conditional_dict = self.model.text_encoder(
260
+ text_prompts=text_prompts)
261
+
262
+ if not getattr(self, "unconditional_dict", None):
263
+ unconditional_dict = self.model.text_encoder(
264
+ text_prompts=[self.config.negative_prompt] * batch_size)
265
+ unconditional_dict = {k: v.detach()
266
+ for k, v in unconditional_dict.items()}
267
+ self.unconditional_dict = unconditional_dict # cache the unconditional_dict
268
+ else:
269
+ unconditional_dict = self.unconditional_dict
270
+
271
+ mini_bs, full_bs = (
272
+ batch["mini_bs"],
273
+ batch["full_bs"],
274
+ )
275
+
276
+ # Step 3: Store gradients for the generator (if training the generator)
277
+ if train_generator:
278
+ gan_G_loss = self.model.generator_loss(
279
+ image_or_video_shape=image_or_video_shape,
280
+ conditional_dict=conditional_dict,
281
+ unconditional_dict=unconditional_dict,
282
+ clean_latent=clean_latent,
283
+ initial_latent=image_latent if self.config.i2v else None
284
+ )
285
+
286
+ loss_ratio = mini_bs * self.world_size / full_bs
287
+ total_loss = gan_G_loss * loss_ratio * self.loss_scale
288
+
289
+ total_loss.backward()
290
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
291
+ self.max_grad_norm_generator)
292
+
293
+ generator_log_dict = {"generator_grad_norm": generator_grad_norm,
294
+ "gan_G_loss": gan_G_loss}
295
+
296
+ return generator_log_dict
297
+ else:
298
+ generator_log_dict = {}
299
+
300
+ # Step 4: Store gradients for the critic (if training the critic)
301
+ (gan_D_loss, r1_loss, r2_loss), critic_log_dict = self.model.critic_loss(
302
+ image_or_video_shape=image_or_video_shape,
303
+ conditional_dict=conditional_dict,
304
+ unconditional_dict=unconditional_dict,
305
+ clean_latent=clean_latent,
306
+ real_image_or_video=clean_latent,
307
+ initial_latent=image_latent if self.config.i2v else None
308
+ )
309
+
310
+ loss_ratio = mini_bs * dist.get_world_size() / full_bs
311
+ total_loss = (gan_D_loss + 0.5 * (r1_loss + r2_loss)) * loss_ratio * self.loss_scale
312
+
313
+ total_loss.backward()
314
+ critic_grad_norm = self.model.fake_score.clip_grad_norm_(
315
+ self.max_grad_norm_critic)
316
+
317
+ critic_log_dict.update({"critic_grad_norm": critic_grad_norm,
318
+ "gan_D_loss": gan_D_loss,
319
+ "r1_loss": r1_loss,
320
+ "r2_loss": r2_loss})
321
+
322
+ return critic_log_dict
323
+
324
+ def generate_video(self, pipeline, prompts, image=None):
325
+ batch_size = len(prompts)
326
+ sampled_noise = torch.randn(
327
+ [batch_size, 21, 16, 60, 104], device="cuda", dtype=self.dtype
328
+ )
329
+ video, _ = pipeline.inference(
330
+ noise=sampled_noise,
331
+ text_prompts=prompts,
332
+ return_latents=True
333
+ )
334
+ current_video = video.permute(0, 1, 3, 4, 2).cpu().numpy() * 255.0
335
+ return current_video
336
+
337
+ def train(self):
338
+ start_step = self.step
339
+
340
+ while True:
341
+ if self.step == self.discriminator_warmup_steps and self.discriminator_warmup_steps != 0:
342
+ print("Resetting critic optimizer")
343
+ del self.critic_optimizer
344
+ torch.cuda.empty_cache()
345
+ # Create new optimizers
346
+ self.critic_optimizer = torch.optim.AdamW(
347
+ self.critic_param_groups,
348
+ betas=(self.config.beta1_critic, self.config.beta2_critic)
349
+ )
350
+ # Update checkpointer references
351
+ self.checkpointer_critic.optimizer = self.critic_optimizer
352
+ # Check if we're in the discriminator warmup phase
353
+ self.in_discriminator_warmup = self.step < self.discriminator_warmup_steps
354
+
355
+ # Only update generator and critic outside the warmup phase
356
+ TRAIN_GENERATOR = not self.in_discriminator_warmup and self.step % self.config.dfake_gen_update_ratio == 0
357
+
358
+ # Train the generator (only outside warmup phase)
359
+ if TRAIN_GENERATOR:
360
+ self.model.fake_score.requires_grad_(False)
361
+ self.model.generator.requires_grad_(True)
362
+ self.generator_optimizer.zero_grad(set_to_none=True)
363
+ extras_list = []
364
+ for ii, mini_batch in enumerate(self.dataloader.next()):
365
+ extra = self.fwdbwd_one_step(mini_batch, True)
366
+ extras_list.append(extra)
367
+ generator_log_dict = merge_dict_list(extras_list)
368
+ self.generator_optimizer.step()
369
+ if self.generator_ema is not None:
370
+ self.generator_ema.update(self.model.generator)
371
+ else:
372
+ generator_log_dict = {}
373
+
374
+ # Train the critic/discriminator
375
+ if self.in_discriminator_warmup:
376
+ # During warmup, only allow gradient for discriminator params
377
+ self.model.generator.requires_grad_(False)
378
+ self.model.fake_score.requires_grad_(False)
379
+
380
+ # Enable gradient only for discriminator params
381
+ for name, param in self.model.fake_score.named_parameters():
382
+ if "_cls_pred_branch" in name or "_gan_ca_blocks" in name:
383
+ param.requires_grad_(True)
384
+ else:
385
+ # Normal training mode
386
+ self.model.generator.requires_grad_(False)
387
+ self.model.fake_score.requires_grad_(True)
388
+
389
+ self.critic_optimizer.zero_grad(set_to_none=True)
390
+ extras_list = []
391
+ batch = next(self.dataloader)
392
+ extra = self.fwdbwd_one_step(batch, False)
393
+ extras_list.append(extra)
394
+ critic_log_dict = merge_dict_list(extras_list)
395
+ self.critic_optimizer.step()
396
+
397
+ # Increment the step since we finished gradient update
398
+ self.step += 1
399
+
400
+ # If we just finished warmup, print a message
401
+ if self.is_main_process and self.step == self.discriminator_warmup_steps:
402
+ print(f"Finished discriminator warmup after {self.discriminator_warmup_steps} steps")
403
+
404
+ # Create EMA params (if not already created)
405
+ if (self.step >= self.config.ema_start_step) and \
406
+ (self.generator_ema is None) and (self.config.ema_weight > 0):
407
+ self.generator_ema = EMA_FSDP(self.model.generator, decay=self.config.ema_weight)
408
+
409
+ # Save the model
410
+ if (not self.config.no_save) and (self.step - start_step) > 0 and self.step % self.config.log_iters == 0:
411
+ torch.cuda.empty_cache()
412
+ self.save()
413
+ torch.cuda.empty_cache()
414
+
415
+ # Logging
416
+ wandb_loss_dict = {
417
+ "generator_grad_norm": generator_log_dict["generator_grad_norm"],
418
+ "critic_grad_norm": critic_log_dict["critic_grad_norm"],
419
+ "real_logit": critic_log_dict["noisy_real_logit"],
420
+ "fake_logit": critic_log_dict["noisy_fake_logit"],
421
+ "r1_loss": critic_log_dict["r1_loss"],
422
+ "r2_loss": critic_log_dict["r2_loss"],
423
+ }
424
+ if TRAIN_GENERATOR:
425
+ wandb_loss_dict.update({
426
+ "generator_grad_norm": generator_log_dict["generator_grad_norm"],
427
+ })
428
+ self.all_gather_dict(wandb_loss_dict)
429
+ wandb_loss_dict["diff_logit"] = wandb_loss_dict["real_logit"] - wandb_loss_dict["fake_logit"]
430
+ wandb_loss_dict["reg_loss"] = 0.5 * (wandb_loss_dict["r1_loss"] + wandb_loss_dict["r2_loss"])
431
+
432
+ if self.is_main_process:
433
+ if self.in_discriminator_warmup:
434
+ warmup_status = f"[WARMUP {self.step}/{self.discriminator_warmup_steps}] Training only discriminator params"
435
+ print(warmup_status)
436
+ if not self.disable_wandb:
437
+ wandb_loss_dict.update({"warmup_status": 1.0})
438
+
439
+ if not self.disable_wandb:
440
+ wandb.log(wandb_loss_dict, step=self.step)
441
+
442
+ if self.step % self.config.gc_interval == 0:
443
+ if dist.get_rank() == 0:
444
+ logging.info("DistGarbageCollector: Running GC.")
445
+ gc.collect()
446
+ torch.cuda.empty_cache()
447
+
448
+ if self.is_main_process:
449
+ current_time = time.time()
450
+ if self.previous_time is None:
451
+ self.previous_time = current_time
452
+ else:
453
+ if not self.disable_wandb:
454
+ wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
455
+ self.previous_time = current_time
456
+
457
+ def all_gather_dict(self, target_dict):
458
+ for key, value in target_dict.items():
459
+ gathered_value = torch.zeros(
460
+ [self.world_size, *value.shape],
461
+ dtype=value.dtype, device=self.device)
462
+ dist.all_gather_into_tensor(gathered_value, value)
463
+ avg_value = gathered_value.mean().item()
464
+ target_dict[key] = avg_value
trainer/ode.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import logging
3
+ from utils.dataset import ODERegressionLMDBDataset, cycle
4
+ from model import ODERegression
5
+ from collections import defaultdict
6
+ from utils.misc import (
7
+ set_seed
8
+ )
9
+ import torch.distributed as dist
10
+ from omegaconf import OmegaConf
11
+ import torch
12
+ import wandb
13
+ import time
14
+ import os
15
+
16
+ from utils.distributed import barrier, fsdp_wrap, fsdp_state_dict, launch_distributed_job
17
+
18
+
19
+ class Trainer:
20
+ def __init__(self, config):
21
+ self.config = config
22
+ self.step = 0
23
+
24
+ # Step 1: Initialize the distributed training environment (rank, seed, dtype, logging etc.)
25
+ torch.backends.cuda.matmul.allow_tf32 = True
26
+ torch.backends.cudnn.allow_tf32 = True
27
+
28
+ launch_distributed_job()
29
+ global_rank = dist.get_rank()
30
+ self.world_size = dist.get_world_size()
31
+
32
+ self.dtype = torch.bfloat16 if config.mixed_precision else torch.float32
33
+ self.device = torch.cuda.current_device()
34
+ self.is_main_process = global_rank == 0
35
+ self.disable_wandb = config.disable_wandb
36
+
37
+ # use a random seed for the training
38
+ if config.seed == 0:
39
+ random_seed = torch.randint(0, 10000000, (1,), device=self.device)
40
+ dist.broadcast(random_seed, src=0)
41
+ config.seed = random_seed.item()
42
+
43
+ set_seed(config.seed + global_rank)
44
+
45
+ if self.is_main_process and not self.disable_wandb:
46
+ wandb.login(host=config.wandb_host, key=config.wandb_key)
47
+ wandb.init(
48
+ config=OmegaConf.to_container(config, resolve=True),
49
+ name=config.config_name,
50
+ mode="online",
51
+ entity=config.wandb_entity,
52
+ project=config.wandb_project,
53
+ dir=config.wandb_save_dir
54
+ )
55
+
56
+ self.output_path = config.logdir
57
+
58
+ # Step 2: Initialize the model and optimizer
59
+
60
+ assert config.distribution_loss == "ode", "Only ODE loss is supported for ODE training"
61
+ self.model = ODERegression(config, device=self.device)
62
+
63
+ self.model.generator = fsdp_wrap(
64
+ self.model.generator,
65
+ sharding_strategy=config.sharding_strategy,
66
+ mixed_precision=config.mixed_precision,
67
+ wrap_strategy=config.generator_fsdp_wrap_strategy
68
+ )
69
+ self.model.text_encoder = fsdp_wrap(
70
+ self.model.text_encoder,
71
+ sharding_strategy=config.sharding_strategy,
72
+ mixed_precision=config.mixed_precision,
73
+ wrap_strategy=config.text_encoder_fsdp_wrap_strategy,
74
+ cpu_offload=getattr(config, "text_encoder_cpu_offload", False)
75
+ )
76
+
77
+ if not config.no_visualize or config.load_raw_video:
78
+ self.model.vae = self.model.vae.to(
79
+ device=self.device, dtype=torch.bfloat16 if config.mixed_precision else torch.float32)
80
+
81
+ self.generator_optimizer = torch.optim.AdamW(
82
+ [param for param in self.model.generator.parameters()
83
+ if param.requires_grad],
84
+ lr=config.lr,
85
+ betas=(config.beta1, config.beta2),
86
+ weight_decay=config.weight_decay
87
+ )
88
+
89
+ # Step 3: Initialize the dataloader
90
+ dataset = ODERegressionLMDBDataset(
91
+ config.data_path, max_pair=getattr(config, "max_pair", int(1e8)))
92
+ sampler = torch.utils.data.distributed.DistributedSampler(
93
+ dataset, shuffle=True, drop_last=True)
94
+ dataloader = torch.utils.data.DataLoader(
95
+ dataset, batch_size=config.batch_size, sampler=sampler, num_workers=8)
96
+ total_batch_size = getattr(config, "total_batch_size", None)
97
+ if total_batch_size is not None:
98
+ assert total_batch_size == config.batch_size * self.world_size, "Gradient accumulation is not supported for ODE training"
99
+ self.dataloader = cycle(dataloader)
100
+
101
+ self.step = 0
102
+
103
+ ##############################################################################################################
104
+ # 7. (If resuming) Load the model and optimizer, lr_scheduler, ema's statedicts
105
+ if getattr(config, "generator_ckpt", False):
106
+ print(f"Loading pretrained generator from {config.generator_ckpt}")
107
+ state_dict = torch.load(config.generator_ckpt, map_location="cpu")[
108
+ 'generator']
109
+ self.model.generator.load_state_dict(
110
+ state_dict, strict=True
111
+ )
112
+
113
+ ##############################################################################################################
114
+
115
+ self.max_grad_norm = 10.0
116
+ self.previous_time = None
117
+
118
+ def save(self):
119
+ print("Start gathering distributed model states...")
120
+ generator_state_dict = fsdp_state_dict(
121
+ self.model.generator)
122
+ state_dict = {
123
+ "generator": generator_state_dict
124
+ }
125
+
126
+ if self.is_main_process:
127
+ os.makedirs(os.path.join(self.output_path,
128
+ f"checkpoint_model_{self.step:06d}"), exist_ok=True)
129
+ torch.save(state_dict, os.path.join(self.output_path,
130
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
131
+ print("Model saved to", os.path.join(self.output_path,
132
+ f"checkpoint_model_{self.step:06d}", "model.pt"))
133
+
134
+ def train_one_step(self):
135
+ VISUALIZE = self.step % 100 == 0
136
+ self.model.eval() # prevent any randomness (e.g. dropout)
137
+
138
+ # Step 1: Get the next batch of text prompts
139
+ batch = next(self.dataloader)
140
+ text_prompts = batch["prompts"]
141
+ ode_latent = batch["ode_latent"].to(
142
+ device=self.device, dtype=self.dtype)
143
+
144
+ # Step 2: Extract the conditional infos
145
+ with torch.no_grad():
146
+ conditional_dict = self.model.text_encoder(
147
+ text_prompts=text_prompts)
148
+
149
+ # Step 3: Train the generator
150
+ generator_loss, log_dict = self.model.generator_loss(
151
+ ode_latent=ode_latent,
152
+ conditional_dict=conditional_dict
153
+ )
154
+
155
+ unnormalized_loss = log_dict["unnormalized_loss"]
156
+ timestep = log_dict["timestep"]
157
+
158
+ if self.world_size > 1:
159
+ gathered_unnormalized_loss = torch.zeros(
160
+ [self.world_size, *unnormalized_loss.shape],
161
+ dtype=unnormalized_loss.dtype, device=self.device)
162
+ gathered_timestep = torch.zeros(
163
+ [self.world_size, *timestep.shape],
164
+ dtype=timestep.dtype, device=self.device)
165
+
166
+ dist.all_gather_into_tensor(
167
+ gathered_unnormalized_loss, unnormalized_loss)
168
+ dist.all_gather_into_tensor(gathered_timestep, timestep)
169
+ else:
170
+ gathered_unnormalized_loss = unnormalized_loss
171
+ gathered_timestep = timestep
172
+
173
+ loss_breakdown = defaultdict(list)
174
+ stats = {}
175
+
176
+ for index, t in enumerate(timestep):
177
+ loss_breakdown[str(int(t.item()) // 250 * 250)].append(
178
+ unnormalized_loss[index].item())
179
+
180
+ for key_t in loss_breakdown.keys():
181
+ stats["loss_at_time_" + key_t] = sum(loss_breakdown[key_t]) / \
182
+ len(loss_breakdown[key_t])
183
+
184
+ self.generator_optimizer.zero_grad()
185
+ generator_loss.backward()
186
+ generator_grad_norm = self.model.generator.clip_grad_norm_(
187
+ self.max_grad_norm)
188
+ self.generator_optimizer.step()
189
+
190
+ # Step 4: Visualization
191
+ if VISUALIZE and not self.config.no_visualize and not self.config.disable_wandb and self.is_main_process:
192
+ # Visualize the input, output, and ground truth
193
+ input = log_dict["input"]
194
+ output = log_dict["output"]
195
+ ground_truth = ode_latent[:, -1]
196
+
197
+ input_video = self.model.vae.decode_to_pixel(input)
198
+ output_video = self.model.vae.decode_to_pixel(output)
199
+ ground_truth_video = self.model.vae.decode_to_pixel(ground_truth)
200
+ input_video = 255.0 * (input_video.cpu().numpy() * 0.5 + 0.5)
201
+ output_video = 255.0 * (output_video.cpu().numpy() * 0.5 + 0.5)
202
+ ground_truth_video = 255.0 * (ground_truth_video.cpu().numpy() * 0.5 + 0.5)
203
+
204
+ # Visualize the input, output, and ground truth
205
+ wandb.log({
206
+ "input": wandb.Video(input_video, caption="Input", fps=16, format="mp4"),
207
+ "output": wandb.Video(output_video, caption="Output", fps=16, format="mp4"),
208
+ "ground_truth": wandb.Video(ground_truth_video, caption="Ground Truth", fps=16, format="mp4"),
209
+ }, step=self.step)
210
+
211
+ # Step 5: Logging
212
+ if self.is_main_process and not self.disable_wandb:
213
+ wandb_loss_dict = {
214
+ "generator_loss": generator_loss.item(),
215
+ "generator_grad_norm": generator_grad_norm.item(),
216
+ **stats
217
+ }
218
+ wandb.log(wandb_loss_dict, step=self.step)
219
+
220
+ if self.step % self.config.gc_interval == 0:
221
+ if dist.get_rank() == 0:
222
+ logging.info("DistGarbageCollector: Running GC.")
223
+ gc.collect()
224
+
225
+ def train(self):
226
+ while True:
227
+ self.train_one_step()
228
+ if (not self.config.no_save) and self.step % self.config.log_iters == 0:
229
+ self.save()
230
+ torch.cuda.empty_cache()
231
+
232
+ barrier()
233
+ if self.is_main_process:
234
+ current_time = time.time()
235
+ if self.previous_time is None:
236
+ self.previous_time = current_time
237
+ else:
238
+ if not self.disable_wandb:
239
+ wandb.log({"per iteration time": current_time - self.previous_time}, step=self.step)
240
+ self.previous_time = current_time
241
+
242
+ self.step += 1
utils/dataset.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.lmdb import get_array_shape_from_lmdb, retrieve_row_from_lmdb
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ import torch
5
+ import lmdb
6
+ import json
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import os
10
+
11
+
12
+ class TextDataset(Dataset):
13
+ def __init__(self, prompt_path, extended_prompt_path=None):
14
+ with open(prompt_path, encoding="utf-8") as f:
15
+ self.prompt_list = [line.rstrip() for line in f]
16
+
17
+ if extended_prompt_path is not None:
18
+ with open(extended_prompt_path, encoding="utf-8") as f:
19
+ self.extended_prompt_list = [line.rstrip() for line in f]
20
+ assert len(self.extended_prompt_list) == len(self.prompt_list)
21
+ else:
22
+ self.extended_prompt_list = None
23
+
24
+ def __len__(self):
25
+ return len(self.prompt_list)
26
+
27
+ def __getitem__(self, idx):
28
+ batch = {
29
+ "prompts": self.prompt_list[idx],
30
+ "idx": idx,
31
+ }
32
+ if self.extended_prompt_list is not None:
33
+ batch["extended_prompts"] = self.extended_prompt_list[idx]
34
+ return batch
35
+
36
+
37
+ class ODERegressionLMDBDataset(Dataset):
38
+ def __init__(self, data_path: str, max_pair: int = int(1e8)):
39
+ self.env = lmdb.open(data_path, readonly=True,
40
+ lock=False, readahead=False, meminit=False)
41
+
42
+ self.latents_shape = get_array_shape_from_lmdb(self.env, 'latents')
43
+ self.max_pair = max_pair
44
+
45
+ def __len__(self):
46
+ return min(self.latents_shape[0], self.max_pair)
47
+
48
+ def __getitem__(self, idx):
49
+ """
50
+ Outputs:
51
+ - prompts: List of Strings
52
+ - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
53
+ """
54
+ latents = retrieve_row_from_lmdb(
55
+ self.env,
56
+ "latents", np.float16, idx, shape=self.latents_shape[1:]
57
+ )
58
+
59
+ if len(latents.shape) == 4:
60
+ latents = latents[None, ...]
61
+
62
+ prompts = retrieve_row_from_lmdb(
63
+ self.env,
64
+ "prompts", str, idx
65
+ )
66
+ return {
67
+ "prompts": prompts,
68
+ "ode_latent": torch.tensor(latents, dtype=torch.float32)
69
+ }
70
+
71
+
72
+ class ShardingLMDBDataset(Dataset):
73
+ def __init__(self, data_path: str, max_pair: int = int(1e8)):
74
+ self.envs = []
75
+ self.index = []
76
+
77
+ for fname in sorted(os.listdir(data_path)):
78
+ path = os.path.join(data_path, fname)
79
+ env = lmdb.open(path,
80
+ readonly=True,
81
+ lock=False,
82
+ readahead=False,
83
+ meminit=False)
84
+ self.envs.append(env)
85
+
86
+ self.latents_shape = [None] * len(self.envs)
87
+ for shard_id, env in enumerate(self.envs):
88
+ self.latents_shape[shard_id] = get_array_shape_from_lmdb(env, 'latents')
89
+ for local_i in range(self.latents_shape[shard_id][0]):
90
+ self.index.append((shard_id, local_i))
91
+
92
+ # print("shard_id ", shard_id, " local_i ", local_i)
93
+
94
+ self.max_pair = max_pair
95
+
96
+ def __len__(self):
97
+ return len(self.index)
98
+
99
+ def __getitem__(self, idx):
100
+ """
101
+ Outputs:
102
+ - prompts: List of Strings
103
+ - latents: Tensor of shape (num_denoising_steps, num_frames, num_channels, height, width). It is ordered from pure noise to clean image.
104
+ """
105
+ shard_id, local_idx = self.index[idx]
106
+
107
+ latents = retrieve_row_from_lmdb(
108
+ self.envs[shard_id],
109
+ "latents", np.float16, local_idx,
110
+ shape=self.latents_shape[shard_id][1:]
111
+ )
112
+
113
+ if len(latents.shape) == 4:
114
+ latents = latents[None, ...]
115
+
116
+ prompts = retrieve_row_from_lmdb(
117
+ self.envs[shard_id],
118
+ "prompts", str, local_idx
119
+ )
120
+
121
+ return {
122
+ "prompts": prompts,
123
+ "ode_latent": torch.tensor(latents, dtype=torch.float32)
124
+ }
125
+
126
+
127
+ class TextImagePairDataset(Dataset):
128
+ def __init__(
129
+ self,
130
+ data_dir,
131
+ transform=None,
132
+ eval_first_n=-1,
133
+ pad_to_multiple_of=None
134
+ ):
135
+ """
136
+ Args:
137
+ data_dir (str): Path to the directory containing:
138
+ - target_crop_info_*.json (metadata file)
139
+ - */ (subdirectory containing images with matching aspect ratio)
140
+ transform (callable, optional): Optional transform to be applied on the image
141
+ """
142
+ self.transform = transform
143
+ data_dir = Path(data_dir)
144
+
145
+ # Find the metadata JSON file
146
+ metadata_files = list(data_dir.glob('target_crop_info_*.json'))
147
+ if not metadata_files:
148
+ raise FileNotFoundError(f"No metadata file found in {data_dir}")
149
+ if len(metadata_files) > 1:
150
+ raise ValueError(f"Multiple metadata files found in {data_dir}")
151
+
152
+ metadata_path = metadata_files[0]
153
+ # Extract aspect ratio from metadata filename (e.g. target_crop_info_26-15.json -> 26-15)
154
+ aspect_ratio = metadata_path.stem.split('_')[-1]
155
+
156
+ # Use aspect ratio subfolder for images
157
+ self.image_dir = data_dir / aspect_ratio
158
+ if not self.image_dir.exists():
159
+ raise FileNotFoundError(f"Image directory not found: {self.image_dir}")
160
+
161
+ # Load metadata
162
+ with open(metadata_path, 'r') as f:
163
+ self.metadata = json.load(f)
164
+
165
+ eval_first_n = eval_first_n if eval_first_n != -1 else len(self.metadata)
166
+ self.metadata = self.metadata[:eval_first_n]
167
+
168
+ # Verify all images exist
169
+ for item in self.metadata:
170
+ image_path = self.image_dir / item['file_name']
171
+ if not image_path.exists():
172
+ raise FileNotFoundError(f"Image not found: {image_path}")
173
+
174
+ self.dummy_prompt = "DUMMY PROMPT"
175
+ self.pre_pad_len = len(self.metadata)
176
+ if pad_to_multiple_of is not None and len(self.metadata) % pad_to_multiple_of != 0:
177
+ # Duplicate the last entry
178
+ self.metadata += [self.metadata[-1]] * (
179
+ pad_to_multiple_of - len(self.metadata) % pad_to_multiple_of
180
+ )
181
+
182
+ def __len__(self):
183
+ return len(self.metadata)
184
+
185
+ def __getitem__(self, idx):
186
+ """
187
+ Returns:
188
+ dict: A dictionary containing:
189
+ - image: PIL Image
190
+ - caption: str
191
+ - target_bbox: list of int [x1, y1, x2, y2]
192
+ - target_ratio: str
193
+ - type: str
194
+ - origin_size: tuple of int (width, height)
195
+ """
196
+ item = self.metadata[idx]
197
+
198
+ # Load image
199
+ image_path = self.image_dir / item['file_name']
200
+ image = Image.open(image_path).convert('RGB')
201
+
202
+ # Apply transform if specified
203
+ if self.transform:
204
+ image = self.transform(image)
205
+
206
+ return {
207
+ 'image': image,
208
+ 'prompts': item['caption'],
209
+ 'target_bbox': item['target_crop']['target_bbox'],
210
+ 'target_ratio': item['target_crop']['target_ratio'],
211
+ 'type': item['type'],
212
+ 'origin_size': (item['origin_width'], item['origin_height']),
213
+ 'idx': idx
214
+ }
215
+
216
+
217
+ def cycle(dl):
218
+ while True:
219
+ for data in dl:
220
+ yield data
utils/distributed.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import timedelta
2
+ from functools import partial
3
+ import os
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.distributed.fsdp import FullStateDictConfig, FullyShardedDataParallel as FSDP, MixedPrecision, ShardingStrategy, StateDictType
7
+ from torch.distributed.fsdp.api import CPUOffload
8
+ from torch.distributed.fsdp.wrap import size_based_auto_wrap_policy, transformer_auto_wrap_policy
9
+
10
+
11
+ def fsdp_state_dict(model):
12
+ fsdp_fullstate_save_policy = FullStateDictConfig(
13
+ offload_to_cpu=True, rank0_only=True
14
+ )
15
+ with FSDP.state_dict_type(
16
+ model, StateDictType.FULL_STATE_DICT, fsdp_fullstate_save_policy
17
+ ):
18
+ checkpoint = model.state_dict()
19
+
20
+ return checkpoint
21
+
22
+
23
+ def fsdp_wrap(module, sharding_strategy="full", mixed_precision=False, wrap_strategy="size", min_num_params=int(5e7), transformer_module=None, cpu_offload=False):
24
+ if mixed_precision:
25
+ mixed_precision_policy = MixedPrecision(
26
+ param_dtype=torch.bfloat16,
27
+ reduce_dtype=torch.float32,
28
+ buffer_dtype=torch.float32,
29
+ cast_forward_inputs=False
30
+ )
31
+ else:
32
+ mixed_precision_policy = None
33
+
34
+ if wrap_strategy == "transformer":
35
+ auto_wrap_policy = partial(
36
+ transformer_auto_wrap_policy,
37
+ transformer_layer_cls=transformer_module
38
+ )
39
+ elif wrap_strategy == "size":
40
+ auto_wrap_policy = partial(
41
+ size_based_auto_wrap_policy,
42
+ min_num_params=min_num_params
43
+ )
44
+ else:
45
+ raise ValueError(f"Invalid wrap strategy: {wrap_strategy}")
46
+
47
+ os.environ["NCCL_CROSS_NIC"] = "1"
48
+
49
+ sharding_strategy = {
50
+ "full": ShardingStrategy.FULL_SHARD,
51
+ "hybrid_full": ShardingStrategy.HYBRID_SHARD,
52
+ "hybrid_zero2": ShardingStrategy._HYBRID_SHARD_ZERO2,
53
+ "no_shard": ShardingStrategy.NO_SHARD,
54
+ }[sharding_strategy]
55
+
56
+ module = FSDP(
57
+ module,
58
+ auto_wrap_policy=auto_wrap_policy,
59
+ sharding_strategy=sharding_strategy,
60
+ mixed_precision=mixed_precision_policy,
61
+ device_id=torch.cuda.current_device(),
62
+ limit_all_gathers=True,
63
+ use_orig_params=True,
64
+ cpu_offload=CPUOffload(offload_params=cpu_offload),
65
+ sync_module_states=False # Load ckpt on rank 0 and sync to other ranks
66
+ )
67
+ return module
68
+
69
+
70
+ def barrier():
71
+ if dist.is_initialized():
72
+ dist.barrier()
73
+
74
+
75
+ def launch_distributed_job(backend: str = "nccl"):
76
+ rank = int(os.environ["RANK"])
77
+ local_rank = int(os.environ["LOCAL_RANK"])
78
+ world_size = int(os.environ["WORLD_SIZE"])
79
+ host = os.environ["MASTER_ADDR"]
80
+ port = int(os.environ["MASTER_PORT"])
81
+
82
+ if ":" in host: # IPv6
83
+ init_method = f"tcp://[{host}]:{port}"
84
+ else: # IPv4
85
+ init_method = f"tcp://{host}:{port}"
86
+ dist.init_process_group(rank=rank, world_size=world_size, backend=backend,
87
+ init_method=init_method, timeout=timedelta(minutes=30))
88
+ torch.cuda.set_device(local_rank)
89
+
90
+
91
+ class EMA_FSDP:
92
+ def __init__(self, fsdp_module: torch.nn.Module, decay: float = 0.999):
93
+ self.decay = decay
94
+ self.shadow = {}
95
+ self._init_shadow(fsdp_module)
96
+
97
+ @torch.no_grad()
98
+ def _init_shadow(self, fsdp_module):
99
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
100
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
101
+ for n, p in fsdp_module.module.named_parameters():
102
+ self.shadow[n] = p.detach().clone().float().cpu()
103
+
104
+ @torch.no_grad()
105
+ def update(self, fsdp_module):
106
+ d = self.decay
107
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
108
+ with FSDP.summon_full_params(fsdp_module, writeback=False):
109
+ for n, p in fsdp_module.module.named_parameters():
110
+ self.shadow[n].mul_(d).add_(p.detach().float().cpu(), alpha=1. - d)
111
+
112
+ # Optional helpers ---------------------------------------------------
113
+ def state_dict(self):
114
+ return self.shadow # picklable
115
+
116
+ def load_state_dict(self, sd):
117
+ self.shadow = {k: v.clone() for k, v in sd.items()}
118
+
119
+ def copy_to(self, fsdp_module):
120
+ # load EMA weights into an (unwrapped) copy of the generator
121
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
122
+ with FSDP.summon_full_params(fsdp_module, writeback=True):
123
+ for n, p in fsdp_module.module.named_parameters():
124
+ if n in self.shadow:
125
+ p.data.copy_(self.shadow[n].to(p.dtype, device=p.device))
utils/lmdb.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def get_array_shape_from_lmdb(env, array_name):
5
+ with env.begin() as txn:
6
+ image_shape = txn.get(f"{array_name}_shape".encode()).decode()
7
+ image_shape = tuple(map(int, image_shape.split()))
8
+ return image_shape
9
+
10
+
11
+ def store_arrays_to_lmdb(env, arrays_dict, start_index=0):
12
+ """
13
+ Store rows of multiple numpy arrays in a single LMDB.
14
+ Each row is stored separately with a naming convention.
15
+ """
16
+ with env.begin(write=True) as txn:
17
+ for array_name, array in arrays_dict.items():
18
+ for i, row in enumerate(array):
19
+ # Convert row to bytes
20
+ if isinstance(row, str):
21
+ row_bytes = row.encode()
22
+ else:
23
+ row_bytes = row.tobytes()
24
+
25
+ data_key = f'{array_name}_{start_index + i}_data'.encode()
26
+
27
+ txn.put(data_key, row_bytes)
28
+
29
+
30
+ def process_data_dict(data_dict, seen_prompts):
31
+ output_dict = {}
32
+
33
+ all_videos = []
34
+ all_prompts = []
35
+ for prompt, video in data_dict.items():
36
+ if prompt in seen_prompts:
37
+ continue
38
+ else:
39
+ seen_prompts.add(prompt)
40
+
41
+ video = video.half().numpy()
42
+ all_videos.append(video)
43
+ all_prompts.append(prompt)
44
+
45
+ if len(all_videos) == 0:
46
+ return {"latents": np.array([]), "prompts": np.array([])}
47
+
48
+ all_videos = np.concatenate(all_videos, axis=0)
49
+
50
+ output_dict['latents'] = all_videos
51
+ output_dict['prompts'] = np.array(all_prompts)
52
+
53
+ return output_dict
54
+
55
+
56
+ def retrieve_row_from_lmdb(lmdb_env, array_name, dtype, row_index, shape=None):
57
+ """
58
+ Retrieve a specific row from a specific array in the LMDB.
59
+ """
60
+ data_key = f'{array_name}_{row_index}_data'.encode()
61
+
62
+ with lmdb_env.begin() as txn:
63
+ row_bytes = txn.get(data_key)
64
+
65
+ if dtype == str:
66
+ array = row_bytes.decode()
67
+ else:
68
+ array = np.frombuffer(row_bytes, dtype=dtype)
69
+
70
+ if shape is not None and len(shape) > 0:
71
+ array = array.reshape(shape)
72
+ return array
utils/loss.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ import torch
3
+
4
+
5
+ class DenoisingLoss(ABC):
6
+ @abstractmethod
7
+ def __call__(
8
+ self, x: torch.Tensor, x_pred: torch.Tensor,
9
+ noise: torch.Tensor, noise_pred: torch.Tensor,
10
+ alphas_cumprod: torch.Tensor,
11
+ timestep: torch.Tensor,
12
+ **kwargs
13
+ ) -> torch.Tensor:
14
+ """
15
+ Base class for denoising loss.
16
+ Input:
17
+ - x: the clean data with shape [B, F, C, H, W]
18
+ - x_pred: the predicted clean data with shape [B, F, C, H, W]
19
+ - noise: the noise with shape [B, F, C, H, W]
20
+ - noise_pred: the predicted noise with shape [B, F, C, H, W]
21
+ - alphas_cumprod: the cumulative product of alphas (defining the noise schedule) with shape [T]
22
+ - timestep: the current timestep with shape [B, F]
23
+ """
24
+ pass
25
+
26
+
27
+ class X0PredLoss(DenoisingLoss):
28
+ def __call__(
29
+ self, x: torch.Tensor, x_pred: torch.Tensor,
30
+ noise: torch.Tensor, noise_pred: torch.Tensor,
31
+ alphas_cumprod: torch.Tensor,
32
+ timestep: torch.Tensor,
33
+ **kwargs
34
+ ) -> torch.Tensor:
35
+ return torch.mean((x - x_pred) ** 2)
36
+
37
+
38
+ class VPredLoss(DenoisingLoss):
39
+ def __call__(
40
+ self, x: torch.Tensor, x_pred: torch.Tensor,
41
+ noise: torch.Tensor, noise_pred: torch.Tensor,
42
+ alphas_cumprod: torch.Tensor,
43
+ timestep: torch.Tensor,
44
+ **kwargs
45
+ ) -> torch.Tensor:
46
+ weights = 1 / (1 - alphas_cumprod[timestep].reshape(*timestep.shape, 1, 1, 1))
47
+ return torch.mean(weights * (x - x_pred) ** 2)
48
+
49
+
50
+ class NoisePredLoss(DenoisingLoss):
51
+ def __call__(
52
+ self, x: torch.Tensor, x_pred: torch.Tensor,
53
+ noise: torch.Tensor, noise_pred: torch.Tensor,
54
+ alphas_cumprod: torch.Tensor,
55
+ timestep: torch.Tensor,
56
+ **kwargs
57
+ ) -> torch.Tensor:
58
+ return torch.mean((noise - noise_pred) ** 2)
59
+
60
+
61
+ class FlowPredLoss(DenoisingLoss):
62
+ def __call__(
63
+ self, x: torch.Tensor, x_pred: torch.Tensor,
64
+ noise: torch.Tensor, noise_pred: torch.Tensor,
65
+ alphas_cumprod: torch.Tensor,
66
+ timestep: torch.Tensor,
67
+ **kwargs
68
+ ) -> torch.Tensor:
69
+ return torch.mean((kwargs["flow_pred"] - (noise - x)) ** 2)
70
+
71
+
72
+ NAME_TO_CLASS = {
73
+ "x0": X0PredLoss,
74
+ "v": VPredLoss,
75
+ "noise": NoisePredLoss,
76
+ "flow": FlowPredLoss
77
+ }
78
+
79
+
80
+ def get_denoising_loss(loss_type: str) -> DenoisingLoss:
81
+ return NAME_TO_CLASS[loss_type]
utils/misc.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+
5
+
6
+ def set_seed(seed: int, deterministic: bool = False):
7
+ """
8
+ Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
9
+
10
+ Args:
11
+ seed (`int`):
12
+ The seed to set.
13
+ deterministic (`bool`, *optional*, defaults to `False`):
14
+ Whether to use deterministic algorithms where available. Can slow down training.
15
+ """
16
+ random.seed(seed)
17
+ np.random.seed(seed)
18
+ torch.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+
21
+ if deterministic:
22
+ torch.use_deterministic_algorithms(True)
23
+
24
+
25
+ def merge_dict_list(dict_list):
26
+ if len(dict_list) == 1:
27
+ return dict_list[0]
28
+
29
+ merged_dict = {}
30
+ for k, v in dict_list[0].items():
31
+ if isinstance(v, torch.Tensor):
32
+ if v.ndim == 0:
33
+ merged_dict[k] = torch.stack([d[k] for d in dict_list], dim=0)
34
+ else:
35
+ merged_dict[k] = torch.cat([d[k] for d in dict_list], dim=0)
36
+ else:
37
+ # for non-tensor values, we just copy the value from the first item
38
+ merged_dict[k] = v
39
+ return merged_dict