Spaces:
Running
Running
import os | |
import pytest | |
import torch | |
from mmcv import Config | |
from risk_biased.mpc_planner.dynamics import PositionVelocityDoubleIntegrator | |
from risk_biased.mpc_planner.planner_cost import TrackingCost, TrackingCostParams | |
from risk_biased.mpc_planner.solver import CrossEntropySolver, CrossEntropySolverParams | |
from risk_biased.predictors.biased_predictor import ( | |
LitTrajectoryPredictorParams, | |
LitTrajectoryPredictor, | |
) | |
from risk_biased.scene_dataset.loaders import SceneDataLoaders | |
from risk_biased.utils.cost import TTCCostTorch, TTCCostParams | |
from risk_biased.utils.risk import get_risk_estimator | |
from risk_biased.utils.planner_utils import to_state | |
def params(): | |
torch.manual_seed(0) | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
config_path = os.path.join( | |
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py" | |
) | |
planning_config_path = os.path.join( | |
working_dir, "..", "..", "..", "risk_biased", "config", "planning_config.py" | |
) | |
paths = [config_path, planning_config_path] | |
if isinstance(paths, str): | |
cfg = Config.fromfile(paths) | |
else: | |
cfg = Config.fromfile(paths[0]) | |
for path in paths[1:]: | |
c = Config.fromfile(path) | |
cfg.update(c) | |
cfg.num_control_samples = 10 | |
cfg.num_elite = 3 | |
cfg.iter_max = 3 | |
cfg.smoothing_factor = 0.2 | |
cfg.mean_warm_start = True | |
cfg.num_steps = 3 | |
cfg.num_steps_future = 5 | |
cfg.state_dim = 5 | |
cfg.dynamic_state_dim = 5 | |
cfg.map_state_dim = 2 | |
cfg.max_size_lane = 2 | |
cfg.latent_dim = 2 | |
cfg.hidden_dim = 64 | |
cfg.num_hidden_layers = 3 | |
return cfg | |
class TestCrossEntropySolver: | |
def setup(self, params): | |
self.solver_params = CrossEntropySolverParams.from_config(params) | |
self.dynamics_model = PositionVelocityDoubleIntegrator(params.dt) | |
self.interaction_cost_function = TTCCostTorch(TTCCostParams.from_config(params)) | |
self.tracking_cost_function = TrackingCost( | |
TrackingCostParams.from_config(params) | |
) | |
self.risk_estimator = get_risk_estimator(params.risk_estimator) | |
self.control_input_mean_default = torch.randn( | |
1, params.num_steps_future, self.dynamics_model.control_dim | |
) | |
self.control_input_std_default = torch.rand_like( | |
self.control_input_mean_default | |
) | |
self.solver_default = CrossEntropySolver( | |
self.solver_params, | |
self.dynamics_model, | |
self.control_input_mean_default, | |
self.control_input_std_default, | |
self.tracking_cost_function, | |
self.interaction_cost_function, | |
self.risk_estimator, | |
) | |
predictor_params = LitTrajectoryPredictorParams.from_config(params) | |
self.predictor = LitTrajectoryPredictor( | |
predictor_params, | |
TTCCostParams.from_config(params), | |
SceneDataLoaders.unnormalize_trajectory, | |
) | |
self.normalizer = SceneDataLoaders.normalize_trajectory | |
def test_reset(self): | |
self.solver_default.reset() | |
assert self.solver_default._iter_current == 0 | |
assert torch.allclose( | |
self.solver_default._control_input_mean, self.control_input_mean_default | |
) | |
assert torch.allclose( | |
self.solver_default._control_input_std, self.control_input_std_default | |
) | |
assert self.solver_default._latest_ado_position_future_samples == None | |
def test_get_elites(self, params): | |
control_input = torch.randn( | |
params.num_control_samples, | |
1, | |
params.num_steps_future, | |
self.dynamics_model.control_dim, | |
) | |
risk = torch.Tensor( | |
[0.0, 1.0, 0.1, 0.9, 0.2, 0.8, 0.3, 0.7, 0.4, 0.6] | |
).unsqueeze(-1) | |
elite_control_input, elite_risk = self.solver_default._get_elites( | |
control_input, risk | |
) | |
assert elite_control_input.shape == torch.Size( | |
[ | |
params.num_elite, | |
1, | |
params.num_steps_future, | |
self.dynamics_model.control_dim, | |
] | |
) | |
assert elite_risk.shape == torch.Size([params.num_elite, 1]) | |
assert torch.allclose(elite_control_input, control_input[[0, 2, 4]]) | |
assert torch.allclose(elite_risk, torch.Tensor([0.0, 0.1, 0.2]).unsqueeze(-1)) | |
def test_update_control_distribution(self, params, num_elite, smoothing_factor): | |
solver = CrossEntropySolver( | |
self.solver_params, | |
self.dynamics_model, | |
self.control_input_mean_default, | |
self.control_input_std_default, | |
self.tracking_cost_function, | |
self.interaction_cost_function, | |
self.risk_estimator, | |
) | |
solver.params.num_elite = num_elite | |
solver.params.smoothing_factor = smoothing_factor | |
elite_control_input = torch.ones( | |
num_elite, params.num_steps_future, self.dynamics_model.control_dim | |
) | |
solver._update_control_distribution(elite_control_input) | |
if smoothing_factor == 0.0: | |
assert torch.allclose( | |
solver._control_input_mean, torch.ones_like(solver._control_input_mean) | |
) | |
assert torch.allclose( | |
solver._control_input_std, torch.zeros_like(solver._control_input_std) | |
) | |
else: | |
assert torch.allclose( | |
solver._control_input_mean, solver.control_input_mean_init | |
) | |
assert torch.allclose( | |
solver._control_input_std, solver.control_input_std_init | |
) | |
def test_sample_prediction(self, params, risk_level, num_prediction_samples): | |
num_agents = 1 | |
ado_state_history = to_state( | |
torch.randn(num_agents, params.num_steps, params.state_dim), params.dt | |
) | |
ego_state_history = to_state( | |
torch.randn(1, params.num_steps, params.state_dim), params.dt | |
) | |
ego_state_future = to_state( | |
torch.randn(1, params.num_steps_future, params.state_dim), params.dt | |
) | |
ado_position_future_samples, weights = CrossEntropySolver.sample_prediction( | |
self.predictor, | |
ado_state_history, | |
self.normalizer, | |
ego_state_history, | |
ego_state_future, | |
num_prediction_samples=num_prediction_samples, | |
risk_level=risk_level, | |
) | |
assert ado_position_future_samples.shape == torch.Size( | |
[num_prediction_samples, num_agents, params.num_steps_future] | |
) | |
def test_solve( | |
self, | |
params, | |
mean_warm_start, | |
risk_level, | |
resample_prediction, | |
risk_in_predictor, | |
): | |
num_prediction_samples = 5 | |
num_agents = 1 | |
ego_state_history = to_state( | |
torch.randn(num_agents, params.num_steps, params.state_dim), params.dt | |
) | |
ego_state_target_trajectory = to_state( | |
torch.randn(num_agents, params.num_steps_future, params.state_dim), | |
params.dt, | |
) | |
ado_state_history = to_state( | |
torch.randn(num_agents, params.num_steps, 2), params.dt | |
) | |
self.solver_default.params.mean_warm_start = mean_warm_start | |
self.solver_default.solve( | |
self.predictor, | |
ego_state_history, | |
ego_state_target_trajectory, | |
ado_state_history, | |
self.normalizer, | |
num_prediction_samples=num_prediction_samples, | |
risk_level=risk_level, | |
resample_prediction=resample_prediction, | |
risk_in_predictor=risk_in_predictor, | |
) | |
assert self.solver_default._iter_current == params.iter_max | |
assert self.solver_default.fetch_latest_prediction().shape == torch.Size( | |
[num_prediction_samples, num_agents, params.num_steps_future] | |
) | |
if not mean_warm_start: | |
assert torch.allclose( | |
self.solver_default.control_input_mean_init, | |
self.control_input_mean_default, | |
) | |
assert torch.allclose( | |
self.solver_default.control_input_std_init, | |
self.control_input_std_default, | |
) | |
else: | |
assert torch.allclose( | |
self.solver_default.control_input_mean_init[:, -1], | |
self.control_input_mean_default[:, -1], | |
) | |
assert torch.allclose( | |
self.solver_default.control_input_mean_init[:, :-1], | |
self.solver_default._control_input_mean[:, 1:], | |
) | |
assert torch.allclose( | |
self.solver_default.control_input_std_init, | |
self.control_input_std_default, | |
) | |