|
import numpy as np |
|
import pytest |
|
from scipy.sparse.csgraph import connected_components |
|
|
|
from sklearn.metrics.pairwise import pairwise_distances |
|
from sklearn.neighbors import kneighbors_graph |
|
from sklearn.utils.graph import _fix_connected_components |
|
|
|
|
|
def test_fix_connected_components(): |
|
|
|
X = np.array([0, 1, 2, 5, 6, 7])[:, None] |
|
graph = kneighbors_graph(X, n_neighbors=2, mode="distance") |
|
|
|
n_connected_components, labels = connected_components(graph) |
|
assert n_connected_components > 1 |
|
|
|
graph = _fix_connected_components(X, graph, n_connected_components, labels) |
|
|
|
n_connected_components, labels = connected_components(graph) |
|
assert n_connected_components == 1 |
|
|
|
|
|
def test_fix_connected_components_precomputed(): |
|
|
|
X = np.array([0, 1, 2, 5, 6, 7])[:, None] |
|
graph = kneighbors_graph(X, n_neighbors=2, mode="distance") |
|
|
|
n_connected_components, labels = connected_components(graph) |
|
assert n_connected_components > 1 |
|
|
|
distances = pairwise_distances(X) |
|
graph = _fix_connected_components( |
|
distances, graph, n_connected_components, labels, metric="precomputed" |
|
) |
|
|
|
n_connected_components, labels = connected_components(graph) |
|
assert n_connected_components == 1 |
|
|
|
|
|
with pytest.raises(RuntimeError, match="does not work with a sparse"): |
|
_fix_connected_components( |
|
graph, graph, n_connected_components, labels, metric="precomputed" |
|
) |
|
|
|
|
|
def test_fix_connected_components_wrong_mode(): |
|
|
|
X = np.array([0, 1, 2, 5, 6, 7])[:, None] |
|
graph = kneighbors_graph(X, n_neighbors=2, mode="distance") |
|
n_connected_components, labels = connected_components(graph) |
|
|
|
with pytest.raises(ValueError, match="Unknown mode"): |
|
graph = _fix_connected_components( |
|
X, graph, n_connected_components, labels, mode="foo" |
|
) |
|
|
|
|
|
def test_fix_connected_components_connectivity_mode(): |
|
|
|
X = np.array([0, 1, 6, 7])[:, None] |
|
graph = kneighbors_graph(X, n_neighbors=1, mode="connectivity") |
|
n_connected_components, labels = connected_components(graph) |
|
graph = _fix_connected_components( |
|
X, graph, n_connected_components, labels, mode="connectivity" |
|
) |
|
assert np.all(graph.data == 1) |
|
|
|
|
|
def test_fix_connected_components_distance_mode(): |
|
|
|
X = np.array([0, 1, 6, 7])[:, None] |
|
graph = kneighbors_graph(X, n_neighbors=1, mode="distance") |
|
assert np.all(graph.data == 1) |
|
|
|
n_connected_components, labels = connected_components(graph) |
|
graph = _fix_connected_components( |
|
X, graph, n_connected_components, labels, mode="distance" |
|
) |
|
assert not np.all(graph.data == 1) |
|
|