File size: 5,340 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
"""Test the 20news downloader, if the data is available,
or if specifically requested via environment variable
(e.g. for CI jobs)."""

from functools import partial
from unittest.mock import patch

import numpy as np
import pytest
import scipy.sparse as sp

from sklearn.datasets.tests.test_common import (
    check_as_frame,
    check_pandas_dependency_message,
    check_return_X_y,
)
from sklearn.preprocessing import normalize
from sklearn.utils._testing import assert_allclose_dense_sparse


def test_20news(fetch_20newsgroups_fxt):
    data = fetch_20newsgroups_fxt(subset="all", shuffle=False)
    assert data.DESCR.startswith(".. _20newsgroups_dataset:")

    # Extract a reduced dataset
    data2cats = fetch_20newsgroups_fxt(
        subset="all", categories=data.target_names[-1:-3:-1], shuffle=False
    )
    # Check that the ordering of the target_names is the same
    # as the ordering in the full dataset
    assert data2cats.target_names == data.target_names[-2:]
    # Assert that we have only 0 and 1 as labels
    assert np.unique(data2cats.target).tolist() == [0, 1]

    # Check that the number of filenames is consistent with data/target
    assert len(data2cats.filenames) == len(data2cats.target)
    assert len(data2cats.filenames) == len(data2cats.data)

    # Check that the first entry of the reduced dataset corresponds to
    # the first entry of the corresponding category in the full dataset
    entry1 = data2cats.data[0]
    category = data2cats.target_names[data2cats.target[0]]
    label = data.target_names.index(category)
    entry2 = data.data[np.where(data.target == label)[0][0]]
    assert entry1 == entry2

    # check that return_X_y option
    X, y = fetch_20newsgroups_fxt(subset="all", shuffle=False, return_X_y=True)
    assert len(X) == len(data.data)
    assert y.shape == data.target.shape


def test_20news_length_consistency(fetch_20newsgroups_fxt):
    """Checks the length consistencies within the bunch

    This is a non-regression test for a bug present in 0.16.1.
    """
    # Extract the full dataset
    data = fetch_20newsgroups_fxt(subset="all")
    assert len(data["data"]) == len(data.data)
    assert len(data["target"]) == len(data.target)
    assert len(data["filenames"]) == len(data.filenames)


def test_20news_vectorized(fetch_20newsgroups_vectorized_fxt):
    # test subset = train
    bunch = fetch_20newsgroups_vectorized_fxt(subset="train")
    assert sp.issparse(bunch.data) and bunch.data.format == "csr"
    assert bunch.data.shape == (11314, 130107)
    assert bunch.target.shape[0] == 11314
    assert bunch.data.dtype == np.float64
    assert bunch.DESCR.startswith(".. _20newsgroups_dataset:")

    # test subset = test
    bunch = fetch_20newsgroups_vectorized_fxt(subset="test")
    assert sp.issparse(bunch.data) and bunch.data.format == "csr"
    assert bunch.data.shape == (7532, 130107)
    assert bunch.target.shape[0] == 7532
    assert bunch.data.dtype == np.float64
    assert bunch.DESCR.startswith(".. _20newsgroups_dataset:")

    # test return_X_y option
    fetch_func = partial(fetch_20newsgroups_vectorized_fxt, subset="test")
    check_return_X_y(bunch, fetch_func)

    # test subset = all
    bunch = fetch_20newsgroups_vectorized_fxt(subset="all")
    assert sp.issparse(bunch.data) and bunch.data.format == "csr"
    assert bunch.data.shape == (11314 + 7532, 130107)
    assert bunch.target.shape[0] == 11314 + 7532
    assert bunch.data.dtype == np.float64
    assert bunch.DESCR.startswith(".. _20newsgroups_dataset:")


def test_20news_normalization(fetch_20newsgroups_vectorized_fxt):
    X = fetch_20newsgroups_vectorized_fxt(normalize=False)
    X_ = fetch_20newsgroups_vectorized_fxt(normalize=True)
    X_norm = X_["data"][:100]
    X = X["data"][:100]

    assert_allclose_dense_sparse(X_norm, normalize(X))
    assert np.allclose(np.linalg.norm(X_norm.todense(), axis=1), 1)


def test_20news_as_frame(fetch_20newsgroups_vectorized_fxt):
    pd = pytest.importorskip("pandas")

    bunch = fetch_20newsgroups_vectorized_fxt(as_frame=True)
    check_as_frame(bunch, fetch_20newsgroups_vectorized_fxt)

    frame = bunch.frame
    assert frame.shape == (11314, 130108)
    assert all([isinstance(col, pd.SparseDtype) for col in bunch.data.dtypes])

    # Check a small subset of features
    for expected_feature in [
        "beginner",
        "beginners",
        "beginning",
        "beginnings",
        "begins",
        "begley",
        "begone",
    ]:
        assert expected_feature in frame.keys()
    assert "category_class" in frame.keys()
    assert bunch.target.name == "category_class"


def test_as_frame_no_pandas(fetch_20newsgroups_vectorized_fxt, hide_available_pandas):
    check_pandas_dependency_message(fetch_20newsgroups_vectorized_fxt)


def test_outdated_pickle(fetch_20newsgroups_vectorized_fxt):
    with patch("os.path.exists") as mock_is_exist:
        with patch("joblib.load") as mock_load:
            # mock that the dataset was cached
            mock_is_exist.return_value = True
            # mock that we have an outdated pickle with only X and y returned
            mock_load.return_value = ("X", "y")
            err_msg = "The cached dataset located in"
            with pytest.raises(ValueError, match=err_msg):
                fetch_20newsgroups_vectorized_fxt(as_frame=True)