Spaces:
Running
Running
File size: 6,975 Bytes
c61ccee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import re
from typing import Callable, List
import torch
from torch import Tensor
__all__: List[str] = []
class _CodeParser:
def __init__(self, code_string: str):
optional_ws = r"\s*"
required_ws = r"\s+"
template_params = r"(?P<template_params>\<.+\>)"
return_type = r"(?P<return_type>\w+)"
function_name = r"(?P<function_name>\w+)"
function_params = r"(?P<function_params>\(.+\))"
function_body = r"(?P<function_body>\{.+\})"
pattern = (
optional_ws
+ "template"
+ optional_ws
+ template_params
+ optional_ws
+ return_type
+ required_ws
+ function_name
+ optional_ws
+ function_params
+ optional_ws
+ function_body
+ optional_ws
)
result = re.match(
pattern, code_string, re.DOTALL
) # DOTALL for matching multiline
if result is None:
raise Exception(
f"Couldn't parse code, please check correctness:\n {code_string}"
)
self.template_params = result["template_params"]
self.return_type = result["return_type"]
self.function_name = result["function_name"]
self.function_params = result["function_params"]
self.function_body = result["function_body"]
class _JittedFunction:
def __init__(
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
):
self.code_string = code_string
assert (
return_by_ref or num_outputs == 1
), "Return by value only works for single output. "
self.return_by_ref = return_by_ref
self.num_outputs = num_outputs
parsed_code = _CodeParser(code_string)
self.kernel_name = parsed_code.function_name
self.kwargs_dict = kwargs
self.is_cuda_available = torch.cuda.is_available()
def __call__(self, *tensors: Tensor, **kwargs):
# Jiterator follow torch.cuda's lazy initialization behavior
# Defer checking cuda's availability at the function invocation time
assert (
self.is_cuda_available
), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
expanded_kwargs = self.kwargs_dict.copy()
for key, value in kwargs.items():
if key in self.kwargs_dict:
expanded_kwargs[key] = value
else:
raise KeyError(f"{key} is not declared in function definition")
return torch._C._cuda_jiterator_compile_and_launch_kernel(
self.code_string,
self.kernel_name,
self.return_by_ref,
self.num_outputs,
tensors,
expanded_kwargs,
)
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
"""
Create a jiterator-generated cuda kernel for an elementwise op.
The code string has to be a valid CUDA function that describes the computation for a single element. The code
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
local temp dir.
Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
Args:
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
kwargs (Dict, optional): Keyword arguments for generated function
Example::
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)
code_string also allows multiple function definitions, and the last function will be treated as the entry function.
Example::
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
jitted_fn = create_jit_fn(code_string, val=0.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b) # using default val=0.0
Jiterator can be used together with python registration to override an operator's cuda kernel.
Following example is overriding gelu's cuda kernel with relu.
Example::
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
my_gelu = create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', my_gelu, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
a = torch.rand(3, device='cuda')
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
.. warning::
This API is in beta and may change in future releases.
.. warning::
This API only supports up to 8 inputs and 1 output
.. warning::
All input tensors must live in CUDA device
"""
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
def _create_multi_output_jit_fn(
code_string: str, num_outputs: int, **kwargs
) -> Callable:
"""
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
Args:
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
num_outputs(int): number of outputs return by the kernel
kwargs (Dict, optional): Keyword arguments for generated function
Example::
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)
.. warning::
This API is in beta and may change in future releases.
.. warning::
This API only supports up to 8 inputs and 8 outputs
"""
return _JittedFunction(
code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
)
|