lmzjms's picture
Upload 1162 files
0b32ad6 verified
import logging
import pytest
from s3prl.dataio.sampler import (
DistributedBatchSamplerWrapper,
FixedBatchSizeBatchSampler,
MaxTimestampBatchSampler,
)
logger = logging.getLogger(__name__)
def _merge_batch_indices(batch_indices):
all_indices = []
for indices in batch_indices:
all_indices += indices
return all_indices
@pytest.mark.parametrize("world_size", [1, 2, 3, 4, 5, 6, 7, 8])
def test_distributed_sampler(world_size):
sampler = [[1, 2, 3], [4, 5, 6, 7], [8], [9, 10]]
ddp_indices = []
for rank in range(world_size):
ddp_sampler = DistributedBatchSamplerWrapper(sampler, world_size, rank)
ddp_indices += _merge_batch_indices(ddp_sampler)
assert sorted(ddp_indices) == sorted(_merge_batch_indices(sampler))
timestamps = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
@pytest.mark.parametrize("batch_size", [1, 2, 3, len(data)])
def test_FixedBatchSizeBatchSampler(batch_size):
dataset = data
iter1 = list(iter(FixedBatchSizeBatchSampler(dataset, batch_size, shuffle=False)))
iter2 = list(iter(FixedBatchSizeBatchSampler(dataset, batch_size, shuffle=True)))
indices1 = sorted(_merge_batch_indices(iter1))
indices2 = sorted(_merge_batch_indices(iter2))
assert indices1 == indices2 == list(range(len(timestamps)))