Spaces:
Running
Running
import re | |
from typing import Callable, List | |
import torch | |
from torch import Tensor | |
__all__: List[str] = [] | |
class _CodeParser: | |
def __init__(self, code_string: str): | |
optional_ws = r"\s*" | |
required_ws = r"\s+" | |
template_params = r"(?P<template_params>\<.+\>)" | |
return_type = r"(?P<return_type>\w+)" | |
function_name = r"(?P<function_name>\w+)" | |
function_params = r"(?P<function_params>\(.+\))" | |
function_body = r"(?P<function_body>\{.+\})" | |
pattern = ( | |
optional_ws | |
+ "template" | |
+ optional_ws | |
+ template_params | |
+ optional_ws | |
+ return_type | |
+ required_ws | |
+ function_name | |
+ optional_ws | |
+ function_params | |
+ optional_ws | |
+ function_body | |
+ optional_ws | |
) | |
result = re.match( | |
pattern, code_string, re.DOTALL | |
) # DOTALL for matching multiline | |
if result is None: | |
raise Exception( | |
f"Couldn't parse code, please check correctness:\n {code_string}" | |
) | |
self.template_params = result["template_params"] | |
self.return_type = result["return_type"] | |
self.function_name = result["function_name"] | |
self.function_params = result["function_params"] | |
self.function_body = result["function_body"] | |
class _JittedFunction: | |
def __init__( | |
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs | |
): | |
self.code_string = code_string | |
assert ( | |
return_by_ref or num_outputs == 1 | |
), "Return by value only works for single output. " | |
self.return_by_ref = return_by_ref | |
self.num_outputs = num_outputs | |
parsed_code = _CodeParser(code_string) | |
self.kernel_name = parsed_code.function_name | |
self.kwargs_dict = kwargs | |
self.is_cuda_available = torch.cuda.is_available() | |
def __call__(self, *tensors: Tensor, **kwargs): | |
# Jiterator follow torch.cuda's lazy initialization behavior | |
# Defer checking cuda's availability at the function invocation time | |
assert ( | |
self.is_cuda_available | |
), "Jiterator is only supported on CUDA and ROCm GPUs, none are available." | |
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs." | |
expanded_kwargs = self.kwargs_dict.copy() | |
for key, value in kwargs.items(): | |
if key in self.kwargs_dict: | |
expanded_kwargs[key] = value | |
else: | |
raise KeyError(f"{key} is not declared in function definition") | |
return torch._C._cuda_jiterator_compile_and_launch_kernel( | |
self.code_string, | |
self.kernel_name, | |
self.return_by_ref, | |
self.num_outputs, | |
tensors, | |
expanded_kwargs, | |
) | |
def _create_jit_fn(code_string: str, **kwargs) -> Callable: | |
""" | |
Create a jiterator-generated cuda kernel for an elementwise op. | |
The code string has to be a valid CUDA function that describes the computation for a single element. The code | |
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined | |
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as | |
local temp dir. | |
Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion. | |
Args: | |
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value. | |
kwargs (Dict, optional): Keyword arguments for generated function | |
Example:: | |
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }" | |
jitted_fn = create_jit_fn(code_string, alpha=1.0) | |
a = torch.rand(3, device='cuda') | |
b = torch.rand(3, device='cuda') | |
# invoke jitted function like a regular python function | |
result = jitted_fn(a, b, alpha=3.14) | |
code_string also allows multiple function definitions, and the last function will be treated as the entry function. | |
Example:: | |
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }" | |
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }" | |
jitted_fn = create_jit_fn(code_string, val=0.0) | |
a = torch.rand(3, device='cuda') | |
b = torch.rand(3, device='cuda') | |
# invoke jitted function like a regular python function | |
result = jitted_fn(a, b) # using default val=0.0 | |
Jiterator can be used together with python registration to override an operator's cuda kernel. | |
Following example is overriding gelu's cuda kernel with relu. | |
Example:: | |
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }" | |
my_gelu = create_jit_fn(code_string) | |
my_lib = torch.library.Library("aten", "IMPL") | |
my_lib.impl('aten::gelu', my_gelu, "CUDA") | |
# torch.nn.GELU and torch.nn.function.gelu are now overridden | |
a = torch.rand(3, device='cuda') | |
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a)) | |
.. warning:: | |
This API is in beta and may change in future releases. | |
.. warning:: | |
This API only supports up to 8 inputs and 1 output | |
.. warning:: | |
All input tensors must live in CUDA device | |
""" | |
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs) | |
def _create_multi_output_jit_fn( | |
code_string: str, num_outputs: int, **kwargs | |
) -> Callable: | |
""" | |
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs. | |
Args: | |
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference. | |
num_outputs(int): number of outputs return by the kernel | |
kwargs (Dict, optional): Keyword arguments for generated function | |
Example:: | |
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }" | |
jitted_fn = create_jit_fn(code_string, alpha=1.0) | |
a = torch.rand(3, device='cuda') | |
b = torch.rand(3, device='cuda') | |
# invoke jitted function like a regular python function | |
result = jitted_fn(a, b, alpha=3.14) | |
.. warning:: | |
This API is in beta and may change in future releases. | |
.. warning:: | |
This API only supports up to 8 inputs and 8 outputs | |
""" | |
return _JittedFunction( | |
code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs | |
) | |