File size: 7,156 Bytes
7885a28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
# This program is public domain
# Authors: Paul Kienzle, Nadav Horesh
'''
A unit test module for czt.py
'''
import pytest
from scipy._lib._array_api import xp_assert_close
from scipy.fft import fft
from scipy.signal import (czt, zoom_fft, czt_points, CZT, ZoomFFT)
import numpy as np


def check_czt(x):
    # Check that czt is the equivalent of normal fft
    y = fft(x)
    y1 = czt(x)
    xp_assert_close(y1, y, rtol=1e-13)

    # Check that interpolated czt is the equivalent of normal fft
    y = fft(x, 100*len(x))
    y1 = czt(x, 100*len(x))
    xp_assert_close(y1, y, rtol=1e-12)


def check_zoom_fft(x):
    # Check that zoom_fft is the equivalent of normal fft
    y = fft(x)
    y1 = zoom_fft(x, [0, 2-2./len(y)], endpoint=True)
    xp_assert_close(y1, y, rtol=1e-11, atol=1e-14)
    y1 = zoom_fft(x, [0, 2])
    xp_assert_close(y1, y, rtol=1e-11, atol=1e-14)

    # Test fn scalar
    y1 = zoom_fft(x, 2-2./len(y), endpoint=True)
    xp_assert_close(y1, y, rtol=1e-11, atol=1e-14)
    y1 = zoom_fft(x, 2)
    xp_assert_close(y1, y, rtol=1e-11, atol=1e-14)

    # Check that zoom_fft with oversampling is equivalent to zero padding
    over = 10
    yover = fft(x, over*len(x))
    y2 = zoom_fft(x, [0, 2-2./len(yover)], m=len(yover), endpoint=True)
    xp_assert_close(y2, yover, rtol=1e-12, atol=1e-10)
    y2 = zoom_fft(x, [0, 2], m=len(yover))
    xp_assert_close(y2, yover, rtol=1e-12, atol=1e-10)

    # Check that zoom_fft works on a subrange
    w = np.linspace(0, 2-2./len(x), len(x))
    f1, f2 = w[3], w[6]
    y3 = zoom_fft(x, [f1, f2], m=3*over+1, endpoint=True)
    idx3 = slice(3*over, 6*over+1)
    xp_assert_close(y3, yover[idx3], rtol=1e-13)


def test_1D():
    # Test of 1D version of the transforms

    rng = np.random.RandomState(0)  # Deterministic randomness

    # Random signals
    lengths = rng.randint(8, 200, 20)
    np.append(lengths, 1)
    for length in lengths:
        x = rng.random(length)
        check_zoom_fft(x)
        check_czt(x)

    # Gauss
    t = np.linspace(-2, 2, 128)
    x = np.exp(-t**2/0.01)
    check_zoom_fft(x)

    # Linear
    x = [1, 2, 3, 4, 5, 6, 7]
    check_zoom_fft(x)

    # Check near powers of two
    check_zoom_fft(range(126-31))
    check_zoom_fft(range(127-31))
    check_zoom_fft(range(128-31))
    check_zoom_fft(range(129-31))
    check_zoom_fft(range(130-31))

    # Check transform on n-D array input
    x = np.reshape(np.arange(3*2*28), (3, 2, 28))
    y1 = zoom_fft(x, [0, 2-2./28])
    y2 = zoom_fft(x[2, 0, :], [0, 2-2./28])
    xp_assert_close(y1[2, 0], y2, rtol=1e-13, atol=1e-12)

    y1 = zoom_fft(x, [0, 2], endpoint=False)
    y2 = zoom_fft(x[2, 0, :], [0, 2], endpoint=False)
    xp_assert_close(y1[2, 0], y2, rtol=1e-13, atol=1e-12)

    # Random (not a test condition)
    x = rng.rand(101)
    check_zoom_fft(x)

    # Spikes
    t = np.linspace(0, 1, 128)
    x = np.sin(2*np.pi*t*5)+np.sin(2*np.pi*t*13)
    check_zoom_fft(x)

    # Sines
    x = np.zeros(100, dtype=complex)
    x[[1, 5, 21]] = 1
    check_zoom_fft(x)

    # Sines plus complex component
    x += 1j*np.linspace(0, 0.5, x.shape[0])
    check_zoom_fft(x)


def test_large_prime_lengths():
    rng = np.random.RandomState(0)  # Deterministic randomness
    for N in (101, 1009, 10007):
        x = rng.rand(N)
        y = fft(x)
        y1 = czt(x)
        xp_assert_close(y, y1, rtol=1e-12)


@pytest.mark.slow
def test_czt_vs_fft():
    rng = np.random.RandomState(123)  # Deterministic randomness
    random_lengths = rng.exponential(100000, size=10).astype('int')
    for n in random_lengths:
        a = rng.randn(n)
        xp_assert_close(czt(a), fft(a), rtol=1e-11)


def test_empty_input():
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        czt([])
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        zoom_fft([], 0.5)


def test_0_rank_input():
    with pytest.raises(IndexError, match='tuple index out of range'):
        czt(5)
    with pytest.raises(IndexError, match='tuple index out of range'):
        zoom_fft(5, 0.5)


@pytest.mark.parametrize('impulse', ([0, 0, 1], [0, 0, 1, 0, 0],
                                     np.concatenate((np.array([0, 0, 1]),
                                                     np.zeros(100)))))
@pytest.mark.parametrize('m', (1, 3, 5, 8, 101, 1021))
@pytest.mark.parametrize('a', (1, 2, 0.5, 1.1))
# Step that tests away from the unit circle, but not so far it explodes from
# numerical error
@pytest.mark.parametrize('w', (None, 0.98534 + 0.17055j))
def test_czt_math(impulse, m, w, a):
    # z-transform of an impulse is 1 everywhere
    xp_assert_close(czt(impulse[2:], m=m, w=w, a=a),
                    np.ones(m, dtype=np.complex128), rtol=1e-10)

    # z-transform of a delayed impulse is z**-1
    xp_assert_close(czt(impulse[1:], m=m, w=w, a=a),
                    czt_points(m=m, w=w, a=a)**-1, rtol=1e-10)

    # z-transform of a 2-delayed impulse is z**-2
    xp_assert_close(czt(impulse, m=m, w=w, a=a),
                    czt_points(m=m, w=w, a=a)**-2, rtol=1e-10)


def test_int_args():
    # Integer argument `a` was producing all 0s
    xp_assert_close(abs(czt([0, 1], m=10, a=2)), 0.5*np.ones(10), rtol=1e-15)
    xp_assert_close(czt_points(11, w=2),
                    1/(2**np.arange(11, dtype=np.complex128)), rtol=1e-30)


def test_czt_points():
    for N in (1, 2, 3, 8, 11, 100, 101, 10007):
        xp_assert_close(czt_points(N), np.exp(2j*np.pi*np.arange(N)/N),
                        rtol=1e-30)

    xp_assert_close(czt_points(7, w=1), np.ones(7, dtype=np.complex128), rtol=1e-30)
    xp_assert_close(czt_points(11, w=2.),
                    1/(2**np.arange(11, dtype=np.complex128)), rtol=1e-30)

    func = CZT(12, m=11, w=2., a=1)
    xp_assert_close(func.points(), 1/(2**np.arange(11)), rtol=1e-30)


@pytest.mark.parametrize('cls, args', [(CZT, (100,)), (ZoomFFT, (100, 0.2))])
def test_CZT_size_mismatch(cls, args):
    # Data size doesn't match function's expected size
    myfunc = cls(*args)
    with pytest.raises(ValueError, match='CZT defined for'):
        myfunc(np.arange(5))


def test_invalid_range():
    with pytest.raises(ValueError, match='2-length sequence'):
        ZoomFFT(100, [1, 2, 3])


@pytest.mark.parametrize('m', [0, -11, 5.5, 4.0])
def test_czt_points_errors(m):
    # Invalid number of points
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        czt_points(m)


@pytest.mark.parametrize('size', [0, -5, 3.5, 4.0])
def test_nonsense_size(size):
    # Numpy and Scipy fft() give ValueError for 0 output size, so we do, too
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        CZT(size, 3)
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        ZoomFFT(size, 0.2, 3)
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        CZT(3, size)
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        ZoomFFT(3, 0.2, size)
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        czt([1, 2, 3], size)
    with pytest.raises(ValueError, match='Invalid number of CZT'):
        zoom_fft([1, 2, 3], 0.2, size)