CLIP-EBC / models /clip /model.py
Yiming-M's picture
🐣 born
570db9a
import torch
from torch import nn, Tensor
import torch.nn.functional as F
import numpy as np
import os
import math
from typing import List, Tuple, Union, Optional
from . import _clip
from ..utils import _init_weights, make_resnet_layers, Bottleneck, BasicBlock
from .utils import format_count
curr_dir = os.path.abspath(os.path.dirname(__file__))
# resnet50: reduction, channels, embed_dim = 32, 2048, 1024
# resnet101: reduction, channels, embed_dim = 32, 2048, 512
# resnet50x4: reduction, channels, embed_dim = 32, 2560, 640
# resnet50x16: reduction, channels, embed_dim = 32, 3072, 768
# resnet50x64: reduction, channels, embed_dim = 32, 4096, 1024
# vit_b_32: reduction, channels, embed_dim = 32, 768, 512
# vit_b_16: reduction, channels, embed_dim = 16, 768, 512
# vit_l_14: reduction, channels, embed_dim = 14, 1024, 768
# vit_l_14_336px: reduction, channels, embed_dim = 14, 1024, 768
resnet_backbones = ["resnet50", "resnet101", "resnet50x4", "resnet50x16", "resnet50x64"]
vit_backbones = ["vit_b_16", "vit_b_32", "vit_l_14", "vit_l_14_336px"]
class CLIP_EBC(nn.Module):
def __init__(
self,
backbone: str,
bins: List[Tuple[float, float]],
anchor_points: List[float],
reduction: Optional[int] = None,
freeze_text_encoder: bool = True,
prompt_type: str = "number",
input_size: Optional[int] = None,
num_vpt: Optional[int] = None,
deep_vpt: Optional[bool] = None,
vpt_drop: Optional[float] = None,
decoder_block: Optional[nn.Module] = None,
decoder_cfg: Optional[List[Union[str, int]]] = None,
) -> None:
super().__init__()
assert backbone in resnet_backbones + vit_backbones, f"Backbone should be in {resnet_backbones + vit_backbones}, got {backbone}"
self.backbone = backbone
# Image encoder
if backbone in resnet_backbones:
self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, out_indices=(-1,), reduction=reduction)
else:
assert input_size is not None, "Expected input_size to be an integer, got None."
assert num_vpt is not None, "Expected num_vpt to be an integer, got None."
assert deep_vpt is not None, "Expected deep_vpt to be a boolean, got None."
assert vpt_drop is not None, "Expected vpt_drop to be a float, got None."
self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, input_size=input_size)
self.image_encoder_depth = len(self.image_encoder.transformer.resblocks)
# Use VPT. Freeze the image encoder.
for param in self.image_encoder.parameters():
param.requires_grad = False
self.num_vpt = num_vpt
self.deep_vpt = deep_vpt
patch_size = self.image_encoder.patch_size[0]
val = math.sqrt(6. / float(3 * patch_size + self.image_encoder.channels))
for idx in range(self.image_encoder_depth if self.deep_vpt else 1):
setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.image_encoder.channels)))
nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val)
setattr(self, f"vpt_drop_{idx}", nn.Dropout(vpt_drop) if vpt_drop > 0 else nn.Identity())
self.encoder_reduction = self.image_encoder.reduction
self.reduction = self.encoder_reduction if reduction is None else reduction
self.channels = self.image_encoder.channels
self.clip_embed_dim = self.image_encoder.clip_embed_dim
if decoder_cfg is not None:
assert decoder_block is not None, "Expected decoder_block to be a nn.Module, got None."
self.image_decoder = make_resnet_layers(decoder_block, decoder_cfg, in_channels=self.channels, expansion=1, dilation=1)
self.image_decoder.apply(_init_weights)
self.channels = decoder_cfg[-1]
else:
self.image_decoder = nn.Identity()
if self.channels != self.clip_embed_dim:
self.projection = nn.Conv2d(in_channels=self.channels, out_channels=self.clip_embed_dim, kernel_size=1)
self.projection.apply(_init_weights)
else:
self.projection = nn.Identity()
# Text encoder
assert prompt_type in ["number", "word"], f"Expected prompt_type to be 'number' or 'word', got {prompt_type}"
self.prompt_type = prompt_type
self.text_encoder = getattr(_clip, f"{backbone}_txt")()
self.freeze_text_encoder = freeze_text_encoder
if self.freeze_text_encoder:
for param in self.text_encoder.parameters():
param.requires_grad = False
self.bins = bins
self.anchor_points = torch.tensor(anchor_points, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1)
self._get_text_prompts()
self._tokenize_text_prompts()
if self.freeze_text_encoder:
self._extract_text_features()
else:
self.text_features = None
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True)
def _get_text_prompts(self) -> None:
bins = [b[0] if b[0] == b[1] else b for b in self.bins]
self.text_prompts = [format_count(b, self.prompt_type) for b in bins]
print(f"Initialized model with text prompts: {self.text_prompts}")
def _tokenize_text_prompts(self) -> None:
self.text_prompts = _clip.tokenize(self.text_prompts)
def _extract_text_features(self) -> None:
with torch.no_grad():
self.text_features = self.text_encoder(self.text_prompts)
def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor:
if not self.deep_vpt:
assert layer == 0, f"Expected layer to be 0 when using Shallow Visual Prompt Tuning, got {layer}"
vpt = getattr(self, f"vpt_{layer}").to(device)
vpt = vpt.unsqueeze(0).expand(batch_size, -1, -1)
vpt = getattr(self, f"vpt_drop_{layer}")(vpt)
vpt = vpt.permute(1, 0, 2) # (num_vpt, batch_size, hidden_dim)
assert vpt.shape[1] == batch_size, f"Expected the VPT to have the shape [L_vis B C], got {vpt.shape}."
return vpt
def _forward_vpt(self, x: Tensor) -> Tuple[Tensor]:
device = x.device
batch_size, _, height, width = x.shape
num_h_patches, num_w_patches = height // self.image_encoder.patch_size[0], width // self.image_encoder.patch_size[1]
image_features = self.image_encoder.conv1(x)
image_features = image_features.reshape(batch_size, image_features.shape[1], -1)
image_features = image_features.permute(0, 2, 1) # (B, num_patches, C)
image_features = torch.cat([
self.image_encoder.class_embedding + torch.zeros(batch_size, 1, image_features.shape[-1], dtype=image_features.dtype, device=device),
image_features,
], dim=1) # (B, num_patches + 1, C)
pos_embedding = self.image_encoder._interpolate_pos_embed(num_h_patches, num_w_patches)
image_features = image_features + pos_embedding
image_features = self.image_encoder.ln_pre(image_features)
image_features = image_features.permute(1, 0, 2) # (num_patches + 1, B, C)
assert image_features.shape[0] == num_h_patches * num_w_patches + 1 and image_features.shape[1] == batch_size, f"Expected image_features to have shape [num_patches + 1, B, C], got {image_features.shape}."
vpt = self._prepare_vpt(0, batch_size, device)
for idx in range(self.image_encoder_depth):
# assemble
image_features = torch.cat([
image_features[:1, :, :], # CLS token
vpt,
image_features[1:, :, :],
], dim=0)
# transformer
image_features = self.image_encoder.transformer.resblocks[idx](image_features)
# disassemble
if idx < self.image_encoder_depth - 1:
if self.deep_vpt:
vpt = self._prepare_vpt(idx + 1, batch_size, device)
else:
vpt = image_features[1: (self.num_vpt + 1), :, :]
image_features = torch.cat([
image_features[:1, :, :], # CLS token
image_features[(self.num_vpt + 1):, :, :],
], dim=0)
image_features = image_features.permute(1, 0, 2) # (B, num_patches + 1, C)
image_features = self.image_encoder.ln_post(image_features)
image_features = image_features[:, 1:, :].permute(0, 2, 1) # (B, C, num_patches)
image_features = image_features.reshape(batch_size, -1, num_h_patches, num_w_patches)
return image_features
def _forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
device = x.device
x = self.image_encoder(x) if self.backbone in resnet_backbones else self._forward_vpt(x)
if self.reduction != self.encoder_reduction:
x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear")
x = self.image_decoder(x)
x = self.projection(x)
image_features = x.permute(0, 2, 3, 1) # shape (B, H, W, C)
text_features = self.text_encoder(self.text_prompts.to(device)) if self.text_features is None else self.text_features.to(device) # shape (N, C)
image_features = F.normalize(image_features, p=2, dim=-1)
text_features = F.normalize(text_features, p=2, dim=-1)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits = logit_scale * image_features @ text_features.t() # (B, H, W, N), logits per image
logits = logits.permute(0, 3, 1, 2) # (B, N, H, W)
probs = logits.softmax(dim=1)
exp = (probs * self.anchor_points.to(x.device)).sum(dim=1, keepdim=True) # (B, 1, H, W)
if self.training:
return logits, exp
else:
return exp
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]:
assert len(x.shape) == 4, f"Expected input to have shape (B C H W), got {x.shape}."
if "vit" in self.backbone:
image_height, image_width = x.shape[2], x.shape[3]
window_height, window_width = self.image_encoder.input_resolution
if self.training:
assert (image_height, image_width) == (window_height, window_width), f"Expected input to have shape ({window_height} {window_width}), got ({image_height} {image_width})."
return self._forward(x)
elif (image_height, image_width) == (window_height, window_width): # evaluation, input size = training size
return self._forward(x)
else: # evaluation, input_size != training size, use sliding window prediction
stride_height, stride_width = window_height, window_width
reduction = self.reduction
num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1)
num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1)
windows = []
for i in range(num_rows):
for j in range(num_cols):
x_start, y_start = i * stride_height, j * stride_width
x_end, y_end = x_start + window_height, y_start + window_width
if x_end > image_height:
x_start, x_end = image_height - window_height, image_height
if y_end > image_width:
y_start, y_end = image_width - window_width, image_width
window = x[:, :, x_start:x_end, y_start:y_end]
windows.append(window)
windows = torch.cat(windows, dim=0).to(x.device) # batched windows, shape: (num_windows, c, h, w)
preds = self._forward(windows)
preds = preds.cpu().detach().numpy()
# assemble the density map
pred_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32)
count_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32)
idx = 0
for i in range(num_rows):
for j in range(num_cols):
x_start, y_start = i * stride_height, j * stride_width
x_end, y_end = x_start + window_height, y_start + window_width
if x_end > image_height:
x_start, x_end = image_height - window_height, image_height
if y_end > image_width:
y_start, y_end = image_width - window_width, image_width
pred_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += preds[idx, :, :, :]
count_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += 1.
idx += 1
pred_map /= count_map # average the overlapping regions
return torch.tensor(pred_map).unsqueeze(0) # shape: (1, 1, h // reduction, w // reduction)
else:
return self._forward(x)
def _clip_ebc(
backbone: str,
bins: List[Tuple[float, float]],
anchor_points: List[float],
reduction: Optional[int] = None,
freeze_text_encoder: bool = True,
prompt_type: str = "number",
input_size: Optional[int] = None,
num_vpt: Optional[int] = None,
deep_vpt: Optional[bool] = None,
vpt_drop: Optional[float] = None,
decoder_block: Optional[nn.Module] = None,
decoder_cfg: Optional[List[Union[str, int]]] = None
) -> CLIP_EBC:
if backbone in resnet_backbones:
decoder_block = Bottleneck
if decoder_cfg is None:
if backbone == "resnet50":
decoder_cfg = [2048]
elif backbone == "resnet50x4":
decoder_cfg = [1280]
elif backbone == "resnet50x16":
decoder_cfg = [1536]
elif backbone == "resnet50x64":
decoder_cfg = [2048]
else: # backbone == "resnet101"
decoder_cfg = [2048, 1024]
else:
decoder_block = BasicBlock
if decoder_cfg is None:
if backbone == "vit_b_16":
decoder_cfg = [768]
elif backbone == "vit_b_32":
decoder_cfg = [768]
else: # backbone == "vit_l_14"
decoder_cfg = [1024]
return CLIP_EBC(
backbone=backbone,
bins=bins,
anchor_points=anchor_points,
reduction=reduction,
freeze_text_encoder=freeze_text_encoder,
prompt_type=prompt_type,
input_size=input_size,
num_vpt=num_vpt,
deep_vpt=deep_vpt,
vpt_drop=vpt_drop,
decoder_block=decoder_block,
decoder_cfg=decoder_cfg,
)