Spaces:
Sleeping
Sleeping
Elite-text-gen-web
/
venv
/lib
/python3.10
/site-packages
/bitsandbytes
/triton
/quantize_rowwise.py
import math | |
import torch | |
import time | |
from bitsandbytes.triton.triton_utils import is_triton_available | |
if not is_triton_available(): | |
def quantize_rowwise(x: torch.Tensor): return None | |
else: | |
import triton | |
import triton.language as tl | |
from triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time | |
# rowwise quantize | |
# TODO: autotune this better. | |
def _quantize_rowwise( | |
x_ptr, | |
output_ptr, | |
output_maxs, | |
n_elements, | |
BLOCK_SIZE: tl.constexpr, | |
P2: tl.constexpr, | |
): | |
pid = tl.program_id(axis=0) | |
block_start = pid * BLOCK_SIZE | |
arange = tl.arange(0, P2) | |
offsets = block_start + arange | |
row_mask = arange < BLOCK_SIZE | |
x = tl.load(x_ptr + offsets, mask=row_mask) | |
abs_x = tl.abs(x) | |
max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0) | |
output = tl.libdevice.llrint(127. * (x / max_val)) | |
tl.store(output_ptr + offsets, output, mask=row_mask) | |
tl.store(output_maxs + pid, max_val) | |
def quantize_rowwise(x: torch.Tensor): | |
output = torch.empty(*x.shape, device=x.device, dtype=torch.int8) | |
output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16) | |
P2 = int(2 ** (math.ceil(math.log2(x.shape[1])))) | |
assert x.is_cuda and output.is_cuda | |
n_elements = output.numel() | |
grid = lambda meta: (x.shape[0],) | |
_quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2) | |
return output, output_maxs | |