|
import typing as tp |
|
from einops import rearrange |
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
from torch.utils.checkpoint import checkpoint as torch_checkpoint |
|
from xformers import ops |
|
|
|
|
|
_efficient_attention_backend: str = 'torch' |
|
|
|
|
|
|
|
|
|
|
|
def _get_attention_time_dimension(memory_efficient: bool) -> int: |
|
if _efficient_attention_backend == 'torch' and memory_efficient: |
|
return 2 |
|
else: |
|
return 1 |
|
|
|
|
|
|
|
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, |
|
dtype: torch.dtype = torch.float32) -> torch.Tensor: |
|
"""Create sinusoidal positional embedding, with shape `[B, T, C]`. |
|
|
|
Args: |
|
positions (torch.Tensor): LongTensor of positions. |
|
dim (int): Dimension of the embedding. |
|
max_period (float): Maximum period of the cosine/sine functions. |
|
dtype (torch.dtype or str): dtype to use to generate the embedding. |
|
Returns: |
|
torch.Tensor: Sinusoidal positional embedding. |
|
""" |
|
|
|
assert dim % 2 == 0 |
|
half_dim = dim // 2 |
|
positions = positions.to(dtype) |
|
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) |
|
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) |
|
phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) |
|
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) |
|
|
|
|
|
def expand_repeated_kv(x: torch.Tensor, n_rep: int, memory_efficient: bool) -> torch.Tensor: |
|
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers.""" |
|
if n_rep == 1: |
|
return x |
|
if _efficient_attention_backend == 'torch' and memory_efficient: |
|
bs, n_kv_heads, slen, head_dim = x.shape |
|
return ( |
|
x[:, :, None, :, :] |
|
.expand(bs, n_kv_heads, n_rep, slen, head_dim) |
|
.reshape(bs, n_kv_heads * n_rep, slen, head_dim) |
|
) |
|
else: |
|
bs, slen, n_kv_heads, head_dim = x.shape |
|
return ( |
|
x[:, :, :, None, :] |
|
.expand(bs, slen, n_kv_heads, n_rep, head_dim) |
|
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) |
|
) |
|
|
|
|
|
|
|
|
|
|
|
class StreamingMultiheadAttention(nn.Module): |
|
|
|
def __init__(self, |
|
embed_dim, |
|
num_heads, dropout: float = 0.0, bias: bool = True, |
|
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, |
|
memory_efficient: bool = False, attention_as_float32: bool = False, |
|
cross_attention: bool = False, |
|
kv_repeat: int = 1, |
|
device=None, dtype=None): |
|
super().__init__() |
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
if past_context is not None: |
|
assert causal |
|
|
|
self.embed_dim = embed_dim |
|
|
|
self.k_history = None |
|
self.v_history = None |
|
|
|
self.memory_efficient = memory_efficient |
|
|
|
|
|
self.cross_attention = cross_attention |
|
|
|
self.num_heads = num_heads |
|
self.dropout = dropout |
|
self.kv_repeat = kv_repeat |
|
|
|
|
|
|
|
|
|
self.custom = True |
|
if not self.custom: |
|
print(f'{self.custom}') |
|
if self.custom: |
|
out_dim = embed_dim |
|
assert num_heads % kv_repeat == 0 |
|
assert not cross_attention or kv_repeat == 1 |
|
num_kv = num_heads // kv_repeat |
|
kv_dim = (embed_dim // num_heads) * num_kv |
|
out_dim += 2 * kv_dim |
|
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) |
|
|
|
self.in_proj_weight = in_proj.weight |
|
self.in_proj_bias = in_proj.bias |
|
if bias: |
|
self.in_proj_bias.data.zero_() |
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) |
|
if bias: |
|
self.out_proj.bias.data.zero_() |
|
else: |
|
assert kv_repeat == 1 |
|
self.mha = nn.MultiheadAttention( |
|
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, |
|
**factory_kwargs) |
|
|
|
|
|
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): |
|
if not self.custom: |
|
|
|
keys = [n for n, _ in self.mha.named_parameters()] |
|
for key in keys: |
|
if prefix + key in state_dict: |
|
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) |
|
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, |
|
query, |
|
key=None, |
|
value=None): |
|
|
|
|
|
|
|
|
|
layout = "b h t d" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.custom: |
|
|
|
if self.cross_attention: |
|
|
|
|
|
dim = self.in_proj_weight.shape[0] // 3 |
|
if self.in_proj_bias is None: |
|
bias_q, bias_k, bias_v = None, None, None |
|
else: |
|
bias_q = self.in_proj_bias[:dim] |
|
bias_k = self.in_proj_bias[dim: 2 * dim] |
|
bias_v = self.in_proj_bias[2 * dim:] |
|
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) |
|
|
|
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) |
|
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) |
|
|
|
q, k, v = [rearrange(x, f"b t (h d) -> {layout}", h=self.num_heads) for x in [q, k, v]] |
|
print(q.shape, k.shape, v.shape, q.sum(), k.sum(), v.sum(),'CROSS A5') |
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) |
|
if self.kv_repeat == 1: |
|
|
|
bound_layout = "b h p t d" |
|
|
|
|
|
packed = rearrange(projected, f"b t (p h d) -> {bound_layout}", p=3, h=self.num_heads) |
|
q, k, v = ops.unbind(packed, dim=2) |
|
|
|
|
|
if self.k_history is not None: |
|
|
|
|
|
|
|
|
|
self.k_history = torch.cat([self.k_history, k], 2) |
|
self.v_history = torch.cat([self.v_history, v], 2) |
|
|
|
else: |
|
|
|
print(f'else skip') |
|
self.k_history = k |
|
self.v_history = v |
|
|
|
k = self.k_history |
|
v = self.v_history |
|
|
|
|
|
|
|
|
|
print('KV5', self.k_history.sum(), self.v_history.sum(), self.k_history.shape, self.v_history.shape) |
|
|
|
|
|
if self.memory_efficient: |
|
|
|
|
|
|
|
p = self.dropout if self.training else 0 |
|
if _efficient_attention_backend == 'torch': |
|
|
|
x = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, is_causal=False, dropout_p=p |
|
) |
|
|
|
x = x.to(q.dtype) |
|
x = rearrange(x, f"{layout} -> b t (h d)", h=self.num_heads) |
|
x = self.out_proj(x) |
|
return x |
|
|
|
|
|
class StreamingTransformerLayer(nn.Module): |
|
|
|
|
|
def __init__(self, |
|
d_model: int, |
|
num_heads: int, |
|
dim_feedforward: int = 2048, |
|
dropout: float = 0.1, |
|
bias_ff: bool = True, |
|
bias_attn: bool = True, |
|
custom: bool = False, |
|
memory_efficient: bool = False, |
|
attention_as_float32: bool = False, |
|
cross_attention: bool = False, |
|
attention_dropout: tp.Optional[float] = None, |
|
kv_repeat: int = 1, |
|
norm: str = 'layer_norm', |
|
device=None, |
|
dtype=None, |
|
**kwargs): |
|
|
|
|
|
super().__init__() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
factory_kwargs = {'device': device, 'dtype': dtype} |
|
|
|
attn_kwargs: tp.Dict[str, tp.Any] = { |
|
'embed_dim': d_model, |
|
'num_heads': num_heads, |
|
'dropout': dropout if attention_dropout is None else attention_dropout, |
|
'bias': bias_attn, |
|
'custom': custom, |
|
'memory_efficient': memory_efficient, |
|
'attention_as_float32': attention_as_float32, |
|
} |
|
self.self_attn = StreamingMultiheadAttention( |
|
kv_repeat=kv_repeat, |
|
**attn_kwargs, |
|
**factory_kwargs) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) |
|
|
|
|
|
|
|
self.cross_attention= None |
|
if cross_attention: |
|
self.cross_attention = StreamingMultiheadAttention( |
|
cross_attention=True, |
|
**attn_kwargs, |
|
**factory_kwargs) |
|
|
|
self.dropout_cross = nn.Dropout(dropout) |
|
|
|
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) |
|
self.norm1 = nn.LayerNorm(d_model, eps=1e-5) |
|
self.norm2 = nn.LayerNorm(d_model, eps=1e-5) |
|
|
|
|
|
def forward(self, |
|
src, |
|
cross_attention_src=None): |
|
'''T is saved float16 weights - should we cast src to float16''' |
|
|
|
x = src |
|
|
|
x = x + self.self_attn(self.norm1(x)) |
|
|
|
if cross_attention_src is not None: |
|
x = x + self.cross_attention( |
|
query = self.norm_cross(x), |
|
key = cross_attention_src, |
|
value = cross_attention_src) |
|
|
|
x = x + self.linear2(F.gelu(self.linear1( self.norm2(x) ))) |
|
return x |
|
|
|
|
|
class StreamingTransformer(nn.Module): |
|
|
|
def __init__(self, d_model: int, |
|
num_heads: int, |
|
num_layers: int, |
|
dim_feedforward: int = 2048, |
|
dropout: float = 0.1, |
|
bias_ff: bool = True, |
|
bias_attn: bool = True, |
|
custom: bool = False, |
|
memory_efficient: bool = False, |
|
attention_as_float32: bool = False, |
|
cross_attention: bool = False, |
|
positional_embedding: str = 'sin', |
|
max_period: float = 10_000, |
|
layer_class=StreamingTransformerLayer, |
|
checkpointing: str = 'none', |
|
device=None, |
|
dtype=None, |
|
**kwargs): |
|
super().__init__() |
|
assert d_model % num_heads == 0 |
|
|
|
self.positional_embedding = positional_embedding |
|
self.max_period = max_period |
|
|
|
|
|
|
|
|
|
|
|
self.checkpointing = checkpointing |
|
|
|
|
|
|
|
|
|
self.layers = nn.ModuleList() |
|
for idx in range(num_layers): |
|
self.layers.append( |
|
layer_class( |
|
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, |
|
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, |
|
custom=custom, |
|
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, |
|
cross_attention=cross_attention, |
|
device=device, dtype=dtype, **kwargs)) |
|
|
|
if self.checkpointing != 'none': |
|
for layer in self.layers: |
|
|
|
|
|
layer._magma_checkpointed = True |
|
|
|
|
|
|
|
def forward(self, x: torch.Tensor, *args, **kwargs): |
|
|
|
B, T, C = x.shape |
|
|
|
|
|
if self.positional_embedding in ['sin', 'sin_rope']: |
|
|
|
positions = torch.arange(T, device=x.device).view(1, -1, 1) |
|
positions = positions + kwargs['token_count'] |
|
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) |
|
x = x + pos_emb |
|
|
|
|
|
|
|
for j, lay in enumerate(self.layers): |
|
print(f'5_________________________{j} {pos_emb.sum()=} {pos_emb.shape=}{x.shape=}___________________') |
|
x = lay(x, cross_attention_src=kwargs["cross_attention_src"]) |
|
|
|
return x |
|
|