Spaces:
Running
Running
import os | |
import pytest | |
import torch | |
import numpy as np | |
from mmcv import Config | |
from risk_biased.utils.cost import BaseCostTorch, TTCCostTorch, DistanceCostTorch | |
from risk_biased.utils.cost import BaseCostNumpy, TTCCostNumpy, DistanceCostNumpy | |
from risk_biased.utils.cost import ( | |
CostParams, | |
TTCCostParams, | |
DistanceCostParams, | |
) | |
def params(): | |
torch.manual_seed(0) | |
working_dir = os.path.dirname(os.path.realpath(__file__)) | |
config_path = os.path.join( | |
working_dir, "..", "..", "..", "risk_biased", "config", "learning_config.py" | |
) | |
cfg = Config.fromfile(config_path) | |
cfg.cost_scale = 1 | |
cfg.cost_reduce = "mean" | |
cfg.ego_length = 4 | |
cfg.ego_width = 1.75 | |
cfg.distance_bandwidth = 2 | |
cfg.time_bandwidth = 2 | |
cfg.min_velocity_diff = 0.01 | |
return cfg | |
def get_fake_input(batch_size, num_steps, is_torch, use_mask, num_agents=0): | |
if num_agents <= 0: | |
shape = [batch_size, num_steps, 2] | |
else: | |
shape = [batch_size, num_agents, num_steps, 2] | |
if is_torch: | |
x1 = torch.rand(shape) | |
x2 = torch.rand(shape) | |
v1 = torch.rand(shape) | |
v2 = torch.rand(shape) | |
if use_mask: | |
mask = torch.rand(shape[:-1]) > 0.1 | |
else: | |
mask = None | |
else: | |
x1 = np.random.uniform(size=shape) | |
x2 = np.random.uniform(size=shape) | |
v1 = np.random.uniform(size=shape) | |
v2 = np.random.uniform(size=shape) | |
if use_mask: | |
mask = np.random.uniform(size=shape[:-1]) > 0.1 | |
else: | |
mask = None | |
return x1, x2, v1, v2, mask | |
def test_base_cost( | |
params, | |
reduce: str, | |
batch_size: int, | |
num_steps: int, | |
is_torch: bool, | |
use_mask: bool, | |
num_agents: int, | |
): | |
params.cost_reduce = reduce | |
cost_params = CostParams.from_config(params) | |
if is_torch: | |
base_cost = BaseCostTorch(cost_params) | |
else: | |
base_cost = BaseCostNumpy(cost_params) | |
x1, x2, v1, v2, mask = get_fake_input( | |
batch_size, num_steps, is_torch, use_mask, num_agents | |
) | |
cost, _ = base_cost(x1, x2, v1, v2, mask) | |
if num_agents > 0: | |
assert cost.shape == ( | |
batch_size, | |
num_agents, | |
) | |
else: | |
assert cost.shape == (batch_size,) | |
assert (cost == 0).all() | |
assert base_cost.scale == params.cost_scale | |
assert base_cost.distance_bandwidth == 1 | |
assert base_cost.time_bandwidth == 1 | |
def test_generic_cost( | |
params, | |
param_class, | |
cost_class, | |
reduce: str, | |
batch_size: int, | |
num_steps: int, | |
is_torch: bool, | |
use_mask: bool, | |
num_agents: int, | |
): | |
params.cost_reduce = reduce | |
cost_params = param_class.from_config(params) | |
x1, x2, v1, v2, mask = get_fake_input( | |
batch_size, num_steps, is_torch, use_mask, num_agents | |
) | |
compute_cost = cost_class(cost_params) | |
cost, _ = compute_cost(x1, x2, v1, v2, mask) | |
# Shaped is reduced | |
if num_agents > 0: | |
assert cost.shape == (batch_size, num_agents) | |
else: | |
assert cost.shape == (batch_size,) | |
assert (cost != 0).any() | |
assert compute_cost.scale == params.cost_scale | |
# Rescale the cost for comparison | |
compute_cost.scale = params.cost_scale + 10 | |
assert compute_cost.scale != params.cost_scale | |
rescaled_cost, _ = compute_cost(x1, x2, v1, v2, mask) | |
# all rescaled cost are larger but 0 cost is equal to rescaled cost | |
assert (rescaled_cost >= cost).all() | |
# at least some rescaled cost are strictly larger than normal scale cost | |
assert (rescaled_cost > cost).any() | |
# Compute mean and min costs to compare | |
params.cost_reduce = "mean" | |
cost_params_mean = param_class.from_config(params) | |
cost_function_mean = cost_class(cost_params_mean) | |
cost_mean, _ = cost_function_mean(x1, x2, v1, v2) | |
params.cost_reduce = "min" | |
cost_params_min = param_class.from_config(params) | |
cost_function_min = cost_class(cost_params_min) | |
cost_min, _ = cost_function_min(x1, x2, v1, v2) | |
# max reduce is larger than mean | |
if reduce == "max": | |
assert (cost >= cost_mean).all() | |
# min reduce is lower than any othir | |
assert (cost_mean >= cost_min).all() | |
assert (cost >= cost_min).all() | |