Spaces:
Sleeping
Sleeping
from typing import Callable, Any, List, Dict, Optional, Union, TYPE_CHECKING | |
import copy | |
import numpy as np | |
import torch | |
from ding.utils import SumSegmentTree, MinSegmentTree | |
from ding.data.buffer.buffer import BufferedData | |
if TYPE_CHECKING: | |
from ding.data.buffer.buffer import Buffer | |
class PriorityExperienceReplay: | |
""" | |
Overview: | |
The middleware that implements priority experience replay (PER). | |
""" | |
def __init__( | |
self, | |
buffer: 'Buffer', | |
IS_weight: bool = True, | |
priority_power_factor: float = 0.6, | |
IS_weight_power_factor: float = 0.4, | |
IS_weight_anneal_train_iter: int = int(1e5), | |
) -> None: | |
""" | |
Arguments: | |
- buffer (:obj:`Buffer`): The buffer to use PER. | |
- IS_weight (:obj:`bool`): Whether use importance sampling or not. | |
- priority_power_factor (:obj:`float`): The factor that adjust the sensitivity between\ | |
the sampling probability and the priority level. | |
- IS_weight_power_factor (:obj:`float`): The factor that adjust the sensitivity between\ | |
the sample rarity and sampling probability in importance sampling. | |
- IS_weight_anneal_train_iter (:obj:`float`): The factor that controls the increasing of\ | |
``IS_weight_power_factor`` during training. | |
""" | |
self.buffer = buffer | |
self.buffer_idx = {} | |
self.buffer_size = buffer.size | |
self.IS_weight = IS_weight | |
self.priority_power_factor = priority_power_factor | |
self.IS_weight_power_factor = IS_weight_power_factor | |
self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter | |
# Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data. | |
self.max_priority = 1.0 | |
# Capacity needs to be the power of 2. | |
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) | |
self.sum_tree = SumSegmentTree(capacity) | |
if self.IS_weight: | |
self.min_tree = MinSegmentTree(capacity) | |
self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter | |
self.pivot = 0 | |
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData: | |
if meta is None: | |
if 'priority' in data: | |
meta = {'priority': data.pop('priority')} | |
else: | |
meta = {'priority': self.max_priority} | |
else: | |
if 'priority' not in meta: | |
meta['priority'] = self.max_priority | |
meta['priority_idx'] = self.pivot | |
self._update_tree(meta['priority'], self.pivot) | |
buffered = chain(data, meta=meta, *args, **kwargs) | |
index = buffered.index | |
self.buffer_idx[self.pivot] = index | |
self.pivot = (self.pivot + 1) % self.buffer_size | |
return buffered | |
def sample(self, chain: Callable, size: int, *args, | |
**kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: | |
# Divide [0, 1) into size intervals on average | |
intervals = np.array([i * 1.0 / size for i in range(size)]) | |
# Uniformly sample within each interval | |
mass = intervals + np.random.uniform(size=(size, )) * 1. / size | |
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree) | |
mass *= self.sum_tree.reduce() | |
indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass] | |
indices = [self.buffer_idx[i] for i in indices] | |
# Sample with indices | |
data = chain(indices=indices, *args, **kwargs) | |
if self.IS_weight: | |
# Calculate max weight for normalizing IS | |
sum_tree_root = self.sum_tree.reduce() | |
p_min = self.min_tree.reduce() / sum_tree_root | |
buffer_count = self.buffer.count() | |
max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor) | |
for i in range(len(data)): | |
meta = data[i].meta | |
priority_idx = meta['priority_idx'] | |
p_sample = self.sum_tree[priority_idx] / sum_tree_root | |
weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor) | |
meta['priority_IS'] = weight / max_weight | |
data[i].data['priority_IS'] = torch.as_tensor([meta['priority_IS']]).float() # for compability | |
self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal) | |
return data | |
def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwargs) -> None: | |
update_flag = chain(index, data, meta, *args, **kwargs) | |
if update_flag: # when update succeed | |
assert meta is not None, "Please indicate dict-type meta in priority update" | |
new_priority, idx = meta['priority'], meta['priority_idx'] | |
assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority) | |
new_priority += 1e-5 # Add epsilon to avoid priority == 0 | |
self._update_tree(new_priority, idx) | |
self.max_priority = max(self.max_priority, new_priority) | |
def delete(self, chain: Callable, index: str, *args, **kwargs) -> None: | |
for item in self.buffer.storage: | |
meta = item.meta | |
priority_idx = meta['priority_idx'] | |
self.sum_tree[priority_idx] = self.sum_tree.neutral_element | |
self.min_tree[priority_idx] = self.min_tree.neutral_element | |
self.buffer_idx.pop(priority_idx) | |
return chain(index, *args, **kwargs) | |
def clear(self, chain: Callable) -> None: | |
self.max_priority = 1.0 | |
capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size)))) | |
self.sum_tree = SumSegmentTree(capacity) | |
if self.IS_weight: | |
self.min_tree = MinSegmentTree(capacity) | |
self.buffer_idx = {} | |
self.pivot = 0 | |
chain() | |
def _update_tree(self, priority: float, idx: int) -> None: | |
weight = priority ** self.priority_power_factor | |
self.sum_tree[idx] = weight | |
if self.IS_weight: | |
self.min_tree[idx] = weight | |
def state_dict(self) -> Dict: | |
return { | |
'max_priority': self.max_priority, | |
'IS_weight_power_factor': self.IS_weight_power_factor, | |
'sumtree': self.sumtree, | |
'mintree': self.mintree, | |
'buffer_idx': self.buffer_idx, | |
} | |
def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None: | |
for k, v in _state_dict.items(): | |
if deepcopy: | |
setattr(self, '{}'.format(k), copy.deepcopy(v)) | |
else: | |
setattr(self, '{}'.format(k), v) | |
def __call__(self, action: str, chain: Callable, *args, **kwargs) -> Any: | |
if action in ["push", "sample", "update", "delete", "clear"]: | |
return getattr(self, action)(chain, *args, **kwargs) | |
return chain(*args, **kwargs) | |