Spaces:
Sleeping
Sleeping
from typing import List, Dict, Tuple | |
from ditk import logging | |
from copy import deepcopy | |
from easydict import EasyDict | |
from torch.utils.data import Dataset | |
from dataclasses import dataclass | |
import pickle | |
import easydict | |
import torch | |
import numpy as np | |
from ding.utils.bfs_helper import get_vi_sequence | |
from ding.utils import DATASET_REGISTRY, import_module, DatasetNormalizer | |
from ding.rl_utils import discount_cumsum | |
class DatasetStatistics: | |
""" | |
Overview: | |
Dataset statistics. | |
""" | |
mean: np.ndarray # obs | |
std: np.ndarray # obs | |
action_bounds: np.ndarray | |
class NaiveRLDataset(Dataset): | |
""" | |
Overview: | |
Naive RL dataset, which is used for offline RL algorithms. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
""" | |
def __init__(self, cfg) -> None: | |
""" | |
Overview: | |
Initialization method. | |
Arguments: | |
- cfg (:obj:`dict`): Config dict. | |
""" | |
assert type(cfg) in [str, EasyDict], "invalid cfg type: {}".format(type(cfg)) | |
if isinstance(cfg, EasyDict): | |
self._data_path = cfg.policy.collect.data_path | |
elif isinstance(cfg, str): | |
self._data_path = cfg | |
with open(self._data_path, 'rb') as f: | |
self._data: List[Dict[str, torch.Tensor]] = pickle.load(f) | |
def __len__(self) -> int: | |
""" | |
Overview: | |
Get the length of the dataset. | |
""" | |
return len(self._data) | |
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
""" | |
Overview: | |
Get the item of the dataset. | |
""" | |
return self._data[idx] | |
class D4RLDataset(Dataset): | |
""" | |
Overview: | |
D4RL dataset, which is used for offline RL algorithms. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
Properties: | |
- mean (:obj:`np.ndarray`): Mean of the dataset. | |
- std (:obj:`np.ndarray`): Std of the dataset. | |
- action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. | |
- statistics (:obj:`dict`): Statistics of the dataset. | |
""" | |
def __init__(self, cfg: dict) -> None: | |
""" | |
Overview: | |
Initialization method. | |
Arguments: | |
- cfg (:obj:`dict`): Config dict. | |
""" | |
import gym | |
try: | |
import d4rl # register d4rl enviroments with open ai gym | |
except ImportError: | |
import sys | |
logging.warning("not found d4rl env, please install it, refer to https://github.com/rail-berkeley/d4rl") | |
sys.exit(1) | |
# Init parameters | |
data_path = cfg.policy.collect.get('data_path', None) | |
env_id = cfg.env.env_id | |
# Create the environment | |
if data_path: | |
d4rl.set_dataset_path(data_path) | |
env = gym.make(env_id) | |
dataset = d4rl.qlearning_dataset(env) | |
self._cal_statistics(dataset, env) | |
try: | |
if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats: | |
dataset = self._normalize_states(dataset) | |
except (KeyError, AttributeError): | |
# do not normalize | |
pass | |
self._data = [] | |
self._load_d4rl(dataset) | |
def data(self) -> List: | |
return self._data | |
def __len__(self) -> int: | |
""" | |
Overview: | |
Get the length of the dataset. | |
""" | |
return len(self._data) | |
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
""" | |
Overview: | |
Get the item of the dataset. | |
""" | |
return self._data[idx] | |
def _load_d4rl(self, dataset: Dict[str, np.ndarray]) -> None: | |
""" | |
Overview: | |
Load the d4rl dataset. | |
Arguments: | |
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. | |
""" | |
for i in range(len(dataset['observations'])): | |
trans_data = {} | |
trans_data['obs'] = torch.from_numpy(dataset['observations'][i]) | |
trans_data['next_obs'] = torch.from_numpy(dataset['next_observations'][i]) | |
trans_data['action'] = torch.from_numpy(dataset['actions'][i]) | |
trans_data['reward'] = torch.tensor(dataset['rewards'][i]) | |
trans_data['done'] = dataset['terminals'][i] | |
self._data.append(trans_data) | |
def _cal_statistics(self, dataset, env, eps=1e-3, add_action_buffer=True): | |
""" | |
Overview: | |
Calculate the statistics of the dataset. | |
Arguments: | |
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. | |
- env (:obj:`gym.Env`): The environment. | |
- eps (:obj:`float`): Epsilon. | |
""" | |
self._mean = dataset['observations'].mean(0) | |
self._std = dataset['observations'].std(0) + eps | |
action_max = dataset['actions'].max(0) | |
action_min = dataset['actions'].min(0) | |
if add_action_buffer: | |
action_buffer = 0.05 * (action_max - action_min) | |
action_max = (action_max + action_buffer).clip(max=env.action_space.high) | |
action_min = (action_min - action_buffer).clip(min=env.action_space.low) | |
self._action_bounds = np.stack([action_min, action_max], axis=0) | |
def _normalize_states(self, dataset): | |
""" | |
Overview: | |
Normalize the states. | |
Arguments: | |
- dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset. | |
""" | |
dataset['observations'] = (dataset['observations'] - self._mean) / self._std | |
dataset['next_observations'] = (dataset['next_observations'] - self._mean) / self._std | |
return dataset | |
def mean(self): | |
""" | |
Overview: | |
Get the mean of the dataset. | |
""" | |
return self._mean | |
def std(self): | |
""" | |
Overview: | |
Get the std of the dataset. | |
""" | |
return self._std | |
def action_bounds(self) -> np.ndarray: | |
""" | |
Overview: | |
Get the action bounds of the dataset. | |
""" | |
return self._action_bounds | |
def statistics(self) -> dict: | |
""" | |
Overview: | |
Get the statistics of the dataset. | |
""" | |
return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) | |
class HDF5Dataset(Dataset): | |
""" | |
Overview: | |
HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms. | |
The hdf5 format is a common format for storing large numerical arrays in Python. | |
For more details, please refer to https://support.hdfgroup.org/HDF5/. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
Properties: | |
- mean (:obj:`np.ndarray`): Mean of the dataset. | |
- std (:obj:`np.ndarray`): Std of the dataset. | |
- action_bounds (:obj:`np.ndarray`): Action bounds of the dataset. | |
- statistics (:obj:`dict`): Statistics of the dataset. | |
""" | |
def __init__(self, cfg: dict) -> None: | |
""" | |
Overview: | |
Initialization method. | |
Arguments: | |
- cfg (:obj:`dict`): Config dict. | |
""" | |
try: | |
import h5py | |
except ImportError: | |
import sys | |
logging.warning("not found h5py package, please install it trough `pip install h5py ") | |
sys.exit(1) | |
data_path = cfg.policy.collect.get('data_path', None) | |
if 'dataset' in cfg: | |
self.context_len = cfg.dataset.context_len | |
else: | |
self.context_len = 0 | |
data = h5py.File(data_path, 'r') | |
self._load_data(data) | |
self._cal_statistics() | |
try: | |
if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats: | |
self._normalize_states() | |
except (KeyError, AttributeError): | |
# do not normalize | |
pass | |
def __len__(self) -> int: | |
""" | |
Overview: | |
Get the length of the dataset. | |
""" | |
return len(self._data['obs']) - self.context_len | |
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: | |
""" | |
Overview: | |
Get the item of the dataset. | |
Arguments: | |
- idx (:obj:`int`): The index of the dataset. | |
""" | |
if self.context_len == 0: # for other offline RL algorithms | |
return {k: self._data[k][idx] for k in self._data.keys()} | |
else: # for decision transformer | |
block_size = self.context_len | |
done_idx = idx + block_size | |
idx = done_idx - block_size | |
states = torch.as_tensor( | |
np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32 | |
).view(block_size, -1) | |
actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long) | |
rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32) | |
timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64) | |
traj_mask = torch.ones(self.context_len, dtype=torch.long) | |
return timesteps, states, actions, rtgs, traj_mask | |
def _load_data(self, dataset: Dict[str, np.ndarray]) -> None: | |
""" | |
Overview: | |
Load the dataset. | |
Arguments: | |
- dataset (:obj:`Dict[str, np.ndarray]`): The dataset. | |
""" | |
self._data = {} | |
for k in dataset.keys(): | |
logging.info(f'Load {k} data.') | |
self._data[k] = dataset[k][:] | |
def _cal_statistics(self, eps: float = 1e-3): | |
""" | |
Overview: | |
Calculate the statistics of the dataset. | |
Arguments: | |
- eps (:obj:`float`): Epsilon. | |
""" | |
self._mean = self._data['obs'].mean(0) | |
self._std = self._data['obs'].std(0) + eps | |
action_max = self._data['action'].max(0) | |
action_min = self._data['action'].min(0) | |
buffer = 0.05 * (action_max - action_min) | |
action_max = action_max.astype(float) + buffer | |
action_min = action_max.astype(float) - buffer | |
self._action_bounds = np.stack([action_min, action_max], axis=0) | |
def _normalize_states(self): | |
""" | |
Overview: | |
Normalize the states. | |
""" | |
self._data['obs'] = (self._data['obs'] - self._mean) / self._std | |
self._data['next_obs'] = (self._data['next_obs'] - self._mean) / self._std | |
def mean(self): | |
""" | |
Overview: | |
Get the mean of the dataset. | |
""" | |
return self._mean | |
def std(self): | |
""" | |
Overview: | |
Get the std of the dataset. | |
""" | |
return self._std | |
def action_bounds(self) -> np.ndarray: | |
""" | |
Overview: | |
Get the action bounds of the dataset. | |
""" | |
return self._action_bounds | |
def statistics(self) -> dict: | |
""" | |
Overview: | |
Get the statistics of the dataset. | |
""" | |
return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds) | |
class D4RLTrajectoryDataset(Dataset): | |
""" | |
Overview: | |
D4RL trajectory dataset, which is used for offline RL algorithms. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
""" | |
# from infos.py from official d4rl github repo | |
REF_MIN_SCORE = { | |
'halfcheetah': -280.178953, | |
'walker2d': 1.629008, | |
'hopper': -20.272305, | |
} | |
REF_MAX_SCORE = { | |
'halfcheetah': 12135.0, | |
'walker2d': 4592.3, | |
'hopper': 3234.3, | |
} | |
# calculated from d4rl datasets | |
D4RL_DATASET_STATS = { | |
'halfcheetah-medium-v2': { | |
'state_mean': [ | |
-0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164, | |
-0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436, | |
5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435, | |
0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445, | |
0.013382787816226482 | |
], | |
'state_std': [ | |
0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184, | |
0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577, | |
1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098, | |
5.671932697296143, 7.4982590675354 | |
] | |
}, | |
'halfcheetah-medium-replay-v2': { | |
'state_mean': [ | |
-0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193, | |
-0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682, | |
3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752, | |
0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994, | |
-0.015839405357837677 | |
], | |
'state_std': [ | |
0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494, | |
0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578, | |
1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416, | |
6.085654258728027, 7.25300407409668 | |
] | |
}, | |
'halfcheetah-medium-expert-v2': { | |
'state_mean': [ | |
-0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338, | |
-0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053, | |
8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784, | |
0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314 | |
], | |
'state_std': [ | |
0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533, | |
0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467, | |
1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797, | |
6.4811787605285645, 6.378620147705078 | |
] | |
}, | |
'walker2d-medium-v2': { | |
'state_mean': [ | |
1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026, | |
-0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777, | |
-0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654, | |
0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654 | |
], | |
'state_std': [ | |
0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724, | |
0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583, | |
1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145, | |
3.7445690631866455, 5.5851287841796875 | |
] | |
}, | |
'walker2d-medium-replay-v2': { | |
'state_mean': [ | |
1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221, | |
-0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662, | |
-0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088, | |
-0.08934258669614792, -0.2992438077926636, -0.5984178185462952 | |
], | |
'state_std': [ | |
0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303, | |
0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276, | |
2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096, | |
3.845186948776245, 5.4768385887146 | |
] | |
}, | |
'walker2d-medium-expert-v2': { | |
'state_mean': [ | |
1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075, | |
0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122, | |
3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811, | |
-0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786, | |
-0.27366524934768677 | |
], | |
'state_std': [ | |
0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586, | |
0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831, | |
1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857, | |
4.039782524108887, 5.891613960266113 | |
] | |
}, | |
'hopper-medium-v2': { | |
'state_mean': [ | |
1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081, | |
2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474, | |
-0.18540096282958984, -0.28461286425590515 | |
], | |
'state_std': [ | |
0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535, | |
0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754, | |
5.607253551483154 | |
] | |
}, | |
'hopper-medium-replay-v2': { | |
'state_mean': [ | |
1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224, | |
0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328, | |
-0.5287045240402222, -0.14465883374214172, -0.19652697443962097 | |
], | |
'state_std': [ | |
0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718, | |
1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137, | |
5.108601093292236 | |
] | |
}, | |
'hopper-medium-expert-v2': { | |
'state_mean': [ | |
1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415, | |
0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272, | |
-0.1766270101070404, -0.11862941086292267, -0.12097819894552231 | |
], | |
'state_std': [ | |
0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771, | |
0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893, | |
5.725032806396484 | |
] | |
}, | |
} | |
def __init__(self, cfg: dict) -> None: | |
""" | |
Overview: | |
Initialization method. | |
Arguments: | |
- cfg (:obj:`dict`): Config dict. | |
""" | |
dataset_path = cfg.dataset.data_dir_prefix | |
rtg_scale = cfg.dataset.rtg_scale | |
self.context_len = cfg.dataset.context_len | |
self.env_type = cfg.dataset.env_type | |
if 'hdf5' in dataset_path: # for mujoco env | |
try: | |
import h5py | |
import collections | |
except ImportError: | |
import sys | |
logging.warning("not found h5py package, please install it trough `pip install h5py ") | |
sys.exit(1) | |
dataset = h5py.File(dataset_path, 'r') | |
N = dataset['rewards'].shape[0] | |
data_ = collections.defaultdict(list) | |
use_timeouts = False | |
if 'timeouts' in dataset: | |
use_timeouts = True | |
episode_step = 0 | |
paths = [] | |
for i in range(N): | |
done_bool = bool(dataset['terminals'][i]) | |
if use_timeouts: | |
final_timestep = dataset['timeouts'][i] | |
else: | |
final_timestep = (episode_step == 1000 - 1) | |
for k in ['observations', 'actions', 'rewards', 'terminals']: | |
data_[k].append(dataset[k][i]) | |
if done_bool or final_timestep: | |
episode_step = 0 | |
episode_data = {} | |
for k in data_: | |
episode_data[k] = np.array(data_[k]) | |
paths.append(episode_data) | |
data_ = collections.defaultdict(list) | |
episode_step += 1 | |
self.trajectories = paths | |
# calculate state mean and variance and returns_to_go for all traj | |
states = [] | |
for traj in self.trajectories: | |
traj_len = traj['observations'].shape[0] | |
states.append(traj['observations']) | |
# calculate returns to go and rescale them | |
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale | |
# used for input normalization | |
states = np.concatenate(states, axis=0) | |
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 | |
# normalize states | |
for traj in self.trajectories: | |
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std | |
elif 'pkl' in dataset_path: | |
if 'dqn' in dataset_path: | |
# load dataset | |
with open(dataset_path, 'rb') as f: | |
self.trajectories = pickle.load(f) | |
if isinstance(self.trajectories[0], list): | |
# for our collected dataset, e.g. cartpole/lunarlander case | |
trajectories_tmp = [] | |
original_keys = ['obs', 'next_obs', 'action', 'reward'] | |
keys = ['observations', 'next_observations', 'actions', 'rewards'] | |
trajectories_tmp = [ | |
{ | |
key: np.stack( | |
[ | |
self.trajectories[eps_index][transition_index][o_key] | |
for transition_index in range(len(self.trajectories[eps_index])) | |
], | |
axis=0 | |
) | |
for key, o_key in zip(keys, original_keys) | |
} for eps_index in range(len(self.trajectories)) | |
] | |
self.trajectories = trajectories_tmp | |
states = [] | |
for traj in self.trajectories: | |
# traj_len = traj['observations'].shape[0] | |
states.append(traj['observations']) | |
# calculate returns to go and rescale them | |
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale | |
# used for input normalization | |
states = np.concatenate(states, axis=0) | |
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 | |
# normalize states | |
for traj in self.trajectories: | |
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std | |
else: | |
# load dataset | |
with open(dataset_path, 'rb') as f: | |
self.trajectories = pickle.load(f) | |
states = [] | |
for traj in self.trajectories: | |
states.append(traj['observations']) | |
# calculate returns to go and rescale them | |
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale | |
# used for input normalization | |
states = np.concatenate(states, axis=0) | |
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 | |
# normalize states | |
for traj in self.trajectories: | |
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std | |
else: | |
# -- load data from memory (make more efficient) | |
obss = [] | |
actions = [] | |
returns = [0] | |
done_idxs = [] | |
stepwise_returns = [] | |
transitions_per_buffer = np.zeros(50, dtype=int) | |
num_trajectories = 0 | |
while len(obss) < cfg.dataset.num_steps: | |
buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0] | |
i = transitions_per_buffer[buffer_num] | |
frb = FixedReplayBuffer( | |
data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs', | |
replay_suffix=buffer_num, | |
observation_shape=(84, 84), | |
stack_size=4, | |
update_horizon=1, | |
gamma=0.99, | |
observation_dtype=np.uint8, | |
batch_size=32, | |
replay_capacity=100000 | |
) | |
if frb._loaded_buffers: | |
done = False | |
curr_num_transitions = len(obss) | |
trajectories_to_load = cfg.dataset.trajectories_per_buffer | |
while not done: | |
states, ac, ret, next_states, next_action, next_reward, terminal, indices = \ | |
frb.sample_transition_batch(batch_size=1, indices=[i]) | |
states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84) | |
obss.append(states) | |
actions.append(ac[0]) | |
stepwise_returns.append(ret[0]) | |
if terminal[0]: | |
done_idxs.append(len(obss)) | |
returns.append(0) | |
if trajectories_to_load == 0: | |
done = True | |
else: | |
trajectories_to_load -= 1 | |
returns[-1] += ret[0] | |
i += 1 | |
if i >= 100000: | |
obss = obss[:curr_num_transitions] | |
actions = actions[:curr_num_transitions] | |
stepwise_returns = stepwise_returns[:curr_num_transitions] | |
returns[-1] = 0 | |
i = transitions_per_buffer[buffer_num] | |
done = True | |
num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load) | |
transitions_per_buffer[buffer_num] = i | |
actions = np.array(actions) | |
returns = np.array(returns) | |
stepwise_returns = np.array(stepwise_returns) | |
done_idxs = np.array(done_idxs) | |
# -- create reward-to-go dataset | |
start_index = 0 | |
rtg = np.zeros_like(stepwise_returns) | |
for i in done_idxs: | |
i = int(i) | |
curr_traj_returns = stepwise_returns[start_index:i] | |
for j in range(i - 1, start_index - 1, -1): # start from i-1 | |
rtg_j = curr_traj_returns[j - start_index:i - start_index] | |
rtg[j] = sum(rtg_j) | |
start_index = i | |
# -- create timestep dataset | |
start_index = 0 | |
timesteps = np.zeros(len(actions) + 1, dtype=int) | |
for i in done_idxs: | |
i = int(i) | |
timesteps[start_index:i + 1] = np.arange(i + 1 - start_index) | |
start_index = i + 1 | |
self.obss = obss | |
self.actions = actions | |
self.done_idxs = done_idxs | |
self.rtgs = rtg | |
self.timesteps = timesteps | |
# return obss, actions, returns, done_idxs, rtg, timesteps | |
def get_max_timestep(self) -> int: | |
""" | |
Overview: | |
Get the max timestep of the dataset. | |
""" | |
return max(self.timesteps) | |
def get_state_stats(self) -> Tuple[np.ndarray, np.ndarray]: | |
""" | |
Overview: | |
Get the state mean and std of the dataset. | |
""" | |
return deepcopy(self.state_mean), deepcopy(self.state_std) | |
def get_d4rl_dataset_stats(self, env_d4rl_name: str) -> Dict[str, list]: | |
""" | |
Overview: | |
Get the d4rl dataset stats. | |
Arguments: | |
- env_d4rl_name (:obj:`str`): The d4rl env name. | |
""" | |
return self.D4RL_DATASET_STATS[env_d4rl_name] | |
def __len__(self) -> int: | |
""" | |
Overview: | |
Get the length of the dataset. | |
""" | |
if self.env_type != 'atari': | |
return len(self.trajectories) | |
else: | |
return len(self.obss) - self.context_len | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
""" | |
Overview: | |
Get the item of the dataset. | |
Arguments: | |
- idx (:obj:`int`): The index of the dataset. | |
""" | |
if self.env_type != 'atari': | |
traj = self.trajectories[idx] | |
traj_len = traj['observations'].shape[0] | |
if traj_len > self.context_len: | |
# sample random index to slice trajectory | |
si = np.random.randint(0, traj_len - self.context_len) | |
states = torch.from_numpy(traj['observations'][si:si + self.context_len]) | |
actions = torch.from_numpy(traj['actions'][si:si + self.context_len]) | |
returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len]) | |
timesteps = torch.arange(start=si, end=si + self.context_len, step=1) | |
# all ones since no padding | |
traj_mask = torch.ones(self.context_len, dtype=torch.long) | |
else: | |
padding_len = self.context_len - traj_len | |
# padding with zeros | |
states = torch.from_numpy(traj['observations']) | |
states = torch.cat( | |
[states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0 | |
) | |
actions = torch.from_numpy(traj['actions']) | |
actions = torch.cat( | |
[actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0 | |
) | |
returns_to_go = torch.from_numpy(traj['returns_to_go']) | |
returns_to_go = torch.cat( | |
[ | |
returns_to_go, | |
torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype) | |
], | |
dim=0 | |
) | |
timesteps = torch.arange(start=0, end=self.context_len, step=1) | |
traj_mask = torch.cat( | |
[torch.ones(traj_len, dtype=torch.long), | |
torch.zeros(padding_len, dtype=torch.long)], dim=0 | |
) | |
return timesteps, states, actions, returns_to_go, traj_mask | |
else: # mean cost less than 0.001s | |
block_size = self.context_len | |
done_idx = idx + block_size | |
for i in self.done_idxs: | |
if i > idx: # first done_idx greater than idx | |
done_idx = min(int(i), done_idx) | |
break | |
idx = done_idx - block_size | |
states = torch.as_tensor( | |
np.array(self.obss[idx:done_idx]), dtype=torch.float32 | |
).view(block_size, -1) # (block_size, 4*84*84) | |
states = states / 255. | |
actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1) | |
rtgs = torch.as_tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1) | |
timesteps = torch.as_tensor(self.timesteps[idx:idx + 1], dtype=torch.int64).unsqueeze(1) | |
traj_mask = torch.ones(self.context_len, dtype=torch.long) | |
return timesteps, states, actions, rtgs, traj_mask | |
class D4RLDiffuserDataset(Dataset): | |
""" | |
Overview: | |
D4RL diffuser dataset, which is used for offline RL algorithms. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
""" | |
def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> None: | |
""" | |
Overview: | |
Initialization method of D4RLDiffuserDataset. | |
Arguments: | |
- dataset_path (:obj:`str`): The dataset path. | |
- context_len (:obj:`int`): The length of the context. | |
- rtg_scale (:obj:`float`): The scale of the returns to go. | |
""" | |
self.context_len = context_len | |
# load dataset | |
with open(dataset_path, 'rb') as f: | |
self.trajectories = pickle.load(f) | |
if isinstance(self.trajectories[0], list): | |
# for our collected dataset, e.g. cartpole/lunarlander case | |
trajectories_tmp = [] | |
original_keys = ['obs', 'next_obs', 'action', 'reward'] | |
keys = ['observations', 'next_observations', 'actions', 'rewards'] | |
for key, o_key in zip(keys, original_keys): | |
trajectories_tmp = [ | |
{ | |
key: np.stack( | |
[ | |
self.trajectories[eps_index][transition_index][o_key] | |
for transition_index in range(len(self.trajectories[eps_index])) | |
], | |
axis=0 | |
) | |
} for eps_index in range(len(self.trajectories)) | |
] | |
self.trajectories = trajectories_tmp | |
states = [] | |
for traj in self.trajectories: | |
traj_len = traj['observations'].shape[0] | |
states.append(traj['observations']) | |
# calculate returns to go and rescale them | |
traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale | |
# used for input normalization | |
states = np.concatenate(states, axis=0) | |
self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6 | |
# normalize states | |
for traj in self.trajectories: | |
traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std | |
class FixedReplayBuffer(object): | |
""" | |
Overview: | |
Object composed of a list of OutofGraphReplayBuffers. | |
Interfaces: | |
``__init__``, ``get_transition_elements``, ``sample_transition_batch`` | |
""" | |
def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg | |
""" | |
Overview: | |
Initialize the FixedReplayBuffer class. | |
Arguments: | |
- data_dir (:obj:`str`): log Directory from which to load the replay buffer. | |
- replay_suffix (:obj:`int`): If not None, then only load the replay buffer \ | |
corresponding to the specific suffix in data directory. | |
- args (:obj:`list`): Arbitrary extra arguments. | |
- kwargs (:obj:`dict`): Arbitrary keyword arguments. | |
""" | |
self._args = args | |
self._kwargs = kwargs | |
self._data_dir = data_dir | |
self._loaded_buffers = False | |
self.add_count = np.array(0) | |
self._replay_suffix = replay_suffix | |
if not self._loaded_buffers: | |
if replay_suffix is not None: | |
assert replay_suffix >= 0, 'Please pass a non-negative replay suffix' | |
self.load_single_buffer(replay_suffix) | |
else: | |
pass | |
# self._load_replay_buffers(num_buffers=50) | |
def load_single_buffer(self, suffix): | |
""" | |
Overview: | |
Load a single replay buffer. | |
Arguments: | |
- suffix (:obj:`int`): The suffix of the replay buffer. | |
""" | |
replay_buffer = self._load_buffer(suffix) | |
if replay_buffer is not None: | |
self._replay_buffers = [replay_buffer] | |
self.add_count = replay_buffer.add_count | |
self._num_replay_buffers = 1 | |
self._loaded_buffers = True | |
def _load_buffer(self, suffix): | |
""" | |
Overview: | |
Loads a OutOfGraphReplayBuffer replay buffer. | |
Arguments: | |
- suffix (:obj:`int`): The suffix of the replay buffer. | |
""" | |
try: | |
from dopamine.replay_memory import circular_replay_buffer | |
STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX | |
# pytype: disable=attribute-error | |
replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(*self._args, **self._kwargs) | |
replay_buffer.load(self._data_dir, suffix) | |
# pytype: enable=attribute-error | |
return replay_buffer | |
# except tf.errors.NotFoundError: | |
except: | |
raise ('can not load') | |
def get_transition_elements(self): | |
""" | |
Overview: | |
Returns the transition elements. | |
""" | |
return self._replay_buffers[0].get_transition_elements() | |
def sample_transition_batch(self, batch_size=None, indices=None): | |
""" | |
Overview: | |
Returns a batch of transitions (including any extra contents). | |
Arguments: | |
- batch_size (:obj:`int`): The batch size. | |
- indices (:obj:`list`): The indices of the batch. | |
""" | |
buffer_index = np.random.randint(self._num_replay_buffers) | |
return self._replay_buffers[buffer_index].sample_transition_batch(batch_size=batch_size, indices=indices) | |
class PCDataset(Dataset): | |
""" | |
Overview: | |
Dataset for Procedure Cloning. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
""" | |
def __init__(self, all_data): | |
""" | |
Overview: | |
Initialization method of PCDataset. | |
Arguments: | |
- all_data (:obj:`tuple`): The tuple of all data. | |
""" | |
self._data = all_data | |
def __getitem__(self, item): | |
""" | |
Overview: | |
Get the item of the dataset. | |
Arguments: | |
- item (:obj:`int`): The index of the dataset. | |
""" | |
return {'obs': self._data[0][item], 'bfs_in': self._data[1][item], 'bfs_out': self._data[2][item]} | |
def __len__(self): | |
""" | |
Overview: | |
Get the length of the dataset. | |
""" | |
return self._data[0].shape[0] | |
def load_bfs_datasets(train_seeds=1, test_seeds=5): | |
""" | |
Overview: | |
Load BFS datasets. | |
Arguments: | |
- train_seeds (:obj:`int`): The number of train seeds. | |
- test_seeds (:obj:`int`): The number of test seeds. | |
""" | |
from dizoo.maze.envs import Maze | |
def load_env(seed): | |
ccc = easydict.EasyDict({'size': 16}) | |
e = Maze(ccc) | |
e.seed(seed) | |
e.reset() | |
return e | |
envs = [load_env(i) for i in range(train_seeds + test_seeds)] | |
observations_train = [] | |
observations_test = [] | |
bfs_input_maps_train = [] | |
bfs_input_maps_test = [] | |
bfs_output_maps_train = [] | |
bfs_output_maps_test = [] | |
for idx, env in enumerate(envs): | |
if idx < train_seeds: | |
observations = observations_train | |
bfs_input_maps = bfs_input_maps_train | |
bfs_output_maps = bfs_output_maps_train | |
else: | |
observations = observations_test | |
bfs_input_maps = bfs_input_maps_test | |
bfs_output_maps = bfs_output_maps_test | |
start_obs = env.process_states(env._get_obs(), env.get_maze_map()) | |
_, track_back = get_vi_sequence(env, start_obs) | |
env_observations = torch.stack([track_back[i][0] for i in range(len(track_back))], dim=0) | |
for i in range(env_observations.shape[0]): | |
bfs_sequence, _ = get_vi_sequence(env, env_observations[i].numpy().astype(np.int32)) # [L, W, W] | |
bfs_input_map = env.n_action * np.ones([env.size, env.size], dtype=np.long) | |
for j in range(bfs_sequence.shape[0]): | |
bfs_input_maps.append(torch.from_numpy(bfs_input_map)) | |
bfs_output_maps.append(torch.from_numpy(bfs_sequence[j])) | |
observations.append(env_observations[i]) | |
bfs_input_map = bfs_sequence[j] | |
train_data = PCDataset( | |
( | |
torch.stack(observations_train, dim=0), | |
torch.stack(bfs_input_maps_train, dim=0), | |
torch.stack(bfs_output_maps_train, dim=0), | |
) | |
) | |
test_data = PCDataset( | |
( | |
torch.stack(observations_test, dim=0), | |
torch.stack(bfs_input_maps_test, dim=0), | |
torch.stack(bfs_output_maps_test, dim=0), | |
) | |
) | |
return train_data, test_data | |
class BCODataset(Dataset): | |
""" | |
Overview: | |
Dataset for Behavioral Cloning from Observation. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
Properties: | |
- obs (:obj:`np.ndarray`): The observation array. | |
- action (:obj:`np.ndarray`): The action array. | |
""" | |
def __init__(self, data=None): | |
""" | |
Overview: | |
Initialization method of BCODataset. | |
Arguments: | |
- data (:obj:`dict`): The data dict. | |
""" | |
if data is None: | |
raise ValueError('Dataset can not be empty!') | |
else: | |
self._data = data | |
def __len__(self): | |
""" | |
Overview: | |
Get the length of the dataset. | |
""" | |
return len(self._data['obs']) | |
def __getitem__(self, idx): | |
""" | |
Overview: | |
Get the item of the dataset. | |
Arguments: | |
- idx (:obj:`int`): The index of the dataset. | |
""" | |
return {k: self._data[k][idx] for k in self._data.keys()} | |
def obs(self): | |
""" | |
Overview: | |
Get the observation array. | |
""" | |
return self._data['obs'] | |
def action(self): | |
""" | |
Overview: | |
Get the action array. | |
""" | |
return self._data['action'] | |
class SequenceDataset(torch.utils.data.Dataset): | |
""" | |
Overview: | |
Dataset for diffuser. | |
Interfaces: | |
``__init__``, ``__len__``, ``__getitem__`` | |
""" | |
def __init__(self, cfg): | |
""" | |
Overview: | |
Initialization method of SequenceDataset. | |
Arguments: | |
- cfg (:obj:`dict`): The config dict. | |
""" | |
import gym | |
env_id = cfg.env.env_id | |
data_path = cfg.policy.collect.get('data_path', None) | |
env = gym.make(env_id) | |
dataset = env.get_dataset() | |
self.returns_scale = cfg.env.returns_scale | |
self.horizon = cfg.env.horizon | |
self.max_path_length = cfg.env.max_path_length | |
self.discount = cfg.policy.learn.discount_factor | |
self.discounts = self.discount ** np.arange(self.max_path_length)[:, None] | |
self.use_padding = cfg.env.use_padding | |
self.include_returns = cfg.env.include_returns | |
self.env_id = cfg.env.env_id | |
itr = self.sequence_dataset(env, dataset) | |
self.n_episodes = 0 | |
fields = {} | |
for k in dataset.keys(): | |
if 'metadata' in k: | |
continue | |
fields[k] = [] | |
fields['path_lengths'] = [] | |
for i, episode in enumerate(itr): | |
path_length = len(episode['observations']) | |
assert path_length <= self.max_path_length | |
fields['path_lengths'].append(path_length) | |
for key, val in episode.items(): | |
if key not in fields: | |
fields[key] = [] | |
if val.ndim < 2: | |
val = np.expand_dims(val, axis=-1) | |
shape = (self.max_path_length, val.shape[-1]) | |
arr = np.zeros(shape, dtype=np.float32) | |
arr[:path_length] = val | |
fields[key].append(arr) | |
if episode['terminals'].any() and cfg.env.termination_penalty and 'timeouts' in episode: | |
assert not episode['timeouts'].any(), 'Penalized a timeout episode for early termination' | |
fields['rewards'][-1][path_length - 1] += cfg.env.termination_penalty | |
self.n_episodes += 1 | |
for k in fields.keys(): | |
fields[k] = np.array(fields[k]) | |
self.normalizer = DatasetNormalizer(fields, cfg.policy.normalizer, path_lengths=fields['path_lengths']) | |
self.indices = self.make_indices(fields['path_lengths'], self.horizon) | |
self.observation_dim = cfg.env.obs_dim | |
self.action_dim = cfg.env.action_dim | |
self.fields = fields | |
self.normalize() | |
self.normed = False | |
if cfg.env.normed: | |
self.vmin, self.vmax = self._get_bounds() | |
self.normed = True | |
# shapes = {key: val.shape for key, val in self.fields.items()} | |
# print(f'[ datasets/mujoco ] Dataset fields: {shapes}') | |
def sequence_dataset(self, env, dataset=None): | |
""" | |
Overview: | |
Sequence the dataset. | |
Arguments: | |
- env (:obj:`gym.Env`): The gym env. | |
""" | |
import collections | |
N = dataset['rewards'].shape[0] | |
if 'maze2d' in env.spec.id: | |
dataset = self.maze2d_set_terminals(env, dataset) | |
data_ = collections.defaultdict(list) | |
# The newer version of the dataset adds an explicit | |
# timeouts field. Keep old method for backwards compatability. | |
use_timeouts = 'timeouts' in dataset | |
episode_step = 0 | |
for i in range(N): | |
done_bool = bool(dataset['terminals'][i]) | |
if use_timeouts: | |
final_timestep = dataset['timeouts'][i] | |
else: | |
final_timestep = (episode_step == env._max_episode_steps - 1) | |
for k in dataset: | |
if 'metadata' in k: | |
continue | |
data_[k].append(dataset[k][i]) | |
if done_bool or final_timestep: | |
episode_step = 0 | |
episode_data = {} | |
for k in data_: | |
episode_data[k] = np.array(data_[k]) | |
if 'maze2d' in env.spec.id: | |
episode_data = self.process_maze2d_episode(episode_data) | |
yield episode_data | |
data_ = collections.defaultdict(list) | |
episode_step += 1 | |
def maze2d_set_terminals(self, env, dataset): | |
""" | |
Overview: | |
Set the terminals for maze2d. | |
Arguments: | |
- env (:obj:`gym.Env`): The gym env. | |
- dataset (:obj:`dict`): The dataset dict. | |
""" | |
goal = env.get_target() | |
threshold = 0.5 | |
xy = dataset['observations'][:, :2] | |
distances = np.linalg.norm(xy - goal, axis=-1) | |
at_goal = distances < threshold | |
timeouts = np.zeros_like(dataset['timeouts']) | |
# timeout at time t iff | |
# at goal at time t and | |
# not at goal at time t + 1 | |
timeouts[:-1] = at_goal[:-1] * ~at_goal[1:] | |
timeout_steps = np.where(timeouts)[0] | |
path_lengths = timeout_steps[1:] - timeout_steps[:-1] | |
print( | |
f'[ utils/preprocessing ] Segmented {env.spec.id} | {len(path_lengths)} paths | ' | |
f'min length: {path_lengths.min()} | max length: {path_lengths.max()}' | |
) | |
dataset['timeouts'] = timeouts | |
return dataset | |
def process_maze2d_episode(self, episode): | |
""" | |
Overview: | |
Process the maze2d episode, adds in `next_observations` field to episode. | |
Arguments: | |
- episode (:obj:`dict`): The episode dict. | |
""" | |
assert 'next_observations' not in episode | |
length = len(episode['observations']) | |
next_observations = episode['observations'][1:].copy() | |
for key, val in episode.items(): | |
episode[key] = val[:-1] | |
episode['next_observations'] = next_observations | |
return episode | |
def normalize(self, keys=['observations', 'actions']): | |
""" | |
Overview: | |
Normalize the dataset, normalize fields that will be predicted by the diffusion model | |
Arguments: | |
- keys (:obj:`list`): The list of keys. | |
""" | |
for key in keys: | |
array = self.fields[key].reshape(self.n_episodes * self.max_path_length, -1) | |
normed = self.normalizer.normalize(array, key) | |
self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1) | |
def make_indices(self, path_lengths, horizon): | |
""" | |
Overview: | |
Make indices for sampling from dataset. Each index maps to a datapoint. | |
Arguments: | |
- path_lengths (:obj:`np.ndarray`): The path length array. | |
- horizon (:obj:`int`): The horizon. | |
""" | |
indices = [] | |
for i, path_length in enumerate(path_lengths): | |
max_start = min(path_length - 1, self.max_path_length - horizon) | |
if not self.use_padding: | |
max_start = min(max_start, path_length - horizon) | |
for start in range(max_start): | |
end = start + horizon | |
indices.append((i, start, end)) | |
indices = np.array(indices) | |
return indices | |
def get_conditions(self, observations): | |
""" | |
Overview: | |
Get the conditions on current observation for planning. | |
Arguments: | |
- observations (:obj:`np.ndarray`): The observation array. | |
""" | |
if 'maze2d' in self.env_id: | |
return {'condition_id': [0, self.horizon - 1], 'condition_val': [observations[0], observations[-1]]} | |
else: | |
return {'condition_id': [0], 'condition_val': [observations[0]]} | |
def __len__(self): | |
""" | |
Overview: | |
Get the length of the dataset. | |
""" | |
return len(self.indices) | |
def _get_bounds(self): | |
""" | |
Overview: | |
Get the bounds of the dataset. | |
""" | |
print('[ datasets/sequence ] Getting value dataset bounds...', end=' ', flush=True) | |
vmin = np.inf | |
vmax = -np.inf | |
for i in range(len(self.indices)): | |
value = self.__getitem__(i)['returns'].item() | |
vmin = min(value, vmin) | |
vmax = max(value, vmax) | |
print('✓') | |
return vmin, vmax | |
def normalize_value(self, value): | |
""" | |
Overview: | |
Normalize the value. | |
Arguments: | |
- value (:obj:`np.ndarray`): The value array. | |
""" | |
# [0, 1] | |
normed = (value - self.vmin) / (self.vmax - self.vmin) | |
# [-1, 1] | |
normed = normed * 2 - 1 | |
return normed | |
def __getitem__(self, idx, eps=1e-4): | |
""" | |
Overview: | |
Get the item of the dataset. | |
Arguments: | |
- idx (:obj:`int`): The index of the dataset. | |
- eps (:obj:`float`): The epsilon. | |
""" | |
path_ind, start, end = self.indices[idx] | |
observations = self.fields['normed_observations'][path_ind, start:end] | |
actions = self.fields['normed_actions'][path_ind, start:end] | |
done = self.fields['terminals'][path_ind, start:end] | |
# conditions = self.get_conditions(observations) | |
trajectories = np.concatenate([actions, observations], axis=-1) | |
if self.include_returns: | |
rewards = self.fields['rewards'][path_ind, start:] | |
discounts = self.discounts[:len(rewards)] | |
returns = (discounts * rewards).sum() | |
if self.normed: | |
returns = self.normalize_value(returns) | |
returns = np.array([returns / self.returns_scale], dtype=np.float32) | |
batch = { | |
'trajectories': trajectories, | |
'returns': returns, | |
'done': done, | |
'action': actions, | |
} | |
else: | |
batch = { | |
'trajectories': trajectories, | |
'done': done, | |
'action': actions, | |
} | |
batch.update(self.get_conditions(observations)) | |
return batch | |
def hdf5_save(exp_data, expert_data_path): | |
""" | |
Overview: | |
Save the data to hdf5. | |
""" | |
try: | |
import h5py | |
except ImportError: | |
import sys | |
logging.warning("not found h5py package, please install it trough 'pip install h5py' ") | |
sys.exit(1) | |
dataset = dataset = h5py.File('%s_demos.hdf5' % expert_data_path.replace('.pkl', ''), 'w') | |
dataset.create_dataset('obs', data=np.array([d['obs'].numpy() for d in exp_data]), compression='gzip') | |
dataset.create_dataset('action', data=np.array([d['action'].numpy() for d in exp_data]), compression='gzip') | |
dataset.create_dataset('reward', data=np.array([d['reward'].numpy() for d in exp_data]), compression='gzip') | |
dataset.create_dataset('done', data=np.array([d['done'] for d in exp_data]), compression='gzip') | |
dataset.create_dataset('next_obs', data=np.array([d['next_obs'].numpy() for d in exp_data]), compression='gzip') | |
def naive_save(exp_data, expert_data_path): | |
""" | |
Overview: | |
Save the data to pickle. | |
""" | |
with open(expert_data_path, 'wb') as f: | |
pickle.dump(exp_data, f) | |
def offline_data_save_type(exp_data, expert_data_path, data_type='naive'): | |
""" | |
Overview: | |
Save the offline data. | |
""" | |
globals()[data_type + '_save'](exp_data, expert_data_path) | |
def create_dataset(cfg, **kwargs) -> Dataset: | |
""" | |
Overview: | |
Create dataset. | |
""" | |
cfg = EasyDict(cfg) | |
import_module(cfg.get('import_names', [])) | |
return DATASET_REGISTRY.build(cfg.policy.collect.data_type, cfg=cfg, **kwargs) | |