Spaces:
Sleeping
Sleeping
from itertools import product | |
import pytest | |
import torch | |
from ding.torch_utils import is_differentiable | |
from lzero.model.alphazero_model import PredictionNetwork | |
action_space_size = [2, 3] | |
batch_size = [100, 200] | |
num_res_blocks = [3] | |
num_channels = [3] | |
value_head_channels = [8] | |
policy_head_channels = [8] | |
fc_value_layers = [[ | |
16, | |
]] | |
fc_policy_layers = [[ | |
16, | |
]] | |
output_support_size = [2] | |
observation_shape = [1, 3, 3] | |
prediction_network_args = list( | |
product( | |
action_space_size, | |
batch_size, | |
num_res_blocks, | |
num_channels, | |
value_head_channels, | |
policy_head_channels, | |
fc_value_layers, | |
fc_policy_layers, | |
output_support_size, | |
) | |
) | |
class TestAlphaZeroModel: | |
def output_check(self, model, outputs): | |
if isinstance(outputs, torch.Tensor): | |
loss = outputs.sum() | |
elif isinstance(outputs, list): | |
loss = sum([t.sum() for t in outputs]) | |
elif isinstance(outputs, dict): | |
loss = sum([v.sum() for v in outputs.values()]) | |
is_differentiable(loss, model) | |
def test_prediction_network( | |
self, action_space_size, batch_size, num_res_blocks, num_channels, value_head_channels, | |
policy_head_channels, | |
fc_value_layers, fc_policy_layers, output_support_size | |
): | |
obs = torch.rand(batch_size, num_channels, 3, 3) | |
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2] | |
flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2] | |
# print('='*20) | |
# print(batch_size, num_res_blocks, num_channels, action_space_size, fc_value_layers, fc_policy_layers, output_support_size) | |
# print('='*20) | |
prediction_network = PredictionNetwork( | |
action_space_size=action_space_size, | |
continuous_action_space=False, | |
num_res_blocks=num_res_blocks, | |
num_channels=num_channels, | |
value_head_channels=value_head_channels, | |
policy_head_channels=policy_head_channels, | |
fc_value_layers=fc_value_layers, | |
fc_policy_layers=fc_policy_layers, | |
output_support_size=output_support_size, | |
flatten_output_size_for_value_head=flatten_output_size_for_value_head, | |
flatten_output_size_for_policy_head=flatten_output_size_for_policy_head, | |
last_linear_layer_init_zero=True, | |
) | |
policy, value = prediction_network(obs) | |
assert policy.shape == torch.Size([batch_size, action_space_size]) | |
assert value.shape == torch.Size([batch_size, output_support_size]) | |
if __name__ == "__main__": | |
action_space_size = 2 | |
batch_size = 100 | |
num_res_blocks = 3 | |
num_channels = 3 | |
reward_head_channels = 2 | |
value_head_channels = 8 | |
policy_head_channels = 8 | |
fc_value_layers = [16] | |
fc_policy_layers = [16] | |
output_support_size = 2 | |
observation_shape = [1, 3, 3] | |
obs = torch.rand(batch_size, num_channels, 3, 3) | |
flatten_output_size_for_value_head = value_head_channels * observation_shape[1] * observation_shape[2] | |
flatten_output_size_for_policy_head = policy_head_channels * observation_shape[1] * observation_shape[2] | |
print('=' * 20) | |
print( | |
batch_size, num_res_blocks, num_channels, action_space_size, reward_head_channels, fc_value_layers, | |
fc_policy_layers, output_support_size | |
) | |
print('=' * 20) | |
prediction_network = PredictionNetwork( | |
action_space_size=action_space_size, | |
num_res_blocks=num_res_blocks, | |
num_channels=num_channels, | |
value_head_channels=value_head_channels, | |
policy_head_channels=policy_head_channels, | |
fc_value_layers=fc_value_layers, | |
fc_policy_layers=fc_policy_layers, | |
output_support_size=output_support_size, | |
flatten_output_size_for_value_head=flatten_output_size_for_value_head, | |
flatten_output_size_for_policy_head=flatten_output_size_for_policy_head, | |
last_linear_layer_init_zero=True, | |
) | |
policy, value = prediction_network(obs) | |
assert policy.shape == torch.Size([batch_size, action_space_size]) | |
assert value.shape == torch.Size([batch_size, output_support_size]) | |