Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
import numpy as np | |
from ding.model.template.ebm import EBM, AutoregressiveEBM | |
from ding.model.template.ebm import DFO, AutoRegressiveDFO, MCMC | |
# batch, negative_samples, obs_shape, action_shape | |
B, N, O, A = 32, 1024, 11, 3 | |
class TestEBM: | |
def test_forward(self): | |
obs = torch.randn(B, N, O) | |
action = torch.randn(B, N, A) | |
ebm = EBM(O, A) | |
energy = ebm(obs, action) | |
assert energy.shape == (B, N) | |
class TestDFO: | |
opt = DFO(train_samples=N, inference_samples=N) | |
opt.set_action_bounds(np.stack([np.zeros(A), np.ones(A)], axis=0)) | |
ebm = EBM(O, A) | |
def test_sample(self): | |
obs = torch.randn(B, O) | |
tiled_obs, action_samples = self.opt.sample(obs, self.ebm) | |
assert tiled_obs.shape == (B, N, O) | |
assert action_samples.shape == (B, N, A) | |
def test_infer(self): | |
obs = torch.randn(B, O) | |
action = self.opt.infer(obs, self.ebm) | |
assert action.shape == (B, A) | |
class TestAutoregressiveEBM: | |
def test_forward(self): | |
obs = torch.randn(B, N, O) | |
action = torch.randn(B, N, A) | |
arebm = AutoregressiveEBM(O, A) | |
energy = arebm(obs, action) | |
assert energy.shape == (B, N, A) | |
class TestAutoregressiveDFO: | |
opt = AutoRegressiveDFO(train_samples=N, inference_samples=N) | |
opt.set_action_bounds(np.stack([np.zeros(A), np.ones(A)], axis=0)) | |
ebm = AutoregressiveEBM(O, A) | |
def test_sample(self): | |
obs = torch.randn(B, O) | |
tiled_obs, action_samples = self.opt.sample(obs, self.ebm) | |
assert tiled_obs.shape == (B, N, O) | |
assert action_samples.shape == (B, N, A) | |
def test_infer(self): | |
obs = torch.randn(B, O) | |
action = self.opt.infer(obs, self.ebm) | |
assert action.shape == (B, A) | |
class TestMCMC: | |
opt = MCMC(iters=3, train_samples=N, inference_samples=N) | |
opt.set_action_bounds(np.stack([np.zeros(A), np.ones(A)], axis=0)) | |
obs = torch.randn(B, N, O) | |
action = torch.randn(B, N, A) | |
ebm = EBM(O, A) | |
def test_gradient_wrt_act(self): | |
ebm = EBM(O, A) | |
# inference mode | |
de_dact = MCMC._gradient_wrt_act(self.obs, self.action, ebm) | |
assert de_dact.shape == (B, N, A) | |
# train mode | |
de_dact = MCMC._gradient_wrt_act(self.obs, self.action, ebm, create_graph=True) | |
loss = de_dact.pow(2).sum() | |
loss.backward() | |
assert de_dact.shape == (B, N, A) | |
assert ebm.net[0].weight.grad is not None | |
def test_langevin_step(self): | |
stepsize = 1 | |
action = self.opt._langevin_step(self.obs, self.action, stepsize, self.ebm) | |
assert action.shape == (B, N, A) | |
# TODO: new action should have lower energy | |
def test_langevin_action_given_obs(self): | |
action = self.opt._langevin_action_given_obs(self.obs, self.action, self.ebm) | |
assert action.shape == (B, N, A) | |
def test_grad_penalty(self): | |
ebm = EBM(O, A) | |
self.opt.add_grad_penalty = True | |
loss = self.opt.grad_penalty(self.obs, self.action, ebm) | |
loss.backward() | |
assert ebm.net[0].weight.grad is not None | |
def test_sample(self): | |
obs = torch.randn(B, O) | |
tiled_obs, action_samples = self.opt.sample(obs, self.ebm) | |
assert tiled_obs.shape == (B, N, O) | |
assert action_samples.shape == (B, N, A) | |
def test_infer(self): | |
obs = torch.randn(B, O) | |
action = self.opt.infer(obs, self.ebm) | |
assert action.shape == (B, A) | |