YuxueYang commited on
Commit
2a59fa8
·
1 Parent(s): f68f71d

Upload demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. README.md +1 -1
  3. __assets__/demos/demo_1/first_frame.jpg +0 -0
  4. __assets__/demos/demo_1/layer_0.jpg +0 -0
  5. __assets__/demos/demo_1/layer_1.jpg +0 -0
  6. __assets__/demos/demo_1/layer_2.jpg +0 -0
  7. __assets__/demos/demo_1/sketch.mp4 +0 -0
  8. __assets__/demos/demo_1/trajectory.json +200 -0
  9. __assets__/demos/demo_1/trajectory.npz +3 -0
  10. __assets__/demos/demo_2/first_frame.jpg +0 -0
  11. __assets__/demos/demo_2/layer_0.jpg +0 -0
  12. __assets__/demos/demo_2/layer_1.jpg +0 -0
  13. __assets__/demos/demo_2/layer_2.jpg +0 -0
  14. __assets__/demos/demo_2/sketch.mp4 +0 -0
  15. __assets__/demos/demo_2/trajectory.json +200 -0
  16. __assets__/demos/demo_2/trajectory.npz +3 -0
  17. __assets__/demos/demo_3/first_frame.jpg +0 -0
  18. __assets__/demos/demo_3/last_frame.jpg +0 -0
  19. __assets__/demos/demo_3/layer_0.jpg +0 -0
  20. __assets__/demos/demo_3/layer_0_last.jpg +0 -0
  21. __assets__/demos/demo_3/layer_1.jpg +0 -0
  22. __assets__/demos/demo_3/layer_1_last.jpg +0 -0
  23. __assets__/demos/demo_3/layer_2.jpg +0 -0
  24. __assets__/demos/demo_3/layer_2_last.jpg +0 -0
  25. __assets__/demos/demo_3/layer_3.jpg +0 -0
  26. __assets__/demos/demo_3/layer_3_last.jpg +0 -0
  27. __assets__/demos/demo_3/sketch.mp4 +0 -0
  28. __assets__/demos/demo_3/trajectory.json +134 -0
  29. __assets__/demos/demo_3/trajectory.npz +3 -0
  30. __assets__/demos/demo_4/first_frame.jpg +0 -0
  31. __assets__/demos/demo_4/layer_0.jpg +0 -0
  32. __assets__/demos/demo_4/layer_1.jpg +0 -0
  33. __assets__/demos/demo_4/layer_2.jpg +0 -0
  34. __assets__/demos/demo_4/sketch.mp4 +0 -0
  35. __assets__/demos/demo_4/trajectory.json +200 -0
  36. __assets__/demos/demo_4/trajectory.npz +3 -0
  37. __assets__/demos/demo_5/first_frame.jpg +0 -0
  38. __assets__/demos/demo_5/layer_0.jpg +0 -0
  39. __assets__/demos/demo_5/layer_1.jpg +0 -0
  40. __assets__/demos/demo_5/sketch.mp4 +0 -0
  41. __assets__/demos/demo_5/trajectory.json +332 -0
  42. __assets__/demos/demo_5/trajectory.npz +3 -0
  43. __assets__/figs/demos.gif +3 -0
  44. app.py +651 -0
  45. lvdm/basics.py +100 -0
  46. lvdm/common.py +94 -0
  47. lvdm/models/autoencoder.py +143 -0
  48. lvdm/models/condition.py +477 -0
  49. lvdm/models/controlnet.py +500 -0
  50. lvdm/models/layer_controlnet.py +444 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.gif filter=lfs diff=lfs merge=lfs -text
README.md CHANGED
@@ -8,7 +8,7 @@ sdk_version: 5.23.1
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: https://arxiv.org/abs/2501.08295
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
+ short_description: "LayerAnimate: Layer-level Control for Animation"
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
__assets__/demos/demo_1/first_frame.jpg ADDED
__assets__/demos/demo_1/layer_0.jpg ADDED
__assets__/demos/demo_1/layer_1.jpg ADDED
__assets__/demos/demo_1/layer_2.jpg ADDED
__assets__/demos/demo_1/sketch.mp4 ADDED
Binary file (65.5 kB). View file
 
__assets__/demos/demo_1/trajectory.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ [
4
+ 111.87965393066406,
5
+ 204.28741455078125
6
+ ],
7
+ [
8
+ 83.42483520507812,
9
+ 204.21835327148438
10
+ ],
11
+ [
12
+ 52.417137145996094,
13
+ 205.34869384765625
14
+ ],
15
+ [
16
+ -10.01504135131836,
17
+ 205.83694458007812
18
+ ],
19
+ [
20
+ -33.109561920166016,
21
+ 206.53018188476562
22
+ ],
23
+ [
24
+ -86.02885437011719,
25
+ 205.10772705078125
26
+ ],
27
+ [
28
+ -119.59435272216797,
29
+ 204.4576873779297
30
+ ],
31
+ [
32
+ -168.70248413085938,
33
+ 210.6188201904297
34
+ ],
35
+ [
36
+ -185.9542999267578,
37
+ 211.16294860839844
38
+ ],
39
+ [
40
+ -206.82852172851562,
41
+ 207.50912475585938
42
+ ],
43
+ [
44
+ -232.2637939453125,
45
+ 208.35643005371094
46
+ ],
47
+ [
48
+ -177.6964111328125,
49
+ 205.50949096679688
50
+ ],
51
+ [
52
+ -231.19761657714844,
53
+ 203.8624267578125
54
+ ],
55
+ [
56
+ -276.06622314453125,
57
+ 208.6024169921875
58
+ ],
59
+ [
60
+ -285.68218994140625,
61
+ 210.30313110351562
62
+ ],
63
+ [
64
+ -235.0211639404297,
65
+ 207.910400390625
66
+ ]
67
+ ],
68
+ [
69
+ [
70
+ 130.59063720703125,
71
+ 131.48106384277344
72
+ ],
73
+ [
74
+ 101.31892395019531,
75
+ 131.62567138671875
76
+ ],
77
+ [
78
+ 69.3387451171875,
79
+ 132.40696716308594
80
+ ],
81
+ [
82
+ 6.821704864501953,
83
+ 133.10546875
84
+ ],
85
+ [
86
+ -21.6120548248291,
87
+ 132.92977905273438
88
+ ],
89
+ [
90
+ -83.36480712890625,
91
+ 132.2947998046875
92
+ ],
93
+ [
94
+ -111.29481506347656,
95
+ 131.91827392578125
96
+ ],
97
+ [
98
+ -168.74850463867188,
99
+ 138.11587524414062
100
+ ],
101
+ [
102
+ -198.75299072265625,
103
+ 139.32774353027344
104
+ ],
105
+ [
106
+ -253.08055114746094,
107
+ 136.65480041503906
108
+ ],
109
+ [
110
+ -278.3507080078125,
111
+ 136.42958068847656
112
+ ],
113
+ [
114
+ -312.9150390625,
115
+ 134.22898864746094
116
+ ],
117
+ [
118
+ -332.20989990234375,
119
+ 133.93161010742188
120
+ ],
121
+ [
122
+ -357.1211853027344,
123
+ 139.33224487304688
124
+ ],
125
+ [
126
+ -361.4031677246094,
127
+ 139.66172790527344
128
+ ],
129
+ [
130
+ -338.45501708984375,
131
+ 141.38809204101562
132
+ ]
133
+ ],
134
+ [
135
+ [
136
+ 308.344970703125,
137
+ 6.6701483726501465
138
+ ],
139
+ [
140
+ 278.66864013671875,
141
+ 7.116205215454102
142
+ ],
143
+ [
144
+ 247.65390014648438,
145
+ 7.756659507751465
146
+ ],
147
+ [
148
+ 184.76953125,
149
+ 8.749884605407715
150
+ ],
151
+ [
152
+ 154.9658203125,
153
+ 8.66163444519043
154
+ ],
155
+ [
156
+ 92.775146484375,
157
+ 7.572597503662109
158
+ ],
159
+ [
160
+ 63.20433044433594,
161
+ 7.524573802947998
162
+ ],
163
+ [
164
+ 1.4797935485839844,
165
+ 13.07353401184082
166
+ ],
167
+ [
168
+ -26.288057327270508,
169
+ 13.74260139465332
170
+ ],
171
+ [
172
+ -83.00379943847656,
173
+ 11.522849082946777
174
+ ],
175
+ [
176
+ -109.52509307861328,
177
+ 10.739717483520508
178
+ ],
179
+ [
180
+ -140.5462646484375,
181
+ 8.596296310424805
182
+ ],
183
+ [
184
+ -155.35394287109375,
185
+ 8.009984970092773
186
+ ],
187
+ [
188
+ -180.55775451660156,
189
+ 13.584362030029297
190
+ ],
191
+ [
192
+ -185.0371856689453,
193
+ 14.09956169128418
194
+ ],
195
+ [
196
+ -203.57778930664062,
197
+ 18.082473754882812
198
+ ]
199
+ ]
200
+ ]
__assets__/demos/demo_1/trajectory.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:232a68740a9d2828277e786d760cb2d7436f4617ae1d64d31a61888be0c65ea1
3
+ size 994
__assets__/demos/demo_2/first_frame.jpg ADDED
__assets__/demos/demo_2/layer_0.jpg ADDED
__assets__/demos/demo_2/layer_1.jpg ADDED
__assets__/demos/demo_2/layer_2.jpg ADDED
__assets__/demos/demo_2/sketch.mp4 ADDED
Binary file (13 kB). View file
 
__assets__/demos/demo_2/trajectory.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ [
4
+ 158.21946716308594,
5
+ 245.89105224609375
6
+ ],
7
+ [
8
+ 148.94857788085938,
9
+ 246.4789276123047
10
+ ],
11
+ [
12
+ 137.88522338867188,
13
+ 247.1299285888672
14
+ ],
15
+ [
16
+ 128.4403839111328,
17
+ 247.8033905029297
18
+ ],
19
+ [
20
+ 127.84039306640625,
21
+ 246.24864196777344
22
+ ],
23
+ [
24
+ 127.06155395507812,
25
+ 244.60606384277344
26
+ ],
27
+ [
28
+ 126.77435302734375,
29
+ 243.17208862304688
30
+ ],
31
+ [
32
+ 126.42509460449219,
33
+ 243.04747009277344
34
+ ],
35
+ [
36
+ 125.61285400390625,
37
+ 242.14913940429688
38
+ ],
39
+ [
40
+ 125.40904235839844,
41
+ 242.65948486328125
42
+ ],
43
+ [
44
+ 125.03759765625,
45
+ 242.90908813476562
46
+ ],
47
+ [
48
+ 124.67877197265625,
49
+ 242.95994567871094
50
+ ],
51
+ [
52
+ 125.00759887695312,
53
+ 242.61265563964844
54
+ ],
55
+ [
56
+ 125.37916564941406,
57
+ 242.13555908203125
58
+ ],
59
+ [
60
+ 125.7420654296875,
61
+ 242.410888671875
62
+ ],
63
+ [
64
+ 125.54336547851562,
65
+ 242.98825073242188
66
+ ]
67
+ ],
68
+ [
69
+ [
70
+ 223.55435180664062,
71
+ 204.28741455078125
72
+ ],
73
+ [
74
+ 207.83377075195312,
75
+ 202.7445068359375
76
+ ],
77
+ [
78
+ 193.4696044921875,
79
+ 200.418701171875
80
+ ],
81
+ [
82
+ 178.7669677734375,
83
+ 199.83621215820312
84
+ ],
85
+ [
86
+ 178.14218139648438,
87
+ 200.34848022460938
88
+ ],
89
+ [
90
+ 176.58251953125,
91
+ 200.19627380371094
92
+ ],
93
+ [
94
+ 175.0523681640625,
95
+ 200.24407958984375
96
+ ],
97
+ [
98
+ 174.57379150390625,
99
+ 199.90940856933594
100
+ ],
101
+ [
102
+ 173.37542724609375,
103
+ 200.4640350341797
104
+ ],
105
+ [
106
+ 173.5262451171875,
107
+ 200.5198974609375
108
+ ],
109
+ [
110
+ 173.60935974121094,
111
+ 200.36471557617188
112
+ ],
113
+ [
114
+ 173.8643035888672,
115
+ 200.39389038085938
116
+ ],
117
+ [
118
+ 173.903076171875,
119
+ 200.2958984375
120
+ ],
121
+ [
122
+ 173.96859741210938,
123
+ 200.00491333007812
124
+ ],
125
+ [
126
+ 174.22422790527344,
127
+ 200.09921264648438
128
+ ],
129
+ [
130
+ 174.16683959960938,
131
+ 200.00193786621094
132
+ ]
133
+ ],
134
+ [
135
+ [
136
+ 232.88790893554688,
137
+ 261.492431640625
138
+ ],
139
+ [
140
+ 224.37376403808594,
141
+ 258.9049072265625
142
+ ],
143
+ [
144
+ 214.7504119873047,
145
+ 255.82171630859375
146
+ ],
147
+ [
148
+ 205.59695434570312,
149
+ 252.74368286132812
150
+ ],
151
+ [
152
+ 203.56024169921875,
153
+ 254.83567810058594
154
+ ],
155
+ [
156
+ 200.3128662109375,
157
+ 256.933349609375
158
+ ],
159
+ [
160
+ 197.56045532226562,
161
+ 258.17236328125
162
+ ],
163
+ [
164
+ 196.72007751464844,
165
+ 258.3282470703125
166
+ ],
167
+ [
168
+ 194.2041473388672,
169
+ 259.42486572265625
170
+ ],
171
+ [
172
+ 194.23858642578125,
173
+ 259.9649353027344
174
+ ],
175
+ [
176
+ 194.01547241210938,
177
+ 260.14569091796875
178
+ ],
179
+ [
180
+ 193.87156677246094,
181
+ 259.9699401855469
182
+ ],
183
+ [
184
+ 193.9617919921875,
185
+ 259.7339172363281
186
+ ],
187
+ [
188
+ 193.89659118652344,
189
+ 259.5014343261719
190
+ ],
191
+ [
192
+ 193.8680419921875,
193
+ 259.7557373046875
194
+ ],
195
+ [
196
+ 193.91842651367188,
197
+ 260.28717041015625
198
+ ]
199
+ ]
200
+ ]
__assets__/demos/demo_2/trajectory.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ba8194e3bd1376e10cb6c708d59603c406269b95bb1e266b20c7cfa66e248875
3
+ size 972
__assets__/demos/demo_3/first_frame.jpg ADDED
__assets__/demos/demo_3/last_frame.jpg ADDED
__assets__/demos/demo_3/layer_0.jpg ADDED
__assets__/demos/demo_3/layer_0_last.jpg ADDED
__assets__/demos/demo_3/layer_1.jpg ADDED
__assets__/demos/demo_3/layer_1_last.jpg ADDED
__assets__/demos/demo_3/layer_2.jpg ADDED
__assets__/demos/demo_3/layer_2_last.jpg ADDED
__assets__/demos/demo_3/layer_3.jpg ADDED
__assets__/demos/demo_3/layer_3_last.jpg ADDED
__assets__/demos/demo_3/sketch.mp4 ADDED
Binary file (54.1 kB). View file
 
__assets__/demos/demo_3/trajectory.json ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ [
4
+ 49.66927719116211,
5
+ 126.28060150146484
6
+ ],
7
+ [
8
+ 53.070796966552734,
9
+ 140.00479125976562
10
+ ],
11
+ [
12
+ 58.86982345581055,
13
+ 157.8321533203125
14
+ ],
15
+ [
16
+ 69.01676177978516,
17
+ 175.84800720214844
18
+ ],
19
+ [
20
+ 76.01651000976562,
21
+ 197.62847900390625
22
+ ],
23
+ [
24
+ 93.34223937988281,
25
+ 232.17538452148438
26
+ ],
27
+ [
28
+ 96.88280487060547,
29
+ 246.68162536621094
30
+ ],
31
+ [
32
+ 105.09373474121094,
33
+ 265.91741943359375
34
+ ],
35
+ [
36
+ 122.41947174072266,
37
+ 300.46429443359375
38
+ ],
39
+ [
40
+ 139.74520874023438,
41
+ 335.0111999511719
42
+ ],
43
+ [
44
+ 157.07093811035156,
45
+ 369.55810546875
46
+ ],
47
+ [
48
+ 174.39666748046875,
49
+ 404.10498046875
50
+ ],
51
+ [
52
+ 191.722412109375,
53
+ 438.65185546875
54
+ ],
55
+ [
56
+ 209.0481414794922,
57
+ 473.19873046875
58
+ ],
59
+ [
60
+ 226.37387084960938,
61
+ 507.74560546875
62
+ ],
63
+ [
64
+ 243.6995849609375,
65
+ 542.29248046875
66
+ ]
67
+ ],
68
+ [
69
+ [
70
+ 56.677669525146484,
71
+ 69.07560729980469
72
+ ],
73
+ [
74
+ 66.92218780517578,
75
+ 90.37911224365234
76
+ ],
77
+ [
78
+ 79.62323760986328,
79
+ 116.14250183105469
80
+ ],
81
+ [
82
+ 91.2628173828125,
83
+ 141.8087921142578
84
+ ],
85
+ [
86
+ 103.7956771850586,
87
+ 167.58724975585938
88
+ ],
89
+ [
90
+ 117.59683227539062,
91
+ 195.22598266601562
92
+ ],
93
+ [
94
+ 127.79037475585938,
95
+ 221.12567138671875
96
+ ],
97
+ [
98
+ 140.4638671875,
99
+ 248.97164916992188
100
+ ],
101
+ [
102
+ 138.9651641845703,
103
+ 256.9488830566406
104
+ ],
105
+ [
106
+ 165.24566650390625,
107
+ 296.32525634765625
108
+ ],
109
+ [
110
+ 191.52615356445312,
111
+ 335.70166015625
112
+ ],
113
+ [
114
+ 217.806640625,
115
+ 375.07806396484375
116
+ ],
117
+ [
118
+ 244.08714294433594,
119
+ 414.4544372558594
120
+ ],
121
+ [
122
+ 270.3676452636719,
123
+ 453.830810546875
124
+ ],
125
+ [
126
+ 296.64813232421875,
127
+ 493.20721435546875
128
+ ],
129
+ [
130
+ 322.92864990234375,
131
+ 532.5836181640625
132
+ ]
133
+ ]
134
+ ]
__assets__/demos/demo_3/trajectory.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a1080b8523b361f2e4fb3f5591c88f50e44d176a404e5f62b04cfc2bfe8c2f5d
3
+ size 857
__assets__/demos/demo_4/first_frame.jpg ADDED
__assets__/demos/demo_4/layer_0.jpg ADDED
__assets__/demos/demo_4/layer_1.jpg ADDED
__assets__/demos/demo_4/layer_2.jpg ADDED
__assets__/demos/demo_4/sketch.mp4 ADDED
Binary file (65.7 kB). View file
 
__assets__/demos/demo_4/trajectory.json ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ [
4
+ 186.72357177734375,
5
+ 225.0892333984375
6
+ ],
7
+ [
8
+ 186.59104919433594,
9
+ 220.61599731445312
10
+ ],
11
+ [
12
+ 190.39842224121094,
13
+ 216.0291748046875
14
+ ],
15
+ [
16
+ 199.52769470214844,
17
+ 213.26031494140625
18
+ ],
19
+ [
20
+ 204.145263671875,
21
+ 214.56866455078125
22
+ ],
23
+ [
24
+ 209.41751098632812,
25
+ 214.23330688476562
26
+ ],
27
+ [
28
+ 211.30255126953125,
29
+ 216.12774658203125
30
+ ],
31
+ [
32
+ 215.53131103515625,
33
+ 215.55880737304688
34
+ ],
35
+ [
36
+ 211.28453063964844,
37
+ 215.3497314453125
38
+ ],
39
+ [
40
+ 205.66819763183594,
41
+ 210.34344482421875
42
+ ],
43
+ [
44
+ 208.09231567382812,
45
+ 197.720458984375
46
+ ],
47
+ [
48
+ 201.51205444335938,
49
+ 215.72598266601562
50
+ ],
51
+ [
52
+ 191.19480895996094,
53
+ 223.12850952148438
54
+ ],
55
+ [
56
+ 194.90512084960938,
57
+ 222.38108825683594
58
+ ],
59
+ [
60
+ 200.74607849121094,
61
+ 217.3187713623047
62
+ ],
63
+ [
64
+ 207.563720703125,
65
+ 235.63250732421875
66
+ ]
67
+ ],
68
+ [
69
+ [
70
+ 289.63397216796875,
71
+ 230.28970336914062
72
+ ],
73
+ [
74
+ 289.8543701171875,
75
+ 227.20205688476562
76
+ ],
77
+ [
78
+ 292.2384033203125,
79
+ 223.03854370117188
80
+ ],
81
+ [
82
+ 301.47711181640625,
83
+ 219.50289916992188
84
+ ],
85
+ [
86
+ 308.8260803222656,
87
+ 220.3004608154297
88
+ ],
89
+ [
90
+ 315.6751403808594,
91
+ 219.62095642089844
92
+ ],
93
+ [
94
+ 317.8089599609375,
95
+ 221.09295654296875
96
+ ],
97
+ [
98
+ 320.73956298828125,
99
+ 221.21011352539062
100
+ ],
101
+ [
102
+ 317.1898193359375,
103
+ 221.21250915527344
104
+ ],
105
+ [
106
+ 319.5433349609375,
107
+ 217.74606323242188
108
+ ],
109
+ [
110
+ 317.6147155761719,
111
+ 207.62603759765625
112
+ ],
113
+ [
114
+ 308.29156494140625,
115
+ 224.09878540039062
116
+ ],
117
+ [
118
+ 294.7052917480469,
119
+ 230.4814910888672
120
+ ],
121
+ [
122
+ 298.7985534667969,
123
+ 230.0016326904297
124
+ ],
125
+ [
126
+ 304.0728454589844,
127
+ 226.04998779296875
128
+ ],
129
+ [
130
+ 314.6731872558594,
131
+ 242.630126953125
132
+ ]
133
+ ],
134
+ [
135
+ [
136
+ 214.7900390625,
137
+ 230.28970336914062
138
+ ],
139
+ [
140
+ 214.2034912109375,
141
+ 226.12539672851562
142
+ ],
143
+ [
144
+ 216.921630859375,
145
+ 221.91062927246094
146
+ ],
147
+ [
148
+ 226.7117156982422,
149
+ 219.55148315429688
150
+ ],
151
+ [
152
+ 232.1102294921875,
153
+ 220.2542724609375
154
+ ],
155
+ [
156
+ 237.49270629882812,
157
+ 219.5577850341797
158
+ ],
159
+ [
160
+ 240.1033935546875,
161
+ 220.77169799804688
162
+ ],
163
+ [
164
+ 243.27154541015625,
165
+ 220.56069946289062
166
+ ],
167
+ [
168
+ 240.3792724609375,
169
+ 221.12344360351562
170
+ ],
171
+ [
172
+ 235.10897827148438,
173
+ 216.4136962890625
174
+ ],
175
+ [
176
+ 234.0819091796875,
177
+ 202.91900634765625
178
+ ],
179
+ [
180
+ 224.08642578125,
181
+ 220.4688720703125
182
+ ],
183
+ [
184
+ 212.40911865234375,
185
+ 227.7927703857422
186
+ ],
187
+ [
188
+ 218.22300720214844,
189
+ 226.47549438476562
190
+ ],
191
+ [
192
+ 225.32315063476562,
193
+ 221.8306884765625
194
+ ],
195
+ [
196
+ 234.59808349609375,
197
+ 239.94235229492188
198
+ ]
199
+ ]
200
+ ]
__assets__/demos/demo_4/trajectory.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c2904e38cbc8820daaa5f88085bbfc33aa3cd8b9be7d9588e02d6cadcccf2fa
3
+ size 973
__assets__/demos/demo_5/first_frame.jpg ADDED
__assets__/demos/demo_5/layer_0.jpg ADDED
__assets__/demos/demo_5/layer_1.jpg ADDED
__assets__/demos/demo_5/sketch.mp4 ADDED
Binary file (93.4 kB). View file
 
__assets__/demos/demo_5/trajectory.json ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ [
3
+ [
4
+ 494.2274169921875,
5
+ 22.271512985229492
6
+ ],
7
+ [
8
+ 499.44189453125,
9
+ 21.746015548706055
10
+ ],
11
+ [
12
+ 504.0919189453125,
13
+ 21.225364685058594
14
+ ],
15
+ [
16
+ 514.5880737304688,
17
+ 20.82619285583496
18
+ ],
19
+ [
20
+ 520.4939575195312,
21
+ 20.672199249267578
22
+ ],
23
+ [
24
+ 526.637451171875,
25
+ 20.305557250976562
26
+ ],
27
+ [
28
+ 534.9617919921875,
29
+ 20.358591079711914
30
+ ],
31
+ [
32
+ 539.2017211914062,
33
+ 20.12591552734375
34
+ ],
35
+ [
36
+ 543.9376220703125,
37
+ 20.107173919677734
38
+ ],
39
+ [
40
+ 549.5306396484375,
41
+ 19.739456176757812
42
+ ],
43
+ [
44
+ 553.4171142578125,
45
+ 20.842308044433594
46
+ ],
47
+ [
48
+ 554.49462890625,
49
+ 20.15322494506836
50
+ ],
51
+ [
52
+ 559.0555419921875,
53
+ 21.292396545410156
54
+ ],
55
+ [
56
+ 558.5130004882812,
57
+ 21.357444763183594
58
+ ],
59
+ [
60
+ 561.72607421875,
61
+ 20.114139556884766
62
+ ],
63
+ [
64
+ 560.4268798828125,
65
+ 21.73964500427246
66
+ ]
67
+ ],
68
+ [
69
+ [
70
+ 494.2274169921875,
71
+ 48.27378463745117
72
+ ],
73
+ [
74
+ 494.85711669921875,
75
+ 48.05669403076172
76
+ ],
77
+ [
78
+ 494.21563720703125,
79
+ 48.0822868347168
80
+ ],
81
+ [
82
+ 492.88446044921875,
83
+ 48.20854187011719
84
+ ],
85
+ [
86
+ 491.5914306640625,
87
+ 48.36796569824219
88
+ ],
89
+ [
90
+ 490.6370849609375,
91
+ 48.649070739746094
92
+ ],
93
+ [
94
+ 488.6202392578125,
95
+ 48.874202728271484
96
+ ],
97
+ [
98
+ 487.603271484375,
99
+ 49.16374969482422
100
+ ],
101
+ [
102
+ 486.469970703125,
103
+ 49.414939880371094
104
+ ],
105
+ [
106
+ 484.92120361328125,
107
+ 49.98759460449219
108
+ ],
109
+ [
110
+ 483.7000427246094,
111
+ 50.26809310913086
112
+ ],
113
+ [
114
+ 482.22125244140625,
115
+ 50.42219161987305
116
+ ],
117
+ [
118
+ 480.54931640625,
119
+ 50.766448974609375
120
+ ],
121
+ [
122
+ 479.24481201171875,
123
+ 51.03229522705078
124
+ ],
125
+ [
126
+ 478.1097106933594,
127
+ 51.489837646484375
128
+ ],
129
+ [
130
+ 476.470947265625,
131
+ 52.048194885253906
132
+ ]
133
+ ],
134
+ [
135
+ [
136
+ 64.8839111328125,
137
+ 287.4947204589844
138
+ ],
139
+ [
140
+ 81.71736145019531,
141
+ 288.09869384765625
142
+ ],
143
+ [
144
+ 100.02552795410156,
145
+ 288.89111328125
146
+ ],
147
+ [
148
+ 128.72686767578125,
149
+ 289.8943176269531
150
+ ],
151
+ [
152
+ 149.62322998046875,
153
+ 290.7263488769531
154
+ ],
155
+ [
156
+ 170.50192260742188,
157
+ 291.29925537109375
158
+ ],
159
+ [
160
+ 203.6192626953125,
161
+ 292.2691345214844
162
+ ],
163
+ [
164
+ 227.08547973632812,
165
+ 292.68035888671875
166
+ ],
167
+ [
168
+ 250.68621826171875,
169
+ 293.3591613769531
170
+ ],
171
+ [
172
+ 286.62176513671875,
173
+ 294.1515197753906
174
+ ],
175
+ [
176
+ 311.21240234375,
177
+ 294.3829650878906
178
+ ],
179
+ [
180
+ 335.68389892578125,
181
+ 294.7114562988281
182
+ ],
183
+ [
184
+ 373.18115234375,
185
+ 295.2404479980469
186
+ ],
187
+ [
188
+ 397.2961120605469,
189
+ 295.111572265625
190
+ ],
191
+ [
192
+ 422.346923828125,
193
+ 295.5068054199219
194
+ ],
195
+ [
196
+ 457.2431335449219,
197
+ 295.49383544921875
198
+ ]
199
+ ],
200
+ [
201
+ [
202
+ 64.8839111328125,
203
+ 235.4901580810547
204
+ ],
205
+ [
206
+ 61.33024597167969,
207
+ 235.5504150390625
208
+ ],
209
+ [
210
+ 57.36271667480469,
211
+ 235.6099090576172
212
+ ],
213
+ [
214
+ 50.592864990234375,
215
+ 235.9037322998047
216
+ ],
217
+ [
218
+ 46.184783935546875,
219
+ 235.94981384277344
220
+ ],
221
+ [
222
+ 42.2303466796875,
223
+ 235.8488006591797
224
+ ],
225
+ [
226
+ 35.333221435546875,
227
+ 235.73272705078125
228
+ ],
229
+ [
230
+ 29.864356994628906,
231
+ 236.13253784179688
232
+ ],
233
+ [
234
+ 24.596290588378906,
235
+ 236.366943359375
236
+ ],
237
+ [
238
+ 17.585124969482422,
239
+ 236.61953735351562
240
+ ],
241
+ [
242
+ 12.934989929199219,
243
+ 236.7737274169922
244
+ ],
245
+ [
246
+ 8.478790283203125,
247
+ 236.75421142578125
248
+ ],
249
+ [
250
+ 2.206012725830078,
251
+ 236.9993896484375
252
+ ],
253
+ [
254
+ -2.862123489379883,
255
+ 237.2617645263672
256
+ ],
257
+ [
258
+ -7.3507843017578125,
259
+ 237.2784423828125
260
+ ],
261
+ [
262
+ -12.782325744628906,
263
+ 237.2703094482422
264
+ ]
265
+ ],
266
+ [
267
+ [
268
+ 92.88457489013672,
269
+ 225.0892333984375
270
+ ],
271
+ [
272
+ 88.737548828125,
273
+ 225.09442138671875
274
+ ],
275
+ [
276
+ 84.08223724365234,
277
+ 225.36553955078125
278
+ ],
279
+ [
280
+ 76.90846252441406,
281
+ 225.7208251953125
282
+ ],
283
+ [
284
+ 72.26066589355469,
285
+ 225.9451141357422
286
+ ],
287
+ [
288
+ 67.7042465209961,
289
+ 226.13169860839844
290
+ ],
291
+ [
292
+ 60.917144775390625,
293
+ 226.32199096679688
294
+ ],
295
+ [
296
+ 55.98236083984375,
297
+ 226.5792236328125
298
+ ],
299
+ [
300
+ 51.30162811279297,
301
+ 226.9581298828125
302
+ ],
303
+ [
304
+ 44.654823303222656,
305
+ 227.06956481933594
306
+ ],
307
+ [
308
+ 40.06951904296875,
309
+ 227.15420532226562
310
+ ],
311
+ [
312
+ 35.59206771850586,
313
+ 227.13719177246094
314
+ ],
315
+ [
316
+ 29.056011199951172,
317
+ 227.17002868652344
318
+ ],
319
+ [
320
+ 24.805736541748047,
321
+ 227.24826049804688
322
+ ],
323
+ [
324
+ 20.537612915039062,
325
+ 227.34564208984375
326
+ ],
327
+ [
328
+ 14.309333801269531,
329
+ 227.30154418945312
330
+ ]
331
+ ]
332
+ ]
__assets__/demos/demo_5/trajectory.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e9da4a1142e8210f0486ff1682fe7853e8714ecf813bfef4b9019efbc102f61
3
+ size 1222
__assets__/figs/demos.gif ADDED

Git LFS Details

  • SHA256: 1fec782faeaf8433550a05a782216b449e84bb3e1c1db03cbcd2fbb25f5a0bc1
  • Pointer size: 133 Bytes
  • Size of remote file: 10.4 MB
app.py ADDED
@@ -0,0 +1,651 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import datetime
3
+ import os
4
+ import json
5
+
6
+ import torch
7
+ import torchvision.transforms as transforms
8
+ from torchvision.transforms import functional as F
9
+
10
+ import spaces
11
+ from huggingface_hub import snapshot_download
12
+ import gradio as gr
13
+
14
+ from diffusers import DDIMScheduler
15
+
16
+ from lvdm.models.unet import UNetModel
17
+ from lvdm.models.autoencoder import AutoencoderKL, AutoencoderKL_Dualref
18
+ from lvdm.models.condition import FrozenOpenCLIPEmbedder, FrozenOpenCLIPImageEmbedderV2, Resampler
19
+ from lvdm.models.layer_controlnet import LayerControlNet
20
+ from lvdm.pipelines.pipeline_animation import AnimationPipeline
21
+ from lvdm.utils import generate_gaussian_heatmap, save_videos_grid, save_videos_with_traj
22
+
23
+ from einops import rearrange
24
+ import cv2
25
+ import decord
26
+ from PIL import Image
27
+ import numpy as np
28
+ from scipy.interpolate import PchipInterpolator
29
+
30
+ SAVE_DIR = "outputs"
31
+ LENGTH = 16
32
+ WIDTH = 512
33
+ HEIGHT = 320
34
+ LAYER_CAPACITY = 4
35
+ DEVICE = "cuda"
36
+
37
+ os.makedirs("checkpoints", exist_ok=True)
38
+
39
+ snapshot_download(
40
+ "Yuppie1204/LayerAnimate-Mix",
41
+ local_dir="checkpoints/LayerAnimate-Mix",
42
+ )
43
+
44
+ class LayerAnimate:
45
+
46
+ @spaces.GPU
47
+ def __init__(self):
48
+ self.savedir = SAVE_DIR
49
+ os.makedirs(self.savedir, exist_ok=True)
50
+
51
+ self.weight_dtype = torch.bfloat16
52
+ self.device = DEVICE
53
+ self.text_encoder = FrozenOpenCLIPEmbedder().eval()
54
+ self.image_encoder = FrozenOpenCLIPImageEmbedderV2().eval()
55
+
56
+ self.W = WIDTH
57
+ self.H = HEIGHT
58
+ self.L = LENGTH
59
+ self.layer_capacity = LAYER_CAPACITY
60
+
61
+ self.transforms = transforms.Compose([
62
+ transforms.Resize(min(self.H, self.W)),
63
+ transforms.CenterCrop((self.H, self.W)),
64
+ ])
65
+ self.pipeline = None
66
+ self.generator = None
67
+ # sample_grid is used to generate fixed trajectories to freeze static layers
68
+ self.sample_grid = np.meshgrid(np.linspace(0, self.W - 1, 10, dtype=int), np.linspace(0, self.H - 1, 10, dtype=int))
69
+ self.sample_grid = np.stack(self.sample_grid, axis=-1).reshape(-1, 1, 2)
70
+ self.sample_grid = np.repeat(self.sample_grid, self.L, axis=1) # [N, F, 2]
71
+
72
+ @spaces.GPU
73
+ def set_seed(self, seed):
74
+ np.random.seed(seed)
75
+ torch.manual_seed(seed)
76
+ self.generator = torch.Generator(self.device).manual_seed(seed)
77
+
78
+ @spaces.GPU
79
+ def set_model(self, pretrained_model_path):
80
+ scheduler = DDIMScheduler.from_pretrained(pretrained_model_path, subfolder="scheduler")
81
+ image_projector = Resampler.from_pretrained(pretrained_model_path, subfolder="image_projector").eval()
82
+ vae, vae_dualref = None, None
83
+ if "I2V" or "Mix" in pretrained_model_path:
84
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").eval()
85
+ if "Interp" or "Mix" in pretrained_model_path:
86
+ vae_dualref = AutoencoderKL_Dualref.from_pretrained(pretrained_model_path, subfolder="vae_dualref").eval()
87
+ unet = UNetModel.from_pretrained(pretrained_model_path, subfolder="unet").eval()
88
+ layer_controlnet = LayerControlNet.from_pretrained(pretrained_model_path, subfolder="layer_controlnet").eval()
89
+
90
+ self.pipeline = AnimationPipeline(
91
+ vae=vae, vae_dualref=vae_dualref, text_encoder=self.text_encoder, image_encoder=self.image_encoder, image_projector=image_projector,
92
+ unet=unet, layer_controlnet=layer_controlnet, scheduler=scheduler
93
+ ).to(device=self.device, dtype=self.weight_dtype)
94
+ if "Interp" or "Mix" in pretrained_model_path:
95
+ self.pipeline.vae_dualref.decoder.to(dtype=torch.float32)
96
+ return pretrained_model_path
97
+
98
+ def upload_image(self, image):
99
+ image = self.transforms(image)
100
+ return image
101
+
102
+ def run(self, input_image, input_image_end, pretrained_model_path, seed,
103
+ prompt, n_prompt, num_inference_steps, guidance_scale,
104
+ *layer_args):
105
+ self.set_seed(seed)
106
+ global layer_tracking_points
107
+ args_layer_tracking_points = [layer_tracking_points[i].value for i in range(self.layer_capacity)]
108
+
109
+ args_layer_masks = layer_args[:self.layer_capacity]
110
+ args_layer_masks_end = layer_args[self.layer_capacity : 2 * self.layer_capacity]
111
+ args_layer_controls = layer_args[2 * self.layer_capacity : 3 * self.layer_capacity]
112
+ args_layer_scores = list(layer_args[3 * self.layer_capacity : 4 * self.layer_capacity])
113
+ args_layer_sketches = layer_args[4 * self.layer_capacity : 5 * self.layer_capacity]
114
+ args_layer_valids = layer_args[5 * self.layer_capacity : 6 * self.layer_capacity]
115
+ args_layer_statics = layer_args[6 * self.layer_capacity : 7 * self.layer_capacity]
116
+ for layer_idx in range(self.layer_capacity):
117
+ if args_layer_controls[layer_idx] != "score":
118
+ args_layer_scores[layer_idx] = -1
119
+ if args_layer_statics[layer_idx]:
120
+ args_layer_scores[layer_idx] = 0
121
+
122
+ mode = "i2v"
123
+ image1 = F.to_tensor(input_image) * 2 - 1
124
+ frame_tensor = image1[None].to(self.device) # [F, C, H, W]
125
+ if input_image_end is not None:
126
+ mode = "interpolate"
127
+ image2 = F.to_tensor(input_image_end) * 2 - 1
128
+ frame_tensor2 = image2[None].to(self.device)
129
+ frame_tensor = torch.cat([frame_tensor, frame_tensor2], dim=0)
130
+ frame_tensor = frame_tensor[None]
131
+
132
+ if mode == "interpolate":
133
+ layer_masks = torch.zeros((1, self.layer_capacity, 2, 1, self.H, self.W), dtype=torch.bool)
134
+ else:
135
+ layer_masks = torch.zeros((1, self.layer_capacity, 1, 1, self.H, self.W), dtype=torch.bool)
136
+ for layer_idx in range(self.layer_capacity):
137
+ if args_layer_masks[layer_idx] is not None:
138
+ mask = F.to_tensor(args_layer_masks[layer_idx]) > 0.5
139
+ layer_masks[0, layer_idx, 0] = mask
140
+ if args_layer_masks_end[layer_idx] is not None and mode == "interpolate":
141
+ mask = F.to_tensor(args_layer_masks_end[layer_idx]) > 0.5
142
+ layer_masks[0, layer_idx, 1] = mask
143
+ layer_masks = layer_masks.to(self.device)
144
+ layer_regions = layer_masks * frame_tensor[:, None]
145
+ layer_validity = torch.tensor([args_layer_valids], dtype=torch.bool, device=self.device)
146
+ motion_scores = torch.tensor([args_layer_scores], dtype=self.weight_dtype, device=self.device)
147
+ layer_static = torch.tensor([args_layer_statics], dtype=torch.bool, device=self.device)
148
+
149
+ sketch = torch.ones((1, self.layer_capacity, self.L, 3, self.H, self.W), dtype=self.weight_dtype)
150
+ for layer_idx in range(self.layer_capacity):
151
+ sketch_path = args_layer_sketches[layer_idx]
152
+ if sketch_path is not None:
153
+ video_reader = decord.VideoReader(sketch_path)
154
+ assert len(video_reader) == self.L, f"Input the length of sketch sequence should match the video length."
155
+ video_frames = video_reader.get_batch(range(self.L)).asnumpy()
156
+ sketch_values = [F.to_tensor(self.transforms(Image.fromarray(frame))) for frame in video_frames]
157
+ sketch_values = torch.stack(sketch_values) * 2 - 1
158
+ sketch[0, layer_idx] = sketch_values
159
+ sketch = sketch.to(self.device)
160
+
161
+ heatmap = torch.zeros((1, self.layer_capacity, self.L, 3, self.H, self.W), dtype=self.weight_dtype)
162
+ heatmap[:, :, :, 0] -= 1
163
+ trajectory = []
164
+ traj_layer_index = []
165
+ for layer_idx in range(self.layer_capacity):
166
+ tracking_points = args_layer_tracking_points[layer_idx]
167
+ if args_layer_statics[layer_idx]:
168
+ # generate pseudo trajectory for static layers
169
+ temp_layer_mask = layer_masks[0, layer_idx, 0, 0].cpu().numpy()
170
+ valid_flag = temp_layer_mask[self.sample_grid[:, 0, 1], self.sample_grid[:, 0, 0]]
171
+ valid_grid = self.sample_grid[valid_flag] # [F, N, 2]
172
+ trajectory.extend(list(valid_grid))
173
+ traj_layer_index.extend([layer_idx] * valid_grid.shape[0])
174
+ else:
175
+ for temp_track in tracking_points:
176
+ if len(temp_track) > 1:
177
+ x = [point[0] for point in temp_track]
178
+ y = [point[1] for point in temp_track]
179
+ t = np.linspace(0, 1, len(temp_track))
180
+ fx = PchipInterpolator(t, x)
181
+ fy = PchipInterpolator(t, y)
182
+ t_new = np.linspace(0, 1, self.L)
183
+ x_new = fx(t_new)
184
+ y_new = fy(t_new)
185
+ temp_traj = np.stack([x_new, y_new], axis=-1).astype(np.float32)
186
+ trajectory.append(temp_traj)
187
+ traj_layer_index.append(layer_idx)
188
+ elif len(temp_track) == 1:
189
+ trajectory.append(np.array(temp_track * self.L))
190
+ traj_layer_index.append(layer_idx)
191
+
192
+ trajectory = np.stack(trajectory)
193
+ trajectory = np.transpose(trajectory, (1, 0, 2))
194
+ traj_layer_index = np.array(traj_layer_index)
195
+ heatmap = generate_gaussian_heatmap(trajectory, self.W, self.H, traj_layer_index, self.layer_capacity, offset=True)
196
+ heatmap = rearrange(heatmap, "f n c h w -> (f n) c h w")
197
+ graymap, offset = heatmap[:, :1], heatmap[:, 1:]
198
+ graymap = graymap / 255.
199
+ rad = torch.sqrt(offset[:, 0:1]**2 + offset[:, 1:2]**2)
200
+ rad_max = torch.max(rad)
201
+ epsilon = 1e-5
202
+ offset = offset / (rad_max + epsilon)
203
+ graymap = graymap * 2 - 1
204
+ heatmap = torch.cat([graymap, offset], dim=1)
205
+ heatmap = rearrange(heatmap, '(f n) c h w -> n f c h w', n=self.layer_capacity)
206
+ heatmap = heatmap[None]
207
+ heatmap = heatmap.to(self.device)
208
+
209
+ sample = self.pipeline(
210
+ prompt,
211
+ self.L,
212
+ self.H,
213
+ self.W,
214
+ frame_tensor,
215
+ layer_masks = layer_masks,
216
+ layer_regions = layer_regions,
217
+ layer_static = layer_static,
218
+ motion_scores = motion_scores,
219
+ sketch = sketch,
220
+ trajectory = heatmap,
221
+ layer_validity = layer_validity,
222
+ num_inference_steps = num_inference_steps,
223
+ guidance_scale = guidance_scale,
224
+ guidance_rescale = 0.7,
225
+ negative_prompt = n_prompt,
226
+ num_videos_per_prompt = 1,
227
+ eta = 1.0,
228
+ generator = self.generator,
229
+ fps = 24,
230
+ mode = mode,
231
+ weight_dtype = self.weight_dtype,
232
+ output_type = "tensor",
233
+ ).videos
234
+ output_video_path = os.path.join(self.savedir, "video.mp4")
235
+ save_videos_grid(sample, output_video_path, fps=8)
236
+ output_video_traj_path = os.path.join(self.savedir, "video_with_traj.mp4")
237
+ vis_traj_flag = np.zeros(trajectory.shape[1], dtype=bool)
238
+ for traj_idx in range(trajectory.shape[1]):
239
+ if not args_layer_statics[traj_layer_index[traj_idx]]:
240
+ vis_traj_flag[traj_idx] = True
241
+ vis_traj = torch.from_numpy(trajectory[:, vis_traj_flag])
242
+ save_videos_with_traj(sample[0], vis_traj, os.path.join(self.savedir, f"video_with_traj.mp4"), fps=8, line_width=7, circle_radius=10)
243
+ return output_video_path, output_video_traj_path
244
+
245
+
246
+ def update_layer_region(image, layer_mask):
247
+ if image is None or layer_mask is None:
248
+ return None, False
249
+ layer_mask_tensor = (F.to_tensor(layer_mask) > 0.5).float()
250
+ image = F.to_tensor(image)
251
+ layer_region = image * layer_mask_tensor
252
+ layer_region = F.to_pil_image(layer_region)
253
+ layer_region.putalpha(layer_mask)
254
+ return layer_region, True
255
+
256
+ def control_layers(control_type):
257
+ if control_type == "score":
258
+ return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
259
+ elif control_type == "trajectory":
260
+ return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)
261
+ else:
262
+ return gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
263
+
264
+ def visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask):
265
+ first_mask_tensor = (F.to_tensor(first_mask) > 0.5).float()
266
+ first_frame = F.to_tensor(first_frame)
267
+ first_region = first_frame * first_mask_tensor
268
+ first_region = F.to_pil_image(first_region)
269
+ first_region.putalpha(first_mask)
270
+ transparent_background = first_region.convert('RGBA')
271
+
272
+ if last_frame is not None and last_mask is not None:
273
+ last_mask_tensor = (F.to_tensor(last_mask) > 0.5).float()
274
+ last_frame = F.to_tensor(last_frame)
275
+ last_region = last_frame * last_mask_tensor
276
+ last_region = F.to_pil_image(last_region)
277
+ last_region.putalpha(last_mask)
278
+ transparent_background_end = last_region.convert('RGBA')
279
+
280
+ width, height = transparent_background.size
281
+ transparent_layer = np.zeros((height, width, 4))
282
+ for track in tracking_points:
283
+ if len(track) > 1:
284
+ for i in range(len(track)-1):
285
+ start_point = np.array(track[i], dtype=np.int32)
286
+ end_point = np.array(track[i+1], dtype=np.int32)
287
+ vx = end_point[0] - start_point[0]
288
+ vy = end_point[1] - start_point[1]
289
+ arrow_length = max(np.sqrt(vx**2 + vy**2), 1)
290
+ if i == len(track)-2:
291
+ cv2.arrowedLine(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2, tipLength=8 / arrow_length)
292
+ else:
293
+ cv2.line(transparent_layer, tuple(start_point), tuple(end_point), (255, 0, 0, 255), 2,)
294
+ elif len(track) == 1:
295
+ cv2.circle(transparent_layer, tuple(track[0]), 5, (255, 0, 0, 255), -1)
296
+ transparent_layer = Image.fromarray(transparent_layer.astype(np.uint8))
297
+ trajectory_map = Image.alpha_composite(transparent_background, transparent_layer)
298
+ if last_frame is not None and last_mask is not None:
299
+ trajectory_map_end = Image.alpha_composite(transparent_background_end, transparent_layer)
300
+ else:
301
+ trajectory_map_end = None
302
+ return trajectory_map, trajectory_map_end
303
+
304
+ def add_drag(layer_idx):
305
+ global layer_tracking_points
306
+ tracking_points = layer_tracking_points[layer_idx].value
307
+ tracking_points.append([])
308
+ return
309
+
310
+ def delete_last_drag(layer_idx, first_frame, first_mask, last_frame, last_mask):
311
+ global layer_tracking_points
312
+ tracking_points = layer_tracking_points[layer_idx].value
313
+ tracking_points.pop()
314
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
315
+ return trajectory_map, trajectory_map_end
316
+
317
+ def delete_last_step(layer_idx, first_frame, first_mask, last_frame, last_mask):
318
+ global layer_tracking_points
319
+ tracking_points = layer_tracking_points[layer_idx].value
320
+ tracking_points[-1].pop()
321
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
322
+ return trajectory_map, trajectory_map_end
323
+
324
+ def add_tracking_points(layer_idx, first_frame, first_mask, last_frame, last_mask, evt: gr.SelectData): # SelectData is a subclass of EventData
325
+ print(f"You selected {evt.value} at {evt.index} from {evt.target}")
326
+ global layer_tracking_points
327
+ tracking_points = layer_tracking_points[layer_idx].value
328
+ tracking_points[-1].append(evt.index)
329
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
330
+ return trajectory_map, trajectory_map_end
331
+
332
+ def reset_states(layer_idx, first_frame, first_mask, last_frame, last_mask):
333
+ global layer_tracking_points
334
+ layer_tracking_points[layer_idx].value = [[]]
335
+ tracking_points = layer_tracking_points[layer_idx].value
336
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
337
+ return trajectory_map, trajectory_map_end
338
+
339
+ def upload_tracking_points(tracking_path, layer_idx, first_frame, first_mask, last_frame, last_mask):
340
+ if tracking_path is None:
341
+ layer_region, _ = update_layer_region(first_frame, first_mask)
342
+ layer_region_end, _ = update_layer_region(last_frame, last_mask)
343
+ return layer_region, layer_region_end
344
+
345
+ global layer_tracking_points
346
+ with open(tracking_path, "r") as f:
347
+ tracking_points = json.load(f)
348
+ layer_tracking_points[layer_idx].value = tracking_points
349
+ trajectory_map, trajectory_map_end = visualize_trajectory(tracking_points, first_frame, first_mask, last_frame, last_mask)
350
+ return trajectory_map, trajectory_map_end
351
+
352
+ def reset_all_controls():
353
+ global layer_tracking_points
354
+ outputs = []
355
+ # Reset tracking points states
356
+ for layer_idx in range(LAYER_CAPACITY):
357
+ layer_tracking_points[layer_idx].value = [[]]
358
+
359
+ # Reset global components
360
+ outputs.extend([
361
+ "an anime scene.", # text prompt
362
+ "", # negative text prompt
363
+ 50, # inference steps
364
+ 7.5, # guidance scale
365
+ 42, # seed
366
+ None, # input image
367
+ None, # input image end
368
+ None, # output video
369
+ None, # output video with trajectory
370
+ ])
371
+ # Reset layer controls visibility
372
+ outputs.extend([None] * LAYER_CAPACITY) # layer masks
373
+ outputs.extend([None] * LAYER_CAPACITY) # layer masks end
374
+ outputs.extend([None] * LAYER_CAPACITY) # layer regions
375
+ outputs.extend([None] * LAYER_CAPACITY) # layer regions end
376
+ outputs.extend(["sketch"] * LAYER_CAPACITY) # layer controls
377
+ outputs.extend([gr.update(visible=False, value=-1) for _ in range(LAYER_CAPACITY)]) # layer score controls
378
+ outputs.extend([gr.update(visible=False) for _ in range(4 * LAYER_CAPACITY)]) # layer trajectory control 4 buttons
379
+ outputs.extend([gr.update(visible=False, value=None) for _ in range(LAYER_CAPACITY)]) # layer trajectory file
380
+ outputs.extend([None] * LAYER_CAPACITY) # layer sketch controls
381
+ outputs.extend([False] * LAYER_CAPACITY) # layer validity
382
+ outputs.extend([False] * LAYER_CAPACITY) # layer statics
383
+ return outputs
384
+
385
+ if __name__ == "__main__":
386
+ with gr.Blocks() as demo:
387
+ gr.Markdown("""<h1 align="center">LayerAnimate: Layer-level Control for Animation</h1><br>""")
388
+
389
+ gr.Markdown("""Gradio Demo for <a href='https://arxiv.org/abs/2501.08295'><b>LayerAnimate: Layer-level Control for Animation</b></a>.<br>
390
+ Github Repo can be found at https://github.com/IamCreateAI/LayerAnimate<br>
391
+ The template is inspired by Framer.""")
392
+
393
+ gr.Image(label="LayerAnimate: Layer-level Control for Animation", value="__assets__/figs/demos.gif", height=540, width=960)
394
+
395
+ gr.Markdown("""## Usage: <br>
396
+ 1. Select a pretrained model via the "Pretrained Model" dropdown of choices in the right column.<br>
397
+ 2. Upload frames in the right column.<br>
398
+ &ensp; 1.1. Upload the first frame.<br>
399
+ &ensp; 1.2. (Optional) Upload the last frame.<br>
400
+ 3. Input layer-level controls in the left column.<br>
401
+ &ensp; 2.1. Upload layer mask images for each layer, which can be obtained from many tools such as https://huggingface.co/spaces/yumyum2081/SAM2-Image-Predictor.<br>
402
+ &ensp; 2.2. Choose a control type from "motion score", "trajectory" and "sketch".<br>
403
+ &ensp; 2.3. For trajectory control, you can draw trajectories on layer regions.<br>
404
+ &ensp; &ensp; 2.3.1. Click "Add New Trajectory" to add a new trajectory.<br>
405
+ &ensp; &ensp; 2.3.2. Click "Reset" to reset all trajectories.<br>
406
+ &ensp; &ensp; 2.3.3. Click "Delete Last Step" to delete the lastest clicked control point.<br>
407
+ &ensp; &ensp; 2.3.4. Click "Delete Last Trajectory" to delete the whole lastest path.<br>
408
+ &ensp; &ensp; 2.3.5. Or upload a trajectory file in json format, we provide examples below.<br>
409
+ &ensp; 2.4. For sketch control, you can upload a sketch video.<br>
410
+ 4. We provide four layers for you to control, and it is not necessary to use all of them.<br>
411
+ 5. Click "Run" button to generate videos. <br>
412
+ 6. **Note: Remember to click "Clear" button to clear all the controls before switching to another example.**<br>
413
+ """)
414
+
415
+ layeranimate = LayerAnimate()
416
+ layer_indices = [gr.Number(value=i, visible=False) for i in range(LAYER_CAPACITY)]
417
+ layer_tracking_points = [gr.State([[]]) for _ in range(LAYER_CAPACITY)]
418
+ layer_masks = []
419
+ layer_masks_end = []
420
+ layer_regions = []
421
+ layer_regions_end = []
422
+ layer_controls = []
423
+ layer_score_controls = []
424
+ layer_traj_controls = []
425
+ layer_traj_files = []
426
+ layer_sketch_controls = []
427
+ layer_statics = []
428
+ layer_valids = []
429
+
430
+ with gr.Row():
431
+ with gr.Column(scale=1):
432
+ for layer_idx in range(LAYER_CAPACITY):
433
+ with gr.Accordion(label=f"Layer {layer_idx+1}", open=True if layer_idx == 0 else False):
434
+ gr.Markdown("""<div align="center"><b>Layer Masks</b></div>""")
435
+ gr.Markdown("**Note**: Layer mask for the last frame is not required in I2V mode.")
436
+ with gr.Row():
437
+ with gr.Column():
438
+ layer_masks.append(gr.Image(
439
+ label="Layer Mask for First Frame",
440
+ height=320,
441
+ width=512,
442
+ image_mode="L",
443
+ type="pil",
444
+ ))
445
+
446
+ with gr.Column():
447
+ layer_masks_end.append(gr.Image(
448
+ label="Layer Mask for Last Frame",
449
+ height=320,
450
+ width=512,
451
+ image_mode="L",
452
+ type="pil",
453
+ ))
454
+ gr.Markdown("""<div align="center"><b>Layer Regions</b></div>""")
455
+ with gr.Row():
456
+ with gr.Column():
457
+ layer_regions.append(gr.Image(
458
+ label="Layer Region for First Frame",
459
+ height=320,
460
+ width=512,
461
+ image_mode="RGBA",
462
+ type="pil",
463
+ # value=Image.new("RGBA", (512, 320), (255, 255, 255, 0)),
464
+ ))
465
+
466
+ with gr.Column():
467
+ layer_regions_end.append(gr.Image(
468
+ label="Layer Region for Last Frame",
469
+ height=320,
470
+ width=512,
471
+ image_mode="RGBA",
472
+ type="pil",
473
+ # value=Image.new("RGBA", (512, 320), (255, 255, 255, 0)),
474
+ ))
475
+ layer_controls.append(
476
+ gr.Radio(["score", "trajectory", "sketch"], label="Choose A Control Type", value="sketch")
477
+ )
478
+ layer_score_controls.append(
479
+ gr.Number(label="Motion Score", value=-1, visible=False)
480
+ )
481
+ layer_traj_controls.append(
482
+ [
483
+ gr.Button(value="Add New Trajectory", visible=False),
484
+ gr.Button(value="Reset", visible=False),
485
+ gr.Button(value="Delete Last Step", visible=False),
486
+ gr.Button(value="Delete Last Trajectory", visible=False),
487
+ ]
488
+ )
489
+ layer_traj_files.append(
490
+ gr.File(label="Trajectory File", visible=False)
491
+ )
492
+ layer_sketch_controls.append(
493
+ gr.Video(label="Sketch", height=320, width=512, visible=True)
494
+ )
495
+ layer_controls[layer_idx].change(
496
+ fn=control_layers,
497
+ inputs=layer_controls[layer_idx],
498
+ outputs=[layer_score_controls[layer_idx], *layer_traj_controls[layer_idx], layer_traj_files[layer_idx], layer_sketch_controls[layer_idx]]
499
+ )
500
+ with gr.Row():
501
+ layer_valids.append(gr.Checkbox(label="Valid", info="Is the layer valid?"))
502
+ layer_statics.append(gr.Checkbox(label="Static", info="Is the layer static?"))
503
+
504
+ with gr.Column(scale=1):
505
+ pretrained_model_path = gr.Dropdown(
506
+ label="Pretrained Model",
507
+ choices=[
508
+ "None",
509
+ "checkpoints/LayerAnimate-Mix",
510
+ ],
511
+ value="None",
512
+ )
513
+ text_prompt = gr.Textbox(label="Text Prompt", value="an anime scene.")
514
+ text_n_prompt = gr.Textbox(label="Negative Text Prompt", value="")
515
+ with gr.Row():
516
+ num_inference_steps = gr.Number(label="Inference Steps", value=50, minimum=1, maximum=1000)
517
+ guidance_scale = gr.Number(label="Guidance Scale", value=7.5)
518
+ seed = gr.Number(label="Seed", value=42)
519
+ with gr.Row():
520
+ input_image = gr.Image(
521
+ label="First Frame",
522
+ height=320,
523
+ width=512,
524
+ type="pil",
525
+ )
526
+ input_image_end = gr.Image(
527
+ label="Last Frame",
528
+ height=320,
529
+ width=512,
530
+ type="pil",
531
+ )
532
+ run_button = gr.Button(value="Run")
533
+ with gr.Row():
534
+ output_video = gr.Video(
535
+ label="Output Video",
536
+ height=320,
537
+ width=512,
538
+ )
539
+ output_video_traj = gr.Video(
540
+ label="Output Video with Trajectory",
541
+ height=320,
542
+ width=512,
543
+ )
544
+ clear_button = gr.Button(value="Clear")
545
+
546
+ with gr.Row():
547
+ gr.Markdown("""
548
+ ## Citation
549
+ ```bibtex
550
+ @article{yang2025layeranimate,
551
+ author = {Yang, Yuxue and Fan, Lue and Lin, Zuzeng and Wang, Feng and Zhang, Zhaoxiang},
552
+ title = {LayerAnimate: Layer-level Control for Animation},
553
+ journal = {arXiv preprint arXiv:2501.08295},
554
+ year = {2025},
555
+ }
556
+ ```
557
+ """)
558
+
559
+ pretrained_model_path.input(layeranimate.set_model, pretrained_model_path, pretrained_model_path)
560
+ input_image.upload(layeranimate.upload_image, input_image, input_image)
561
+ input_image_end.upload(layeranimate.upload_image, input_image_end, input_image_end)
562
+ for i in range(LAYER_CAPACITY):
563
+ layer_masks[i].upload(layeranimate.upload_image, layer_masks[i], layer_masks[i])
564
+ layer_masks[i].change(update_layer_region, [input_image, layer_masks[i]], [layer_regions[i], layer_valids[i]])
565
+ layer_masks_end[i].upload(layeranimate.upload_image, layer_masks_end[i], layer_masks_end[i])
566
+ layer_masks_end[i].change(update_layer_region, [input_image_end, layer_masks_end[i]], [layer_regions_end[i], layer_valids[i]])
567
+ layer_traj_controls[i][0].click(add_drag, layer_indices[i], None)
568
+ layer_traj_controls[i][1].click(
569
+ reset_states,
570
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
571
+ [layer_regions[i], layer_regions_end[i]]
572
+ )
573
+ layer_traj_controls[i][2].click(
574
+ delete_last_step,
575
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
576
+ [layer_regions[i], layer_regions_end[i]]
577
+ )
578
+ layer_traj_controls[i][3].click(
579
+ delete_last_drag,
580
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
581
+ [layer_regions[i], layer_regions_end[i]]
582
+ )
583
+ layer_traj_files[i].change(
584
+ upload_tracking_points,
585
+ [layer_traj_files[i], layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
586
+ [layer_regions[i], layer_regions_end[i]]
587
+ )
588
+ layer_regions[i].select(
589
+ add_tracking_points,
590
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
591
+ [layer_regions[i], layer_regions_end[i]]
592
+ )
593
+ layer_regions_end[i].select(
594
+ add_tracking_points,
595
+ [layer_indices[i], input_image, layer_masks[i], input_image_end, layer_masks_end[i]],
596
+ [layer_regions[i], layer_regions_end[i]]
597
+ )
598
+ run_button.click(
599
+ layeranimate.run,
600
+ [input_image, input_image_end, pretrained_model_path, seed, text_prompt, text_n_prompt, num_inference_steps, guidance_scale,
601
+ *layer_masks, *layer_masks_end, *layer_controls, *layer_score_controls, *layer_sketch_controls, *layer_valids, *layer_statics],
602
+ [output_video, output_video_traj]
603
+ )
604
+ clear_button.click(
605
+ reset_all_controls,
606
+ [],
607
+ [
608
+ text_prompt, text_n_prompt, num_inference_steps, guidance_scale, seed,
609
+ input_image, input_image_end, output_video, output_video_traj,
610
+ *layer_masks, *layer_masks_end, *layer_regions, *layer_regions_end,
611
+ *layer_controls, *layer_score_controls, *[button for temp_layer_controls in layer_traj_controls for button in temp_layer_controls], *layer_traj_files,
612
+ *layer_sketch_controls, *layer_valids, *layer_statics
613
+ ]
614
+ )
615
+ examples = gr.Examples(
616
+ examples=[
617
+ [
618
+ "__assets__/demos/demo_3/first_frame.jpg", "__assets__/demos/demo_3/last_frame.jpg",
619
+ "score", "__assets__/demos/demo_3/layer_0.jpg", "__assets__/demos/demo_3/layer_0_last.jpg", 0.4, None, None, True, False,
620
+ "score", "__assets__/demos/demo_3/layer_1.jpg", "__assets__/demos/demo_3/layer_1_last.jpg", 0.2, None, None, True, False,
621
+ "trajectory", "__assets__/demos/demo_3/layer_2.jpg", "__assets__/demos/demo_3/layer_2_last.jpg", -1, "__assets__/demos/demo_3/trajectory.json", None, True, False,
622
+ "sketch", "__assets__/demos/demo_3/layer_3.jpg", "__assets__/demos/demo_3/layer_3_last.jpg", -1, None, "__assets__/demos/demo_3/sketch.mp4", True, False,
623
+ 52
624
+ ],
625
+ [
626
+ "__assets__/demos/demo_4/first_frame.jpg", None,
627
+ "score", "__assets__/demos/demo_4/layer_0.jpg", None, 0.0, None, None, True, True,
628
+ "trajectory", "__assets__/demos/demo_4/layer_1.jpg", None, -1, "__assets__/demos/demo_4/trajectory.json", None, True, False,
629
+ "sketch", "__assets__/demos/demo_4/layer_2.jpg", None, -1, None, "__assets__/demos/demo_4/sketch.mp4", True, False,
630
+ "score", None, None, -1, None, None, False, False,
631
+ 42
632
+ ],
633
+ [
634
+ "__assets__/demos/demo_5/first_frame.jpg", None,
635
+ "sketch", "__assets__/demos/demo_5/layer_0.jpg", None, -1, None, "__assets__/demos/demo_5/sketch.mp4", True, False,
636
+ "trajectory", "__assets__/demos/demo_5/layer_1.jpg", None, -1, "__assets__/demos/demo_5/trajectory.json", None, True, False,
637
+ "score", None, None, -1, None, None, False, False,
638
+ "score", None, None, -1, None, None, False, False,
639
+ 47
640
+ ],
641
+ ],
642
+ inputs=[
643
+ input_image, input_image_end,
644
+ layer_controls[0], layer_masks[0], layer_masks_end[0], layer_score_controls[0], layer_traj_files[0], layer_sketch_controls[0], layer_valids[0], layer_statics[0],
645
+ layer_controls[1], layer_masks[1], layer_masks_end[1], layer_score_controls[1], layer_traj_files[1], layer_sketch_controls[1], layer_valids[1], layer_statics[1],
646
+ layer_controls[2], layer_masks[2], layer_masks_end[2], layer_score_controls[2], layer_traj_files[2], layer_sketch_controls[2], layer_valids[2], layer_statics[2],
647
+ layer_controls[3], layer_masks[3], layer_masks_end[3], layer_score_controls[3], layer_traj_files[3], layer_sketch_controls[3], layer_valids[3], layer_statics[3],
648
+ seed
649
+ ],
650
+ )
651
+ demo.launch()
lvdm/basics.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # adopted from
2
+ # https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
3
+ # and
4
+ # https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
5
+ # and
6
+ # https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
7
+ #
8
+ # thanks!
9
+
10
+ import torch.nn as nn
11
+ from .utils import instantiate_from_config
12
+
13
+
14
+ def disabled_train(self, mode=True):
15
+ """Overwrite model.train with this function to make sure train/eval mode
16
+ does not change anymore."""
17
+ return self
18
+
19
+ def zero_module(module):
20
+ """
21
+ Zero out the parameters of a module and return it.
22
+ """
23
+ for p in module.parameters():
24
+ p.detach().zero_()
25
+ return module
26
+
27
+ def scale_module(module, scale):
28
+ """
29
+ Scale the parameters of a module and return it.
30
+ """
31
+ for p in module.parameters():
32
+ p.detach().mul_(scale)
33
+ return module
34
+
35
+
36
+ def conv_nd(dims, *args, **kwargs):
37
+ """
38
+ Create a 1D, 2D, or 3D convolution module.
39
+ """
40
+ if dims == 1:
41
+ return nn.Conv1d(*args, **kwargs)
42
+ elif dims == 2:
43
+ return nn.Conv2d(*args, **kwargs)
44
+ elif dims == 3:
45
+ return nn.Conv3d(*args, **kwargs)
46
+ raise ValueError(f"unsupported dimensions: {dims}")
47
+
48
+
49
+ def linear(*args, **kwargs):
50
+ """
51
+ Create a linear module.
52
+ """
53
+ return nn.Linear(*args, **kwargs)
54
+
55
+
56
+ def avg_pool_nd(dims, *args, **kwargs):
57
+ """
58
+ Create a 1D, 2D, or 3D average pooling module.
59
+ """
60
+ if dims == 1:
61
+ return nn.AvgPool1d(*args, **kwargs)
62
+ elif dims == 2:
63
+ return nn.AvgPool2d(*args, **kwargs)
64
+ elif dims == 3:
65
+ return nn.AvgPool3d(*args, **kwargs)
66
+ raise ValueError(f"unsupported dimensions: {dims}")
67
+
68
+
69
+ def nonlinearity(type='silu'):
70
+ if type == 'silu':
71
+ return nn.SiLU()
72
+ elif type == 'leaky_relu':
73
+ return nn.LeakyReLU()
74
+
75
+
76
+ class GroupNormSpecific(nn.GroupNorm):
77
+ def forward(self, x):
78
+ return super().forward(x.float()).type(x.dtype)
79
+
80
+
81
+ def normalization(channels, num_groups=32):
82
+ """
83
+ Make a standard normalization layer.
84
+ :param channels: number of input channels.
85
+ :return: an nn.Module for normalization.
86
+ """
87
+ return GroupNormSpecific(num_groups, channels)
88
+
89
+
90
+ class HybridConditioner(nn.Module):
91
+
92
+ def __init__(self, c_concat_config, c_crossattn_config):
93
+ super().__init__()
94
+ self.concat_conditioner = instantiate_from_config(c_concat_config)
95
+ self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
96
+
97
+ def forward(self, c_concat, c_crossattn):
98
+ c_concat = self.concat_conditioner(c_concat)
99
+ c_crossattn = self.crossattn_conditioner(c_crossattn)
100
+ return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
lvdm/common.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from inspect import isfunction
3
+ import torch
4
+ from torch import nn
5
+ import torch.distributed as dist
6
+
7
+
8
+ def gather_data(data, return_np=True):
9
+ ''' gather data from multiple processes to one list '''
10
+ data_list = [torch.zeros_like(data) for _ in range(dist.get_world_size())]
11
+ dist.all_gather(data_list, data) # gather not supported with NCCL
12
+ if return_np:
13
+ data_list = [data.cpu().numpy() for data in data_list]
14
+ return data_list
15
+
16
+ def autocast(f):
17
+ def do_autocast(*args, **kwargs):
18
+ with torch.cuda.amp.autocast(enabled=True,
19
+ dtype=torch.get_autocast_gpu_dtype(),
20
+ cache_enabled=torch.is_autocast_cache_enabled()):
21
+ return f(*args, **kwargs)
22
+ return do_autocast
23
+
24
+
25
+ def extract_into_tensor(a, t, x_shape):
26
+ b, *_ = t.shape
27
+ out = a.gather(-1, t)
28
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
29
+
30
+
31
+ def noise_like(shape, device, repeat=False):
32
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
33
+ noise = lambda: torch.randn(shape, device=device)
34
+ return repeat_noise() if repeat else noise()
35
+
36
+
37
+ def default(val, d):
38
+ if exists(val):
39
+ return val
40
+ return d() if isfunction(d) else d
41
+
42
+ def exists(val):
43
+ return val is not None
44
+
45
+ def identity(*args, **kwargs):
46
+ return nn.Identity()
47
+
48
+ def uniq(arr):
49
+ return{el: True for el in arr}.keys()
50
+
51
+ def mean_flat(tensor):
52
+ """
53
+ Take the mean over all non-batch dimensions.
54
+ """
55
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
56
+
57
+ def ismap(x):
58
+ if not isinstance(x, torch.Tensor):
59
+ return False
60
+ return (len(x.shape) == 4) and (x.shape[1] > 3)
61
+
62
+ def isimage(x):
63
+ if not isinstance(x,torch.Tensor):
64
+ return False
65
+ return (len(x.shape) == 4) and (x.shape[1] == 3 or x.shape[1] == 1)
66
+
67
+ def max_neg_value(t):
68
+ return -torch.finfo(t.dtype).max
69
+
70
+ def shape_to_str(x):
71
+ shape_str = "x".join([str(x) for x in x.shape])
72
+ return shape_str
73
+
74
+ def init_(tensor):
75
+ dim = tensor.shape[-1]
76
+ std = 1 / math.sqrt(dim)
77
+ tensor.uniform_(-std, std)
78
+ return tensor
79
+
80
+ ckpt = torch.utils.checkpoint.checkpoint
81
+ def checkpoint(func, inputs, params, flag):
82
+ """
83
+ Evaluate a function without caching intermediate activations, allowing for
84
+ reduced memory at the expense of extra compute in the backward pass.
85
+ :param func: the function to evaluate.
86
+ :param inputs: the argument sequence to pass to `func`.
87
+ :param params: a sequence of parameters `func` depends on but does not
88
+ explicitly take as arguments.
89
+ :param flag: if False, disable gradient checkpointing.
90
+ """
91
+ if flag:
92
+ return ckpt(func, *inputs, use_reentrant=False)
93
+ else:
94
+ return func(*inputs)
lvdm/models/autoencoder.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from functools import partial
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ import numpy as np
7
+ from einops import rearrange
8
+ import torch.nn.functional as F
9
+ from torch.utils.checkpoint import checkpoint
10
+ from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution
11
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
12
+ from diffusers.models import ModelMixin
13
+ from diffusers.utils import BaseOutput
14
+
15
+ from ..modules.ae_modules import Encoder, Decoder
16
+ from ..modules.ae_dualref_modules import VideoDecoder
17
+ from ..utils import instantiate_from_config
18
+
19
+
20
+ @dataclass
21
+ class DecoderOutput(BaseOutput):
22
+ """
23
+ Output of decoding method.
24
+
25
+ Args:
26
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
27
+ Decoded output sample of the model. Output of the last layer of the model.
28
+ """
29
+
30
+ sample: torch.FloatTensor
31
+
32
+
33
+ @dataclass
34
+ class AutoencoderKLOutput(BaseOutput):
35
+ """
36
+ Output of AutoencoderKL encoding method.
37
+
38
+ Args:
39
+ latent_dist (`DiagonalGaussianDistribution`):
40
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
41
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
42
+ """
43
+
44
+ latent_dist: "DiagonalGaussianDistribution"
45
+
46
+
47
+ class AutoencoderKL(ModelMixin, ConfigMixin):
48
+ @register_to_config
49
+ def __init__(self,
50
+ ddconfig,
51
+ embed_dim,
52
+ image_key="image",
53
+ input_dim=4,
54
+ use_checkpoint=False,
55
+ ):
56
+ super().__init__()
57
+ self.image_key = image_key
58
+ self.encoder = Encoder(**ddconfig)
59
+ self.decoder = Decoder(**ddconfig)
60
+ assert ddconfig["double_z"]
61
+ self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1)
62
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1)
63
+ self.embed_dim = embed_dim
64
+ self.input_dim = input_dim
65
+ self.use_checkpoint = use_checkpoint
66
+
67
+ def encode(self, x, return_hidden_states=False, **kwargs):
68
+ if return_hidden_states:
69
+ h, hidden = self.encoder(x, return_hidden_states)
70
+ moments = self.quant_conv(h)
71
+ posterior = DiagonalGaussianDistribution(moments)
72
+ return AutoencoderKLOutput(latent_dist=posterior), hidden
73
+ else:
74
+ h = self.encoder(x)
75
+ moments = self.quant_conv(h)
76
+ posterior = DiagonalGaussianDistribution(moments)
77
+ return AutoencoderKLOutput(latent_dist=posterior)
78
+
79
+ def decode(self, z, **kwargs):
80
+ if len(kwargs) == 0: ## use the original decoder in AutoencoderKL
81
+ z = self.post_quant_conv(z)
82
+ dec = self.decoder(z, **kwargs) ##change for SVD decoder by adding **kwargs
83
+ return dec
84
+
85
+ def forward(self, input, sample_posterior=True, **additional_decode_kwargs):
86
+ input_tuple = (input, )
87
+ forward_temp = partial(self._forward, sample_posterior=sample_posterior, **additional_decode_kwargs)
88
+ return checkpoint(forward_temp, input_tuple, self.parameters(), self.use_checkpoint)
89
+
90
+
91
+ def _forward(self, input, sample_posterior=True, **additional_decode_kwargs):
92
+ posterior = self.encode(input)[0]
93
+ if sample_posterior:
94
+ z = posterior.sample()
95
+ else:
96
+ z = posterior.mode()
97
+ dec = self.decode(z, **additional_decode_kwargs)
98
+ ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
99
+ return dec, posterior
100
+
101
+ def get_input(self, batch, k):
102
+ x = batch[k]
103
+ if x.dim() == 5 and self.input_dim == 4:
104
+ b,c,t,h,w = x.shape
105
+ self.b = b
106
+ self.t = t
107
+ x = rearrange(x, 'b c t h w -> (b t) c h w')
108
+
109
+ return x
110
+
111
+ def get_last_layer(self):
112
+ return self.decoder.conv_out.weight
113
+
114
+
115
+ class AutoencoderKL_Dualref(AutoencoderKL):
116
+ @register_to_config
117
+ def __init__(self,
118
+ ddconfig,
119
+ embed_dim,
120
+ image_key="image",
121
+ input_dim=4,
122
+ use_checkpoint=False,
123
+ ):
124
+ super().__init__(ddconfig, embed_dim, image_key, input_dim, use_checkpoint)
125
+ self.decoder = VideoDecoder(**ddconfig)
126
+
127
+ def _forward(self, input, batch_size, sample_posterior=True, **additional_decode_kwargs):
128
+ posterior, hidden_states = self.encode(input, return_hidden_states=True)
129
+
130
+ hidden_states_first_last = []
131
+ ### use only the first and last hidden states
132
+ for hid in hidden_states:
133
+ hid = rearrange(hid, '(b t) c h w -> b c t h w', b=batch_size)
134
+ hid_new = torch.cat([hid[:, :, 0:1], hid[:, :, -1:]], dim=2)
135
+ hidden_states_first_last.append(hid_new)
136
+
137
+ if sample_posterior:
138
+ z = posterior[0].sample()
139
+ else:
140
+ z = posterior[0].mode()
141
+ dec = self.decode(z, ref_context=hidden_states_first_last, **additional_decode_kwargs)
142
+ ## print(input.shape, dec.shape) torch.Size([16, 3, 256, 256]) torch.Size([16, 3, 256, 256])
143
+ return dec, posterior
lvdm/models/condition.py ADDED
@@ -0,0 +1,477 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from torchvision.transforms import functional as F
5
+ import open_clip
6
+ from torch.utils.checkpoint import checkpoint
7
+ from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel
8
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
9
+ from diffusers.models import ModelMixin
10
+ from ..common import autocast
11
+ from ..utils import count_params
12
+
13
+
14
+ class AbstractEncoder(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ def encode(self, *args, **kwargs):
19
+ raise NotImplementedError
20
+
21
+ @property
22
+ def device(self):
23
+ return next(self.parameters()).device
24
+
25
+ @property
26
+ def dtype(self):
27
+ return next(self.parameters()).dtype
28
+
29
+ class IdentityEncoder(AbstractEncoder):
30
+ def encode(self, x):
31
+ return x
32
+
33
+
34
+ class ClassEmbedder(nn.Module):
35
+ def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1):
36
+ super().__init__()
37
+ self.key = key
38
+ self.embedding = nn.Embedding(n_classes, embed_dim)
39
+ self.n_classes = n_classes
40
+ self.ucg_rate = ucg_rate
41
+
42
+ def forward(self, batch, key=None, disable_dropout=False):
43
+ if key is None:
44
+ key = self.key
45
+ # this is for use in crossattn
46
+ c = batch[key][:, None]
47
+ if self.ucg_rate > 0. and not disable_dropout:
48
+ mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate)
49
+ c = mask * c + (1 - mask) * torch.ones_like(c) * (self.n_classes - 1)
50
+ c = c.long()
51
+ c = self.embedding(c)
52
+ return c
53
+
54
+ def get_unconditional_conditioning(self, bs, device="cuda"):
55
+ uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000)
56
+ uc = torch.ones((bs,), device=device) * uc_class
57
+ uc = {self.key: uc}
58
+ return uc
59
+
60
+
61
+ def disabled_train(self, mode=True):
62
+ """Overwrite model.train with this function to make sure train/eval mode
63
+ does not change anymore."""
64
+ return self
65
+
66
+
67
+ class FrozenT5Embedder(AbstractEncoder):
68
+ """Uses the T5 transformer encoder for text"""
69
+
70
+ def __init__(self, version="google/t5-v1_1-large", max_length=77,
71
+ freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl
72
+ super().__init__()
73
+ self.tokenizer = T5Tokenizer.from_pretrained(version)
74
+ self.transformer = T5EncoderModel.from_pretrained(version)
75
+ self.max_length = max_length # TODO: typical value?
76
+ if freeze:
77
+ self.freeze()
78
+
79
+ def freeze(self):
80
+ self.transformer = self.transformer.eval()
81
+ # self.train = disabled_train
82
+ for param in self.parameters():
83
+ param.requires_grad = False
84
+
85
+ def forward(self, text):
86
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
87
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
88
+ tokens = batch_encoding["input_ids"].to(self.device)
89
+ outputs = self.transformer(input_ids=tokens)
90
+
91
+ z = outputs.last_hidden_state
92
+ return z
93
+
94
+ def encode(self, text):
95
+ return self(text)
96
+
97
+
98
+ class FrozenCLIPEmbedder(AbstractEncoder):
99
+ """Uses the CLIP transformer encoder for text (from huggingface)"""
100
+ LAYERS = [
101
+ "last",
102
+ "pooled",
103
+ "hidden"
104
+ ]
105
+
106
+ def __init__(self, version="openai/clip-vit-large-patch14", max_length=77,
107
+ freeze=True, layer="last", layer_idx=None): # clip-vit-base-patch32
108
+ super().__init__()
109
+ assert layer in self.LAYERS
110
+ self.tokenizer = CLIPTokenizer.from_pretrained(version)
111
+ self.transformer = CLIPTextModel.from_pretrained(version)
112
+ self.max_length = max_length
113
+ if freeze:
114
+ self.freeze()
115
+ self.layer = layer
116
+ self.layer_idx = layer_idx
117
+ if layer == "hidden":
118
+ assert layer_idx is not None
119
+ assert 0 <= abs(layer_idx) <= 12
120
+
121
+ def freeze(self):
122
+ self.transformer = self.transformer.eval()
123
+ # self.train = disabled_train
124
+ for param in self.parameters():
125
+ param.requires_grad = False
126
+
127
+ def forward(self, text):
128
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
129
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
130
+ tokens = batch_encoding["input_ids"].to(self.device)
131
+ outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer == "hidden")
132
+ if self.layer == "last":
133
+ z = outputs.last_hidden_state
134
+ elif self.layer == "pooled":
135
+ z = outputs.pooler_output[:, None, :]
136
+ else:
137
+ z = outputs.hidden_states[self.layer_idx]
138
+ return z
139
+
140
+ def encode(self, text):
141
+ return self(text)
142
+
143
+
144
+ class FrozenOpenCLIPEmbedder(AbstractEncoder):
145
+ """
146
+ Uses the OpenCLIP transformer encoder for text
147
+ """
148
+ LAYERS = [
149
+ # "pooled",
150
+ "last",
151
+ "penultimate"
152
+ ]
153
+
154
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", max_length=77,
155
+ freeze=True, layer="penultimate"):
156
+ super().__init__()
157
+ assert layer in self.LAYERS
158
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version)
159
+ del model.visual
160
+ self.model = model
161
+
162
+ self.max_length = max_length
163
+ if freeze:
164
+ self.freeze()
165
+ self.layer = layer
166
+ if self.layer == "last":
167
+ self.layer_idx = 0
168
+ elif self.layer == "penultimate":
169
+ self.layer_idx = 1
170
+ else:
171
+ raise NotImplementedError()
172
+
173
+ def freeze(self):
174
+ self.model = self.model.eval()
175
+ for param in self.parameters():
176
+ param.requires_grad = False
177
+
178
+ def forward(self, text):
179
+ tokens = open_clip.tokenize(text) ## all clip models use 77 as context length
180
+ z = self.encode_with_transformer(tokens.to(self.device))
181
+ return z
182
+
183
+ def encode_with_transformer(self, text):
184
+ x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model]
185
+ x = x + self.model.positional_embedding
186
+ x = x.permute(1, 0, 2) # NLD -> LND
187
+ x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask)
188
+ x = x.permute(1, 0, 2) # LND -> NLD
189
+ x = self.model.ln_final(x)
190
+ return x
191
+
192
+ def text_transformer_forward(self, x: torch.Tensor, attn_mask=None):
193
+ for i, r in enumerate(self.model.transformer.resblocks):
194
+ if i == len(self.model.transformer.resblocks) - self.layer_idx:
195
+ break
196
+ if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting():
197
+ x = checkpoint(r, x, attn_mask)
198
+ else:
199
+ x = r(x, attn_mask=attn_mask)
200
+ return x
201
+
202
+ def encode(self, text):
203
+ return self(text)
204
+
205
+
206
+ class FrozenOpenCLIPImageEmbedder(AbstractEncoder):
207
+ """
208
+ Uses the OpenCLIP vision transformer encoder for images
209
+ """
210
+
211
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", max_length=77,
212
+ freeze=True, layer="pooled", antialias=True, ucg_rate=0.):
213
+ super().__init__()
214
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
215
+ pretrained=version, )
216
+ del model.transformer
217
+ self.model = model
218
+ self.preprocess_val = preprocess_val
219
+ # self.mapper = torch.nn.Linear(1280, 1024)
220
+ self.max_length = max_length
221
+ if freeze:
222
+ self.freeze()
223
+ self.layer = layer
224
+ if self.layer == "penultimate":
225
+ raise NotImplementedError()
226
+ self.layer_idx = 1
227
+
228
+ self.antialias = antialias
229
+
230
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
231
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
232
+ self.ucg_rate = ucg_rate
233
+
234
+ def preprocess(self, x):
235
+ # normalize to [0,1]
236
+ x = F.resize(x, (224, 224), interpolation=F.InterpolationMode.BICUBIC, antialias=self.antialias)
237
+ x = (x + 1.) / 2.
238
+ # renormalize according to clip
239
+ x = F.normalize(x, mean=self.mean, std=self.std)
240
+ return x
241
+
242
+ def freeze(self):
243
+ self.model = self.model.eval()
244
+ for param in self.model.parameters():
245
+ param.requires_grad = False
246
+
247
+ @autocast
248
+ def forward(self, image, no_dropout=False):
249
+ z = self.encode_with_vision_transformer(image)
250
+ if self.ucg_rate > 0. and not no_dropout:
251
+ z = torch.bernoulli((1. - self.ucg_rate) * torch.ones(z.shape[0], device=z.device))[:, None] * z
252
+ return z
253
+
254
+ def encode_with_vision_transformer(self, img):
255
+ img = self.preprocess(img)
256
+ x = self.model.visual(img)
257
+ return x
258
+
259
+ def encode(self, text):
260
+ return self(text)
261
+
262
+ class FrozenOpenCLIPImageEmbedderV2(AbstractEncoder):
263
+ """
264
+ Uses the OpenCLIP vision transformer encoder for images
265
+ """
266
+
267
+ def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k",
268
+ freeze=True, layer="pooled", antialias=True):
269
+ super().__init__()
270
+ model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'),
271
+ pretrained=version, )
272
+ del model.transformer
273
+ self.model = model
274
+ self.preprocess_val = preprocess_val
275
+
276
+ if freeze:
277
+ self.freeze()
278
+ self.layer = layer
279
+ if self.layer == "penultimate":
280
+ raise NotImplementedError()
281
+ self.layer_idx = 1
282
+
283
+ self.antialias = antialias
284
+
285
+ self.register_buffer('mean', torch.Tensor([0.48145466, 0.4578275, 0.40821073]), persistent=False)
286
+ self.register_buffer('std', torch.Tensor([0.26862954, 0.26130258, 0.27577711]), persistent=False)
287
+
288
+
289
+ def preprocess(self, x):
290
+ # normalize to [0,1]
291
+ x = F.resize(x, (224, 224), interpolation=F.InterpolationMode.BICUBIC, antialias=self.antialias)
292
+ x = (x + 1.) / 2.
293
+ # renormalize according to clip
294
+ x = F.normalize(x, mean=self.mean, std=self.std)
295
+ return x
296
+
297
+ def freeze(self):
298
+ self.model = self.model.eval()
299
+ for param in self.model.parameters():
300
+ param.requires_grad = False
301
+
302
+ def forward(self, image, no_dropout=False):
303
+ ## image: b c h w
304
+ z = self.encode_with_vision_transformer(image)
305
+ return z
306
+
307
+ def encode_with_vision_transformer(self, x):
308
+ x = self.preprocess(x)
309
+
310
+ # to patches - whether to use dual patchnorm - https://arxiv.org/abs/2302.01327v1
311
+ if self.model.visual.input_patchnorm:
312
+ # einops - rearrange(x, 'b c (h p1) (w p2) -> b (h w) (c p1 p2)')
313
+ x = x.reshape(x.shape[0], x.shape[1], self.model.visual.grid_size[0], self.model.visual.patch_size[0], self.model.visual.grid_size[1], self.model.visual.patch_size[1])
314
+ x = x.permute(0, 2, 4, 1, 3, 5)
315
+ x = x.reshape(x.shape[0], self.model.visual.grid_size[0] * self.model.visual.grid_size[1], -1)
316
+ x = self.model.visual.patchnorm_pre_ln(x)
317
+ x = self.model.visual.conv1(x)
318
+ else:
319
+ x = self.model.visual.conv1(x) # shape = [*, width, grid, grid]
320
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
321
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
322
+
323
+ # class embeddings and positional embeddings
324
+ x = torch.cat(
325
+ [self.model.visual.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
326
+ x], dim=1) # shape = [*, grid ** 2 + 1, width]
327
+ x = x + self.model.visual.positional_embedding.to(x.dtype)
328
+
329
+ # a patch_dropout of 0. would mean it is disabled and this function would do nothing but return what was passed in
330
+ x = self.model.visual.patch_dropout(x)
331
+ x = self.model.visual.ln_pre(x)
332
+
333
+ x = x.permute(1, 0, 2) # NLD -> LND
334
+ x = self.model.visual.transformer(x)
335
+ x = x.permute(1, 0, 2) # LND -> NLD
336
+
337
+ return x
338
+
339
+ class FrozenCLIPT5Encoder(AbstractEncoder):
340
+ def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl",
341
+ clip_max_length=77, t5_max_length=77):
342
+ super().__init__()
343
+ self.clip_encoder = FrozenCLIPEmbedder(clip_version, max_length=clip_max_length)
344
+ self.t5_encoder = FrozenT5Embedder(t5_version, max_length=t5_max_length)
345
+ print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder) * 1.e-6:.2f} M parameters, "
346
+ f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder) * 1.e-6:.2f} M params.")
347
+
348
+ def encode(self, text):
349
+ return self(text)
350
+
351
+ def forward(self, text):
352
+ clip_z = self.clip_encoder.encode(text)
353
+ t5_z = self.t5_encoder.encode(text)
354
+ return [clip_z, t5_z]
355
+
356
+
357
+ # FFN
358
+ def FeedForward(dim, mult=4):
359
+ inner_dim = int(dim * mult)
360
+ return nn.Sequential(
361
+ nn.LayerNorm(dim),
362
+ nn.Linear(dim, inner_dim, bias=False),
363
+ nn.GELU(),
364
+ nn.Linear(inner_dim, dim, bias=False),
365
+ )
366
+
367
+
368
+ def reshape_tensor(x, heads):
369
+ bs, length, width = x.shape
370
+ #(bs, length, width) --> (bs, length, n_heads, dim_per_head)
371
+ x = x.view(bs, length, heads, -1)
372
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
373
+ x = x.transpose(1, 2)
374
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
375
+ x = x.reshape(bs, heads, length, -1)
376
+ return x
377
+
378
+
379
+ class PerceiverAttention(nn.Module):
380
+ def __init__(self, *, dim, dim_head=64, heads=8):
381
+ super().__init__()
382
+ self.scale = dim_head**-0.5
383
+ self.dim_head = dim_head
384
+ self.heads = heads
385
+ inner_dim = dim_head * heads
386
+
387
+ self.norm1 = nn.LayerNorm(dim)
388
+ self.norm2 = nn.LayerNorm(dim)
389
+
390
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
391
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
392
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
393
+
394
+
395
+ def forward(self, x, latents):
396
+ """
397
+ Args:
398
+ x (torch.Tensor): image features
399
+ shape (b, n1, D)
400
+ latent (torch.Tensor): latent features
401
+ shape (b, n2, D)
402
+ """
403
+ x = self.norm1(x)
404
+ latents = self.norm2(latents)
405
+
406
+ b, l, _ = latents.shape
407
+
408
+ q = self.to_q(latents)
409
+ kv_input = torch.cat((x, latents), dim=-2)
410
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
411
+
412
+ q = reshape_tensor(q, self.heads)
413
+ k = reshape_tensor(k, self.heads)
414
+ v = reshape_tensor(v, self.heads)
415
+
416
+ # attention
417
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
418
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
419
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
420
+ out = weight @ v
421
+
422
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
423
+
424
+ return self.to_out(out)
425
+
426
+
427
+ class Resampler(ModelMixin, ConfigMixin):
428
+ @register_to_config
429
+ def __init__(
430
+ self,
431
+ dim=1024,
432
+ depth=8,
433
+ dim_head=64,
434
+ heads=16,
435
+ num_queries=8,
436
+ embedding_dim=768,
437
+ output_dim=1024,
438
+ ff_mult=4,
439
+ video_length=None, # using frame-wise version or not
440
+ ):
441
+ super().__init__()
442
+ ## queries for a single frame / image
443
+ self.num_queries = num_queries
444
+ self.video_length = video_length
445
+
446
+ ## <num_queries> queries for each frame
447
+ if video_length is not None:
448
+ num_queries = num_queries * video_length
449
+
450
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
451
+ self.proj_in = nn.Linear(embedding_dim, dim)
452
+ self.proj_out = nn.Linear(dim, output_dim)
453
+ self.norm_out = nn.LayerNorm(output_dim)
454
+
455
+ self.layers = nn.ModuleList([])
456
+ for _ in range(depth):
457
+ self.layers.append(
458
+ nn.ModuleList(
459
+ [
460
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
461
+ FeedForward(dim=dim, mult=ff_mult),
462
+ ]
463
+ )
464
+ )
465
+
466
+ def forward(self, x):
467
+ latents = self.latents.repeat(x.size(0), 1, 1) ## B (T L) C
468
+ x = self.proj_in(x)
469
+
470
+ for attn, ff in self.layers:
471
+ latents = attn(x, latents) + latents
472
+ latents = ff(latents) + latents
473
+
474
+ latents = self.proj_out(latents)
475
+ latents = self.norm_out(latents) # B L C or B (T L) C
476
+
477
+ return latents
lvdm/models/controlnet.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+ from einops import rearrange, repeat
3
+ import numpy as np
4
+ from functools import partial
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from .unet import TimestepEmbedSequential, ResBlock, Downsample, Upsample, TemporalConvBlock
9
+ from ..basics import zero_module, conv_nd
10
+ from ..modules.attention import SpatialTransformer, TemporalTransformer
11
+ from ..common import checkpoint
12
+
13
+ from diffusers import __version__
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from diffusers.models.model_loading_utils import load_state_dict
18
+ from diffusers.utils import (
19
+ SAFETENSORS_WEIGHTS_NAME,
20
+ WEIGHTS_NAME,
21
+ logging,
22
+ _get_model_file,
23
+ _add_variant
24
+ )
25
+ from omegaconf import ListConfig, DictConfig, OmegaConf
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class ResBlock_v2(nn.Module):
32
+ def __init__(
33
+ self,
34
+ channels,
35
+ emb_channels,
36
+ dropout,
37
+ out_channels=None,
38
+ dims=2,
39
+ use_checkpoint=False,
40
+ use_conv=False,
41
+ up=False,
42
+ down=False,
43
+ use_temporal_conv=False,
44
+ tempspatial_aware=False
45
+ ):
46
+ super().__init__()
47
+ self.channels = channels
48
+ self.emb_channels = emb_channels
49
+ self.dropout = dropout
50
+ self.out_channels = out_channels or channels
51
+ self.use_conv = use_conv
52
+ self.use_checkpoint = use_checkpoint
53
+ self.use_temporal_conv = use_temporal_conv
54
+
55
+ self.in_layers = nn.Sequential(
56
+ nn.GroupNorm(32, channels),
57
+ nn.SiLU(),
58
+ zero_module(conv_nd(dims, channels, self.out_channels, 3, padding=1)),
59
+ )
60
+
61
+ self.updown = up or down
62
+
63
+ if up:
64
+ self.h_upd = Upsample(channels, False, dims)
65
+ self.x_upd = Upsample(channels, False, dims)
66
+ elif down:
67
+ self.h_upd = Downsample(channels, False, dims)
68
+ self.x_upd = Downsample(channels, False, dims)
69
+ else:
70
+ self.h_upd = self.x_upd = nn.Identity()
71
+
72
+ if self.out_channels == channels:
73
+ self.skip_connection = nn.Identity()
74
+ elif use_conv:
75
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 3, padding=1)
76
+ else:
77
+ self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
78
+
79
+ if self.use_temporal_conv:
80
+ self.temopral_conv = TemporalConvBlock(
81
+ self.out_channels,
82
+ self.out_channels,
83
+ dropout=0.1,
84
+ spatial_aware=tempspatial_aware
85
+ )
86
+
87
+ def forward(self, x, batch_size=None):
88
+ """
89
+ Apply the block to a Tensor, conditioned on a timestep embedding.
90
+ :param x: an [N x C x ...] Tensor of features.
91
+ :return: an [N x C x ...] Tensor of outputs.
92
+ """
93
+ input_tuple = (x, )
94
+ if batch_size:
95
+ forward_batchsize = partial(self._forward, batch_size=batch_size)
96
+ return checkpoint(forward_batchsize, input_tuple, self.parameters(), self.use_checkpoint)
97
+ return checkpoint(self._forward, input_tuple, self.parameters(), self.use_checkpoint)
98
+
99
+ def _forward(self, x, batch_size=None):
100
+ if self.updown:
101
+ in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
102
+ h = in_rest(x)
103
+ h = self.h_upd(h)
104
+ x = self.x_upd(x)
105
+ h = in_conv(h)
106
+ else:
107
+ h = self.in_layers(x)
108
+ h = self.skip_connection(x) + h
109
+
110
+ if self.use_temporal_conv and batch_size:
111
+ h = rearrange(h, '(b t) c h w -> b c t h w', b=batch_size)
112
+ h = self.temopral_conv(h)
113
+ h = rearrange(h, 'b c t h w -> (b t) c h w')
114
+ return h
115
+
116
+
117
+ class TrajectoryEncoder(nn.Module):
118
+ def __init__(self, cin, time_embed_dim, channels=[320, 640, 1280, 1280], nums_rb=3,
119
+ dropout=0.0, use_checkpoint=False, tempspatial_aware=False, temporal_conv=False):
120
+ super(TrajectoryEncoder, self).__init__()
121
+ # self.unshuffle = nn.PixelUnshuffle(8)
122
+ self.channels = channels
123
+ self.nums_rb = nums_rb
124
+ self.body = []
125
+ # self.conv_out = []
126
+ for i in range(len(channels)):
127
+ for j in range(nums_rb):
128
+ if (i != 0) and (j == 0):
129
+ self.body.append(
130
+ ResBlock_v2(channels[i - 1], time_embed_dim, dropout,
131
+ out_channels=channels[i], dims=2, use_checkpoint=use_checkpoint,
132
+ tempspatial_aware=tempspatial_aware,
133
+ use_temporal_conv=temporal_conv,
134
+ down=True
135
+ )
136
+ )
137
+ else:
138
+ self.body.append(
139
+ ResBlock_v2(channels[i], time_embed_dim, dropout,
140
+ out_channels=channels[i], dims=2, use_checkpoint=use_checkpoint,
141
+ tempspatial_aware=tempspatial_aware,
142
+ use_temporal_conv=temporal_conv,
143
+ down=False
144
+ )
145
+ )
146
+ self.body.append(
147
+ ResBlock_v2(channels[-1], time_embed_dim, dropout,
148
+ out_channels=channels[-1], dims=2, use_checkpoint=use_checkpoint,
149
+ tempspatial_aware=tempspatial_aware,
150
+ use_temporal_conv=temporal_conv,
151
+ down=True
152
+ )
153
+ )
154
+ self.body = nn.ModuleList(self.body)
155
+ self.conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
156
+ self.conv_out = zero_module(conv_nd(2, channels[-1], channels[-1], 3, 1, 1))
157
+
158
+ def forward(self, x, batch_size=None):
159
+ # unshuffle
160
+ # x = self.unshuffle(x)
161
+ # extract features
162
+ # features = []
163
+ x = self.conv_in(x)
164
+ for i in range(len(self.channels)):
165
+ for j in range(self.nums_rb):
166
+ idx = i * self.nums_rb + j
167
+ x = self.body[idx](x, batch_size)
168
+ x = self.body[-1](x, batch_size)
169
+ out = self.conv_out(x)
170
+ return out
171
+
172
+
173
+ class ControlNet(ModelMixin, ConfigMixin):
174
+ _supports_gradient_checkpointing = True
175
+
176
+ @register_to_config
177
+ def __init__(
178
+ self,
179
+ in_channels,
180
+ model_channels,
181
+ out_channels,
182
+ num_res_blocks,
183
+ attention_resolutions,
184
+ dropout=0.0,
185
+ channel_mult=(1, 2, 4, 8),
186
+ conv_resample=True,
187
+ dims=2,
188
+ context_dim=None,
189
+ use_scale_shift_norm=False,
190
+ resblock_updown=False,
191
+ num_heads=-1,
192
+ num_head_channels=-1,
193
+ transformer_depth=1,
194
+ use_linear=False,
195
+ use_checkpoint=False,
196
+ temporal_conv=False,
197
+ tempspatial_aware=False,
198
+ temporal_attention=True,
199
+ use_relative_position=True,
200
+ use_causal_attention=False,
201
+ temporal_length=None,
202
+ addition_attention=False,
203
+ temporal_selfatt_only=True,
204
+ image_cross_attention=False,
205
+ image_cross_attention_scale_learnable=False,
206
+ default_fps=4,
207
+ fps_condition=False,
208
+ ignore_noisy_latents=True,
209
+ conditioning_channels=4,
210
+ ):
211
+ super().__init__()
212
+ if num_heads == -1:
213
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
214
+ if num_head_channels == -1:
215
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
216
+
217
+ self.in_channels = in_channels
218
+ self.model_channels = model_channels
219
+ self.out_channels = out_channels
220
+ self.num_res_blocks = num_res_blocks
221
+ self.attention_resolutions = attention_resolutions
222
+ self.dropout = dropout
223
+ self.channel_mult = channel_mult
224
+ self.conv_resample = conv_resample
225
+ self.temporal_attention = temporal_attention
226
+ time_embed_dim = model_channels * 4
227
+ self.use_checkpoint = use_checkpoint
228
+ temporal_self_att_only = True
229
+ self.addition_attention = addition_attention
230
+ self.temporal_length = temporal_length
231
+ self.image_cross_attention = image_cross_attention
232
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
233
+ self.default_fps = default_fps
234
+ self.fps_condition = fps_condition
235
+ self.ignore_noisy_latents = ignore_noisy_latents
236
+
237
+ ## Time embedding blocks
238
+ self.time_proj = Timesteps(model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
239
+ self.time_embed = TimestepEmbedding(model_channels, time_embed_dim)
240
+
241
+ if fps_condition:
242
+ self.fps_embedding = TimestepEmbedding(model_channels, time_embed_dim)
243
+ nn.init.zeros_(self.fps_embedding.linear_2.weight)
244
+ nn.init.zeros_(self.fps_embedding.linear_2.bias)
245
+
246
+ # self.cond_embedding = TrajectoryEncoder(
247
+ # cin=conditioning_channels, time_embed_dim=time_embed_dim, channels=trajectory_channels, nums_rb=3,
248
+ # dropout=dropout, use_checkpoint=use_checkpoint, tempspatial_aware=tempspatial_aware, temporal_conv=False
249
+ # )
250
+ self.cond_embedding = zero_module(conv_nd(dims, conditioning_channels, model_channels, 3, padding=1))
251
+ self.input_blocks = nn.ModuleList(
252
+ [
253
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
254
+ ]
255
+ )
256
+
257
+ ## Output Block
258
+ self.downsample_output = nn.ModuleList(
259
+ [
260
+ nn.Sequential(
261
+ nn.GroupNorm(32, model_channels),
262
+ nn.SiLU(),
263
+ zero_module(conv_nd(dims, model_channels, model_channels, 3, padding=1))
264
+ )
265
+ ]
266
+ )
267
+
268
+ if self.addition_attention:
269
+ self.init_attn = TimestepEmbedSequential(
270
+ TemporalTransformer(
271
+ model_channels,
272
+ n_heads=8,
273
+ d_head=num_head_channels,
274
+ depth=transformer_depth,
275
+ context_dim=context_dim,
276
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
277
+ causal_attention=False, relative_position=use_relative_position,
278
+ temporal_length=temporal_length
279
+ )
280
+ )
281
+
282
+ ch = model_channels
283
+ ds = 1
284
+ for level, mult in enumerate(channel_mult):
285
+ for _ in range(num_res_blocks):
286
+ layers = [
287
+ ResBlock(ch, time_embed_dim, dropout,
288
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
289
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
290
+ use_temporal_conv=temporal_conv
291
+ )
292
+ ]
293
+ ch = mult * model_channels
294
+ if ds in attention_resolutions:
295
+ if num_head_channels == -1:
296
+ dim_head = ch // num_heads
297
+ else:
298
+ num_heads = ch // num_head_channels
299
+ dim_head = num_head_channels
300
+ layers.append(
301
+ SpatialTransformer(ch, num_heads, dim_head,
302
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
303
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
304
+ video_length=temporal_length, image_cross_attention=self.image_cross_attention,
305
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
306
+ )
307
+ )
308
+ if self.temporal_attention:
309
+ layers.append(
310
+ TemporalTransformer(ch, num_heads, dim_head,
311
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
312
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
313
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
314
+ temporal_length=temporal_length
315
+ )
316
+ )
317
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
318
+ self.downsample_output.append(
319
+ nn.Sequential(
320
+ nn.GroupNorm(32, ch),
321
+ nn.SiLU(),
322
+ zero_module(conv_nd(dims, ch, ch, 3, padding=1))
323
+ )
324
+ )
325
+ if level < len(channel_mult) - 1:
326
+ out_ch = ch
327
+ self.input_blocks.append(
328
+ TimestepEmbedSequential(
329
+ ResBlock(ch, time_embed_dim, dropout,
330
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
331
+ use_scale_shift_norm=use_scale_shift_norm,
332
+ down=True
333
+ )
334
+ if resblock_updown
335
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
336
+ )
337
+ )
338
+ self.downsample_output.append(
339
+ nn.Sequential(
340
+ nn.GroupNorm(32, out_ch),
341
+ nn.SiLU(),
342
+ zero_module(conv_nd(dims, out_ch, out_ch, 3, padding=1))
343
+ )
344
+ )
345
+ ch = out_ch
346
+ ds *= 2
347
+
348
+ def forward(
349
+ self,
350
+ noisy_latents,
351
+ timesteps,
352
+ context_text,
353
+ context_img=None,
354
+ fps=None,
355
+ condition=None, # [b, t, c, h, w]
356
+ ):
357
+ if self.ignore_noisy_latents:
358
+ noisy_latents = torch.zeros_like(noisy_latents)
359
+
360
+ b, _, t, height, width = noisy_latents.shape
361
+ t_emb = self.time_proj(timesteps).type(noisy_latents.dtype)
362
+ emb = self.time_embed(t_emb)
363
+
364
+ ## repeat t times for context [(b t) 77 768] & time embedding
365
+ ## check if we use per-frame image conditioning
366
+ if context_img is not None: ## decompose context into text and image
367
+ context_text = context_text.repeat_interleave(repeats=t, dim=0)
368
+ context_img = rearrange(context_img, 'b (t l) c -> (b t) l c', t=t)
369
+ context = torch.cat([context_text, context_img], dim=1)
370
+ else:
371
+ context = context_text.repeat_interleave(repeats=t, dim=0)
372
+ emb = emb.repeat_interleave(repeats=t, dim=0)
373
+
374
+ ## always in shape (b n t) c h w, except for temporal layer
375
+ noisy_latents = rearrange(noisy_latents, 'b c t h w -> (b t) c h w')
376
+ condition = rearrange(condition, 'b t c h w -> (b t) c h w')
377
+
378
+ ## combine emb
379
+ if self.fps_condition:
380
+ if fps is None:
381
+ fps = torch.tensor(
382
+ [self.default_fs] * b, dtype=torch.long, device=noisy_latents.device)
383
+ fps_emb = self.time_proj(fps).type(noisy_latents.dtype)
384
+
385
+ fps_embed = self.fps_embedding(fps_emb)
386
+ fps_embed = fps_embed.repeat_interleave(repeats=t, dim=0)
387
+ emb = emb + fps_embed
388
+
389
+ h = noisy_latents.type(self.dtype)
390
+ hs = []
391
+ for id, module in enumerate(self.input_blocks):
392
+ h = module(h, emb, context=context, batch_size=b)
393
+ if id == 0:
394
+ h = h + self.cond_embedding(condition)
395
+ if self.addition_attention:
396
+ h = self.init_attn(h, emb, context=context, batch_size=b)
397
+ hs.append(h)
398
+
399
+ guidance_feature_list = []
400
+ for hidden, module in zip(hs, self.downsample_output):
401
+ h = module(hidden)
402
+ guidance_feature_list.append(h)
403
+
404
+ return guidance_feature_list
405
+
406
+ @classmethod
407
+ def from_pretrained(cls, pretrained_model_name_or_path, layer_encoder_additional_kwargs={}, **kwargs):
408
+ cache_dir = kwargs.pop("cache_dir", None)
409
+ force_download = kwargs.pop("force_download", False)
410
+ proxies = kwargs.pop("proxies", None)
411
+ local_files_only = kwargs.pop("local_files_only", None)
412
+ token = kwargs.pop("token", None)
413
+ revision = kwargs.pop("revision", None)
414
+ subfolder = kwargs.pop("subfolder", None)
415
+ variant = kwargs.pop("variant", None)
416
+ use_safetensors = kwargs.pop("use_safetensors", None)
417
+
418
+ allow_pickle = False
419
+ if use_safetensors is None:
420
+ use_safetensors = True
421
+ allow_pickle = True
422
+
423
+ # Load config if we don't provide a configuration
424
+ config_path = pretrained_model_name_or_path
425
+
426
+ user_agent = {
427
+ "diffusers": __version__,
428
+ "file_type": "model",
429
+ "framework": "pytorch",
430
+ }
431
+
432
+ # load config
433
+ config, unused_kwargs, commit_hash = cls.load_config(
434
+ config_path,
435
+ cache_dir=cache_dir,
436
+ return_unused_kwargs=True,
437
+ return_commit_hash=True,
438
+ force_download=force_download,
439
+ proxies=proxies,
440
+ local_files_only=local_files_only,
441
+ token=token,
442
+ revision=revision,
443
+ subfolder=subfolder,
444
+ user_agent=user_agent,
445
+ **kwargs,
446
+ )
447
+
448
+ for key, value in layer_encoder_additional_kwargs.items():
449
+ if isinstance(value, (ListConfig, DictConfig)):
450
+ config[key] = OmegaConf.to_container(value, resolve=True)
451
+ else:
452
+ config[key] = value
453
+
454
+ # load model
455
+ model_file = None
456
+ if use_safetensors:
457
+ try:
458
+ model_file = _get_model_file(
459
+ pretrained_model_name_or_path,
460
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
461
+ cache_dir=cache_dir,
462
+ force_download=force_download,
463
+ proxies=proxies,
464
+ local_files_only=local_files_only,
465
+ token=token,
466
+ revision=revision,
467
+ subfolder=subfolder,
468
+ user_agent=user_agent,
469
+ commit_hash=commit_hash,
470
+ )
471
+
472
+ except IOError as e:
473
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
474
+ if not allow_pickle:
475
+ raise
476
+ logger.warning(
477
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
478
+ )
479
+
480
+ if model_file is None:
481
+ model_file = _get_model_file(
482
+ pretrained_model_name_or_path,
483
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
484
+ cache_dir=cache_dir,
485
+ force_download=force_download,
486
+ proxies=proxies,
487
+ local_files_only=local_files_only,
488
+ token=token,
489
+ revision=revision,
490
+ subfolder=subfolder,
491
+ user_agent=user_agent,
492
+ commit_hash=commit_hash,
493
+ )
494
+
495
+ model = cls.from_config(config, **unused_kwargs)
496
+ state_dict = load_state_dict(model_file, variant)
497
+
498
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
499
+ print(f"Controlnet loaded from {model_file} with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys.")
500
+ return model
lvdm/models/layer_controlnet.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Optional, Tuple, Union
2
+ from einops import rearrange, repeat
3
+ import numpy as np
4
+ from functools import partial
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from .unet import TimestepEmbedSequential, ResBlock, Downsample, Upsample, TemporalConvBlock
9
+ from ..basics import zero_module, conv_nd
10
+ from ..modules.attention import SpatialTransformer, TemporalTransformer
11
+ from ..common import checkpoint
12
+
13
+ from diffusers import __version__
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
17
+ from diffusers.models.model_loading_utils import load_state_dict
18
+ from diffusers.utils import (
19
+ SAFETENSORS_WEIGHTS_NAME,
20
+ WEIGHTS_NAME,
21
+ logging,
22
+ _get_model_file,
23
+ _add_variant
24
+ )
25
+ from omegaconf import ListConfig, DictConfig, OmegaConf
26
+
27
+
28
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
29
+
30
+
31
+ class ControlNetConditioningEmbedding(nn.Module):
32
+ """
33
+ Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
34
+ [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
35
+ training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
36
+ convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
37
+ (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
38
+ model) to encode image-space conditions ... into feature maps ..."
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ conditioning_embedding_channels: int,
44
+ conditioning_channels: int = 3,
45
+ block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
46
+ ):
47
+ super().__init__()
48
+
49
+ self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
50
+
51
+ self.blocks = nn.ModuleList([])
52
+
53
+ for i in range(len(block_out_channels) - 1):
54
+ channel_in = block_out_channels[i]
55
+ channel_out = block_out_channels[i + 1]
56
+ self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
57
+ self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
58
+
59
+ self.conv_out = zero_module(
60
+ nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
61
+ )
62
+
63
+ def forward(self, conditioning):
64
+ embedding = self.conv_in(conditioning)
65
+ embedding = F.silu(embedding)
66
+
67
+ for block in self.blocks:
68
+ embedding = block(embedding)
69
+ embedding = F.silu(embedding)
70
+
71
+ embedding = self.conv_out(embedding)
72
+
73
+ return embedding
74
+
75
+
76
+ class LayerControlNet(ModelMixin, ConfigMixin):
77
+ _supports_gradient_checkpointing = True
78
+
79
+ @register_to_config
80
+ def __init__(
81
+ self,
82
+ in_channels,
83
+ model_channels,
84
+ out_channels,
85
+ num_res_blocks,
86
+ attention_resolutions,
87
+ dropout=0.0,
88
+ channel_mult=(1, 2, 4, 8),
89
+ conv_resample=True,
90
+ dims=2,
91
+ context_dim=None,
92
+ use_scale_shift_norm=False,
93
+ resblock_updown=False,
94
+ num_heads=-1,
95
+ num_head_channels=-1,
96
+ transformer_depth=1,
97
+ use_linear=False,
98
+ use_checkpoint=False,
99
+ temporal_conv=False,
100
+ tempspatial_aware=False,
101
+ temporal_attention=True,
102
+ use_relative_position=True,
103
+ use_causal_attention=False,
104
+ temporal_length=None,
105
+ addition_attention=False,
106
+ temporal_selfatt_only=True,
107
+ image_cross_attention=False,
108
+ image_cross_attention_scale_learnable=False,
109
+ default_fps=4,
110
+ fps_condition=False,
111
+ ignore_noisy_latents=True,
112
+ condition_channels={},
113
+ control_injection_mode='add',
114
+ use_vae_for_trajectory=False,
115
+ ):
116
+ super().__init__()
117
+ if num_heads == -1:
118
+ assert num_head_channels != -1, 'Either num_heads or num_head_channels has to be set'
119
+ if num_head_channels == -1:
120
+ assert num_heads != -1, 'Either num_heads or num_head_channels has to be set'
121
+
122
+ self.in_channels = in_channels
123
+ self.model_channels = model_channels
124
+ self.out_channels = out_channels
125
+ self.num_res_blocks = num_res_blocks
126
+ self.attention_resolutions = attention_resolutions
127
+ self.dropout = dropout
128
+ self.channel_mult = channel_mult
129
+ self.conv_resample = conv_resample
130
+ self.temporal_attention = temporal_attention
131
+ time_embed_dim = model_channels * 4
132
+ self.use_checkpoint = use_checkpoint
133
+ temporal_self_att_only = True
134
+ self.addition_attention = addition_attention
135
+ self.temporal_length = temporal_length
136
+ self.image_cross_attention = image_cross_attention
137
+ self.image_cross_attention_scale_learnable = image_cross_attention_scale_learnable
138
+ self.default_fps = default_fps
139
+ self.fps_condition = fps_condition
140
+ self.ignore_noisy_latents = ignore_noisy_latents
141
+ assert len(condition_channels) > 0, 'Condition types must be specified'
142
+ self.condition_channels = condition_channels
143
+ self.control_injection_mode = control_injection_mode
144
+ self.use_vae_for_trajectory = use_vae_for_trajectory
145
+
146
+ ## Time embedding blocks
147
+ self.time_proj = Timesteps(model_channels, flip_sin_to_cos=True, downscale_freq_shift=0)
148
+ self.time_embed = TimestepEmbedding(model_channels, time_embed_dim)
149
+
150
+ if fps_condition:
151
+ self.fps_embedding = TimestepEmbedding(model_channels, time_embed_dim)
152
+ nn.init.zeros_(self.fps_embedding.linear_2.weight)
153
+ nn.init.zeros_(self.fps_embedding.linear_2.bias)
154
+
155
+ if "motion_score" in condition_channels:
156
+ if control_injection_mode == 'add':
157
+ self.motion_embedding = zero_module(conv_nd(dims, condition_channels["motion_score"], model_channels, 3, padding=1))
158
+ elif control_injection_mode == 'concat':
159
+ self.motion_embedding = zero_module(conv_nd(dims, condition_channels["motion_score"], condition_channels["motion_score"], 3, padding=1))
160
+ else:
161
+ raise ValueError(f"control_injection_mode {control_injection_mode} is not supported, use 'add' or 'concat'")
162
+ if "sketch" in condition_channels:
163
+ if control_injection_mode == 'add':
164
+ self.sketch_embedding = zero_module(conv_nd(dims, condition_channels["sketch"], model_channels, 3, padding=1))
165
+ elif control_injection_mode == 'concat':
166
+ self.sketch_embedding = zero_module(conv_nd(dims, condition_channels["sketch"], condition_channels["sketch"], 3, padding=1))
167
+ else:
168
+ raise ValueError(f"control_injection_mode {control_injection_mode} is not supported, use 'add' or 'concat'")
169
+ if "trajectory" in condition_channels:
170
+ if control_injection_mode == 'add':
171
+ if use_vae_for_trajectory:
172
+ self.trajectory_embedding = zero_module(conv_nd(dims, condition_channels["trajectory"], model_channels, 3, padding=1))
173
+ else:
174
+ self.trajectory_embedding = ControlNetConditioningEmbedding(model_channels, condition_channels["trajectory"])
175
+ elif control_injection_mode == 'concat':
176
+ if use_vae_for_trajectory:
177
+ self.trajectory_embedding = zero_module(conv_nd(dims, condition_channels["trajectory"], condition_channels["trajectory"], 3, padding=1))
178
+ else:
179
+ self.trajectory_embedding = ControlNetConditioningEmbedding(condition_channels["trajectory"], condition_channels["trajectory"])
180
+ else:
181
+ raise ValueError(f"control_injection_mode {control_injection_mode} is not supported, use 'add' or 'concat'")
182
+
183
+ self.input_blocks = nn.ModuleList(
184
+ [
185
+ TimestepEmbedSequential(conv_nd(dims, in_channels, model_channels, 3, padding=1))
186
+ ]
187
+ )
188
+
189
+ if self.addition_attention:
190
+ self.init_attn = TimestepEmbedSequential(
191
+ TemporalTransformer(
192
+ model_channels,
193
+ n_heads=8,
194
+ d_head=num_head_channels,
195
+ depth=transformer_depth,
196
+ context_dim=context_dim,
197
+ use_checkpoint=use_checkpoint, only_self_att=temporal_selfatt_only,
198
+ causal_attention=False, relative_position=use_relative_position,
199
+ temporal_length=temporal_length
200
+ )
201
+ )
202
+
203
+ ch = model_channels
204
+ ds = 1
205
+ for level, mult in enumerate(channel_mult):
206
+ for _ in range(num_res_blocks):
207
+ layers = [
208
+ ResBlock(ch, time_embed_dim, dropout,
209
+ out_channels=mult * model_channels, dims=dims, use_checkpoint=use_checkpoint,
210
+ use_scale_shift_norm=use_scale_shift_norm, tempspatial_aware=tempspatial_aware,
211
+ use_temporal_conv=temporal_conv
212
+ )
213
+ ]
214
+ ch = mult * model_channels
215
+ if ds in attention_resolutions:
216
+ if num_head_channels == -1:
217
+ dim_head = ch // num_heads
218
+ else:
219
+ num_heads = ch // num_head_channels
220
+ dim_head = num_head_channels
221
+ layers.append(
222
+ SpatialTransformer(ch, num_heads, dim_head,
223
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
224
+ use_checkpoint=use_checkpoint, disable_self_attn=False,
225
+ video_length=temporal_length, image_cross_attention=self.image_cross_attention,
226
+ image_cross_attention_scale_learnable=self.image_cross_attention_scale_learnable,
227
+ )
228
+ )
229
+ if self.temporal_attention:
230
+ layers.append(
231
+ TemporalTransformer(ch, num_heads, dim_head,
232
+ depth=transformer_depth, context_dim=context_dim, use_linear=use_linear,
233
+ use_checkpoint=use_checkpoint, only_self_att=temporal_self_att_only,
234
+ causal_attention=use_causal_attention, relative_position=use_relative_position,
235
+ temporal_length=temporal_length
236
+ )
237
+ )
238
+ self.input_blocks.append(TimestepEmbedSequential(*layers))
239
+
240
+ if level < len(channel_mult) - 1:
241
+ out_ch = ch
242
+ self.input_blocks.append(
243
+ TimestepEmbedSequential(
244
+ ResBlock(ch, time_embed_dim, dropout,
245
+ out_channels=out_ch, dims=dims, use_checkpoint=use_checkpoint,
246
+ use_scale_shift_norm=use_scale_shift_norm,
247
+ down=True
248
+ )
249
+ if resblock_updown
250
+ else Downsample(ch, conv_resample, dims=dims, out_channels=out_ch)
251
+ )
252
+ )
253
+ ch = out_ch
254
+ ds *= 2
255
+
256
+ def forward(
257
+ self,
258
+ noisy_latents,
259
+ timesteps,
260
+ context_text,
261
+ context_img=None,
262
+ fps=None,
263
+ layer_latents=None, # [b, n_layer, t, c, h, w]
264
+ layer_latent_mask=None, # [b, n_layer, t, 1, h, w]
265
+ motion_scores=None, # [b, n_layer]
266
+ sketch=None, # [b, n_layer, t, c, h, w]
267
+ trajectory=None, # [b, n_layer, t, c, h, w]
268
+ ):
269
+ if self.ignore_noisy_latents:
270
+ noisy_latents_shape = list(noisy_latents.shape)
271
+ noisy_latents_shape[1] = 0
272
+ noisy_latents = torch.zeros(noisy_latents_shape, device=noisy_latents.device, dtype=noisy_latents.dtype)
273
+
274
+ b, _, t, height, width = noisy_latents.shape
275
+ n_layer = layer_latents.shape[1]
276
+ t_emb = self.time_proj(timesteps).type(noisy_latents.dtype)
277
+ emb = self.time_embed(t_emb)
278
+
279
+ ## repeat t times for context [(b t) 77 768] & time embedding
280
+ ## check if we use per-frame image conditioning
281
+ if context_img is not None: ## decompose context into text and image
282
+ context_text = repeat(context_text, 'b l c -> (b n t) l c', n=n_layer, t=t)
283
+ context_img = repeat(context_img, 'b tl c -> b n tl c', n=n_layer)
284
+ context_img = rearrange(context_img, 'b n (t l) c -> (b n t) l c', t=t)
285
+ context = torch.cat([context_text, context_img], dim=1)
286
+ else:
287
+ context = repeat(context_text, 'b l c -> (b n t) l c', n=n_layer, t=t)
288
+ emb = repeat(emb, 'b c -> (b n t) c', n=n_layer, t=t)
289
+
290
+ ## always in shape (b n t) c h w, except for temporal layer
291
+ noisy_latents = repeat(noisy_latents, 'b c t h w -> (b n t) c h w', n=n_layer)
292
+
293
+ ## combine emb
294
+ if self.fps_condition:
295
+ if fps is None:
296
+ fps = torch.tensor(
297
+ [self.default_fs] * b, dtype=torch.long, device=noisy_latents.device)
298
+ fps_emb = self.time_proj(fps).type(noisy_latents.dtype)
299
+
300
+ fps_embed = self.fps_embedding(fps_emb)
301
+ fps_embed = repeat(fps_embed, 'b c -> (b n t) c', n=n_layer, t=t)
302
+ emb = emb + fps_embed
303
+
304
+ ## process conditions
305
+ layer_condition = torch.cat([layer_latents, layer_latent_mask], dim=3)
306
+ layer_condition = rearrange(layer_condition, 'b n t c h w -> (b n t) c h w')
307
+ h = torch.cat([noisy_latents, layer_condition], dim=1)
308
+
309
+ if "motion_score" in self.condition_channels:
310
+ motion_condition = repeat(motion_scores, 'b n -> b n t 1 h w', t=t, h=height, w=width)
311
+ motion_condition = torch.cat([motion_condition, layer_latent_mask], dim=3)
312
+ motion_condition = rearrange(motion_condition, 'b n t c h w -> (b n t) c h w')
313
+ motion_condition = self.motion_embedding(motion_condition)
314
+ if self.control_injection_mode == 'concat':
315
+ h = torch.cat([h, motion_condition], dim=1)
316
+
317
+ if "sketch" in self.condition_channels:
318
+ sketch_condition = rearrange(sketch, 'b n t c h w -> (b n t) c h w')
319
+ sketch_condition = self.sketch_embedding(sketch_condition)
320
+ if self.control_injection_mode == 'concat':
321
+ h = torch.cat([h, sketch_condition], dim=1)
322
+
323
+ if "trajectory" in self.condition_channels:
324
+ traj_condition = rearrange(trajectory, 'b n t c h w -> (b n t) c h w')
325
+ traj_condition = self.trajectory_embedding(traj_condition)
326
+ if self.control_injection_mode == 'concat':
327
+ h = torch.cat([h, traj_condition], dim=1)
328
+
329
+ layer_features = []
330
+ for id, module in enumerate(self.input_blocks):
331
+ h = module(h, emb, context=context, batch_size=b*n_layer)
332
+ if id == 0:
333
+ if self.control_injection_mode == 'add':
334
+ if "motion_score" in self.condition_channels:
335
+ h = h + motion_condition
336
+ if "sketch" in self.condition_channels:
337
+ h = h + sketch_condition
338
+ if "trajectory" in self.condition_channels:
339
+ h = h + traj_condition
340
+ if self.addition_attention:
341
+ h = self.init_attn(h, emb, context=context, batch_size=b*n_layer)
342
+ if SpatialTransformer in [type(m) for m in module]:
343
+ layer_features.append(rearrange(h, '(b n t) c h w -> b n t c h w', b=b, n=n_layer))
344
+
345
+ return layer_features
346
+
347
+ @classmethod
348
+ def from_pretrained(cls, pretrained_model_name_or_path, layer_controlnet_additional_kwargs={}, **kwargs):
349
+ cache_dir = kwargs.pop("cache_dir", None)
350
+ force_download = kwargs.pop("force_download", False)
351
+ proxies = kwargs.pop("proxies", None)
352
+ local_files_only = kwargs.pop("local_files_only", None)
353
+ token = kwargs.pop("token", None)
354
+ revision = kwargs.pop("revision", None)
355
+ subfolder = kwargs.pop("subfolder", None)
356
+ variant = kwargs.pop("variant", None)
357
+ use_safetensors = kwargs.pop("use_safetensors", None)
358
+
359
+ allow_pickle = False
360
+ if use_safetensors is None:
361
+ use_safetensors = True
362
+ allow_pickle = True
363
+
364
+ # Load config if we don't provide a configuration
365
+ config_path = pretrained_model_name_or_path
366
+
367
+ user_agent = {
368
+ "diffusers": __version__,
369
+ "file_type": "model",
370
+ "framework": "pytorch",
371
+ }
372
+
373
+ # load config
374
+ config, unused_kwargs, commit_hash = cls.load_config(
375
+ config_path,
376
+ cache_dir=cache_dir,
377
+ return_unused_kwargs=True,
378
+ return_commit_hash=True,
379
+ force_download=force_download,
380
+ proxies=proxies,
381
+ local_files_only=local_files_only,
382
+ token=token,
383
+ revision=revision,
384
+ subfolder=subfolder,
385
+ user_agent=user_agent,
386
+ **kwargs,
387
+ )
388
+
389
+ for key, value in layer_controlnet_additional_kwargs.items():
390
+ if isinstance(value, (ListConfig, DictConfig)):
391
+ config[key] = OmegaConf.to_container(value, resolve=True)
392
+ else:
393
+ config[key] = value
394
+
395
+ # load model
396
+ model_file = None
397
+ if use_safetensors:
398
+ try:
399
+ model_file = _get_model_file(
400
+ pretrained_model_name_or_path,
401
+ weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
402
+ cache_dir=cache_dir,
403
+ force_download=force_download,
404
+ proxies=proxies,
405
+ local_files_only=local_files_only,
406
+ token=token,
407
+ revision=revision,
408
+ subfolder=subfolder,
409
+ user_agent=user_agent,
410
+ commit_hash=commit_hash,
411
+ )
412
+
413
+ except IOError as e:
414
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
415
+ if not allow_pickle:
416
+ raise
417
+ logger.warning(
418
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
419
+ )
420
+
421
+ if model_file is None:
422
+ model_file = _get_model_file(
423
+ pretrained_model_name_or_path,
424
+ weights_name=_add_variant(WEIGHTS_NAME, variant),
425
+ cache_dir=cache_dir,
426
+ force_download=force_download,
427
+ proxies=proxies,
428
+ local_files_only=local_files_only,
429
+ token=token,
430
+ revision=revision,
431
+ subfolder=subfolder,
432
+ user_agent=user_agent,
433
+ commit_hash=commit_hash,
434
+ )
435
+
436
+ model = cls.from_config(config, **unused_kwargs)
437
+ state_dict = load_state_dict(model_file, variant)
438
+
439
+ if state_dict['input_blocks.0.0.weight'].shape[1] != model.input_blocks[0][0].weight.shape[1]:
440
+ state_dict.pop('input_blocks.0.0.weight')
441
+
442
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
443
+ print(f"LayerControlNet loaded from {model_file} with {len(missing_keys)} missing keys and {len(unexpected_keys)} unexpected keys.")
444
+ return model