jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
from cmath import isnan
import pytest
import torch
from mmcv import Config
from risk_biased.models.nn_blocks import (
SequenceDecoderLSTM,
SequenceDecoderMLP,
SequenceEncoderMaskedLSTM,
SequenceEncoderMLP,
AttentionBlock,
)
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
cfg = Config()
cfg.batch_size = 4
cfg.input_dim = 10
cfg.output_dim = 15
cfg.latent_dim = 3
cfg.h_dim = 32
cfg.num_attention_heads = 4
cfg.num_h_layers = 2
cfg.device = "cpu"
return cfg
def test_AttentionBlock(params):
attention = AttentionBlock(params.h_dim, params.num_attention_heads)
num_agents = 4
num_map_objects = 8
encoded_agents = torch.rand(params.batch_size, num_agents, params.h_dim)
mask_agents = torch.rand(params.batch_size, num_agents) > 0.1
encoded_absolute_agents = torch.rand(params.batch_size, num_agents, params.h_dim)
encoded_map = torch.rand(params.batch_size, num_map_objects, params.h_dim)
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1
output = attention(
encoded_agents, mask_agents, encoded_absolute_agents, encoded_map, mask_map
)
# check shape
assert output.shape == (params.batch_size, num_agents, params.h_dim)
assert not torch.isnan(output).any()
def test_SequenceDecoder(params):
decoder = SequenceDecoderLSTM(params.h_dim)
num_agents = 8
sequence_length = 16
input = torch.rand(params.batch_size, num_agents, params.h_dim)
output = decoder(input, sequence_length)
assert output.shape == (
params.batch_size,
num_agents,
sequence_length,
params.h_dim,
)
assert not torch.isnan(output).any()
def test_SequenceDecoderMLP(params):
sequence_length = 16
decoder = SequenceDecoderMLP(
params.h_dim, params.num_h_layers, sequence_length, True
)
num_agents = 8
input = torch.rand(params.batch_size, num_agents, params.h_dim)
output = decoder(input, sequence_length)
assert output.shape == (
params.batch_size,
num_agents,
sequence_length,
params.h_dim,
)
assert not torch.isnan(output).any()
def test_SequenceEncoder(params):
encoder = SequenceEncoderMaskedLSTM(params.input_dim, params.h_dim)
num_agents = 8
sequence_length = 16
input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim)
mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1
output = encoder(input, mask_input)
assert output.shape == (params.batch_size, num_agents, params.h_dim)
assert not torch.isnan(output).any()
def test_SequenceEncoderMLP(params):
sequence_length = 16
num_agents = 8
encoder = SequenceEncoderMLP(
params.input_dim, params.h_dim, params.num_h_layers, sequence_length, True
)
input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim)
mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1
output = encoder(input, mask_input)
assert output.shape == (params.batch_size, num_agents, params.h_dim)
assert not torch.isnan(output).any()