Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from ding.data.buffer import DequeBuffer | |
from ding.data.buffer.middleware import clone_object, use_time_check, staleness_check, sample_range_view | |
from ding.data.buffer.middleware import PriorityExperienceReplay, group_sample | |
from ding.data.buffer.middleware.padding import padding | |
def test_clone_object(): | |
buffer = DequeBuffer(size=10).use(clone_object()) | |
# Store a dict, a list, a tensor | |
arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])] | |
for o in arr: | |
buffer.push(o) | |
# Modify it | |
for item in buffer.sample(len(arr)): | |
item = item.data | |
if isinstance(item, dict): | |
item["key"] = "v2" | |
elif isinstance(item, list): | |
item.append("b") | |
elif isinstance(item, torch.Tensor): | |
item[0] = 3 | |
else: | |
raise Exception("Unexpected type") | |
# Resample it, and check their values | |
for item in buffer.sample(len(arr)): | |
item = item.data | |
if isinstance(item, dict): | |
assert item["key"] == "v1" | |
elif isinstance(item, list): | |
assert len(item) == 1 | |
elif isinstance(item, torch.Tensor): | |
assert item[0] == 1 | |
else: | |
raise Exception("Unexpected type") | |
def get_data(): | |
return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'} | |
def test_use_time_check(): | |
N = 6 | |
buffer = DequeBuffer(size=10) | |
buffer.use(use_time_check(buffer, max_use=2)) | |
for _ in range(N): | |
buffer.push(get_data()) | |
for _ in range(2): | |
data = buffer.sample(size=N, replace=False) | |
assert len(data) == N | |
with pytest.raises(ValueError): | |
buffer.sample(size=1, replace=False) | |
def test_staleness_check(): | |
N = 6 | |
buffer = DequeBuffer(size=10) | |
buffer.use(staleness_check(buffer, max_staleness=10)) | |
with pytest.raises(AssertionError): | |
buffer.push(get_data()) | |
for _ in range(N): | |
buffer.push(get_data(), meta={'train_iter_data_collected': 0}) | |
data = buffer.sample(size=N, replace=False, train_iter_sample_data=9) | |
assert len(data) == N | |
data = buffer.sample(size=N, replace=False, train_iter_sample_data=10) # edge case | |
assert len(data) == N | |
for _ in range(2): | |
buffer.push(get_data(), meta={'train_iter_data_collected': 5}) | |
assert buffer.count() == 8 | |
with pytest.raises(ValueError): | |
data = buffer.sample(size=N, replace=False, train_iter_sample_data=11) | |
assert buffer.count() == 2 | |
def test_priority(): | |
N = 5 | |
buffer = DequeBuffer(size=10) | |
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True)) | |
for _ in range(N): | |
buffer.push(get_data(), meta={'priority': 2.0}) | |
assert buffer.count() == N | |
for _ in range(N): | |
buffer.push(get_data(), meta={'priority': 2.0}) | |
assert buffer.count() == N + N | |
data = buffer.sample(size=N + N, replace=False) | |
assert len(data) == N + N | |
for item in data: | |
meta = item.meta | |
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS'])) | |
meta['priority'] = 3.0 | |
for item in data: | |
data, index, meta = item.data, item.index, item.meta | |
buffer.update(index, data, meta) | |
data = buffer.sample(size=1) | |
assert data[0].meta['priority'] == 3.0 | |
buffer.delete(data[0].index) | |
assert buffer.count() == N + N - 1 | |
buffer.clear() | |
assert buffer.count() == 0 | |
def test_priority_from_collector(): | |
N = 5 | |
buffer = DequeBuffer(size=10) | |
buffer.use(PriorityExperienceReplay(buffer, IS_weight=True)) | |
for _ in range(N): | |
tmp_data = get_data() | |
tmp_data['priority'] = 2.0 | |
buffer.push(get_data()) | |
assert buffer.count() == N | |
for _ in range(N): | |
tmp_data = get_data() | |
tmp_data['priority'] = 2.0 | |
buffer.push(get_data()) | |
assert buffer.count() == N + N | |
data = buffer.sample(size=N + N, replace=False) | |
assert len(data) == N + N | |
for item in data: | |
meta = item.meta | |
assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS'])) | |
meta['priority'] = 3.0 | |
for item in data: | |
data, index, meta = item.data, item.index, item.meta | |
buffer.update(index, data, meta) | |
data = buffer.sample(size=1) | |
assert data[0].meta['priority'] == 3.0 | |
buffer.delete(data[0].index) | |
assert buffer.count() == N + N - 1 | |
buffer.clear() | |
assert buffer.count() == 0 | |
def test_padding(): | |
buffer = DequeBuffer(size=10) | |
buffer.use(padding()) | |
for i in range(10): | |
buffer.push(i, {"group": i & 5}) # [3,3,2,2] | |
sampled_data = buffer.sample(4, groupby="group") | |
assert len(sampled_data) == 4 | |
for grouped_data in sampled_data: | |
assert len(grouped_data) == 3 | |
def test_group_sample(): | |
buffer = DequeBuffer(size=10) | |
buffer.use(padding(policy="none")).use(group_sample(size_in_group=5, ordered_in_group=True, max_use_in_group=True)) | |
for i in range(4): | |
buffer.push(i, {"episode": 0}) | |
for i in range(6): | |
buffer.push(i, {"episode": 1}) | |
sampled_data = buffer.sample(2, groupby="episode") | |
assert len(sampled_data) == 2 | |
def check_group0(grouped_data): | |
# In group0 should find only last record with data as None | |
n_none = 0 | |
for item in grouped_data: | |
if item.data is None: | |
n_none += 1 | |
assert n_none == 1 | |
def check_group1(grouped_data): | |
# In group1 every record should have data and meta | |
for item in grouped_data: | |
assert item.data is not None | |
for grouped_data in sampled_data: | |
assert len(grouped_data) == 5 | |
meta = grouped_data[0].meta | |
if meta and "episode" in meta and meta["episode"] == 1: | |
check_group1(grouped_data) | |
else: | |
check_group0(grouped_data) | |
def test_sample_range_view(): | |
buffer_ = DequeBuffer(size=10) | |
for i in range(5): | |
buffer_.push({'data': 'x'}) | |
for i in range(5, 5 + 3): | |
buffer_.push({'data': 'y'}) | |
for i in range(8, 8 + 2): | |
buffer_.push({'data': 'z'}) | |
buffer1 = buffer_.view() | |
buffer1.use(sample_range_view(buffer1, start=-5, end=-2)) | |
for _ in range(10): | |
sampled_data = buffer1.sample(1) | |
assert sampled_data[0].data['data'] == 'y' | |
buffer2 = buffer_.view() | |
buffer2.use(sample_range_view(buffer1, start=-2)) | |
for _ in range(10): | |
sampled_data = buffer2.sample(1) | |
assert sampled_data[0].data['data'] == 'z' | |