Spaces:
Sleeping
Sleeping
import warnings | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
class SelfAttention(nn.Module): | |
def __init__(self, config): | |
""" | |
Initializes the SelfAttention module. | |
Args: | |
config: An object containing the configuration parameters for the SelfAttention module. | |
""" | |
super().__init__() | |
self._validate_config(config) | |
self._initialize_parameters(config) | |
def empty_kv_cache(self, batch_size: int, kv_cache_maxlen: int, dtype: torch.dtype): | |
""" | |
Empties the key-value cache. | |
Args: | |
batch_size: The batch size. | |
kv_cache_maxlen: The maximum length of the key-value cache. | |
dtype: The data type of the cache. | |
Raises: | |
Exception: If trying to empty the KV cache when it is disabled. | |
""" | |
if self.kv_cache_enabled is False: | |
raise Exception("Trying to empty KV cache when it is disabled") | |
# register so that the cache moves devices along with the module | |
# TODO: get rid of re-allocation. | |
self.register_buffer( | |
"kv_cache", | |
torch.zeros( | |
2, | |
batch_size, | |
kv_cache_maxlen, | |
self.n_head, | |
self.n_embd // self.n_head, | |
dtype=dtype, | |
device=self.c_attn.weight.device, | |
), | |
persistent=False, | |
) | |
self.kv_cache_first_empty_index = 0 | |
def _initialize_parameters(self, config): | |
""" | |
Initializes the parameters of the SelfAttention module. | |
Args: | |
config: An object containing the configuration parameters for the SelfAttention module. | |
""" | |
# key, query, value projections for all heads, but in a batch | |
self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias) | |
# output projection | |
self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias) | |
# regularization | |
self.resid_dropout = nn.Dropout(config.dropout) | |
self.n_head = config.n_head | |
self.n_embd = config.n_embd | |
self.dropout = config.dropout | |
self.causal = config.causal | |
self.attn_kernel_type = config.attn_kernel_type | |
self.attn_dropout = nn.Dropout(config.dropout) | |
self.kv_cache_enabled = False | |
def _validate_config(self, config): | |
""" | |
Validates the configuration parameters. | |
Args: | |
config: An object containing the configuration parameters for the SelfAttention module. | |
Raises: | |
AssertionError: If the embedding dimension is not divisible by the number of heads. | |
""" | |
assert config.n_embd % config.n_head == 0, "Embedding dimension must be divisible by number of heads" | |
def _update_kv_cache(self, q, k, v): | |
""" | |
Updates the key-value cache. | |
Args: | |
q: The query tensor. | |
k: The key tensor. | |
v: The value tensor. | |
Returns: | |
The updated key and value tensors. | |
Raises: | |
AssertionError: If the dimensions of the query, key, and value tensors are not compatible. | |
""" | |
q_time, k_time, v_time = q.shape[1], k.shape[1], v.shape[1] | |
if self.kv_cache_first_empty_index == 0: | |
assert q_time == k_time and q_time == v_time | |
else: | |
assert ( | |
q_time == 1 | |
), f"Only one query at a time is supported, but got q_time={q_time} for kv_cache_first_empty_index={self.kv_cache_first_empty_index}" | |
self.kv_cache[0, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = k | |
self.kv_cache[1, :, self.kv_cache_first_empty_index : self.kv_cache_first_empty_index + q_time] = v | |
self.kv_cache_first_empty_index += q_time | |
k = self.kv_cache[0, :, : self.kv_cache_first_empty_index] | |
v = self.kv_cache[1, :, : self.kv_cache_first_empty_index] | |
return k, v | |
def _torch_attn(self, c_x: torch.Tensor) -> torch.Tensor: | |
""" | |
Performs attention using the torch.nn.functional.scaled_dot_product_attention function. | |
Args: | |
c_x: The input tensor. | |
Returns: | |
The output tensor. | |
""" | |
q, k, v = c_x.split(1, dim=2) # q, k, v of shape (B, T, 1, nh, hs) | |
q = q.squeeze(2) # (B, T, nh, hs) | |
k = k.squeeze(2) # (B, T, nh, hs) | |
v = v.squeeze(2) # (B, T, nh, hs) | |
# if kv-caching and causal, for the "prefill" stage, we need to use a causal mask, and | |
# use no mask for the "one time step" parts. | |
# calculate this before updating kv_caching so we have the right value for kv_cache_first_empty_index | |
is_causal_attn_mask = self.causal and (not self.kv_cache_enabled or self.kv_cache_first_empty_index == 0) | |
if self.kv_cache_enabled: | |
k, v = self._update_kv_cache(q, k, v) | |
q = q.transpose(1, 2) # (B, nh, T, hs) | |
k = k.transpose(1, 2) # (B, nh, T, hs) | |
v = v.transpose(1, 2) # (B, nh, T, hs) | |
y = torch.nn.functional.scaled_dot_product_attention( | |
q, | |
k, | |
v, | |
attn_mask=None, | |
dropout_p=self.dropout if self.training else 0, | |
is_causal=is_causal_attn_mask, | |
).transpose( | |
1, 2 | |
) # (B, nh, T, hs) -> (B, T, nh, hs) | |
return y | |
def forward(self, x): | |
""" | |
Performs the forward pass of the SelfAttention module. | |
Args: | |
x: The input tensor. | |
Returns: | |
The output tensor. | |
""" | |
B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd) | |
# calculate query, key, values for all heads in batch and move head forward to be the batch dim | |
c_x = self.c_attn(x).view(B, T, 3, self.n_head, C // self.n_head) # (B, T, 3, nh, hs) | |
# causal self-attention; | |
if self.attn_kernel_type == "torch_attn": | |
y = self._torch_attn(c_x) | |
else: | |
raise Exception(f"Unknown attention kernel type: {self.attn_kernel_type}") | |
y = y.contiguous().view(B, T, C) # re-assemble all head outputs side by side: (B, T, nh, hs) -> (B, T, hs * nh) | |
# output projection | |
y = self.resid_dropout(self.c_proj(y)) | |
return y | |