Fabrice-TIERCELIN commited on
Commit
aa24895
·
verified ·
1 Parent(s): 9eb54a4

Upload 3 files

Browse files
hyvideo/vae/autoencoder_kl_causal_3d.py ADDED
@@ -0,0 +1,603 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+ from typing import Dict, Optional, Tuple, Union
20
+ from dataclasses import dataclass
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+
25
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
26
+
27
+ try:
28
+ # This diffusers is modified and packed in the mirror.
29
+ from diffusers.loaders import FromOriginalVAEMixin
30
+ except ImportError:
31
+ # Use this to be compatible with the original diffusers.
32
+ from diffusers.loaders.single_file_model import FromOriginalModelMixin as FromOriginalVAEMixin
33
+ from diffusers.utils.accelerate_utils import apply_forward_hook
34
+ from diffusers.models.attention_processor import (
35
+ ADDED_KV_ATTENTION_PROCESSORS,
36
+ CROSS_ATTENTION_PROCESSORS,
37
+ Attention,
38
+ AttentionProcessor,
39
+ AttnAddedKVProcessor,
40
+ AttnProcessor,
41
+ )
42
+ from diffusers.models.modeling_outputs import AutoencoderKLOutput
43
+ from diffusers.models.modeling_utils import ModelMixin
44
+ from .vae import DecoderCausal3D, BaseOutput, DecoderOutput, DiagonalGaussianDistribution, EncoderCausal3D
45
+
46
+
47
+ @dataclass
48
+ class DecoderOutput2(BaseOutput):
49
+ sample: torch.FloatTensor
50
+ posterior: Optional[DiagonalGaussianDistribution] = None
51
+
52
+
53
+ class AutoencoderKLCausal3D(ModelMixin, ConfigMixin, FromOriginalVAEMixin):
54
+ r"""
55
+ A VAE model with KL loss for encoding images/videos into latents and decoding latent representations into images/videos.
56
+
57
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
58
+ for all models (such as downloading or saving).
59
+ """
60
+
61
+ _supports_gradient_checkpointing = True
62
+
63
+ @register_to_config
64
+ def __init__(
65
+ self,
66
+ in_channels: int = 3,
67
+ out_channels: int = 3,
68
+ down_block_types: Tuple[str] = ("DownEncoderBlockCausal3D",),
69
+ up_block_types: Tuple[str] = ("UpDecoderBlockCausal3D",),
70
+ block_out_channels: Tuple[int] = (64,),
71
+ layers_per_block: int = 1,
72
+ act_fn: str = "silu",
73
+ latent_channels: int = 4,
74
+ norm_num_groups: int = 32,
75
+ sample_size: int = 32,
76
+ sample_tsize: int = 64,
77
+ scaling_factor: float = 0.18215,
78
+ force_upcast: float = True,
79
+ spatial_compression_ratio: int = 8,
80
+ time_compression_ratio: int = 4,
81
+ mid_block_add_attention: bool = True,
82
+ ):
83
+ super().__init__()
84
+
85
+ self.time_compression_ratio = time_compression_ratio
86
+
87
+ self.encoder = EncoderCausal3D(
88
+ in_channels=in_channels,
89
+ out_channels=latent_channels,
90
+ down_block_types=down_block_types,
91
+ block_out_channels=block_out_channels,
92
+ layers_per_block=layers_per_block,
93
+ act_fn=act_fn,
94
+ norm_num_groups=norm_num_groups,
95
+ double_z=True,
96
+ time_compression_ratio=time_compression_ratio,
97
+ spatial_compression_ratio=spatial_compression_ratio,
98
+ mid_block_add_attention=mid_block_add_attention,
99
+ )
100
+
101
+ self.decoder = DecoderCausal3D(
102
+ in_channels=latent_channels,
103
+ out_channels=out_channels,
104
+ up_block_types=up_block_types,
105
+ block_out_channels=block_out_channels,
106
+ layers_per_block=layers_per_block,
107
+ norm_num_groups=norm_num_groups,
108
+ act_fn=act_fn,
109
+ time_compression_ratio=time_compression_ratio,
110
+ spatial_compression_ratio=spatial_compression_ratio,
111
+ mid_block_add_attention=mid_block_add_attention,
112
+ )
113
+
114
+ self.quant_conv = nn.Conv3d(2 * latent_channels, 2 * latent_channels, kernel_size=1)
115
+ self.post_quant_conv = nn.Conv3d(latent_channels, latent_channels, kernel_size=1)
116
+
117
+ self.use_slicing = False
118
+ self.use_spatial_tiling = False
119
+ self.use_temporal_tiling = False
120
+
121
+ # only relevant if vae tiling is enabled
122
+ self.tile_sample_min_tsize = sample_tsize
123
+ self.tile_latent_min_tsize = sample_tsize // time_compression_ratio
124
+
125
+ self.tile_sample_min_size = self.config.sample_size
126
+ sample_size = (
127
+ self.config.sample_size[0]
128
+ if isinstance(self.config.sample_size, (list, tuple))
129
+ else self.config.sample_size
130
+ )
131
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
132
+ self.tile_overlap_factor = 0.25
133
+
134
+ def _set_gradient_checkpointing(self, module, value=False):
135
+ if isinstance(module, (EncoderCausal3D, DecoderCausal3D)):
136
+ module.gradient_checkpointing = value
137
+
138
+ def enable_temporal_tiling(self, use_tiling: bool = True):
139
+ self.use_temporal_tiling = use_tiling
140
+
141
+ def disable_temporal_tiling(self):
142
+ self.enable_temporal_tiling(False)
143
+
144
+ def enable_spatial_tiling(self, use_tiling: bool = True):
145
+ self.use_spatial_tiling = use_tiling
146
+
147
+ def disable_spatial_tiling(self):
148
+ self.enable_spatial_tiling(False)
149
+
150
+ def enable_tiling(self, use_tiling: bool = True):
151
+ r"""
152
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
153
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
154
+ processing larger videos.
155
+ """
156
+ self.enable_spatial_tiling(use_tiling)
157
+ self.enable_temporal_tiling(use_tiling)
158
+
159
+ def disable_tiling(self):
160
+ r"""
161
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
162
+ decoding in one step.
163
+ """
164
+ self.disable_spatial_tiling()
165
+ self.disable_temporal_tiling()
166
+
167
+ def enable_slicing(self):
168
+ r"""
169
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
170
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
171
+ """
172
+ self.use_slicing = True
173
+
174
+ def disable_slicing(self):
175
+ r"""
176
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
177
+ decoding in one step.
178
+ """
179
+ self.use_slicing = False
180
+
181
+ @property
182
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
183
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
184
+ r"""
185
+ Returns:
186
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
187
+ indexed by its weight name.
188
+ """
189
+ # set recursively
190
+ processors = {}
191
+
192
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
193
+ if hasattr(module, "get_processor"):
194
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
195
+
196
+ for sub_name, child in module.named_children():
197
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
198
+
199
+ return processors
200
+
201
+ for name, module in self.named_children():
202
+ fn_recursive_add_processors(name, module, processors)
203
+
204
+ return processors
205
+
206
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
207
+ def set_attn_processor(
208
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
209
+ ):
210
+ r"""
211
+ Sets the attention processor to use to compute attention.
212
+
213
+ Parameters:
214
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
215
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
216
+ for **all** `Attention` layers.
217
+
218
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
219
+ processor. This is strongly recommended when setting trainable attention processors.
220
+
221
+ """
222
+ count = len(self.attn_processors.keys())
223
+
224
+ if isinstance(processor, dict) and len(processor) != count:
225
+ raise ValueError(
226
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
227
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
228
+ )
229
+
230
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
231
+ if hasattr(module, "set_processor"):
232
+ if not isinstance(processor, dict):
233
+ module.set_processor(processor, _remove_lora=_remove_lora)
234
+ else:
235
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
236
+
237
+ for sub_name, child in module.named_children():
238
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
239
+
240
+ for name, module in self.named_children():
241
+ fn_recursive_attn_processor(name, module, processor)
242
+
243
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
244
+ def set_default_attn_processor(self):
245
+ """
246
+ Disables custom attention processors and sets the default attention implementation.
247
+ """
248
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
249
+ processor = AttnAddedKVProcessor()
250
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
251
+ processor = AttnProcessor()
252
+ else:
253
+ raise ValueError(
254
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
255
+ )
256
+
257
+ self.set_attn_processor(processor, _remove_lora=True)
258
+
259
+ @apply_forward_hook
260
+ def encode(
261
+ self, x: torch.FloatTensor, return_dict: bool = True
262
+ ) -> Union[AutoencoderKLOutput, Tuple[DiagonalGaussianDistribution]]:
263
+ """
264
+ Encode a batch of images/videos into latents.
265
+
266
+ Args:
267
+ x (`torch.FloatTensor`): Input batch of images/videos.
268
+ return_dict (`bool`, *optional*, defaults to `True`):
269
+ Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
270
+
271
+ Returns:
272
+ The latent representations of the encoded images/videos. If `return_dict` is True, a
273
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
274
+ """
275
+ assert len(x.shape) == 5, "The input tensor should have 5 dimensions."
276
+
277
+ if self.use_temporal_tiling and x.shape[2] > self.tile_sample_min_tsize:
278
+ return self.temporal_tiled_encode(x, return_dict=return_dict)
279
+
280
+ if self.use_spatial_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
281
+ return self.spatial_tiled_encode(x, return_dict=return_dict)
282
+
283
+ if self.use_slicing and x.shape[0] > 1:
284
+ encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)]
285
+ h = torch.cat(encoded_slices)
286
+ else:
287
+ h = self.encoder(x)
288
+
289
+ moments = self.quant_conv(h)
290
+ posterior = DiagonalGaussianDistribution(moments)
291
+
292
+ if not return_dict:
293
+ return (posterior,)
294
+
295
+ return AutoencoderKLOutput(latent_dist=posterior)
296
+
297
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
298
+ assert len(z.shape) == 5, "The input tensor should have 5 dimensions."
299
+
300
+ if self.use_temporal_tiling and z.shape[2] > self.tile_latent_min_tsize:
301
+ return self.temporal_tiled_decode(z, return_dict=return_dict)
302
+
303
+ if self.use_spatial_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
304
+ return self.spatial_tiled_decode(z, return_dict=return_dict)
305
+
306
+ z = self.post_quant_conv(z)
307
+ dec = self.decoder(z)
308
+
309
+ if not return_dict:
310
+ return (dec,)
311
+
312
+ return DecoderOutput(sample=dec)
313
+
314
+ @apply_forward_hook
315
+ def decode(
316
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
317
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
318
+ """
319
+ Decode a batch of images/videos.
320
+
321
+ Args:
322
+ z (`torch.FloatTensor`): Input batch of latent vectors.
323
+ return_dict (`bool`, *optional*, defaults to `True`):
324
+ Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
325
+
326
+ Returns:
327
+ [`~models.vae.DecoderOutput`] or `tuple`:
328
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
329
+ returned.
330
+
331
+ """
332
+ if self.use_slicing and z.shape[0] > 1:
333
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
334
+ decoded = torch.cat(decoded_slices)
335
+ else:
336
+ decoded = self._decode(z).sample
337
+
338
+ if not return_dict:
339
+ return (decoded,)
340
+
341
+ return DecoderOutput(sample=decoded)
342
+
343
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
344
+ blend_extent = min(a.shape[-2], b.shape[-2], blend_extent)
345
+ for y in range(blend_extent):
346
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (y / blend_extent)
347
+ return b
348
+
349
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
350
+ blend_extent = min(a.shape[-1], b.shape[-1], blend_extent)
351
+ for x in range(blend_extent):
352
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (x / blend_extent)
353
+ return b
354
+
355
+ def blend_t(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
356
+ blend_extent = min(a.shape[-3], b.shape[-3], blend_extent)
357
+ for x in range(blend_extent):
358
+ b[:, :, x, :, :] = a[:, :, -blend_extent + x, :, :] * (1 - x / blend_extent) + b[:, :, x, :, :] * (x / blend_extent)
359
+ return b
360
+
361
+ def spatial_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True, return_moments: bool = False) -> AutoencoderKLOutput:
362
+ r"""Encode a batch of images/videos using a tiled encoder.
363
+
364
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
365
+ steps. This is useful to keep memory use constant regardless of image/videos size. The end result of tiled encoding is
366
+ different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the
367
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
368
+ output, but they should be much less noticeable.
369
+
370
+ Args:
371
+ x (`torch.FloatTensor`): Input batch of images/videos.
372
+ return_dict (`bool`, *optional*, defaults to `True`):
373
+ Whether or not to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
374
+
375
+ Returns:
376
+ [`~models.autoencoder_kl.AutoencoderKLOutput`] or `tuple`:
377
+ If return_dict is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain
378
+ `tuple` is returned.
379
+ """
380
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
381
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
382
+ row_limit = self.tile_latent_min_size - blend_extent
383
+
384
+ # Split video into tiles and encode them separately.
385
+ rows = []
386
+ for i in range(0, x.shape[-2], overlap_size):
387
+ row = []
388
+ for j in range(0, x.shape[-1], overlap_size):
389
+ tile = x[:, :, :, i: i + self.tile_sample_min_size, j: j + self.tile_sample_min_size]
390
+ tile = self.encoder(tile)
391
+ tile = self.quant_conv(tile)
392
+ row.append(tile)
393
+ rows.append(row)
394
+ result_rows = []
395
+ for i, row in enumerate(rows):
396
+ result_row = []
397
+ for j, tile in enumerate(row):
398
+ # blend the above tile and the left tile
399
+ # to the current tile and add the current tile to the result row
400
+ if i > 0:
401
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
402
+ if j > 0:
403
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
404
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
405
+ result_rows.append(torch.cat(result_row, dim=-1))
406
+
407
+ moments = torch.cat(result_rows, dim=-2)
408
+ if return_moments:
409
+ return moments
410
+
411
+ posterior = DiagonalGaussianDistribution(moments)
412
+ if not return_dict:
413
+ return (posterior,)
414
+
415
+ return AutoencoderKLOutput(latent_dist=posterior)
416
+
417
+ def spatial_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
418
+ r"""
419
+ Decode a batch of images/videos using a tiled decoder.
420
+
421
+ Args:
422
+ z (`torch.FloatTensor`): Input batch of latent vectors.
423
+ return_dict (`bool`, *optional*, defaults to `True`):
424
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
425
+
426
+ Returns:
427
+ [`~models.vae.DecoderOutput`] or `tuple`:
428
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
429
+ returned.
430
+ """
431
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
432
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
433
+ row_limit = self.tile_sample_min_size - blend_extent
434
+
435
+ # Split z into overlapping tiles and decode them separately.
436
+ # The tiles have an overlap to avoid seams between tiles.
437
+ rows = []
438
+ for i in range(0, z.shape[-2], overlap_size):
439
+ row = []
440
+ for j in range(0, z.shape[-1], overlap_size):
441
+ tile = z[:, :, :, i: i + self.tile_latent_min_size, j: j + self.tile_latent_min_size]
442
+ tile = self.post_quant_conv(tile)
443
+ decoded = self.decoder(tile)
444
+ row.append(decoded)
445
+ rows.append(row)
446
+ result_rows = []
447
+ for i, row in enumerate(rows):
448
+ result_row = []
449
+ for j, tile in enumerate(row):
450
+ # blend the above tile and the left tile
451
+ # to the current tile and add the current tile to the result row
452
+ if i > 0:
453
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
454
+ if j > 0:
455
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
456
+ result_row.append(tile[:, :, :, :row_limit, :row_limit])
457
+ result_rows.append(torch.cat(result_row, dim=-1))
458
+
459
+ dec = torch.cat(result_rows, dim=-2)
460
+ if not return_dict:
461
+ return (dec,)
462
+
463
+ return DecoderOutput(sample=dec)
464
+
465
+ def temporal_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
466
+
467
+ B, C, T, H, W = x.shape
468
+ overlap_size = int(self.tile_sample_min_tsize * (1 - self.tile_overlap_factor))
469
+ blend_extent = int(self.tile_latent_min_tsize * self.tile_overlap_factor)
470
+ t_limit = self.tile_latent_min_tsize - blend_extent
471
+
472
+ # Split the video into tiles and encode them separately.
473
+ row = []
474
+ for i in range(0, T, overlap_size):
475
+ tile = x[:, :, i: i + self.tile_sample_min_tsize + 1, :, :]
476
+ if self.use_spatial_tiling and (tile.shape[-1] > self.tile_sample_min_size or tile.shape[-2] > self.tile_sample_min_size):
477
+ tile = self.spatial_tiled_encode(tile, return_moments=True)
478
+ else:
479
+ tile = self.encoder(tile)
480
+ tile = self.quant_conv(tile)
481
+ if i > 0:
482
+ tile = tile[:, :, 1:, :, :]
483
+ row.append(tile)
484
+ result_row = []
485
+ for i, tile in enumerate(row):
486
+ if i > 0:
487
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
488
+ result_row.append(tile[:, :, :t_limit, :, :])
489
+ else:
490
+ result_row.append(tile[:, :, :t_limit + 1, :, :])
491
+
492
+ moments = torch.cat(result_row, dim=2)
493
+ posterior = DiagonalGaussianDistribution(moments)
494
+
495
+ if not return_dict:
496
+ return (posterior,)
497
+
498
+ return AutoencoderKLOutput(latent_dist=posterior)
499
+
500
+ def temporal_tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
501
+ # Split z into overlapping tiles and decode them separately.
502
+
503
+ B, C, T, H, W = z.shape
504
+ overlap_size = int(self.tile_latent_min_tsize * (1 - self.tile_overlap_factor))
505
+ blend_extent = int(self.tile_sample_min_tsize * self.tile_overlap_factor)
506
+ t_limit = self.tile_sample_min_tsize - blend_extent
507
+
508
+ row = []
509
+ for i in range(0, T, overlap_size):
510
+ tile = z[:, :, i: i + self.tile_latent_min_tsize + 1, :, :]
511
+ if self.use_spatial_tiling and (tile.shape[-1] > self.tile_latent_min_size or tile.shape[-2] > self.tile_latent_min_size):
512
+ decoded = self.spatial_tiled_decode(tile, return_dict=True).sample
513
+ else:
514
+ tile = self.post_quant_conv(tile)
515
+ decoded = self.decoder(tile)
516
+ if i > 0:
517
+ decoded = decoded[:, :, 1:, :, :]
518
+ row.append(decoded)
519
+ result_row = []
520
+ for i, tile in enumerate(row):
521
+ if i > 0:
522
+ tile = self.blend_t(row[i - 1], tile, blend_extent)
523
+ result_row.append(tile[:, :, :t_limit, :, :])
524
+ else:
525
+ result_row.append(tile[:, :, :t_limit + 1, :, :])
526
+
527
+ dec = torch.cat(result_row, dim=2)
528
+ if not return_dict:
529
+ return (dec,)
530
+
531
+ return DecoderOutput(sample=dec)
532
+
533
+ def forward(
534
+ self,
535
+ sample: torch.FloatTensor,
536
+ sample_posterior: bool = False,
537
+ return_dict: bool = True,
538
+ return_posterior: bool = False,
539
+ generator: Optional[torch.Generator] = None,
540
+ ) -> Union[DecoderOutput2, torch.FloatTensor]:
541
+ r"""
542
+ Args:
543
+ sample (`torch.FloatTensor`): Input sample.
544
+ sample_posterior (`bool`, *optional*, defaults to `False`):
545
+ Whether to sample from the posterior.
546
+ return_dict (`bool`, *optional*, defaults to `True`):
547
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
548
+ """
549
+ x = sample
550
+ posterior = self.encode(x).latent_dist
551
+ if sample_posterior:
552
+ z = posterior.sample(generator=generator)
553
+ else:
554
+ z = posterior.mode()
555
+ dec = self.decode(z).sample
556
+
557
+ if not return_dict:
558
+ if return_posterior:
559
+ return (dec, posterior)
560
+ else:
561
+ return (dec,)
562
+ if return_posterior:
563
+ return DecoderOutput2(sample=dec, posterior=posterior)
564
+ else:
565
+ return DecoderOutput2(sample=dec)
566
+
567
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
568
+ def fuse_qkv_projections(self):
569
+ """
570
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
571
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
572
+
573
+ <Tip warning={true}>
574
+
575
+ This API is 🧪 experimental.
576
+
577
+ </Tip>
578
+ """
579
+ self.original_attn_processors = None
580
+
581
+ for _, attn_processor in self.attn_processors.items():
582
+ if "Added" in str(attn_processor.__class__.__name__):
583
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
584
+
585
+ self.original_attn_processors = self.attn_processors
586
+
587
+ for module in self.modules():
588
+ if isinstance(module, Attention):
589
+ module.fuse_projections(fuse=True)
590
+
591
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
592
+ def unfuse_qkv_projections(self):
593
+ """Disables the fused QKV projection if enabled.
594
+
595
+ <Tip warning={true}>
596
+
597
+ This API is 🧪 experimental.
598
+
599
+ </Tip>
600
+
601
+ """
602
+ if self.original_attn_processors is not None:
603
+ self.set_attn_processor(self.original_attn_processors)
hyvideo/vae/unet_causal_3d_blocks.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ #
16
+ # Modified from diffusers==0.29.2
17
+ #
18
+ # ==============================================================================
19
+
20
+ from typing import Optional, Tuple, Union
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+ from einops import rearrange
26
+
27
+ from diffusers.utils import logging
28
+ from diffusers.models.activations import get_activation
29
+ from diffusers.models.attention_processor import SpatialNorm
30
+ from diffusers.models.attention_processor import Attention
31
+ from diffusers.models.normalization import AdaGroupNorm
32
+ from diffusers.models.normalization import RMSNorm
33
+
34
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
35
+
36
+
37
+ def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, device, batch_size: int = None):
38
+ seq_len = n_frame * n_hw
39
+ mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
40
+ for i in range(seq_len):
41
+ i_frame = i // n_hw
42
+ mask[i, : (i_frame + 1) * n_hw] = 0
43
+ if batch_size is not None:
44
+ mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
45
+ return mask
46
+
47
+
48
+ class CausalConv3d(nn.Module):
49
+ """
50
+ Implements a causal 3D convolution layer where each position only depends on previous timesteps and current spatial locations.
51
+ This maintains temporal causality in video generation tasks.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ chan_in,
57
+ chan_out,
58
+ kernel_size: Union[int, Tuple[int, int, int]],
59
+ stride: Union[int, Tuple[int, int, int]] = 1,
60
+ dilation: Union[int, Tuple[int, int, int]] = 1,
61
+ pad_mode='replicate',
62
+ **kwargs
63
+ ):
64
+ super().__init__()
65
+
66
+ self.pad_mode = pad_mode
67
+ padding = (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size - 1, 0) # W, H, T
68
+ self.time_causal_padding = padding
69
+
70
+ self.conv = nn.Conv3d(chan_in, chan_out, kernel_size, stride=stride, dilation=dilation, **kwargs)
71
+
72
+ def forward(self, x):
73
+ x = F.pad(x, self.time_causal_padding, mode=self.pad_mode)
74
+ return self.conv(x)
75
+
76
+
77
+ class UpsampleCausal3D(nn.Module):
78
+ """
79
+ A 3D upsampling layer with an optional convolution.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ channels: int,
85
+ use_conv: bool = False,
86
+ use_conv_transpose: bool = False,
87
+ out_channels: Optional[int] = None,
88
+ name: str = "conv",
89
+ kernel_size: Optional[int] = None,
90
+ padding=1,
91
+ norm_type=None,
92
+ eps=None,
93
+ elementwise_affine=None,
94
+ bias=True,
95
+ interpolate=True,
96
+ upsample_factor=(2, 2, 2),
97
+ ):
98
+ super().__init__()
99
+ self.channels = channels
100
+ self.out_channels = out_channels or channels
101
+ self.use_conv = use_conv
102
+ self.use_conv_transpose = use_conv_transpose
103
+ self.name = name
104
+ self.interpolate = interpolate
105
+ self.upsample_factor = upsample_factor
106
+
107
+ if norm_type == "ln_norm":
108
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
109
+ elif norm_type == "rms_norm":
110
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
111
+ elif norm_type is None:
112
+ self.norm = None
113
+ else:
114
+ raise ValueError(f"unknown norm_type: {norm_type}")
115
+
116
+ conv = None
117
+ if use_conv_transpose:
118
+ raise NotImplementedError
119
+ elif use_conv:
120
+ if kernel_size is None:
121
+ kernel_size = 3
122
+ conv = CausalConv3d(self.channels, self.out_channels, kernel_size=kernel_size, bias=bias)
123
+
124
+ if name == "conv":
125
+ self.conv = conv
126
+ else:
127
+ self.Conv2d_0 = conv
128
+
129
+ def forward(
130
+ self,
131
+ hidden_states: torch.FloatTensor,
132
+ output_size: Optional[int] = None,
133
+ scale: float = 1.0,
134
+ ) -> torch.FloatTensor:
135
+ assert hidden_states.shape[1] == self.channels
136
+
137
+ if self.norm is not None:
138
+ raise NotImplementedError
139
+
140
+ if self.use_conv_transpose:
141
+ return self.conv(hidden_states)
142
+
143
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
144
+ dtype = hidden_states.dtype
145
+ if dtype == torch.bfloat16:
146
+ hidden_states = hidden_states.to(torch.float32)
147
+
148
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
149
+ if hidden_states.shape[0] >= 64:
150
+ hidden_states = hidden_states.contiguous()
151
+
152
+ # if `output_size` is passed we force the interpolation output
153
+ # size and do not make use of `scale_factor=2`
154
+ if self.interpolate:
155
+ B, C, T, H, W = hidden_states.shape
156
+ first_h, other_h = hidden_states.split((1, T - 1), dim=2)
157
+ if output_size is None:
158
+ if T > 1:
159
+ other_h = F.interpolate(other_h, scale_factor=self.upsample_factor, mode="nearest")
160
+
161
+ first_h = first_h.squeeze(2)
162
+ first_h = F.interpolate(first_h, scale_factor=self.upsample_factor[1:], mode="nearest")
163
+ first_h = first_h.unsqueeze(2)
164
+ else:
165
+ raise NotImplementedError
166
+
167
+ if T > 1:
168
+ hidden_states = torch.cat((first_h, other_h), dim=2)
169
+ else:
170
+ hidden_states = first_h
171
+
172
+ # If the input is bfloat16, we cast back to bfloat16
173
+ if dtype == torch.bfloat16:
174
+ hidden_states = hidden_states.to(dtype)
175
+
176
+ if self.use_conv:
177
+ if self.name == "conv":
178
+ hidden_states = self.conv(hidden_states)
179
+ else:
180
+ hidden_states = self.Conv2d_0(hidden_states)
181
+
182
+ return hidden_states
183
+
184
+
185
+ class DownsampleCausal3D(nn.Module):
186
+ """
187
+ A 3D downsampling layer with an optional convolution.
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ channels: int,
193
+ use_conv: bool = False,
194
+ out_channels: Optional[int] = None,
195
+ padding: int = 1,
196
+ name: str = "conv",
197
+ kernel_size=3,
198
+ norm_type=None,
199
+ eps=None,
200
+ elementwise_affine=None,
201
+ bias=True,
202
+ stride=2,
203
+ ):
204
+ super().__init__()
205
+ self.channels = channels
206
+ self.out_channels = out_channels or channels
207
+ self.use_conv = use_conv
208
+ self.padding = padding
209
+ stride = stride
210
+ self.name = name
211
+
212
+ if norm_type == "ln_norm":
213
+ self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
214
+ elif norm_type == "rms_norm":
215
+ self.norm = RMSNorm(channels, eps, elementwise_affine)
216
+ elif norm_type is None:
217
+ self.norm = None
218
+ else:
219
+ raise ValueError(f"unknown norm_type: {norm_type}")
220
+
221
+ if use_conv:
222
+ conv = CausalConv3d(
223
+ self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, bias=bias
224
+ )
225
+ else:
226
+ raise NotImplementedError
227
+
228
+ if name == "conv":
229
+ self.Conv2d_0 = conv
230
+ self.conv = conv
231
+ elif name == "Conv2d_0":
232
+ self.conv = conv
233
+ else:
234
+ self.conv = conv
235
+
236
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
237
+ assert hidden_states.shape[1] == self.channels
238
+
239
+ if self.norm is not None:
240
+ hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
241
+
242
+ assert hidden_states.shape[1] == self.channels
243
+
244
+ hidden_states = self.conv(hidden_states)
245
+
246
+ return hidden_states
247
+
248
+
249
+ class ResnetBlockCausal3D(nn.Module):
250
+ r"""
251
+ A Resnet block.
252
+ """
253
+
254
+ def __init__(
255
+ self,
256
+ *,
257
+ in_channels: int,
258
+ out_channels: Optional[int] = None,
259
+ conv_shortcut: bool = False,
260
+ dropout: float = 0.0,
261
+ temb_channels: int = 512,
262
+ groups: int = 32,
263
+ groups_out: Optional[int] = None,
264
+ pre_norm: bool = True,
265
+ eps: float = 1e-6,
266
+ non_linearity: str = "swish",
267
+ skip_time_act: bool = False,
268
+ # default, scale_shift, ada_group, spatial
269
+ time_embedding_norm: str = "default",
270
+ kernel: Optional[torch.FloatTensor] = None,
271
+ output_scale_factor: float = 1.0,
272
+ use_in_shortcut: Optional[bool] = None,
273
+ up: bool = False,
274
+ down: bool = False,
275
+ conv_shortcut_bias: bool = True,
276
+ conv_3d_out_channels: Optional[int] = None,
277
+ ):
278
+ super().__init__()
279
+ self.pre_norm = pre_norm
280
+ self.pre_norm = True
281
+ self.in_channels = in_channels
282
+ out_channels = in_channels if out_channels is None else out_channels
283
+ self.out_channels = out_channels
284
+ self.use_conv_shortcut = conv_shortcut
285
+ self.up = up
286
+ self.down = down
287
+ self.output_scale_factor = output_scale_factor
288
+ self.time_embedding_norm = time_embedding_norm
289
+ self.skip_time_act = skip_time_act
290
+
291
+ linear_cls = nn.Linear
292
+
293
+ if groups_out is None:
294
+ groups_out = groups
295
+
296
+ if self.time_embedding_norm == "ada_group":
297
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
298
+ elif self.time_embedding_norm == "spatial":
299
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
300
+ else:
301
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
302
+
303
+ self.conv1 = CausalConv3d(in_channels, out_channels, kernel_size=3, stride=1)
304
+
305
+ if temb_channels is not None:
306
+ if self.time_embedding_norm == "default":
307
+ self.time_emb_proj = linear_cls(temb_channels, out_channels)
308
+ elif self.time_embedding_norm == "scale_shift":
309
+ self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
310
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
311
+ self.time_emb_proj = None
312
+ else:
313
+ raise ValueError(f"Unknown time_embedding_norm : {self.time_embedding_norm} ")
314
+ else:
315
+ self.time_emb_proj = None
316
+
317
+ if self.time_embedding_norm == "ada_group":
318
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
319
+ elif self.time_embedding_norm == "spatial":
320
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
321
+ else:
322
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
323
+
324
+ self.dropout = torch.nn.Dropout(dropout)
325
+ conv_3d_out_channels = conv_3d_out_channels or out_channels
326
+ self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)
327
+
328
+ self.nonlinearity = get_activation(non_linearity)
329
+
330
+ self.upsample = self.downsample = None
331
+ if self.up:
332
+ self.upsample = UpsampleCausal3D(in_channels, use_conv=False)
333
+ elif self.down:
334
+ self.downsample = DownsampleCausal3D(in_channels, use_conv=False, name="op")
335
+
336
+ self.use_in_shortcut = self.in_channels != conv_3d_out_channels if use_in_shortcut is None else use_in_shortcut
337
+
338
+ self.conv_shortcut = None
339
+ if self.use_in_shortcut:
340
+ self.conv_shortcut = CausalConv3d(
341
+ in_channels,
342
+ conv_3d_out_channels,
343
+ kernel_size=1,
344
+ stride=1,
345
+ bias=conv_shortcut_bias,
346
+ )
347
+
348
+ def forward(
349
+ self,
350
+ input_tensor: torch.FloatTensor,
351
+ temb: torch.FloatTensor,
352
+ scale: float = 1.0,
353
+ ) -> torch.FloatTensor:
354
+ hidden_states = input_tensor
355
+
356
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
357
+ hidden_states = self.norm1(hidden_states, temb)
358
+ else:
359
+ hidden_states = self.norm1(hidden_states)
360
+
361
+ hidden_states = self.nonlinearity(hidden_states)
362
+
363
+ if self.upsample is not None:
364
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
365
+ if hidden_states.shape[0] >= 64:
366
+ input_tensor = input_tensor.contiguous()
367
+ hidden_states = hidden_states.contiguous()
368
+ input_tensor = (
369
+ self.upsample(input_tensor, scale=scale)
370
+ )
371
+ hidden_states = (
372
+ self.upsample(hidden_states, scale=scale)
373
+ )
374
+ elif self.downsample is not None:
375
+ input_tensor = (
376
+ self.downsample(input_tensor, scale=scale)
377
+ )
378
+ hidden_states = (
379
+ self.downsample(hidden_states, scale=scale)
380
+ )
381
+
382
+ hidden_states = self.conv1(hidden_states)
383
+
384
+ if self.time_emb_proj is not None:
385
+ if not self.skip_time_act:
386
+ temb = self.nonlinearity(temb)
387
+ temb = (
388
+ self.time_emb_proj(temb, scale)[:, :, None, None]
389
+ )
390
+
391
+ if temb is not None and self.time_embedding_norm == "default":
392
+ hidden_states = hidden_states + temb
393
+
394
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
395
+ hidden_states = self.norm2(hidden_states, temb)
396
+ else:
397
+ hidden_states = self.norm2(hidden_states)
398
+
399
+ if temb is not None and self.time_embedding_norm == "scale_shift":
400
+ scale, shift = torch.chunk(temb, 2, dim=1)
401
+ hidden_states = hidden_states * (1 + scale) + shift
402
+
403
+ hidden_states = self.nonlinearity(hidden_states)
404
+
405
+ hidden_states = self.dropout(hidden_states)
406
+ hidden_states = self.conv2(hidden_states)
407
+
408
+ if self.conv_shortcut is not None:
409
+ input_tensor = (
410
+ self.conv_shortcut(input_tensor)
411
+ )
412
+
413
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
414
+
415
+ return output_tensor
416
+
417
+
418
+ def get_down_block3d(
419
+ down_block_type: str,
420
+ num_layers: int,
421
+ in_channels: int,
422
+ out_channels: int,
423
+ temb_channels: int,
424
+ add_downsample: bool,
425
+ downsample_stride: int,
426
+ resnet_eps: float,
427
+ resnet_act_fn: str,
428
+ transformer_layers_per_block: int = 1,
429
+ num_attention_heads: Optional[int] = None,
430
+ resnet_groups: Optional[int] = None,
431
+ cross_attention_dim: Optional[int] = None,
432
+ downsample_padding: Optional[int] = None,
433
+ dual_cross_attention: bool = False,
434
+ use_linear_projection: bool = False,
435
+ only_cross_attention: bool = False,
436
+ upcast_attention: bool = False,
437
+ resnet_time_scale_shift: str = "default",
438
+ attention_type: str = "default",
439
+ resnet_skip_time_act: bool = False,
440
+ resnet_out_scale_factor: float = 1.0,
441
+ cross_attention_norm: Optional[str] = None,
442
+ attention_head_dim: Optional[int] = None,
443
+ downsample_type: Optional[str] = None,
444
+ dropout: float = 0.0,
445
+ ):
446
+ # If attn head dim is not defined, we default it to the number of heads
447
+ if attention_head_dim is None:
448
+ logger.warn(
449
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
450
+ )
451
+ attention_head_dim = num_attention_heads
452
+
453
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
454
+ if down_block_type == "DownEncoderBlockCausal3D":
455
+ return DownEncoderBlockCausal3D(
456
+ num_layers=num_layers,
457
+ in_channels=in_channels,
458
+ out_channels=out_channels,
459
+ dropout=dropout,
460
+ add_downsample=add_downsample,
461
+ downsample_stride=downsample_stride,
462
+ resnet_eps=resnet_eps,
463
+ resnet_act_fn=resnet_act_fn,
464
+ resnet_groups=resnet_groups,
465
+ downsample_padding=downsample_padding,
466
+ resnet_time_scale_shift=resnet_time_scale_shift,
467
+ )
468
+ raise ValueError(f"{down_block_type} does not exist.")
469
+
470
+
471
+ def get_up_block3d(
472
+ up_block_type: str,
473
+ num_layers: int,
474
+ in_channels: int,
475
+ out_channels: int,
476
+ prev_output_channel: int,
477
+ temb_channels: int,
478
+ add_upsample: bool,
479
+ upsample_scale_factor: Tuple,
480
+ resnet_eps: float,
481
+ resnet_act_fn: str,
482
+ resolution_idx: Optional[int] = None,
483
+ transformer_layers_per_block: int = 1,
484
+ num_attention_heads: Optional[int] = None,
485
+ resnet_groups: Optional[int] = None,
486
+ cross_attention_dim: Optional[int] = None,
487
+ dual_cross_attention: bool = False,
488
+ use_linear_projection: bool = False,
489
+ only_cross_attention: bool = False,
490
+ upcast_attention: bool = False,
491
+ resnet_time_scale_shift: str = "default",
492
+ attention_type: str = "default",
493
+ resnet_skip_time_act: bool = False,
494
+ resnet_out_scale_factor: float = 1.0,
495
+ cross_attention_norm: Optional[str] = None,
496
+ attention_head_dim: Optional[int] = None,
497
+ upsample_type: Optional[str] = None,
498
+ dropout: float = 0.0,
499
+ ) -> nn.Module:
500
+ # If attn head dim is not defined, we default it to the number of heads
501
+ if attention_head_dim is None:
502
+ logger.warn(
503
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
504
+ )
505
+ attention_head_dim = num_attention_heads
506
+
507
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
508
+ if up_block_type == "UpDecoderBlockCausal3D":
509
+ return UpDecoderBlockCausal3D(
510
+ num_layers=num_layers,
511
+ in_channels=in_channels,
512
+ out_channels=out_channels,
513
+ resolution_idx=resolution_idx,
514
+ dropout=dropout,
515
+ add_upsample=add_upsample,
516
+ upsample_scale_factor=upsample_scale_factor,
517
+ resnet_eps=resnet_eps,
518
+ resnet_act_fn=resnet_act_fn,
519
+ resnet_groups=resnet_groups,
520
+ resnet_time_scale_shift=resnet_time_scale_shift,
521
+ temb_channels=temb_channels,
522
+ )
523
+ raise ValueError(f"{up_block_type} does not exist.")
524
+
525
+
526
+ class UNetMidBlockCausal3D(nn.Module):
527
+ """
528
+ A 3D UNet mid-block [`UNetMidBlockCausal3D`] with multiple residual blocks and optional attention blocks.
529
+ """
530
+
531
+ def __init__(
532
+ self,
533
+ in_channels: int,
534
+ temb_channels: int,
535
+ dropout: float = 0.0,
536
+ num_layers: int = 1,
537
+ resnet_eps: float = 1e-6,
538
+ resnet_time_scale_shift: str = "default", # default, spatial
539
+ resnet_act_fn: str = "swish",
540
+ resnet_groups: int = 32,
541
+ attn_groups: Optional[int] = None,
542
+ resnet_pre_norm: bool = True,
543
+ add_attention: bool = True,
544
+ attention_head_dim: int = 1,
545
+ output_scale_factor: float = 1.0,
546
+ ):
547
+ super().__init__()
548
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
549
+ self.add_attention = add_attention
550
+
551
+ if attn_groups is None:
552
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
553
+
554
+ # there is always at least one resnet
555
+ resnets = [
556
+ ResnetBlockCausal3D(
557
+ in_channels=in_channels,
558
+ out_channels=in_channels,
559
+ temb_channels=temb_channels,
560
+ eps=resnet_eps,
561
+ groups=resnet_groups,
562
+ dropout=dropout,
563
+ time_embedding_norm=resnet_time_scale_shift,
564
+ non_linearity=resnet_act_fn,
565
+ output_scale_factor=output_scale_factor,
566
+ pre_norm=resnet_pre_norm,
567
+ )
568
+ ]
569
+ attentions = []
570
+
571
+ if attention_head_dim is None:
572
+ logger.warn(
573
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
574
+ )
575
+ attention_head_dim = in_channels
576
+
577
+ for _ in range(num_layers):
578
+ if self.add_attention:
579
+ attentions.append(
580
+ Attention(
581
+ in_channels,
582
+ heads=in_channels // attention_head_dim,
583
+ dim_head=attention_head_dim,
584
+ rescale_output_factor=output_scale_factor,
585
+ eps=resnet_eps,
586
+ norm_num_groups=attn_groups,
587
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
588
+ residual_connection=True,
589
+ bias=True,
590
+ upcast_softmax=True,
591
+ _from_deprecated_attn_block=True,
592
+ )
593
+ )
594
+ else:
595
+ attentions.append(None)
596
+
597
+ resnets.append(
598
+ ResnetBlockCausal3D(
599
+ in_channels=in_channels,
600
+ out_channels=in_channels,
601
+ temb_channels=temb_channels,
602
+ eps=resnet_eps,
603
+ groups=resnet_groups,
604
+ dropout=dropout,
605
+ time_embedding_norm=resnet_time_scale_shift,
606
+ non_linearity=resnet_act_fn,
607
+ output_scale_factor=output_scale_factor,
608
+ pre_norm=resnet_pre_norm,
609
+ )
610
+ )
611
+
612
+ self.attentions = nn.ModuleList(attentions)
613
+ self.resnets = nn.ModuleList(resnets)
614
+
615
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
616
+ hidden_states = self.resnets[0](hidden_states, temb)
617
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
618
+ if attn is not None:
619
+ B, C, T, H, W = hidden_states.shape
620
+ hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
621
+ attention_mask = prepare_causal_attention_mask(
622
+ T, H * W, hidden_states.dtype, hidden_states.device, batch_size=B
623
+ )
624
+ hidden_states = attn(hidden_states, temb=temb, attention_mask=attention_mask)
625
+ hidden_states = rearrange(hidden_states, "b (f h w) c -> b c f h w", f=T, h=H, w=W)
626
+ hidden_states = resnet(hidden_states, temb)
627
+
628
+ return hidden_states
629
+
630
+
631
+ class DownEncoderBlockCausal3D(nn.Module):
632
+ def __init__(
633
+ self,
634
+ in_channels: int,
635
+ out_channels: int,
636
+ dropout: float = 0.0,
637
+ num_layers: int = 1,
638
+ resnet_eps: float = 1e-6,
639
+ resnet_time_scale_shift: str = "default",
640
+ resnet_act_fn: str = "swish",
641
+ resnet_groups: int = 32,
642
+ resnet_pre_norm: bool = True,
643
+ output_scale_factor: float = 1.0,
644
+ add_downsample: bool = True,
645
+ downsample_stride: int = 2,
646
+ downsample_padding: int = 1,
647
+ ):
648
+ super().__init__()
649
+ resnets = []
650
+
651
+ for i in range(num_layers):
652
+ in_channels = in_channels if i == 0 else out_channels
653
+ resnets.append(
654
+ ResnetBlockCausal3D(
655
+ in_channels=in_channels,
656
+ out_channels=out_channels,
657
+ temb_channels=None,
658
+ eps=resnet_eps,
659
+ groups=resnet_groups,
660
+ dropout=dropout,
661
+ time_embedding_norm=resnet_time_scale_shift,
662
+ non_linearity=resnet_act_fn,
663
+ output_scale_factor=output_scale_factor,
664
+ pre_norm=resnet_pre_norm,
665
+ )
666
+ )
667
+
668
+ self.resnets = nn.ModuleList(resnets)
669
+
670
+ if add_downsample:
671
+ self.downsamplers = nn.ModuleList(
672
+ [
673
+ DownsampleCausal3D(
674
+ out_channels,
675
+ use_conv=True,
676
+ out_channels=out_channels,
677
+ padding=downsample_padding,
678
+ name="op",
679
+ stride=downsample_stride,
680
+ )
681
+ ]
682
+ )
683
+ else:
684
+ self.downsamplers = None
685
+
686
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
687
+ for resnet in self.resnets:
688
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
689
+
690
+ if self.downsamplers is not None:
691
+ for downsampler in self.downsamplers:
692
+ hidden_states = downsampler(hidden_states, scale)
693
+
694
+ return hidden_states
695
+
696
+
697
+ class UpDecoderBlockCausal3D(nn.Module):
698
+ def __init__(
699
+ self,
700
+ in_channels: int,
701
+ out_channels: int,
702
+ resolution_idx: Optional[int] = None,
703
+ dropout: float = 0.0,
704
+ num_layers: int = 1,
705
+ resnet_eps: float = 1e-6,
706
+ resnet_time_scale_shift: str = "default", # default, spatial
707
+ resnet_act_fn: str = "swish",
708
+ resnet_groups: int = 32,
709
+ resnet_pre_norm: bool = True,
710
+ output_scale_factor: float = 1.0,
711
+ add_upsample: bool = True,
712
+ upsample_scale_factor=(2, 2, 2),
713
+ temb_channels: Optional[int] = None,
714
+ ):
715
+ super().__init__()
716
+ resnets = []
717
+
718
+ for i in range(num_layers):
719
+ input_channels = in_channels if i == 0 else out_channels
720
+
721
+ resnets.append(
722
+ ResnetBlockCausal3D(
723
+ in_channels=input_channels,
724
+ out_channels=out_channels,
725
+ temb_channels=temb_channels,
726
+ eps=resnet_eps,
727
+ groups=resnet_groups,
728
+ dropout=dropout,
729
+ time_embedding_norm=resnet_time_scale_shift,
730
+ non_linearity=resnet_act_fn,
731
+ output_scale_factor=output_scale_factor,
732
+ pre_norm=resnet_pre_norm,
733
+ )
734
+ )
735
+
736
+ self.resnets = nn.ModuleList(resnets)
737
+
738
+ if add_upsample:
739
+ self.upsamplers = nn.ModuleList(
740
+ [
741
+ UpsampleCausal3D(
742
+ out_channels,
743
+ use_conv=True,
744
+ out_channels=out_channels,
745
+ upsample_factor=upsample_scale_factor,
746
+ )
747
+ ]
748
+ )
749
+ else:
750
+ self.upsamplers = None
751
+
752
+ self.resolution_idx = resolution_idx
753
+
754
+ def forward(
755
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
756
+ ) -> torch.FloatTensor:
757
+ for resnet in self.resnets:
758
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
759
+
760
+ if self.upsamplers is not None:
761
+ for upsampler in self.upsamplers:
762
+ hidden_states = upsampler(hidden_states)
763
+
764
+ return hidden_states
hyvideo/vae/vae.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Optional, Tuple
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torch.nn as nn
7
+
8
+ from diffusers.utils import BaseOutput, is_torch_version
9
+ from diffusers.utils.torch_utils import randn_tensor
10
+ from diffusers.models.attention_processor import SpatialNorm
11
+ from .unet_causal_3d_blocks import (
12
+ CausalConv3d,
13
+ UNetMidBlockCausal3D,
14
+ get_down_block3d,
15
+ get_up_block3d,
16
+ )
17
+
18
+
19
+ @dataclass
20
+ class DecoderOutput(BaseOutput):
21
+ r"""
22
+ Output of decoding method.
23
+
24
+ Args:
25
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
26
+ The decoded output sample from the last layer of the model.
27
+ """
28
+
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ class EncoderCausal3D(nn.Module):
33
+ r"""
34
+ The `EncoderCausal3D` layer of a variational autoencoder that encodes its input into a latent representation.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ in_channels: int = 3,
40
+ out_channels: int = 3,
41
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlockCausal3D",),
42
+ block_out_channels: Tuple[int, ...] = (64,),
43
+ layers_per_block: int = 2,
44
+ norm_num_groups: int = 32,
45
+ act_fn: str = "silu",
46
+ double_z: bool = True,
47
+ mid_block_add_attention=True,
48
+ time_compression_ratio: int = 4,
49
+ spatial_compression_ratio: int = 8,
50
+ ):
51
+ super().__init__()
52
+ self.layers_per_block = layers_per_block
53
+
54
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[0], kernel_size=3, stride=1)
55
+ self.mid_block = None
56
+ self.down_blocks = nn.ModuleList([])
57
+
58
+ # down
59
+ output_channel = block_out_channels[0]
60
+ for i, down_block_type in enumerate(down_block_types):
61
+ input_channel = output_channel
62
+ output_channel = block_out_channels[i]
63
+ is_final_block = i == len(block_out_channels) - 1
64
+ num_spatial_downsample_layers = int(np.log2(spatial_compression_ratio))
65
+ num_time_downsample_layers = int(np.log2(time_compression_ratio))
66
+
67
+ if time_compression_ratio == 4:
68
+ add_spatial_downsample = bool(i < num_spatial_downsample_layers)
69
+ add_time_downsample = bool(
70
+ i >= (len(block_out_channels) - 1 - num_time_downsample_layers)
71
+ and not is_final_block
72
+ )
73
+ else:
74
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
75
+
76
+ downsample_stride_HW = (2, 2) if add_spatial_downsample else (1, 1)
77
+ downsample_stride_T = (2,) if add_time_downsample else (1,)
78
+ downsample_stride = tuple(downsample_stride_T + downsample_stride_HW)
79
+ down_block = get_down_block3d(
80
+ down_block_type,
81
+ num_layers=self.layers_per_block,
82
+ in_channels=input_channel,
83
+ out_channels=output_channel,
84
+ add_downsample=bool(add_spatial_downsample or add_time_downsample),
85
+ downsample_stride=downsample_stride,
86
+ resnet_eps=1e-6,
87
+ downsample_padding=0,
88
+ resnet_act_fn=act_fn,
89
+ resnet_groups=norm_num_groups,
90
+ attention_head_dim=output_channel,
91
+ temb_channels=None,
92
+ )
93
+ self.down_blocks.append(down_block)
94
+
95
+ # mid
96
+ self.mid_block = UNetMidBlockCausal3D(
97
+ in_channels=block_out_channels[-1],
98
+ resnet_eps=1e-6,
99
+ resnet_act_fn=act_fn,
100
+ output_scale_factor=1,
101
+ resnet_time_scale_shift="default",
102
+ attention_head_dim=block_out_channels[-1],
103
+ resnet_groups=norm_num_groups,
104
+ temb_channels=None,
105
+ add_attention=mid_block_add_attention,
106
+ )
107
+
108
+ # out
109
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
110
+ self.conv_act = nn.SiLU()
111
+
112
+ conv_out_channels = 2 * out_channels if double_z else out_channels
113
+ self.conv_out = CausalConv3d(block_out_channels[-1], conv_out_channels, kernel_size=3)
114
+
115
+ def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor:
116
+ r"""The forward method of the `EncoderCausal3D` class."""
117
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions"
118
+
119
+ sample = self.conv_in(sample)
120
+
121
+ # down
122
+ for down_block in self.down_blocks:
123
+ sample = down_block(sample)
124
+
125
+ # middle
126
+ sample = self.mid_block(sample)
127
+
128
+ # post-process
129
+ sample = self.conv_norm_out(sample)
130
+ sample = self.conv_act(sample)
131
+ sample = self.conv_out(sample)
132
+
133
+ return sample
134
+
135
+
136
+ class DecoderCausal3D(nn.Module):
137
+ r"""
138
+ The `DecoderCausal3D` layer of a variational autoencoder that decodes its latent representation into an output sample.
139
+ """
140
+
141
+ def __init__(
142
+ self,
143
+ in_channels: int = 3,
144
+ out_channels: int = 3,
145
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlockCausal3D",),
146
+ block_out_channels: Tuple[int, ...] = (64,),
147
+ layers_per_block: int = 2,
148
+ norm_num_groups: int = 32,
149
+ act_fn: str = "silu",
150
+ norm_type: str = "group", # group, spatial
151
+ mid_block_add_attention=True,
152
+ time_compression_ratio: int = 4,
153
+ spatial_compression_ratio: int = 8,
154
+ ):
155
+ super().__init__()
156
+ self.layers_per_block = layers_per_block
157
+
158
+ self.conv_in = CausalConv3d(in_channels, block_out_channels[-1], kernel_size=3, stride=1)
159
+ self.mid_block = None
160
+ self.up_blocks = nn.ModuleList([])
161
+
162
+ temb_channels = in_channels if norm_type == "spatial" else None
163
+
164
+ # mid
165
+ self.mid_block = UNetMidBlockCausal3D(
166
+ in_channels=block_out_channels[-1],
167
+ resnet_eps=1e-6,
168
+ resnet_act_fn=act_fn,
169
+ output_scale_factor=1,
170
+ resnet_time_scale_shift="default" if norm_type == "group" else norm_type,
171
+ attention_head_dim=block_out_channels[-1],
172
+ resnet_groups=norm_num_groups,
173
+ temb_channels=temb_channels,
174
+ add_attention=mid_block_add_attention,
175
+ )
176
+
177
+ # up
178
+ reversed_block_out_channels = list(reversed(block_out_channels))
179
+ output_channel = reversed_block_out_channels[0]
180
+ for i, up_block_type in enumerate(up_block_types):
181
+ prev_output_channel = output_channel
182
+ output_channel = reversed_block_out_channels[i]
183
+ is_final_block = i == len(block_out_channels) - 1
184
+ num_spatial_upsample_layers = int(np.log2(spatial_compression_ratio))
185
+ num_time_upsample_layers = int(np.log2(time_compression_ratio))
186
+
187
+ if time_compression_ratio == 4:
188
+ add_spatial_upsample = bool(i < num_spatial_upsample_layers)
189
+ add_time_upsample = bool(
190
+ i >= len(block_out_channels) - 1 - num_time_upsample_layers
191
+ and not is_final_block
192
+ )
193
+ else:
194
+ raise ValueError(f"Unsupported time_compression_ratio: {time_compression_ratio}.")
195
+
196
+ upsample_scale_factor_HW = (2, 2) if add_spatial_upsample else (1, 1)
197
+ upsample_scale_factor_T = (2,) if add_time_upsample else (1,)
198
+ upsample_scale_factor = tuple(upsample_scale_factor_T + upsample_scale_factor_HW)
199
+ up_block = get_up_block3d(
200
+ up_block_type,
201
+ num_layers=self.layers_per_block + 1,
202
+ in_channels=prev_output_channel,
203
+ out_channels=output_channel,
204
+ prev_output_channel=None,
205
+ add_upsample=bool(add_spatial_upsample or add_time_upsample),
206
+ upsample_scale_factor=upsample_scale_factor,
207
+ resnet_eps=1e-6,
208
+ resnet_act_fn=act_fn,
209
+ resnet_groups=norm_num_groups,
210
+ attention_head_dim=output_channel,
211
+ temb_channels=temb_channels,
212
+ resnet_time_scale_shift=norm_type,
213
+ )
214
+ self.up_blocks.append(up_block)
215
+ prev_output_channel = output_channel
216
+
217
+ # out
218
+ if norm_type == "spatial":
219
+ self.conv_norm_out = SpatialNorm(block_out_channels[0], temb_channels)
220
+ else:
221
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6)
222
+ self.conv_act = nn.SiLU()
223
+ self.conv_out = CausalConv3d(block_out_channels[0], out_channels, kernel_size=3)
224
+
225
+ self.gradient_checkpointing = False
226
+
227
+ def forward(
228
+ self,
229
+ sample: torch.FloatTensor,
230
+ latent_embeds: Optional[torch.FloatTensor] = None,
231
+ ) -> torch.FloatTensor:
232
+ r"""The forward method of the `DecoderCausal3D` class."""
233
+ assert len(sample.shape) == 5, "The input tensor should have 5 dimensions."
234
+
235
+ sample = self.conv_in(sample)
236
+
237
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
238
+ if self.training and self.gradient_checkpointing:
239
+
240
+ def create_custom_forward(module):
241
+ def custom_forward(*inputs):
242
+ return module(*inputs)
243
+
244
+ return custom_forward
245
+
246
+ if is_torch_version(">=", "1.11.0"):
247
+ # middle
248
+ sample = torch.utils.checkpoint.checkpoint(
249
+ create_custom_forward(self.mid_block),
250
+ sample,
251
+ latent_embeds,
252
+ use_reentrant=False,
253
+ )
254
+ sample = sample.to(upscale_dtype)
255
+
256
+ # up
257
+ for up_block in self.up_blocks:
258
+ sample = torch.utils.checkpoint.checkpoint(
259
+ create_custom_forward(up_block),
260
+ sample,
261
+ latent_embeds,
262
+ use_reentrant=False,
263
+ )
264
+ else:
265
+ # middle
266
+ sample = torch.utils.checkpoint.checkpoint(
267
+ create_custom_forward(self.mid_block), sample, latent_embeds
268
+ )
269
+ sample = sample.to(upscale_dtype)
270
+
271
+ # up
272
+ for up_block in self.up_blocks:
273
+ sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample, latent_embeds)
274
+ else:
275
+ # middle
276
+ sample = self.mid_block(sample, latent_embeds)
277
+ sample = sample.to(upscale_dtype)
278
+
279
+ # up
280
+ for up_block in self.up_blocks:
281
+ sample = up_block(sample, latent_embeds)
282
+
283
+ # post-process
284
+ if latent_embeds is None:
285
+ sample = self.conv_norm_out(sample)
286
+ else:
287
+ sample = self.conv_norm_out(sample, latent_embeds)
288
+ sample = self.conv_act(sample)
289
+ sample = self.conv_out(sample)
290
+
291
+ return sample
292
+
293
+
294
+ class DiagonalGaussianDistribution(object):
295
+ def __init__(self, parameters: torch.Tensor, deterministic: bool = False):
296
+ if parameters.ndim == 3:
297
+ dim = 2 # (B, L, C)
298
+ elif parameters.ndim == 5 or parameters.ndim == 4:
299
+ dim = 1 # (B, C, T, H ,W) / (B, C, H, W)
300
+ else:
301
+ raise NotImplementedError
302
+ self.parameters = parameters
303
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=dim)
304
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
305
+ self.deterministic = deterministic
306
+ self.std = torch.exp(0.5 * self.logvar)
307
+ self.var = torch.exp(self.logvar)
308
+ if self.deterministic:
309
+ self.var = self.std = torch.zeros_like(
310
+ self.mean, device=self.parameters.device, dtype=self.parameters.dtype
311
+ )
312
+
313
+ def sample(self, generator: Optional[torch.Generator] = None) -> torch.FloatTensor:
314
+ # make sure sample is on the same device as the parameters and has same dtype
315
+ sample = randn_tensor(
316
+ self.mean.shape,
317
+ generator=generator,
318
+ device=self.parameters.device,
319
+ dtype=self.parameters.dtype,
320
+ )
321
+ x = self.mean + self.std * sample
322
+ return x
323
+
324
+ def kl(self, other: "DiagonalGaussianDistribution" = None) -> torch.Tensor:
325
+ if self.deterministic:
326
+ return torch.Tensor([0.0])
327
+ else:
328
+ reduce_dim = list(range(1, self.mean.ndim))
329
+ if other is None:
330
+ return 0.5 * torch.sum(
331
+ torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar,
332
+ dim=reduce_dim,
333
+ )
334
+ else:
335
+ return 0.5 * torch.sum(
336
+ torch.pow(self.mean - other.mean, 2) / other.var
337
+ + self.var / other.var
338
+ - 1.0
339
+ - self.logvar
340
+ + other.logvar,
341
+ dim=reduce_dim,
342
+ )
343
+
344
+ def nll(self, sample: torch.Tensor, dims: Tuple[int, ...] = [1, 2, 3]) -> torch.Tensor:
345
+ if self.deterministic:
346
+ return torch.Tensor([0.0])
347
+ logtwopi = np.log(2.0 * np.pi)
348
+ return 0.5 * torch.sum(
349
+ logtwopi + self.logvar +
350
+ torch.pow(sample - self.mean, 2) / self.var,
351
+ dim=dims,
352
+ )
353
+
354
+ def mode(self) -> torch.Tensor:
355
+ return self.mean