Spaces:
Sleeping
Sleeping
import os | |
import itertools | |
import random | |
import uuid | |
from ditk import logging | |
import hickle | |
from typing import Any, Iterable, List, Optional, Tuple, Union | |
from collections import Counter | |
from collections import defaultdict, deque, OrderedDict | |
from ding.data.buffer import Buffer, apply_middleware, BufferedData | |
from ding.utils import fastcopy | |
from ding.torch_utils import get_null_data | |
class BufferIndex(): | |
""" | |
Overview: | |
Save index string and offset in key value pair. | |
""" | |
def __init__(self, maxlen: int, *args, **kwargs): | |
self.maxlen = maxlen | |
self.__map = OrderedDict(*args, **kwargs) | |
self._last_key = next(reversed(self.__map)) if len(self) > 0 else None | |
self._cumlen = len(self.__map) | |
def get(self, key: str) -> int: | |
value = self.__map[key] | |
value = value % self._cumlen + min(0, (self.maxlen - self._cumlen)) | |
return value | |
def __len__(self) -> int: | |
return len(self.__map) | |
def has(self, key: str) -> bool: | |
return key in self.__map | |
def append(self, key: str): | |
self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0 | |
self._last_key = key | |
self._cumlen += 1 | |
if len(self) > self.maxlen: | |
self.__map.popitem(last=False) | |
def clear(self): | |
self.__map = OrderedDict() | |
self._last_key = None | |
self._cumlen = 0 | |
class DequeBuffer(Buffer): | |
""" | |
Overview: | |
A buffer implementation based on the deque structure. | |
""" | |
def __init__(self, size: int, sliced: bool = False) -> None: | |
""" | |
Overview: | |
The initialization method of DequeBuffer. | |
Arguments: | |
- size (:obj:`int`): The maximum number of objects that the buffer can hold. | |
- sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group | |
""" | |
super().__init__(size=size) | |
self.storage = deque(maxlen=size) | |
self.indices = BufferIndex(maxlen=size) | |
self.sliced = sliced | |
# Meta index is a dict which uses deque as values | |
self.meta_index = {} | |
def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: | |
""" | |
Overview: | |
The method that input the objects and the related meta information into the buffer. | |
Arguments: | |
- data (:obj:`Any`): The input object which can be in any format. | |
- meta (:obj:`Optional[dict]`): A dict that helps describe data, such as\ | |
category, label, priority, etc. Default to ``None``. | |
""" | |
return self._push(data, meta) | |
def sample( | |
self, | |
size: Optional[int] = None, | |
indices: Optional[List[str]] = None, | |
replace: bool = False, | |
sample_range: Optional[slice] = None, | |
ignore_insufficient: bool = False, | |
groupby: Optional[str] = None, | |
unroll_len: Optional[int] = None | |
) -> Union[List[BufferedData], List[List[BufferedData]]]: | |
""" | |
Overview: | |
The method that randomly sample data from the buffer or retrieve certain data by indices. | |
Arguments: | |
- size (:obj:`Optional[int]`): The number of objects to be obtained from the buffer. | |
If ``indices`` is not specified, the ``size`` is required to randomly sample the\ | |
corresponding number of objects from the buffer. | |
- indices (:obj:`Optional[List[str]]`): Only used when you want to retrieve data by indices. | |
Default to ``None``. | |
- replace (:obj:`bool`): As the sampling process is carried out one by one, this parameter\ | |
determines whether the previous samples will be put back into the buffer for subsequent\ | |
sampling. Default to ``False``, it means that duplicate samples will not appear in one\ | |
``sample`` call. | |
- sample_range (:obj:`Optional[slice]`): The indices range to sample data. Default to ``None``,\ | |
it means no restrictions on the range of indices for the sampling process. | |
- ignore_insufficient (:obj:`bool`): whether throw `` ValueError`` if the sampled size is smaller\ | |
than the required size. Default to ``False``. | |
- groupby (:obj:`Optional[str]`): If this parameter is activated, the method will return a\ | |
target size of object groups. | |
- unroll_len (:obj:`Optional[int]`): The unroll length of a trajectory, used only when the\ | |
``groupby`` is activated. | |
Returns: | |
- sampled_data (Union[List[BufferedData], List[List[BufferedData]]]): The sampling result. | |
""" | |
storage = self.storage | |
if sample_range: | |
storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step)) | |
# Size and indices | |
assert size or indices, "One of size and indices must not be empty." | |
if (size and indices) and (size != len(indices)): | |
raise AssertionError("Size and indices length must be equal.") | |
if not size: | |
size = len(indices) | |
# Indices and groupby | |
assert not (indices and groupby), "Cannot use groupby and indicex at the same time." | |
# Groupby and unroll_len | |
assert not unroll_len or ( | |
unroll_len and groupby | |
), "Parameter unroll_len needs to be used in conjunction with groupby." | |
value_error = None | |
sampled_data = [] | |
if indices: | |
indices_set = set(indices) | |
hashed_data = filter(lambda item: item.index in indices_set, storage) | |
hashed_data = map(lambda item: (item.index, item), hashed_data) | |
hashed_data = dict(hashed_data) | |
# Re-sample and return in indices order | |
sampled_data = [hashed_data[index] for index in indices] | |
elif groupby: | |
sampled_data = self._sample_by_group( | |
size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced | |
) | |
else: | |
if replace: | |
sampled_data = random.choices(storage, k=size) | |
else: | |
try: | |
sampled_data = random.sample(storage, k=size) | |
except ValueError as e: | |
value_error = e | |
if value_error or len(sampled_data) != size: | |
if ignore_insufficient: | |
logging.warning( | |
"Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}". | |
format(self.count(), size) | |
) | |
else: | |
raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count())) | |
sampled_data = self._independence(sampled_data) | |
return sampled_data | |
def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool: | |
""" | |
Overview: | |
the method that update data and the related meta information with a certain index. | |
Arguments: | |
- data (:obj:`Any`): The data which is supposed to replace the old one. If you set it\ | |
to ``None``, nothing will happen to the old record. | |
- meta (:obj:`Optional[dict]`): The new dict which is supposed to merge with the old one. | |
""" | |
if not self.indices.has(index): | |
return False | |
i = self.indices.get(index) | |
item = self.storage[i] | |
if data is not None: | |
item.data = data | |
if meta is not None: | |
item.meta = meta | |
for key in self.meta_index: | |
self.meta_index[key][i] = meta[key] if key in meta else None | |
return True | |
def delete(self, indices: Union[str, Iterable[str]]) -> None: | |
""" | |
Overview: | |
The method that delete the data and related meta information by specific indices. | |
Arguments: | |
- indices (Union[str, Iterable[str]]): Where the data to be cleared in the buffer. | |
""" | |
if isinstance(indices, str): | |
indices = [indices] | |
del_idx = [] | |
for index in indices: | |
if self.indices.has(index): | |
del_idx.append(self.indices.get(index)) | |
if len(del_idx) == 0: | |
return | |
del_idx = sorted(del_idx, reverse=True) | |
for idx in del_idx: | |
del self.storage[idx] | |
remain_indices = [item.index for item in self.storage] | |
key_value_pairs = zip(remain_indices, range(len(indices))) | |
self.indices = BufferIndex(self.storage.maxlen, key_value_pairs) | |
def save_data(self, file_name: str): | |
if not os.path.exists(os.path.dirname(file_name)): | |
# If the folder for the specified file does not exist, it will be created. | |
if os.path.dirname(file_name) != "": | |
os.makedirs(os.path.dirname(file_name)) | |
hickle.dump( | |
py_obj=( | |
self.storage, | |
self.indices, | |
self.meta_index, | |
), file_obj=file_name | |
) | |
def load_data(self, file_name: str): | |
self.storage, self.indices, self.meta_index = hickle.load(file_name) | |
def count(self) -> int: | |
""" | |
Overview: | |
The method that returns the current length of the buffer. | |
""" | |
return len(self.storage) | |
def get(self, idx: int) -> BufferedData: | |
""" | |
Overview: | |
The method that returns the BufferedData object given a specific index. | |
""" | |
return self.storage[idx] | |
def clear(self) -> None: | |
""" | |
Overview: | |
The method that clear all data, indices, and the meta information in the buffer. | |
""" | |
self.storage.clear() | |
self.indices.clear() | |
self.meta_index = {} | |
def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData: | |
index = uuid.uuid1().hex | |
if meta is None: | |
meta = {} | |
buffered = BufferedData(data=data, index=index, meta=meta) | |
self.storage.append(buffered) | |
self.indices.append(index) | |
# Add meta index | |
for key in self.meta_index: | |
self.meta_index[key].append(meta[key] if key in meta else None) | |
return buffered | |
def _independence( | |
self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]] | |
) -> Union[List[BufferedData], List[List[BufferedData]]]: | |
""" | |
Overview: | |
Make sure that each record is different from each other, but remember that this function | |
is different from clone_object. You may change the data in the buffer by modifying a record. | |
Arguments: | |
- buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data, | |
can be nested if groupby has been set. | |
""" | |
if len(buffered_samples) == 0: | |
return buffered_samples | |
occurred = defaultdict(int) | |
for i, buffered in enumerate(buffered_samples): | |
if isinstance(buffered, list): | |
sampled_list = buffered | |
# Loop over nested samples | |
for j, buffered in enumerate(sampled_list): | |
occurred[buffered.index] += 1 | |
if occurred[buffered.index] > 1: | |
sampled_list[j] = fastcopy.copy(buffered) | |
elif isinstance(buffered, BufferedData): | |
occurred[buffered.index] += 1 | |
if occurred[buffered.index] > 1: | |
buffered_samples[i] = fastcopy.copy(buffered) | |
else: | |
raise Exception("Get unexpected buffered type {}".format(type(buffered))) | |
return buffered_samples | |
def _sample_by_group( | |
self, | |
size: int, | |
groupby: str, | |
replace: bool = False, | |
unroll_len: Optional[int] = None, | |
storage: deque = None, | |
sliced: bool = False | |
) -> List[List[BufferedData]]: | |
""" | |
Overview: | |
Sampling by `group` instead of records, the result will be a collection | |
of lists with a length of `size`, but the length of each list may be different from other lists. | |
""" | |
if storage is None: | |
storage = self.storage | |
if groupby not in self.meta_index: | |
self._create_index(groupby) | |
def filter_by_unroll_len(): | |
"Filter groups by unroll len, ensure count of items in each group is greater than unroll_len." | |
group_count = Counter(self.meta_index[groupby]) | |
group_names = [] | |
for key, count in group_count.items(): | |
if count >= unroll_len: | |
group_names.append(key) | |
return group_names | |
if unroll_len and unroll_len > 1: | |
group_names = filter_by_unroll_len() | |
if len(group_names) == 0: | |
return [] | |
else: | |
group_names = list(set(self.meta_index[groupby])) | |
sampled_groups = [] | |
if replace: | |
sampled_groups = random.choices(group_names, k=size) | |
else: | |
try: | |
sampled_groups = random.sample(group_names, k=size) | |
except ValueError: | |
raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(group_names))) | |
# Build dict like {"group name": [records]} | |
sampled_data = defaultdict(list) | |
for buffered in storage: | |
meta_value = buffered.meta[groupby] if groupby in buffered.meta else None | |
if meta_value in sampled_groups: | |
sampled_data[buffered.meta[groupby]].append(buffered) | |
final_sampled_data = [] | |
for group in sampled_groups: | |
seq_data = sampled_data[group] | |
# Filter records by unroll_len | |
if unroll_len: | |
# slice b unroll_len. If don’t do this, more likely obtain duplicate data, \ | |
# and the training will easily crash. | |
if sliced: | |
start_indice = random.choice(range(max(1, len(seq_data)))) | |
start_indice = start_indice // unroll_len | |
if start_indice == (len(seq_data) - 1) // unroll_len: | |
seq_data = seq_data[-unroll_len:] | |
else: | |
seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len] | |
else: | |
start_indice = random.choice(range(max(1, len(seq_data) - unroll_len))) | |
seq_data = seq_data[start_indice:start_indice + unroll_len] | |
final_sampled_data.append(seq_data) | |
return final_sampled_data | |
def _create_index(self, meta_key: str): | |
self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen) | |
for data in self.storage: | |
self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None) | |
def __iter__(self) -> deque: | |
return iter(self.storage) | |
def __copy__(self) -> "DequeBuffer": | |
buffer = type(self)(size=self.storage.maxlen) | |
buffer.storage = self.storage | |
buffer.meta_index = self.meta_index | |
buffer.indices = self.indices | |
return buffer | |