Spaces:
Sleeping
Sleeping
import sys | |
import timeit | |
import torch | |
import random | |
import pytest | |
import numpy as np | |
from ding.data.buffer import DequeBuffer | |
from ding.data.buffer.middleware import clone_object, PriorityExperienceReplay | |
# test different buffer size, eg: 1000, 10000, 100000; | |
size_list = [1000, 10000] | |
# test different tensor dim, eg: 32*32, 128*128, 512*512; | |
data_dim_list = [32, 128] | |
# repeat times. | |
repeats = 100 | |
class BufferBenchmark: | |
def __init__(self, buffer_size, data_dim, buffer_type='base') -> None: | |
self._buffer = DequeBuffer(size=buffer_size) | |
self._meta = dict() | |
if buffer_type == "clone": | |
self._buffer.use(clone_object()) | |
if buffer_type == "priority": | |
self._buffer.use(PriorityExperienceReplay(self._buffer)) | |
self._meta["priority"] = 2.0 | |
self._data = {"obs": torch.rand(data_dim, data_dim)} | |
def data_storage(self) -> float: | |
return sys.getsizeof(self._data["obs"].storage()) / 1024 | |
def count(self) -> int: | |
return self._buffer.count() | |
def push_op(self) -> None: | |
self._buffer.push(self._data, meta=self._meta) | |
def push_with_group_info(self, num_keys=256) -> None: | |
meta = self._meta.copy() | |
rand = random.random() | |
value = int(rand * num_keys) | |
meta['group'] = value | |
self._buffer.push(self._data, meta=meta) | |
def sample_op(self) -> None: | |
self._buffer.sample(128, replace=False) | |
def replace_sample_op(self) -> None: | |
self._buffer.sample(128, replace=True) | |
def groupby_sample_op(self) -> None: | |
self._buffer.sample(128, groupby="group") | |
def get_mean_std(res): | |
# return the total time per 1000 ops | |
return np.mean(res) * 1000.0 / repeats, np.std(res) * 1000.0 / repeats | |
def test_benchmark(buffer_type): | |
for size in size_list: | |
for dim in data_dim_list: | |
assert size >= 128, "size is too small, please set an int no less than 128!" | |
buffer_test = BufferBenchmark(size, dim, buffer_type) | |
print("exp-buffer_{}_{}-data_{:.2f}_KB".format(buffer_type, size, buffer_test.data_storage())) | |
# test pushing | |
mean, std = get_mean_std(timeit.repeat(buffer_test.push_op, number=repeats)) | |
print("Empty Push Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
# fill the buffer before sampling tests | |
for _ in range(size): | |
buffer_test.push_with_group_info() | |
assert buffer_test.count() == size, "buffer is not full when testing sampling!" | |
# test sampling without replace | |
mean, std = get_mean_std(timeit.repeat(buffer_test.sample_op, number=repeats)) | |
print("No-Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
# test sampling with replace | |
mean, std = get_mean_std(timeit.repeat(buffer_test.replace_sample_op, number=repeats)) | |
print("Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
# test groupby sampling | |
if buffer_type != 'priority': | |
mean, std = get_mean_std(timeit.repeat(buffer_test.groupby_sample_op, number=repeats)) | |
print("Groupby Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std)) | |
print("=" * 100) | |