risk_biased_prediction / tests /risk_biased /models /test_interaction_encoder.py
jmercat's picture
Removed history to avoid any unverified information being released
5769ee4
import os
import pytest
import torch
import torch.nn as nn
from mmcv import Config
from risk_biased.models.cvae_encoders import (
CVAEEncoder,
BiasedEncoderNN,
FutureEncoderNN,
InferenceEncoderNN,
)
from risk_biased.models.latent_distributions import (
GaussianLatentDistribution,
QuantizedDistributionCreator,
)
from risk_biased.models.cvae_params import CVAEParams
@pytest.fixture(scope="module")
def params():
torch.manual_seed(0)
working_dir = os.path.dirname(os.path.realpath(__file__))
config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py"
)
waymo_config_path = os.path.join(
working_dir, "..", "..", "..", "risk_biased", "config", "waymo_config.py"
)
paths = [config_path, waymo_config_path]
if isinstance(paths, str):
cfg = Config.fromfile(paths)
else:
cfg = Config.fromfile(paths[0])
for path in paths[1:]:
c = Config.fromfile(path)
cfg.update(c)
cfg.batch_size = 4
cfg.state_dim = 5
cfg.dynamic_state_dim = 5
cfg.map_state_dim = 2
cfg.num_steps = 3
cfg.num_steps_future = 4
cfg.latent_dim = 2
cfg.hidden_dim = 64
cfg.device = "cpu"
cfg.sequence_encoder_type = "LSTM"
cfg.sequence_decoder_type = "MLP"
return cfg
@pytest.mark.parametrize(
"num_agents, num_map_objects, type, interaction_nn_class",
[
(4, 5, "MLP", BiasedEncoderNN),
(2, 4, "LSTM", BiasedEncoderNN),
(3, 2, "maskedLSTM", BiasedEncoderNN),
(4, 5, "MLP", FutureEncoderNN),
(2, 4, "LSTM", FutureEncoderNN),
(3, 2, "maskedLSTM", FutureEncoderNN),
(4, 5, "MLP", InferenceEncoderNN),
(2, 4, "LSTM", InferenceEncoderNN),
(3, 2, "maskedLSTM", InferenceEncoderNN),
],
)
def test_attention_encoder_nn(
params,
num_agents: int,
num_map_objects: int,
type: str,
interaction_nn_class: nn.Module,
):
params.sequence_encoder_type = type
cvae_params = CVAEParams.from_config(params)
if interaction_nn_class == BiasedEncoderNN:
model = interaction_nn_class(
cvae_params,
num_steps=cvae_params.num_steps,
latent_dim=2 * cvae_params.latent_dim,
)
elif interaction_nn_class == FutureEncoderNN:
model = interaction_nn_class(
cvae_params,
num_steps=cvae_params.num_steps + cvae_params.num_steps_future,
latent_dim=2 * cvae_params.latent_dim,
)
else:
model = interaction_nn_class(
cvae_params,
num_steps=cvae_params.num_steps,
latent_dim=2 * cvae_params.latent_dim,
)
assert model.latent_dim == 2 * params.latent_dim
assert model.hidden_dim == params.hidden_dim
x = torch.rand(params.batch_size, num_agents, params.num_steps, params.state_dim)
offset = x[:, :, -1, :]
x = x - offset.unsqueeze(-2)
mask_x = torch.rand(params.batch_size, num_agents, params.num_steps) > 0.1
encoded_absolute = torch.rand(params.batch_size, num_agents, params.hidden_dim)
encoded_map = torch.rand(params.batch_size, num_map_objects, params.hidden_dim)
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1
if interaction_nn_class == FutureEncoderNN:
y = torch.rand(
params.batch_size, num_agents, params.num_steps_future, params.state_dim
)
y = y - offset.unsqueeze(-2)
y_ego = y[:, 0:1]
mask_y = torch.rand(params.batch_size, num_agents, params.num_steps_future)
else:
y = None
y_ego = None
mask_y = None
x_ego = x[:, 0:1]
if interaction_nn_class == BiasedEncoderNN:
risk_level = torch.rand(params.batch_size, num_agents)
else:
risk_level = None
output = model(
x,
mask_x,
encoded_absolute,
encoded_map,
mask_map,
y=y,
mask_y=mask_y,
x_ego=x_ego,
y_ego=y_ego,
offset=offset,
risk_level=risk_level,
)
# check shape
assert output.shape == (params.batch_size, num_agents, 2 * params.latent_dim)
@pytest.mark.parametrize(
"num_agents, num_map_objects, type, interaction_nn_class, latent_distribution_class",
[
(2, 8, "MLP", BiasedEncoderNN, GaussianLatentDistribution),
(7, 5, "LSTM", BiasedEncoderNN, GaussianLatentDistribution),
(2, 10, "maskedLSTM", BiasedEncoderNN, QuantizedDistributionCreator),
(2, 8, "MLP", FutureEncoderNN, GaussianLatentDistribution),
(7, 5, "LSTM", FutureEncoderNN, QuantizedDistributionCreator),
(2, 10, "maskedLSTM", FutureEncoderNN, GaussianLatentDistribution),
(2, 8, "MLP", InferenceEncoderNN, QuantizedDistributionCreator),
(7, 5, "LSTM", InferenceEncoderNN, GaussianLatentDistribution),
(2, 10, "maskedLSTM", InferenceEncoderNN, GaussianLatentDistribution),
],
)
# TODO: Add test for QuantizedDistributionCreator
def test_attention_cvae_encoder(
params,
num_agents: int,
num_map_objects: int,
type: str,
interaction_nn_class,
latent_distribution_class,
):
params.sequence_encoder_type = type
if interaction_nn_class == FutureEncoderNN:
risk_level = None
y = torch.rand(
params.batch_size, num_agents, params.num_steps_future, params.state_dim
)
mask_y = torch.rand(params.batch_size, num_agents, params.num_steps_future)
else:
risk_level = torch.rand(params.batch_size, num_agents)
y = None
mask_y = None
if interaction_nn_class == BiasedEncoderNN:
model = interaction_nn_class(
CVAEParams.from_config(params),
num_steps=params.num_steps,
latent_dim=2 * params.latent_dim,
)
elif interaction_nn_class == FutureEncoderNN:
model = interaction_nn_class(
CVAEParams.from_config(params),
num_steps=params.num_steps + params.num_steps_future,
latent_dim=2 * params.latent_dim,
)
else:
model = interaction_nn_class(
CVAEParams.from_config(params),
num_steps=params.num_steps,
latent_dim=2 * params.latent_dim,
)
encoder = CVAEEncoder(model, GaussianLatentDistribution)
# check latent_dim
assert encoder.latent_dim == 2 * params.latent_dim
x = torch.rand(params.batch_size, num_agents, params.num_steps, params.state_dim)
offset = x[:, :, -1, :]
x = x - offset.unsqueeze(-2)
if y is not None:
y = y - offset.unsqueeze(-2)
x_ego = x[:, 0:1]
y_ego = y[:, 0:1]
else:
x_ego = x[:, 0:1]
y_ego = None
mask_x = torch.rand(params.batch_size, num_agents, params.num_steps) > 0.1
encoded_absolute = torch.rand(params.batch_size, num_agents, params.hidden_dim)
encoded_map = torch.rand(params.batch_size, num_map_objects, params.hidden_dim)
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1
latent_distribution = encoder(
x=x,
mask_x=mask_x,
encoded_absolute=encoded_absolute,
encoded_map=encoded_map,
mask_map=mask_map,
y=y,
mask_y=mask_y,
x_ego=x_ego,
y_ego=y_ego,
offset=offset,
risk_level=risk_level,
)
latent_mean = latent_distribution.mu
latent_log_std = latent_distribution.logvar
# check shape
assert (
latent_mean.shape
== latent_log_std.shape
== (params.batch_size, num_agents, params.latent_dim)
)
latent_sample_1, weights = latent_distribution.sample()
# check shape when n_samples = 0
assert latent_sample_1.shape == latent_mean.shape
assert latent_sample_1.shape[:-1] == weights.shape
latent_sample_2, weights = latent_distribution.sample(n_samples=2)
# check shape when n_samples = 2
assert latent_sample_2.shape == (
params.batch_size,
num_agents,
2,
params.latent_dim,
)
latent_sample_3, weights = latent_distribution.sample()
# make sure sampling is non-deterministic
assert not torch.allclose(latent_sample_1, latent_sample_3)