Hukuna's picture
Upload 275 files
e9e75df verified
raw
history blame
1.43 kB
from functools import partial
import pytest
import torch
from chroma.layers.sde import sde_integrate, sde_integrate_heun
@pytest.fixture
def y0():
# try multiple 1D trajectories, then take mean and variance in testing
return torch.zeros(10000)
@pytest.fixture
def tspan():
return (0.5, 0.3)
@pytest.fixture
def N():
return 200
@pytest.fixture
def exp_mean(y0, tspan):
return torch.Tensor(y0 + (tspan[1] - tspan[0]) / 2).mean()
@pytest.fixture
def exp_var(tspan):
deltat = tspan[1] - tspan[0]
# variance contributions arising from drift and diffusion, respectively
return torch.Tensor([deltat ** 2 / 12 + abs(deltat) / 6])
def sde_sample_func(t, y):
f = torch.ones_like(y)
gZ = torch.randn(y.shape)
return f, gZ
def test_sde_integrate(y0, tspan, N, exp_mean, exp_var):
y_trajectory = torch.stack(sde_integrate(sde_sample_func, y0, tspan, N), dim=-1)
assert torch.allclose(torch.mean(y_trajectory, dim=-1).mean(), exp_mean, rtol=5e-2)
assert torch.allclose(torch.var(y_trajectory, dim=-1).mean(), exp_var, rtol=5e-2)
def test_sde_integrate_heun(y0, tspan, N, exp_mean, exp_var):
y_trajectory = torch.stack(
sde_integrate_heun(sde_sample_func, y0, tspan, N), dim=-1
)
assert torch.allclose(torch.mean(y_trajectory, dim=-1).mean(), exp_mean, rtol=5e-2)
assert torch.allclose(torch.var(y_trajectory, dim=-1).mean(), exp_var, rtol=5e-2)