|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
from typing import Callable, List, Any, Tuple, Dict |
|
import warnings |
|
|
|
import torch |
|
from torch import nn, Tensor |
|
|
|
from .attention import Attention, MemEffAttention |
|
from .drop_path import DropPath |
|
from .layer_scale import LayerScale |
|
from .mlp import Mlp |
|
|
|
|
|
logger = logging.getLogger("dinov2") |
|
|
|
|
|
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None |
|
try: |
|
if XFORMERS_ENABLED: |
|
from xformers.ops import fmha, scaled_index_add, index_select_cat |
|
|
|
XFORMERS_AVAILABLE = True |
|
warnings.warn("xFormers is available (Block)") |
|
else: |
|
warnings.warn("xFormers is disabled (Block)") |
|
raise ImportError |
|
except ImportError: |
|
XFORMERS_AVAILABLE = False |
|
|
|
warnings.warn("xFormers is not available (Block)") |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
num_heads: int, |
|
mlp_ratio: float = 4.0, |
|
qkv_bias: bool = False, |
|
proj_bias: bool = True, |
|
ffn_bias: bool = True, |
|
drop: float = 0.0, |
|
attn_drop: float = 0.0, |
|
init_values=None, |
|
drop_path: float = 0.0, |
|
act_layer: Callable[..., nn.Module] = nn.GELU, |
|
norm_layer: Callable[..., nn.Module] = nn.LayerNorm, |
|
attn_class: Callable[..., nn.Module] = Attention, |
|
ffn_layer: Callable[..., nn.Module] = Mlp, |
|
) -> None: |
|
super().__init__() |
|
|
|
self.norm1 = norm_layer(dim) |
|
self.attn = attn_class( |
|
dim, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
proj_bias=proj_bias, |
|
attn_drop=attn_drop, |
|
proj_drop=drop, |
|
) |
|
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
self.norm2 = norm_layer(dim) |
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
self.mlp = ffn_layer( |
|
in_features=dim, |
|
hidden_features=mlp_hidden_dim, |
|
act_layer=act_layer, |
|
drop=drop, |
|
bias=ffn_bias, |
|
) |
|
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() |
|
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
self.sample_drop_ratio = drop_path |
|
|
|
def forward(self, x: Tensor) -> Tensor: |
|
def attn_residual_func(x: Tensor) -> Tensor: |
|
return self.ls1(self.attn(self.norm1(x))) |
|
|
|
def ffn_residual_func(x: Tensor) -> Tensor: |
|
return self.ls2(self.mlp(self.norm2(x))) |
|
|
|
if self.training and self.sample_drop_ratio > 0.1: |
|
|
|
x = drop_add_residual_stochastic_depth( |
|
x, |
|
residual_func=attn_residual_func, |
|
sample_drop_ratio=self.sample_drop_ratio, |
|
) |
|
x = drop_add_residual_stochastic_depth( |
|
x, |
|
residual_func=ffn_residual_func, |
|
sample_drop_ratio=self.sample_drop_ratio, |
|
) |
|
elif self.training and self.sample_drop_ratio > 0.0: |
|
x = x + self.drop_path1(attn_residual_func(x)) |
|
x = x + self.drop_path1(ffn_residual_func(x)) |
|
else: |
|
x = x + attn_residual_func(x) |
|
x = x + ffn_residual_func(x) |
|
return x |
|
|
|
|
|
def drop_add_residual_stochastic_depth( |
|
x: Tensor, |
|
residual_func: Callable[[Tensor], Tensor], |
|
sample_drop_ratio: float = 0.0, |
|
) -> Tensor: |
|
|
|
b, n, d = x.shape |
|
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) |
|
brange = (torch.randperm(b, device=x.device))[:sample_subset_size] |
|
x_subset = x[brange] |
|
|
|
|
|
residual = residual_func(x_subset) |
|
|
|
x_flat = x.flatten(1) |
|
residual = residual.flatten(1) |
|
|
|
residual_scale_factor = b / sample_subset_size |
|
|
|
|
|
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) |
|
return x_plus_residual.view_as(x) |
|
|
|
|
|
def get_branges_scales(x, sample_drop_ratio=0.0): |
|
b, n, d = x.shape |
|
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) |
|
brange = (torch.randperm(b, device=x.device))[:sample_subset_size] |
|
residual_scale_factor = b / sample_subset_size |
|
return brange, residual_scale_factor |
|
|
|
|
|
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): |
|
if scaling_vector is None: |
|
x_flat = x.flatten(1) |
|
residual = residual.flatten(1) |
|
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) |
|
else: |
|
x_plus_residual = scaled_index_add( |
|
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor |
|
) |
|
return x_plus_residual |
|
|
|
|
|
attn_bias_cache: Dict[Tuple, Any] = {} |
|
|
|
|
|
def get_attn_bias_and_cat(x_list, branges=None): |
|
""" |
|
this will perform the index select, cat the tensors, and provide the attn_bias from cache |
|
""" |
|
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] |
|
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) |
|
if all_shapes not in attn_bias_cache.keys(): |
|
seqlens = [] |
|
for b, x in zip(batch_sizes, x_list): |
|
for _ in range(b): |
|
seqlens.append(x.shape[1]) |
|
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) |
|
attn_bias._batch_sizes = batch_sizes |
|
attn_bias_cache[all_shapes] = attn_bias |
|
|
|
if branges is not None: |
|
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) |
|
else: |
|
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) |
|
cat_tensors = torch.cat(tensors_bs1, dim=1) |
|
|
|
return attn_bias_cache[all_shapes], cat_tensors |
|
|
|
|
|
def drop_add_residual_stochastic_depth_list( |
|
x_list: List[Tensor], |
|
residual_func: Callable[[Tensor, Any], Tensor], |
|
sample_drop_ratio: float = 0.0, |
|
scaling_vector=None, |
|
) -> Tensor: |
|
|
|
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] |
|
branges = [s[0] for s in branges_scales] |
|
residual_scale_factors = [s[1] for s in branges_scales] |
|
|
|
|
|
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) |
|
|
|
|
|
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) |
|
|
|
outputs = [] |
|
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): |
|
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) |
|
return outputs |
|
|
|
|
|
class NestedTensorBlock(Block): |
|
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: |
|
""" |
|
x_list contains a list of tensors to nest together and run |
|
""" |
|
assert isinstance(self.attn, MemEffAttention) |
|
|
|
if self.training and self.sample_drop_ratio > 0.0: |
|
|
|
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
|
return self.attn(self.norm1(x), attn_bias=attn_bias) |
|
|
|
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
|
return self.mlp(self.norm2(x)) |
|
|
|
x_list = drop_add_residual_stochastic_depth_list( |
|
x_list, |
|
residual_func=attn_residual_func, |
|
sample_drop_ratio=self.sample_drop_ratio, |
|
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, |
|
) |
|
x_list = drop_add_residual_stochastic_depth_list( |
|
x_list, |
|
residual_func=ffn_residual_func, |
|
sample_drop_ratio=self.sample_drop_ratio, |
|
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, |
|
) |
|
return x_list |
|
else: |
|
|
|
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
|
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) |
|
|
|
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: |
|
return self.ls2(self.mlp(self.norm2(x))) |
|
|
|
attn_bias, x = get_attn_bias_and_cat(x_list) |
|
x = x + attn_residual_func(x, attn_bias=attn_bias) |
|
x = x + ffn_residual_func(x) |
|
return attn_bias.split(x) |
|
|
|
def forward(self, x_or_x_list): |
|
if isinstance(x_or_x_list, Tensor): |
|
return super().forward(x_or_x_list) |
|
elif isinstance(x_or_x_list, list): |
|
if not XFORMERS_AVAILABLE: |
|
raise AssertionError("xFormers is required for using nested tensors") |
|
return self.forward_nested(x_or_x_list) |
|
else: |
|
raise AssertionError |
|
|