import logging import random from collections import OrderedDict from s3prl.dataio.sampler import SortedBucketingSampler, SortedSliceSampler logger = logging.getLogger(__name__) def test_sorted_slice_sampler(): batch_size = 16 max_length = 16000 * 5 lengths = [random.randint(16000 * 3, 16000 * 8) for index in range(1000)] sampler = SortedSliceSampler( lengths, batch_size=batch_size, max_length=max_length, ) for epoch in range(5): sampler.set_epoch(epoch) id2length = lengths for batch_ids in sampler: batch_lengths = [id2length[idx] for idx in batch_ids] assert sorted(batch_lengths, reverse=True) == batch_lengths if batch_lengths[0] > max_length: assert len(batch_lengths) == batch_size // 2 other_batch_sizes = [ len(batch) for batch in sampler if len(batch) not in [batch_size, batch_size // 2] ] assert len(set(other_batch_sizes)) == len(other_batch_sizes) assert len(sampler) == len(lengths) def test_sorted_bucketing_sampler(): batch_size = 16 max_length = 16000 * 5 lengths = [random.randint(16000 * 3, 16000 * 8) for index in range(1000)] sampler = SortedBucketingSampler( lengths, batch_size=batch_size, max_length=max_length, shuffle=False, ) for epoch in range(5): sampler.set_epoch(epoch) id2length = lengths for batch_ids in sampler: batch_lengths = [id2length[idx] for idx in batch_ids] assert sorted(batch_lengths, reverse=True) == batch_lengths if batch_lengths[0] > max_length: assert len(batch_lengths) == batch_size // 2 batch_sizes = [len(batch_indices) for batch_indices in sampler] other_batch_sizes = [ batch_size for batch_size in batch_sizes if batch_size not in [batch_size, batch_size // 2] ] assert len(other_batch_sizes) <= 1 assert len(lengths) / 16 < len(sampler) < len(lengths) / 8