|
import torch |
|
|
|
def merge_tensors_fft2(v0: torch.Tensor, v1: torch.Tensor, t: float) -> torch.Tensor: |
|
""" |
|
Merges two tensors using 2D Fourier transform interpolation. |
|
|
|
Parameters: |
|
- v0 (torch.Tensor): The first input tensor. |
|
- v1 (torch.Tensor): The second input tensor. |
|
- t (float): Interpolation parameter (0 <= t <= 1). |
|
|
|
Returns: |
|
- torch.Tensor: The tensor resulting from the interpolated inverse FFT. |
|
""" |
|
|
|
v0 = v0.to("cuda:0") |
|
v1 = v1.to("cuda:0") |
|
|
|
if len(v0.shape) == 1: |
|
fft_v0 = torch.fft.fft(v0) |
|
fft_v1 = torch.fft.fft(v1) |
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
merged_tensor = torch.fft.ifft(result_fft).real |
|
del v0, v1, fft_v0, fft_v1, result_fft |
|
return merged_tensor |
|
|
|
|
|
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) |
|
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) |
|
|
|
|
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
|
|
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = (1 - t) * imag_v0[sign_mask] + t * imag_v1[sign_mask] |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
|
|
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real |
|
|
|
return merged_tensor |
|
|
|
def merge_tensors_fft_shell(v0: torch.Tensor, v1: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Merges two tensors using 2D Fourier transform interpolation. |
|
|
|
Parameters: |
|
- v0 (torch.Tensor): The first input tensor. |
|
- v1 (torch.Tensor): The second input tensor. |
|
|
|
Returns: |
|
- torch.Tensor: The tensor resulting from the maximal interpolated inverse FFT. |
|
""" |
|
|
|
v0 = v0.to("cuda:0") |
|
v1 = v1.to("cuda:0") |
|
|
|
if len(v0.shape) == 1: |
|
fft_v0 = torch.fft.fft(v0) |
|
fft_v1 = torch.fft.fft(v1) |
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask]) |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
merged_tensor = torch.fft.ifft(result_fft).real |
|
del v0, v1, fft_v0, fft_v1, result_fft |
|
return merged_tensor |
|
|
|
|
|
fft_v0 = torch.fft.fftn(v0, dim=(-2, -1)) |
|
fft_v1 = torch.fft.fftn(v1, dim=(-2, -1)) |
|
|
|
|
|
result_fft = torch.zeros_like(fft_v0) |
|
|
|
|
|
real_v0 = fft_v0.real |
|
real_v1 = fft_v1.real |
|
abs_real_v0 = real_v0.abs() |
|
abs_real_v1 = real_v1.abs() |
|
|
|
|
|
sign_mask = real_v0.sign() == real_v1.sign() |
|
larger_values_mask = abs_real_v0 > abs_real_v1 |
|
|
|
|
|
result_fft.real[sign_mask] = torch.max(real_v0[sign_mask], real_v1[sign_mask]) |
|
result_fft.real[~sign_mask] = torch.where(larger_values_mask[~sign_mask], real_v0[~sign_mask], real_v1[~sign_mask]) |
|
|
|
|
|
imag_v0 = fft_v0.imag |
|
imag_v1 = fft_v1.imag |
|
abs_imag_v0 = imag_v0.abs() |
|
abs_imag_v1 = imag_v1.abs() |
|
larger_values_mask_imag = abs_imag_v0 > abs_imag_v1 |
|
|
|
result_fft.imag[sign_mask] = torch.max(imag_v0[sign_mask], imag_v1[sign_mask]) |
|
result_fft.imag[~sign_mask] = torch.where(larger_values_mask_imag[~sign_mask], imag_v0[~sign_mask], imag_v1[~sign_mask]) |
|
|
|
|
|
merged_tensor = torch.fft.ifftn(result_fft, dim=(-2, -1)).real |
|
|
|
return merged_tensor |
|
|