Spaces:
Sleeping
Sleeping
import numpy as np | |
import pytest | |
from easydict import EasyDict | |
from ding.torch_utils import to_list | |
from lzero.mcts.buffer.game_buffer_efficientzero import EfficientZeroGameBuffer | |
config = EasyDict( | |
dict( | |
batch_size=10, | |
transition_num=20, | |
priority_prob_alpha=0.6, | |
priority_prob_beta=0.4, | |
replay_buffer_size=10000, | |
env_type='not_board_games', | |
use_priority=True, | |
action_type='fixed_action_space', | |
) | |
) | |
def test_push(): | |
buffer = EfficientZeroGameBuffer(config) | |
# fake data | |
data = [[1, 1, 1] for _ in range(10)] # (s,a,r) | |
meta = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])} | |
# _push_game_segment | |
for i in range(20): | |
buffer._push_game_segment(to_list(np.multiply(i, data)), meta) | |
assert buffer.get_num_of_game_segments() == 20 | |
# push_game_segments | |
buffer.push_game_segments([[data, data], [meta, meta]]) | |
assert buffer.get_num_of_game_segments() == 22 | |
# Clear | |
del buffer.game_segment_buffer[:] | |
assert buffer.get_num_of_game_segments() == 0 | |
# _push_game_segment | |
for i in range(5): | |
buffer._push_game_segment(to_list(np.multiply(i, data)), meta) | |
def test_update_priority(): | |
buffer = EfficientZeroGameBuffer(config) | |
# fake data | |
data = [[1, 1, 1] for _ in range(10)] # (s,a,r) | |
meta = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])} | |
# _push_game_segment | |
for i in range(20): | |
buffer._push_game_segment(to_list(np.multiply(i, data)), meta) | |
assert buffer.get_num_of_game_segments() == 20 | |
# fake data | |
indices = [0, 1] | |
make_time = [999, 1000] | |
train_data = [[[], [], [], indices, [], make_time], []] | |
# train_data = [current_batch, target_batch] | |
# current_batch = [obs_lst, action_lst, mask_lst, batch_index_list, weights, make_time_lst] | |
batch_priorities = [0.999, 0.8] | |
buffer.update_priority(train_data, batch_priorities) | |
assert buffer.game_pos_priorities[0] == 0.999 | |
def test_sample_orig_data(): | |
buffer = EfficientZeroGameBuffer(config) | |
# fake data | |
data_1 = [[1, 1, 1] for i in range(10)] # (s,a,r) | |
meta_1 = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])} | |
data_2 = [[1, 1, 1] for i in range(10, 20)] # (s,a,r) | |
meta_2 = {'done': True, 'unroll_plus_td_steps': 5, 'priorities': np.array([0.9 for i in range(10)])} | |
# push | |
buffer._push_game_segment(data_1, meta_1) | |
buffer._push_game_segment(data_2, meta_2) | |
context = buffer._sample_orig_data(batch_size=2) | |
# context = (game_lst, game_pos_lst, indices_lst, weights, make_time) | |
print(context) | |