import torch def get_latent_dimensions(num_frames, width, height): spatial_downsample = 8 temporal_downsample = 6 in_channels = 12 B = 1 C = in_channels T = (num_frames - 1) // temporal_downsample + 1 H = height // spatial_downsample W = width // spatial_downsample return (B, C, T, H, W) def add_latent_noise(model, latent_shape, sigma_schedule, samples, generator): z = torch.randn( latent_shape, device=model.device, generator=generator, dtype=torch.float32, ) if samples is not None: z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * samples.to(model.device) return z