|
import cupy |
|
import os |
|
import re |
|
import torch |
|
import typing |
|
from pathlib import Path |
|
import platform |
|
|
|
|
|
|
|
|
|
objCudacache = {} |
|
|
|
|
|
def cuda_int32(intIn: int): |
|
return cupy.int32(intIn) |
|
|
|
|
|
|
|
|
|
|
|
def cuda_float32(fltIn: float): |
|
return cupy.float32(fltIn) |
|
|
|
|
|
|
|
|
|
|
|
def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict, **replace_kwargs): |
|
if "device" not in objCudacache: |
|
objCudacache["device"] = torch.cuda.get_device_name() |
|
|
|
|
|
strKey = strFunction |
|
|
|
for strVariable in objVariables: |
|
objValue = objVariables[strVariable] |
|
|
|
strKey += strVariable |
|
|
|
if objValue is None: |
|
continue |
|
|
|
elif type(objValue) == int: |
|
strKey += str(objValue) |
|
|
|
elif type(objValue) == float: |
|
strKey += str(objValue) |
|
|
|
elif type(objValue) == bool: |
|
strKey += str(objValue) |
|
|
|
elif type(objValue) == str: |
|
strKey += objValue |
|
|
|
elif type(objValue) == torch.Tensor: |
|
strKey += str(objValue.dtype) |
|
strKey += str(objValue.shape) |
|
strKey += str(objValue.stride()) |
|
|
|
elif True: |
|
print(strVariable, type(objValue)) |
|
assert False |
|
|
|
|
|
|
|
|
|
strKey += objCudacache["device"] |
|
|
|
if strKey not in objCudacache: |
|
for strVariable in objVariables: |
|
objValue = objVariables[strVariable] |
|
|
|
if objValue is None: |
|
continue |
|
|
|
elif type(objValue) == int: |
|
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) |
|
|
|
elif type(objValue) == float: |
|
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) |
|
|
|
elif type(objValue) == bool: |
|
strKernel = strKernel.replace("{{" + strVariable + "}}", str(objValue)) |
|
|
|
elif type(objValue) == str: |
|
strKernel = strKernel.replace("{{" + strVariable + "}}", objValue) |
|
|
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.uint8: |
|
strKernel = strKernel.replace("{{type}}", "unsigned char") |
|
|
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float16: |
|
strKernel = strKernel.replace("{{type}}", "half") |
|
|
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float32: |
|
strKernel = strKernel.replace("{{type}}", "float") |
|
|
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.float64: |
|
strKernel = strKernel.replace("{{type}}", "double") |
|
|
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int32: |
|
strKernel = strKernel.replace("{{type}}", "int") |
|
|
|
elif type(objValue) == torch.Tensor and objValue.dtype == torch.int64: |
|
strKernel = strKernel.replace("{{type}}", "long") |
|
|
|
elif type(objValue) == torch.Tensor: |
|
print(strVariable, objValue.dtype) |
|
assert False |
|
|
|
elif True: |
|
print(strVariable, type(objValue)) |
|
assert False |
|
|
|
|
|
|
|
|
|
while True: |
|
objMatch = re.search("(SIZE_)([0-4])(\()([^\)]*)(\))", strKernel) |
|
|
|
if objMatch is None: |
|
break |
|
|
|
|
|
intArg = int(objMatch.group(2)) |
|
|
|
strTensor = objMatch.group(4) |
|
intSizes = objVariables[strTensor].size() |
|
|
|
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) |
|
|
|
|
|
while True: |
|
objMatch = re.search("(OFFSET_)([0-4])(\()([^\)]+)(\))", strKernel) |
|
|
|
if objMatch is None: |
|
break |
|
|
|
|
|
intArgs = int(objMatch.group(2)) |
|
strArgs = objMatch.group(4).split(",") |
|
|
|
strTensor = strArgs[0] |
|
intStrides = objVariables[strTensor].stride() |
|
strIndex = [ |
|
"((" |
|
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() |
|
+ ")*" |
|
+ str(intStrides[intArg]) |
|
+ ")" |
|
for intArg in range(intArgs) |
|
] |
|
|
|
strKernel = strKernel.replace( |
|
objMatch.group(0), "(" + str.join("+", strIndex) + ")" |
|
) |
|
|
|
|
|
while True: |
|
objMatch = re.search("(VALUE_)([0-4])(\()", strKernel) |
|
|
|
if objMatch is None: |
|
break |
|
|
|
|
|
intStart = objMatch.span()[1] |
|
intStop = objMatch.span()[1] |
|
intParentheses = 1 |
|
|
|
while True: |
|
intParentheses += 1 if strKernel[intStop] == "(" else 0 |
|
intParentheses -= 1 if strKernel[intStop] == ")" else 0 |
|
|
|
if intParentheses == 0: |
|
break |
|
|
|
|
|
intStop += 1 |
|
|
|
|
|
intArgs = int(objMatch.group(2)) |
|
strArgs = strKernel[intStart:intStop].split(",") |
|
|
|
assert intArgs == len(strArgs) - 1 |
|
|
|
strTensor = strArgs[0] |
|
intStrides = objVariables[strTensor].stride() |
|
|
|
strIndex = [] |
|
|
|
for intArg in range(intArgs): |
|
strIndex.append( |
|
"((" |
|
+ strArgs[intArg + 1].replace("{", "(").replace("}", ")").strip() |
|
+ ")*" |
|
+ str(intStrides[intArg]) |
|
+ ")" |
|
) |
|
|
|
|
|
strKernel = strKernel.replace( |
|
"VALUE_" + str(intArgs) + "(" + strKernel[intStart:intStop] + ")", |
|
strTensor + "[" + str.join("+", strIndex) + "]", |
|
) |
|
|
|
|
|
for replace_key, value in replace_kwargs.items(): |
|
strKernel = strKernel.replace(replace_key, value) |
|
|
|
objCudacache[strKey] = {"strFunction": strFunction, "strKernel": strKernel} |
|
|
|
|
|
return strKey |
|
|
|
|
|
|
|
def get_cuda_home_path(): |
|
if "CUDA_HOME" in os.environ: |
|
return os.environ["CUDA_HOME"] |
|
import torch |
|
torch_lib_path = Path(torch.__file__).parent / "lib" |
|
torch_lib_path = str(torch_lib_path.resolve()) |
|
if os.path.exists(torch_lib_path): |
|
nvrtc = filter(lambda lib_file: "nvrtc-builtins" in lib_file, os.listdir(torch_lib_path)) |
|
nvrtc = list(nvrtc) |
|
return torch_lib_path if len(nvrtc) > 0 else None |
|
|
|
@cupy.memoize(for_each_device=True) |
|
def cuda_launch(strKey: str): |
|
if True: |
|
cuda_home = get_cuda_home_path() |
|
if cuda_home is not None: |
|
os.environ["CUDA_HOME"] = cuda_home |
|
os.environ["CUDA_PATH"] = cuda_home |
|
else: |
|
os.environ["CUDA_HOME"] = "/usr/local/cuda/" |
|
os.environ["CUDA_PATH"] = "/usr/local/cuda/" |
|
|
|
|
|
return cupy.RawModule(code=objCudacache[strKey]["strKernel"]).get_function( |
|
objCudacache[strKey]["strFunction"] |
|
) |
|
|