|
from collections.abc import Sequence |
|
from typing import Union |
|
|
|
import numpy as np |
|
from numba import njit |
|
|
|
from tianshou.data import Batch, HERReplayBuffer, PrioritizedReplayBuffer, ReplayBuffer |
|
from tianshou.data.batch import alloc_by_keys_diff, create_value |
|
from tianshou.data.types import RolloutBatchProtocol |
|
|
|
|
|
class ReplayBufferManager(ReplayBuffer): |
|
"""ReplayBufferManager contains a list of ReplayBuffer with exactly the same configuration. |
|
|
|
These replay buffers have contiguous memory layout, and the storage space each |
|
buffer has is a shallow copy of the topmost memory. |
|
|
|
:param buffer_list: a list of ReplayBuffer needed to be handled. |
|
|
|
.. seealso:: |
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. |
|
""" |
|
|
|
def __init__(self, buffer_list: list[ReplayBuffer] | list[HERReplayBuffer]) -> None: |
|
self.buffer_num = len(buffer_list) |
|
self.buffers = np.array(buffer_list, dtype=object) |
|
offset, size = [], 0 |
|
buffer_type = type(self.buffers[0]) |
|
kwargs = self.buffers[0].options |
|
for buf in self.buffers: |
|
assert buf._meta.is_empty() |
|
assert isinstance(buf, buffer_type) |
|
assert buf.options == kwargs |
|
offset.append(size) |
|
size += buf.maxsize |
|
self._offset = np.array(offset) |
|
self._extend_offset = np.array([*offset, size]) |
|
self._lengths = np.zeros_like(offset) |
|
super().__init__(size=size, **kwargs) |
|
self._compile() |
|
self._meta: RolloutBatchProtocol |
|
|
|
def _compile(self) -> None: |
|
lens = last = index = np.array([0]) |
|
offset = np.array([0, 1]) |
|
done = np.array([False, False]) |
|
_prev_index(index, offset, done, last, lens) |
|
_next_index(index, offset, done, last, lens) |
|
|
|
def __len__(self) -> int: |
|
return int(self._lengths.sum()) |
|
|
|
def reset(self, keep_statistics: bool = False) -> None: |
|
self.last_index = self._offset.copy() |
|
self._lengths = np.zeros_like(self._offset) |
|
for buf in self.buffers: |
|
buf.reset(keep_statistics=keep_statistics) |
|
|
|
def _set_batch_for_children(self) -> None: |
|
for offset, buf in zip(self._offset, self.buffers, strict=True): |
|
buf.set_batch(self._meta[offset : offset + buf.maxsize]) |
|
|
|
def set_batch(self, batch: RolloutBatchProtocol) -> None: |
|
super().set_batch(batch) |
|
self._set_batch_for_children() |
|
|
|
def unfinished_index(self) -> np.ndarray: |
|
return np.concatenate( |
|
[ |
|
buf.unfinished_index() + offset |
|
for offset, buf in zip(self._offset, self.buffers, strict=True) |
|
], |
|
) |
|
|
|
def prev(self, index: int | np.ndarray) -> np.ndarray: |
|
if isinstance(index, list | np.ndarray): |
|
return _prev_index( |
|
np.asarray(index), |
|
self._extend_offset, |
|
self.done, |
|
self.last_index, |
|
self._lengths, |
|
) |
|
return _prev_index( |
|
np.array([index]), |
|
self._extend_offset, |
|
self.done, |
|
self.last_index, |
|
self._lengths, |
|
)[0] |
|
|
|
def next(self, index: int | np.ndarray) -> np.ndarray: |
|
if isinstance(index, list | np.ndarray): |
|
return _next_index( |
|
np.asarray(index), |
|
self._extend_offset, |
|
self.done, |
|
self.last_index, |
|
self._lengths, |
|
) |
|
return _next_index( |
|
np.array([index]), |
|
self._extend_offset, |
|
self.done, |
|
self.last_index, |
|
self._lengths, |
|
)[0] |
|
|
|
def update(self, buffer: ReplayBuffer) -> np.ndarray: |
|
"""The ReplayBufferManager cannot be updated by any buffer.""" |
|
raise NotImplementedError |
|
|
|
def add( |
|
self, |
|
batch: RolloutBatchProtocol, |
|
buffer_ids: np.ndarray | list[int] | None = None, |
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
|
"""Add a batch of data into ReplayBufferManager. |
|
|
|
Each of the data's length (first dimension) must equal to the length of |
|
buffer_ids. By default buffer_ids is [0, 1, ..., buffer_num - 1]. |
|
|
|
Return (current_index, episode_reward, episode_length, episode_start_index). If |
|
the episode is not finished, the return value of episode_length and |
|
episode_reward is 0. |
|
""" |
|
|
|
new_batch = Batch() |
|
for key in set(self._reserved_keys).intersection(batch.get_keys()): |
|
new_batch.__dict__[key] = batch[key] |
|
batch = new_batch |
|
batch.__dict__["done"] = np.logical_or(batch.terminated, batch.truncated) |
|
assert {"obs", "act", "rew", "terminated", "truncated", "done"}.issubset(batch.get_keys()) |
|
if self._save_only_last_obs: |
|
batch.obs = batch.obs[:, -1] |
|
if not self._save_obs_next: |
|
batch.pop("obs_next", None) |
|
elif self._save_only_last_obs: |
|
batch.obs_next = batch.obs_next[:, -1] |
|
|
|
if buffer_ids is None: |
|
buffer_ids = np.arange(self.buffer_num) |
|
ptrs, ep_lens, ep_rews, ep_idxs = [], [], [], [] |
|
for batch_idx, buffer_id in enumerate(buffer_ids): |
|
ptr, ep_rew, ep_len, ep_idx = self.buffers[buffer_id]._add_index( |
|
batch.rew[batch_idx], |
|
batch.done[batch_idx], |
|
) |
|
ptrs.append(ptr + self._offset[buffer_id]) |
|
ep_lens.append(ep_len) |
|
ep_rews.append(ep_rew) |
|
ep_idxs.append(ep_idx + self._offset[buffer_id]) |
|
self.last_index[buffer_id] = ptr + self._offset[buffer_id] |
|
self._lengths[buffer_id] = len(self.buffers[buffer_id]) |
|
ptrs = np.array(ptrs) |
|
try: |
|
self._meta[ptrs] = batch |
|
except ValueError: |
|
batch.rew = batch.rew.astype(float) |
|
batch.done = batch.done.astype(bool) |
|
batch.terminated = batch.terminated.astype(bool) |
|
batch.truncated = batch.truncated.astype(bool) |
|
if self._meta.is_empty(): |
|
self._meta = create_value(batch, self.maxsize, stack=False) |
|
else: |
|
alloc_by_keys_diff(self._meta, batch, self.maxsize, False) |
|
self._set_batch_for_children() |
|
self._meta[ptrs] = batch |
|
return ptrs, np.array(ep_rews), np.array(ep_lens), np.array(ep_idxs) |
|
|
|
def sample_indices(self, batch_size: int | None) -> np.ndarray: |
|
|
|
if batch_size is not None and batch_size < 0: |
|
|
|
return np.array([], int) |
|
if self._sample_avail and self.stack_num > 1: |
|
all_indices = np.concatenate( |
|
[ |
|
buf.sample_indices(0) + offset |
|
for offset, buf in zip(self._offset, self.buffers, strict=True) |
|
], |
|
) |
|
if batch_size == 0: |
|
return all_indices |
|
if batch_size is None: |
|
batch_size = len(all_indices) |
|
return np.random.choice(all_indices, batch_size) |
|
if batch_size == 0 or batch_size is None: |
|
sample_num = np.zeros(self.buffer_num, int) |
|
else: |
|
buffer_idx = np.random.choice( |
|
self.buffer_num, |
|
batch_size, |
|
p=self._lengths / self._lengths.sum(), |
|
) |
|
sample_num = np.bincount(buffer_idx, minlength=self.buffer_num) |
|
|
|
sample_num[sample_num == 0] = -1 |
|
|
|
return np.concatenate( |
|
[ |
|
buf.sample_indices(int(bsz)) + offset |
|
for offset, buf, bsz in zip(self._offset, self.buffers, sample_num, strict=True) |
|
], |
|
) |
|
|
|
|
|
class PrioritizedReplayBufferManager(PrioritizedReplayBuffer, ReplayBufferManager): |
|
"""PrioritizedReplayBufferManager contains a list of PrioritizedReplayBuffer with exactly the same configuration. |
|
|
|
These replay buffers have contiguous memory layout, and the storage space each |
|
buffer has is a shallow copy of the topmost memory. |
|
|
|
:param buffer_list: a list of PrioritizedReplayBuffer needed to be handled. |
|
|
|
.. seealso:: |
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. |
|
""" |
|
|
|
def __init__(self, buffer_list: Sequence[PrioritizedReplayBuffer]) -> None: |
|
ReplayBufferManager.__init__(self, buffer_list) |
|
kwargs = buffer_list[0].options |
|
for buf in buffer_list: |
|
del buf.weight |
|
PrioritizedReplayBuffer.__init__(self, self.maxsize, **kwargs) |
|
|
|
|
|
class HERReplayBufferManager(ReplayBufferManager): |
|
"""HERReplayBufferManager contains a list of HERReplayBuffer with exactly the same configuration. |
|
|
|
These replay buffers have contiguous memory layout, and the storage space each |
|
buffer has is a shallow copy of the topmost memory. |
|
|
|
:param buffer_list: a list of HERReplayBuffer needed to be handled. |
|
|
|
.. seealso:: |
|
|
|
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage. |
|
""" |
|
|
|
def __init__(self, buffer_list: list[HERReplayBuffer]) -> None: |
|
super().__init__(buffer_list) |
|
|
|
def _restore_cache(self) -> None: |
|
for buf in self.buffers: |
|
buf._restore_cache() |
|
|
|
def save_hdf5(self, path: str, compression: str | None = None) -> None: |
|
self._restore_cache() |
|
return super().save_hdf5(path, compression) |
|
|
|
def set_batch(self, batch: RolloutBatchProtocol) -> None: |
|
self._restore_cache() |
|
return super().set_batch(batch) |
|
|
|
def update(self, buffer: Union["HERReplayBuffer", "ReplayBuffer"]) -> np.ndarray: |
|
self._restore_cache() |
|
return super().update(buffer) |
|
|
|
def add( |
|
self, |
|
batch: RolloutBatchProtocol, |
|
buffer_ids: np.ndarray | list[int] | None = None, |
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: |
|
self._restore_cache() |
|
return super().add(batch, buffer_ids) |
|
|
|
|
|
@njit |
|
def _prev_index( |
|
index: np.ndarray, |
|
offset: np.ndarray, |
|
done: np.ndarray, |
|
last_index: np.ndarray, |
|
lengths: np.ndarray, |
|
) -> np.ndarray: |
|
index = index % offset[-1] |
|
prev_index = np.zeros_like(index) |
|
|
|
|
|
for start, end, cur_len, last in zip( |
|
offset[:-1], |
|
offset[1:], |
|
lengths, |
|
last_index, |
|
): |
|
mask = (start <= index) & (index < end) |
|
correct_cur_len = max(1, cur_len) |
|
if np.sum(mask) > 0: |
|
subind = index[mask] |
|
subind = (subind - start - 1) % correct_cur_len |
|
end_flag = done[subind + start] | (subind + start == last) |
|
prev_index[mask] = (subind + end_flag) % correct_cur_len + start |
|
return prev_index |
|
|
|
|
|
@njit |
|
def _next_index( |
|
index: np.ndarray, |
|
offset: np.ndarray, |
|
done: np.ndarray, |
|
last_index: np.ndarray, |
|
lengths: np.ndarray, |
|
) -> np.ndarray: |
|
index = index % offset[-1] |
|
next_index = np.zeros_like(index) |
|
|
|
|
|
for start, end, cur_len, last in zip( |
|
offset[:-1], |
|
offset[1:], |
|
lengths, |
|
last_index, |
|
): |
|
mask = (start <= index) & (index < end) |
|
correct_cur_len = max(1, cur_len) |
|
if np.sum(mask) > 0: |
|
subind = index[mask] |
|
end_flag = done[subind] | (subind == last) |
|
next_index[mask] = (subind - start + 1 - end_flag) % correct_cur_len + start |
|
return next_index |
|
|