Spaces:
Running
Running
File size: 628 Bytes
224a33f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import math
import torch
def cosine_schedule(t: torch.Tensor):
# t is a tensor of size (batch_size,) with values between 0 and 1. This is the
# schedule used in the MaskGIT paper
return torch.cos(t * math.pi * 0.5)
def cubic_schedule(t):
return 1 - t**3
def linear_schedule(t):
return 1 - t
def square_root_schedule(t):
return 1 - torch.sqrt(t)
def square_schedule(t):
return 1 - t**2
NOISE_SCHEDULE_REGISTRY = {
"cosine": cosine_schedule,
"linear": linear_schedule,
"square_root_schedule": square_root_schedule,
"cubic": cubic_schedule,
"square": square_schedule,
}
|