Zhendong
Initial Commit
2e04998
import torch
from torch.fft import fftn
def roll_quadrants(data, backwards=False):
"""
Shift low frequencies to the center of fourier transform, i.e. [-N/2, ..., +N/2] -> [0, ..., N-1]
Args:
data: fourier transform, (NxHxW)
backwards: bool, if True shift high frequencies back to center
Returns:
Shifted fourier transform.
"""
dim = data.ndim - 1
if dim != 2:
raise AttributeError(f'Data must be 2d but it is {dim}d.')
if any(s % 2 == 0 for s in data.shape[1:]):
raise RuntimeWarning('Roll quadrants for 2d input should only be used with uneven spatial sizes.')
# for each dimension swap left and right half
dims = tuple(range(1, dim+1)) # add one for batch dimension
shifts = torch.tensor(data.shape[1:]) // 2 #.div(2, rounding_mode='floor') # N/2 if N even, (N-1)/2 if N odd
if backwards:
shifts *= -1
return data.roll(shifts.tolist(), dims=dims)
def batch_fft(data, normalize=False):
"""
Compute fourier transform of batch.
Args:
data: input tensor, (NxHxW)
Returns:
Batch fourier transform of input data.
"""
dim = data.ndim - 1 # subtract one for batch dimension
if dim != 2:
raise AttributeError(f'Data must be 2d but it is {dim}d.')
dims = tuple(range(1, dim + 1)) # add one for batch dimension
if normalize:
norm = 'ortho'
else:
norm = 'backward'
if not torch.is_complex(data):
data = torch.complex(data, torch.zeros_like(data))
freq = fftn(data, dim=dims, norm=norm)
return freq
def azimuthal_average(image, center=None):
# modified to tensor inputs from https://www.astrobetter.com/blog/2010/03/03/fourier-transforms-of-images-in-python/
"""
Calculate the azimuthally averaged radial profile.
Requires low frequencies to be at the center of the image.
Args:
image: Batch of 2D images, NxHxW
center: The [x,y] pixel coordinates used as the center. The default is
None, which then uses the center of the image (including
fracitonal pixels).
Returns:
Azimuthal average over the image around the center
"""
# Check input shapes
assert center is None or (len(center) == 2), f'Center has to be None or len(center)=2 ' \
f'(but it is len(center)={len(center)}.'
# Calculate the indices from the image
H, W = image.shape[-2:]
h, w = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
if center is None:
center = torch.tensor([(w.max() - w.min()) / 2.0, (h.max() - h.min()) / 2.0])
# Compute radius for each pixel wrt center
r = torch.stack([w-center[0], h-center[1]]).norm(2, 0)
# Get sorted radii
r_sorted, ind = r.flatten().sort()
i_sorted = image.flatten(-2, -1)[..., ind]
# Get the integer part of the radii (bin size = 1)
r_int = r_sorted.long() # attribute to the smaller integer
# Find all pixels that fall within each radial bin.
deltar = r_int[1:] - r_int[:-1] # Assumes all radii represented, computes bin change between subsequent radii
rind = torch.where(deltar)[0] # location of changed radius
# compute number of elements in each bin
nind = rind + 1 # number of elements = idx + 1
nind = torch.cat([torch.tensor([0]), nind, torch.tensor([H*W])]) # add borders
nr = nind[1:] - nind[:-1] # number of radius bin, i.e. counter for bins belonging to each radius
# Cumulative sum to figure out sums for each radius bin
if H % 2 == 0:
raise NotImplementedError('Not sure if implementation correct, please check')
rind = torch.cat([torch.tensor([0]), rind, torch.tensor([H * W - 1])]) # add borders
else:
rind = torch.cat([rind, torch.tensor([H * W - 1])]) # add borders
csim = i_sorted.cumsum(-1, dtype=torch.float64) # integrate over all values with smaller radius
tbin = csim[..., rind[1:]] - csim[..., rind[:-1]]
# add mean
tbin = torch.cat([csim[:, 0:1], tbin], 1)
radial_prof = tbin / nr.to(tbin.device) # normalize by counted bins
return radial_prof
def get_spectrum(data, normalize=False):
dim = data.ndim - 1 # subtract one for batch dimension
if dim != 2:
raise AttributeError(f'Data must be 2d but it is {dim}d.')
freq = batch_fft(data, normalize=normalize)
power_spec = freq.real ** 2 + freq.imag ** 2
N = data.shape[1]
if N % 2 == 0: # duplicate value for N/2 so it is put at the end of the spectrum
# and is not averaged with the mean value
N_2 = N//2
power_spec = torch.cat([power_spec[:, :N_2+1], power_spec[:, N_2:N_2+1], power_spec[:, N_2+1:]], dim=1)
power_spec = torch.cat([power_spec[:, :, :N_2+1], power_spec[:, :, N_2:N_2+1], power_spec[:, :, N_2+1:]], dim=2)
power_spec = roll_quadrants(power_spec)
power_spec = azimuthal_average(power_spec)
return power_spec
def plot_std(mean, std, x=None, ax=None, **kwargs):
import matplotlib.pyplot as plt
if ax is None:
fig, ax = plt.subplots(1)
# plot error margins in same color as line
err_kwargs = {
'alpha': 0.3
}
if 'c' in kwargs.keys():
err_kwargs['color'] = kwargs['c']
elif 'color' in kwargs.keys():
err_kwargs['color'] = kwargs['color']
if x is None:
x = torch.linspace(0, 1, len(mean)) # use normalized x axis
ax.plot(x, mean, **kwargs)
ax.fill_between(x, mean-std, mean+std, **err_kwargs)
return ax