File size: 6,944 Bytes
f85e212
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155

import torch 
import torch.nn.functional as F


from medical_diffusion.models.noise_schedulers import BasicNoiseScheduler

class GaussianNoiseScheduler(BasicNoiseScheduler):
    def __init__(
        self,
        timesteps=1000,
        T = None, 
        schedule_strategy='cosine',
        beta_start = 0.0001, # default 1e-4, stable-diffusion ~ 1e-3
        beta_end = 0.02,
        betas = None,
        ):
        super().__init__(timesteps, T)

        self.schedule_strategy = schedule_strategy

        if betas is not None:
            betas = torch.as_tensor(betas, dtype = torch.float64)
        elif schedule_strategy == "linear":
            betas = torch.linspace(beta_start, beta_end, timesteps, dtype = torch.float64)
        elif schedule_strategy == "scaled_linear": # proposed as "quadratic" in https://arxiv.org/abs/2006.11239, used in stable-diffusion 
            betas = torch.linspace(beta_start**0.5, beta_end**0.5, timesteps, dtype = torch.float64)**2
        elif schedule_strategy == "cosine":
            s = 0.008
            x = torch.linspace(0, timesteps, timesteps + 1, dtype = torch.float64) # [0, T]
            alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
            alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
            betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
            betas =  torch.clip(betas, 0, 0.999)
        else:
            raise NotImplementedError(f"{schedule_strategy} does is not implemented for {self.__class__}")


        alphas = 1-betas 
        alphas_cumprod = torch.cumprod(alphas, dim=0)
        alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)


        register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))

        register_buffer('betas', betas) # (0 , 1)

        register_buffer('alphas', alphas) # (1 , 0)
        register_buffer('alphas_cumprod', alphas_cumprod)
        register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
        register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
        register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
        register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
        register_buffer('sqrt_recipm1_alphas_cumprod',  torch.sqrt(1. / alphas_cumprod - 1))

        register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
        register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
        register_buffer('posterior_variance', betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod))
            

    def estimate_x_t(self, x_0, t, x_T=None):
        # NOTE: t == 0 means diffused for 1 step (https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils.py#L108)
        # NOTE: t == 0 means not diffused for cold-diffusion (in contradiction to the above comment) https://github.com/arpitbansal297/Cold-Diffusion-Models/blob/c828140b7047ca22f995b99fbcda360bc30fc25d/denoising-diffusion-pytorch/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L361
        x_T = self.x_final(x_0) if x_T is None else x_T 
        # ndim = x_0.ndim
        # x_t = (self.extract(self.sqrt_alphas_cumprod, t, ndim)*x_0 + 
        #         self.extract(self.sqrt_one_minus_alphas_cumprod, t, ndim)*x_T)
        def clipper(b):
            tb = t[b]
            if tb<0:
                return x_0[b]
            elif tb>=self.T:
                return x_T[b] 
            else:
                return self.sqrt_alphas_cumprod[tb]*x_0[b]+self.sqrt_one_minus_alphas_cumprod[tb]*x_T[b]
        x_t = torch.stack([clipper(b) for b in range(t.shape[0])]) 
        return x_t 
    

    def estimate_x_t_prior_from_x_T(self, x_t, t, x_T, use_log=True, clip_x0=True,  var_scale=0, cold_diffusion=False): 
        x_0 = self.estimate_x_0(x_t, x_T, t, clip_x0)
        return self.estimate_x_t_prior_from_x_0(x_t, t, x_0, use_log, clip_x0, var_scale, cold_diffusion)


    def estimate_x_t_prior_from_x_0(self, x_t, t, x_0, use_log=True, clip_x0=True, var_scale=0, cold_diffusion=False):
        x_0 = self._clip_x_0(x_0) if clip_x0 else x_0

        if cold_diffusion: # see https://arxiv.org/abs/2208.09392 
            x_T_est =  self.estimate_x_T(x_t, x_0, t) # or use x_T estimated by UNet if available? 
            x_t_est = self.estimate_x_t(x_0, t, x_T=x_T_est) 
            x_t_prior = self.estimate_x_t(x_0, t-1, x_T=x_T_est) 
            noise_t = x_t_est-x_t_prior
            x_t_prior = x_t-noise_t 
        else:
            mean = self.estimate_mean_t(x_t, x_0, t)
            variance = self.estimate_variance_t(t, x_t.ndim, use_log, var_scale)
            std = torch.exp(0.5*variance) if use_log else torch.sqrt(variance)
            std[t==0] = 0.0 
            x_T = self.x_final(x_t)
            x_t_prior =  mean+std*x_T
        return x_t_prior, x_0 

    
    def estimate_mean_t(self, x_t, x_0, t):
        ndim = x_t.ndim
        return (self.extract(self.posterior_mean_coef1, t, ndim)*x_0+
                self.extract(self.posterior_mean_coef2, t, ndim)*x_t) 
    

    def estimate_variance_t(self, t, ndim, log=True, var_scale=0, eps=1e-20):
        min_variance = self.extract(self.posterior_variance, t, ndim)
        max_variance = self.extract(self.betas, t, ndim)
        if log:
            min_variance = torch.log(min_variance.clamp(min=eps))
            max_variance = torch.log(max_variance.clamp(min=eps))
        return var_scale * max_variance + (1 - var_scale) * min_variance 
    

    def estimate_x_0(self, x_t, x_T, t, clip_x0=True):
        ndim = x_t.ndim
        x_0 = (self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - 
                self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim)*x_T)
        x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
        return x_0


    def estimate_x_T(self, x_t, x_0, t, clip_x0=True):
        ndim = x_t.ndim
        x_0 = self._clip_x_0(x_0) if clip_x0 else x_0
        return ((self.extract(self.sqrt_recip_alphas_cumprod, t, ndim)*x_t - x_0)/ 
                self.extract(self.sqrt_recipm1_alphas_cumprod, t, ndim))
    
    
    @classmethod
    def x_final(cls, x):
        return torch.randn_like(x)

    @classmethod
    def _clip_x_0(cls, x_0):
        # See "static/dynamic thresholding" in Imagen https://arxiv.org/abs/2205.11487 

        # "static thresholding"
        m = 1 # Set this to about 4*sigma = 4 if latent diffusion is used  
        x_0 = x_0.clamp(-m, m)

        # "dynamic thresholding"
        # r = torch.stack([torch.quantile(torch.abs(x_0_b), 0.997) for x_0_b in x_0])
        # r = torch.maximum(r, torch.full_like(r,m))
        # x_0 =  torch.stack([x_0_b.clamp(-r_b, r_b)/r_b*m for x_0_b, r_b in zip(x_0, r) ] )
        
        return x_0