File size: 5,658 Bytes
2e04998 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
from torch.fft import fftn
def roll_quadrants(data, backwards=False):
"""
Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
Args:
data: fourier transform, (NxHxW)
backwards: bool, if True shift high frequencies back to center
Returns:
Shifted fourier transform.
"""
dim = data.ndim - 1
if dim != 2:
raise AttributeError(f'Data must be 2d but it is {dim}d.')
if any(s % 2 == 0 for s in data.shape[1:]):
raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')
# for each dimension swap left and right half
dims = tuple(range(1, dim+1)) # add one for batch dimension
shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd
if backwards:
shifts *= -1
return data.roll(shifts.tolist(), dims=dims)
def batch_fft(data, normalize=False):
"""
Compute fourier transform of batch.
Args:
data: input tensor, (NxHxW)
Returns:
Batch fourier transform of input data.
"""
dim = data.ndim - 1 # subtract one for batch dimension
if dim != 2:
raise AttributeError(f'Data must be 2d but it is {dim}d.')
dims = tuple(range(1, dim + 1)) # add one for batch dimension
if normalize:
norm = 'ortho'
else:
norm = 'backward'
if not torch.is_complex(data):
data = torch.complex(data, torch.zeros_like(data))
freq = fftn(data, dim=dims, norm=norm)
return freq
def azimuthal_average(image, center=None):
# modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
"""
Calculate the azimuthally averaged radial profile.
Requires low frequencies to be at the center of the image.
Args:
image: Batch of 2D images, NxHxW
center: The [x,y] pixel coordinates used as the center. The default is
None, which then uses the center of the image (including
fracitonal pixels).
Returns:
Azimuthal average over the image around the center
"""
# Check input shapes
assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
f'(but it is len(center)={len(center)}.'
# Calculate the indices from the image
H, W = image.shape[-2:]
h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
if center is None:
center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])
# Compute radius for each pixel wrt center
r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)
# Get sorted radii
r_sorted, ind = r.flatten().sort()
i_sorted = image.flatten(-2, -1)[..., ind]
# Get the integer part of the radii (bin size = 1)
r_int = r_sorted.long() # attribute to the smaller integer
# Find all pixels that fall within each radial bin.
deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii
rind = torch.where(deltar)[0] # location of changed radius
# compute number of elements in each bin
nind = rind + 1 # number of elements = idx + 1
nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders
nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius
# Cumulative sum to figure out sums for each radius bin
if H % 2 == 0:
raise NotImplementedError('Not sure if implementation correct, please check')
rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders
else:
rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders
csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius
tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
# add mean
tbin = torch.cat([csim[:, 0:1], tbin], 1)
radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins
return radial_prof
def get_spectrum(data, normalize=False):
dim = data.ndim - 1 # subtract one for batch dimension
if dim != 2:
raise AttributeError(f'Data must be 2d but it is {dim}d.')
freq = batch_fft(data, normalize=normalize)
power_spec = freq.real ** 2 + freq.imag ** 2
N = data.shape[1]
if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum
# and is not averaged with the mean value
N_2 = N//2
power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)
power_spec = roll_quadrants(power_spec)
power_spec = azimuthal_average(power_spec)
return power_spec
def plot_std(mean, std, x=None, ax=None, **kwargs):
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots(1)
# plot error margins in same color as line
err_kwargs = {
'alpha': 0.3
}
if 'c' in kwargs.keys():
err_kwargs['color'] = kwargs['c']
elif 'color' in kwargs.keys():
err_kwargs['color'] = kwargs['color']
if x is None:
x = torch.linspace(0, 1, len(mean)) # use normalized x axis
ax.plot(x, mean, **kwargs)
ax.fill_between(x, mean-std, mean+std, **err_kwargs)
return ax
|