|
import numpy as np |
|
from scipy.sparse import issparse |
|
from scipy.sparse._sputils import convert_pydata_sparse_to_scipy |
|
from scipy.sparse.csgraph._tools import ( |
|
csgraph_to_dense, csgraph_from_dense, |
|
csgraph_masked_from_dense, csgraph_from_masked |
|
) |
|
|
|
DTYPE = np.float64 |
|
|
|
|
|
def validate_graph(csgraph, directed, dtype=DTYPE, |
|
csr_output=True, dense_output=True, |
|
copy_if_dense=False, copy_if_sparse=False, |
|
null_value_in=0, null_value_out=np.inf, |
|
infinity_null=True, nan_null=True): |
|
"""Routine for validation and conversion of csgraph inputs""" |
|
if not (csr_output or dense_output): |
|
raise ValueError("Internal: dense or csr output must be true") |
|
|
|
accept_fv = [null_value_in] |
|
if infinity_null: |
|
accept_fv.append(np.inf) |
|
if nan_null: |
|
accept_fv.append(np.nan) |
|
csgraph = convert_pydata_sparse_to_scipy(csgraph, accept_fv=accept_fv) |
|
|
|
|
|
|
|
if (not directed) and issparse(csgraph) and csgraph.format == "csc": |
|
csgraph = csgraph.T |
|
|
|
if issparse(csgraph): |
|
if csr_output: |
|
csgraph = csgraph.tocsr(copy=copy_if_sparse).astype(DTYPE, copy=False) |
|
else: |
|
csgraph = csgraph_to_dense(csgraph, null_value=null_value_out) |
|
elif np.ma.isMaskedArray(csgraph): |
|
if dense_output: |
|
mask = csgraph.mask |
|
csgraph = np.array(csgraph.data, dtype=DTYPE, copy=copy_if_dense) |
|
csgraph[mask] = null_value_out |
|
else: |
|
csgraph = csgraph_from_masked(csgraph) |
|
else: |
|
if dense_output: |
|
csgraph = csgraph_masked_from_dense(csgraph, |
|
copy=copy_if_dense, |
|
null_value=null_value_in, |
|
nan_null=nan_null, |
|
infinity_null=infinity_null) |
|
mask = csgraph.mask |
|
csgraph = np.asarray(csgraph.data, dtype=DTYPE) |
|
csgraph[mask] = null_value_out |
|
else: |
|
csgraph = csgraph_from_dense(csgraph, null_value=null_value_in, |
|
infinity_null=infinity_null, |
|
nan_null=nan_null) |
|
|
|
if csgraph.ndim != 2: |
|
raise ValueError("compressed-sparse graph must be 2-D") |
|
|
|
if csgraph.shape[0] != csgraph.shape[1]: |
|
raise ValueError("compressed-sparse graph must be shape (N, N)") |
|
|
|
return csgraph |
|
|