Spaces:
Running
Running
import os | |
import pytest | |
import torch | |
import torch.nn as nn | |
from mmcv import Config | |
from risk_biased.models.cvae_encoders import ( | |
CVAEEncoder, | |
BiasedEncoderNN, | |
FutureEncoderNN, | |
InferenceEncoderNN, | |
) | |
from risk_biased.models.latent_distributions import ( | |
GaussianLatentDistribution, | |
QuantizedDistributionCreator, | |
) | |
from risk_biased.models.cvae_params import CVAEParams | |
def params(): | |
torch.manual_seed(0) | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
config_path = os.path.join( | |
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py" | |
) | |
waymo_config_path = os.path.join( | |
working_dir, "..", "..", "..", "risk_biased", "config", "waymo_config.py" | |
) | |
paths = [config_path, waymo_config_path] | |
if isinstance(paths, str): | |
cfg = Config.fromfile(paths) | |
else: | |
cfg = Config.fromfile(paths[0]) | |
for path in paths[1:]: | |
c = Config.fromfile(path) | |
cfg.update(c) | |
cfg.batch_size = 4 | |
cfg.state_dim = 5 | |
cfg.dynamic_state_dim = 5 | |
cfg.map_state_dim = 2 | |
cfg.num_steps = 3 | |
cfg.num_steps_future = 4 | |
cfg.latent_dim = 2 | |
cfg.hidden_dim = 64 | |
cfg.device = "cpu" | |
cfg.sequence_encoder_type = "LSTM" | |
cfg.sequence_decoder_type = "MLP" | |
return cfg | |
def test_attention_encoder_nn( | |
params, | |
num_agents: int, | |
num_map_objects: int, | |
type: str, | |
interaction_nn_class: nn.Module, | |
): | |
params.sequence_encoder_type = type | |
cvae_params = CVAEParams.from_config(params) | |
if interaction_nn_class == BiasedEncoderNN: | |
model = interaction_nn_class( | |
cvae_params, | |
num_steps=cvae_params.num_steps, | |
latent_dim=2 * cvae_params.latent_dim, | |
) | |
elif interaction_nn_class == FutureEncoderNN: | |
model = interaction_nn_class( | |
cvae_params, | |
num_steps=cvae_params.num_steps + cvae_params.num_steps_future, | |
latent_dim=2 * cvae_params.latent_dim, | |
) | |
else: | |
model = interaction_nn_class( | |
cvae_params, | |
num_steps=cvae_params.num_steps, | |
latent_dim=2 * cvae_params.latent_dim, | |
) | |
assert model.latent_dim == 2 * params.latent_dim | |
assert model.hidden_dim == params.hidden_dim | |
x = torch.rand(params.batch_size, num_agents, params.num_steps, params.state_dim) | |
offset = x[:, :, -1, :] | |
x = x - offset.unsqueeze(-2) | |
mask_x = torch.rand(params.batch_size, num_agents, params.num_steps) > 0.1 | |
encoded_absolute = torch.rand(params.batch_size, num_agents, params.hidden_dim) | |
encoded_map = torch.rand(params.batch_size, num_map_objects, params.hidden_dim) | |
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1 | |
if interaction_nn_class == FutureEncoderNN: | |
y = torch.rand( | |
params.batch_size, num_agents, params.num_steps_future, params.state_dim | |
) | |
y = y - offset.unsqueeze(-2) | |
y_ego = y[:, 0:1] | |
mask_y = torch.rand(params.batch_size, num_agents, params.num_steps_future) | |
else: | |
y = None | |
y_ego = None | |
mask_y = None | |
x_ego = x[:, 0:1] | |
if interaction_nn_class == BiasedEncoderNN: | |
risk_level = torch.rand(params.batch_size, num_agents) | |
else: | |
risk_level = None | |
output = model( | |
x, | |
mask_x, | |
encoded_absolute, | |
encoded_map, | |
mask_map, | |
y=y, | |
mask_y=mask_y, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
offset=offset, | |
risk_level=risk_level, | |
) | |
# check shape | |
assert output.shape == (params.batch_size, num_agents, 2 * params.latent_dim) | |
# TODO: Add test for QuantizedDistributionCreator | |
def test_attention_cvae_encoder( | |
params, | |
num_agents: int, | |
num_map_objects: int, | |
type: str, | |
interaction_nn_class, | |
latent_distribution_class, | |
): | |
params.sequence_encoder_type = type | |
if interaction_nn_class == FutureEncoderNN: | |
risk_level = None | |
y = torch.rand( | |
params.batch_size, num_agents, params.num_steps_future, params.state_dim | |
) | |
mask_y = torch.rand(params.batch_size, num_agents, params.num_steps_future) | |
else: | |
risk_level = torch.rand(params.batch_size, num_agents) | |
y = None | |
mask_y = None | |
if interaction_nn_class == BiasedEncoderNN: | |
model = interaction_nn_class( | |
CVAEParams.from_config(params), | |
num_steps=params.num_steps, | |
latent_dim=2 * params.latent_dim, | |
) | |
elif interaction_nn_class == FutureEncoderNN: | |
model = interaction_nn_class( | |
CVAEParams.from_config(params), | |
num_steps=params.num_steps + params.num_steps_future, | |
latent_dim=2 * params.latent_dim, | |
) | |
else: | |
model = interaction_nn_class( | |
CVAEParams.from_config(params), | |
num_steps=params.num_steps, | |
latent_dim=2 * params.latent_dim, | |
) | |
encoder = CVAEEncoder(model, GaussianLatentDistribution) | |
# check latent_dim | |
assert encoder.latent_dim == 2 * params.latent_dim | |
x = torch.rand(params.batch_size, num_agents, params.num_steps, params.state_dim) | |
offset = x[:, :, -1, :] | |
x = x - offset.unsqueeze(-2) | |
if y is not None: | |
y = y - offset.unsqueeze(-2) | |
x_ego = x[:, 0:1] | |
y_ego = y[:, 0:1] | |
else: | |
x_ego = x[:, 0:1] | |
y_ego = None | |
mask_x = torch.rand(params.batch_size, num_agents, params.num_steps) > 0.1 | |
encoded_absolute = torch.rand(params.batch_size, num_agents, params.hidden_dim) | |
encoded_map = torch.rand(params.batch_size, num_map_objects, params.hidden_dim) | |
mask_map = torch.rand(params.batch_size, num_map_objects) > 0.1 | |
latent_distribution = encoder( | |
x=x, | |
mask_x=mask_x, | |
encoded_absolute=encoded_absolute, | |
encoded_map=encoded_map, | |
mask_map=mask_map, | |
y=y, | |
mask_y=mask_y, | |
x_ego=x_ego, | |
y_ego=y_ego, | |
offset=offset, | |
risk_level=risk_level, | |
) | |
latent_mean = latent_distribution.mu | |
latent_log_std = latent_distribution.logvar | |
# check shape | |
assert ( | |
latent_mean.shape | |
== latent_log_std.shape | |
== (params.batch_size, num_agents, params.latent_dim) | |
) | |
latent_sample_1, weights = latent_distribution.sample() | |
# check shape when n_samples = 0 | |
assert latent_sample_1.shape == latent_mean.shape | |
assert latent_sample_1.shape[:-1] == weights.shape | |
latent_sample_2, weights = latent_distribution.sample(n_samples=2) | |
# check shape when n_samples = 2 | |
assert latent_sample_2.shape == ( | |
params.batch_size, | |
num_agents, | |
2, | |
params.latent_dim, | |
) | |
latent_sample_3, weights = latent_distribution.sample() | |
# make sure sampling is non-deterministic | |
assert not torch.allclose(latent_sample_1, latent_sample_3) | |