Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,748 Bytes
570db9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from einops import rearrange
from typing import Tuple, Union, Any, List, Iterable, Optional
from .blocks import LayerNorm, Transformer, Bottleneck, AttentionPool2d
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(
self,
layers: Tuple[int, int, int, int],
output_dim: int,
input_resolution: int = 224,
width: int = 64,
heads: int = 8,
features_only: bool = False,
out_indices: Optional[Iterable[int]] = None,
reduction: int = 32,
**kwargs: Any,
) -> None:
super().__init__()
input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
self.input_resolution = input_resolution
self.downsampling_rate = 32 # the rate at which the input is downsampled by the network
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=1 if reduction <= 16 else 2)
self.features_only = features_only
if features_only:
self.out_indices = out_indices if out_indices is not None else range(5)
self.out_indices = [idx + 5 if idx < 0 else idx for idx in self.out_indices] # map negative indices to positive indices
self.out_indices = sorted(set(self.out_indices)) # remove duplicates and sort
assert min(self.out_indices) >= 0 and max(self.out_indices) <= 4, f"out_indices={self.out_indices} is invalid for a ResNet with 5 stages"
self.channels = width * 32 # the ResNet feature dimension
else:
self.out_indices = None
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d((input_resolution[0] // 32) * (input_resolution[1] // 32), embed_dim, heads, output_dim)
self.channels = output_dim
self.reduction = self.downsampling_rate // 2 if reduction <= 16 else self.downsampling_rate
self.clip_embed_dim = output_dim
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def _stem(self, x: Tensor) -> Tensor:
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]:
x = x.type(self.conv1.weight.dtype)
x = self._stem(x)
feats = [x] if self.features_only and 0 in self.out_indices else []
x = self.layer1(x)
if self.features_only and 1 in self.out_indices:
feats.append(x)
x = self.layer2(x)
if self.features_only and 2 in self.out_indices:
feats.append(x)
x = self.layer3(x)
if self.features_only and 3 in self.out_indices:
feats.append(x)
x = self.layer4(x)
if self.features_only and 4 in self.out_indices:
feats.append(x)
if self.features_only:
if len(self.out_indices) == 1:
return feats[0]
else:
return feats
else:
x = self.attnpool(x)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
input_resolution: Union[int, Tuple[int, int]],
patch_size: Union[int, Tuple[int, int]],
output_dim: int,
width: int,
layers: int,
heads: int,
features_only: bool = False,
**kwargs: Any,
) -> None:
super().__init__()
input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
assert isinstance(patch_size, tuple) and len(patch_size) == 2, f"patch_size should be a tuple of length 2, but got {patch_size}"
assert patch_size[0] == patch_size[1], f"ViT only supports square patches, patch_size={patch_size} is invalid."
assert input_resolution[0] % patch_size[0] == 0 and input_resolution[1] % patch_size[1] == 0, f"input_resolution {input_resolution} should be divisible by patch_size {patch_size}"
self.input_resolution = input_resolution
self.patch_size = patch_size
self.downsampling_rate = patch_size[0]
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.num_patches_h = int(input_resolution[0] // patch_size[0])
self.num_patches_w = int(input_resolution[1] // patch_size[1])
self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches_h * self.num_patches_w + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.features_only = features_only # if True, return the final patches instead of the CLS token
if features_only:
self.channels = width
else:
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.channels = output_dim
self.reduction = patch_size[0]
self.clip_embed_dim = output_dim
def adjust_pos_embed(self, h: int, w: int) -> None:
"""
Permanently adjust the size of the positional embedding matrix.
Args:
h: the height of the original input image.
w: the width of the original input image.
"""
assert h % self.patch_size[0] == 0 and w % self.patch_size[1] == 0, f"input_resolution {h, w} should be divisible by patch_size {self.patch_size}"
if self.input_resolution[0] != h or self.input_resolution[1] != w:
new_num_patches_h = int(h // self.patch_size[0])
new_num_patches_w = int(w // self.patch_size[1])
positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension
positional_embedding = F.interpolate(positional_embedding, size=(new_num_patches_h, new_num_patches_w), mode="bicubic", ).squeeze(0) # remove batch dimension
positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
self.positional_embedding = nn.Parameter(torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0))
self.input_resolution = (h, w)
self.num_patches_h = new_num_patches_h
self.num_patches_w = new_num_patches_w
def _interpolate_pos_embed(self, h: int, w: int) -> Tensor:
"""
Interpolate the positional embedding matrix to match the size of the input image.
Args:
h: the required number of patches along the height dimension.
w: the required number of patches along the width dimension.
"""
if h == self.num_patches_h and w == self.num_patches_w:
return self.positional_embedding
else:
positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension
positional_embedding = F.interpolate(positional_embedding, size=(h, w), mode="bicubic").squeeze(0) # remove batch dimension
positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
positional_embedding = torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0)
return positional_embedding
def forward(self, x: Tensor) -> Tensor:
x = self.conv1(x) # shape = [*, width, grid, grid]
num_patches_h, num_patches_w = x.shape[-2:]
positional_embedding = self._interpolate_pos_embed(num_patches_h, num_patches_w).to(x.dtype)
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x
], dim=1)
x = x + positional_embedding
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND. N: batch size, L: sequence length, D: feature dimension
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
if self.features_only:
x = x[:, 1:, :] # remove the CLS token
x = rearrange(x, "n (h w) c -> n c h w", h=num_patches_h, w=num_patches_w)
else:
x = x[:, 0, :]
x = x @ self.proj
return x
|