Spaces:
Sleeping
Sleeping
# Copyright (c) 2023, Tri Dao. | |
from typing import Tuple, Optional, Union | |
import torch | |
import triton | |
import triton.language as tl | |
# `all_gather_into_tensor` and `reduce_scatter_tensor` are new placeholders for | |
# `_all_gather_base` and `_reduce_scatter_base`. They require the most recent | |
# version of PyTorch. The following 2 lines are for backward compatibility with | |
# older PyTorch. | |
if "all_gather_into_tensor" not in dir(torch.distributed): | |
torch.distributed.all_gather_into_tensor = torch.distributed._all_gather_base | |
def cross_entropy_fwd_kernel( | |
loss_ptr, # data ptrs | |
lse_ptr, | |
z_loss_ptr, | |
logits_ptr, | |
labels_ptr, | |
smoothing, | |
logit_scale, | |
lse_square_scale, | |
ignore_index, | |
total_classes, | |
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes | |
n_cols, # shapes | |
n_rows, | |
logits_row_stride, # strides | |
BLOCK_SIZE: tl.constexpr, | |
HAS_SMOOTHING: tl.constexpr, | |
# if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE | |
SPLIT: tl.constexpr, | |
): | |
row_idx = tl.program_id(0) | |
col_block_idx = tl.program_id(1) | |
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) | |
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
label_idx = tl.load(labels_ptr + row_idx) | |
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( | |
tl.float32 | |
) * logit_scale | |
max_logits = tl.max(logits, 0) | |
if HAS_SMOOTHING: | |
sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0) | |
lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits | |
tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse) | |
if label_idx == ignore_index: | |
loss = 0.0 | |
z_loss = 0.0 | |
else: | |
label_idx -= class_start_idx | |
if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min( | |
n_cols, (col_block_idx + 1) * BLOCK_SIZE | |
): | |
logits_label = tl.load(logits_ptr + label_idx) * logit_scale | |
if HAS_SMOOTHING: | |
loss = ( | |
(lse if not SPLIT else 0.0) | |
- smoothing * sum_logits / total_classes | |
- (1 - smoothing) * logits_label | |
) | |
else: | |
loss = (lse if not SPLIT else 0.0) - logits_label | |
else: | |
# If label is out of bounds, we set the CE loss to 0.0. But we still want the smoothing loss | |
if HAS_SMOOTHING: | |
loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes) | |
else: | |
loss = 0.0 | |
if not SPLIT: | |
z_loss = lse_square_scale * lse * lse | |
loss += z_loss | |
else: | |
z_loss = 0.0 | |
tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss) | |
if not SPLIT: | |
tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss) | |
def cross_entropy_bwd_kernel( | |
dlogits_ptr, # data ptrs | |
dloss_ptr, | |
logits_ptr, | |
lse_ptr, | |
labels_ptr, | |
smoothing, | |
logit_scale, | |
lse_square_scale, | |
ignore_index, | |
total_classes, | |
class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes | |
n_cols, # shapes | |
logits_row_stride, # strides | |
dlogits_row_stride, | |
dloss_row_stride, | |
BLOCK_SIZE: tl.constexpr, | |
HAS_SMOOTHING: tl.constexpr, | |
): | |
row_idx = tl.program_id(0) | |
col_block_idx = tl.program_id(1) | |
logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64) | |
dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64) | |
col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) | |
label_idx = tl.load(labels_ptr + row_idx) | |
if label_idx != ignore_index: | |
dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride) | |
else: | |
dloss = 0.0 | |
logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float("inf")).to( | |
tl.float32 | |
) * logit_scale | |
lse = tl.load(lse_ptr + row_idx) | |
probs = tl.exp(logits - lse) | |
probs += 2.0 * lse_square_scale * lse * probs | |
label_idx -= class_start_idx | |
if HAS_SMOOTHING: | |
smooth_positive = 1.0 - smoothing | |
smooth_negative = smoothing / total_classes | |
probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative | |
else: | |
probs = tl.where(col_offsets == label_idx, probs - 1.0, probs) | |
tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols) | |
class CrossEntropyLoss(torch.autograd.Function): | |
def forward( | |
ctx, | |
logits, | |
labels, | |
smoothing=0.0, | |
logit_scale=1.0, | |
lse_square_scale=0.0, | |
ignore_index=-100, | |
inplace_backward=False, | |
process_group=None, | |
): | |
n_rows, n_cols = logits.shape | |
assert labels.shape == (n_rows,) | |
world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group) | |
total_classes = world_size * n_cols | |
rank = 0 if process_group is None else torch.distributed.get_rank(process_group) | |
class_start_idx = rank * n_cols | |
if logits.stride(-1) != 1: | |
logits = logits.contiguous() | |
# Set these similar to https://github.com/openai/triton/blob/main/python/tutorials/02-fused-softmax.py | |
MAX_BLOCK_SIZE = 64 * 1024 | |
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE) | |
num_warps = ( | |
4 | |
if BLOCK_SIZE < 2048 | |
else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32)) | |
) | |
# We may split the lse computation across multiple blocks, then do a reduction | |
# lse(local_lse) to get the final LSE. This is faster for large n_cols (e.g., > 64k) | |
# where having just one thread block processing more than 64k elements is slow. | |
split = world_size > 1 or n_cols > MAX_BLOCK_SIZE | |
n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE | |
loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,) | |
losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) | |
lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) | |
z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device) | |
# Need this, otherwise Triton tries to launch from cuda:0 and we get | |
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) | |
with torch.cuda.device(logits.device.index): | |
cross_entropy_fwd_kernel[(n_rows, n_splits)]( | |
losses, # data ptrs | |
lse, | |
z_losses, | |
logits, | |
labels, | |
smoothing, | |
logit_scale, | |
lse_square_scale, | |
ignore_index, | |
total_classes, | |
class_start_idx, | |
n_cols, # shapes | |
n_rows, | |
logits.stride(0), # strides | |
BLOCK_SIZE=BLOCK_SIZE, # constants | |
num_warps=num_warps, | |
SPLIT=split, | |
) | |
if split: | |
# If there's no smoothing, if labels are in the vocab of this partition, losses contains | |
# - predicted logit, and 0 otherwise. | |
# If there's smoothing=0.1, for labels in the vocab of this partition, losses contains | |
# -0.9 * predicted logit - 0.1 * sum logit / total_classes. | |
# For labels not in the vocab of this partition, losses contains | |
# -0.1 * sum logit / total_classes. | |
if n_splits > 1: | |
lse = torch.logsumexp(lse, dim=0) | |
losses = losses.sum(dim=0) | |
if world_size > 1: | |
lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device) | |
torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group) | |
handle_losses = torch.distributed.all_reduce( | |
losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True | |
) | |
lse = torch.logsumexp(lse_allgather, dim=0) | |
handle_losses.wait() | |
# After the allreduce, if there's no smoothing, the total losses are - predicted_logit, | |
# we just have to add the (global) lse. | |
# If there's smoothing=0.1, the total losses are | |
# -0.9 * predicted_logit - 0.1 * sum logit / total_classes. | |
# Again, we just have to add the (global) lse. | |
losses += lse | |
if lse_square_scale != 0.0: | |
z_losses = lse_square_scale * lse.square() | |
z_losses.masked_fill_(labels == ignore_index, 0.0) | |
losses += z_losses | |
else: | |
z_losses = torch.zeros_like(losses) | |
losses.masked_fill_(labels == ignore_index, 0.0) | |
ctx.save_for_backward(logits, lse, labels) | |
ctx.mark_non_differentiable(z_losses) | |
ctx.smoothing = smoothing | |
ctx.logit_scale = logit_scale | |
ctx.lse_square_scale = lse_square_scale | |
ctx.ignore_index = ignore_index | |
ctx.total_classes = total_classes | |
ctx.class_start_idx = class_start_idx | |
ctx.inplace_backward = inplace_backward | |
return losses, z_losses | |
def backward(ctx, grad_losses, grad_z_losses): | |
del grad_z_losses # z_losses are only for logging. | |
logits, lse, labels = ctx.saved_tensors | |
dlogits = logits if ctx.inplace_backward else torch.empty_like(logits) | |
n_rows, n_cols = logits.shape | |
BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024) | |
num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16) | |
grid = lambda META: (n_rows, triton.cdiv(n_cols, META["BLOCK_SIZE"])) # noqa | |
# Need this, otherwise Triton tries to launch from cuda:0 and we get | |
# ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?) | |
with torch.cuda.device(logits.device.index): | |
cross_entropy_bwd_kernel[grid]( | |
dlogits, # data ptrs | |
grad_losses, | |
logits, | |
lse, | |
labels, | |
ctx.smoothing, | |
ctx.logit_scale, | |
ctx.lse_square_scale, | |
ctx.ignore_index, | |
ctx.total_classes, | |
ctx.class_start_idx, | |
n_cols, # shapes | |
logits.stride(0), # strides | |
dlogits.stride(0), | |
grad_losses.stride(0), | |
BLOCK_SIZE=BLOCK_SIZE, # constants | |
num_warps=num_warps, | |
) | |
return dlogits, None, None, None, None, None, None, None, None | |
def cross_entropy_loss( | |
logits: torch.Tensor, | |
labels: torch.Tensor, | |
label_smoothing: float = 0.0, | |
logit_scale: float = 1.0, | |
lse_square_scale: float = 0.0, | |
ignore_index=-100, | |
inplace_backward: bool = False, | |
process_group=None, | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
""" | |
Arguments: | |
logits: (batch, vocab_size) | |
labels: (batch,) | |
label_smoothing: float | |
logit_scale: float. Multiply logits by this scale before calculating the loss. | |
lse_square_scale: float. If > 0, we add lse_square_scale * lse(logits) ^ 2 to the loss. | |
This is also referred to as "z-loss". | |
ignore_index: int. If labels == ignore_index, the loss is set to 0.0. | |
inplace_backward: bool. If True, we do the backward pass in-place by modifying the logits. | |
This saves memory. | |
process_group: if not None, we're doing Tensor Parallel: each process is responsible for | |
one part of the vocab. The loss will be aggregated across processes. | |
Returns: | |
losses: (batch,), float | |
z_losses: (batch,), float | |
""" | |
return CrossEntropyLoss.apply( | |
logits, | |
labels, | |
label_smoothing, | |
logit_scale, | |
lse_square_scale, | |
ignore_index, | |
inplace_backward, | |
process_group, | |
) | |