Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\ | |
td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data,\ | |
dqfd_nstep_td_data, dqfd_nstep_td_error, dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data,\ | |
v_nstep_td_error, q_nstep_sql_td_error, iqn_nstep_td_data, iqn_nstep_td_error,\ | |
fqf_nstep_td_data, fqf_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error, bdq_nstep_td_error,\ | |
m_q_1step_td_data, m_q_1step_td_error | |
from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale | |
def test_q_nstep_td(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True) | |
value_gamma = torch.tensor(0.9) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma) | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
def test_bdq_nstep_td(): | |
batch_size = 8 | |
branch_num = 6 | |
action_per_branch = 3 | |
next_q = torch.randn(batch_size, branch_num, action_per_branch) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_per_branch, size=(batch_size, branch_num)) | |
next_action = torch.randint(0, action_per_branch, size=(batch_size, branch_num)) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, branch_num, action_per_branch).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True) | |
value_gamma = torch.tensor(0.9) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample = bdq_nstep_td_error( | |
data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma | |
) | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
def test_q_nstep_td_ngu(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
gamma = [torch.tensor(0.95) for i in range(batch_size)] | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample = q_nstep_td_error(data, gamma, nstep=nstep) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
def test_dist_1step_td(): | |
batch_size = 4 | |
action_dim = 3 | |
n_atom = 51 | |
v_min = -10.0 | |
v_max = 10.0 | |
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) | |
next_dist = torch.randn(batch_size, action_dim, n_atom).abs() | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
reward = torch.randn(batch_size) | |
data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) | |
loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) | |
assert loss.shape == () | |
assert dist.grad is None | |
loss.backward() | |
assert isinstance(dist.grad, torch.Tensor) | |
def test_q_1step_compatible(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(batch_size) | |
nstep_data = q_nstep_td_data(q, next_q, action, next_action, reward.unsqueeze(0), done, None) | |
onestep_data = q_1step_td_data(q, next_q, action, next_action, reward, done, None) | |
nstep_loss, _ = q_nstep_td_error(nstep_data, 0.99, nstep=1) | |
onestep_loss = q_1step_td_error(onestep_data, 0.99) | |
assert pytest.approx(nstep_loss.item()) == onestep_loss.item() | |
def test_dist_nstep_td(): | |
batch_size = 4 | |
action_dim = 3 | |
n_atom = 51 | |
v_min = -10.0 | |
v_max = 10.0 | |
nstep = 5 | |
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) | |
next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs() | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
reward = torch.randn(nstep, batch_size) | |
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) | |
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep) | |
assert loss.shape == () | |
assert dist.grad is None | |
loss.backward() | |
assert isinstance(dist.grad, torch.Tensor) | |
weight = torch.tensor([0.9]) | |
value_gamma = torch.tensor(0.9) | |
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight) | |
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) | |
assert loss.shape == () | |
loss.backward() | |
assert isinstance(dist.grad, torch.Tensor) | |
def test_dist_nstep_multi_agent_td(): | |
batch_size = 4 | |
action_dim = 3 | |
agent_num = 2 | |
n_atom = 51 | |
v_min = -10.0 | |
v_max = 10.0 | |
nstep = 5 | |
dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True) | |
next_n_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs() | |
done = torch.randint(0, 2, (batch_size, )) | |
action = torch.randint( | |
0, action_dim, size=( | |
batch_size, | |
agent_num, | |
) | |
) | |
next_action = torch.randint( | |
0, action_dim, size=( | |
batch_size, | |
agent_num, | |
) | |
) | |
reward = torch.randn(nstep, batch_size) | |
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) | |
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep) | |
assert loss.shape == () | |
assert dist.grad is None | |
loss.backward() | |
assert isinstance(dist.grad, torch.Tensor) | |
weight = 0.9 | |
value_gamma = 0.9 | |
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight) | |
loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) | |
assert loss.shape == () | |
loss.backward() | |
assert isinstance(dist.grad, torch.Tensor) | |
agent_total_loss = 0 | |
for i in range(agent_num): | |
data = dist_nstep_td_data( | |
dist[:, i, ], next_n_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, weight | |
) | |
agent_loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma) | |
agent_total_loss = agent_total_loss + agent_loss | |
agent_average_loss = agent_total_loss / agent_num | |
assert abs(agent_average_loss.item() - loss.item()) < 1e-5 | |
def test_q_nstep_td_with_rescale(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
print(loss) | |
def test_q_nstep_td_with_rescale_ngu(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
gamma = [torch.tensor(0.95) for i in range(batch_size)] | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, _ = q_nstep_td_error_with_rescale(data, gamma, nstep=nstep) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
print(loss) | |
def test_qrdqn_nstep_td(): | |
batch_size = 4 | |
action_dim = 3 | |
tau = 3 | |
next_q = torch.randn(batch_size, action_dim, tau) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim, tau).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, tau, None) | |
loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) | |
assert td_error_per_sample.shape == (batch_size, ) | |
def test_dist_1step_compatible(): | |
batch_size = 4 | |
action_dim = 3 | |
n_atom = 51 | |
v_min = -10.0 | |
v_max = 10.0 | |
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) | |
next_dist = torch.randn(batch_size, action_dim, n_atom).abs() | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
reward = torch.randn(batch_size) | |
onestep_data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) | |
nstep_data = dist_nstep_td_data(dist, next_dist, action, next_action, reward.unsqueeze(0), done, None) | |
onestep_loss = dist_1step_td_error(onestep_data, 0.95, v_min, v_max, n_atom) | |
nstep_loss, _ = dist_nstep_td_error(nstep_data, 0.95, v_min, v_max, n_atom, nstep=1) | |
assert pytest.approx(nstep_loss.item()) == onestep_loss.item() | |
def test_dist_1step_multi_agent_td(): | |
batch_size = 4 | |
action_dim = 3 | |
agent_num = 2 | |
n_atom = 51 | |
v_min = -10.0 | |
v_max = 10.0 | |
dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True) | |
next_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs() | |
done = torch.randint(0, 2, (batch_size, )) | |
action = torch.randint( | |
0, action_dim, size=( | |
batch_size, | |
agent_num, | |
) | |
) | |
next_action = torch.randint( | |
0, action_dim, size=( | |
batch_size, | |
agent_num, | |
) | |
) | |
reward = torch.randn(batch_size) | |
data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None) | |
loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) | |
assert loss.shape == () | |
assert dist.grad is None | |
loss.backward() | |
assert isinstance(dist.grad, torch.Tensor) | |
agent_total_loss = 0 | |
for i in range(agent_num): | |
data = dist_1step_td_data( | |
dist[:, i, ], next_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, None | |
) | |
agent_loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom) | |
agent_total_loss = agent_total_loss + agent_loss | |
agent_average_loss = agent_total_loss / agent_num | |
assert abs(agent_average_loss.item() - loss.item()) < 1e-5 | |
def test_td_lambda(): | |
T, B = 8, 4 | |
value = torch.randn(T + 1, B).requires_grad_(True) | |
reward = torch.rand(T, B) | |
loss = td_lambda_error(td_lambda_data(value, reward, None)) | |
assert loss.shape == () | |
assert value.grad is None | |
loss.backward() | |
assert isinstance(value.grad, torch.Tensor) | |
def test_v_1step_td(): | |
batch_size = 5 | |
v = torch.randn(batch_size).requires_grad_(True) | |
next_v = torch.randn(batch_size) | |
reward = torch.rand(batch_size) | |
done = torch.zeros(batch_size) | |
data = v_1step_td_data(v, next_v, reward, done, None) | |
loss, td_error_per_sample = v_1step_td_error(data, 0.99) | |
assert loss.shape == () | |
assert v.grad is None | |
loss.backward() | |
assert isinstance(v.grad, torch.Tensor) | |
data = v_1step_td_data(v, next_v, reward, None, None) | |
loss, td_error_per_sample = v_1step_td_error(data, 0.99) | |
loss.backward() | |
assert isinstance(v.grad, torch.Tensor) | |
def test_v_1step_multi_agent_td(): | |
batch_size = 5 | |
agent_num = 2 | |
v = torch.randn(batch_size, agent_num).requires_grad_(True) | |
next_v = torch.randn(batch_size, agent_num) | |
reward = torch.rand(batch_size) | |
done = torch.zeros(batch_size) | |
data = v_1step_td_data(v, next_v, reward, done, None) | |
loss, td_error_per_sample = v_1step_td_error(data, 0.99) | |
assert loss.shape == () | |
assert v.grad is None | |
loss.backward() | |
assert isinstance(v.grad, torch.Tensor) | |
data = v_1step_td_data(v, next_v, reward, None, None) | |
loss, td_error_per_sample = v_1step_td_error(data, 0.99) | |
loss.backward() | |
assert isinstance(v.grad, torch.Tensor) | |
def test_v_nstep_td(): | |
batch_size = 5 | |
v = torch.randn(batch_size).requires_grad_(True) | |
next_v = torch.randn(batch_size) | |
reward = torch.rand(5, batch_size) | |
done = torch.zeros(batch_size) | |
data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99) | |
loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) | |
assert loss.shape == () | |
assert v.grad is None | |
loss.backward() | |
assert isinstance(v.grad, torch.Tensor) | |
data = v_nstep_td_data(v, next_v, reward, done, None, 0.99) | |
loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5) | |
loss.backward() | |
assert isinstance(v.grad, torch.Tensor) | |
def test_dqfd_nstep_td(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
done_1 = torch.randn(batch_size) | |
next_q_one_step = torch.randn(batch_size, action_dim) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action_one_step = torch.randint(0, action_dim, size=(batch_size, )) | |
is_expert = torch.ones((batch_size)) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = dqfd_nstep_td_data( | |
q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert | |
) | |
loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error( | |
data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep | |
) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
print(loss) | |
def test_q_nstep_sql_td(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 0.5, nstep=nstep, cum_reward=True) | |
value_gamma = torch.tensor(0.9) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error( | |
data, 0.95, 0.5, nstep=nstep, cum_reward=True, value_gamma=value_gamma | |
) | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
def test_iqn_nstep_td(): | |
batch_size = 4 | |
action_dim = 3 | |
tau = 3 | |
next_q = torch.randn(tau, batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(tau, batch_size, action_dim).requires_grad_(True) | |
replay_quantile = torch.randn([tau, batch_size, 1]) | |
reward = torch.rand(nstep, batch_size) | |
data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None) | |
loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) | |
assert td_error_per_sample.shape == (batch_size, ) | |
def test_fqf_nstep_td(): | |
batch_size = 4 | |
action_dim = 3 | |
tau = 3 | |
next_q = torch.randn(batch_size, tau, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, tau, action_dim).requires_grad_(True) | |
quantiles_hats = torch.randn([batch_size, tau]) | |
reward = torch.rand(nstep, batch_size) | |
data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None) | |
loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep) | |
assert td_error_per_sample.shape == (batch_size, ) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9)) | |
assert td_error_per_sample.shape == (batch_size, ) | |
def test_shape_fn_qntd(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
tmp = shape_fn_qntd([data, 0.95, 1], {}) | |
assert tmp[0] == reward.shape[0] | |
assert tmp[1] == q.shape[0] | |
assert tmp[2] == q.shape[1] | |
tmp = shape_fn_qntd([], {'gamma': 0.95, 'nstep': 1, 'data': data}) | |
assert tmp[0] == reward.shape[0] | |
assert tmp[1] == q.shape[0] | |
assert tmp[2] == q.shape[1] | |
def test_shape_fn_dntd(): | |
batch_size = 4 | |
action_dim = 3 | |
n_atom = 51 | |
v_min = -10.0 | |
v_max = 10.0 | |
nstep = 5 | |
dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True) | |
next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs() | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
reward = torch.randn(nstep, batch_size) | |
data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None) | |
tmp = shape_fn_dntd([data, 0.9, v_min, v_max, n_atom, nstep], {}) | |
assert tmp[0] == reward.shape[0] | |
assert tmp[1] == dist.shape[0] | |
assert tmp[2] == dist.shape[1] | |
assert tmp[3] == n_atom | |
tmp = shape_fn_dntd([], {'data': data, 'gamma': 0.9, 'v_min': v_min, 'v_max': v_max, 'n_atom': n_atom, 'nstep': 5}) | |
assert tmp[0] == reward.shape[0] | |
assert tmp[1] == dist.shape[0] | |
assert tmp[2] == dist.shape[1] | |
assert tmp[3] == n_atom | |
def test_shape_fn_qntd_rescale(): | |
batch_size = 4 | |
action_dim = 3 | |
next_q = torch.randn(batch_size, action_dim) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
next_action = torch.randint(0, action_dim, size=(batch_size, )) | |
for nstep in range(1, 10): | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
reward = torch.rand(nstep, batch_size) | |
data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None) | |
tmp = shape_fn_qntd_rescale([data, 0.95, 1], {}) | |
assert tmp[0] == reward.shape[0] | |
assert tmp[1] == q.shape[0] | |
assert tmp[2] == q.shape[1] | |
tmp = shape_fn_qntd_rescale([], {'gamma': 0.95, 'nstep': 1, 'data': data}) | |
assert tmp[0] == reward.shape[0] | |
assert tmp[1] == q.shape[0] | |
assert tmp[2] == q.shape[1] | |
def test_fn_td_lambda(): | |
T, B = 8, 4 | |
value = torch.randn(T + 1, B).requires_grad_(True) | |
reward = torch.rand(T, B) | |
data = td_lambda_data(value, reward, None) | |
tmp = shape_fn_td_lambda([], {'data': data}) | |
assert tmp == reward.shape[0] | |
tmp = shape_fn_td_lambda([data], {}) | |
assert tmp == reward.shape | |
def test_fn_m_q_1step_td_error(): | |
batch_size = 128 | |
action_dim = 9 | |
q = torch.randn(batch_size, action_dim).requires_grad_(True) | |
target_q_current = torch.randn(batch_size, action_dim).requires_grad_(False) | |
target_q_next = torch.randn(batch_size, action_dim).requires_grad_(False) | |
done = torch.randn(batch_size) | |
action = torch.randint(0, action_dim, size=(batch_size, )) | |
reward = torch.randn(batch_size) | |
data = m_q_1step_td_data(q, target_q_current, target_q_next, action, reward, done, None) | |
loss, td_error_per_sample, action_gap, clip_frac = m_q_1step_td_error(data, 0.99, 0.03, 0.6) | |
assert loss.shape == () | |
assert q.grad is None | |
loss.backward() | |
assert isinstance(q.grad, torch.Tensor) | |
assert clip_frac.mean().item() <= 1 | |
assert action_gap.item() > 0 | |
assert td_error_per_sample.shape == (batch_size, ) | |