Spaces:
Sleeping
Sleeping
import pytest | |
from itertools import product | |
import numpy as np | |
import torch | |
from ding.rl_utils import coma_data, coma_error | |
random_weight = torch.rand(128, 4, 8) + 1 | |
weight_args = [None, random_weight] | |
def test_coma(weight): | |
T, B, A, N = 128, 4, 8, 32 | |
logit = torch.randn( | |
T, | |
B, | |
A, | |
N, | |
).requires_grad_(True) | |
action = torch.randint( | |
0, N, size=( | |
T, | |
B, | |
A, | |
) | |
) | |
reward = torch.rand(T, B) | |
q_value = torch.randn(T, B, A, N).requires_grad_(True) | |
target_q_value = torch.randn(T, B, A, N).requires_grad_(True) | |
mask = torch.randint(0, 2, (T, B, A)) | |
data = coma_data(logit, action, q_value, target_q_value, reward, weight) | |
loss = coma_error(data, 0.99, 0.95) | |
assert all([l.shape == tuple() for l in loss]) | |
assert logit.grad is None | |
assert q_value.grad is None | |
total_loss = sum(loss) | |
total_loss.backward() | |
assert isinstance(logit.grad, torch.Tensor) | |
assert isinstance(q_value.grad, torch.Tensor) | |