Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.rl_utils.upgo import upgo_loss, upgo_returns, tb_cross_entropy | |
def test_upgo(): | |
T, B, N, N2 = 4, 8, 5, 7 | |
# tb_cross_entropy: 3 tests | |
logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True) | |
action = logit.argmax(-1).detach() | |
ce = tb_cross_entropy(logit, action) | |
assert ce.shape == (T, B) | |
logit = torch.randn(T, B, N, N2, 2).softmax(-1).requires_grad_(True) | |
action = logit.argmax(-1).detach() | |
with pytest.raises(AssertionError): | |
ce = tb_cross_entropy(logit, action) | |
logit = torch.randn(T, B, N).softmax(-1).requires_grad_(True) | |
action = logit.argmax(-1).detach() | |
ce = tb_cross_entropy(logit, action) | |
assert ce.shape == (T, B) | |
# upgo_returns | |
rewards = torch.randn(T, B) | |
bootstrap_values = torch.randn(T + 1, B).requires_grad_(True) | |
returns = upgo_returns(rewards, bootstrap_values) | |
assert returns.shape == (T, B) | |
# upgo loss | |
rhos = torch.randn(T, B) | |
loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values) | |
assert logit.requires_grad | |
assert bootstrap_values.requires_grad | |
for t in [logit, bootstrap_values]: | |
assert t.grad is None | |
loss.backward() | |
for t in [logit]: | |
assert isinstance(t.grad, torch.Tensor) | |