jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
raw
history blame contribute delete
740 Bytes
import pytest
import torch
from mmcv import Config
from risk_biased.models.mlp import MLP
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
cfg = Config()
cfg.batch_size = 4
cfg.input_dim = 10
cfg.output_dim = 15
cfg.latent_dim = 3
cfg.h_dim = 64
cfg.num_h_layers = 2
cfg.device = "cpu"
cfg.is_mlp_residual = True
return cfg
def test_mlp(params):
mlp = MLP(
params.input_dim,
params.output_dim,
params.h_dim,
params.num_h_layers,
params.is_mlp_residual,
)
input = torch.rand(params.batch_size, params.input_dim)
output = mlp(input)
# check shape
assert output.shape == (params.batch_size, params.output_dim)