import math import torch def cosine_schedule(t: torch.Tensor): # t is a tensor of size (batch_size,) with values between 0 and 1. This is the # schedule used in the MaskGIT paper return torch.cos(t * math.pi * 0.5) def cubic_schedule(t): return 1 - t**3 def linear_schedule(t): return 1 - t def square_root_schedule(t): return 1 - torch.sqrt(t) def square_schedule(t): return 1 - t**2 NOISE_SCHEDULE_REGISTRY = { "cosine": cosine_schedule, "linear": linear_schedule, "square_root_schedule": square_root_schedule, "cubic": cubic_schedule, "square": square_schedule, }