Spaces:
Sleeping
Sleeping
import pytest | |
import numpy as np | |
import random | |
import torch | |
from ding.data.level_replay.level_sampler import LevelSampler | |
def test_level_sampler(): | |
num_seeds = 500 | |
obs_shape = [3, 64, 64] | |
action_shape = 15 | |
collector_env_num = 16 | |
level_replay_dict = dict( | |
strategy='min_margin', | |
score_transform='rank', | |
temperature=0.1, | |
) | |
N = 10 | |
collector_sample_length = 160 | |
train_seeds = [i for i in range(num_seeds)] | |
level_sampler = LevelSampler(train_seeds, obs_shape, action_shape, collector_env_num, level_replay_dict) | |
value = torch.randn(collector_sample_length) | |
reward = torch.randn(collector_sample_length) | |
adv = torch.randn(collector_sample_length) | |
done = torch.randn(collector_sample_length) | |
logit = torch.randn(collector_sample_length, N) | |
seeds = [random.randint(0, num_seeds) for i in range(collector_env_num)] | |
all_seeds = torch.Tensor( | |
[seeds[i] for i in range(collector_env_num) for j in range(int(collector_sample_length / collector_env_num))] | |
) | |
train_data = {'value': value, 'reward': reward, 'adv': adv, 'done': done, 'logit': logit, 'seed': all_seeds} | |
level_sampler.update_with_rollouts(train_data, collector_env_num) | |
sample_seed = level_sampler.sample() | |
assert isinstance(sample_seed, int) | |