Spaces:
Sleeping
Sleeping
from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict | |
import copy | |
from ditk import logging | |
import random | |
from functools import lru_cache # in python3.9, we can change to cache | |
import numpy as np | |
import torch | |
import treetensor.torch as ttorch | |
def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int: | |
""" | |
Overview: | |
Get shape[0] of data's torch tensor or treetensor | |
Arguments: | |
- data (:obj:`Union[List,Dict,torch.Tensor,ttorch.Tensor]`): data to be analysed | |
Returns: | |
- shape[0] (:obj:`int`): first dimension length of data, usually the batchsize. | |
""" | |
if isinstance(data, list) or isinstance(data, tuple): | |
return get_shape0(data[0]) | |
elif isinstance(data, dict): | |
for k, v in data.items(): | |
return get_shape0(v) | |
elif isinstance(data, torch.Tensor): | |
return data.shape[0] | |
elif isinstance(data, ttorch.Tensor): | |
def fn(t): | |
item = list(t.values())[0] | |
if np.isscalar(item[0]): | |
return item[0] | |
else: | |
return fn(item) | |
return fn(data.shape) | |
else: | |
raise TypeError("Error in getting shape0, not support type: {}".format(data)) | |
def lists_to_dicts( | |
data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]], | |
recursive: bool = False, | |
) -> Union[Mapping[object, object], NamedTuple]: | |
""" | |
Overview: | |
Transform a list of dicts to a dict of lists. | |
Arguments: | |
- data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`): | |
A dict of lists need to be transformed | |
- recursive (:obj:`bool`): whether recursively deals with dict element | |
Returns: | |
- newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result | |
Example: | |
>>> from ding.utils import * | |
>>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) | |
{1: [1, 2], 10: [3, 4]} | |
""" | |
if len(data) == 0: | |
raise ValueError("empty data") | |
if isinstance(data[0], dict): | |
if recursive: | |
new_data = {} | |
for k in data[0].keys(): | |
if isinstance(data[0][k], dict) and k != 'prev_state': | |
tmp = [data[b][k] for b in range(len(data))] | |
new_data[k] = lists_to_dicts(tmp) | |
else: | |
new_data[k] = [data[b][k] for b in range(len(data))] | |
else: | |
new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()} | |
elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'): # namedtuple | |
new_data = type(data[0])(*list(zip(*data))) | |
else: | |
raise TypeError("not support element type: {}".format(type(data[0]))) | |
return new_data | |
def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]: | |
""" | |
Overview: | |
Transform a dict of lists to a list of dicts. | |
Arguments: | |
- data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed | |
Returns: | |
- newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result | |
Example: | |
>>> from ding.utils import * | |
>>> dicts_to_lists({1: [1, 2], 10: [3, 4]}) | |
[{1: 1, 10: 3}, {1: 2, 10: 4}] | |
""" | |
new_data = [v for v in data.values()] | |
new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))] | |
return new_data | |
def override(cls: type) -> Callable[[ | |
Callable, | |
], Callable]: | |
""" | |
Overview: | |
Annotation for documenting method overrides. | |
Arguments: | |
- cls (:obj:`type`): The superclass that provides the overridden method. If this | |
cls does not actually have the method, an error is raised. | |
""" | |
def check_override(method: Callable) -> Callable: | |
if method.__name__ not in dir(cls): | |
raise NameError("{} does not override any method of {}".format(method, cls)) | |
return method | |
return check_override | |
def squeeze(data: object) -> object: | |
""" | |
Overview: | |
Squeeze data from tuple, list or dict to single object | |
Arguments: | |
- data (:obj:`object`): data to be squeezed | |
Example: | |
>>> a = (4, ) | |
>>> a = squeeze(a) | |
>>> print(a) | |
>>> 4 | |
""" | |
if isinstance(data, tuple) or isinstance(data, list): | |
if len(data) == 1: | |
return data[0] | |
else: | |
return tuple(data) | |
elif isinstance(data, dict): | |
if len(data) == 1: | |
return list(data.values())[0] | |
return data | |
default_get_set = set() | |
def default_get( | |
data: dict, | |
name: str, | |
default_value: Optional[Any] = None, | |
default_fn: Optional[Callable] = None, | |
judge_fn: Optional[Callable] = None | |
) -> Any: | |
""" | |
Overview: | |
Getting the value by input, checks generically on the inputs with \ | |
at least ``data`` and ``name``. If ``name`` exists in ``data``, \ | |
get the value at ``name``; else, add ``name`` to ``default_get_set``\ | |
with value generated by \ | |
``default_fn`` (or directly as ``default_value``) that \ | |
is checked by `` judge_fn`` to be legal. | |
Arguments: | |
- data(:obj:`dict`): Data input dictionary | |
- name(:obj:`str`): Key name | |
- default_value(:obj:`Optional[Any]`) = None, | |
- default_fn(:obj:`Optional[Callable]`) = Value | |
- judge_fn(:obj:`Optional[Callable]`) = None | |
Returns: | |
- ret(:obj:`list`): Splitted data | |
- residual(:obj:`list`): Residule list | |
""" | |
if name in data: | |
return data[name] | |
else: | |
assert default_value is not None or default_fn is not None | |
value = default_fn() if default_fn is not None else default_value | |
if judge_fn: | |
assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value)) | |
if name not in default_get_set: | |
logging.warning("{} use default value {}".format(name, value)) | |
default_get_set.add(name) | |
return value | |
def list_split(data: list, step: int) -> List[list]: | |
""" | |
Overview: | |
Split list of data by step. | |
Arguments: | |
- data(:obj:`list`): List of data for spliting | |
- step(:obj:`int`): Number of step for spliting | |
Returns: | |
- ret(:obj:`list`): List of splitted data. | |
- residual(:obj:`list`): Residule list. This value is ``None`` when ``data`` divides ``steps``. | |
Example: | |
>>> list_split([1,2,3,4],2) | |
([[1, 2], [3, 4]], None) | |
>>> list_split([1,2,3,4],3) | |
([[1, 2, 3]], [4]) | |
""" | |
if len(data) < step: | |
return [], data | |
ret = [] | |
divide_num = len(data) // step | |
for i in range(divide_num): | |
start, end = i * step, (i + 1) * step | |
ret.append(data[start:end]) | |
if divide_num * step < len(data): | |
residual = data[divide_num * step:] | |
else: | |
residual = None | |
return ret, residual | |
def error_wrapper(fn, default_ret, warning_msg=""): | |
""" | |
Overview: | |
wrap the function, so that any Exception in the function will be catched and return the default_ret | |
Arguments: | |
- fn (:obj:`Callable`): the function to be wraped | |
- default_ret (:obj:`obj`): the default return when an Exception occurred in the function | |
Returns: | |
- wrapper (:obj:`Callable`): the wrapped function | |
Examples: | |
>>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py) | |
>>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink. | |
>>> if is_fake_link: | |
>>> return 0 | |
>>> return error_wrapper(link.get_rank, 0)() | |
""" | |
def wrapper(*args, **kwargs): | |
try: | |
ret = fn(*args, **kwargs) | |
except Exception as e: | |
ret = default_ret | |
if warning_msg != "": | |
one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e)) | |
return ret | |
return wrapper | |
class LimitedSpaceContainer: | |
""" | |
Overview: | |
A space simulator. | |
Interfaces: | |
``__init__``, ``get_residual_space``, ``release_space`` | |
""" | |
def __init__(self, min_val: int, max_val: int) -> None: | |
""" | |
Overview: | |
Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization. | |
Arguments: | |
- min_val (:obj:`int`): Min volume of the container, usually 0. | |
- max_val (:obj:`int`): Max volume of the container. | |
""" | |
self.min_val = min_val | |
self.max_val = max_val | |
assert (max_val >= min_val) | |
self.cur = self.min_val | |
def get_residual_space(self) -> int: | |
""" | |
Overview: | |
Get all residual pieces of space. Set ``cur`` to ``max_val`` | |
Arguments: | |
- ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``. | |
""" | |
ret = self.max_val - self.cur | |
self.cur = self.max_val | |
return ret | |
def acquire_space(self) -> bool: | |
""" | |
Overview: | |
Try to get one pice of space. If there is one, return True; Otherwise return False. | |
Returns: | |
- flag (:obj:`bool`): Whether there is any piece of residual space. | |
""" | |
if self.cur < self.max_val: | |
self.cur += 1 | |
return True | |
else: | |
return False | |
def release_space(self) -> None: | |
""" | |
Overview: | |
Release only one piece of space. Decrement ``cur``, but ensure it won't be negative. | |
""" | |
self.cur = max(self.min_val, self.cur - 1) | |
def increase_space(self) -> None: | |
""" | |
Overview: | |
Increase one piece in space. Increment ``max_val``. | |
""" | |
self.max_val += 1 | |
def decrease_space(self) -> None: | |
""" | |
Overview: | |
Decrease one piece in space. Decrement ``max_val``. | |
""" | |
self.max_val -= 1 | |
def deep_merge_dicts(original: dict, new_dict: dict) -> dict: | |
""" | |
Overview: | |
Merge two dicts by calling ``deep_update`` | |
Arguments: | |
- original (:obj:`dict`): Dict 1. | |
- new_dict (:obj:`dict`): Dict 2. | |
Returns: | |
- merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged. | |
""" | |
original = original or {} | |
new_dict = new_dict or {} | |
merged = copy.deepcopy(original) | |
if new_dict: # if new_dict is neither empty dict nor None | |
deep_update(merged, new_dict, True, []) | |
return merged | |
def deep_update( | |
original: dict, | |
new_dict: dict, | |
new_keys_allowed: bool = False, | |
whitelist: Optional[List[str]] = None, | |
override_all_if_type_changes: Optional[List[str]] = None | |
): | |
""" | |
Overview: | |
Update original dict with values from new_dict recursively. | |
Arguments: | |
- original (:obj:`dict`): Dictionary with default values. | |
- new_dict (:obj:`dict`): Dictionary with values to be updated | |
- new_keys_allowed (:obj:`bool`): Whether new keys are allowed. | |
- whitelist (:obj:`Optional[List[str]]`): | |
List of keys that correspond to dict | |
values where new subkeys can be introduced. This is only at the top | |
level. | |
- override_all_if_type_changes(:obj:`Optional[List[str]]`): | |
List of top level | |
keys with value=dict, for which we always simply override the | |
entire value (:obj:`dict`), if the "type" key in that value dict changes. | |
.. note:: | |
If new key is introduced in new_dict, then if new_keys_allowed is not | |
True, an error will be thrown. Further, for sub-dicts, if the key is | |
in the whitelist, then new subkeys can be introduced. | |
""" | |
whitelist = whitelist or [] | |
override_all_if_type_changes = override_all_if_type_changes or [] | |
for k, value in new_dict.items(): | |
if k not in original and not new_keys_allowed: | |
raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys())) | |
# Both original value and new one are dicts. | |
if isinstance(original.get(k), dict) and isinstance(value, dict): | |
# Check old type vs old one. If different, override entire value. | |
if k in override_all_if_type_changes and \ | |
"type" in value and "type" in original[k] and \ | |
value["type"] != original[k]["type"]: | |
original[k] = value | |
# Whitelisted key -> ok to add new subkeys. | |
elif k in whitelist: | |
deep_update(original[k], value, True) | |
# Non-whitelisted key. | |
else: | |
deep_update(original[k], value, new_keys_allowed) | |
# Original value not a dict OR new value not a dict: | |
# Override entire value. | |
else: | |
original[k] = value | |
return original | |
def flatten_dict(data: dict, delimiter: str = "/") -> dict: | |
""" | |
Overview: | |
Flatten the dict, see example | |
Arguments: | |
- data (:obj:`dict`): Original nested dict | |
- delimiter (str): Delimiter of the keys of the new dict | |
Returns: | |
- data (:obj:`dict`): Flattened nested dict | |
Example: | |
>>> a | |
{'a': {'b': 100}} | |
>>> flatten_dict(a) | |
{'a/b': 100} | |
""" | |
data = copy.deepcopy(data) | |
while any(isinstance(v, dict) for v in data.values()): | |
remove = [] | |
add = {} | |
for key, value in data.items(): | |
if isinstance(value, dict): | |
for subkey, v in value.items(): | |
add[delimiter.join([key, subkey])] = v | |
remove.append(key) | |
data.update(add) | |
for k in remove: | |
del data[k] | |
return data | |
def set_pkg_seed(seed: int, use_cuda: bool = True) -> None: | |
""" | |
Overview: | |
Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\ | |
This is usaually used in entry scipt in the section of setting random seed for all package and instance | |
Argument: | |
- seed(:obj:`int`): Set seed | |
- use_cuda(:obj:`bool`) Whether use cude | |
Examples: | |
>>> # ../entry/xxxenv_xxxpolicy_main.py | |
>>> ... | |
# Set random seed for all package and instance | |
>>> collector_env.seed(seed) | |
>>> evaluator_env.seed(seed, dynamic_seed=False) | |
>>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda) | |
>>> ... | |
# Set up RL Policy, etc. | |
>>> ... | |
""" | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
if use_cuda and torch.cuda.is_available(): | |
torch.cuda.manual_seed(seed) | |
def one_time_warning(warning_msg: str) -> None: | |
""" | |
Overview: | |
Print warning message only once. | |
Arguments: | |
- warning_msg (:obj:`str`): Warning message. | |
""" | |
logging.warning(warning_msg) | |
def split_fn(data, indices, start, end): | |
""" | |
Overview: | |
Split data by indices | |
Arguments: | |
- data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed | |
- indices (:obj:`np.ndarray`): indices to split | |
- start (:obj:`int`): start index | |
- end (:obj:`int`): end index | |
""" | |
if data is None: | |
return None | |
elif isinstance(data, list): | |
return [split_fn(d, indices, start, end) for d in data] | |
elif isinstance(data, dict): | |
return {k1: split_fn(v1, indices, start, end) for k1, v1 in data.items()} | |
elif isinstance(data, str): | |
return data | |
else: | |
return data[indices[start:end]] | |
def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict: | |
""" | |
Overview: | |
Split data into batches | |
Arguments: | |
- data (:obj:`dict`): data to be analysed | |
- split_size (:obj:`int`): split size | |
- shuffle (:obj:`bool`): whether shuffle | |
""" | |
assert isinstance(data, dict), type(data) | |
length = [] | |
for k, v in data.items(): | |
if v is None: | |
continue | |
elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']: | |
length.append(len(v)) | |
elif isinstance(v, list) or isinstance(v, tuple): | |
if isinstance(v[0], str): | |
# some buffer data contains useless string infos, such as 'buffer_id', | |
# which should not be split, so we just skip it | |
continue | |
else: | |
length.append(get_shape0(v[0])) | |
elif isinstance(v, dict): | |
length.append(len(v[list(v.keys())[0]])) | |
else: | |
length.append(len(v)) | |
assert len(length) > 0 | |
# assert len(set(length)) == 1, "data values must have the same length: {}".format(length) | |
# if continuous action, data['logit'] is list of length 2 | |
length = length[0] | |
assert split_size >= 1 | |
if shuffle: | |
indices = np.random.permutation(length) | |
else: | |
indices = np.arange(length) | |
for i in range(0, length, split_size): | |
if i + split_size > length: | |
i = length - split_size | |
batch = split_fn(data, indices, i, i + split_size) | |
yield batch | |
class RunningMeanStd(object): | |
""" | |
Overview: | |
Wrapper to update new variable, new mean, and new count | |
Interfaces: | |
``__init__``, ``update``, ``reset``, ``new_shape`` | |
Properties: | |
- ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count`` | |
""" | |
def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')): | |
""" | |
Overview: | |
Initialize ``self.`` See ``help(type(self))`` for accurate \ | |
signature; setup the properties. | |
Arguments: | |
- env (:obj:`gym.Env`): the environment to wrap. | |
- epsilon (:obj:`Float`): the epsilon used for self for the std output | |
- shape (:obj: `np.array`): the np array shape used for the expression \ | |
of this wrapper on attibutes of mean and variance | |
""" | |
self._epsilon = epsilon | |
self._shape = shape | |
self._device = device | |
self.reset() | |
def update(self, x): | |
""" | |
Overview: | |
Update mean, variable, and count | |
Arguments: | |
- ``x``: the batch | |
""" | |
batch_mean = np.mean(x, axis=0) | |
batch_var = np.var(x, axis=0) | |
batch_count = x.shape[0] | |
new_count = batch_count + self._count | |
mean_delta = batch_mean - self._mean | |
new_mean = self._mean + mean_delta * batch_count / new_count | |
# this method for calculating new variable might be numerically unstable | |
m_a = self._var * self._count | |
m_b = batch_var * batch_count | |
m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count | |
new_var = m2 / new_count | |
self._mean = new_mean | |
self._var = new_var | |
self._count = new_count | |
def reset(self): | |
""" | |
Overview: | |
Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count`` | |
""" | |
if len(self._shape) > 0: | |
self._mean = np.zeros(self._shape, 'float32') | |
self._var = np.ones(self._shape, 'float32') | |
else: | |
self._mean, self._var = 0., 1. | |
self._count = self._epsilon | |
def mean(self) -> np.ndarray: | |
""" | |
Overview: | |
Property ``mean`` gotten from ``self._mean`` | |
""" | |
if np.isscalar(self._mean): | |
return self._mean | |
else: | |
return torch.FloatTensor(self._mean).to(self._device) | |
def std(self) -> np.ndarray: | |
""" | |
Overview: | |
Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon`` | |
""" | |
std = np.sqrt(self._var + 1e-8) | |
if np.isscalar(std): | |
return std | |
else: | |
return torch.FloatTensor(std).to(self._device) | |
def new_shape(obs_shape, act_shape, rew_shape): | |
""" | |
Overview: | |
Get new shape of observation, acton, and reward; in this case unchanged. | |
Arguments: | |
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`) | |
Returns: | |
obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`) | |
""" | |
return obs_shape, act_shape, rew_shape | |
def make_key_as_identifier(data: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Overview: | |
Make the key of dict into legal python identifier string so that it is | |
compatible with some python magic method such as ``__getattr``. | |
Arguments: | |
- data (:obj:`Dict[str, Any]`): The original dict data. | |
Return: | |
- new_data (:obj:`Dict[str, Any]`): The new dict data with legal identifier keys. | |
""" | |
def legalization(s: str) -> str: | |
if s[0].isdigit(): | |
s = '_' + s | |
return s.replace('.', '_') | |
new_data = {} | |
for k in data: | |
new_k = legalization(k) | |
new_data[new_k] = data[k] | |
return new_data | |
def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]: | |
""" | |
Overview: | |
Remove illegal item in dict info, like str, which is not compatible with Tensor. | |
Arguments: | |
- data (:obj:`Dict[str, Any]`): The original dict data. | |
Return: | |
- new_data (:obj:`Dict[str, Any]`): The new dict data without legal items. | |
""" | |
new_data = {} | |
for k, v in data.items(): | |
if isinstance(v, str): | |
continue | |
new_data[k] = data[k] | |
return new_data | |