import pytest | |
from sklearn.utils._mask import safe_mask | |
from sklearn.utils.fixes import CSR_CONTAINERS | |
from sklearn.utils.validation import check_random_state | |
def test_safe_mask(csr_container): | |
random_state = check_random_state(0) | |
X = random_state.rand(5, 4) | |
X_csr = csr_container(X) | |
mask = [False, False, True, True, True] | |
mask = safe_mask(X, mask) | |
assert X[mask].shape[0] == 3 | |
mask = safe_mask(X_csr, mask) | |
assert X_csr[mask].shape[0] == 3 | |