File size: 1,611 Bytes
137645c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import torch
from torchdiffeq import odeint


# https://github.com/willisma/SiT/blob/main/transport/integrators.py#L77
class ODE:
    """ODE solver class"""
    def __init__(
        self,
        *,
        t0,
        t1,
        sampler_type,
        num_steps,
        atol,
        rtol,
    ):
        assert t0 < t1, "ODE sampler has to be in forward time"

        self.t = torch.linspace(t0, t1, num_steps)
        self.atol = atol
        self.rtol = rtol
        self.sampler_type = sampler_type

    def time_linear_to_Timesteps(self, t, t_start, t_end, T_start, T_end):
        # T = k * t + b
        k = (T_end - T_start) / (t_end - t_start)
        b = T_start - t_start * k
        return k * t + b

    def sample(self, x, model, T_start, T_end, **model_kwargs):
        device = x[0].device if isinstance(x, tuple) else x.device
        def _fn(t, x):
            t = torch.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else torch.ones(x.size(0)).to(device) * t
            model_output = model(x, self.time_linear_to_Timesteps(t, 0, 1, T_start, T_end), **model_kwargs)
            assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
            return model_output

        t = self.t.to(device)
        atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
        rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
        samples = odeint(
            _fn,
            x,
            t,
            method=self.sampler_type,
            atol=atol,
            rtol=rtol
        )
        return samples