Spaces:
Runtime error
Runtime error
import torch | |
import torch_xla | |
import torch_xla.distributed.spmd as xs | |
import torch_xla.core.xla_model as xm | |
import pickle | |
import jax | |
import os | |
from torch_xla.experimental.custom_kernel import ( | |
FlashAttention, | |
jax_import_guard, | |
trace_pallas, | |
) | |
def flash_attention( | |
q, # [batch_size, num_heads, q_seq_len, d_model] | |
k, # [batch_size, num_heads, kv_seq_len, d_model] | |
v, # [batch_size, num_heads, kv_seq_len, d_model] | |
causal=False, | |
q_segment_ids=None, # [batch_size, q_seq_len] | |
kv_segment_ids=None, # [batch_size, kv_seq_len] | |
sm_scale=1.0, | |
*, | |
ab=None, # [batch_size, num_heads, q_seq_len, kv_seq_len] | |
partition_spec=None, | |
mesh=None, | |
): | |
# TODO: support SPMD and Dynamo with segment_ids. | |
return SPMDFlashAttention.apply( | |
q, | |
k, | |
v, | |
causal, | |
q_segment_ids, | |
kv_segment_ids, | |
sm_scale, | |
ab, | |
partition_spec, | |
mesh, | |
) | |
class SPMDFlashAttention(FlashAttention): | |
""" | |
This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139 | |
where we only takes q, k, v and causal as input and set block_sizes for the users. | |
""" | |
def forward( | |
ctx, | |
q, | |
k, | |
v, | |
causal, | |
q_segment_ids, | |
kv_segment_ids, | |
sm_scale, | |
ab, | |
partition_spec, | |
mesh, | |
): | |
# Import JAX within the function such that we don't need to call the jax_import_guard() | |
# in the global scope which could cause problems for xmp.spawn. | |
jax_import_guard() | |
import jax # noqa: F401 | |
from jax.experimental.pallas.ops.tpu.flash_attention import ( | |
_flash_attention_impl, | |
) | |
ctx.causal = causal | |
ctx.sm_scale = sm_scale | |
ctx.partition_spec = partition_spec | |
ctx.mesh = mesh | |
ctx.q_full_shape = None | |
ctx.kv_full_shape = None | |
save_residuals = q.requires_grad or k.requires_grad or v.requires_grad | |
# SPMD integration. | |
# mark_sharding is in-placed, and therefore save the full q, k, v for the backward. | |
full_q = q | |
full_k = k | |
full_v = v | |
full_ab = ab | |
if partition_spec is not None: | |
ctx.q_full_shape = q.shape | |
ctx.kv_full_shape = k.shape | |
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor | |
k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor | |
v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor | |
if ab: | |
ab = xs.enable_manual_sharding( | |
ab, partition_spec, mesh=mesh | |
).global_tensor | |
# It computes the shape and type of o, l, m. | |
shapes = [q.shape] | |
dtypes = [q.dtype] | |
if save_residuals: | |
res_shape = list(q.shape) | |
res_shape[-1] = FlashAttention.MIN_BLOCK_SIZE | |
for _ in range(2): | |
shapes.append(res_shape) | |
dtypes.append(torch.float32) | |
with torch.no_grad(): | |
if ( | |
partition_spec is not None | |
and q_segment_ids is not None | |
and kv_segment_ids is not None | |
): | |
# partition_spec is for q,k,v with shape [batch, num_head, seq_len, head_dim], segment id | |
# is of shape [batch, seq_len], hence we need to tweak it a bit | |
segment_id_partition_spec = (partition_spec[0], partition_spec[2]) | |
q_segment_ids = xs.enable_manual_sharding( | |
q_segment_ids, (partition_spec[0], partition_spec[2]), mesh=mesh | |
).global_tensor | |
kv_segment_ids = xs.enable_manual_sharding( | |
kv_segment_ids, segment_id_partition_spec, mesh=mesh | |
).global_tensor | |
segment_ids, q_segment_ids_fa, kv_segment_ids_fa = ( | |
FlashAttention.prepare_segment_ids(q_segment_ids, kv_segment_ids) | |
) | |
ctx.segment_ids = segment_ids | |
# We can't directly use flash_attention as we need to override the save_residuals flag which returns | |
# l and m that is needed for the backward. Then we lose all the shape checks. | |
# TODO: replicate the shape checks on flash_attention. | |
# Here we seperate the tracing and execution part just to support SegmentIds. | |
payload, _ = trace_pallas( | |
_flash_attention_impl, | |
q, | |
k, | |
v, | |
ab, | |
segment_ids, | |
save_residuals, | |
causal, | |
sm_scale, | |
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_b"], q.shape[0]), | |
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_q"], q.shape[2]), | |
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major"], k.shape[2]), | |
min(FlashAttention.DEFAULT_BLOCK_SIZES["block_k"], k.shape[2]), | |
False, | |
static_argnums=range(5, 13), | |
use_cache=True, | |
) | |
args = [q, k, v] | |
if ab is not None: | |
args += [ab] | |
if segment_ids is not None: | |
args += [q_segment_ids_fa, kv_segment_ids_fa] | |
o = torch_xla._XLAC._xla_tpu_custom_call(args, payload, shapes, dtypes) | |
if not save_residuals: | |
o = o[0] | |
# SPMD integration | |
if partition_spec is not None: | |
o = xs.disable_manual_sharding( | |
o, partition_spec, ctx.q_full_shape, mesh=mesh | |
).global_tensor | |
return o | |
o, *aux = o | |
l, m = (v[..., 0] for v in aux[-2:]) # noqa: E741 | |
# SPMD integration | |
if partition_spec is not None: | |
o = xs.disable_manual_sharding( | |
o, partition_spec, ctx.q_full_shape, mesh=mesh | |
).global_tensor | |
l = xs.disable_manual_sharding( # noqa: E741 | |
l, partition_spec[0:3], ctx.q_full_shape[0:3], mesh=mesh | |
).global_tensor | |
m = xs.disable_manual_sharding( | |
m, partition_spec[0:3], ctx.q_full_shape[0:3], mesh=mesh | |
).global_tensor | |
ctx.save_for_backward( | |
full_q, | |
full_k, | |
full_v, | |
o, | |
l, | |
m, | |
q_segment_ids_fa, | |
kv_segment_ids_fa, | |
full_ab, | |
) | |
return o | |
def backward(ctx, grad_output): | |
from jax.experimental.pallas.ops.tpu.flash_attention import ( | |
_flash_attention_bwd_dq, | |
_flash_attention_bwd_dkv, | |
) | |
q, k, v, o, l, m, q_segment_ids_fa, kv_segment_ids_fa, ab = ( # noqa: E741 | |
ctx.saved_tensors | |
) | |
causal = ctx.causal | |
sm_scale = ctx.sm_scale | |
partition_spec = ctx.partition_spec | |
mesh = ctx.mesh | |
q_full_shape = ctx.q_full_shape | |
kv_full_shape = ctx.kv_full_shape | |
segment_ids = ctx.segment_ids | |
grad_q = grad_k = grad_v = grad_ab = None | |
grad_i = torch.sum( | |
o.to(torch.float32) * grad_output.to(torch.float32), axis=-1 | |
) # [batch_size, num_heads, q_seq_len] | |
expanded_l = l.unsqueeze(-1).expand( | |
[-1 for _ in l.shape] + [FlashAttention.MIN_BLOCK_SIZE] | |
) | |
expanded_m = m.unsqueeze(-1).expand( | |
[-1 for _ in m.shape] + [FlashAttention.MIN_BLOCK_SIZE] | |
) | |
expanded_grad_i = grad_i.unsqueeze(-1).expand( | |
[-1 for _ in grad_i.shape] + [FlashAttention.MIN_BLOCK_SIZE] | |
) | |
# SPMD integration | |
if partition_spec is not None: | |
q = xs.enable_manual_sharding(q, partition_spec, mesh=mesh).global_tensor | |
k = xs.enable_manual_sharding(k, partition_spec, mesh=mesh).global_tensor | |
v = xs.enable_manual_sharding(v, partition_spec, mesh=mesh).global_tensor | |
expanded_l = xs.enable_manual_sharding( | |
expanded_l, partition_spec, mesh=mesh | |
).global_tensor | |
expanded_m = xs.enable_manual_sharding( | |
expanded_m, partition_spec, mesh=mesh | |
).global_tensor | |
grad_output = xs.enable_manual_sharding( | |
grad_output, partition_spec, mesh=mesh | |
).global_tensor | |
expanded_grad_i = xs.enable_manual_sharding( | |
expanded_grad_i, partition_spec, mesh=mesh | |
).global_tensor | |
if ab: | |
ab = xs.enable_manual_sharding( | |
ab, partition_spec, mesh=mesh | |
).global_tensor | |
if ctx.needs_input_grad[0]: | |
payload, _ = trace_pallas( | |
_flash_attention_bwd_dq, | |
q, | |
k, | |
v, | |
ab, | |
segment_ids, | |
l, | |
m, | |
grad_output, | |
grad_i, | |
block_q_major=min( | |
FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dq"], q.shape[2] | |
), | |
block_k_major=min( | |
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dq"], k.shape[2] | |
), | |
block_k=min( | |
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dq"], k.shape[2] | |
), | |
sm_scale=sm_scale, | |
causal=causal, | |
mask_value=FlashAttention.DEFAULT_MASK_VALUE, | |
debug=False, | |
static_argnames=[ | |
"block_q_major", | |
"block_k_major", | |
"block_k", | |
"sm_scale", | |
"causal", | |
"mask_value", | |
"debug", | |
], | |
use_cache=True, | |
) | |
args = [q, k, v] | |
if ab is not None: | |
args += [ab] | |
if segment_ids is not None: | |
args += [q_segment_ids_fa, kv_segment_ids_fa] | |
args += [expanded_l, expanded_m, grad_output, expanded_grad_i] | |
outputs = [q] | |
if ab is not None: | |
outputs += [ab] | |
grads = torch_xla._XLAC._xla_tpu_custom_call( | |
args, payload, [i.shape for i in outputs], [i.dtype for i in outputs] | |
) | |
if ctx.needs_input_grad[0]: | |
grad_q = grads[0] | |
if ctx.needs_input_grad[-3]: | |
grad_ab = grads[1] | |
if ctx.needs_input_grad[1] or ctx.needs_input_grad[2]: | |
payload, _ = trace_pallas( | |
_flash_attention_bwd_dkv, | |
q, | |
k, | |
v, | |
ab, | |
segment_ids, | |
l, | |
m, | |
grad_output, | |
grad_i, | |
block_q_major=min( | |
FlashAttention.DEFAULT_BLOCK_SIZES["block_q_major_dkv"], q.shape[2] | |
), | |
block_k_major=min( | |
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_major_dkv"], k.shape[2] | |
), | |
block_k=min( | |
FlashAttention.DEFAULT_BLOCK_SIZES["block_k_dkv"], k.shape[2] | |
), | |
block_q=min( | |
FlashAttention.DEFAULT_BLOCK_SIZES["block_q_dkv"], q.shape[2] | |
), | |
sm_scale=sm_scale, | |
causal=causal, | |
mask_value=FlashAttention.DEFAULT_MASK_VALUE, | |
debug=False, | |
static_argnames=[ | |
"block_q_major", | |
"block_k_major", | |
"block_k", | |
"block_q", | |
"sm_scale", | |
"causal", | |
"mask_value", | |
"debug", | |
], | |
use_cache=True, | |
) | |
grads = torch_xla._XLAC._xla_tpu_custom_call( | |
args, payload, [k.shape, v.shape], [k.dtype, v.dtype] | |
) | |
if ctx.needs_input_grad[1]: | |
grad_k = grads[0] | |
if ctx.needs_input_grad[2]: | |
grad_v = grads[1] | |
# SPMD integration | |
if partition_spec is not None: | |
grad_q = xs.disable_manual_sharding( | |
grad_q, partition_spec, q_full_shape, mesh=mesh | |
).global_tensor | |
grad_k = xs.disable_manual_sharding( | |
grad_k, partition_spec, kv_full_shape, mesh=mesh | |
).global_tensor | |
grad_v = xs.disable_manual_sharding( | |
grad_v, partition_spec, kv_full_shape, mesh=mesh | |
).global_tensor | |
return grad_q, grad_k, grad_v, None, None, None, None, grad_ab, None, None | |
if __name__ == "__main__": | |
if len(os.sys.argv) < 2: | |
print("Usage: python custom_kernel_spmd.py <use_spmd>") | |
os.sys.exit(1) | |
use_spmd = os.sys.argv[1] | |
jax.config.update("jax_default_matmul_precision", "highest") | |
mesh, attn_spec = None, None | |
if use_spmd: | |
import torch_xla.runtime as xr | |
from torch_xla.distributed.spmd import Mesh | |
import numpy as np | |
xr.use_spmd() | |
num_devices = xr.global_runtime_device_count() | |
mesh_shape = (1, 1, num_devices) | |
device_ids = np.array(range(num_devices)) | |
mesh = Mesh(device_ids, mesh_shape, ("data", "model", "sequence")) | |
attn_spec = ("data", None, None, None) | |
batch_size = 1000 | |
data_path = "data.pkl" | |
if os.path.exists(data_path): | |
with open(data_path, "rb") as f: | |
q, k, v, mask = pickle.load(f) | |
else: | |
q = torch.randn(batch_size, 2, 128, 4) | |
k = torch.randn(batch_size, 2, 128, 4) | |
v = torch.randn(batch_size, 2, 128, 4) | |
mask = torch.rand(batch_size, 128) | |
pickle.dump((q, k, v, mask), open(data_path, "wb")) | |
q, k, v, mask = q.to("xla"), k.to("xla"), v.to("xla"), mask.to("xla") | |
q.requires_grad = True | |
k.requires_grad = True | |
v.requires_grad = True | |
q.retain_grad() | |
k.retain_grad() | |
v.retain_grad() | |
q_segment_indexes = torch.ones( | |
batch_size, q.shape[2], device=q.device, dtype=torch.float32 | |
) | |
grads_path = "grads.pkl" | |
if os.path.exists(grads_path): | |
print("loaded output") | |
with open(grads_path, "rb") as f: | |
o, q_grad, k_grad, v_grad = pickle.load(f) | |
o, q_grad, k_grad, v_grad = ( | |
o.to("xla"), | |
q_grad.to("xla"), | |
k_grad.to("xla"), | |
v_grad.to("xla"), | |
) | |
else: | |
o = SPMDFlashAttention.apply( | |
q, k, v, False, q_segment_indexes, mask, 1.0, attn_spec, mesh | |
) | |
print(f"created output with shape {o.shape}", flush=True) | |
loss = o.sum() | |
loss.backward() | |
xm.mark_step() | |
q_grad = q.grad | |
k_grad = k.grad | |
v_grad = v.grad | |
o_cpu = o.cpu() | |
with open("grads.pkl", "wb") as f: | |
pickle.dump([o.cpu(), q_grad.cpu(), k_grad.cpu(), v_grad.cpu()], f) | |
q.grad = None | |
k.grad = None | |
v.grad = None | |
o2 = SPMDFlashAttention.apply( | |
q, k, v, False, q_segment_indexes, mask, 1.0, attn_spec, mesh | |
) | |
loss = o2.sum() | |
loss.backward() | |
xm.mark_step() | |
print( | |
"comparing gradients (loaded / computed) to the gradients after computing the same again:" | |
) | |
for i, j in [(q_grad, q.grad), (k_grad, k.grad), (v_grad, v.grad)]: | |
print(torch.allclose(i, j, rtol=1e-14)) | |
print("opposite") | |
for i, j in [(q_grad, q.grad), (k_grad, k.grad), (v_grad, v.grad)]: | |
print(torch.allclose(j, i, rtol=1e-14)) | |
print(f"comparing second output with shape: {o2.shape}", flush=True) | |
print(torch.allclose(o, o2, rtol=1e-14)) | |