File size: 5,658 Bytes
2e04998
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import torch
from torch.fft import fftn


def roll_quadrants(data, backwards=False):
    """
    Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
    Args:
        data: fourier transform, (NxHxW)
        backwards: bool, if True shift high frequencies back to center

    Returns:
    Shifted fourier transform.
    """
    dim = data.ndim - 1

    if dim != 2:
        raise AttributeError(f'Data must be 2d but it is {dim}d.')
    if any(s % 2 == 0 for s in data.shape[1:]):
        raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')

    # for each dimension swap left and right half
    dims = tuple(range(1, dim+1))          # add one for batch dimension
    shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor')                # N/2 if N even, (N-1)/2 if N odd
    if backwards:
        shifts *= -1
    return data.roll(shifts.tolist(), dims=dims)


def batch_fft(data, normalize=False):
    """
    Compute fourier transform of batch.
    Args:
        data: input tensor, (NxHxW)

    Returns:
    Batch fourier transform of input data.
    """

    dim = data.ndim - 1     # subtract one for batch dimension
    if dim != 2:
        raise AttributeError(f'Data must be 2d but it is {dim}d.')
        
    dims = tuple(range(1, dim + 1))  # add one for batch dimension
    if normalize:
        norm = 'ortho'
    else:
        norm = 'backward'

    if not torch.is_complex(data):
        data = torch.complex(data, torch.zeros_like(data))
    freq = fftn(data, dim=dims, norm=norm)

    return freq


def azimuthal_average(image, center=None):
    # modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
    """
    Calculate the azimuthally averaged radial profile.
    Requires low frequencies to be at the center of the image.
    Args:
        image: Batch of 2D images, NxHxW
        center: The [x,y] pixel coordinates used as the center. The default is
             None, which then uses the center of the image (including
             fracitonal pixels).

    Returns:
    Azimuthal average over the image around the center
    """
    # Check input shapes
    assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
                                                 f'(but it is len(center)={len(center)}.'
    # Calculate the indices from the image
    H, W = image.shape[-2:]
    h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))

    if center is None:
        center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])

    # Compute radius for each pixel wrt center
    r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)

    # Get sorted radii
    r_sorted, ind = r.flatten().sort()
    i_sorted = image.flatten(-2, -1)[..., ind]

    # Get the integer part of the radii (bin size = 1)
    r_int = r_sorted.long()             # attribute to the smaller integer

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented, computes bin change between subsequent radii
    rind = torch.where(deltar)[0]  # location of changed radius

    # compute number of elements in each bin
    nind = rind + 1         # number of elements = idx + 1
    nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])])        # add borders
    nr = nind[1:] - nind[:-1]  # number of radius bin, i.e. counter for bins belonging to each radius

    # Cumulative sum to figure out sums for each radius bin
    if H % 2 == 0:
        raise NotImplementedError('Not sure if implementation correct, please check')
        rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])])  # add borders
    else:
        rind = torch.cat([rind, torch.tensor([H * W - 1])])  # add borders
    csim = i_sorted.cumsum(-1, dtype=torch.float64)             # integrate over all values with smaller radius
    tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
    # add mean
    tbin = torch.cat([csim[:, 0:1], tbin], 1)

    radial_prof = tbin / nr.to(tbin.device)         # normalize by counted bins

    return radial_prof


def get_spectrum(data, normalize=False):
    dim = data.ndim - 1  # subtract one for batch dimension
    if dim != 2:
        raise AttributeError(f'Data must be 2d but it is {dim}d.')

    freq = batch_fft(data, normalize=normalize)
    power_spec = freq.real ** 2 + freq.imag ** 2
    N = data.shape[1]
    if N % 2 == 0:      # duplicate value for N/2 so it is put at the end of the spectrum
                        # and is not averaged with the mean value
        N_2 = N//2
        power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
        power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)

    power_spec = roll_quadrants(power_spec)
    power_spec = azimuthal_average(power_spec)
    return power_spec


def plot_std(mean, std, x=None, ax=None, **kwargs):
    import matplotlib.pyplot as plt
    if ax is None:
        fig, ax = plt.subplots(1)

    # plot error margins in same color as line
    err_kwargs = {
        'alpha': 0.3
    }

    if 'c' in kwargs.keys():
        err_kwargs['color'] = kwargs['c']
    elif 'color' in kwargs.keys():
        err_kwargs['color'] = kwargs['color']

    if x is None:
        x = torch.linspace(0, 1, len(mean))     # use normalized x axis
    ax.plot(x, mean, **kwargs)
    ax.fill_between(x, mean-std, mean+std, **err_kwargs)

    return ax