|
|
|
|
|
''' |
|
A unit test module for czt.py |
|
''' |
|
import pytest |
|
from scipy._lib._array_api import xp_assert_close |
|
from scipy.fft import fft |
|
from scipy.signal import (czt, zoom_fft, czt_points, CZT, ZoomFFT) |
|
import numpy as np |
|
|
|
|
|
def check_czt(x): |
|
|
|
y = fft(x) |
|
y1 = czt(x) |
|
xp_assert_close(y1, y, rtol=1e-13) |
|
|
|
|
|
y = fft(x, 100*len(x)) |
|
y1 = czt(x, 100*len(x)) |
|
xp_assert_close(y1, y, rtol=1e-12) |
|
|
|
|
|
def check_zoom_fft(x): |
|
|
|
y = fft(x) |
|
y1 = zoom_fft(x, [0, 2-2./len(y)], endpoint=True) |
|
xp_assert_close(y1, y, rtol=1e-11, atol=1e-14) |
|
y1 = zoom_fft(x, [0, 2]) |
|
xp_assert_close(y1, y, rtol=1e-11, atol=1e-14) |
|
|
|
|
|
y1 = zoom_fft(x, 2-2./len(y), endpoint=True) |
|
xp_assert_close(y1, y, rtol=1e-11, atol=1e-14) |
|
y1 = zoom_fft(x, 2) |
|
xp_assert_close(y1, y, rtol=1e-11, atol=1e-14) |
|
|
|
|
|
over = 10 |
|
yover = fft(x, over*len(x)) |
|
y2 = zoom_fft(x, [0, 2-2./len(yover)], m=len(yover), endpoint=True) |
|
xp_assert_close(y2, yover, rtol=1e-12, atol=1e-10) |
|
y2 = zoom_fft(x, [0, 2], m=len(yover)) |
|
xp_assert_close(y2, yover, rtol=1e-12, atol=1e-10) |
|
|
|
|
|
w = np.linspace(0, 2-2./len(x), len(x)) |
|
f1, f2 = w[3], w[6] |
|
y3 = zoom_fft(x, [f1, f2], m=3*over+1, endpoint=True) |
|
idx3 = slice(3*over, 6*over+1) |
|
xp_assert_close(y3, yover[idx3], rtol=1e-13) |
|
|
|
|
|
def test_1D(): |
|
|
|
|
|
rng = np.random.RandomState(0) |
|
|
|
|
|
lengths = rng.randint(8, 200, 20) |
|
np.append(lengths, 1) |
|
for length in lengths: |
|
x = rng.random(length) |
|
check_zoom_fft(x) |
|
check_czt(x) |
|
|
|
|
|
t = np.linspace(-2, 2, 128) |
|
x = np.exp(-t**2/0.01) |
|
check_zoom_fft(x) |
|
|
|
|
|
x = [1, 2, 3, 4, 5, 6, 7] |
|
check_zoom_fft(x) |
|
|
|
|
|
check_zoom_fft(range(126-31)) |
|
check_zoom_fft(range(127-31)) |
|
check_zoom_fft(range(128-31)) |
|
check_zoom_fft(range(129-31)) |
|
check_zoom_fft(range(130-31)) |
|
|
|
|
|
x = np.reshape(np.arange(3*2*28), (3, 2, 28)) |
|
y1 = zoom_fft(x, [0, 2-2./28]) |
|
y2 = zoom_fft(x[2, 0, :], [0, 2-2./28]) |
|
xp_assert_close(y1[2, 0], y2, rtol=1e-13, atol=1e-12) |
|
|
|
y1 = zoom_fft(x, [0, 2], endpoint=False) |
|
y2 = zoom_fft(x[2, 0, :], [0, 2], endpoint=False) |
|
xp_assert_close(y1[2, 0], y2, rtol=1e-13, atol=1e-12) |
|
|
|
|
|
x = rng.rand(101) |
|
check_zoom_fft(x) |
|
|
|
|
|
t = np.linspace(0, 1, 128) |
|
x = np.sin(2*np.pi*t*5)+np.sin(2*np.pi*t*13) |
|
check_zoom_fft(x) |
|
|
|
|
|
x = np.zeros(100, dtype=complex) |
|
x[[1, 5, 21]] = 1 |
|
check_zoom_fft(x) |
|
|
|
|
|
x += 1j*np.linspace(0, 0.5, x.shape[0]) |
|
check_zoom_fft(x) |
|
|
|
|
|
def test_large_prime_lengths(): |
|
rng = np.random.RandomState(0) |
|
for N in (101, 1009, 10007): |
|
x = rng.rand(N) |
|
y = fft(x) |
|
y1 = czt(x) |
|
xp_assert_close(y, y1, rtol=1e-12) |
|
|
|
|
|
@pytest.mark.slow |
|
def test_czt_vs_fft(): |
|
rng = np.random.RandomState(123) |
|
random_lengths = rng.exponential(100000, size=10).astype('int') |
|
for n in random_lengths: |
|
a = rng.randn(n) |
|
xp_assert_close(czt(a), fft(a), rtol=1e-11) |
|
|
|
|
|
def test_empty_input(): |
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
czt([]) |
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
zoom_fft([], 0.5) |
|
|
|
|
|
def test_0_rank_input(): |
|
with pytest.raises(IndexError, match='tuple index out of range'): |
|
czt(5) |
|
with pytest.raises(IndexError, match='tuple index out of range'): |
|
zoom_fft(5, 0.5) |
|
|
|
|
|
@pytest.mark.parametrize('impulse', ([0, 0, 1], [0, 0, 1, 0, 0], |
|
np.concatenate((np.array([0, 0, 1]), |
|
np.zeros(100))))) |
|
@pytest.mark.parametrize('m', (1, 3, 5, 8, 101, 1021)) |
|
@pytest.mark.parametrize('a', (1, 2, 0.5, 1.1)) |
|
|
|
|
|
@pytest.mark.parametrize('w', (None, 0.98534 + 0.17055j)) |
|
def test_czt_math(impulse, m, w, a): |
|
|
|
xp_assert_close(czt(impulse[2:], m=m, w=w, a=a), |
|
np.ones(m, dtype=np.complex128), rtol=1e-10) |
|
|
|
|
|
xp_assert_close(czt(impulse[1:], m=m, w=w, a=a), |
|
czt_points(m=m, w=w, a=a)**-1, rtol=1e-10) |
|
|
|
|
|
xp_assert_close(czt(impulse, m=m, w=w, a=a), |
|
czt_points(m=m, w=w, a=a)**-2, rtol=1e-10) |
|
|
|
|
|
def test_int_args(): |
|
|
|
xp_assert_close(abs(czt([0, 1], m=10, a=2)), 0.5*np.ones(10), rtol=1e-15) |
|
xp_assert_close(czt_points(11, w=2), |
|
1/(2**np.arange(11, dtype=np.complex128)), rtol=1e-30) |
|
|
|
|
|
def test_czt_points(): |
|
for N in (1, 2, 3, 8, 11, 100, 101, 10007): |
|
xp_assert_close(czt_points(N), np.exp(2j*np.pi*np.arange(N)/N), |
|
rtol=1e-30) |
|
|
|
xp_assert_close(czt_points(7, w=1), np.ones(7, dtype=np.complex128), rtol=1e-30) |
|
xp_assert_close(czt_points(11, w=2.), |
|
1/(2**np.arange(11, dtype=np.complex128)), rtol=1e-30) |
|
|
|
func = CZT(12, m=11, w=2., a=1) |
|
xp_assert_close(func.points(), 1/(2**np.arange(11)), rtol=1e-30) |
|
|
|
|
|
@pytest.mark.parametrize('cls, args', [(CZT, (100,)), (ZoomFFT, (100, 0.2))]) |
|
def test_CZT_size_mismatch(cls, args): |
|
|
|
myfunc = cls(*args) |
|
with pytest.raises(ValueError, match='CZT defined for'): |
|
myfunc(np.arange(5)) |
|
|
|
|
|
def test_invalid_range(): |
|
with pytest.raises(ValueError, match='2-length sequence'): |
|
ZoomFFT(100, [1, 2, 3]) |
|
|
|
|
|
@pytest.mark.parametrize('m', [0, -11, 5.5, 4.0]) |
|
def test_czt_points_errors(m): |
|
|
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
czt_points(m) |
|
|
|
|
|
@pytest.mark.parametrize('size', [0, -5, 3.5, 4.0]) |
|
def test_nonsense_size(size): |
|
|
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
CZT(size, 3) |
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
ZoomFFT(size, 0.2, 3) |
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
CZT(3, size) |
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
ZoomFFT(3, 0.2, size) |
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
czt([1, 2, 3], size) |
|
with pytest.raises(ValueError, match='Invalid number of CZT'): |
|
zoom_fft([1, 2, 3], 0.2, size) |
|
|