CLIP-EBC / models /clip /_clip /image_encoder.py
Yiming-M's picture
🐣 born
570db9a
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from einops import rearrange
from typing import Tuple, Union, Any, List, Iterable, Optional
from .blocks import LayerNorm, Transformer, Bottleneck, AttentionPool2d
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(
self,
layers: Tuple[int, int, int, int],
output_dim: int,
input_resolution: int = 224,
width: int = 64,
heads: int = 8,
features_only: bool = False,
out_indices: Optional[Iterable[int]] = None,
reduction: int = 32,
**kwargs: Any,
) -> None:
super().__init__()
input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
self.input_resolution = input_resolution
self.downsampling_rate = 32 # the rate at which the input is downsampled by the network
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.relu2 = nn.ReLU(inplace=True)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.relu3 = nn.ReLU(inplace=True)
self.avgpool = nn.AvgPool2d(2)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=1 if reduction <= 16 else 2)
self.features_only = features_only
if features_only:
self.out_indices = out_indices if out_indices is not None else range(5)
self.out_indices = [idx + 5 if idx < 0 else idx for idx in self.out_indices] # map negative indices to positive indices
self.out_indices = sorted(set(self.out_indices)) # remove duplicates and sort
assert min(self.out_indices) >= 0 and max(self.out_indices) <= 4, f"out_indices={self.out_indices} is invalid for a ResNet with 5 stages"
self.channels = width * 32 # the ResNet feature dimension
else:
self.out_indices = None
embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d((input_resolution[0] // 32) * (input_resolution[1] // 32), embed_dim, heads, output_dim)
self.channels = output_dim
self.reduction = self.downsampling_rate // 2 if reduction <= 16 else self.downsampling_rate
self.clip_embed_dim = output_dim
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def _stem(self, x: Tensor) -> Tensor:
x = self.relu1(self.bn1(self.conv1(x)))
x = self.relu2(self.bn2(self.conv2(x)))
x = self.relu3(self.bn3(self.conv3(x)))
x = self.avgpool(x)
return x
def forward(self, x: Tensor) -> Union[Tensor, List[Tensor]]:
x = x.type(self.conv1.weight.dtype)
x = self._stem(x)
feats = [x] if self.features_only and 0 in self.out_indices else []
x = self.layer1(x)
if self.features_only and 1 in self.out_indices:
feats.append(x)
x = self.layer2(x)
if self.features_only and 2 in self.out_indices:
feats.append(x)
x = self.layer3(x)
if self.features_only and 3 in self.out_indices:
feats.append(x)
x = self.layer4(x)
if self.features_only and 4 in self.out_indices:
feats.append(x)
if self.features_only:
if len(self.out_indices) == 1:
return feats[0]
else:
return feats
else:
x = self.attnpool(x)
return x
class VisionTransformer(nn.Module):
def __init__(
self,
input_resolution: Union[int, Tuple[int, int]],
patch_size: Union[int, Tuple[int, int]],
output_dim: int,
width: int,
layers: int,
heads: int,
features_only: bool = False,
**kwargs: Any,
) -> None:
super().__init__()
input_resolution = (input_resolution, input_resolution) if isinstance(input_resolution, int) else input_resolution
patch_size = (patch_size, patch_size) if isinstance(patch_size, int) else patch_size
assert isinstance(input_resolution, tuple) and len(input_resolution) == 2, f"input_resolution should be a tuple of length 2, but got {input_resolution}"
assert isinstance(patch_size, tuple) and len(patch_size) == 2, f"patch_size should be a tuple of length 2, but got {patch_size}"
assert patch_size[0] == patch_size[1], f"ViT only supports square patches, patch_size={patch_size} is invalid."
assert input_resolution[0] % patch_size[0] == 0 and input_resolution[1] % patch_size[1] == 0, f"input_resolution {input_resolution} should be divisible by patch_size {patch_size}"
self.input_resolution = input_resolution
self.patch_size = patch_size
self.downsampling_rate = patch_size[0]
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.num_patches_h = int(input_resolution[0] // patch_size[0])
self.num_patches_w = int(input_resolution[1] // patch_size[1])
self.positional_embedding = nn.Parameter(scale * torch.randn(self.num_patches_h * self.num_patches_w + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.features_only = features_only # if True, return the final patches instead of the CLS token
if features_only:
self.channels = width
else:
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
self.channels = output_dim
self.reduction = patch_size[0]
self.clip_embed_dim = output_dim
def adjust_pos_embed(self, h: int, w: int) -> None:
"""
Permanently adjust the size of the positional embedding matrix.
Args:
h: the height of the original input image.
w: the width of the original input image.
"""
assert h % self.patch_size[0] == 0 and w % self.patch_size[1] == 0, f"input_resolution {h, w} should be divisible by patch_size {self.patch_size}"
if self.input_resolution[0] != h or self.input_resolution[1] != w:
new_num_patches_h = int(h // self.patch_size[0])
new_num_patches_w = int(w // self.patch_size[1])
positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension
positional_embedding = F.interpolate(positional_embedding, size=(new_num_patches_h, new_num_patches_w), mode="bicubic", ).squeeze(0) # remove batch dimension
positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
self.positional_embedding = nn.Parameter(torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0))
self.input_resolution = (h, w)
self.num_patches_h = new_num_patches_h
self.num_patches_w = new_num_patches_w
def _interpolate_pos_embed(self, h: int, w: int) -> Tensor:
"""
Interpolate the positional embedding matrix to match the size of the input image.
Args:
h: the required number of patches along the height dimension.
w: the required number of patches along the width dimension.
"""
if h == self.num_patches_h and w == self.num_patches_w:
return self.positional_embedding
else:
positional_embedding = rearrange(self.positional_embedding[1:, :], "(h w) c -> c h w", h=self.num_patches_h, w=self.num_patches_w).unsqueeze(0) # add batch dimension
positional_embedding = F.interpolate(positional_embedding, size=(h, w), mode="bicubic").squeeze(0) # remove batch dimension
positional_embedding = rearrange(positional_embedding, "c h w -> (h w) c")
positional_embedding = torch.cat([self.positional_embedding[:1, :], positional_embedding], dim=0)
return positional_embedding
def forward(self, x: Tensor) -> Tensor:
x = self.conv1(x) # shape = [*, width, grid, grid]
num_patches_h, num_patches_w = x.shape[-2:]
positional_embedding = self._interpolate_pos_embed(num_patches_h, num_patches_w).to(x.dtype)
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([
self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device),
x
], dim=1)
x = x + positional_embedding
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND. N: batch size, L: sequence length, D: feature dimension
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
if self.features_only:
x = x[:, 1:, :] # remove the CLS token
x = rearrange(x, "n (h w) c -> n c h w", h=num_patches_h, w=num_patches_w)
else:
x = x[:, 0, :]
x = x @ self.proj
return x