Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.torch_utils import is_differentiable | |
from lzero.model.common import RepresentationNetwork | |
class TestCommon: | |
def output_check(self, model, outputs): | |
if isinstance(outputs, torch.Tensor): | |
loss = outputs.sum() | |
elif isinstance(outputs, list): | |
loss = sum([t.sum() for t in outputs]) | |
elif isinstance(outputs, dict): | |
loss = sum([v.sum() for v in outputs.values()]) | |
is_differentiable(loss, model) | |
def test_representation_network(self, batch_size): | |
batch = batch_size | |
obs = torch.rand(batch, 1, 3, 3) | |
representation_network = RepresentationNetwork( | |
observation_shape=[1, 3, 3], num_res_blocks=1, num_channels=16, downsample=False | |
) | |
state = representation_network(obs) | |
assert state.shape == torch.Size([10, 16, 3, 3]) | |