File size: 780 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
from collections import Counter

import pytest

from s3prl.dataio.sampler import BalancedWeightedSampler


@pytest.mark.parametrize("duplicate", [10000, 100000])
def test_balanced_weighted_sampler(duplicate: int):
    labels = ["a", "a", "b", "a"]
    batch_size = 5
    prev_diff_ratio = 1.0
    sampler = BalancedWeightedSampler(
        labels, batch_size=batch_size, duplicate=duplicate, seed=0
    )
    indices = list(sampler)
    assert len(indices[0]) == batch_size

    counter = Counter()
    for batch_indices in indices:
        for idx in batch_indices:
            counter.update(labels[idx])

    diff_ratio = abs(counter["a"] - counter["b"]) / duplicate * len(labels)
    assert diff_ratio < prev_diff_ratio
    prev_diff_ratio = diff_ratio

    diff_ratio < 0.05