File size: 2,371 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import warnings
from itertools import chain
import pytest
from sklearn import config_context
from sklearn.utils._chunking import gen_even_slices, get_chunk_n_rows
from sklearn.utils._testing import assert_array_equal
def test_gen_even_slices():
# check that gen_even_slices contains all samples
some_range = range(10)
joined_range = list(chain(*[some_range[slice] for slice in gen_even_slices(10, 3)]))
assert_array_equal(some_range, joined_range)
@pytest.mark.parametrize(
("row_bytes", "max_n_rows", "working_memory", "expected"),
[
(1024, None, 1, 1024),
(1024, None, 0.99999999, 1023),
(1023, None, 1, 1025),
(1025, None, 1, 1023),
(1024, None, 2, 2048),
(1024, 7, 1, 7),
(1024 * 1024, None, 1, 1),
],
)
def test_get_chunk_n_rows(row_bytes, max_n_rows, working_memory, expected):
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
actual = get_chunk_n_rows(
row_bytes=row_bytes,
max_n_rows=max_n_rows,
working_memory=working_memory,
)
assert actual == expected
assert type(actual) is type(expected)
with config_context(working_memory=working_memory):
with warnings.catch_warnings():
warnings.simplefilter("error", UserWarning)
actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
assert actual == expected
assert type(actual) is type(expected)
def test_get_chunk_n_rows_warns():
"""Check that warning is raised when working_memory is too low."""
row_bytes = 1024 * 1024 + 1
max_n_rows = None
working_memory = 1
expected = 1
warn_msg = (
"Could not adhere to working_memory config. Currently 1MiB, 2MiB required."
)
with pytest.warns(UserWarning, match=warn_msg):
actual = get_chunk_n_rows(
row_bytes=row_bytes,
max_n_rows=max_n_rows,
working_memory=working_memory,
)
assert actual == expected
assert type(actual) is type(expected)
with config_context(working_memory=working_memory):
with pytest.warns(UserWarning, match=warn_msg):
actual = get_chunk_n_rows(row_bytes=row_bytes, max_n_rows=max_n_rows)
assert actual == expected
assert type(actual) is type(expected)
|