Spaces:
Sleeping
Sleeping
import os | |
import copy | |
import time | |
from typing import Union, Any, Optional, List, Dict, Tuple | |
import numpy as np | |
import hickle | |
from ding.worker.replay_buffer import IBuffer | |
from ding.utils import SumSegmentTree, MinSegmentTree, BUFFER_REGISTRY | |
from ding.utils import LockContext, LockContextType, build_logger, get_rank | |
from ding.utils.autolog import TickTime | |
from .utils import UsedDataRemover, generate_id, SampledDataAttrMonitor, PeriodicThruputMonitor, ThruputController | |
def to_positive_index(idx: Union[int, None], size: int) -> int: | |
if idx is None or idx >= 0: | |
return idx | |
else: | |
return size + idx | |
class AdvancedReplayBuffer(IBuffer): | |
r""" | |
Overview: | |
Prioritized replay buffer derived from ``NaiveReplayBuffer``. | |
This replay buffer adds: | |
1) Prioritized experience replay implemented by segment tree. | |
2) Data quality monitor. Monitor use count and staleness of each data. | |
3) Throughput monitor and control. | |
4) Logger. Log 2) and 3) in tensorboard or text. | |
Interface: | |
start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config | |
Property: | |
beta, replay_buffer_size, push_count | |
""" | |
config = dict( | |
type='advanced', | |
# Max length of the buffer. | |
replay_buffer_size=4096, | |
# Max use times of one data in the buffer. Data will be removed once used for too many times. | |
max_use=float("inf"), | |
# Max staleness time duration of one data in the buffer; Data will be removed if | |
# the duration from collecting to training is too long, i.e. The data is too stale. | |
max_staleness=float("inf"), | |
# (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization | |
alpha=0.6, | |
# (Float type) How much correction is used: 0 means no correction while 1 means full correction | |
beta=0.4, | |
# Anneal step for beta: 0 means no annealing | |
anneal_step=int(1e5), | |
# Whether to track the used data. Used data means they are removed out of buffer and would never be used again. | |
enable_track_used_data=False, | |
# Whether to deepcopy data when willing to insert and sample data. For security purpose. | |
deepcopy=False, | |
thruput_controller=dict( | |
# Rate limit. The ratio of "Sample Count" to "Push Count" should be in [min, max] range. | |
# If greater than max ratio, return `None` when calling ``sample```; | |
# If smaller than min ratio, throw away the new data when calling ``push``. | |
push_sample_rate_limit=dict( | |
max=float("inf"), | |
min=0, | |
), | |
# Controller will take how many seconds into account, i.e. For the past `window_seconds` seconds, | |
# sample_push_rate will be calculated and campared with `push_sample_rate_limit`. | |
window_seconds=30, | |
# The minimum ratio that buffer must satisfy before anything can be sampled. | |
# The ratio is calculated by "Valid Count" divided by "Batch Size". | |
# E.g. sample_min_limit_ratio = 2.0, valid_count = 50, batch_size = 32, it is forbidden to sample. | |
sample_min_limit_ratio=1, | |
), | |
# Monitor configuration for monitor and logger to use. This part does not affect buffer's function. | |
monitor=dict( | |
sampled_data_attr=dict( | |
# Past datas will be used for moving average. | |
average_range=5, | |
# Print data attributes every `print_freq` samples. | |
print_freq=200, # times | |
), | |
periodic_thruput=dict( | |
# Every `seconds` seconds, thruput(push/sample/remove count) will be printed. | |
seconds=60, | |
), | |
), | |
) | |
def __init__( | |
self, | |
cfg: dict, | |
tb_logger: Optional['SummaryWriter'] = None, # noqa | |
exp_name: Optional[str] = 'default_experiment', | |
instance_name: Optional[str] = 'buffer', | |
) -> int: | |
""" | |
Overview: | |
Initialize the buffer | |
Arguments: | |
- cfg (:obj:`dict`): Config dict. | |
- tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode. | |
- exp_name (:obj:`Optional[str]`): Name of this experiment. | |
- instance_name (:obj:`Optional[str]`): Name of this instance. | |
""" | |
self._exp_name = exp_name | |
self._instance_name = instance_name | |
self._end_flag = False | |
self._cfg = cfg | |
self._rank = get_rank() | |
self._replay_buffer_size = self._cfg.replay_buffer_size | |
self._deepcopy = self._cfg.deepcopy | |
# ``_data`` is a circular queue to store data (full data or meta data) | |
self._data = [None for _ in range(self._replay_buffer_size)] | |
# Current valid data count, indicating how many elements in ``self._data`` is valid. | |
self._valid_count = 0 | |
# How many pieces of data have been pushed into this buffer, should be no less than ``_valid_count``. | |
self._push_count = 0 | |
# Point to the tail position where next data can be inserted, i.e. latest inserted data's next position. | |
self._tail = 0 | |
# Is used to generate a unique id for each data: If a new data is inserted, its unique id will be this. | |
self._next_unique_id = 0 | |
# Lock to guarantee thread safe | |
self._lock = LockContext(type_=LockContextType.THREAD_LOCK) | |
# Point to the head of the circular queue. The true data is the stalest(oldest) data in this queue. | |
# Because buffer would remove data due to staleness or use count, and at the beginning when queue is not | |
# filled with data head would always be 0, so ``head`` may be not equal to ``tail``; | |
# Otherwise, they two should be the same. Head is used to optimize staleness check in ``_sample_check``. | |
self._head = 0 | |
# use_count is {position_idx: use_count} | |
self._use_count = {idx: 0 for idx in range(self._cfg.replay_buffer_size)} | |
# Max priority till now. Is used to initizalize a data's priority if "priority" is not passed in with the data. | |
self._max_priority = 1.0 | |
# A small positive number to avoid edge-case, e.g. "priority" == 0. | |
self._eps = 1e-5 | |
# Data check function list, used in ``_append`` and ``_extend``. This buffer requires data to be dict. | |
self.check_list = [lambda x: isinstance(x, dict)] | |
self._max_use = self._cfg.max_use | |
self._max_staleness = self._cfg.max_staleness | |
self.alpha = self._cfg.alpha | |
assert 0 <= self.alpha <= 1, self.alpha | |
self._beta = self._cfg.beta | |
assert 0 <= self._beta <= 1, self._beta | |
self._anneal_step = self._cfg.anneal_step | |
if self._anneal_step != 0: | |
self._beta_anneal_step = (1 - self._beta) / self._anneal_step | |
# Prioritized sample. | |
# Capacity needs to be the power of 2. | |
capacity = int(np.power(2, np.ceil(np.log2(self.replay_buffer_size)))) | |
# Sum segtree and min segtree are used to sample data according to priority. | |
self._sum_tree = SumSegmentTree(capacity) | |
self._min_tree = MinSegmentTree(capacity) | |
# Thruput controller | |
push_sample_rate_limit = self._cfg.thruput_controller.push_sample_rate_limit | |
self._always_can_push = True if push_sample_rate_limit['max'] == float('inf') else False | |
self._always_can_sample = True if push_sample_rate_limit['min'] == 0 else False | |
self._use_thruput_controller = not self._always_can_push or not self._always_can_sample | |
if self._use_thruput_controller: | |
self._thruput_controller = ThruputController(self._cfg.thruput_controller) | |
self._sample_min_limit_ratio = self._cfg.thruput_controller.sample_min_limit_ratio | |
assert self._sample_min_limit_ratio >= 1 | |
# Monitor & Logger | |
monitor_cfg = self._cfg.monitor | |
if self._rank == 0: | |
if tb_logger is not None: | |
self._logger, _ = build_logger( | |
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False | |
) | |
self._tb_logger = tb_logger | |
else: | |
self._logger, self._tb_logger = build_logger( | |
'./{}/log/{}'.format(self._exp_name, self._instance_name), | |
self._instance_name, | |
) | |
else: | |
self._logger, _ = build_logger( | |
'./{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False | |
) | |
self._tb_logger = None | |
self._start_time = time.time() | |
# Sampled data attributes. | |
self._cur_learner_iter = -1 | |
self._cur_collector_envstep = -1 | |
self._sampled_data_attr_print_count = 0 | |
self._sampled_data_attr_monitor = SampledDataAttrMonitor( | |
TickTime(), expire=monitor_cfg.sampled_data_attr.average_range | |
) | |
self._sampled_data_attr_print_freq = monitor_cfg.sampled_data_attr.print_freq | |
# Periodic thruput. | |
if self._rank == 0: | |
self._periodic_thruput_monitor = PeriodicThruputMonitor( | |
self._instance_name, monitor_cfg.periodic_thruput, self._logger, self._tb_logger | |
) | |
# Used data remover | |
self._enable_track_used_data = self._cfg.enable_track_used_data | |
if self._enable_track_used_data: | |
self._used_data_remover = UsedDataRemover() | |
def start(self) -> None: | |
""" | |
Overview: | |
Start the buffer's used_data_remover thread if enables track_used_data. | |
""" | |
if self._enable_track_used_data: | |
self._used_data_remover.start() | |
def close(self) -> None: | |
""" | |
Overview: | |
Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data. | |
Join periodic throughtput monitor, flush tensorboard logger. | |
""" | |
if self._end_flag: | |
return | |
self._end_flag = True | |
self.clear() | |
if self._rank == 0: | |
self._periodic_thruput_monitor.close() | |
self._tb_logger.flush() | |
self._tb_logger.close() | |
if self._enable_track_used_data: | |
self._used_data_remover.close() | |
def sample(self, size: int, cur_learner_iter: int, sample_range: slice = None) -> Optional[list]: | |
""" | |
Overview: | |
Sample data with length ``size``. | |
Arguments: | |
- size (:obj:`int`): The number of the data that will be sampled. | |
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. | |
- sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \ | |
means only sample among the last 10 data | |
Returns: | |
- sample_data (:obj:`list`): A list of data with length ``size`` | |
ReturnsKeys: | |
- necessary: original keys(e.g. `obs`, `action`, `next_obs`, `reward`, `info`), \ | |
`replay_unique_id`, `replay_buffer_idx` | |
- optional(if use priority): `IS`, `priority` | |
""" | |
if size == 0: | |
return [] | |
can_sample_stalenss, staleness_info = self._sample_check(size, cur_learner_iter) | |
if self._always_can_sample: | |
can_sample_thruput, thruput_info = True, "Always can sample because push_sample_rate_limit['min'] == 0" | |
else: | |
can_sample_thruput, thruput_info = self._thruput_controller.can_sample(size) | |
if not can_sample_stalenss or not can_sample_thruput: | |
self._logger.info( | |
'Refuse to sample due to -- \nstaleness: {}, {} \nthruput: {}, {}'.format( | |
not can_sample_stalenss, staleness_info, not can_sample_thruput, thruput_info | |
) | |
) | |
return None | |
with self._lock: | |
indices = self._get_indices(size, sample_range) | |
result = self._sample_with_indices(indices, cur_learner_iter) | |
# Deepcopy ``result``'s same indice datas in case ``self._get_indices`` may get datas with | |
# the same indices, i.e. the same datas would be sampled afterwards. | |
# if self._deepcopy==True -> all data is different | |
# if len(indices) == len(set(indices)) -> no duplicate data | |
if not self._deepcopy and len(indices) != len(set(indices)): | |
for i, index in enumerate(indices): | |
tmp = [] | |
for j in range(i + 1, size): | |
if index == indices[j]: | |
tmp.append(j) | |
for j in tmp: | |
result[j] = copy.deepcopy(result[j]) | |
self._monitor_update_of_sample(result, cur_learner_iter) | |
return result | |
def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None: | |
r""" | |
Overview: | |
Push a data into buffer. | |
Arguments: | |
- data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \ | |
(in `Any` type), or many(int `List[Any]` type). | |
- cur_collector_envstep (:obj:`int`): Collector's current env step. | |
""" | |
push_size = len(data) if isinstance(data, list) else 1 | |
if self._always_can_push: | |
can_push, push_info = True, "Always can push because push_sample_rate_limit['max'] == float('inf')" | |
else: | |
can_push, push_info = self._thruput_controller.can_push(push_size) | |
if not can_push: | |
self._logger.info('Refuse to push because {}'.format(push_info)) | |
return | |
if isinstance(data, list): | |
self._extend(data, cur_collector_envstep) | |
else: | |
self._append(data, cur_collector_envstep) | |
def save_data(self, file_name: str): | |
if not os.path.exists(os.path.dirname(file_name)): | |
if os.path.dirname(file_name) != "": | |
os.makedirs(os.path.dirname(file_name)) | |
hickle.dump(py_obj=self._data, file_obj=file_name) | |
def load_data(self, file_name: str): | |
self.push(hickle.load(file_name), 0) | |
def _sample_check(self, size: int, cur_learner_iter: int) -> Tuple[bool, str]: | |
r""" | |
Overview: | |
Do preparations for sampling and check whether data is enough for sampling | |
Preparation includes removing stale datas in ``self._data``. | |
Check includes judging whether this buffer has more than ``size`` datas to sample. | |
Arguments: | |
- size (:obj:`int`): The number of the data that will be sampled. | |
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. | |
Returns: | |
- can_sample (:obj:`bool`): Whether this buffer can sample enough data. | |
- str_info (:obj:`str`): Str type info, explaining why cannot sample. (If can sample, return "Can sample") | |
.. note:: | |
This function must be called before data sample. | |
""" | |
staleness_remove_count = 0 | |
with self._lock: | |
if self._max_staleness != float("inf"): | |
p = self._head | |
while True: | |
if self._data[p] is not None: | |
staleness = self._calculate_staleness(p, cur_learner_iter) | |
if staleness >= self._max_staleness: | |
self._remove(p) | |
staleness_remove_count += 1 | |
else: | |
# Since the circular queue ``self._data`` guarantees that data's staleness is decreasing | |
# from index self._head to index self._tail - 1, we can jump out of the loop as soon as | |
# meeting a fresh enough data | |
break | |
p = (p + 1) % self._replay_buffer_size | |
if p == self._tail: | |
# Traverse a circle and go back to the tail, which means can stop staleness checking now | |
break | |
str_info = "Remove {} elements due to staleness. ".format(staleness_remove_count) | |
if self._valid_count / size < self._sample_min_limit_ratio: | |
str_info += "Not enough for sampling. valid({}) / sample({}) < sample_min_limit_ratio({})".format( | |
self._valid_count, size, self._sample_min_limit_ratio | |
) | |
return False, str_info | |
else: | |
str_info += "Can sample." | |
return True, str_info | |
def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None: | |
r""" | |
Overview: | |
Append a data item into queue. | |
Add two keys in data: | |
- replay_unique_id: The data item's unique id, using ``generate_id`` to generate it. | |
- replay_buffer_idx: The data item's position index in the queue, this position may already have an \ | |
old element, then it would be replaced by this new input one. using ``self._tail`` to locate. | |
Arguments: | |
- ori_data (:obj:`Any`): The data which will be inserted. | |
- cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard. | |
""" | |
with self._lock: | |
if self._deepcopy: | |
data = copy.deepcopy(ori_data) | |
else: | |
data = ori_data | |
try: | |
assert self._data_check(data) | |
except AssertionError: | |
# If data check fails, log it and return without any operations. | |
self._logger.info('Illegal data type [{}], reject it...'.format(type(data))) | |
return | |
self._push_count += 1 | |
# remove->set weight->set data | |
if self._data[self._tail] is not None: | |
self._head = (self._tail + 1) % self._replay_buffer_size | |
self._remove(self._tail) | |
data['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id) | |
data['replay_buffer_idx'] = self._tail | |
self._set_weight(data) | |
self._data[self._tail] = data | |
self._valid_count += 1 | |
if self._rank == 0: | |
self._periodic_thruput_monitor.valid_count = self._valid_count | |
self._tail = (self._tail + 1) % self._replay_buffer_size | |
self._next_unique_id += 1 | |
self._monitor_update_of_push(1, cur_collector_envstep) | |
def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None: | |
r""" | |
Overview: | |
Extend a data list into queue. | |
Add two keys in each data item, you can refer to ``_append`` for more details. | |
Arguments: | |
- ori_data (:obj:`List[Any]`): The data list. | |
- cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard. | |
""" | |
with self._lock: | |
if self._deepcopy: | |
data = copy.deepcopy(ori_data) | |
else: | |
data = ori_data | |
check_result = [self._data_check(d) for d in data] | |
# Only keep data items that pass ``_data_check`. | |
valid_data = [d for d, flag in zip(data, check_result) if flag] | |
length = len(valid_data) | |
# When updating ``_data`` and ``_use_count``, should consider two cases regarding | |
# the relationship between "tail + data length" and "queue max length" to check whether | |
# data will exceed beyond queue's max length limitation. | |
if self._tail + length <= self._replay_buffer_size: | |
for j in range(self._tail, self._tail + length): | |
if self._data[j] is not None: | |
self._head = (j + 1) % self._replay_buffer_size | |
self._remove(j) | |
for i in range(length): | |
valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i) | |
valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size | |
self._set_weight(valid_data[i]) | |
self._push_count += 1 | |
self._data[self._tail:self._tail + length] = valid_data | |
else: | |
data_start = self._tail | |
valid_data_start = 0 | |
residual_num = len(valid_data) | |
while True: | |
space = self._replay_buffer_size - data_start | |
L = min(space, residual_num) | |
for j in range(data_start, data_start + L): | |
if self._data[j] is not None: | |
self._head = (j + 1) % self._replay_buffer_size | |
self._remove(j) | |
for i in range(valid_data_start, valid_data_start + L): | |
valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i) | |
valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size | |
self._set_weight(valid_data[i]) | |
self._push_count += 1 | |
self._data[data_start:data_start + L] = valid_data[valid_data_start:valid_data_start + L] | |
residual_num -= L | |
if residual_num <= 0: | |
break | |
else: | |
data_start = 0 | |
valid_data_start += L | |
self._valid_count += len(valid_data) | |
if self._rank == 0: | |
self._periodic_thruput_monitor.valid_count = self._valid_count | |
# Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer. | |
self._tail = (self._tail + length) % self._replay_buffer_size | |
self._next_unique_id += length | |
self._monitor_update_of_push(length, cur_collector_envstep) | |
def update(self, info: dict) -> None: | |
r""" | |
Overview: | |
Update a data's priority. Use `repaly_buffer_idx` to locate, and use `replay_unique_id` to verify. | |
Arguments: | |
- info (:obj:`dict`): Info dict containing all necessary keys for priority update. | |
ArgumentsKeys: | |
- necessary: `replay_unique_id`, `replay_buffer_idx`, `priority`. All values are lists with the same length. | |
""" | |
with self._lock: | |
if 'priority' not in info: | |
return | |
data = [info['replay_unique_id'], info['replay_buffer_idx'], info['priority']] | |
for id_, idx, priority in zip(*data): | |
# Only if the data still exists in the queue, will the update operation be done. | |
if self._data[idx] is not None \ | |
and self._data[idx]['replay_unique_id'] == id_: # Verify the same transition(data) | |
assert priority >= 0, priority | |
assert self._data[idx]['replay_buffer_idx'] == idx | |
self._data[idx]['priority'] = priority + self._eps # Add epsilon to avoid priority == 0 | |
self._set_weight(self._data[idx]) | |
# Update max priority | |
self._max_priority = max(self._max_priority, priority) | |
else: | |
self._logger.debug( | |
'[Skip Update]: buffer_idx: {}; id_in_buffer: {}; id_in_update_info: {}'.format( | |
idx, id_, priority | |
) | |
) | |
def clear(self) -> None: | |
""" | |
Overview: | |
Clear all the data and reset the related variables. | |
""" | |
with self._lock: | |
for i in range(len(self._data)): | |
self._remove(i) | |
assert self._valid_count == 0, self._valid_count | |
self._head = 0 | |
self._tail = 0 | |
self._max_priority = 1.0 | |
def __del__(self) -> None: | |
""" | |
Overview: | |
Call ``close`` to delete the object. | |
""" | |
if not self._end_flag: | |
self.close() | |
def _set_weight(self, data: Dict) -> None: | |
r""" | |
Overview: | |
Set sumtree and mintree's weight of the input data according to its priority. | |
If input data does not have key "priority", it would set to ``self._max_priority`` instead. | |
Arguments: | |
- data (:obj:`Dict`): The data whose priority(weight) in segement tree should be set/updated. | |
""" | |
if 'priority' not in data.keys() or data['priority'] is None: | |
data['priority'] = self._max_priority | |
weight = data['priority'] ** self.alpha | |
idx = data['replay_buffer_idx'] | |
self._sum_tree[idx] = weight | |
self._min_tree[idx] = weight | |
def _data_check(self, d: Any) -> bool: | |
r""" | |
Overview: | |
Data legality check, using rules(functions) in ``self.check_list``. | |
Arguments: | |
- d (:obj:`Any`): The data which needs to be checked. | |
Returns: | |
- result (:obj:`bool`): Whether the data passes the check. | |
""" | |
# only the data passes all the check functions, would the check return True | |
return all([fn(d) for fn in self.check_list]) | |
def _get_indices(self, size: int, sample_range: slice = None) -> list: | |
r""" | |
Overview: | |
Get the sample index list according to the priority probability. | |
Arguments: | |
- size (:obj:`int`): The number of the data that will be sampled | |
Returns: | |
- index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``. | |
""" | |
# Divide [0, 1) into size intervals on average | |
intervals = np.array([i * 1.0 / size for i in range(size)]) | |
# Uniformly sample within each interval | |
mass = intervals + np.random.uniform(size=(size, )) * 1. / size | |
if sample_range is None: | |
# Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree) | |
mass *= self._sum_tree.reduce() | |
else: | |
# Rescale to [a, b) | |
start = to_positive_index(sample_range.start, self._replay_buffer_size) | |
end = to_positive_index(sample_range.stop, self._replay_buffer_size) | |
a = self._sum_tree.reduce(0, start) | |
b = self._sum_tree.reduce(0, end) | |
mass = mass * (b - a) + a | |
# Find prefix sum index to sample with probability | |
return [self._sum_tree.find_prefixsum_idx(m) for m in mass] | |
def _remove(self, idx: int, use_too_many_times: bool = False) -> None: | |
r""" | |
Overview: | |
Remove a data(set the element in the list to ``None``) and update corresponding variables, | |
e.g. sum_tree, min_tree, valid_count. | |
Arguments: | |
- idx (:obj:`int`): Data at this position will be removed. | |
""" | |
if use_too_many_times: | |
if self._enable_track_used_data: | |
# Must track this data, but in parallel mode. | |
# Do not remove it, but make sure it will not be sampled. | |
self._data[idx]['priority'] = 0 | |
self._sum_tree[idx] = self._sum_tree.neutral_element | |
self._min_tree[idx] = self._min_tree.neutral_element | |
return | |
elif idx == self._head: | |
# Correct `self._head` when the queue head is removed due to use_count | |
self._head = (self._head + 1) % self._replay_buffer_size | |
if self._data[idx] is not None: | |
if self._enable_track_used_data: | |
self._used_data_remover.add_used_data(self._data[idx]) | |
self._valid_count -= 1 | |
if self._rank == 0: | |
self._periodic_thruput_monitor.valid_count = self._valid_count | |
self._periodic_thruput_monitor.remove_data_count += 1 | |
self._data[idx] = None | |
self._sum_tree[idx] = self._sum_tree.neutral_element | |
self._min_tree[idx] = self._min_tree.neutral_element | |
self._use_count[idx] = 0 | |
def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list: | |
r""" | |
Overview: | |
Sample data with ``indices``; Remove a data item if it is used for too many times. | |
Arguments: | |
- indices (:obj:`List[int]`): A list including all the sample indices. | |
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. | |
Returns: | |
- data (:obj:`list`) Sampled data. | |
""" | |
# Calculate max weight for normalizing IS | |
sum_tree_root = self._sum_tree.reduce() | |
p_min = self._min_tree.reduce() / sum_tree_root | |
max_weight = (self._valid_count * p_min) ** (-self._beta) | |
data = [] | |
for idx in indices: | |
assert self._data[idx] is not None | |
assert self._data[idx]['replay_buffer_idx'] == idx, (self._data[idx]['replay_buffer_idx'], idx) | |
if self._deepcopy: | |
copy_data = copy.deepcopy(self._data[idx]) | |
else: | |
copy_data = self._data[idx] | |
# Store staleness, use and IS(importance sampling weight for gradient step) for monitor and outer use | |
self._use_count[idx] += 1 | |
copy_data['staleness'] = self._calculate_staleness(idx, cur_learner_iter) | |
copy_data['use'] = self._use_count[idx] | |
p_sample = self._sum_tree[idx] / sum_tree_root | |
weight = (self._valid_count * p_sample) ** (-self._beta) | |
copy_data['IS'] = weight / max_weight | |
data.append(copy_data) | |
if self._max_use != float("inf"): | |
# Remove datas whose "use count" is greater than ``max_use`` | |
for idx in indices: | |
if self._use_count[idx] >= self._max_use: | |
self._remove(idx, use_too_many_times=True) | |
# Beta annealing | |
if self._anneal_step != 0: | |
self._beta = min(1.0, self._beta + self._beta_anneal_step) | |
return data | |
def _monitor_update_of_push(self, add_count: int, cur_collector_envstep: int = -1) -> None: | |
r""" | |
Overview: | |
Update values in monitor, then update text logger and tensorboard logger. | |
Called in ``_append`` and ``_extend``. | |
Arguments: | |
- add_count (:obj:`int`): How many datas are added into buffer. | |
- cur_collector_envstep (:obj:`int`): Collector envstep, passed in by collector. | |
""" | |
if self._rank == 0: | |
self._periodic_thruput_monitor.push_data_count += add_count | |
if self._use_thruput_controller: | |
self._thruput_controller.history_push_count += add_count | |
self._cur_collector_envstep = cur_collector_envstep | |
def _monitor_update_of_sample(self, sample_data: list, cur_learner_iter: int) -> None: | |
r""" | |
Overview: | |
Update values in monitor, then update text logger and tensorboard logger. | |
Called in ``sample``. | |
Arguments: | |
- sample_data (:obj:`list`): Sampled data. Used to get sample length and data's attributes, \ | |
e.g. use, priority, staleness, etc. | |
- cur_learner_iter (:obj:`int`): Learner iteration, passed in by learner. | |
""" | |
if self._rank == 0: | |
self._periodic_thruput_monitor.sample_data_count += len(sample_data) | |
if self._use_thruput_controller: | |
self._thruput_controller.history_sample_count += len(sample_data) | |
self._cur_learner_iter = cur_learner_iter | |
use_avg = sum([d['use'] for d in sample_data]) / len(sample_data) | |
use_max = max([d['use'] for d in sample_data]) | |
priority_avg = sum([d['priority'] for d in sample_data]) / len(sample_data) | |
priority_max = max([d['priority'] for d in sample_data]) | |
priority_min = min([d['priority'] for d in sample_data]) | |
staleness_avg = sum([d['staleness'] for d in sample_data]) / len(sample_data) | |
staleness_max = max([d['staleness'] for d in sample_data]) | |
self._sampled_data_attr_monitor.use_avg = use_avg | |
self._sampled_data_attr_monitor.use_max = use_max | |
self._sampled_data_attr_monitor.priority_avg = priority_avg | |
self._sampled_data_attr_monitor.priority_max = priority_max | |
self._sampled_data_attr_monitor.priority_min = priority_min | |
self._sampled_data_attr_monitor.staleness_avg = staleness_avg | |
self._sampled_data_attr_monitor.staleness_max = staleness_max | |
self._sampled_data_attr_monitor.time.step() | |
out_dict = { | |
'use_avg': self._sampled_data_attr_monitor.avg['use'](), | |
'use_max': self._sampled_data_attr_monitor.max['use'](), | |
'priority_avg': self._sampled_data_attr_monitor.avg['priority'](), | |
'priority_max': self._sampled_data_attr_monitor.max['priority'](), | |
'priority_min': self._sampled_data_attr_monitor.min['priority'](), | |
'staleness_avg': self._sampled_data_attr_monitor.avg['staleness'](), | |
'staleness_max': self._sampled_data_attr_monitor.max['staleness'](), | |
'beta': self._beta, | |
} | |
if self._sampled_data_attr_print_count % self._sampled_data_attr_print_freq == 0 and self._rank == 0: | |
self._logger.info("=== Sample data {} Times ===".format(self._sampled_data_attr_print_count)) | |
self._logger.info(self._logger.get_tabulate_vars_hor(out_dict)) | |
for k, v in out_dict.items(): | |
iter_metric = self._cur_learner_iter if self._cur_learner_iter != -1 else None | |
step_metric = self._cur_collector_envstep if self._cur_collector_envstep != -1 else None | |
if iter_metric is not None: | |
self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, iter_metric) | |
if step_metric is not None: | |
self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, step_metric) | |
self._sampled_data_attr_print_count += 1 | |
def _calculate_staleness(self, pos_index: int, cur_learner_iter: int) -> Optional[int]: | |
r""" | |
Overview: | |
Calculate a data's staleness according to its own attribute ``collect_iter`` | |
and input parameter ``cur_learner_iter``. | |
Arguments: | |
- pos_index (:obj:`int`): The position index. Staleness of the data at this index will be calculated. | |
- cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness. | |
Returns: | |
- staleness (:obj:`int`): Staleness of data at position ``pos_index``. | |
.. note:: | |
Caller should guarantee that data at ``pos_index`` is not None; Otherwise this function may raise an error. | |
""" | |
if self._data[pos_index] is None: | |
raise ValueError("Prioritized's data at index {} is None".format(pos_index)) | |
else: | |
# Calculate staleness, remove it if too stale | |
collect_iter = self._data[pos_index].get('collect_iter', cur_learner_iter + 1) | |
if isinstance(collect_iter, list): | |
# Timestep transition's collect_iter is a list | |
collect_iter = min(collect_iter) | |
# ``staleness`` might be -1, means invalid, e.g. collector does not report collecting model iter, | |
# or it is a demonstration buffer(which means data is not generated by collector) etc. | |
staleness = cur_learner_iter - collect_iter | |
return staleness | |
def count(self) -> int: | |
""" | |
Overview: | |
Count how many valid datas there are in the buffer. | |
Returns: | |
- count (:obj:`int`): Number of valid data. | |
""" | |
return self._valid_count | |
def beta(self) -> float: | |
return self._beta | |
def beta(self, beta: float) -> None: | |
self._beta = beta | |
def state_dict(self) -> dict: | |
""" | |
Overview: | |
Provide a state dict to keep a record of current buffer. | |
Returns: | |
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \ | |
With the dict, one can easily reproduce the buffer. | |
""" | |
return { | |
'data': self._data, | |
'use_count': self._use_count, | |
'tail': self._tail, | |
'max_priority': self._max_priority, | |
'anneal_step': self._anneal_step, | |
'beta': self._beta, | |
'head': self._head, | |
'next_unique_id': self._next_unique_id, | |
'valid_count': self._valid_count, | |
'push_count': self._push_count, | |
'sum_tree': self._sum_tree, | |
'min_tree': self._min_tree, | |
} | |
def load_state_dict(self, _state_dict: dict, deepcopy: bool = False) -> None: | |
""" | |
Overview: | |
Load state dict to reproduce the buffer. | |
Returns: | |
- state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. | |
""" | |
assert 'data' in _state_dict | |
if set(_state_dict.keys()) == set(['data']): | |
self._extend(_state_dict['data']) | |
else: | |
for k, v in _state_dict.items(): | |
if deepcopy: | |
setattr(self, '_{}'.format(k), copy.deepcopy(v)) | |
else: | |
setattr(self, '_{}'.format(k), v) | |
def replay_buffer_size(self) -> int: | |
return self._replay_buffer_size | |
def push_count(self) -> int: | |
return self._push_count | |