File size: 2,557 Bytes
7885a28 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
"""Tests for spline filtering."""
import pytest
import numpy as np
from scipy._lib._array_api import assert_almost_equal
from scipy import ndimage
from scipy.conftest import array_api_compatible
skip_xp_backends = pytest.mark.skip_xp_backends
pytestmark = [array_api_compatible, pytest.mark.usefixtures("skip_xp_backends"),
skip_xp_backends(cpu_only=True, exceptions=['cupy', 'jax.numpy'],)]
def get_spline_knot_values(order):
"""Knot values to the right of a B-spline's center."""
knot_values = {0: [1],
1: [1],
2: [6, 1],
3: [4, 1],
4: [230, 76, 1],
5: [66, 26, 1]}
return knot_values[order]
def make_spline_knot_matrix(xp, n, order, mode='mirror'):
"""Matrix to invert to find the spline coefficients."""
knot_values = get_spline_knot_values(order)
# NB: do computations with numpy, convert to xp as the last step only
matrix = np.zeros((n, n))
for diag, knot_value in enumerate(knot_values):
indices = np.arange(diag, n)
if diag == 0:
matrix[indices, indices] = knot_value
else:
matrix[indices, indices - diag] = knot_value
matrix[indices - diag, indices] = knot_value
knot_values_sum = knot_values[0] + 2 * sum(knot_values[1:])
if mode == 'mirror':
start, step = 1, 1
elif mode == 'reflect':
start, step = 0, 1
elif mode == 'grid-wrap':
start, step = -1, -1
else:
raise ValueError(f'unsupported mode {mode}')
for row in range(len(knot_values) - 1):
for idx, knot_value in enumerate(knot_values[row + 1:]):
matrix[row, start + step*idx] += knot_value
matrix[-row - 1, -start - 1 - step*idx] += knot_value
return xp.asarray(matrix / knot_values_sum)
@pytest.mark.parametrize('order', [0, 1, 2, 3, 4, 5])
@pytest.mark.parametrize('mode', ['mirror', 'grid-wrap', 'reflect'])
def test_spline_filter_vs_matrix_solution(order, mode, xp):
n = 100
eye = xp.eye(n, dtype=xp.float64)
spline_filter_axis_0 = ndimage.spline_filter1d(eye, axis=0, order=order,
mode=mode)
spline_filter_axis_1 = ndimage.spline_filter1d(eye, axis=1, order=order,
mode=mode)
matrix = make_spline_knot_matrix(xp, n, order, mode=mode)
assert_almost_equal(eye, spline_filter_axis_0 @ matrix)
assert_almost_equal(eye, spline_filter_axis_1 @ matrix.T)
|