An implementation of Blockwise parallel transformer
Also include a reference implementation of memory-efficient transformer
import functools
from typing import NamedTuple
import flax.linen as nn
import jax
import jax.lax as lax
import jax.numpy as jnp
from einops import rearrange
Computing ffn blockwise without materializing the large hidden tensor, training
4x longer sequences than the memory-efficient transformer.
Blockwise parallel transformer Liu et al. 2023
def blockwise_ffn(remat_ffn, inputs, chunk_size=2048, deterministic=True):
# remat_ffn: a rematerialized ffn with policy jax.checkpoint_policies.nothing_saveable()
# inputs: (batch, seq_len, dim)
# chunk_size: the chunk size to split the sequence
inputs = rearrange(inputs, 'b (c n) d -> b c n d', c=chunk_size)
def scan_ffn(remat_ffn, carry, hidden_states):
outputs = remat_ffn(hidden_states, deterministic=deterministic)
return carry, outputs
scan_axis = inputs.ndim - 2
_, res = nn.scan(
split_rngs={"params": False, "dropout": True},
)(remat_ffn, None, inputs)
res = rearrange(res, 'b c n d -> b (c n) d')
return res
Compute attention blockwise without materializing the full attention matrix,
initially proposed in memory-efficient transformer Rabe et al. 2021;
flash attention Dao et al. 2022 proposes a CUDA
efficient implementation; blockwise parallel transformer
Liu et al. 2023 proposes blockwise computing both attention and FFN, enabling 4x
longer sequences than memory-efficient/flash-attention and fusion of attention and FFN.
def blockwise_attn(
query, key, value,
# query, key, value: (batch, seq_len, num_heads, dim_per_head)
# bias: (batch, seq_len) can be used to mask out attention (e.g. padding)
# causal: whether to use causal mask
# policy: one of jax.checkpoint_policies
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
if float32_logits:
query = query.astype(jnp.float32)
key = key.astype(jnp.float32)
batch, q_len, num_heads, dim_per_head = query.shape
batch, kv_len, num_heads, dim_per_head = key.shape
batch, kv_len, num_heads, dim_per_head = value.shape
num_q = q_len // query_chunk_size
num_kv = kv_len // key_chunk_size
query = query.reshape((batch, num_q, query_chunk_size, num_heads, dim_per_head))
key = key.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
value = value.reshape((batch, num_kv, key_chunk_size, num_heads, dim_per_head))
query = jnp.moveaxis(query, 1, 0)
key = jnp.moveaxis(key, 1, 0)
value = jnp.moveaxis(value, 1, 0)
if bias is not None:
for bias_dim, broadcast_dim in zip(bias.shape, (batch, num_heads, q_len, kv_len)):
assert bias_dim == 1 or bias_dim == broadcast_dim
if not deterministic and attn_pdrop > 0.0:
attn_dropout_rng, dropout_rng = jax.random.split(dropout_rng)
attn_dropout = jax.random.bernoulli(attn_dropout_rng, attn_pdrop, (batch, num_heads, q_len, kv_len))
attn_dropout = None
_chunk_bias_fn = functools.partial(
query_chunk_size, key_chunk_size, bias, deterministic,
attn_dropout, attn_pdrop, causal, dtype)
def scan_attention(args):
query_chunk, query_chunk_idx = args
@functools.partial(jax.checkpoint, prevent_cse=prevent_cse, policy=policy)
def scan_kv_block(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
(numerator, denominator, prev_max_score) = carry
attn_weights = jnp.einsum('bqhd,bkhd->bqhk', query_chunk, key_chunk, precision=precision)
bias_chunk = _chunk_bias_fn(query_chunk_idx, key_chunk_idx)
bias_chunk = jnp.moveaxis(bias_chunk, 1, 2)
attn_weights = attn_weights + bias_chunk
max_score = jnp.max(attn_weights, axis=-1, keepdims=True)
max_score = jnp.maximum(prev_max_score, max_score)
max_score = jax.lax.stop_gradient(max_score)
exp_weights = jnp.exp(attn_weights - max_score)
exp_values = jnp.einsum(
'bqhv,bvhd->bqhd', exp_weights, value_chunk, precision=precision
correction = jnp.exp(prev_max_score - max_score)
numerator = numerator * correction + exp_values
denominator = denominator * correction + exp_weights.sum(axis=-1, keepdims=True)
return Carry(numerator, denominator, max_score), None
def skip_upper_half(carry, args):
key_chunk, value_chunk, key_chunk_idx = args
skip_block = jnp.array(False)
if causal:
skip_block = query_chunk_idx < key_chunk_idx
return jax.lax.cond(
lambda carry, args: (carry, None),
init_carry = Carry(
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
jnp.zeros((batch, query_chunk_size, num_heads, dim_per_head), dtype=query.dtype),
(-jnp.inf) * jnp.ones((batch, query_chunk_size, num_heads, 1), dtype=query.dtype),
(numerator, denominator, max_score), _ = lax.scan(
skip_upper_half, init_carry, xs=(key, value, jnp.arange(0, num_kv))
outputs = (numerator / denominator).astype(dtype)
return outputs
_, res = lax.scan(
lambda _, x: ((), scan_attention(x)),
(), xs=(query, jnp.arange(0, num_q))
res = rearrange(res, 'n b c h d -> b (n c) h d')
return res
class Carry(NamedTuple):
numerator: jax.Array
denominator: jax.Array
max_so_far: jax.Array
def _chunk_attention_bias(query_chunk_size, key_chunk_size,
bias, deterministic, attn_dropout, attn_pdrop, causal,
dtype, query_chunk_idx, key_chunk_idx):
query_offset = query_chunk_idx * query_chunk_size
key_offset = key_chunk_idx * key_chunk_size
chunk_bias = jnp.zeros((1, 1, 1, 1), dtype=dtype)
if bias is not None:
chunk_bias = lax.dynamic_slice(
start_indices=(0, 0, query_offset, key_offset),
slice_sizes=(*bias.shape[:2], min(bias.shape[-2], query_chunk_size), min(bias.shape[-1], key_chunk_size)),
if causal:
query_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(query_chunk_size, 1), dimension=0)
key_idx = lax.broadcasted_iota(dtype=jnp.int32, shape=(1, key_chunk_size), dimension=1)
offset = query_offset - key_offset
query_idx += offset
causal_mask_value = (query_idx < key_idx) * jnp.finfo(dtype).min
chunk_bias += causal_mask_value.reshape(1, 1, *causal_mask_value.shape)
if not deterministic and attn_pdrop > 0.0:
attn_dropout_slice = lax.dynamic_slice(
start_indices=(0, 0, query_offset, key_offset),
min(attn_dropout.shape[-2], query_chunk_size),
min(attn_dropout.shape[-1], key_chunk_size),
chunk_bias += attn_dropout_slice * jnp.finfo(dtype).min
return chunk_bias.astype(dtype)
if __name__ == '__main__':
# test
def reference_attn(query, key, value, causal, dtype):
query = query / jnp.sqrt(query.shape[-1]).astype(dtype)
logits = jnp.einsum("bqhc,bkhc->bhqk", query, key)
if causal:
mask_value = jnp.finfo(logits.dtype).min
_, q_seq_len, _, _ = query.shape
_, kv_seq_len, _, _ = key.shape
mask_shape = (q_seq_len, kv_seq_len)
row_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 0)
col_ids = jax.lax.broadcasted_iota(jnp.int32, mask_shape, 1)
causal_mask = (row_ids < col_ids)[None, None, :, :]
logits = logits + jnp.where(causal_mask, mask_value, 0.0)
weights = jax.nn.softmax(logits, axis=-1)
out = jnp.einsum("bhqk,bkhc->bqhc", weights, value)
return out
# random inputs
shape = (1, 32, 8, 64)
query = jax.random.normal(jax.random.PRNGKey(0), shape)
key = jax.random.normal(jax.random.PRNGKey(1), shape)
value = jax.random.normal(jax.random.PRNGKey(2), shape)
causal = True
chunk_size = 4
policy = jax.checkpoint_policies.nothing_saveable()
blockwise = blockwise_attn(query, key, value, None, False, None, 0.0, causal, chunk_size, chunk_size, jnp.float32, policy, 'float32', True, False)
reference = reference_attn(query, key, value, causal, 'float32')
assert jnp.allclose(reference, blockwise, atol=1e-6)