Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.rl_utils import compute_q_retraces | |
def test_compute_q_retraces(): | |
T, B, N = 64, 32, 6 | |
q_values = torch.randn(T + 1, B, N) | |
v_pred = torch.randn(T + 1, B, 1) | |
rewards = torch.randn(T, B) | |
ratio = torch.rand(T, B, N) * 0.4 + 0.8 | |
assert ratio.max() <= 1.2 and ratio.min() >= 0.8 | |
weights = torch.rand(T, B) | |
actions = torch.randint(0, N, size=(T, B)) | |
with torch.no_grad(): | |
q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, gamma=0.99) | |
assert q_retraces.shape == (T + 1, B, 1) | |