Spaces:
Running
on
Zero
Running
on
Zero
import fvcore.nn.weight_init as weight_init | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
import numpy as np | |
import logging | |
from functools import partial | |
from scipy import interpolate | |
from math import pi | |
from einops import rearrange, repeat | |
import warnings | |
from PIL import Image | |
import torch.utils.checkpoint as cp | |
from transformers import CLIPImageProcessor | |
# from ..utils.attention import FlashAttention, FlashMHA | |
# try: | |
# import xformers.ops as xops | |
# except: | |
# pass | |
logger = logging.getLogger(__name__) | |
BatchNorm2d = torch.nn.BatchNorm2d | |
class Conv2d(torch.nn.Conv2d): | |
""" | |
A wrapper around :class:`torch.nn.Conv2d` to support empty inputs and more features. | |
""" | |
def __init__(self, *args, **kwargs): | |
""" | |
Extra keyword arguments supported in addition to those in `torch.nn.Conv2d`: | |
Args: | |
norm (nn.Module, optional): a normalization layer | |
activation (callable(Tensor) -> Tensor): a callable activation function | |
It assumes that norm layer is used before activation. | |
""" | |
norm = kwargs.pop("norm", None) | |
activation = kwargs.pop("activation", None) | |
super().__init__(*args, **kwargs) | |
self.norm = norm | |
self.activation = activation | |
def forward(self, x): | |
# torchscript does not support SyncBatchNorm yet | |
# https://github.com/pytorch/pytorch/issues/40507 | |
# and we skip these codes in torchscript since: | |
# 1. currently we only support torchscript in evaluation mode | |
# 2. features needed by exporting module to torchscript are added in PyTorch 1.6 or | |
# later version, `Conv2d` in these PyTorch versions has already supported empty inputs. | |
if not torch.jit.is_scripting(): | |
with warnings.catch_warnings(record=True): | |
if x.numel() == 0 and self.training: | |
# https://github.com/pytorch/pytorch/issues/12013 | |
assert not isinstance( | |
self.norm, torch.nn.SyncBatchNorm | |
), "SyncBatchNorm does not support empty inputs!" | |
x = F.conv2d( | |
x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups | |
) | |
if self.norm is not None: | |
x = self.norm(x) | |
if self.activation is not None: | |
x = self.activation(x) | |
return x | |
def window_partition(x, window_size): | |
""" | |
Partition into non-overlapping windows with padding if needed. | |
Args: | |
x (tensor): input tokens with [B, H, W, C]. | |
window_size (int): window size. | |
Returns: | |
windows: windows after partition with [B * num_windows, window_size, window_size, C]. | |
(Hp, Wp): padded height and width before partition | |
""" | |
B, H, W, C = x.shape | |
pad_h = (window_size - H % window_size) % window_size | |
pad_w = (window_size - W % window_size) % window_size | |
if pad_h > 0 or pad_w > 0: | |
x = F.pad(x, (0, 0, 0, pad_w, 0, pad_h)) | |
Hp, Wp = H + pad_h, W + pad_w | |
x = x.view(B, Hp // window_size, window_size, Wp // window_size, window_size, C) | |
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) | |
return windows, (Hp, Wp) | |
def window_unpartition(windows, window_size, pad_hw, hw): | |
""" | |
Window unpartition into original sequences and removing padding. | |
Args: | |
x (tensor): input tokens with [B * num_windows, window_size, window_size, C]. | |
window_size (int): window size. | |
pad_hw (Tuple): padded height and width (Hp, Wp). | |
hw (Tuple): original height and width (H, W) before padding. | |
Returns: | |
x: unpartitioned sequences with [B, H, W, C]. | |
""" | |
Hp, Wp = pad_hw | |
H, W = hw | |
B = windows.shape[0] // (Hp * Wp // window_size // window_size) | |
x = windows.view(B, Hp // window_size, Wp // window_size, window_size, window_size, -1) | |
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, Hp, Wp, -1) | |
if Hp > H or Wp > W: | |
x = x[:, :H, :W, :].contiguous() | |
return x | |
def get_rel_pos(q_size, k_size, rel_pos): | |
""" | |
Get relative positional embeddings according to the relative positions of | |
query and key sizes. | |
Args: | |
q_size (int): size of query q. | |
k_size (int): size of key k. | |
rel_pos (Tensor): relative position embeddings (L, C). | |
Returns: | |
Extracted positional embeddings according to relative positions. | |
""" | |
max_rel_dist = int(2 * max(q_size, k_size) - 1) | |
use_log_interpolation = True | |
# Interpolate rel pos if needed. | |
if rel_pos.shape[0] != max_rel_dist: | |
if not use_log_interpolation: | |
# Interpolate rel pos. | |
rel_pos_resized = F.interpolate( | |
rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1), | |
size=max_rel_dist, | |
mode="linear", | |
) | |
rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0) | |
else: | |
src_size = rel_pos.shape[0] | |
dst_size = max_rel_dist | |
# q = 1.13492 | |
q = 1.0903078 | |
dis = [] | |
cur = 1 | |
for i in range(src_size // 2): | |
dis.append(cur) | |
cur += q ** (i + 1) | |
r_ids = [-_ for _ in reversed(dis)] | |
x = r_ids + [0] + dis | |
t = dst_size // 2.0 | |
dx = np.arange(-t, t + 0.1, 1.0) | |
all_rel_pos_bias = [] | |
for i in range(rel_pos.shape[1]): | |
z = rel_pos[:, i].view(src_size).cpu().float().numpy() | |
f = interpolate.interp1d(x, z, kind='cubic', fill_value="extrapolate") | |
all_rel_pos_bias.append( | |
torch.Tensor(f(dx)).contiguous().view(-1, 1).to(rel_pos.device)) | |
rel_pos_resized = torch.cat(all_rel_pos_bias, dim=-1) | |
else: | |
rel_pos_resized = rel_pos | |
# Scale the coords with short length if shapes for q and k are different. | |
q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0) | |
k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0) | |
relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0) | |
return rel_pos_resized[relative_coords.long()] | |
def add_decomposed_rel_pos(attn, q, rel_pos_h, rel_pos_w, q_size, k_size): | |
""" | |
Calculate decomposed Relative Positional Embeddings from :paper:`mvitv2`. | |
https://github.com/facebookresearch/mvit/blob/19786631e330df9f3622e5402b4a419a263a2c80/mvit/models/attention.py # noqa B950 | |
Args: | |
attn (Tensor): attention map. | |
q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C). | |
rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis. | |
rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis. | |
q_size (Tuple): spatial sequence size of query q with (q_h, q_w). | |
k_size (Tuple): spatial sequence size of key k with (k_h, k_w). | |
Returns: | |
attn (Tensor): attention map with added relative positional embeddings. | |
""" | |
q_h, q_w = q_size | |
k_h, k_w = k_size | |
Rh = get_rel_pos(q_h, k_h, rel_pos_h) | |
Rw = get_rel_pos(q_w, k_w, rel_pos_w) | |
B, _, dim = q.shape | |
r_q = q.reshape(B, q_h, q_w, dim) | |
rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh) | |
rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw) | |
attn = ( | |
attn.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :] | |
).view(B, q_h * q_w, k_h * k_w) | |
return attn | |
def get_abs_pos(abs_pos, has_cls_token, hw): | |
""" | |
Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token | |
dimension for the original embeddings. | |
Args: | |
abs_pos (Tensor): absolute positional embeddings with (1, num_position, C). | |
has_cls_token (bool): If true, has 1 embedding in abs_pos for cls token. | |
hw (Tuple): size of input image tokens. | |
Returns: | |
Absolute positional embeddings after processing with shape (1, H, W, C) | |
""" | |
h, w = hw | |
if has_cls_token: | |
abs_pos = abs_pos[:, 1:] | |
xy_num = abs_pos.shape[1] | |
size = int(math.sqrt(xy_num)) | |
assert size * size == xy_num | |
if size != h or size != w: | |
original_datatype = abs_pos.dtype | |
new_abs_pos = F.interpolate( | |
abs_pos.reshape(1, size, size, -1).permute(0, 3, 1, 2).float(), # bf16 is not implemented | |
size=(h, w), | |
mode="bicubic", | |
align_corners=False, | |
).to(original_datatype) | |
return new_abs_pos.permute(0, 2, 3, 1) | |
else: | |
return abs_pos.reshape(1, h, w, -1) | |
class PatchEmbed(nn.Module): | |
""" | |
Image to Patch Embedding. | |
""" | |
def __init__( | |
self, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0), in_chans=3, embed_dim=768 | |
): | |
""" | |
Args: | |
kernel_size (Tuple): kernel size of the projection layer. | |
stride (Tuple): stride of the projection layer. | |
padding (Tuple): padding size of the projection layer. | |
in_chans (int): Number of input image channels. | |
embed_dim (int): embed_dim (int): Patch embedding dimension. | |
""" | |
super().__init__() | |
self.proj = nn.Conv2d( | |
in_chans, embed_dim, kernel_size=kernel_size, stride=stride, padding=padding | |
) | |
def forward(self, x): | |
x = self.proj(x) | |
# B C H W -> B H W C | |
x = x.permute(0, 2, 3, 1) | |
return x | |
def broadcat(tensors, dim = -1): | |
num_tensors = len(tensors) | |
shape_lens = set(list(map(lambda t: len(t.shape), tensors))) | |
assert len(shape_lens) == 1, 'tensors must all have the same number of dimensions' | |
shape_len = list(shape_lens)[0] | |
dim = (dim + shape_len) if dim < 0 else dim | |
dims = list(zip(*map(lambda t: list(t.shape), tensors))) | |
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] | |
assert all([*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]), 'invalid dimensions for broadcastable concatentation' | |
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) | |
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) | |
expanded_dims.insert(dim, (dim, dims[dim])) | |
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) | |
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) | |
return torch.cat(tensors, dim = dim) | |
def rotate_half(x): | |
x = rearrange(x, '... (d r) -> ... d r', r = 2) | |
x1, x2 = x.unbind(dim = -1) | |
x = torch.stack((-x2, x1), dim = -1) | |
return rearrange(x, '... d r -> ... (d r)') | |
class VisionRotaryEmbedding(nn.Module): | |
def __init__( | |
self, | |
dim, | |
pt_seq_len, | |
ft_seq_len=None, | |
custom_freqs = None, | |
freqs_for = 'lang', | |
theta = 10000, | |
max_freq = 10, | |
num_freqs = 1, | |
): | |
super().__init__() | |
if custom_freqs: | |
freqs = custom_freqs | |
elif freqs_for == 'lang': | |
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
elif freqs_for == 'pixel': | |
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
elif freqs_for == 'constant': | |
freqs = torch.ones(num_freqs).float() | |
else: | |
raise ValueError(f'unknown modality {freqs_for}') | |
if ft_seq_len is None: ft_seq_len = pt_seq_len | |
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len | |
freqs_h = torch.einsum('..., f -> ... f', t, freqs) | |
freqs_h = repeat(freqs_h, '... n -> ... (n r)', r = 2) | |
freqs_w = torch.einsum('..., f -> ... f', t, freqs) | |
freqs_w = repeat(freqs_w, '... n -> ... (n r)', r = 2) | |
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim = -1) | |
self.register_buffer("freqs_cos", freqs.cos()) | |
self.register_buffer("freqs_sin", freqs.sin()) | |
# print('======== shape of rope freq', self.freqs_cos.shape, '========') | |
def forward(self, t, start_index = 0): | |
rot_dim = self.freqs_cos.shape[-1] | |
end_index = start_index + rot_dim | |
assert rot_dim <= t.shape[-1], f'feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}' | |
t_left, t, t_right = t[..., :start_index], t[..., start_index:end_index], t[..., end_index:] | |
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) | |
return torch.cat((t_left, t, t_right), dim = -1) | |
class VisionRotaryEmbeddingFast(nn.Module): | |
def __init__( | |
self, | |
dim, | |
pt_seq_len=16, | |
ft_seq_len=None, | |
custom_freqs = None, | |
freqs_for = 'lang', | |
theta = 10000, | |
max_freq = 10, | |
num_freqs = 1, | |
): | |
super().__init__() | |
if custom_freqs: | |
freqs = custom_freqs | |
elif freqs_for == 'lang': | |
freqs = 1. / (theta ** (torch.arange(0, dim, 2)[:(dim // 2)].float() / dim)) | |
elif freqs_for == 'pixel': | |
freqs = torch.linspace(1., max_freq / 2, dim // 2) * pi | |
elif freqs_for == 'constant': | |
freqs = torch.ones(num_freqs).float() | |
else: | |
raise ValueError(f'unknown modality {freqs_for}') | |
if ft_seq_len is None: ft_seq_len = pt_seq_len | |
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len | |
freqs = torch.einsum('..., f -> ... f', t, freqs) | |
freqs = repeat(freqs, '... n -> ... (n r)', r = 2) | |
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim = -1) | |
freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) | |
freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) | |
self.register_buffer("freqs_cos", freqs_cos) | |
self.register_buffer("freqs_sin", freqs_sin) | |
# print('======== shape of rope freq', self.freqs_cos.shape, '========') | |
def forward(self, t): return t * self.freqs_cos + rotate_half(t) * self.freqs_sin | |
class FrozenBatchNorm2d(nn.Module): | |
""" | |
BatchNorm2d where the batch statistics and the affine parameters are fixed. | |
It contains non-trainable buffers called | |
"weight" and "bias", "running_mean", "running_var", | |
initialized to perform identity transformation. | |
The pre-trained backbone models from Caffe2 only contain "weight" and "bias", | |
which are computed from the original four parameters of BN. | |
The affine transform `x * weight + bias` will perform the equivalent | |
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. | |
When loading a backbone model from Caffe2, "running_mean" and "running_var" | |
will be left unchanged as identity transformation. | |
Other pre-trained backbone models may contain all 4 parameters. | |
The forward is implemented by `F.batch_norm(..., training=False)`. | |
""" | |
_version = 3 | |
def __init__(self, num_features, eps=1e-5): | |
super().__init__() | |
self.num_features = num_features | |
self.eps = eps | |
self.register_buffer("weight", torch.ones(num_features)) | |
self.register_buffer("bias", torch.zeros(num_features)) | |
self.register_buffer("running_mean", torch.zeros(num_features)) | |
self.register_buffer("running_var", torch.ones(num_features) - eps) | |
def forward(self, x): | |
if x.requires_grad: | |
# When gradients are needed, F.batch_norm will use extra memory | |
# because its backward op computes gradients for weight/bias as well. | |
scale = self.weight * (self.running_var + self.eps).rsqrt() | |
bias = self.bias - self.running_mean * scale | |
scale = scale.reshape(1, -1, 1, 1) | |
bias = bias.reshape(1, -1, 1, 1) | |
out_dtype = x.dtype # may be half | |
return x * scale.to(out_dtype) + bias.to(out_dtype) | |
else: | |
# When gradients are not needed, F.batch_norm is a single fused op | |
# and provide more optimization opportunities. | |
return F.batch_norm( | |
x, | |
self.running_mean, | |
self.running_var, | |
self.weight, | |
self.bias, | |
training=False, | |
eps=self.eps, | |
) | |
def _load_from_state_dict( | |
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
): | |
version = local_metadata.get("version", None) | |
if version is None or version < 2: | |
# No running_mean/var in early versions | |
# This will silent the warnings | |
if prefix + "running_mean" not in state_dict: | |
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) | |
if prefix + "running_var" not in state_dict: | |
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) | |
super()._load_from_state_dict( | |
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs | |
) | |
def __repr__(self): | |
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) | |
def convert_frozen_batchnorm(cls, module): | |
""" | |
Convert all BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. | |
Args: | |
module (torch.nn.Module): | |
Returns: | |
If module is BatchNorm/SyncBatchNorm, returns a new module. | |
Otherwise, in-place convert module and return it. | |
Similar to convert_sync_batchnorm in | |
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py | |
""" | |
bn_module = nn.modules.batchnorm | |
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) | |
res = module | |
if isinstance(module, bn_module): | |
res = cls(module.num_features) | |
if module.affine: | |
res.weight.data = module.weight.data.clone().detach() | |
res.bias.data = module.bias.data.clone().detach() | |
res.running_mean.data = module.running_mean.data | |
res.running_var.data = module.running_var.data | |
res.eps = module.eps | |
else: | |
for name, child in module.named_children(): | |
new_child = cls.convert_frozen_batchnorm(child) | |
if new_child is not child: | |
res.add_module(name, new_child) | |
return res | |
class LayerNorm(nn.Module): | |
""" | |
A LayerNorm variant, popularized by Transformers, that performs point-wise mean and | |
variance normalization over the channel dimension for inputs that have shape | |
(batch_size, channels, height, width). | |
https://github.com/facebookresearch/ConvNeXt/blob/d1fa8f6fef0a165b27399986cc2bdacc92777e40/models/convnext.py#L119 # noqa B950 | |
""" | |
def __init__(self, normalized_shape, eps=1e-6): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
self.bias = nn.Parameter(torch.zeros(normalized_shape)) | |
self.eps = eps | |
self.normalized_shape = (normalized_shape,) | |
def forward(self, x): | |
u = x.mean(1, keepdim=True) | |
s = (x - u).pow(2).mean(1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.eps) | |
x = self.weight[:, None, None] * x + self.bias[:, None, None] | |
return x | |
class CNNBlockBase(nn.Module): | |
""" | |
A CNN block is assumed to have input channels, output channels and a stride. | |
The input and output of `forward()` method must be NCHW tensors. | |
The method can perform arbitrary computation but must match the given | |
channels and stride specification. | |
Attribute: | |
in_channels (int): | |
out_channels (int): | |
stride (int): | |
""" | |
def __init__(self, in_channels, out_channels, stride): | |
""" | |
The `__init__` method of any subclass should also contain these arguments. | |
Args: | |
in_channels (int): | |
out_channels (int): | |
stride (int): | |
""" | |
super().__init__() | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.stride = stride | |
def freeze(self): | |
""" | |
Make this block not trainable. | |
This method sets all parameters to `requires_grad=False`, | |
and convert all BatchNorm layers to FrozenBatchNorm | |
Returns: | |
the block itself | |
""" | |
for p in self.parameters(): | |
p.requires_grad = False | |
FrozenBatchNorm2d.convert_frozen_batchnorm(self) | |
return self | |
def get_norm(norm, out_channels): | |
""" | |
Args: | |
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN; | |
or a callable that takes a channel number and returns | |
the normalization layer as a nn.Module. | |
Returns: | |
nn.Module or None: the normalization layer | |
""" | |
if norm is None: | |
return None | |
if isinstance(norm, str): | |
if len(norm) == 0: | |
return None | |
norm = { | |
"BN": BatchNorm2d, | |
# Fixed in https://github.com/pytorch/pytorch/pull/36382 | |
"SyncBN": nn.SyncBatchNorm, | |
"FrozenBN": FrozenBatchNorm2d, | |
"GN": lambda channels: nn.GroupNorm(32, channels), | |
# for debugging: | |
"nnSyncBN": nn.SyncBatchNorm, | |
"LN": lambda channels: LayerNorm(channels) | |
}[norm] | |
return norm(out_channels) | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
""" | |
def __init__(self, drop_prob=None): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
def forward(self, x): | |
if self.drop_prob == 0. or not self.training: | |
return x | |
keep_prob = 1 - self.drop_prob | |
# work with diff dim tensors, not just 2D ConvNets | |
shape = (x.shape[0],) + (1,) * (x.ndim - 1) | |
random_tensor = keep_prob + \ | |
torch.rand(shape, dtype=x.dtype, device=x.device) | |
random_tensor.floor_() # binarize | |
output = x.div(keep_prob) * random_tensor | |
return output | |
class SwiGLU(nn.Module): | |
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0., | |
norm_layer=nn.LayerNorm, subln=False | |
): | |
super().__init__() | |
out_features = out_features or in_features | |
hidden_features = hidden_features or in_features | |
self.w1 = nn.Linear(in_features, hidden_features) | |
self.w2 = nn.Linear(in_features, hidden_features) | |
self.act = act_layer() | |
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity() | |
self.w3 = nn.Linear(hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
def forward(self, x): | |
x1 = self.w1(x) | |
x2 = self.w2(x) | |
hidden = self.act(x1) * x2 | |
x = self.ffn_ln(hidden) | |
x = self.w3(x) | |
x = self.drop(x) | |
return x | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
num_heads=8, | |
qkv_bias=True, | |
qk_scale=None, | |
attn_head_dim=None, | |
norm_layer=nn.LayerNorm, | |
rope=None, | |
xattn=True, | |
subln=False | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
if attn_head_dim is not None: | |
head_dim = attn_head_dim | |
all_head_dim = head_dim * self.num_heads | |
self.scale = qk_scale or head_dim ** -0.5 | |
self.subln = subln | |
self.q_proj = nn.Linear(dim, all_head_dim, bias=False) | |
self.k_proj = nn.Linear(dim, all_head_dim, bias=False) | |
self.v_proj = nn.Linear(dim, all_head_dim, bias=False) | |
if qkv_bias: | |
self.q_bias = nn.Parameter(torch.zeros(all_head_dim)) | |
self.v_bias = nn.Parameter(torch.zeros(all_head_dim)) | |
else: | |
self.q_bias = None | |
self.v_bias = None | |
self.rope = rope | |
self.xattn = xattn | |
self.proj = nn.Linear(all_head_dim, dim) | |
self.inner_attn_ln = norm_layer(all_head_dim) if subln else nn.Identity() | |
if self.xattn: | |
factory_kwargs = {'device': 'cuda', 'dtype': torch.float16} | |
self.inner_attn = FlashAttention(attention_dropout=0.0, **factory_kwargs) | |
def forward(self, x): | |
B, H, W, C = x.shape | |
x = x.view(B, -1, C) | |
N = H * W | |
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias) | |
k = F.linear(input=x, weight=self.k_proj.weight, bias=None) | |
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias) | |
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C | |
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) | |
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) | |
## rope | |
q = self.rope(q).type_as(v) | |
k = self.rope(k).type_as(v) | |
if self.xattn: | |
q = q.permute(0, 2, 1, 3) # B, num_heads, N, C -> B, N, num_heads, C | |
k = k.permute(0, 2, 1, 3) | |
v = v.permute(0, 2, 1, 3) | |
kv = torch.stack([k, v], dim=2) | |
x, attn_weights = self.inner_attn(q, kv, key_padding_mask=None, causal=False) | |
# x = xops.memory_efficient_attention(q, k, v) | |
x = x.reshape(B, N, -1) | |
x = self.inner_attn_ln(x) | |
else: | |
q = q * self.scale | |
attn = (q @ k.transpose(-2, -1)) | |
attn = attn.softmax(dim=-1).type_as(x) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |
x = self.inner_attn_ln(x) | |
x = self.proj(x) | |
x = x.view(B, H, W, C) | |
return x | |
class ResBottleneckBlock(CNNBlockBase): | |
""" | |
The standard bottleneck residual block without the last activation layer. | |
It contains 3 conv layers with kernels 1x1, 3x3, 1x1. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
bottleneck_channels, | |
norm="LN", | |
act_layer=nn.GELU, | |
): | |
""" | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
bottleneck_channels (int): number of output channels for the 3x3 | |
"bottleneck" conv layers. | |
norm (str or callable): normalization for all conv layers. | |
See :func:`layers.get_norm` for supported format. | |
act_layer (callable): activation for all conv layers. | |
""" | |
super().__init__(in_channels, out_channels, 1) | |
self.conv1 = Conv2d(in_channels, bottleneck_channels, 1, bias=False) | |
self.norm1 = get_norm(norm, bottleneck_channels) | |
self.act1 = act_layer() | |
self.conv2 = Conv2d( | |
bottleneck_channels, | |
bottleneck_channels, | |
3, | |
padding=1, | |
bias=False, | |
) | |
self.norm2 = get_norm(norm, bottleneck_channels) | |
self.act2 = act_layer() | |
self.conv3 = Conv2d(bottleneck_channels, out_channels, 1, bias=False) | |
self.norm3 = get_norm(norm, out_channels) | |
for layer in [self.conv1, self.conv2, self.conv3]: | |
weight_init.c2_msra_fill(layer) | |
for layer in [self.norm1, self.norm2]: | |
layer.weight.data.fill_(1.0) | |
layer.bias.data.zero_() | |
# zero init last norm layer. | |
self.norm3.weight.data.zero_() | |
self.norm3.bias.data.zero_() | |
def forward(self, x): | |
out = x | |
for layer in self.children(): | |
out = layer(out) | |
out = x + out | |
return out | |
class Block(nn.Module): | |
"""Transformer blocks with support of window attention and residual propagation blocks""" | |
def __init__( | |
self, | |
dim, | |
num_heads, | |
mlp_ratio=4*2/3, | |
qkv_bias=True, | |
drop_path=0.0, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
window_size=0, | |
use_residual_block=False, | |
rope=None, | |
xattn=True, | |
subln=False, | |
# with_cp=True, | |
): | |
""" | |
Args: | |
dim (int): Number of input channels. | |
num_heads (int): Number of attention heads in each ViT block. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
drop_path (float): Stochastic depth rate. | |
norm_layer (nn.Module): Normalization layer. | |
act_layer (nn.Module): Activation layer. | |
use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
window_size (int): Window size for window attention blocks. If it equals 0, then not | |
use window attention. | |
use_residual_block (bool): If True, use a residual block after the MLP block. | |
input_size (int or None): Input resolution for calculating the relative positional | |
parameter size. | |
""" | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Attention( | |
dim, | |
num_heads=num_heads, | |
qkv_bias=qkv_bias, | |
rope=rope, | |
xattn=xattn, | |
subln=subln | |
) | |
# self.with_cp = with_cp | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
self.mlp = SwiGLU( | |
in_features=dim, | |
hidden_features=int(dim * mlp_ratio), | |
subln=True, | |
norm_layer=norm_layer, | |
) | |
self.window_size = window_size | |
self.use_residual_block = use_residual_block | |
if use_residual_block: | |
# Use a residual block with bottleneck channel as dim // 2 | |
self.residual = ResBottleneckBlock( | |
in_channels=dim, | |
out_channels=dim, | |
bottleneck_channels=dim // 2, | |
norm="LN", | |
) | |
def _forward(self, x): | |
shortcut = x | |
x = self.norm1(x) | |
# Window partition | |
if self.window_size > 0: | |
H, W = x.shape[1], x.shape[2] | |
x, pad_hw = window_partition(x, self.window_size) | |
x = self.attn(x) | |
# Reverse window partition | |
if self.window_size > 0: | |
x = window_unpartition(x, self.window_size, pad_hw, (H, W)) | |
x = shortcut + self.drop_path(x) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
if self.use_residual_block: | |
x = self.residual(x.permute(0, 3, 1, 2)).permute(0, 2, 3, 1) | |
return x | |
def forward(self, x, with_cp=False): | |
# if self.with_cp and self.training: | |
if with_cp: | |
x = cp.checkpoint(self._forward, x) | |
else: | |
x = self._forward(x) | |
return x | |
#@BACKBONES.register_module() | |
class EVAViT(nn.Module): | |
""" | |
This module implements Vision Transformer (ViT) backbone in :paper:`vitdet`. | |
"Exploring Plain Vision Transformer Backbones for Object Detection", | |
https://arxiv.org/abs/2203.16527 | |
""" | |
def __init__( | |
self, | |
img_size=1024, | |
patch_size=16, | |
in_chans=3, | |
embed_dim=768, | |
depth=12, | |
num_heads=12, | |
mlp_ratio=4*2/3, | |
qkv_bias=True, | |
drop_path_rate=0.0, | |
norm_layer=partial(nn.LayerNorm, eps=1e-6), | |
act_layer=nn.GELU, | |
use_abs_pos=True, | |
use_rel_pos=False, | |
# sim_fpn=None, | |
rope=True, | |
pt_hw_seq_len=16, | |
intp_freq=True, | |
window_size=0, | |
global_window_size=0, | |
window_block_indexes=(), | |
residual_block_indexes=(), | |
pretrain_img_size=224, | |
pretrain_use_cls_token=True, | |
out_feature="last_feat", | |
subln=False, | |
xattn=True, | |
# with_cp=True, | |
frozen=False, | |
): | |
""" | |
Args: | |
img_size (int): Input image size. | |
patch_size (int): Patch size. | |
in_chans (int): Number of input image channels. | |
embed_dim (int): Patch embedding dimension. | |
depth (int): Depth of ViT. | |
num_heads (int): Number of attention heads in each ViT block. | |
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. | |
qkv_bias (bool): If True, add a learnable bias to query, key, value. | |
drop_path_rate (float): Stochastic depth rate. | |
norm_layer (nn.Module): Normalization layer. | |
act_layer (nn.Module): Activation layer. | |
use_abs_pos (bool): If True, use absolute positional embeddings. | |
use_rel_pos (bool): If True, add relative positional embeddings to the attention map. | |
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters. | |
window_size (int): Window size for window attention blocks. | |
window_block_indexes (list): Indexes for blocks using window attention. | |
residual_block_indexes (list): Indexes for blocks using conv propagation. | |
use_act_checkpoint (bool): If True, use activation checkpointing. | |
pretrain_img_size (int): input image size for pretraining models. | |
pretrain_use_cls_token (bool): If True, pretrainig models use class token. | |
out_feature (str): name of the feature from the last block. | |
""" | |
super().__init__() | |
self.pretrain_use_cls_token = pretrain_use_cls_token | |
self.patch_embed = PatchEmbed( | |
kernel_size=(patch_size, patch_size), | |
stride=(patch_size, patch_size), | |
in_chans=in_chans, | |
embed_dim=embed_dim, | |
) | |
self.frozen = frozen | |
self.gradient_checkpointing = False | |
if use_abs_pos: | |
# Initialize absolute positional embedding with pretrain image size. | |
num_patches = (pretrain_img_size // patch_size) * (pretrain_img_size // patch_size) | |
num_positions = (num_patches + 1) if pretrain_use_cls_token else num_patches | |
self.pos_embed = nn.Parameter(torch.zeros(1, num_positions, embed_dim)) | |
else: | |
self.pos_embed = None | |
half_head_dim = embed_dim // num_heads // 2 | |
hw_seq_len = img_size // patch_size | |
self.rope_win = VisionRotaryEmbeddingFast( | |
dim=half_head_dim, | |
pt_seq_len=pt_hw_seq_len, | |
ft_seq_len=window_size if intp_freq else None, | |
) | |
self.rope_glb = VisionRotaryEmbeddingFast( | |
dim=half_head_dim, | |
pt_seq_len=pt_hw_seq_len, | |
ft_seq_len=hw_seq_len if intp_freq else None, | |
) | |
# stochastic depth decay rule | |
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] | |
self.blocks = nn.ModuleList() | |
for i in range(depth): | |
block = Block( | |
dim=embed_dim, | |
num_heads=num_heads, | |
mlp_ratio=mlp_ratio, | |
qkv_bias=qkv_bias, | |
drop_path=dpr[i], | |
norm_layer=norm_layer, | |
window_size=window_size if i in window_block_indexes else global_window_size, | |
use_residual_block=i in residual_block_indexes, | |
rope=self.rope_win if i in window_block_indexes else self.rope_glb, | |
xattn=xattn, | |
subln=subln, | |
# with_cp=with_cp, | |
) | |
self.blocks.append(block) | |
self._out_feature_channels = {out_feature: embed_dim} | |
self._out_feature_strides = {out_feature: patch_size} | |
self._out_features = [out_feature] | |
# if self.pos_embed is not None: | |
# nn.init.trunc_normal_(self.pos_embed, std=0.02) | |
if self.pos_embed is not None: | |
nn.init.normal_(self.pos_embed, std=0.02) | |
# MIN SHI: I disable the weight initialization since they will be automatically loaded | |
# **However, they will cause problems (deepspeed + bf16)** | |
# self.apply(self._init_weights) | |
self._freeze_stages() | |
# def _init_weights(self, m): | |
# if isinstance(m, nn.Linear): | |
# nn.init.trunc_normal_(m.weight, std=0.02) | |
# if isinstance(m, nn.Linear) and m.bias is not None: | |
# nn.init.constant_(m.bias, 0) | |
# elif isinstance(m, nn.LayerNorm): | |
# nn.init.constant_(m.bias, 0) | |
# nn.init.constant_(m.weight, 1.0) | |
def _freeze_stages(self): | |
if self.frozen: | |
self.eval() | |
for m in self.parameters(): | |
m.requires_grad = False | |
def forward(self, x): | |
x = self.patch_embed(x) | |
if self.pos_embed is not None: | |
x = x + get_abs_pos( | |
self.pos_embed, self.pretrain_use_cls_token, (x.shape[1], x.shape[2]) | |
) | |
for blk in self.blocks: | |
x = blk(x, with_cp=self.gradient_checkpointing) # b, h, w, c | |
x = x.permute(0, 3, 1, 2) # b, c, h, w | |
# if self.adapter is not None: | |
# outputs = self.adapter(x) | |
# else: | |
# outputs = [x, ] | |
# return outputs | |
return x | |
''' | |
EVA VIT vision encoder for LLaVA | |
''' | |
class EVAVITVisionTower(nn.Module): | |
def __init__(self, vision_tower, args, delay_load=False): | |
super().__init__() | |
self.is_loaded = False | |
self.vision_tower_name = vision_tower | |
self.select_layer = args.mm_vision_select_layer # NOTE: not implemented yet, this parameter has no effect | |
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch') | |
self.args = args | |
self.vision_tower, vision_tower_config = build_eva_vit(args=args, | |
model_name=vision_tower, | |
image_size=args.input_image_size | |
) | |
self.input_image_size=args.input_image_size | |
self.vision_tower.config = vision_tower_config | |
self.freeze_vision = args.freeze_vision | |
if not self.is_loaded: | |
self.load_model() | |
# if not delay_load: | |
# self.load_model() | |
# else: | |
# self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name) | |
def load_model(self): | |
if self.is_loaded: | |
return | |
# self.args.vision_tower_input_size = 224 # hardcode | |
self.image_processor = CLIPImageProcessor(crop_size={"height": self.args.input_image_size, "width": self.args.input_image_size}, | |
size={'shortest_edge': self.args.input_image_size}, | |
image_mean=[0.48145466, 0.4578275, 0.40821073], | |
image_std=[0.26862954, 0.26130258, 0.27577711]) | |
# load weights | |
if self.args.vision_tower_pretrained_from is None: | |
self.args.vision_tower_pretrained_from = "/lustre/fsw/portfolios/llmservice/users/fuxiaol/eva02_L_coco_det_sys_o365.pth" | |
# pretrained_params = torch.load(self.args.vision_tower_pretrained_from) | |
# if 'ema_state' in pretrained_params: | |
# pretrained_params = pretrained_params['ema_state'] | |
# elif 'module' in pretrained_params: | |
# pretrained_params = pretrained_params['module'] | |
# from collections import OrderedDict | |
# new_params = OrderedDict() | |
# kw = "" | |
# if "det" in self.args.vision_tower_pretrained_from.lower(): | |
# kw = "backbone.net." | |
# elif "clip" in self.args.vision_tower_pretrained_from.lower(): | |
# kw = "visual." | |
# for k, v in pretrained_params.items(): | |
# if len(kw) > 0: | |
# if kw in k and ("rope" not in k): | |
# new_params[k.replace(kw, "")] = v | |
# else: | |
# if "rope" not in k: | |
# new_params[k] = v | |
# incompatiblekeys = self.vision_tower.load_state_dict(new_params, strict=False) | |
# for k in incompatiblekeys[0]: | |
# if "rope" not in k: | |
# warnings.warn(f"Find incompatible keys {k} in state dict.") | |
# print(f"EVA-02 ckpt loaded from {self.args.vision_tower_pretrained_from}") | |
if self.freeze_vision: | |
self.vision_tower.requires_grad_(False) | |
self.is_loaded = True | |
# @torch.no_grad() | |
def forward(self, images): | |
if type(images) is list: | |
image_features = [] | |
for image in images: | |
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0)) | |
image_feature = image_forward_out.flatten(2,3).transpose(1,2) # b, n, c | |
image_features.append(image_feature) | |
else: | |
image_forward_out = self.vision_tower(images.to(device=self.device, dtype=self.dtype)) | |
return image_forward_out | |
def dummy_feature(self): | |
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) | |
def dtype(self): | |
return next(self.vision_tower.parameters()).dtype | |
def device(self): | |
return next(self.vision_tower.parameters()).device | |
def config(self): | |
# if self.is_loaded: | |
# return self.vision_tower.config | |
# else: | |
# return self.cfg_only | |
# TODO | |
return self.vision_tower.config | |
def hidden_size(self): | |
#return self.config.hidden_size | |
return self.config['hidden_dim'] | |
def num_patches(self): | |
# return (self.config.image_size // self.config.patch_size) ** 2 | |
return self.config['num_patches'] | |
def build_eva_vit(args, | |
model_name=None, | |
image_size=224, | |
window_attn=True | |
): | |
if "336" in args.vision_tower_pretrained_from: | |
pretrained_image_size = 336 | |
else: | |
pretrained_image_size = 224 | |
if "clip" in args.vision_tower_pretrained_from.lower(): | |
subln = True | |
else: | |
subln = False | |
if model_name == 'eva02-l-16': | |
# shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth | |
if window_attn: | |
window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))) | |
else: | |
window_block_indexes = () | |
model = EVAViT( | |
img_size=image_size, | |
patch_size=16, | |
window_size=16, | |
in_chans=3, | |
embed_dim=1024, | |
depth=24, | |
num_heads=16, | |
mlp_ratio=4*2/3, | |
window_block_indexes = window_block_indexes, | |
qkv_bias=True, | |
drop_path_rate=0.0, | |
xattn=False, | |
# with_cp=False, | |
# frozen=True, | |
) | |
# image_size = 224 # HARDCODE | |
eva_config = dict(image_size=image_size, | |
patch_size=16, | |
window_size=16, | |
hidden_dim=1024, | |
depth=24, | |
num_heads=16, | |
window_block_indexes=window_block_indexes, | |
num_patches=image_size ** 2 // 16 ** 2, | |
pretrained_from=args.vision_tower_pretrained_from | |
) | |
elif model_name == 'eva02-l-14': | |
# shilong said that use this: https://huggingface.co/Yuxin-CV/EVA-02/blob/main/eva02/det/eva02_L_coco_det_sys_o365.pth | |
if window_attn: | |
window_block_indexes = (list(range(0, 2)) + list(range(3, 5)) + list(range(6, 8)) + list(range(9, 11)) + list(range(12, 14)) + list(range(15, 17)) + list(range(18, 20)) + list(range(21, 23))) | |
else: | |
window_block_indexes = () | |
model = EVAViT( | |
img_size=image_size, | |
pretrain_img_size=pretrained_image_size, | |
patch_size=14, | |
window_size=16, | |
in_chans=3, | |
embed_dim=1024, | |
depth=24, | |
num_heads=16, | |
mlp_ratio=4*2/3, | |
window_block_indexes = window_block_indexes, | |
qkv_bias=True, | |
drop_path_rate=0.0, | |
xattn=False, | |
# with_cp=False, | |
subln=subln, | |
# frozen=True, | |
) | |
# image_size = 224 # HARDCODE | |
eva_config = dict(image_size=image_size, | |
patch_size=14, | |
window_size=16, | |
hidden_dim=1024, | |
depth=24, | |
num_heads=16, | |
window_block_indexes=window_block_indexes, | |
num_patches=image_size ** 2 // 14 ** 2, | |
pretrained_from=args.vision_tower_pretrained_from | |
) | |
else: | |
raise NotImplementedError | |
return model, eva_config |