|
import torch |
|
from torch.fft import fftn |
|
|
|
|
|
def roll_quadrants(data, backwards=False): |
|
""" |
|
Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1] |
|
Args: |
|
data: fourier transform, (NxHxW) |
|
backwards: bool, if True shift high frequencies back to center |
|
|
|
Returns: |
|
Shifted fourier transform. |
|
""" |
|
dim = data.ndim - 1 |
|
|
|
if dim != 2: |
|
raise AttributeError(f'Data must be 2d but it is {dim}d.') |
|
if any(s % 2 == 0 for s in data.shape[1:]): |
|
raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.') |
|
|
|
|
|
dims = tuple(range(1, dim+1)) |
|
shifts = torch.tensor(data.shape[1:]) // 2 |
|
if backwards: |
|
shifts *= -1 |
|
return data.roll(shifts.tolist(), dims=dims) |
|
|
|
|
|
def batch_fft(data, normalize=False): |
|
""" |
|
Compute fourier transform of batch. |
|
Args: |
|
data: input tensor, (NxHxW) |
|
|
|
Returns: |
|
Batch fourier transform of input data. |
|
""" |
|
|
|
dim = data.ndim - 1 |
|
if dim != 2: |
|
raise AttributeError(f'Data must be 2d but it is {dim}d.') |
|
|
|
dims = tuple(range(1, dim + 1)) |
|
if normalize: |
|
norm = 'ortho' |
|
else: |
|
norm = 'backward' |
|
|
|
if not torch.is_complex(data): |
|
data = torch.complex(data, torch.zeros_like(data)) |
|
freq = fftn(data, dim=dims, norm=norm) |
|
|
|
return freq |
|
|
|
|
|
def azimuthal_average(image, center=None): |
|
|
|
""" |
|
Calculate the azimuthally averaged radial profile. |
|
Requires low frequencies to be at the center of the image. |
|
Args: |
|
image: Batch of 2D images, NxHxW |
|
center: The [x,y] pixel coordinates used as the center. The default is |
|
None, which then uses the center of the image (including |
|
fracitonal pixels). |
|
|
|
Returns: |
|
Azimuthal average over the image around the center |
|
""" |
|
|
|
assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \ |
|
f'(but it is len(center)={len(center)}.' |
|
|
|
H, W = image.shape[-2:] |
|
h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) |
|
|
|
if center is None: |
|
center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0]) |
|
|
|
|
|
r = torch.stack([w-center[0], h-center[1]]).norm(2, 0) |
|
|
|
|
|
r_sorted, ind = r.flatten().sort() |
|
i_sorted = image.flatten(-2, -1)[..., ind] |
|
|
|
|
|
r_int = r_sorted.long() |
|
|
|
|
|
deltar = r_int[1:] - r_int[:-1] |
|
rind = torch.where(deltar)[0] |
|
|
|
|
|
nind = rind + 1 |
|
nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) |
|
nr = nind[1:] - nind[:-1] |
|
|
|
|
|
if H % 2 == 0: |
|
raise NotImplementedError('Not sure if implementation correct, please check') |
|
rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) |
|
else: |
|
rind = torch.cat([rind, torch.tensor([H * W - 1])]) |
|
csim = i_sorted.cumsum(-1, dtype=torch.float64) |
|
tbin = csim[..., rind[1:]] - csim[..., rind[:-1]] |
|
|
|
tbin = torch.cat([csim[:, 0:1], tbin], 1) |
|
|
|
radial_prof = tbin / nr.to(tbin.device) |
|
|
|
return radial_prof |
|
|
|
|
|
def get_spectrum(data, normalize=False): |
|
dim = data.ndim - 1 |
|
if dim != 2: |
|
raise AttributeError(f'Data must be 2d but it is {dim}d.') |
|
|
|
freq = batch_fft(data, normalize=normalize) |
|
power_spec = freq.real ** 2 + freq.imag ** 2 |
|
N = data.shape[1] |
|
if N % 2 == 0: |
|
|
|
N_2 = N//2 |
|
power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1) |
|
power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2) |
|
|
|
power_spec = roll_quadrants(power_spec) |
|
power_spec = azimuthal_average(power_spec) |
|
return power_spec |
|
|
|
|
|
def plot_std(mean, std, x=None, ax=None, **kwargs): |
|
import matplotlib.pyplot as plt |
|
if ax is None: |
|
fig, ax = plt.subplots(1) |
|
|
|
|
|
err_kwargs = { |
|
'alpha': 0.3 |
|
} |
|
|
|
if 'c' in kwargs.keys(): |
|
err_kwargs['color'] = kwargs['c'] |
|
elif 'color' in kwargs.keys(): |
|
err_kwargs['color'] = kwargs['color'] |
|
|
|
if x is None: |
|
x = torch.linspace(0, 1, len(mean)) |
|
ax.plot(x, mean, **kwargs) |
|
ax.fill_between(x, mean-std, mean+std, **err_kwargs) |
|
|
|
return ax |
|
|