# thanks to MinusZoneAI: https://github.com/MinusZoneAI/ComfyUI-CogVideoX-MZ/blob/b98b98bd04621e4c85547866c12de2ec723ae98a/mz_enable_vae_encode_tiling.py from typing import Optional import torch from diffusers.utils.accelerate_utils import apply_forward_hook from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution from diffusers.models.modeling_outputs import AutoencoderKLOutput @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True ): """ Encode a batch of images into latents. Args: x (`torch.Tensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. Returns: The latent representations of the encoded videos. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ if self.use_slicing and x.shape[0] > 1: encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] h = torch.cat(encoded_slices) else: h = self._encode(x) posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,) return AutoencoderKLOutput(latent_dist=posterior) def tiled_encode(self, x: torch.Tensor) -> torch.Tensor: r"""Encode a batch of images using a tiled encoder. When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is different from non-tiled encoding because each tile uses a different encoder. To avoid tiling artifacts, the tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the output, but they should be much less noticeable. Args: x (`torch.Tensor`): Input batch of videos. Returns: `torch.Tensor`: The latent representation of the encoded videos. """ # For a rough memory estimate, take a look at the `tiled_decode` method. batch_size, num_channels, num_frames, height, width = x.shape overlap_height = int(self.tile_sample_min_height * (1 - self.tile_overlap_factor_height)) overlap_width = int(self.tile_sample_min_width * (1 - self.tile_overlap_factor_width)) blend_extent_height = int( self.tile_latent_min_height * self.tile_overlap_factor_height) blend_extent_width = int( self.tile_latent_min_width * self.tile_overlap_factor_width) row_limit_height = self.tile_latent_min_height - blend_extent_height row_limit_width = self.tile_latent_min_width - blend_extent_width frame_batch_size = 4 # Split x into overlapping tiles and encode them separately. # The tiles have an overlap to avoid seams between tiles. rows = [] for i in range(0, height, overlap_height): row = [] for j in range(0, width, overlap_width): # Note: We expect the number of frames to be either `1` or `frame_batch_size * k` or `frame_batch_size * k + 1` for some k. num_batches = num_frames // frame_batch_size if num_frames > 1 else 1 time = [] for k in range(num_batches): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * k + \ (0 if k == 0 else remaining_frames) end_frame = frame_batch_size * (k + 1) + remaining_frames tile = x[ :, :, start_frame:end_frame, i: i + self.tile_sample_min_height, j: j + self.tile_sample_min_width, ] tile = self.encoder(tile) if not isinstance(tile, tuple): tile = (tile,) if self.quant_conv is not None: tile = self.quant_conv(tile) time.append(tile[0]) try: self._clear_fake_context_parallel_cache() except: pass row.append(torch.cat(time, dim=2)) rows.append(row) result_rows = [] for i, row in enumerate(rows): result_row = [] for j, tile in enumerate(row): # blend the above tile and the left tile # to the current tile and add the current tile to the result row if i > 0: tile = self.blend_v( rows[i - 1][j], tile, blend_extent_height) if j > 0: tile = self.blend_h(row[j - 1], tile, blend_extent_width) result_row.append( tile[:, :, :, :row_limit_height, :row_limit_width]) result_rows.append(torch.cat(result_row, dim=4)) enc = torch.cat(result_rows, dim=3) return enc def _encode( self, x: torch.Tensor, return_dict: bool = True ): batch_size, num_channels, num_frames, height, width = x.shape if self.use_encode_tiling and (width > self.tile_sample_min_width or height > self.tile_sample_min_height): return self.tiled_encode(x) if num_frames == 1: h = self.encoder(x) if self.quant_conv is not None: h = self.quant_conv(h) posterior = DiagonalGaussianDistribution(h) else: frame_batch_size = 4 h = [] for i in range(num_frames // frame_batch_size): remaining_frames = num_frames % frame_batch_size start_frame = frame_batch_size * i + \ (0 if i == 0 else remaining_frames) end_frame = frame_batch_size * (i + 1) + remaining_frames z_intermediate = x[:, :, start_frame:end_frame] z_intermediate = self.encoder(z_intermediate) if self.quant_conv is not None: z_intermediate = self.quant_conv(z_intermediate) h.append(z_intermediate) try: self._clear_fake_context_parallel_cache() except: pass h = torch.cat(h, dim=2) return h def enable_encode_tiling( self, tile_sample_min_height: Optional[int] = None, tile_sample_min_width: Optional[int] = None, tile_overlap_factor_height: Optional[float] = None, tile_overlap_factor_width: Optional[float] = None, ) -> None: r""" Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow processing larger images. Args: tile_sample_min_height (`int`, *optional*): The minimum height required for a sample to be separated into tiles across the height dimension. tile_sample_min_width (`int`, *optional*): The minimum width required for a sample to be separated into tiles across the width dimension. tile_overlap_factor_height (`int`, *optional*): The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher value might cause more tiles to be processed leading to slow down of the decoding process. tile_overlap_factor_width (`int`, *optional*): The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher value might cause more tiles to be processed leading to slow down of the decoding process. """ self.use_encode_tiling = True self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width self.tile_latent_min_height = int( self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) ) self.tile_latent_min_width = int( self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width from types import MethodType def enable_vae_encode_tiling(vae): vae.encode = MethodType(encode, vae) setattr(vae, "_encode", MethodType(_encode, vae)) setattr(vae, "tiled_encode", MethodType(tiled_encode, vae)) setattr(vae, "use_encode_tiling", True) setattr(vae, "enable_encode_tiling", MethodType(enable_encode_tiling, vae)) vae.enable_encode_tiling() return vae