File size: 3,667 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""
Testing for the base module (sklearn.ensemble.base).
"""

# Authors: The scikit-learn developers
# SPDX-License-Identifier: BSD-3-Clause

from collections import OrderedDict

import numpy as np

from sklearn.datasets import load_iris
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.ensemble import BaggingClassifier
from sklearn.ensemble._base import _set_random_states
from sklearn.feature_selection import SelectFromModel
from sklearn.linear_model import Perceptron
from sklearn.pipeline import Pipeline


def test_base():
    # Check BaseEnsemble methods.
    ensemble = BaggingClassifier(
        estimator=Perceptron(random_state=None), n_estimators=3
    )

    iris = load_iris()
    ensemble.fit(iris.data, iris.target)
    ensemble.estimators_ = []  # empty the list and create estimators manually

    ensemble._make_estimator()
    random_state = np.random.RandomState(3)
    ensemble._make_estimator(random_state=random_state)
    ensemble._make_estimator(random_state=random_state)
    ensemble._make_estimator(append=False)

    assert 3 == len(ensemble)
    assert 3 == len(ensemble.estimators_)

    assert isinstance(ensemble[0], Perceptron)
    assert ensemble[0].random_state is None
    assert isinstance(ensemble[1].random_state, int)
    assert isinstance(ensemble[2].random_state, int)
    assert ensemble[1].random_state != ensemble[2].random_state

    np_int_ensemble = BaggingClassifier(
        estimator=Perceptron(), n_estimators=np.int32(3)
    )
    np_int_ensemble.fit(iris.data, iris.target)


def test_set_random_states():
    # Linear Discriminant Analysis doesn't have random state: smoke test
    _set_random_states(LinearDiscriminantAnalysis(), random_state=17)

    clf1 = Perceptron(random_state=None)
    assert clf1.random_state is None
    # check random_state is None still sets
    _set_random_states(clf1, None)
    assert isinstance(clf1.random_state, int)

    # check random_state fixes results in consistent initialisation
    _set_random_states(clf1, 3)
    assert isinstance(clf1.random_state, int)
    clf2 = Perceptron(random_state=None)
    _set_random_states(clf2, 3)
    assert clf1.random_state == clf2.random_state

    # nested random_state

    def make_steps():
        return [
            ("sel", SelectFromModel(Perceptron(random_state=None))),
            ("clf", Perceptron(random_state=None)),
        ]

    est1 = Pipeline(make_steps())
    _set_random_states(est1, 3)
    assert isinstance(est1.steps[0][1].estimator.random_state, int)
    assert isinstance(est1.steps[1][1].random_state, int)
    assert (
        est1.get_params()["sel__estimator__random_state"]
        != est1.get_params()["clf__random_state"]
    )

    # ensure multiple random_state parameters are invariant to get_params()
    # iteration order

    class AlphaParamPipeline(Pipeline):
        def get_params(self, *args, **kwargs):
            params = Pipeline.get_params(self, *args, **kwargs).items()
            return OrderedDict(sorted(params))

    class RevParamPipeline(Pipeline):
        def get_params(self, *args, **kwargs):
            params = Pipeline.get_params(self, *args, **kwargs).items()
            return OrderedDict(sorted(params, reverse=True))

    for cls in [AlphaParamPipeline, RevParamPipeline]:
        est2 = cls(make_steps())
        _set_random_states(est2, 3)
        assert (
            est1.get_params()["sel__estimator__random_state"]
            == est2.get_params()["sel__estimator__random_state"]
        )
        assert (
            est1.get_params()["clf__random_state"]
            == est2.get_params()["clf__random_state"]
        )