File size: 10,748 Bytes
570db9a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from einops import rearrange
from typing import Tuple, Union, Any, List, Iterable, Optional

from .blocks import LayerNorm, Transformer, Bottleneck, AttentionPool2d


class ModifiedResNet(nn.Module):
    """
    A ResNet class that is similar to torchvision's but contains the following changes:
    - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
    - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
    - The final pooling layer is a QKV attention instead of an average pool
    """
    def __init__(
        self,
        layers: Tuple[int, int, int, int],
        output_dim: int,
        input_resolution: int = 224,
        width: int = 64,
        heads: int = 8,
        features_only: bool = False,
        out_indices: Optional[Iterable[int]] = None,
        reduction: int = 32,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
        assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
        self.input_resolution = input_resolution
        self.downsampling_rate = 32  # the rate at which the input is downsampled by the network

        # the 3-layer stem
        self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(width // 2)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(width // 2)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
        self.bn3 = nn.BatchNorm2d(width)
        self.relu3 = nn.ReLU(inplace=True)
        self.avgpool = nn.AvgPool2d(2)

        # residual layers
        self._inplanes = width  # this is a *mutable* variable used during construction
        self.layer1 = self._make_layer(width, layers[0])
        self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
        self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
        self.layer4 = self._make_layer(width * 8, layers[3], stride=1 if reduction <= 16 else 2)

        self.features_only = features_only
        if features_only:
            self.out_indices = out_indices if out_indices is not None else range(5)
            self.out_indices = [idx + 5 if idx < 0 else idx for idx in self.out_indices]  # map negative indices to positive indices
            self.out_indices = sorted(set(self.out_indices))  # remove duplicates and sort
            assert min(self.out_indices) >= 0 and max(self.out_indices) <= 4, f"out_indices={self.out_indices} is invalid for a ResNet with 5 stages"
            self.channels = width * 32  # the ResNet feature dimension
        else:
            self.out_indices = None
            embed_dim = width * 32  # the ResNet feature dimension
            self.attnpool = AttentionPool2d((input_resolution[0] // 32) * (input_resolution[1] // 32), embed_dim, heads, output_dim)
            self.channels = output_dim

        self.reduction = self.downsampling_rate // 2 if reduction <= 16 else self.downsampling_rate
        self.clip_embed_dim = output_dim

    def _make_layer(self, planes, blocks, stride=1):
        layers = [Bottleneck(self._inplanes, planes, stride)]

        self._inplanes = planes * Bottleneck.expansion
        for _ in range(1, blocks):
            layers.append(Bottleneck(self._inplanes, planes))

        return nn.Sequential(*layers)

    def _stem(self, x: Tensor) -> Tensor:
        x = self.relu1(self.bn1(self.conv1(x)))
        x = self.relu2(self.bn2(self.conv2(x)))
        x = self.relu3(self.bn3(self.conv3(x)))
        x = self.avgpool(x)
        return x

    def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]:
        x = x.type(self.conv1.weight.dtype)
        x = self._stem(x)

        feats = [x] if self.features_only and 0 in self.out_indices else []

        x = self.layer1(x)
        if self.features_only and 1 in self.out_indices:
            feats.append(x)

        x = self.layer2(x)
        if self.features_only and 2 in self.out_indices:
            feats.append(x)

        x = self.layer3(x)
        if self.features_only and 3 in self.out_indices:
            feats.append(x)

        x = self.layer4(x)
        if self.features_only and 4 in self.out_indices:
            feats.append(x)

        if self.features_only:
            if len(self.out_indices) == 1:
                return feats[0]
            else:
                return feats
        else:
            x = self.attnpool(x)
            return x


class VisionTransformer(nn.Module):
    def __init__(
        self,
        input_resolution: Union[int, Tuple[int, int]],
        patch_size: Union[int, Tuple[int, int]],
        output_dim: int,
        width: int,
        layers: int,
        heads: int,
        features_only: bool = False,
        **kwargs: Any,
    ) -> None:
        super().__init__()
        input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
        patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
        assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
        assert isinstance(patch_size, tuple) and len(patch_size) == 2, f"patch_size should be a tuple of length 2, but got {patch_size}"
        assert patch_size[0] == patch_size[1], f"ViT only supports square patches, patch_size={patch_size} is invalid."
        assert input_resolution[0] % patch_size[0] == 0 and input_resolution[1] % patch_size[1] == 0, f"input_resolution {input_resolution} should be divisible by patch_size {patch_size}"
        self.input_resolution = input_resolution
        self.patch_size = patch_size
        self.downsampling_rate = patch_size[0]

        self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)

        scale = width ** -0.5
        self.class_embedding = nn.Parameter(scale * torch.randn(width))
        self.num_patches_h = int(input_resolution[0] // patch_size[0])
        self.num_patches_w = int(input_resolution[1] // patch_size[1])
        self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches_h * self.num_patches_w + 1, width))
        self.ln_pre = LayerNorm(width)

        self.transformer = Transformer(width, layers, heads)
        self.ln_post = LayerNorm(width)

        self.features_only = features_only  # if True, return the final patches instead of the CLS token
        if features_only:
            self.channels = width
        else:
            self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
            self.channels = output_dim

        self.reduction = patch_size[0]
        self.clip_embed_dim = output_dim

    def adjust_pos_embed(self, h: int, w: int) -> None:
        """
        Permanently adjust the size of the positional embedding matrix.

        Args:
            h: the height of the original input image.
            w: the width of the original input image.
        """
        assert h % self.patch_size[0] == 0 and w % self.patch_size[1] == 0, f"input_resolution {h, w} should be divisible by patch_size {self.patch_size}"
        if self.input_resolution[0] != h or self.input_resolution[1] != w:
            new_num_patches_h = int(h // self.patch_size[0])
            new_num_patches_w = int(w // self.patch_size[1])
            positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0)  # add batch dimension
            positional_embedding = F.interpolate(positional_embedding, size=(new_num_patches_h, new_num_patches_w), mode="bicubic", ).squeeze(0)  # remove batch dimension
            positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
            self.positional_embedding = nn.Parameter(torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0))
            self.input_resolution = (h, w)
            self.num_patches_h = new_num_patches_h
            self.num_patches_w = new_num_patches_w

    def _interpolate_pos_embed(self, h: int, w: int) -> Tensor:
        """
        Interpolate the positional embedding matrix to match the size of the input image.

        Args:
            h: the required number of patches along the height dimension.
            w: the required number of patches along the width dimension.
        """
        if h == self.num_patches_h and w == self.num_patches_w:
            return self.positional_embedding
        else:
            positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0)  # add batch dimension
            positional_embedding = F.interpolate(positional_embedding, size=(h, w), mode="bicubic").squeeze(0)  # remove batch dimension
            positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
            positional_embedding = torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0)
            return positional_embedding

    def forward(self, x: Tensor) -> Tensor:
        x = self.conv1(x) # shape = [*, width, grid, grid]
        num_patches_h, num_patches_w = x.shape[-2:]

        positional_embedding = self._interpolate_pos_embed(num_patches_h, num_patches_w).to(x.dtype)
        x = x.reshape(x.shape[0], x.shape[1], -1)  # shape = [*, width, grid ** 2]
        x = x.permute(0, 2, 1)  # shape = [*, grid ** 2, width]
        x = torch.cat([
                self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), 
                x
            ], dim=1)
        x = x + positional_embedding
        x = self.ln_pre(x)

        x = x.permute(1, 0, 2)  # NLD -> LND. N: batch size, L: sequence length, D: feature dimension
        x = self.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.ln_post(x)

        if self.features_only:
            x = x[:, 1:, :]  # remove the CLS token
            x = rearrange(x, "n (h w) c -> n c h w", h=num_patches_h, w=num_patches_w)
        else:
            x = x[:, 0, :]
            x = x @ self.proj
        return x