Spaces:
Running
Running
from cmath import isnan | |
import pytest | |
import torch | |
from mmcv import Config | |
from risk_biased.models.nn_blocks import ( | |
SequenceDecoderLSTM, | |
SequenceDecoderMLP, | |
SequenceEncoderMaskedLSTM, | |
SequenceEncoderMLP, | |
AttentionBlock, | |
) | |
def params(): | |
torch.manual_seed(0) | |
cfg = Config() | |
cfg.batch_size = 4 | |
cfg.input_dim = 10 | |
cfg.output_dim = 15 | |
cfg.latent_dim = 3 | |
cfg.h_dim = 32 | |
cfg.num_attention_heads = 4 | |
cfg.num_h_layers = 2 | |
cfg.device = "cpu" | |
return cfg | |
def test_AttentionBlock(params): | |
attention = AttentionBlock(params.h_dim, params.num_attention_heads) | |
num_agents = 4 | |
num_map_objects = 8 | |
encoded_agents = torch.rand(params.batch_size, num_agents, params.h_dim) | |
mask_agents = torch.rand(params.batch_size, num_agents) > 0.1 | |
encoded_absolute_agents = torch.rand(params.batch_size, num_agents, params.h_dim) | |
encoded_map = torch.rand(params.batch_size, num_map_objects, params.h_dim) | |
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1 | |
output = attention( | |
encoded_agents, mask_agents, encoded_absolute_agents, encoded_map, mask_map | |
) | |
# check shape | |
assert output.shape == (params.batch_size, num_agents, params.h_dim) | |
assert not torch.isnan(output).any() | |
def test_SequenceDecoder(params): | |
decoder = SequenceDecoderLSTM(params.h_dim) | |
num_agents = 8 | |
sequence_length = 16 | |
input = torch.rand(params.batch_size, num_agents, params.h_dim) | |
output = decoder(input, sequence_length) | |
assert output.shape == ( | |
params.batch_size, | |
num_agents, | |
sequence_length, | |
params.h_dim, | |
) | |
assert not torch.isnan(output).any() | |
def test_SequenceDecoderMLP(params): | |
sequence_length = 16 | |
decoder = SequenceDecoderMLP( | |
params.h_dim, params.num_h_layers, sequence_length, True | |
) | |
num_agents = 8 | |
input = torch.rand(params.batch_size, num_agents, params.h_dim) | |
output = decoder(input, sequence_length) | |
assert output.shape == ( | |
params.batch_size, | |
num_agents, | |
sequence_length, | |
params.h_dim, | |
) | |
assert not torch.isnan(output).any() | |
def test_SequenceEncoder(params): | |
encoder = SequenceEncoderMaskedLSTM(params.input_dim, params.h_dim) | |
num_agents = 8 | |
sequence_length = 16 | |
input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim) | |
mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1 | |
output = encoder(input, mask_input) | |
assert output.shape == (params.batch_size, num_agents, params.h_dim) | |
assert not torch.isnan(output).any() | |
def test_SequenceEncoderMLP(params): | |
sequence_length = 16 | |
num_agents = 8 | |
encoder = SequenceEncoderMLP( | |
params.input_dim, params.h_dim, params.num_h_layers, sequence_length, True | |
) | |
input = torch.rand(params.batch_size, num_agents, sequence_length, params.input_dim) | |
mask_input = torch.rand(params.batch_size, num_agents, sequence_length) > 0.1 | |
output = encoder(input, mask_input) | |
assert output.shape == (params.batch_size, num_agents, params.h_dim) | |
assert not torch.isnan(output).any() | |