File size: 628 Bytes
224a33f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import math

import torch


def cosine_schedule(t: torch.Tensor):
    # t is a tensor of size (batch_size,) with values between 0 and 1. This is the
    # schedule used in the MaskGIT paper
    return torch.cos(t * math.pi * 0.5)


def cubic_schedule(t):
    return 1 - t**3


def linear_schedule(t):
    return 1 - t


def square_root_schedule(t):
    return 1 - torch.sqrt(t)


def square_schedule(t):
    return 1 - t**2


NOISE_SCHEDULE_REGISTRY = {
    "cosine": cosine_schedule,
    "linear": linear_schedule,
    "square_root_schedule": square_root_schedule,
    "cubic": cubic_schedule,
    "square": square_schedule,
}