cs-mixer / timm /layers /pos_embed_sincos.py
crrrr30's picture
Upload folder using huggingface_hub
da716ed
""" Sin-cos, fourier, rotary position embedding modules and functions
Hacked together by / Copyright 2022 Ross Wightman
"""
import math
from typing import List, Tuple, Optional, Union
import torch
from torch import nn as nn
from .trace_utils import _assert
def pixel_freq_bands(
num_bands: int,
max_freq: float = 224.,
linear_bands: bool = True,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
if linear_bands:
bands = torch.linspace(1.0, max_freq / 2, num_bands, dtype=dtype, device=device)
else:
bands = 2 ** torch.linspace(0, math.log(max_freq, 2) - 1, num_bands, dtype=dtype, device=device)
return bands * torch.pi
def freq_bands(
num_bands: int,
temperature: float = 10000.,
step: int = 2,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> torch.Tensor:
bands = 1. / (temperature ** (torch.arange(0, num_bands, step, dtype=dtype, device=device) / num_bands))
return bands
def build_sincos2d_pos_embed(
feat_shape: List[int],
dim: int = 64,
temperature: float = 10000.,
reverse_coord: bool = False,
interleave_sin_cos: bool = False,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None
) -> torch.Tensor:
"""
Args:
feat_shape:
dim:
temperature:
reverse_coord: stack grid order W, H instead of H, W
interleave_sin_cos: sin, cos, sin, cos stack instead of sin, sin, cos, cos
dtype:
device:
Returns:
"""
assert dim % 4 == 0, 'Embed dimension must be divisible by 4 for sin-cos 2D position embedding'
pos_dim = dim // 4
bands = freq_bands(pos_dim, temperature=temperature, step=1, dtype=dtype, device=device)
if reverse_coord:
feat_shape = feat_shape[::-1] # stack W, H instead of H, W
grid = torch.stack(torch.meshgrid(
[torch.arange(s, device=device, dtype=dtype) for s in feat_shape])).flatten(1).transpose(0, 1)
pos2 = grid.unsqueeze(-1) * bands.unsqueeze(0)
# FIXME add support for unflattened spatial dim?
stack_dim = 2 if interleave_sin_cos else 1 # stack sin, cos, sin, cos instead of sin sin cos cos
pos_emb = torch.stack([torch.sin(pos2), torch.cos(pos2)], dim=stack_dim).flatten(1)
return pos_emb
def build_fourier_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
num_bands: int = 64,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
include_grid: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
) -> List[torch.Tensor]:
"""
Args:
feat_shape: Feature shape for embedding.
bands: Pre-calculated frequency bands.
num_bands: Number of frequency bands (determines output dim).
max_res: Maximum resolution for pixel based freq.
temperature: Temperature for non-pixel freq.
linear_bands: Linear band spacing for pixel based freq.
include_grid: Include the spatial grid in output.
in_pixels: Output in pixel freq.
ref_feat_shape: Reference feature shape for resize / fine-tune.
dtype: Output dtype.
device: Output device.
Returns:
"""
if bands is None:
if in_pixels:
bands = pixel_freq_bands(
num_bands,
float(max_res),
linear_bands=linear_bands,
dtype=dtype,
device=device,
)
else:
bands = freq_bands(
num_bands,
temperature=temperature,
step=1,
dtype=dtype,
device=device,
)
else:
if device is None:
device = bands.device
if dtype is None:
dtype = bands.dtype
if in_pixels:
t = [torch.linspace(-1., 1., steps=s, device=device, dtype=dtype) for s in feat_shape]
else:
t = [torch.arange(s, device=device, dtype=dtype) for s in feat_shape]
if ref_feat_shape is not None:
# eva's scheme for resizing rope embeddings (ref shape = pretrain)
t = [x / f * r for x, f, r in zip(t, feat_shape, ref_feat_shape)]
grid = torch.stack(torch.meshgrid(t), dim=-1)
grid = grid.unsqueeze(-1)
pos = grid * bands
pos_sin, pos_cos = pos.sin(), pos.cos()
out = [grid, pos_sin, pos_cos] if include_grid else [pos_sin, pos_cos]
return out
class FourierEmbed(nn.Module):
def __init__(
self,
max_res: int = 224,
num_bands: int = 64,
concat_grid=True,
keep_spatial=False,
):
super().__init__()
self.max_res = max_res
self.num_bands = num_bands
self.concat_grid = concat_grid
self.keep_spatial = keep_spatial
self.register_buffer(
'bands',
pixel_freq_bands(max_res, num_bands),
persistent=False,
)
def forward(self, x):
B, C = x.shape[:2]
feat_shape = x.shape[2:]
emb = build_fourier_pos_embed(
feat_shape,
self.bands,
include_grid=self.concat_grid,
dtype=x.dtype,
device=x.device,
)
emb = torch.cat(emb, dim=-1)
emb = emb.transpose(-1, -2).flatten(len(feat_shape))
batch_expand = (B,) + (-1,) * (x.ndim - 1)
# FIXME support nD
if self.keep_spatial:
x = torch.cat([x, emb.unsqueeze(0).expand(batch_expand).permute(0, 3, 1, 2)], dim=1)
else:
x = torch.cat([x.permute(0, 2, 3, 1), emb.unsqueeze(0).expand(batch_expand)], dim=-1)
x = x.reshape(B, feat_shape.numel(), -1)
return x
def rot(x):
return torch.stack([-x[..., 1::2], x[..., ::2]], -1).reshape(x.shape)
def apply_rot_embed(x: torch.Tensor, sin_emb, cos_emb):
if sin_emb.ndim == 3:
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
return x * cos_emb + rot(x) * sin_emb
def apply_rot_embed_list(x: List[torch.Tensor], sin_emb, cos_emb):
if isinstance(x, torch.Tensor):
x = [x]
return [t * cos_emb + rot(t) * sin_emb for t in x]
def apply_rot_embed_cat(x: torch.Tensor, emb):
sin_emb, cos_emb = emb.tensor_split(2, -1)
if sin_emb.ndim == 3:
return x * cos_emb.unsqueeze(1).expand_as(x) + rot(x) * sin_emb.unsqueeze(1).expand_as(x)
return x * cos_emb + rot(x) * sin_emb
def apply_keep_indices_nlc(x, pos_embed, keep_indices):
pos_embed = pos_embed.unsqueeze(0).expand(x.shape[0], -1, -1)
pos_embed = pos_embed.gather(1, keep_indices.unsqueeze(-1).expand(-1, -1, pos_embed.shape[-1]))
return pos_embed
def build_rotary_pos_embed(
feat_shape: List[int],
bands: Optional[torch.Tensor] = None,
dim: int = 64,
max_res: int = 224,
temperature: float = 10000.,
linear_bands: bool = False,
in_pixels: bool = True,
ref_feat_shape: Optional[List[int]] = None,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
):
"""
Args:
feat_shape: Spatial shape of the target tensor for embedding.
bands: Optional pre-generated frequency bands
dim: Output dimension of embedding tensor.
max_res: Maximum resolution for pixel mode.
temperature: Temperature (inv freq) for non-pixel mode
linear_bands: Linearly (instead of log) spaced bands for pixel mode
in_pixels: Pixel vs language (inv freq) mode.
dtype: Output dtype.
device: Output device.
Returns:
"""
sin_emb, cos_emb = build_fourier_pos_embed(
feat_shape,
bands=bands,
num_bands=dim // 4,
max_res=max_res,
temperature=temperature,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=ref_feat_shape,
device=device,
dtype=dtype,
)
num_spatial_dim = 1
# this would be much nicer as a .numel() call to torch.Size(), but torchscript sucks
for x in feat_shape:
num_spatial_dim *= x
sin_emb = sin_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
cos_emb = cos_emb.reshape(num_spatial_dim, -1).repeat_interleave(2, -1)
return sin_emb, cos_emb
class RotaryEmbedding(nn.Module):
""" Rotary position embedding
NOTE: This is my initial attempt at impl rotary embedding for spatial use, it has not
been well tested, and will likely change. It will be moved to its own file.
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.pos_embed_sin = None
self.pos_embed_cos = None
else:
# cache full sin/cos embeddings if shape provided up front
emb_sin, emb_cos = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed_sin',
emb_sin,
persistent=False,
)
self.register_buffer(
'pos_embed_cos',
emb_cos,
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
assert shape is not None
return build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
else:
return self.pos_embed_sin, self.pos_embed_cos
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
sin_emb, cos_emb = self.get_embed(x.shape[2:])
return apply_rot_embed(x, sin_emb, cos_emb)
class RotaryEmbeddingCat(nn.Module):
""" Rotary position embedding w/ concatenatd sin & cos
The following impl/resources were referenced for this impl:
* https://github.com/lucidrains/vit-pytorch/blob/6f3a5fcf0bca1c5ec33a35ef48d97213709df4ba/vit_pytorch/rvt.py
* https://blog.eleuther.ai/rotary-embeddings/
"""
def __init__(
self,
dim,
max_res=224,
temperature=10000,
in_pixels=True,
linear_bands: bool = False,
feat_shape: Optional[List[int]] = None,
ref_feat_shape: Optional[List[int]] = None,
):
super().__init__()
self.dim = dim
self.max_res = max_res
self.temperature = temperature
self.in_pixels = in_pixels
self.feat_shape = feat_shape
self.ref_feat_shape = ref_feat_shape
if feat_shape is None:
# only cache bands
if in_pixels:
bands = pixel_freq_bands(
dim // 4,
float(max_res),
linear_bands=linear_bands,
)
else:
bands = freq_bands(
dim // 4,
temperature=temperature,
step=1,
)
print(bands)
self.register_buffer(
'bands',
bands,
persistent=False,
)
self.embed = None
else:
# cache full sin/cos embeddings if shape provided up front
embeds = build_rotary_pos_embed(
feat_shape=feat_shape,
dim=dim,
max_res=max_res,
linear_bands=linear_bands,
in_pixels=in_pixels,
ref_feat_shape=self.ref_feat_shape,
)
self.bands = None
self.register_buffer(
'pos_embed',
torch.cat(embeds, -1),
persistent=False,
)
def get_embed(self, shape: Optional[List[int]] = None):
if self.bands is not None:
# rebuild embeddings every call, use if target shape changes
_assert(shape is not None, 'valid shape needed')
embeds = build_rotary_pos_embed(
shape,
self.bands,
in_pixels=self.in_pixels,
)
return torch.cat(embeds, -1)
else:
return self.pos_embed
def forward(self, x):
# assuming channel-first tensor where spatial dim are >= 2
pos_embed = self.get_embed(x.shape[2:])
return apply_rot_embed_cat(x, pos_embed)