File size: 6,975 Bytes
c61ccee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import re
from typing import Callable, List

import torch
from torch import Tensor

__all__: List[str] = []


class _CodeParser:
    def __init__(self, code_string: str):
        optional_ws = r"\s*"
        required_ws = r"\s+"
        template_params = r"(?P<template_params>\<.+\>)"
        return_type = r"(?P<return_type>\w+)"
        function_name = r"(?P<function_name>\w+)"
        function_params = r"(?P<function_params>\(.+\))"
        function_body = r"(?P<function_body>\{.+\})"

        pattern = (
            optional_ws
            + "template"
            + optional_ws
            + template_params
            + optional_ws
            + return_type
            + required_ws
            + function_name
            + optional_ws
            + function_params
            + optional_ws
            + function_body
            + optional_ws
        )

        result = re.match(
            pattern, code_string, re.DOTALL
        )  # DOTALL for matching multiline

        if result is None:
            raise Exception(
                f"Couldn't parse code, please check correctness:\n {code_string}"
            )

        self.template_params = result["template_params"]
        self.return_type = result["return_type"]
        self.function_name = result["function_name"]
        self.function_params = result["function_params"]
        self.function_body = result["function_body"]


class _JittedFunction:
    def __init__(

        self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs

    ):
        self.code_string = code_string

        assert (
            return_by_ref or num_outputs == 1
        ), "Return by value only works for single output. "
        self.return_by_ref = return_by_ref
        self.num_outputs = num_outputs

        parsed_code = _CodeParser(code_string)
        self.kernel_name = parsed_code.function_name

        self.kwargs_dict = kwargs
        self.is_cuda_available = torch.cuda.is_available()

    def __call__(self, *tensors: Tensor, **kwargs):
        # Jiterator follow torch.cuda's lazy initialization behavior
        # Defer checking cuda's availability at the function invocation time
        assert (
            self.is_cuda_available
        ), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."

        assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."

        expanded_kwargs = self.kwargs_dict.copy()
        for key, value in kwargs.items():
            if key in self.kwargs_dict:
                expanded_kwargs[key] = value
            else:
                raise KeyError(f"{key} is not declared in function definition")

        return torch._C._cuda_jiterator_compile_and_launch_kernel(
            self.code_string,
            self.kernel_name,
            self.return_by_ref,
            self.num_outputs,
            tensors,
            expanded_kwargs,
        )


def _create_jit_fn(code_string: str, **kwargs) -> Callable:
    """

    Create a jiterator-generated cuda kernel for an elementwise op.



    The code string has to be a valid CUDA function that describes the computation for a single element. The code

    string has to follow the c++ template pattern, as shown in the example below. This function will be inlined

    into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as

    local temp dir.



    Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.



    Args:

        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.

        kwargs (Dict, optional): Keyword arguments for generated function



    Example::



        code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"

        jitted_fn = create_jit_fn(code_string, alpha=1.0)

        a = torch.rand(3, device='cuda')

        b = torch.rand(3, device='cuda')

        # invoke jitted function like a regular python function

        result = jitted_fn(a, b, alpha=3.14)



    code_string also allows multiple function definitions, and the last function will be treated as the entry function.



    Example::



        code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"

        code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"

        jitted_fn = create_jit_fn(code_string, val=0.0)

        a = torch.rand(3, device='cuda')

        b = torch.rand(3, device='cuda')

        # invoke jitted function like a regular python function

        result = jitted_fn(a, b)  # using default val=0.0



    Jiterator can be used together with python registration to override an operator's cuda kernel.

    Following example is overriding gelu's cuda kernel with relu.



    Example::



        code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"

        my_gelu = create_jit_fn(code_string)

        my_lib = torch.library.Library("aten", "IMPL")

        my_lib.impl('aten::gelu', my_gelu, "CUDA")

        # torch.nn.GELU and torch.nn.function.gelu are now overridden

        a = torch.rand(3, device='cuda')

        torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))



    .. warning::

        This API is in beta and may change in future releases.



    .. warning::

        This API only supports up to 8 inputs and 1 output



    .. warning::

        All input tensors must live in CUDA device

    """
    return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)


def _create_multi_output_jit_fn(

    code_string: str, num_outputs: int, **kwargs

) -> Callable:
    """

    Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.



    Args:

        code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.

        num_outputs(int): number of outputs return by the kernel

        kwargs (Dict, optional): Keyword arguments for generated function



    Example::



        code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"

        jitted_fn = create_jit_fn(code_string, alpha=1.0)

        a = torch.rand(3, device='cuda')

        b = torch.rand(3, device='cuda')

        # invoke jitted function like a regular python function

        result = jitted_fn(a, b, alpha=3.14)



    .. warning::

        This API is in beta and may change in future releases.



    .. warning::

        This API only supports up to 8 inputs and 8 outputs

    """
    return _JittedFunction(
        code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
    )