|
import logging |
|
|
|
import pytest |
|
|
|
from s3prl.dataio.sampler import ( |
|
DistributedBatchSamplerWrapper, |
|
FixedBatchSizeBatchSampler, |
|
MaxTimestampBatchSampler, |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def _merge_batch_indices(batch_indices): |
|
all_indices = [] |
|
for indices in batch_indices: |
|
all_indices += indices |
|
return all_indices |
|
|
|
|
|
@pytest.mark.parametrize("world_size", [1, 2, 3, 4, 5, 6, 7, 8]) |
|
def test_distributed_sampler(world_size): |
|
sampler = [[1, 2, 3], [4, 5, 6, 7], [8], [9, 10]] |
|
ddp_indices = [] |
|
for rank in range(world_size): |
|
ddp_sampler = DistributedBatchSamplerWrapper(sampler, world_size, rank) |
|
ddp_indices += _merge_batch_indices(ddp_sampler) |
|
assert sorted(ddp_indices) == sorted(_merge_batch_indices(sampler)) |
|
|
|
|
|
timestamps = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
|
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] |
|
|
|
|
|
@pytest.mark.parametrize("batch_size", [1, 2, 3, len(data)]) |
|
def test_FixedBatchSizeBatchSampler(batch_size): |
|
dataset = data |
|
iter1 = list(iter(FixedBatchSizeBatchSampler(dataset, batch_size, shuffle=False))) |
|
iter2 = list(iter(FixedBatchSizeBatchSampler(dataset, batch_size, shuffle=True))) |
|
indices1 = sorted(_merge_batch_indices(iter1)) |
|
indices2 = sorted(_merge_batch_indices(iter2)) |
|
assert indices1 == indices2 == list(range(len(timestamps))) |
|
|