Spaces:
Sleeping
Sleeping
import random | |
from typing import Callable, Union, List | |
from ding.data.buffer import BufferedData | |
from ding.utils import fastcopy | |
def padding(policy="random"): | |
""" | |
Overview: | |
Fill the nested buffer list to the same size as the largest list. | |
The default policy `random` will randomly select data from each group | |
and fill it into the current group list. | |
Arguments: | |
- policy (:obj:`str`): Padding policy, supports `random`, `none`. | |
""" | |
def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]: | |
sampled_data = chain(*args, **kwargs) | |
if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData): | |
return sampled_data | |
max_len = len(max(sampled_data, key=len)) | |
for i, grouped_data in enumerate(sampled_data): | |
group_len = len(grouped_data) | |
if group_len == max_len: | |
continue | |
for _ in range(max_len - group_len): | |
if policy == "random": | |
sampled_data[i].append(fastcopy.copy(random.choice(grouped_data))) | |
elif policy == "none": | |
sampled_data[i].append(BufferedData(data=None, index=None, meta=None)) | |
return sampled_data | |
def _padding(action: str, chain: Callable, *args, **kwargs): | |
if action == "sample": | |
return sample(chain, *args, **kwargs) | |
return chain(*args, **kwargs) | |
return _padding | |