Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import pytest | |
from itertools import product | |
from ding.model.template import PG | |
from ding.torch_utils import is_differentiable | |
from ding.utils import squeeze | |
B = 4 | |
class TestDiscretePG: | |
def output_check(self, model, outputs): | |
if isinstance(outputs, torch.Tensor): | |
loss = outputs.sum() | |
elif isinstance(outputs, list): | |
loss = sum([t.sum() for t in outputs]) | |
elif isinstance(outputs, dict): | |
loss = sum([v.sum() for v in outputs.values()]) | |
is_differentiable(loss, model) | |
def test_discrete_pg(self): | |
obs_shape = (4, 84, 84) | |
action_shape = 5 | |
model = PG( | |
obs_shape, | |
action_shape, | |
) | |
inputs = torch.randn(B, 4, 84, 84) | |
outputs = model(inputs) | |
assert isinstance(outputs, dict) | |
assert outputs['logit'].shape == (B, action_shape) | |
assert outputs['dist'].sample().shape == (B, ) | |
self.output_check(model, outputs['logit']) | |
def test_continuous_pg(self): | |
N = 32 | |
action_shape = (6, ) | |
inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))} | |
model = PG( | |
obs_shape=(N, ), | |
action_shape=action_shape, | |
action_space='continuous', | |
) | |
# compute_action | |
print(model) | |
outputs = model(inputs['obs']) | |
assert isinstance(outputs, dict) | |
dist = outputs['dist'] | |
action = dist.sample() | |
assert action.shape == (B, *action_shape) | |
logit = outputs['logit'] | |
mu, sigma = logit['mu'], logit['sigma'] | |
assert mu.shape == (B, *action_shape) | |
assert sigma.shape == (B, *action_shape) | |
is_differentiable(mu.sum() + sigma.sum(), model) | |