|
from typing import Optional |
|
|
|
import torch |
|
|
|
try: |
|
from ._ops import ops |
|
except ImportError as e: |
|
|
|
try: |
|
import _quantization |
|
ops = torch.ops._quantization |
|
except ImportError: |
|
raise e |
|
|
|
def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: |
|
return ops.cutlass_scaled_mm_supports_fp8(cuda_device_capability) |
|
|
|
def cutlass_scaled_mm(a: torch.Tensor, |
|
b: torch.Tensor, |
|
scale_a: torch.Tensor, |
|
scale_b: torch.Tensor, |
|
out_dtype: torch.dtype, |
|
bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
assert (b.shape[0] % 16 == 0 and b.shape[1] % 16 == 0) |
|
assert (out_dtype is torch.bfloat16 or out_dtype is torch.float16) |
|
assert bias is None or bias.shape[0] == b.shape[ |
|
1] and bias.dtype == out_dtype |
|
|
|
m = a.shape[0] |
|
n = b.shape[1] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
out = torch.empty((m, n), dtype=out_dtype, device=a.device) |
|
|
|
ops.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias) |
|
|
|
return out |
|
|
|
|