File size: 6,502 Bytes
d4c980e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 |
"""
Abstract SDE classes, Reverse SDE, and VE/VP SDEs.
Taken and adapted from https://github.com/yang-song/score_sde_pytorch/blob/1618ddea340f3e4a2ed7852a0694a809775cf8d0/sde_lib.py
"""
import abc
import warnings
import math
import scipy.special as sc
import numpy as np
from geco.util.tensors import batch_broadcast
import torch
from geco.util.registry import Registry
SDERegistry = Registry("SDE")
class SDE(abc.ABC):
"""SDE abstract class. Functions are designed for a mini-batch of inputs."""
def __init__(self, N):
"""Construct an SDE.
Args:
N: number of discretization time steps.
"""
super().__init__()
self.N = N
@property
@abc.abstractmethod
def T(self):
"""End time of the SDE."""
pass
@abc.abstractmethod
def sde(self, x, t, *args):
pass
@abc.abstractmethod
def marginal_prob(self, x, t, *args):
"""Parameters to determine the marginal distribution of the SDE, $p_t(x|args)$."""
pass
@abc.abstractmethod
def prior_sampling(self, shape, *args):
"""Generate one sample from the prior distribution, $p_T(x|args)$ with shape `shape`."""
pass
@abc.abstractmethod
def prior_logp(self, z):
"""Compute log-density of the prior distribution.
Useful for computing the log-likelihood via probability flow ODE.
Args:
z: latent code
Returns:
log probability density
"""
pass
@staticmethod
@abc.abstractmethod
def add_argparse_args(parent_parser):
"""
Add the necessary arguments for instantiation of this SDE class to an argparse ArgumentParser.
"""
pass
def discretize(self, x, t, y, stepsize):
"""Discretize the SDE in the form: x_{i+1} = x_i + f_i(x_i) + G_i z_i.
Useful for reverse diffusion sampling and probabiliy flow sampling.
Defaults to Euler-Maruyama discretization.
Args:
x: a torch tensor
t: a torch float representing the time step (from 0 to `self.T`)
Returns:
f, G
"""
dt = stepsize
#dt = 1 /self.N
drift, diffusion = self.sde(x, t, y)
f = drift * dt
G = diffusion * torch.sqrt(torch.tensor(dt, device=t.device))
return f, G
def reverse(oself, score_model, probability_flow=False):
"""Create the reverse-time SDE/ODE.
Args:
score_model: A function that takes x, t and y and returns the score.
probability_flow: If `True`, create the reverse-time ODE used for probability flow sampling.
"""
N = oself.N
T = oself.T
sde_fn = oself.sde
discretize_fn = oself.discretize
# Build the class for reverse-time SDE.
class RSDE(oself.__class__):
def __init__(self):
self.N = N
self.probability_flow = probability_flow
@property
def T(self):
return T
def sde(self, x, t, *args):
"""Create the drift and diffusion functions for the reverse SDE/ODE."""
rsde_parts = self.rsde_parts(x, t, *args)
total_drift, diffusion = rsde_parts["total_drift"], rsde_parts["diffusion"]
return total_drift, diffusion
def discretize(self, x, t, y, m, stepsize):
"""Create discretized iteration rules for the reverse diffusion sampler."""
f, G = discretize_fn(x, t, y, stepsize)
if torch.is_complex(G):
G = G.imag
rev_f = f - G[:, None, None, None] ** 2 * score_model(x, t, y, m) * (0.5 if self.probability_flow else 1.)
rev_G = torch.zeros_like(G) if self.probability_flow else G
return rev_f, rev_G
return RSDE()
@abc.abstractmethod
def copy(self):
pass
@SDERegistry.register("bbed")
class BBED(SDE):
@staticmethod
def add_argparse_args(parser):
parser.add_argument("--sde-n", type=int, default=30, help="The number of timesteps in the SDE discretization. 30 by default")
parser.add_argument("--T_sampling", type=float, default=0.999, help="The T so that t < T during sampling in the train step.")
parser.add_argument("--k", type=float, default = 2.6, help="base factor for diffusion term")
parser.add_argument("--theta", type=float, default = 0.52, help="root scale factor for diffusion term.")
return parser
def __init__(self, T_sampling, k, theta, N=1000, **kwargs):
"""Construct an Brownian Bridge with Exploding Diffusion Coefficient SDE with parameterization as in the paper.
dx = (y-x)/(Tc-t) dt + sqrt(theta)*k^t dw
"""
super().__init__(N)
self.k = k
self.logk = np.log(self.k)
self.theta = theta
self.N = N
self.Eilog = sc.expi(-2*self.logk)
self.T = T_sampling #for sampling in train step and inference
self.Tc = 1 #for constructing the SDE, dont change this
def copy(self):
return BBED(self.T, self.k, self.theta, N=self.N)
def T(self):
return self.T
def Tc(self):
return self.Tc
def sde(self, x, t, y):
drift = (y - x)/(self.Tc - t)
sigma = (self.k) ** t
diffusion = sigma * np.sqrt(self.theta)
return drift, diffusion
def _mean(self, x0, t, y):
time = (t/self.Tc)[:, None, None, None]
mean = x0*(1-time) + y*time
return mean
def _std(self, t):
t_np = t.cpu().detach().numpy()
Eis = sc.expi(2*(t_np-1)*self.logk) - self.Eilog
h = 2*self.k**2*self.logk
var = (self.k**(2*t_np)-1+t_np) + h*(1-t_np)*Eis
var = torch.tensor(var).to(device=t.device)*(1-t)*self.theta
return torch.sqrt(var)
def marginal_prob(self, x0, t, y):
return self._mean(x0, t, y), self._std(t)
def prior_sampling(self, shape, y):
if shape != y.shape:
warnings.warn(f"Target shape {shape} does not match shape of y {y.shape}! Ignoring target shape.")
std = self._std(self.T*torch.ones((y.shape[0],), device=y.device))
z = torch.randn_like(y)
x_T = y + z * std[:, None, None, None]
return x_T, z
def prior_logp(self, z):
raise NotImplementedError("prior_logp for BBED not yet implemented!")
|