""" Linear Transformer proposed in "Transformers are RNNs: Fast Autoregressive Transformers with Linear Attention" Modified from: https://github.com/idiap/fast-transformers/blob/master/fast_transformers/attention/linear_attention.py """ import torch from torch.nn import Module import torch.nn.functional as F from einops.einops import rearrange if hasattr(F, 'scaled_dot_product_attention'): FLASH_AVAILABLE = True from torch.backends.cuda import sdp_kernel else: FLASH_AVAILABLE = False def crop_feature(query, key, value, x_mask, source_mask): mask_h0, mask_w0, mask_h1, mask_w1 = x_mask[0].sum(-2)[0], x_mask[0].sum(-1)[0], source_mask[0].sum(-2)[0], source_mask[0].sum(-1)[0] query = query[:, :mask_h0, :mask_w0, :] key = key[:, :mask_h1, :mask_w1, :] value = value[:, :mask_h1, :mask_w1, :] return query, key, value, mask_h0, mask_w0 def pad_feature(m, mask_h0, mask_w0, x_mask): bs, L, H, D = m.size() m = m.view(bs, mask_h0, mask_w0, H, D) if mask_h0 != x_mask.size(-2): m = torch.cat([m, torch.zeros(m.size(0), x_mask.size(-2)-mask_h0, x_mask.size(-1), H, D, device=m.device, dtype=m.dtype)], dim=1) elif mask_w0 != x_mask.size(-1): m = torch.cat([m, torch.zeros(m.size(0), x_mask.size(-2), x_mask.size(-1)-mask_w0, H, D, device=m.device, dtype=m.dtype)], dim=2) return m class Attention(Module): def __init__(self, no_flash=False, nhead=8, dim=256, fp32=False): super().__init__() self.flash = FLASH_AVAILABLE and not no_flash self.nhead = nhead self.dim = dim self.fp32 = fp32 def attention(self, query, key, value, q_mask=None, kv_mask=None): assert q_mask is None and kv_mask is None, "Not support generalized attention mask yet." if self.flash and not self.fp32: args = [x.contiguous() for x in [query, key, value]] with sdp_kernel(enable_math= False, enable_flash= True, enable_mem_efficient= False): out = F.scaled_dot_product_attention(*args) elif self.flash: args = [x.contiguous() for x in [query, key, value]] out = F.scaled_dot_product_attention(*args) else: QK = torch.einsum("nlhd,nshd->nlsh", query, key) # Compute the attention and the weighted average softmax_temp = 1. / query.size(3)**.5 # sqrt(D) A = torch.softmax(softmax_temp * QK, dim=2) out = torch.einsum("nlsh,nshd->nlhd", A, value) return out def _forward(self, query, key, value, q_mask=None, kv_mask=None): if q_mask is not None: query, key, value, mask_h0, mask_w0 = crop_feature(query, key, value, q_mask, kv_mask) if self.flash: query, key, value = map(lambda x: rearrange(x, 'n h w (nhead d) -> n nhead (h w) d', nhead=self.nhead, d=self.dim), [query, key, value]) else: query, key, value = map(lambda x: rearrange(x, 'n h w (nhead d) -> n (h w) nhead d', nhead=self.nhead, d=self.dim), [query, key, value]) m = self.attention(query, key, value, q_mask=None, kv_mask=None) if self.flash: m = rearrange(m, 'n nhead L d -> n L nhead d', nhead=self.nhead, d=self.dim) if q_mask is not None: m = pad_feature(m, mask_h0, mask_w0, q_mask) return m def forward(self, query, key, value, q_mask=None, kv_mask=None): """ Multi-head scaled dot-product attention, a.k.a full attention. Args: if FLASH_AVAILABLE: # pytorch scaled_dot_product_attention queries: [N, H, L, D] keys: [N, H, S, D] values: [N, H, S, D] else: queries: [N, L, H, D] keys: [N, S, H, D] values: [N, S, H, D] q_mask: [N, L] kv_mask: [N, S] Returns: queried_values: (N, L, H, D) """ bs = query.size(0) if bs == 1 or q_mask is None: m = self._forward(query, key, value, q_mask=q_mask, kv_mask=kv_mask) else: # for faster trainning with padding mask while batch size > 1 m_list = [] for i in range(bs): m_list.append(self._forward(query[i:i+1], key[i:i+1], value[i:i+1], q_mask=q_mask[i:i+1], kv_mask=kv_mask[i:i+1])) m = torch.cat(m_list, dim=0) return m