Spaces:
Running
Running
File size: 1,934 Bytes
67a9b5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
import numbers
from collections.abc import Sequence
import numpy as np
def split_by_num(x, num_splits, strict=True):
"""
Args:
num_splits: an integer indicating the number of splits
References:
numpy.split and numpy.array_split
"""
# NB: np.ndarray is not Sequence
assert isinstance(x, (Sequence, np.ndarray))
assert isinstance(num_splits, numbers.Integral)
if strict:
assert len(x) % num_splits == 0
split_size = (len(x) + num_splits - 1) // num_splits
out_list = []
for i in range(0, len(x), split_size):
out_list.append(x[i: i + split_size])
return out_list
def split_by_size(x, sizes):
"""
References:
tf.split
https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils/misc.py
"""
# NB: np.ndarray is not Sequence
assert isinstance(x, (Sequence, np.ndarray))
assert isinstance(sizes, (list, tuple))
assert sum(sizes) == len(x)
out_list = []
start_index = 0
for size in sizes:
out_list.append(x[start_index: start_index + size])
start_index += size
return out_list
def split_by_slice(x, slices):
"""
References:
SliceLayer in Caffe, and numpy.split
"""
# NB: np.ndarray is not Sequence
assert isinstance(x, (Sequence, np.ndarray))
assert isinstance(slices, (list, tuple))
out_list = []
indices = [0] + list(slices) + [len(x)]
for i in range(len(slices) + 1):
out_list.append(x[indices[i]: indices[i + 1]])
return out_list
def split_by_ratio(x, ratios):
# NB: np.ndarray is not Sequence
assert isinstance(x, (Sequence, np.ndarray))
assert isinstance(ratios, (list, tuple))
pdf = [k / sum(ratios) for k in ratios]
cdf = [sum(pdf[:k]) for k in range(len(pdf) + 1)]
indices = [int(round(len(x) * k)) for k in cdf]
return [x[indices[i]: indices[i + 1]] for i in range(len(ratios))]
|