daquanzhou
merge github repos and lfs track ckpt/path/safetensors/pt
613c9ab
raw
history blame
1.45 kB
import torch
from nodes import VAEEncode
class VAEDecodeBatched:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"samples": ("LATENT", ),
"vae": ("VAE", ),
"per_batch": ("INT", {"default": 16, "min": 1})
}
}
CATEGORY = "Video Helper Suite πŸŽ₯πŸ…₯πŸ…—πŸ…’/batched nodes"
RETURN_TYPES = ("IMAGE",)
FUNCTION = "decode"
def decode(self, vae, samples, per_batch):
decoded = []
for start_idx in range(0, samples["samples"].shape[0], per_batch):
decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch]))
return (torch.cat(decoded, dim=0), )
class VAEEncodeBatched:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"pixels": ("IMAGE", ), "vae": ("VAE", ),
"per_batch": ("INT", {"default": 16, "min": 1})
}
}
CATEGORY = "Video Helper Suite πŸŽ₯πŸ…₯πŸ…—πŸ…’/batched nodes"
RETURN_TYPES = ("LATENT",)
FUNCTION = "encode"
def encode(self, vae, pixels, per_batch):
t = []
for start_idx in range(0, pixels.shape[0], per_batch):
sub_pixels = VAEEncode.vae_encode_crop_pixels(pixels[start_idx:start_idx+per_batch])
t.append(vae.encode(sub_pixels[:,:,:,:3]))
return ({"samples": torch.cat(t, dim=0)}, )