M3Site / esm /utils /noise_schedules.py
anonymousforpaper's picture
Upload 103 files
224a33f verified
raw
history blame contribute delete
628 Bytes
import math
import torch
def cosine_schedule(t: torch.Tensor):
# t is a tensor of size (batch_size,) with values between 0 and 1. This is the
# schedule used in the MaskGIT paper
return torch.cos(t * math.pi * 0.5)
def cubic_schedule(t):
return 1 - t**3
def linear_schedule(t):
return 1 - t
def square_root_schedule(t):
return 1 - torch.sqrt(t)
def square_schedule(t):
return 1 - t**2
NOISE_SCHEDULE_REGISTRY = {
"cosine": cosine_schedule,
"linear": linear_schedule,
"square_root_schedule": square_root_schedule,
"cubic": cubic_schedule,
"square": square_schedule,
}