Spaces:
Sleeping
Sleeping
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
import numpy as np | |
import os | |
import math | |
from typing import List, Tuple, Union, Optional | |
from . import _clip | |
from ..utils import _init_weights, make_resnet_layers, Bottleneck, BasicBlock | |
from .utils import format_count | |
curr_dir = os.path.abspath(os.path.dirname(__file__)) | |
# resnet50: reduction, channels, embed_dim = 32, 2048, 1024 | |
# resnet101: reduction, channels, embed_dim = 32, 2048, 512 | |
# resnet50x4: reduction, channels, embed_dim = 32, 2560, 640 | |
# resnet50x16: reduction, channels, embed_dim = 32, 3072, 768 | |
# resnet50x64: reduction, channels, embed_dim = 32, 4096, 1024 | |
# vit_b_32: reduction, channels, embed_dim = 32, 768, 512 | |
# vit_b_16: reduction, channels, embed_dim = 16, 768, 512 | |
# vit_l_14: reduction, channels, embed_dim = 14, 1024, 768 | |
# vit_l_14_336px: reduction, channels, embed_dim = 14, 1024, 768 | |
resnet_backbones = ["resnet50", "resnet101", "resnet50x4", "resnet50x16", "resnet50x64"] | |
vit_backbones = ["vit_b_16", "vit_b_32", "vit_l_14", "vit_l_14_336px"] | |
class CLIP_EBC(nn.Module): | |
def __init__( | |
self, | |
backbone: str, | |
bins: List[Tuple[float, float]], | |
anchor_points: List[float], | |
reduction: Optional[int] = None, | |
freeze_text_encoder: bool = True, | |
prompt_type: str = "number", | |
input_size: Optional[int] = None, | |
num_vpt: Optional[int] = None, | |
deep_vpt: Optional[bool] = None, | |
vpt_drop: Optional[float] = None, | |
decoder_block: Optional[nn.Module] = None, | |
decoder_cfg: Optional[List[Union[str, int]]] = None, | |
) -> None: | |
super().__init__() | |
assert backbone in resnet_backbones + vit_backbones, f"Backbone should be in {resnet_backbones + vit_backbones}, got {backbone}" | |
self.backbone = backbone | |
# Image encoder | |
if backbone in resnet_backbones: | |
self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, out_indices=(-1,), reduction=reduction) | |
else: | |
assert input_size is not None, "Expected input_size to be an integer, got None." | |
assert num_vpt is not None, "Expected num_vpt to be an integer, got None." | |
assert deep_vpt is not None, "Expected deep_vpt to be a boolean, got None." | |
assert vpt_drop is not None, "Expected vpt_drop to be a float, got None." | |
self.image_encoder = getattr(_clip, f"{backbone}_img")(features_only=True, input_size=input_size) | |
self.image_encoder_depth = len(self.image_encoder.transformer.resblocks) | |
# Use VPT. Freeze the image encoder. | |
for param in self.image_encoder.parameters(): | |
param.requires_grad = False | |
self.num_vpt = num_vpt | |
self.deep_vpt = deep_vpt | |
patch_size = self.image_encoder.patch_size[0] | |
val = math.sqrt(6. / float(3 * patch_size + self.image_encoder.channels)) | |
for idx in range(self.image_encoder_depth if self.deep_vpt else 1): | |
setattr(self, f"vpt_{idx}", nn.Parameter(torch.empty(self.num_vpt, self.image_encoder.channels))) | |
nn.init.uniform_(getattr(self, f"vpt_{idx}"), -val, val) | |
setattr(self, f"vpt_drop_{idx}", nn.Dropout(vpt_drop) if vpt_drop > 0 else nn.Identity()) | |
self.encoder_reduction = self.image_encoder.reduction | |
self.reduction = self.encoder_reduction if reduction is None else reduction | |
self.channels = self.image_encoder.channels | |
self.clip_embed_dim = self.image_encoder.clip_embed_dim | |
if decoder_cfg is not None: | |
assert decoder_block is not None, "Expected decoder_block to be a nn.Module, got None." | |
self.image_decoder = make_resnet_layers(decoder_block, decoder_cfg, in_channels=self.channels, expansion=1, dilation=1) | |
self.image_decoder.apply(_init_weights) | |
self.channels = decoder_cfg[-1] | |
else: | |
self.image_decoder = nn.Identity() | |
if self.channels != self.clip_embed_dim: | |
self.projection = nn.Conv2d(in_channels=self.channels, out_channels=self.clip_embed_dim, kernel_size=1) | |
self.projection.apply(_init_weights) | |
else: | |
self.projection = nn.Identity() | |
# Text encoder | |
assert prompt_type in ["number", "word"], f"Expected prompt_type to be 'number' or 'word', got {prompt_type}" | |
self.prompt_type = prompt_type | |
self.text_encoder = getattr(_clip, f"{backbone}_txt")() | |
self.freeze_text_encoder = freeze_text_encoder | |
if self.freeze_text_encoder: | |
for param in self.text_encoder.parameters(): | |
param.requires_grad = False | |
self.bins = bins | |
self.anchor_points = torch.tensor(anchor_points, dtype=torch.float32, requires_grad=False).view(1, -1, 1, 1) | |
self._get_text_prompts() | |
self._tokenize_text_prompts() | |
if self.freeze_text_encoder: | |
self._extract_text_features() | |
else: | |
self.text_features = None | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07), requires_grad=True) | |
def _get_text_prompts(self) -> None: | |
bins = [b[0] if b[0] == b[1] else b for b in self.bins] | |
self.text_prompts = [format_count(b, self.prompt_type) for b in bins] | |
print(f"Initialized model with text prompts: {self.text_prompts}") | |
def _tokenize_text_prompts(self) -> None: | |
self.text_prompts = _clip.tokenize(self.text_prompts) | |
def _extract_text_features(self) -> None: | |
with torch.no_grad(): | |
self.text_features = self.text_encoder(self.text_prompts) | |
def _prepare_vpt(self, layer: int, batch_size: int, device: torch.device) -> Tensor: | |
if not self.deep_vpt: | |
assert layer == 0, f"Expected layer to be 0 when using Shallow Visual Prompt Tuning, got {layer}" | |
vpt = getattr(self, f"vpt_{layer}").to(device) | |
vpt = vpt.unsqueeze(0).expand(batch_size, -1, -1) | |
vpt = getattr(self, f"vpt_drop_{layer}")(vpt) | |
vpt = vpt.permute(1, 0, 2) # (num_vpt, batch_size, hidden_dim) | |
assert vpt.shape[1] == batch_size, f"Expected the VPT to have the shape [L_vis B C], got {vpt.shape}." | |
return vpt | |
def _forward_vpt(self, x: Tensor) -> Tuple[Tensor]: | |
device = x.device | |
batch_size, _, height, width = x.shape | |
num_h_patches, num_w_patches = height // self.image_encoder.patch_size[0], width // self.image_encoder.patch_size[1] | |
image_features = self.image_encoder.conv1(x) | |
image_features = image_features.reshape(batch_size, image_features.shape[1], -1) | |
image_features = image_features.permute(0, 2, 1) # (B, num_patches, C) | |
image_features = torch.cat([ | |
self.image_encoder.class_embedding + torch.zeros(batch_size, 1, image_features.shape[-1], dtype=image_features.dtype, device=device), | |
image_features, | |
], dim=1) # (B, num_patches + 1, C) | |
pos_embedding = self.image_encoder._interpolate_pos_embed(num_h_patches, num_w_patches) | |
image_features = image_features + pos_embedding | |
image_features = self.image_encoder.ln_pre(image_features) | |
image_features = image_features.permute(1, 0, 2) # (num_patches + 1, B, C) | |
assert image_features.shape[0] == num_h_patches * num_w_patches + 1 and image_features.shape[1] == batch_size, f"Expected image_features to have shape [num_patches + 1, B, C], got {image_features.shape}." | |
vpt = self._prepare_vpt(0, batch_size, device) | |
for idx in range(self.image_encoder_depth): | |
# assemble | |
image_features = torch.cat([ | |
image_features[:1, :, :], # CLS token | |
vpt, | |
image_features[1:, :, :], | |
], dim=0) | |
# transformer | |
image_features = self.image_encoder.transformer.resblocks[idx](image_features) | |
# disassemble | |
if idx < self.image_encoder_depth - 1: | |
if self.deep_vpt: | |
vpt = self._prepare_vpt(idx + 1, batch_size, device) | |
else: | |
vpt = image_features[1: (self.num_vpt + 1), :, :] | |
image_features = torch.cat([ | |
image_features[:1, :, :], # CLS token | |
image_features[(self.num_vpt + 1):, :, :], | |
], dim=0) | |
image_features = image_features.permute(1, 0, 2) # (B, num_patches + 1, C) | |
image_features = self.image_encoder.ln_post(image_features) | |
image_features = image_features[:, 1:, :].permute(0, 2, 1) # (B, C, num_patches) | |
image_features = image_features.reshape(batch_size, -1, num_h_patches, num_w_patches) | |
return image_features | |
def _forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
device = x.device | |
x = self.image_encoder(x) if self.backbone in resnet_backbones else self._forward_vpt(x) | |
if self.reduction != self.encoder_reduction: | |
x = F.interpolate(x, scale_factor=self.encoder_reduction / self.reduction, mode="bilinear") | |
x = self.image_decoder(x) | |
x = self.projection(x) | |
image_features = x.permute(0, 2, 3, 1) # shape (B, H, W, C) | |
text_features = self.text_encoder(self.text_prompts.to(device)) if self.text_features is None else self.text_features.to(device) # shape (N, C) | |
image_features = F.normalize(image_features, p=2, dim=-1) | |
text_features = F.normalize(text_features, p=2, dim=-1) | |
# cosine similarity as logits | |
logit_scale = self.logit_scale.exp() | |
logits = logit_scale * image_features @ text_features.t() # (B, H, W, N), logits per image | |
logits = logits.permute(0, 3, 1, 2) # (B, N, H, W) | |
probs = logits.softmax(dim=1) | |
exp = (probs * self.anchor_points.to(x.device)).sum(dim=1, keepdim=True) # (B, 1, H, W) | |
if self.training: | |
return logits, exp | |
else: | |
return exp | |
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor, Tensor]]: | |
assert len(x.shape) == 4, f"Expected input to have shape (B C H W), got {x.shape}." | |
if "vit" in self.backbone: | |
image_height, image_width = x.shape[2], x.shape[3] | |
window_height, window_width = self.image_encoder.input_resolution | |
if self.training: | |
assert (image_height, image_width) == (window_height, window_width), f"Expected input to have shape ({window_height} {window_width}), got ({image_height} {image_width})." | |
return self._forward(x) | |
elif (image_height, image_width) == (window_height, window_width): # evaluation, input size = training size | |
return self._forward(x) | |
else: # evaluation, input_size != training size, use sliding window prediction | |
stride_height, stride_width = window_height, window_width | |
reduction = self.reduction | |
num_rows = int(np.ceil((image_height - window_height) / stride_height) + 1) | |
num_cols = int(np.ceil((image_width - window_width) / stride_width) + 1) | |
windows = [] | |
for i in range(num_rows): | |
for j in range(num_cols): | |
x_start, y_start = i * stride_height, j * stride_width | |
x_end, y_end = x_start + window_height, y_start + window_width | |
if x_end > image_height: | |
x_start, x_end = image_height - window_height, image_height | |
if y_end > image_width: | |
y_start, y_end = image_width - window_width, image_width | |
window = x[:, :, x_start:x_end, y_start:y_end] | |
windows.append(window) | |
windows = torch.cat(windows, dim=0).to(x.device) # batched windows, shape: (num_windows, c, h, w) | |
preds = self._forward(windows) | |
preds = preds.cpu().detach().numpy() | |
# assemble the density map | |
pred_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32) | |
count_map = np.zeros((preds.shape[1], image_height // reduction, image_width // reduction), dtype=np.float32) | |
idx = 0 | |
for i in range(num_rows): | |
for j in range(num_cols): | |
x_start, y_start = i * stride_height, j * stride_width | |
x_end, y_end = x_start + window_height, y_start + window_width | |
if x_end > image_height: | |
x_start, x_end = image_height - window_height, image_height | |
if y_end > image_width: | |
y_start, y_end = image_width - window_width, image_width | |
pred_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += preds[idx, :, :, :] | |
count_map[:, (x_start // reduction): (x_end // reduction), (y_start // reduction): (y_end // reduction)] += 1. | |
idx += 1 | |
pred_map /= count_map # average the overlapping regions | |
return torch.tensor(pred_map).unsqueeze(0) # shape: (1, 1, h // reduction, w // reduction) | |
else: | |
return self._forward(x) | |
def _clip_ebc( | |
backbone: str, | |
bins: List[Tuple[float, float]], | |
anchor_points: List[float], | |
reduction: Optional[int] = None, | |
freeze_text_encoder: bool = True, | |
prompt_type: str = "number", | |
input_size: Optional[int] = None, | |
num_vpt: Optional[int] = None, | |
deep_vpt: Optional[bool] = None, | |
vpt_drop: Optional[float] = None, | |
decoder_block: Optional[nn.Module] = None, | |
decoder_cfg: Optional[List[Union[str, int]]] = None | |
) -> CLIP_EBC: | |
if backbone in resnet_backbones: | |
decoder_block = Bottleneck | |
if decoder_cfg is None: | |
if backbone == "resnet50": | |
decoder_cfg = [2048] | |
elif backbone == "resnet50x4": | |
decoder_cfg = [1280] | |
elif backbone == "resnet50x16": | |
decoder_cfg = [1536] | |
elif backbone == "resnet50x64": | |
decoder_cfg = [2048] | |
else: # backbone == "resnet101" | |
decoder_cfg = [2048, 1024] | |
else: | |
decoder_block = BasicBlock | |
if decoder_cfg is None: | |
if backbone == "vit_b_16": | |
decoder_cfg = [768] | |
elif backbone == "vit_b_32": | |
decoder_cfg = [768] | |
else: # backbone == "vit_l_14" | |
decoder_cfg = [1024] | |
return CLIP_EBC( | |
backbone=backbone, | |
bins=bins, | |
anchor_points=anchor_points, | |
reduction=reduction, | |
freeze_text_encoder=freeze_text_encoder, | |
prompt_type=prompt_type, | |
input_size=input_size, | |
num_vpt=num_vpt, | |
deep_vpt=deep_vpt, | |
vpt_drop=vpt_drop, | |
decoder_block=decoder_block, | |
decoder_cfg=decoder_cfg, | |
) | |