File size: 676 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
def get_latent_dimensions(num_frames, width, height):
spatial_downsample = 8
temporal_downsample = 6
in_channels = 12
B = 1
C = in_channels
T = (num_frames - 1) // temporal_downsample + 1
H = height // spatial_downsample
W = width // spatial_downsample
return (B, C, T, H, W)
def add_latent_noise(model, latent_shape, sigma_schedule, samples, generator):
z = torch.randn(
latent_shape,
device=model.device,
generator=generator,
dtype=torch.float32,
)
if samples is not None:
z = z * sigma_schedule[0] + (1 -sigma_schedule[0]) * samples.to(model.device)
return z
|