Spaces:
Sleeping
Sleeping
import os | |
from typing import Optional | |
import copy | |
from easydict import EasyDict | |
import numpy as np | |
import hickle | |
from ding.data.buffer import DequeBuffer | |
from ding.data.buffer.middleware import use_time_check, PriorityExperienceReplay | |
from ding.utils import BUFFER_REGISTRY | |
class DequeBufferWrapper(object): | |
def default_config(cls: type) -> EasyDict: | |
cfg = EasyDict(copy.deepcopy(cls.config)) | |
cfg.cfg_type = cls.__name__ + 'Dict' | |
return cfg | |
config = dict( | |
replay_buffer_size=10000, | |
max_use=float("inf"), | |
train_iter_per_log=100, | |
priority=False, | |
priority_IS_weight=False, | |
priority_power_factor=0.6, | |
IS_weight_power_factor=0.4, | |
IS_weight_anneal_train_iter=int(1e5), | |
priority_max_limit=1000, | |
) | |
def __init__( | |
self, | |
cfg: EasyDict, | |
tb_logger: Optional[object] = None, | |
exp_name: str = 'default_experiement', | |
instance_name: str = 'buffer' | |
) -> None: | |
self.cfg = cfg | |
self.priority_max_limit = cfg.priority_max_limit | |
self.name = '{}_iter'.format(instance_name) | |
self.tb_logger = tb_logger | |
self.buffer = DequeBuffer(size=cfg.replay_buffer_size) | |
self.last_log_train_iter = -1 | |
# use_count middleware | |
if self.cfg.max_use != float("inf"): | |
self.buffer.use(use_time_check(self.buffer, max_use=self.cfg.max_use)) | |
# priority middleware | |
if self.cfg.priority: | |
self.buffer.use( | |
PriorityExperienceReplay( | |
self.buffer, | |
IS_weight=self.cfg.priority_IS_weight, | |
priority_power_factor=self.cfg.priority_power_factor, | |
IS_weight_power_factor=self.cfg.IS_weight_power_factor, | |
IS_weight_anneal_train_iter=self.cfg.IS_weight_anneal_train_iter | |
) | |
) | |
self.last_sample_index = None | |
self.last_sample_meta = None | |
def sample(self, size: int, train_iter: int = 0): | |
output = self.buffer.sample(size=size, ignore_insufficient=True) | |
if len(output) > 0: | |
if self.last_log_train_iter == -1 or train_iter - self.last_log_train_iter >= self.cfg.train_iter_per_log: | |
meta = [o.meta for o in output] | |
if self.cfg.max_use != float("inf"): | |
use_count_avg = np.mean([m['use_count'] for m in meta]) | |
self.tb_logger.add_scalar('{}/use_count_avg'.format(self.name), use_count_avg, train_iter) | |
if self.cfg.priority: | |
self.last_sample_index = [o.index for o in output] | |
self.last_sample_meta = meta | |
priority_list = [m['priority'] for m in meta] | |
priority_avg = np.mean(priority_list) | |
priority_max = np.max(priority_list) | |
self.tb_logger.add_scalar('{}/priority_avg'.format(self.name), priority_avg, train_iter) | |
self.tb_logger.add_scalar('{}/priority_max'.format(self.name), priority_max, train_iter) | |
self.tb_logger.add_scalar('{}/buffer_data_count'.format(self.name), self.buffer.count(), train_iter) | |
self.last_log_train_iter = train_iter | |
data = [o.data for o in output] | |
if self.cfg.priority_IS_weight: | |
IS = [o.meta['priority_IS'] for o in output] | |
for i in range(len(data)): | |
data[i]['IS'] = IS[i] | |
return data | |
else: | |
return None | |
def push(self, data, cur_collector_envstep: int = -1) -> None: | |
for d in data: | |
meta = {} | |
if self.cfg.priority and 'priority' in d: | |
init_priority = d.pop('priority') | |
meta['priority'] = init_priority | |
self.buffer.push(d, meta=meta) | |
def update(self, meta: dict) -> None: | |
if not self.cfg.priority: | |
return | |
if self.last_sample_index is None: | |
return | |
new_meta = self.last_sample_meta | |
for m, p in zip(new_meta, meta['priority']): | |
m['priority'] = min(self.priority_max_limit, p) | |
for idx, m in zip(self.last_sample_index, new_meta): | |
self.buffer.update(idx, data=None, meta=m) | |
self.last_sample_index = None | |
self.last_sample_meta = None | |
def count(self) -> int: | |
return self.buffer.count() | |
def save_data(self, file_name): | |
self.buffer.save_data(file_name) | |
def load_data(self, file_name: str): | |
self.buffer.load_data(file_name) | |