nickovchinnikov's picture
Init
9d61c9b
import torch
from torch.nn import Module
class AddCoords(Module):
r"""AddCoords is a PyTorch module that adds additional channels to the input tensor containing the relative
(normalized to `[-1, 1]`) coordinates of each input element along the specified number of dimensions (`rank`).
Essentially, it adds spatial context information to the tensor.
Typically, these inputs are feature maps coming from some CNN, where the spatial organization of the input
matters (such as an image or speech signal).
This additional spatial context allows subsequent layers (such as convolutions) to learn position-dependent
features. For example, in tasks where the absolute position of features matters (such as denoising and
segmentation tasks), it helps the model to know where (in terms of relative position) the features are.
Args:
rank (int): The dimensionality of the input tensor. That is to say, this tells us how many dimensions the
input tensor's spatial context has. It's assumed to be 1, 2, or 3 corresponding to some 1D, 2D,
or 3D data (like an image).
with_r (bool): Boolean indicating whether to add an extra radial distance channel or not. If True, an extra
channel is appended, which measures the Euclidean (L2) distance from the center of the image.
This might be useful when the proximity to the center of the image is important to the task.
"""
def __init__(self, rank: int, with_r: bool = False):
super().__init__()
self.rank = rank
self.with_r = with_r
def forward(self, x: torch.Tensor) -> torch.Tensor:
r"""Forward pass of the AddCoords module. Depending on the rank of the tensor, it adds one or more new channels
with relative coordinate values. If `with_r` is True, an extra radial channel is included.
For example, for an image (`rank=2`), two channels would be added which contain the normalized x and y
coordinates respectively of each pixel.
Calling the forward method updates the original tensor `x` with the added channels.
Args:
x (torch.Tensor): The input tensor.
Returns:
out (torch.Tensor): The input tensor with added coordinate and possibly radial channels.
"""
if self.rank == 1:
batch_size_shape, _, dim_x = x.shape
xx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
xx_channel = xx_range[None, None, :]
xx_channel = xx_channel.float() / (dim_x - 1)
xx_channel = xx_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)
out = torch.cat([x, xx_channel], dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
out = torch.cat([out, rr], dim=1)
elif self.rank == 2:
batch_size_shape, _, dim_y, dim_x = x.shape
xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32, device=x.device)
xx_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
yy_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
xx_range = xx_range[None, None, :, None]
yy_range = yy_range[None, None, :, None]
xx_channel = torch.matmul(xx_range, xx_ones)
yy_channel = torch.matmul(yy_range, yy_ones)
# transpose y
yy_channel = yy_channel.permute(0, 1, 3, 2)
xx_channel = xx_channel.float() / (dim_y - 1)
yy_channel = yy_channel.float() / (dim_x - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
out = torch.cat([x, xx_channel, yy_channel], dim=1)
if self.with_r:
rr = torch.sqrt(
torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2),
)
out = torch.cat([out, rr], dim=1)
elif self.rank == 3:
batch_size_shape, _, dim_z, dim_y, dim_x = x.shape
xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32, device=x.device)
zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32, device=x.device)
xy_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
xy_range = xy_range[None, None, None, :, None]
yz_range = torch.arange(dim_z, dtype=torch.int32, device=x.device)
yz_range = yz_range[None, None, None, :, None]
zx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
zx_range = zx_range[None, None, None, :, None]
xy_channel = torch.matmul(xy_range, xx_ones)
xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)
yz_channel = torch.matmul(yz_range, yy_ones)
yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)
zx_channel = torch.matmul(zx_range, zz_ones)
zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)
out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1)
if self.with_r:
rr = torch.sqrt(
torch.pow(xx_channel - 0.5, 2)
+ torch.pow(yy_channel - 0.5, 2)
+ torch.pow(zz_channel - 0.5, 2),
)
out = torch.cat([out, rr], dim=1)
else:
raise NotImplementedError
return out