Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,611 Bytes
137645c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 |
import torch
from torchdiffeq import odeint
# https://github.com/willisma/SiT/blob/main/transport/integrators.py#L77
class ODE:
"""ODE solver class"""
def __init__(
self,
*,
t0,
t1,
sampler_type,
num_steps,
atol,
rtol,
):
assert t0 < t1, "ODE sampler has to be in forward time"
self.t = torch.linspace(t0, t1, num_steps)
self.atol = atol
self.rtol = rtol
self.sampler_type = sampler_type
def time_linear_to_Timesteps(self, t, t_start, t_end, T_start, T_end):
# T = k * t + b
k = (T_end - T_start) / (t_end - t_start)
b = T_start - t_start * k
return k * t + b
def sample(self, x, model, T_start, T_end, **model_kwargs):
device = x[0].device if isinstance(x, tuple) else x.device
def _fn(t, x):
t = torch.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else torch.ones(x.size(0)).to(device) * t
model_output = model(x, self.time_linear_to_Timesteps(t, 0, 1, T_start, T_end), **model_kwargs)
assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
return model_output
t = self.t.to(device)
atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
samples = odeint(
_fn,
x,
t,
method=self.sampler_type,
atol=atol,
rtol=rtol
)
return samples
|