Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
from einops import rearrange | |
from typing import Tuple, Union, Any, List, Iterable, Optional | |
from .blocks import LayerNorm, Transformer, Bottleneck, AttentionPool2d | |
class ModifiedResNet(nn.Module): | |
""" | |
A ResNet class that is similar to torchvision's but contains the following changes: | |
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. | |
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 | |
- The final pooling layer is a QKV attention instead of an average pool | |
""" | |
def __init__( | |
self, | |
layers: Tuple[int, int, int, int], | |
output_dim: int, | |
input_resolution: int = 224, | |
width: int = 64, | |
heads: int = 8, | |
features_only: bool = False, | |
out_indices: Optional[Iterable[int]] = None, | |
reduction: int = 32, | |
**kwargs: Any, | |
) -> None: | |
super().__init__() | |
input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution | |
assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}" | |
self.input_resolution = input_resolution | |
self.downsampling_rate = 32 # the rate at which the input is downsampled by the network | |
# the 3-layer stem | |
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) | |
self.bn1 = nn.BatchNorm2d(width // 2) | |
self.relu1 = nn.ReLU(inplace=True) | |
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) | |
self.bn2 = nn.BatchNorm2d(width // 2) | |
self.relu2 = nn.ReLU(inplace=True) | |
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) | |
self.bn3 = nn.BatchNorm2d(width) | |
self.relu3 = nn.ReLU(inplace=True) | |
self.avgpool = nn.AvgPool2d(2) | |
# residual layers | |
self._inplanes = width # this is a *mutable* variable used during construction | |
self.layer1 = self._make_layer(width, layers[0]) | |
self.layer2 = self._make_layer(width * 2, layers[1], stride=2) | |
self.layer3 = self._make_layer(width * 4, layers[2], stride=2) | |
self.layer4 = self._make_layer(width * 8, layers[3], stride=1 if reduction <= 16 else 2) | |
self.features_only = features_only | |
if features_only: | |
self.out_indices = out_indices if out_indices is not None else range(5) | |
self.out_indices = [idx + 5 if idx < 0 else idx for idx in self.out_indices] # map negative indices to positive indices | |
self.out_indices = sorted(set(self.out_indices)) # remove duplicates and sort | |
assert min(self.out_indices) >= 0 and max(self.out_indices) <= 4, f"out_indices={self.out_indices} is invalid for a ResNet with 5 stages" | |
self.channels = width * 32 # the ResNet feature dimension | |
else: | |
self.out_indices = None | |
embed_dim = width * 32 # the ResNet feature dimension | |
self.attnpool = AttentionPool2d((input_resolution[0] // 32) * (input_resolution[1] // 32), embed_dim, heads, output_dim) | |
self.channels = output_dim | |
self.reduction = self.downsampling_rate // 2 if reduction <= 16 else self.downsampling_rate | |
self.clip_embed_dim = output_dim | |
def _make_layer(self, planes, blocks, stride=1): | |
layers = [Bottleneck(self._inplanes, planes, stride)] | |
self._inplanes = planes * Bottleneck.expansion | |
for _ in range(1, blocks): | |
layers.append(Bottleneck(self._inplanes, planes)) | |
return nn.Sequential(*layers) | |
def _stem(self, x: Tensor) -> Tensor: | |
x = self.relu1(self.bn1(self.conv1(x))) | |
x = self.relu2(self.bn2(self.conv2(x))) | |
x = self.relu3(self.bn3(self.conv3(x))) | |
x = self.avgpool(x) | |
return x | |
def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]: | |
x = x.type(self.conv1.weight.dtype) | |
x = self._stem(x) | |
feats = [x] if self.features_only and 0 in self.out_indices else [] | |
x = self.layer1(x) | |
if self.features_only and 1 in self.out_indices: | |
feats.append(x) | |
x = self.layer2(x) | |
if self.features_only and 2 in self.out_indices: | |
feats.append(x) | |
x = self.layer3(x) | |
if self.features_only and 3 in self.out_indices: | |
feats.append(x) | |
x = self.layer4(x) | |
if self.features_only and 4 in self.out_indices: | |
feats.append(x) | |
if self.features_only: | |
if len(self.out_indices) == 1: | |
return feats[0] | |
else: | |
return feats | |
else: | |
x = self.attnpool(x) | |
return x | |
class VisionTransformer(nn.Module): | |
def __init__( | |
self, | |
input_resolution: Union[int, Tuple[int, int]], | |
patch_size: Union[int, Tuple[int, int]], | |
output_dim: int, | |
width: int, | |
layers: int, | |
heads: int, | |
features_only: bool = False, | |
**kwargs: Any, | |
) -> None: | |
super().__init__() | |
input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution | |
patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size | |
assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}" | |
assert isinstance(patch_size, tuple) and len(patch_size) == 2, f"patch_size should be a tuple of length 2, but got {patch_size}" | |
assert patch_size[0] == patch_size[1], f"ViT only supports square patches, patch_size={patch_size} is invalid." | |
assert input_resolution[0] % patch_size[0] == 0 and input_resolution[1] % patch_size[1] == 0, f"input_resolution {input_resolution} should be divisible by patch_size {patch_size}" | |
self.input_resolution = input_resolution | |
self.patch_size = patch_size | |
self.downsampling_rate = patch_size[0] | |
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) | |
scale = width ** -0.5 | |
self.class_embedding = nn.Parameter(scale * torch.randn(width)) | |
self.num_patches_h = int(input_resolution[0] // patch_size[0]) | |
self.num_patches_w = int(input_resolution[1] // patch_size[1]) | |
self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches_h * self.num_patches_w + 1, width)) | |
self.ln_pre = LayerNorm(width) | |
self.transformer = Transformer(width, layers, heads) | |
self.ln_post = LayerNorm(width) | |
self.features_only = features_only # if True, return the final patches instead of the CLS token | |
if features_only: | |
self.channels = width | |
else: | |
self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) | |
self.channels = output_dim | |
self.reduction = patch_size[0] | |
self.clip_embed_dim = output_dim | |
def adjust_pos_embed(self, h: int, w: int) -> None: | |
""" | |
Permanently adjust the size of the positional embedding matrix. | |
Args: | |
h: the height of the original input image. | |
w: the width of the original input image. | |
""" | |
assert h % self.patch_size[0] == 0 and w % self.patch_size[1] == 0, f"input_resolution {h, w} should be divisible by patch_size {self.patch_size}" | |
if self.input_resolution[0] != h or self.input_resolution[1] != w: | |
new_num_patches_h = int(h // self.patch_size[0]) | |
new_num_patches_w = int(w // self.patch_size[1]) | |
positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension | |
positional_embedding = F.interpolate(positional_embedding, size=(new_num_patches_h, new_num_patches_w), mode="bicubic", ).squeeze(0) # remove batch dimension | |
positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c") | |
self.positional_embedding = nn.Parameter(torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0)) | |
self.input_resolution = (h, w) | |
self.num_patches_h = new_num_patches_h | |
self.num_patches_w = new_num_patches_w | |
def _interpolate_pos_embed(self, h: int, w: int) -> Tensor: | |
""" | |
Interpolate the positional embedding matrix to match the size of the input image. | |
Args: | |
h: the required number of patches along the height dimension. | |
w: the required number of patches along the width dimension. | |
""" | |
if h == self.num_patches_h and w == self.num_patches_w: | |
return self.positional_embedding | |
else: | |
positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension | |
positional_embedding = F.interpolate(positional_embedding, size=(h, w), mode="bicubic").squeeze(0) # remove batch dimension | |
positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c") | |
positional_embedding = torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0) | |
return positional_embedding | |
def forward(self, x: Tensor) -> Tensor: | |
x = self.conv1(x) # shape = [*, width, grid, grid] | |
num_patches_h, num_patches_w = x.shape[-2:] | |
positional_embedding = self._interpolate_pos_embed(num_patches_h, num_patches_w).to(x.dtype) | |
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] | |
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] | |
x = torch.cat([ | |
self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), | |
x | |
], dim=1) | |
x = x + positional_embedding | |
x = self.ln_pre(x) | |
x = x.permute(1, 0, 2) # NLD -> LND. N: batch size, L: sequence length, D: feature dimension | |
x = self.transformer(x) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_post(x) | |
if self.features_only: | |
x = x[:, 1:, :] # remove the CLS token | |
x = rearrange(x, "n (h w) c -> n c h w", h=num_patches_h, w=num_patches_w) | |
else: | |
x = x[:, 0, :] | |
x = x @ self.proj | |
return x | |