File size: 6,036 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
from torch.nn import Module


class AddCoords(Module):
    r"""AddCoords is a PyTorch module that adds additional channels to the input tensor containing the relative
    (normalized to `[-1, 1]`) coordinates of each input element along the specified number of dimensions (`rank`).
    Essentially, it adds spatial context information to the tensor.

    Typically, these inputs are feature maps coming from some CNN, where the spatial organization of the input
    matters (such as an image or speech signal).

    This additional spatial context allows subsequent layers (such as convolutions) to learn position-dependent
    features. For example, in tasks where the absolute position of features matters (such as denoising and
    segmentation tasks), it helps the model to know where (in terms of relative position) the features are.

    Args:
        rank (int): The dimensionality of the input tensor. That is to say, this tells us how many dimensions the
                    input tensor's spatial context has. It's assumed to be 1, 2, or 3 corresponding to some 1D, 2D,
                    or 3D data (like an image).

        with_r (bool): Boolean indicating whether to add an extra radial distance channel or not. If True, an extra
                       channel is appended, which measures the Euclidean (L2) distance from the center of the image.
                       This might be useful when the proximity to the center of the image is important to the task.
    """

    def __init__(self, rank: int, with_r: bool = False):
        super().__init__()
        self.rank = rank
        self.with_r = with_r

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r"""Forward pass of the AddCoords module. Depending on the rank of the tensor, it adds one or more new channels
        with relative coordinate values. If `with_r` is True, an extra radial channel is included.

        For example, for an image (`rank=2`), two channels would be added which contain the normalized x and y
        coordinates respectively of each pixel.

        Calling the forward method updates the original tensor `x` with the added channels.

        Args:
            x (torch.Tensor): The input tensor.

        Returns:
            out (torch.Tensor): The input tensor with added coordinate and possibly radial channels.
        """
        if self.rank == 1:
            batch_size_shape, _, dim_x = x.shape
            xx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            xx_channel = xx_range[None, None, :]

            xx_channel = xx_channel.float() / (dim_x - 1)
            xx_channel = xx_channel * 2 - 1
            xx_channel = xx_channel.repeat(batch_size_shape, 1, 1)

            out = torch.cat([x, xx_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2))
                out = torch.cat([out, rr], dim=1)

        elif self.rank == 2:
            batch_size_shape, _, dim_y, dim_x = x.shape
            xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
            yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32, device=x.device)

            xx_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
            yy_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            xx_range = xx_range[None, None, :, None]
            yy_range = yy_range[None, None, :, None]

            xx_channel = torch.matmul(xx_range, xx_ones)
            yy_channel = torch.matmul(yy_range, yy_ones)

            # transpose y
            yy_channel = yy_channel.permute(0, 1, 3, 2)

            xx_channel = xx_channel.float() / (dim_y - 1)
            yy_channel = yy_channel.float() / (dim_x - 1)

            xx_channel = xx_channel * 2 - 1
            yy_channel = yy_channel * 2 - 1

            xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
            yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)

            out = torch.cat([x, xx_channel, yy_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(
                    torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2),
                )
                out = torch.cat([out, rr], dim=1)

        elif self.rank == 3:
            batch_size_shape, _, dim_z, dim_y, dim_x = x.shape
            xx_ones = torch.ones([1, 1, 1, 1, dim_x], dtype=torch.int32, device=x.device)
            yy_ones = torch.ones([1, 1, 1, 1, dim_y], dtype=torch.int32, device=x.device)
            zz_ones = torch.ones([1, 1, 1, 1, dim_z], dtype=torch.int32, device=x.device)

            xy_range = torch.arange(dim_y, dtype=torch.int32, device=x.device)
            xy_range = xy_range[None, None, None, :, None]

            yz_range = torch.arange(dim_z, dtype=torch.int32, device=x.device)
            yz_range = yz_range[None, None, None, :, None]

            zx_range = torch.arange(dim_x, dtype=torch.int32, device=x.device)
            zx_range = zx_range[None, None, None, :, None]

            xy_channel = torch.matmul(xy_range, xx_ones)
            xx_channel = torch.cat([xy_channel + i for i in range(dim_z)], dim=2)

            yz_channel = torch.matmul(yz_range, yy_ones)
            yz_channel = yz_channel.permute(0, 1, 3, 4, 2)
            yy_channel = torch.cat([yz_channel + i for i in range(dim_x)], dim=4)

            zx_channel = torch.matmul(zx_range, zz_ones)
            zx_channel = zx_channel.permute(0, 1, 4, 2, 3)
            zz_channel = torch.cat([zx_channel + i for i in range(dim_y)], dim=3)

            out = torch.cat([x, xx_channel, yy_channel, zz_channel], dim=1)

            if self.with_r:
                rr = torch.sqrt(
                    torch.pow(xx_channel - 0.5, 2)
                    + torch.pow(yy_channel - 0.5, 2)
                    + torch.pow(zz_channel - 0.5, 2),
                )
                out = torch.cat([out, rr], dim=1)
        else:
            raise NotImplementedError

        return out