Spaces:
Runtime error
Runtime error
File size: 16,720 Bytes
4d0eb62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 |
# Copyright (c) OpenMMLab. All rights reserved.
from collections import OrderedDict
from typing import List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from mmengine.model import BaseModel, BaseModule
from torch import nn
from mmpretrain.datasets.categories import CIFAR100_CATEGORIES_CN
from mmpretrain.registry import MODELS, TOKENIZER
from mmpretrain.structures import DataSample
from mmpretrain.utils import track_on_main_process
from .utils import OPENAI_PROMPT
PROTOTYPE_MAP = {'cifar100': CIFAR100_CATEGORIES_CN}
PROMPT_MAP = {'openai': OPENAI_PROMPT}
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
self.downsample = nn.Sequential(
OrderedDict([('-1', nn.AvgPool2d(stride)),
('0',
nn.Conv2d(
inplanes,
planes * self.expansion,
1,
stride=1,
bias=False)),
('1', nn.BatchNorm2d(planes * self.expansion))]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self,
spacial_dim: int,
embed_dim: int,
num_heads: int,
output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(
torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1],
x.shape[2] * x.shape[3]).permute(2, 0,
1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x,
key=x,
value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat(
[self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False)
return x[0]
@MODELS.register_module()
class ModifiedResNet(BaseModule):
"""A modified ResNet contains the following changes:
- Apply deep stem with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is
prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
""" # noqa
arch_settings = {
50: (Bottleneck, (3, 4, 6, 3)),
101: (Bottleneck, (3, 4, 23, 3)),
152: (Bottleneck, (3, 8, 36, 3))
}
def __init__(self,
depth: int = 50,
base_channels: int = 64,
input_size: int = 224,
num_attn_heads: int = 32,
output_dim: int = 1024,
init_cfg: Optional[dict] = None):
super().__init__(init_cfg=init_cfg)
self.input_size = input_size
self.block, stage_blocks = self.arch_settings[depth]
# the 3-layer stem
self.conv1 = nn.Conv2d(
3,
base_channels // 2,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(base_channels // 2)
self.conv2 = nn.Conv2d(
base_channels // 2,
base_channels // 2,
kernel_size=3,
padding=1,
bias=False)
self.bn2 = nn.BatchNorm2d(base_channels // 2)
self.conv3 = nn.Conv2d(
base_channels // 2,
base_channels,
kernel_size=3,
padding=1,
bias=False)
self.bn3 = nn.BatchNorm2d(base_channels)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
# this is a *mutable* variable used during construction
self._inplanes = base_channels
self.layer1 = self._make_layer(base_channels, stage_blocks[0])
self.layer2 = self._make_layer(
base_channels * 2, stage_blocks[1], stride=2)
self.layer3 = self._make_layer(
base_channels * 4, stage_blocks[2], stride=2)
self.layer4 = self._make_layer(
base_channels * 8, stage_blocks[3], stride=2)
embed_dim = base_channels * 32
self.attnpool = AttentionPool2d(input_size // 32, embed_dim,
num_attn_heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2),
(self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.attnpool(x)
return x
@MODELS.register_module()
class ChineseCLIP(BaseModel):
"""The implementation of `ChineseCLIP <https://arxiv.org/abs/2211.01335>`_.
Args:
vision_backbone (dict): Config dict for vision backbone.
text_backbone (dict): Config dict for text backbone.
tokenizer (dict): Config dict for text tokenizer.
proj_dim (int): Projection dimension for similarity computation.
text_prototype (str): Text prototype, which can be a key in
`PROTOTYPE_MAP` or list of text.
text_prompt (str): The prompt for text prototype. Defaults to 'openai'.
context_length (int): The context length to use. Defaults to 52.
data_preprocessor (Union[dict, nn.Module], optional): The config for
preprocessing input data. If None or no specified type, it will use
"MultiModalDataPreprocessor" as type.
See :class:`MultiModalDataPreprocessor` for more details.
Defaults to None.
init_cfg (dict, optional): The config to control the initialization.
Defaults to None.
"""
def __init__(self,
vision_backbone: dict,
text_backbone: dict,
tokenizer: dict,
proj_dim: int,
text_prototype: Union[str, List[str]],
text_prompt: str = 'openai',
context_length: int = 52,
data_preprocessor: Optional[dict] = None,
init_cfg: Optional[dict] = None):
if data_preprocessor is None:
data_preprocessor = {}
data_preprocessor.setdefault('type', 'MultiModalDataPreprocessor')
data_preprocessor = MODELS.build(data_preprocessor)
super().__init__(
data_preprocessor=data_preprocessor, init_cfg=init_cfg)
self.vision_backbone = MODELS.build(vision_backbone)
self.text_backbone = MODELS.build(text_backbone)
if not isinstance(self.vision_backbone, ModifiedResNet):
self.vision_projection = nn.Parameter(
torch.empty(self.vision_backbone.embed_dims, proj_dim))
text_hidden_size = text_backbone['config']['hidden_size']
self.text_projection = nn.Parameter(
torch.empty(text_hidden_size, proj_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.tokenizer = TOKENIZER.build(tokenizer)
self.context_length = context_length
# for zero-shot classification
if isinstance(text_prototype,
str) and text_prototype in PROTOTYPE_MAP.keys():
self.prototype = PROTOTYPE_MAP[text_prototype]
else:
self.prototype = text_prototype
self.text_prototype_embeds = None
self.prompt = PROMPT_MAP[text_prompt]
def forward(
self,
images: torch.Tensor,
data_samples: Optional[list] = None,
mode: str = 'predict',
**kwargs,
):
"""The unified entry for a forward process in both training and test.
The method accepts the following modes:
- "predict": Forward and return a list of data samples contain the
predict results.
Args:
images (torch.Tensor): the preprocessed image tensor of shape
``(N, C, H, W)``.
data_samples (List[DataSample], optional): The annotation data
of every samples. Defaults to None.
mode (str): Return what kind of value. Defaults to 'predict'.
"""
if mode == 'predict':
return self.predict(images, data_samples, **kwargs)
else:
raise RuntimeError(f'Invalid mode "{mode}".')
def extract_image_feat(self, images: torch.Tensor) -> torch.Tensor:
"""The function to extract image latent features."""
if isinstance(self.vision_backbone, ModifiedResNet):
return self.vision_backbone(images)
return self.vision_backbone(images)[-1] @ self.vision_projection
def extract_text_feat(self, texts: torch.Tensor) -> torch.Tensor:
"""The function to extract text latent features."""
pad_index = self.tokenizer.vocab['[PAD]']
attn_mask = texts.ne(pad_index)
# [batch_size, seq_length, hidden_size]
x = self.text_backbone(texts, attention_mask=attn_mask)[0]
return x[:, 0, :] @ self.text_projection
def extract_feat(
self, images: torch.Tensor,
texts: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
"""The function to extract image and text latent features, the input
image or text can not both be None."""
assert images is not None or texts is not None, \
'text and image cannot both be None!'
if images is None:
return self.extract_text_feat(texts)
elif texts is None:
return self.extract_image_feat(images)
image_features = self.extract_image_feat(images)
text_features = self.extract_text_feat(texts)
image_features = image_features / image_features.norm(
dim=-1, keepdim=True)
text_features = text_features / text_features.norm(
dim=-1, keepdim=True)
return image_features, text_features
def compute_similarity(self, images, texts):
"""Extract images and texts features and compute cosine similarity."""
image_features, text_features = self.extract_feat(
images=images, texts=texts)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logits_per_image.t()
# shape (N, N)
return logits_per_image, logits_per_text
def predict(self,
images: torch.Tensor,
data_samples: DataSample = None) -> DataSample:
"""Predict the classes of the input images.
The prediction is for zero-shot classification and the text prototypes
will be prepared in thisfunction.
Args:
images (torch.Tensor): The input images.
data_samples (DataSample): The data samples with information from
dataset.
Returns:
DataSample: The results of prediction.
"""
if self.text_prototype_embeds is None:
self.prepare_text_prototype(device=images.device)
image_features = self.extract_image_feat(images=images)
image_features /= image_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logits_per_image = image_features @ self.text_prototype_embeds.to(
image_features.device) * self.logit_scale.exp()
pred_scores = F.softmax(logits_per_image, dim=1)
pred_labels = pred_scores.argmax(dim=1, keepdim=True).detach()
out_data_samples = []
if data_samples is None:
data_samples = [None for _ in range(pred_scores.size(0))]
for data_sample, score, label in zip(data_samples, pred_scores,
pred_labels):
if data_sample is None:
data_sample = DataSample()
data_sample.set_pred_score(score).set_pred_label(label)
out_data_samples.append(data_sample)
return out_data_samples
def prepare_text_prototype(self, device) -> None:
"""The function to prepare text prototypes with prompt."""
class_embeddings = []
for classname in track_on_main_process(self.prototype,
'Prepare text prototype...'):
# format with class
texts = [prompt(classname) for prompt in self.prompt]
tokenized_texts = self.tokenize(texts)
class_features = self.extract_text_feat(tokenized_texts.to(device))
class_features /= class_features.norm(dim=-1, keepdim=True)
class_feature = class_features.mean(dim=0)
class_feature /= class_feature.norm()
class_embeddings.append(class_feature)
self.text_prototype_embeds = torch.stack(
class_embeddings, dim=1).to(device)
def tokenize(self, texts: Union[str, List[str]]) -> torch.LongTensor:
"""Returns the tokenized representation of given input string(s)
Args:
texts (Union[str, List[str]]): An input string or a list of input
strings to tokenize
context_length (int): The context length to use. Defaults to 52.
Returns:
torch.Tensor: Resulting tokens.
"""
if isinstance(texts, str):
texts = [texts]
all_tokens = []
for text in texts:
# adapt the text to Chinese BERT vocab
text = text.lower().replace('β', "\"").replace('β', "\"")
# add special tokens
all_tokens.append(
[self.tokenizer.vocab['[CLS]']] +
self.tokenizer.convert_tokens_to_ids(
self.tokenizer.tokenize(text))[:self.context_length - 2] +
[self.tokenizer.vocab['[SEP]']])
result = torch.zeros(
len(all_tokens), self.context_length, dtype=torch.long)
for i, tokens in enumerate(all_tokens):
assert len(tokens) <= self.context_length
result[i, :len(tokens)] = torch.tensor(tokens)
return result
|