import torch from nodes import VAEEncode class VAEDecodeBatched: @classmethod def INPUT_TYPES(s): return { "required": { "samples": ("LATENT", ), "vae": ("VAE", ), "per_batch": ("INT", {"default": 16, "min": 1}) } } CATEGORY = "Video Helper Suite 🎥🅥🅗🅢/batched nodes" RETURN_TYPES = ("IMAGE",) FUNCTION = "decode" def decode(self, vae, samples, per_batch): decoded = [] for start_idx in range(0, samples["samples"].shape[0], per_batch): decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch])) return (torch.cat(decoded, dim=0), ) class VAEEncodeBatched: @classmethod def INPUT_TYPES(s): return { "required": { "pixels": ("IMAGE", ), "vae": ("VAE", ), "per_batch": ("INT", {"default": 16, "min": 1}) } } CATEGORY = "Video Helper Suite 🎥🅥🅗🅢/batched nodes" RETURN_TYPES = ("LATENT",) FUNCTION = "encode" def encode(self, vae, pixels, per_batch): t = [] for start_idx in range(0, pixels.shape[0], per_batch): sub_pixels = VAEEncode.vae_encode_crop_pixels(pixels[start_idx:start_idx+per_batch]) t.append(vae.encode(sub_pixels[:,:,:,:3])) return ({"samples": torch.cat(t, dim=0)}, )