|
|
|
|
|
|
|
""" |
|
|
|
This code is from AllenAI's Longformer: |
|
https://github.com/allenai/longformer/ |
|
|
|
""" |
|
from typing import Union |
|
from functools import lru_cache |
|
|
|
import torch |
|
import os.path |
|
|
|
|
|
class DiagonaledMM(torch.autograd.Function): |
|
'''Class to encapsulate tvm code for compiling a diagonal_mm function, in addition to calling |
|
this function from PyTorch |
|
''' |
|
|
|
function_dict = {} |
|
|
|
@staticmethod |
|
def _compile_function(dtype: str, device: str, b0: int = 4, b1: int = 4, b2: int = 16): |
|
'''Compiles a tvm function that computes diagonal_mm |
|
args: |
|
dtype: str in ['float64', 'float32', 'float16'] |
|
device: str in ['cpu' or 'cuda'] |
|
b0, b1, b2: size of tensor tiles. Very important for good performance |
|
|
|
''' |
|
import tvm |
|
from tvm.contrib import nvcc |
|
@tvm.register_func |
|
def tvm_callback_cuda_compile(code): |
|
"""Use nvcc compiler for better perf.""" |
|
ptx = nvcc.compile_cuda(code, target="ptx", arch='sm_52') |
|
return ptx |
|
|
|
assert dtype in ['float16', 'float32', 'float64'] |
|
assert device in ['cpu', 'cuda'] |
|
device = None if device == 'cpu' else device |
|
tgt_host="llvm" |
|
|
|
b = tvm.var('b') |
|
n = tvm.var('n') |
|
h = tvm.var('h') |
|
m = tvm.var('m') |
|
w = tvm.var('w') |
|
w_upper = tvm.var('w_upper') |
|
padding = tvm.var('padding') |
|
transpose_t1 = tvm.var('transpose_t1') |
|
t1d3 = tvm.var('t1d3') |
|
t3d3 = tvm.var('t3d3') |
|
X = tvm.placeholder((b, n, h, t1d3), name='X', dtype=dtype) |
|
Y = tvm.placeholder((b, n, h, m), name='Y', dtype=dtype) |
|
k = tvm.reduce_axis((0, t1d3), name='k') |
|
D = tvm.placeholder((h), name='D', dtype='int') |
|
output_shape = (b, n, h, t3d3) |
|
algorithm = lambda l, i, q, j: tvm.sum( |
|
tvm.if_then_else( |
|
t3d3 == m, |
|
tvm.if_then_else( |
|
transpose_t1 == 0, |
|
tvm.if_then_else( |
|
tvm.all( |
|
i + D[q] * (k - w) >= 0, |
|
i + D[q] * (k - w) < n, |
|
), |
|
X[l, i, q, k] * Y[l, i + D[q] * (k - w), q, j], |
|
padding |
|
), |
|
tvm.if_then_else( |
|
tvm.all( |
|
i + D[q] * (k - w_upper) >= 0, |
|
i + D[q] * (k - w_upper) < n, |
|
), |
|
X[l, i + D[q] * (k - w_upper), q, (w_upper + w) - k] * Y[l, i + D[q] * (k - w_upper), q, j], |
|
padding |
|
), |
|
), |
|
tvm.if_then_else( |
|
tvm.all( |
|
i + D[q] * (j - w) >= 0, |
|
i + D[q] * (j - w) < n, |
|
), |
|
X[l, i, q, k] * Y[l, i + D[q] * (j - w), q, k], |
|
padding |
|
) |
|
), axis=k) |
|
|
|
Z = tvm.compute(output_shape, algorithm, name='Z') |
|
s = tvm.create_schedule(Z.op) |
|
|
|
print('Lowering: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True))) |
|
|
|
|
|
ko, ki = s[Z].split(Z.op.reduce_axis[0], factor=b0) |
|
ZF = s.rfactor(Z, ki) |
|
|
|
j_outer, j_inner = s[Z].split(s[Z].op.axis[-1], factor=b1) |
|
i_outer, i_inner = s[Z].split(s[Z].op.axis[1], factor=b2) |
|
|
|
s[Z].bind(j_outer, tvm.thread_axis("blockIdx.x")) |
|
s[Z].bind(j_inner, tvm.thread_axis("threadIdx.y")) |
|
|
|
s[Z].bind(i_outer, tvm.thread_axis("blockIdx.y")) |
|
s[Z].bind(i_inner, tvm.thread_axis("threadIdx.z")) |
|
|
|
tx = tvm.thread_axis("threadIdx.x") |
|
s[Z].bind(s[Z].op.reduce_axis[0], tx) |
|
s[ZF].compute_at(s[Z], s[Z].op.reduce_axis[0]) |
|
s[Z].set_store_predicate(tx.var.equal(0)) |
|
|
|
print('Lowering with GPU splits: \n ===================== \n{}'.format(tvm.lower(s, [X, Y, D], simple_mode=True))) |
|
|
|
|
|
diagonaled_mm = tvm.build(s, [X, Y, Z, D, w, w_upper, padding, transpose_t1, t3d3], target=device, target_host=tgt_host, name='diagonaled_mm') |
|
return diagonaled_mm |
|
|
|
@staticmethod |
|
def _get_lib_filename(dtype: str, device: str): |
|
base_filename = 'longformer/lib/lib_diagonaled_mm' |
|
return '{}_{}_{}.so'.format(base_filename, dtype, device) |
|
|
|
@staticmethod |
|
def _save_compiled_function(f, dtype: str, device: str): |
|
if not os.path.exists('longformer/lib/'): |
|
os.makedirs('longformer/lib/') |
|
f.export_library(DiagonaledMM._get_lib_filename(dtype, device)) |
|
|
|
@staticmethod |
|
def _load_compiled_function(dtype: str, device: str): |
|
from tvm.module import load |
|
filename = DiagonaledMM._get_lib_filename(dtype, device) |
|
current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
potential_dirs = ['../../', '../', './', f'{current_dir}/', f'{current_dir}/../'] |
|
for potential_dir in potential_dirs: |
|
filepath = '{}{}'.format(potential_dir, filename) |
|
if os.path.isfile(filepath): |
|
print('Loading tvm binary from: {}'.format(filepath)) |
|
return load(filepath) |
|
return None |
|
|
|
@staticmethod |
|
def _get_function(dtype: str, device: str): |
|
'''Loads the function from the disk or compile it''' |
|
|
|
args = (dtype, device) |
|
if args not in DiagonaledMM.function_dict: |
|
diagonaled_mm = DiagonaledMM._load_compiled_function(dtype, device) |
|
if not diagonaled_mm: |
|
print('Tvm binary not found. Compiling ...') |
|
diagonaled_mm = DiagonaledMM._compile_function(dtype, device) |
|
DiagonaledMM._save_compiled_function(diagonaled_mm, dtype, device) |
|
|
|
from tvm.contrib import dlpack |
|
diagonaled_mm_pytorch = dlpack.to_pytorch_func(diagonaled_mm) |
|
|
|
DiagonaledMM.function_dict[args] = diagonaled_mm_pytorch |
|
return DiagonaledMM.function_dict[args] |
|
|
|
@staticmethod |
|
def _diagonaled_mm(t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], |
|
is_t1_diagonaled: bool = False, transpose_t1: bool = False, padding: int = 0, |
|
autoregressive: bool = False): |
|
'''Calls the compiled function after checking the input format. This function is called in three different modes. |
|
t1 x t2 = r ==> t1 and t2 are not diagonaled, but r is. Useful for query x key = attention_scores |
|
t1 x t2 = r ==> t1 is diagonaled, but t2 and r are not. Useful to compuate attantion_scores x value = context |
|
t1 x t2 = r ==> t1 is diagonaled and it should be transposed, but t2 and r are not diagonaled. Useful in some of |
|
the calculations in the backward pass. |
|
''' |
|
dtype = str(t1.dtype).split('.')[1] |
|
device = t1.device.type |
|
assert len(t1.shape) == 4 |
|
assert len(t1.shape) == len(t2.shape) |
|
assert t1.shape[:3] == t2.shape[:3] |
|
if isinstance(d, int): |
|
|
|
d = t1.new_full(size=(t1.shape[2],), fill_value=d, dtype=torch.int, requires_grad=False) |
|
|
|
assert len(d.shape) == 1 |
|
assert d.shape[0] == t1.shape[2] |
|
b = t1.shape[0] |
|
n = t1.shape[1] |
|
h = t1.shape[2] |
|
m = t2.shape[3] |
|
w_upper = 0 if autoregressive else w |
|
c = w_upper + w + 1 |
|
if is_t1_diagonaled: |
|
assert t1.shape[3] == c |
|
r = t1.new_empty(b, n, h, m) |
|
else: |
|
assert not transpose_t1 |
|
assert t1.shape[3] == m |
|
r = t1.new_empty(b, n, h, c) |
|
|
|
|
|
_diagonaled_mm_function = DiagonaledMM._get_function(dtype=dtype, device=device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if m == c: |
|
|
|
print('Error: the hidden dimension {m} shouldn\'t match number of diagonals {c}') |
|
assert False |
|
_diagonaled_mm_function(t1, t2, r, d, w, w_upper, padding, transpose_t1, m if is_t1_diagonaled else c) |
|
return r |
|
|
|
@staticmethod |
|
def _prepare_tensors(t): |
|
'''Fix `stride()` information of input tensor. This addresses some inconsistency in stride information in PyTorch. |
|
For a tensor t, if t.size(0) == 1, then the value of t.stride()[0] doesn't matter. |
|
TVM expects this value to be the `product(t.size()[1:])` but PyTorch some times sets it to `t.stride()[1]`. |
|
Here's an example to reporduce this issue: |
|
import torch |
|
print(torch.randn(1, 10).stride()) |
|
> (10, 1) |
|
print(torch.randn(10, 1).t().contiguous().stride()) |
|
> (1, 1) # expected it to be (10, 1) as above |
|
print(torch.randn(10, 2).t().contiguous().stride()) |
|
> (10, 1) # but gets the expected stride if the first dimension is > 1 |
|
''' |
|
assert t.is_contiguous() |
|
t_stride = list(t.stride()) |
|
t_size = list(t.size()) |
|
|
|
if t_size[0] == 1 and t_stride[0] == t_stride[1]: |
|
|
|
|
|
t_stride[0] = t_size[1] * t_size[2] * t_size[3] |
|
t = t.as_strided(size=t_size, stride=t_stride) |
|
return t |
|
|
|
min_seq_len = 16 |
|
|
|
@staticmethod |
|
def forward(ctx, t1: torch.Tensor, t2: torch.Tensor, w: int, d: Union[torch.Tensor,int], is_t1_diagonaled: bool = False, padding: int = 0, autoregressive: bool = False) -> torch.Tensor: |
|
'''Compuates diagonal_mm of t1 and t2. |
|
args: |
|
t1: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals). |
|
t1 can be a regular tensor (e.g. `query_layer`) or a diagonaled one (e.g. `attention_scores`) |
|
t2: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size). This is always a non-diagonaled |
|
tensor, e.g. `key_layer` or `value_layer` |
|
w: int = window size; number of attentions on each side of the word |
|
d: torch.Tensor or int = dilation of attentions per attention head. If int, the same dilation value will be used for all |
|
heads. If torch.Tensor, it should be 1D of lenth=number of attention heads |
|
is_t1_diagonaled: is t1 a diagonaled or a regular tensor |
|
padding: the padding value to use when accessing invalid locations. This is mainly useful when the padding |
|
needs to be a very large negative value (to compute softmax of attentions). For other usecases, |
|
please use zero padding. |
|
autoregressive: if true, return only the lower triangle |
|
returns: torch.Tensor = (batch_size, seq_len, num_attention_heads, hidden_size|number_of_diagonals) |
|
if t1 is diagonaed, result is non-diagonaled, and vice versa |
|
''' |
|
batch_size, seq_len, num_attention_heads, hidden_size = t1.size() |
|
assert seq_len >= DiagonaledMM.min_seq_len, 'avoid splitting errors by using seq_len >= {}'.format(DiagonaledMM.min_seq_len) |
|
ctx.save_for_backward(t1, t2) |
|
ctx.w = w |
|
ctx.d = d |
|
ctx.is_t1_diagonaled = is_t1_diagonaled |
|
ctx.autoregressive = autoregressive |
|
t1 = DiagonaledMM._prepare_tensors(t1) |
|
t2 = DiagonaledMM._prepare_tensors(t2) |
|
|
|
output = DiagonaledMM._diagonaled_mm(t1, t2, w, d, is_t1_diagonaled=is_t1_diagonaled, padding=padding, autoregressive=autoregressive) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
t1, t2 = ctx.saved_tensors |
|
w = ctx.w |
|
d = ctx.d |
|
is_t1_diagonaled = ctx.is_t1_diagonaled |
|
autoregressive = ctx.autoregressive |
|
if not grad_output.is_contiguous(): |
|
grad_output = grad_output.contiguous() |
|
grad_output = DiagonaledMM._prepare_tensors(grad_output) |
|
t1 = DiagonaledMM._prepare_tensors(t1) |
|
t2 = DiagonaledMM._prepare_tensors(t2) |
|
|
|
|
|
|
|
grad_t1 = DiagonaledMM._diagonaled_mm(grad_output, t2, w, d, is_t1_diagonaled=not is_t1_diagonaled, autoregressive=autoregressive) |
|
|
|
if is_t1_diagonaled: |
|
grad_t2 = DiagonaledMM._diagonaled_mm(t1, grad_output, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive) |
|
else: |
|
grad_t2 = DiagonaledMM._diagonaled_mm(grad_output, t1, w, d, is_t1_diagonaled=True, transpose_t1=True, autoregressive=autoregressive) |
|
return grad_t1, grad_t2, None, None, None, None, None |
|
|
|
|
|
def _get_invalid_locations_mask_fixed_dilation(seq_len: int, w: int, d: int): |
|
diagonals_list = [] |
|
for j in range(-d * w, d, d): |
|
diagonal_mask = torch.zeros(seq_len, device='cpu', dtype=torch.uint8) |
|
diagonal_mask[:-j] = 1 |
|
diagonals_list.append(diagonal_mask) |
|
return torch.stack(diagonals_list, dim=-1) |
|
|
|
@lru_cache() |
|
def _get_invalid_locations_mask(w: int, d: Union[torch.Tensor,int], autoregressive: bool, device: str): |
|
if isinstance(d, int): |
|
affected_seq_len = w * d |
|
mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d) |
|
mask = mask[None, :, None, :] |
|
else: |
|
affected_seq_len = w * d.max() |
|
head_masks = [] |
|
d_list = d.cpu().numpy().tolist() |
|
for d in d_list: |
|
one_head_mask = _get_invalid_locations_mask_fixed_dilation(affected_seq_len, w, d) |
|
head_masks.append(one_head_mask) |
|
mask = torch.stack(head_masks, dim=-2) |
|
mask = mask[None, :, :, :] |
|
|
|
ending_mask = None if autoregressive else mask.flip(dims=(1, 3)).bool().to(device) |
|
return affected_seq_len, mask.bool().to(device), ending_mask |
|
|
|
def mask_invalid_locations(input_tensor: torch.Tensor, w: int, d: Union[torch.Tensor, int], autoregressive: bool) -> torch.Tensor: |
|
affected_seq_len, beginning_mask, ending_mask = _get_invalid_locations_mask(w, d, autoregressive, input_tensor.device) |
|
seq_len = input_tensor.size(1) |
|
beginning_input = input_tensor[:, :affected_seq_len, :, :w+1] |
|
beginning_mask = beginning_mask[:, :seq_len].expand(beginning_input.size()) |
|
beginning_input.masked_fill_(beginning_mask, -float('inf')) |
|
if not autoregressive: |
|
ending_input = input_tensor[:, -affected_seq_len:, :, -(w+1):] |
|
ending_mask = ending_mask[:, -seq_len:].expand(ending_input.size()) |
|
ending_input.masked_fill_(ending_mask, -float('inf')) |
|
|
|
|
|
diagonaled_mm = DiagonaledMM.apply |
|
|
|
|
|
|
|
|