Spaces:
Running
Running
import torch | |
from nodes import VAEEncode | |
class VAEDecodeBatched: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"samples": ("LATENT", ), | |
"vae": ("VAE", ), | |
"per_batch": ("INT", {"default": 16, "min": 1}) | |
} | |
} | |
CATEGORY = "Video Helper Suite π₯π ₯π π ’/batched nodes" | |
RETURN_TYPES = ("IMAGE",) | |
FUNCTION = "decode" | |
def decode(self, vae, samples, per_batch): | |
decoded = [] | |
for start_idx in range(0, samples["samples"].shape[0], per_batch): | |
decoded.append(vae.decode(samples["samples"][start_idx:start_idx+per_batch])) | |
return (torch.cat(decoded, dim=0), ) | |
class VAEEncodeBatched: | |
def INPUT_TYPES(s): | |
return { | |
"required": { | |
"pixels": ("IMAGE", ), "vae": ("VAE", ), | |
"per_batch": ("INT", {"default": 16, "min": 1}) | |
} | |
} | |
CATEGORY = "Video Helper Suite π₯π ₯π π ’/batched nodes" | |
RETURN_TYPES = ("LATENT",) | |
FUNCTION = "encode" | |
def encode(self, vae, pixels, per_batch): | |
t = [] | |
for start_idx in range(0, pixels.shape[0], per_batch): | |
sub_pixels = VAEEncode.vae_encode_crop_pixels(pixels[start_idx:start_idx+per_batch]) | |
t.append(vae.encode(sub_pixels[:,:,:,:3])) | |
return ({"samples": torch.cat(t, dim=0)}, ) | |