Spaces:
Sleeping
Sleeping
File size: 1,429 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from functools import partial
import pytest
import torch
from chroma.layers.sde import sde_integrate, sde_integrate_heun
@pytest.fixture
def y0():
# try multiple 1D trajectories, then take mean and variance in testing
return torch.zeros(10000)
@pytest.fixture
def tspan():
return (0.5, 0.3)
@pytest.fixture
def N():
return 200
@pytest.fixture
def exp_mean(y0, tspan):
return torch.Tensor(y0 + (tspan[1] - tspan[0]) / 2).mean()
@pytest.fixture
def exp_var(tspan):
deltat = tspan[1] - tspan[0]
# variance contributions arising from drift and diffusion, respectively
return torch.Tensor([deltat ** 2 / 12 + abs(deltat) / 6])
def sde_sample_func(t, y):
f = torch.ones_like(y)
gZ = torch.randn(y.shape)
return f, gZ
def test_sde_integrate(y0, tspan, N, exp_mean, exp_var):
y_trajectory = torch.stack(sde_integrate(sde_sample_func, y0, tspan, N), dim=-1)
assert torch.allclose(torch.mean(y_trajectory, dim=-1).mean(), exp_mean, rtol=5e-2)
assert torch.allclose(torch.var(y_trajectory, dim=-1).mean(), exp_var, rtol=5e-2)
def test_sde_integrate_heun(y0, tspan, N, exp_mean, exp_var):
y_trajectory = torch.stack(
sde_integrate_heun(sde_sample_func, y0, tspan, N), dim=-1
)
assert torch.allclose(torch.mean(y_trajectory, dim=-1).mean(), exp_mean, rtol=5e-2)
assert torch.allclose(torch.var(y_trajectory, dim=-1).mean(), exp_var, rtol=5e-2)
|