|
import pickle |
|
|
|
import numpy as np |
|
import numpy.testing as npt |
|
from numpy.testing import assert_allclose, assert_equal |
|
from pytest import raises as assert_raises |
|
|
|
import numpy.ma.testutils as ma_npt |
|
|
|
from scipy._lib._util import ( |
|
getfullargspec_no_self as _getfullargspec, np_long |
|
) |
|
from scipy._lib._array_api_no_0d import xp_assert_equal |
|
from scipy import stats |
|
|
|
|
|
def check_named_results(res, attributes, ma=False, xp=None): |
|
for i, attr in enumerate(attributes): |
|
if ma: |
|
ma_npt.assert_equal(res[i], getattr(res, attr)) |
|
elif xp is not None: |
|
xp_assert_equal(res[i], getattr(res, attr)) |
|
else: |
|
npt.assert_equal(res[i], getattr(res, attr)) |
|
|
|
|
|
def check_normalization(distfn, args, distname): |
|
norm_moment = distfn.moment(0, *args) |
|
npt.assert_allclose(norm_moment, 1.0) |
|
|
|
if distname == "rv_histogram_instance": |
|
atol, rtol = 1e-5, 0 |
|
else: |
|
atol, rtol = 1e-7, 1e-7 |
|
|
|
normalization_expect = distfn.expect(lambda x: 1, args=args) |
|
npt.assert_allclose(normalization_expect, 1.0, atol=atol, rtol=rtol, |
|
err_msg=distname, verbose=True) |
|
|
|
_a, _b = distfn.support(*args) |
|
normalization_cdf = distfn.cdf(_b, *args) |
|
npt.assert_allclose(normalization_cdf, 1.0) |
|
|
|
|
|
def check_moment(distfn, arg, m, v, msg): |
|
m1 = distfn.moment(1, *arg) |
|
m2 = distfn.moment(2, *arg) |
|
if not np.isinf(m): |
|
npt.assert_almost_equal(m1, m, decimal=10, |
|
err_msg=msg + ' - 1st moment') |
|
else: |
|
npt.assert_(np.isinf(m1), |
|
msg + f' - 1st moment -infinite, m1={str(m1)}') |
|
|
|
if not np.isinf(v): |
|
npt.assert_almost_equal(m2 - m1 * m1, v, decimal=10, |
|
err_msg=msg + ' - 2ndt moment') |
|
else: |
|
npt.assert_(np.isinf(m2), msg + f' - 2nd moment -infinite, {m2=}') |
|
|
|
|
|
def check_mean_expect(distfn, arg, m, msg): |
|
if np.isfinite(m): |
|
m1 = distfn.expect(lambda x: x, arg) |
|
npt.assert_almost_equal(m1, m, decimal=5, |
|
err_msg=msg + ' - 1st moment (expect)') |
|
|
|
|
|
def check_var_expect(distfn, arg, m, v, msg): |
|
dist_looser_tolerances = {"rv_histogram_instance" , "ksone"} |
|
kwargs = {'rtol': 5e-6} if msg in dist_looser_tolerances else {} |
|
if np.isfinite(v): |
|
m2 = distfn.expect(lambda x: x*x, arg) |
|
npt.assert_allclose(m2, v + m*m, **kwargs) |
|
|
|
|
|
def check_skew_expect(distfn, arg, m, v, s, msg): |
|
if np.isfinite(s): |
|
m3e = distfn.expect(lambda x: np.power(x-m, 3), arg) |
|
npt.assert_almost_equal(m3e, s * np.power(v, 1.5), |
|
decimal=5, err_msg=msg + ' - skew') |
|
else: |
|
npt.assert_(np.isnan(s)) |
|
|
|
|
|
def check_kurt_expect(distfn, arg, m, v, k, msg): |
|
if np.isfinite(k): |
|
m4e = distfn.expect(lambda x: np.power(x-m, 4), arg) |
|
npt.assert_allclose(m4e, (k + 3.) * np.power(v, 2), |
|
atol=1e-5, rtol=1e-5, |
|
err_msg=msg + ' - kurtosis') |
|
elif not np.isposinf(k): |
|
npt.assert_(np.isnan(k)) |
|
|
|
|
|
def check_munp_expect(dist, args, msg): |
|
|
|
|
|
if dist._munp.__func__ != stats.rv_continuous._munp: |
|
res = dist.moment(5, *args) |
|
ref = dist.expect(lambda x: x ** 5, args, lb=-np.inf, ub=np.inf) |
|
if not np.isfinite(res): |
|
return |
|
|
|
assert_allclose(res, ref, atol=1e-10, rtol=1e-4, |
|
err_msg=msg + ' - higher moment / _munp') |
|
|
|
|
|
def check_entropy(distfn, arg, msg): |
|
ent = distfn.entropy(*arg) |
|
npt.assert_(not np.isnan(ent), msg + 'test Entropy is nan') |
|
|
|
|
|
def check_private_entropy(distfn, args, superclass): |
|
|
|
npt.assert_allclose(distfn._entropy(*args), |
|
superclass._entropy(distfn, *args)) |
|
|
|
|
|
def check_entropy_vect_scale(distfn, arg): |
|
|
|
sc = np.asarray([[1, 2], [3, 4]]) |
|
v_ent = distfn.entropy(*arg, scale=sc) |
|
s_ent = [distfn.entropy(*arg, scale=s) for s in sc.ravel()] |
|
s_ent = np.asarray(s_ent).reshape(v_ent.shape) |
|
assert_allclose(v_ent, s_ent, atol=1e-14) |
|
|
|
|
|
sc = [1, 2, -3] |
|
v_ent = distfn.entropy(*arg, scale=sc) |
|
s_ent = [distfn.entropy(*arg, scale=s) for s in sc] |
|
s_ent = np.asarray(s_ent).reshape(v_ent.shape) |
|
assert_allclose(v_ent, s_ent, atol=1e-14) |
|
|
|
|
|
def check_edge_support(distfn, args): |
|
|
|
x = distfn.support(*args) |
|
if isinstance(distfn, stats.rv_discrete): |
|
x = x[0]-1, x[1] |
|
|
|
npt.assert_equal(distfn.cdf(x, *args), [0.0, 1.0]) |
|
npt.assert_equal(distfn.sf(x, *args), [1.0, 0.0]) |
|
|
|
if distfn.name not in ('skellam', 'dlaplace'): |
|
|
|
npt.assert_equal(distfn.logcdf(x, *args), [-np.inf, 0.0]) |
|
npt.assert_equal(distfn.logsf(x, *args), [0.0, -np.inf]) |
|
|
|
npt.assert_equal(distfn.ppf([0.0, 1.0], *args), x) |
|
npt.assert_equal(distfn.isf([0.0, 1.0], *args), x[::-1]) |
|
|
|
|
|
npt.assert_(np.isnan(distfn.isf([-1, 2], *args)).all()) |
|
npt.assert_(np.isnan(distfn.ppf([-1, 2], *args)).all()) |
|
|
|
|
|
def check_named_args(distfn, x, shape_args, defaults, meths): |
|
|
|
|
|
|
|
signature = _getfullargspec(distfn._parse_args) |
|
npt.assert_(signature.varargs is None) |
|
npt.assert_(signature.varkw is None) |
|
npt.assert_(not signature.kwonlyargs) |
|
npt.assert_(list(signature.defaults) == list(defaults)) |
|
|
|
shape_argnames = signature.args[:-len(defaults)] |
|
if distfn.shapes: |
|
shapes_ = distfn.shapes.replace(',', ' ').split() |
|
else: |
|
shapes_ = '' |
|
npt.assert_(len(shapes_) == distfn.numargs) |
|
npt.assert_(len(shapes_) == len(shape_argnames)) |
|
|
|
|
|
shape_args = list(shape_args) |
|
|
|
vals = [meth(x, *shape_args) for meth in meths] |
|
npt.assert_(np.all(np.isfinite(vals))) |
|
|
|
names, a, k = shape_argnames[:], shape_args[:], {} |
|
while names: |
|
k.update({names.pop(): a.pop()}) |
|
v = [meth(x, *a, **k) for meth in meths] |
|
npt.assert_array_equal(vals, v) |
|
if 'n' not in k.keys(): |
|
|
|
npt.assert_equal(distfn.moment(1, *a, **k), |
|
distfn.moment(1, *shape_args)) |
|
|
|
|
|
k.update({'kaboom': 42}) |
|
assert_raises(TypeError, distfn.cdf, x, **k) |
|
|
|
|
|
def check_random_state_property(distfn, args): |
|
|
|
|
|
|
|
|
|
rndm = distfn.random_state |
|
|
|
|
|
np.random.seed(1234) |
|
distfn.random_state = None |
|
r0 = distfn.rvs(*args, size=8) |
|
|
|
|
|
distfn.random_state = 1234 |
|
r1 = distfn.rvs(*args, size=8) |
|
npt.assert_equal(r0, r1) |
|
|
|
distfn.random_state = np.random.RandomState(1234) |
|
r2 = distfn.rvs(*args, size=8) |
|
npt.assert_equal(r0, r2) |
|
|
|
|
|
if hasattr(np.random, 'default_rng'): |
|
|
|
rng = np.random.default_rng(1234) |
|
distfn.rvs(*args, size=1, random_state=rng) |
|
|
|
|
|
distfn.random_state = 2 |
|
orig_state = distfn.random_state.get_state() |
|
|
|
r3 = distfn.rvs(*args, size=8, random_state=np.random.RandomState(1234)) |
|
npt.assert_equal(r0, r3) |
|
|
|
|
|
npt.assert_equal(distfn.random_state.get_state(), orig_state) |
|
|
|
|
|
distfn.random_state = rndm |
|
|
|
|
|
def check_meth_dtype(distfn, arg, meths): |
|
q0 = [0.25, 0.5, 0.75] |
|
x0 = distfn.ppf(q0, *arg) |
|
x_cast = [x0.astype(tp) for tp in (np_long, np.float16, np.float32, |
|
np.float64)] |
|
|
|
for x in x_cast: |
|
|
|
distfn._argcheck(*arg) |
|
x = x[(distfn.a < x) & (x < distfn.b)] |
|
for meth in meths: |
|
val = meth(x, *arg) |
|
npt.assert_(val.dtype == np.float64) |
|
|
|
|
|
def check_ppf_dtype(distfn, arg): |
|
q0 = np.asarray([0.25, 0.5, 0.75]) |
|
q_cast = [q0.astype(tp) for tp in (np.float16, np.float32, np.float64)] |
|
for q in q_cast: |
|
for meth in [distfn.ppf, distfn.isf]: |
|
val = meth(q, *arg) |
|
npt.assert_(val.dtype == np.float64) |
|
|
|
|
|
def check_cmplx_deriv(distfn, arg): |
|
|
|
def deriv(f, x, *arg): |
|
x = np.asarray(x) |
|
h = 1e-10 |
|
return (f(x + h*1j, *arg)/h).imag |
|
|
|
x0 = distfn.ppf([0.25, 0.51, 0.75], *arg) |
|
x_cast = [x0.astype(tp) for tp in (np_long, np.float16, np.float32, |
|
np.float64)] |
|
|
|
for x in x_cast: |
|
|
|
distfn._argcheck(*arg) |
|
x = x[(distfn.a < x) & (x < distfn.b)] |
|
|
|
pdf, cdf, sf = distfn.pdf(x, *arg), distfn.cdf(x, *arg), distfn.sf(x, *arg) |
|
assert_allclose(deriv(distfn.cdf, x, *arg), pdf, rtol=1e-5) |
|
assert_allclose(deriv(distfn.logcdf, x, *arg), pdf/cdf, rtol=1e-5) |
|
|
|
assert_allclose(deriv(distfn.sf, x, *arg), -pdf, rtol=1e-5) |
|
assert_allclose(deriv(distfn.logsf, x, *arg), -pdf/sf, rtol=1e-5) |
|
|
|
assert_allclose(deriv(distfn.logpdf, x, *arg), |
|
deriv(distfn.pdf, x, *arg) / distfn.pdf(x, *arg), |
|
rtol=1e-5) |
|
|
|
|
|
def check_pickling(distfn, args): |
|
|
|
|
|
|
|
|
|
rndm = distfn.random_state |
|
|
|
|
|
distfn.random_state = 1234 |
|
distfn.rvs(*args, size=8) |
|
s = pickle.dumps(distfn) |
|
r0 = distfn.rvs(*args, size=8) |
|
|
|
unpickled = pickle.loads(s) |
|
r1 = unpickled.rvs(*args, size=8) |
|
npt.assert_equal(r0, r1) |
|
|
|
|
|
medians = [distfn.ppf(0.5, *args), unpickled.ppf(0.5, *args)] |
|
npt.assert_equal(medians[0], medians[1]) |
|
npt.assert_equal(distfn.cdf(medians[0], *args), |
|
unpickled.cdf(medians[1], *args)) |
|
|
|
|
|
frozen_dist = distfn(*args) |
|
pkl = pickle.dumps(frozen_dist) |
|
unpickled = pickle.loads(pkl) |
|
|
|
r0 = frozen_dist.rvs(size=8) |
|
r1 = unpickled.rvs(size=8) |
|
npt.assert_equal(r0, r1) |
|
|
|
|
|
if hasattr(distfn, "fit"): |
|
fit_function = distfn.fit |
|
pickled_fit_function = pickle.dumps(fit_function) |
|
unpickled_fit_function = pickle.loads(pickled_fit_function) |
|
assert fit_function.__name__ == unpickled_fit_function.__name__ == "fit" |
|
|
|
|
|
distfn.random_state = rndm |
|
|
|
|
|
def check_freezing(distfn, args): |
|
|
|
|
|
if isinstance(distfn, stats.rv_continuous): |
|
locscale = {'loc': 1, 'scale': 2} |
|
else: |
|
locscale = {'loc': 1} |
|
|
|
rv = distfn(*args, **locscale) |
|
assert rv.a == distfn(*args).a |
|
assert rv.b == distfn(*args).b |
|
|
|
|
|
def check_rvs_broadcast(distfunc, distname, allargs, shape, shape_only, otype): |
|
np.random.seed(123) |
|
sample = distfunc.rvs(*allargs) |
|
assert_equal(sample.shape, shape, f"{distname}: rvs failed to broadcast") |
|
if not shape_only: |
|
rvs = np.vectorize(lambda *allargs: distfunc.rvs(*allargs), otypes=otype) |
|
np.random.seed(123) |
|
expected = rvs(*allargs) |
|
assert_allclose(sample, expected, rtol=1e-13) |
|
|