Question about generating Triton kernels
Hi,
I'm using this script to generate Triton kernels from Level_1 problems of KernelBench.
from kernelllm import KernelLLM
from datasets import load_dataset
import json
from tqdm import tqdm
# Initialize the model
model = KernelLLM()
# Load the KernelBench dataset, level_1
dataset = load_dataset("ScalingIntelligence/KernelBench", split="level_1")
# Process each example in the dataset
results = []
for example in tqdm(dataset, desc="Processing examples"):
# Get the PyTorch code from the example
pytorch_code = f'''
{example["code"]}
'''
# Generate optimized Triton code
optimized_code = model.generate_triton(pytorch_code, max_new_tokens=4096)
# Create a new example with the original data plus the optimized code
result = dict(example)
result["triton_generate"] = optimized_code
# Add to results
results.append(result)
# Save results to JSONL file
output_file = "kernelllm_kernelbench_optimized.jsonl"
with open(output_file, "w") as f:
for result in results:
f.write(json.dumps(result) + "\n")
print(f"Results saved to {output_file}")
Could you please check if I am using the proper prompt format?
I am experiencing an issue where the model generates Triton code, but in the forward
function, it still calls the PyTorch methods instead of calling the Triton kernel.
I also noticed that the model is sensitive to the prompt format. For example, if I use pytorch_code = example["code"]
, the model does not include import torch
in its response, which causes compilation errors.
Additionally, if possible, could you share any correct kernels you were able to generate with KernelLLM for Level 1?
Thank you. I appreciate your help.
Hi @tehranixyz ,
Here is an example prompt / answer / eval result triplet from level 1 that I obtained:
The model doesn't always solve this problem correctly, but sampling a couple of times gave me passing solutions, which is also reflected in the large differences between pass@1 and pass@k evaluation results.
Perhaps comparing the full prompts that the model sees can help troubleshoot this?
{
"prompt": '<|begin_of_text|>You write custom Triton kernels to replace the pytorch operators in the given architecture to get speedups. \n\nYou have complete freedom to choose the set of operators you want to replace. You may make the decision to replace some operators with custom Triton kernels and leave others unchanged. You may replace multiple operators with custom implementations, consider operator fusion opportunities (combining multiple operators into a single kernel, for example, combining matmul+relu), or algorithmic changes (such as online softmax). You are only limited by your imagination.\n\n\nHere\'s an example to show you the syntax of inline embedding custom operators from the Triton DSL in torch: The example given architecture is:\n```\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\n\n\nclass Model(nn.Module):\n def __init__(self) -> None:\n super().__init__()\n\n def forward(self, a, b):\n return a + b\n\n\ndef get_inputs():\n # randomly generate input tensors based on the model architecture\n a = torch.randn(1, 128).cuda()\n b = torch.randn(1, 128).cuda()\n return [a, b]\n\n\ndef get_init_inputs():\n # randomly generate tensors required for initialization based on the model architecture\n return []\n\n```\nThe example new arch with custom Triton kernels looks like this:\n```\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport triton\nimport triton.language as tl\n\n\[email protected]\ndef add_kernel(\n x_ptr, # Pointer to first input\n y_ptr, # Pointer to second input\n out_ptr, # Pointer to output\n n_elements, # Total number of elements in input/output\n BLOCK_SIZE: tl.constexpr,\n):\n # Each program handles a contiguous block of data of size BLOCK_SIZE\n block_start = tl.program_id(0) * BLOCK_SIZE\n # Create a range of offsets [0..BLOCK_SIZE-1]\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # Mask to ensure we don\'t go out of bounds\n mask = offsets < n_elements\n # Load input values\n x = tl.load(x_ptr + offsets, mask=mask, other=0.0)\n y = tl.load(y_ptr + offsets, mask=mask, other=0.0)\n # Perform the elementwise addition\n out = x + y\n # Store the result\n tl.store(out_ptr + offsets, out, mask=mask)\n\n\ndef triton_add(x: torch.Tensor, y: torch.Tensor):\n """\n This function wraps the Triton kernel call. It:\n 1. Ensures the inputs are contiguous on GPU.\n 2. Calculates the grid (blocks) needed.\n 3. Launches the Triton kernel.\n """\n assert x.is_cuda and y.is_cuda, "Tensors must be on CUDA."\n x = x.contiguous()\n y = y.contiguous()\n\n # Prepare output tensor\n out = torch.empty_like(x)\n\n # Number of elements in the tensor\n n_elements = x.numel()\n BLOCK_SIZE = 128 # Tunable parameter for block size\n\n # Determine the number of blocks needed\n grid = lambda meta: ((n_elements + meta["BLOCK_SIZE"] - 1) // meta["BLOCK_SIZE"],)\n\n # Launch the Triton kernel\n add_kernel[grid](x, y, out, n_elements, BLOCK_SIZE=BLOCK_SIZE)\n return out\n\n\nclass ModelNew(nn.Module):\n def __init__(self) -> None:\n super().__init__()\n\n def forward(self, a, b):\n # Instead of "return a + b", call our Triton-based addition\n return triton_add(a, b)\n\n```\n\nYou are given the following architecture:\n```\nimport torch\nimport torch.nn as nn\n\nclass Model(nn.Module):\n """\n A model that computes Hinge Loss for binary classification tasks.\n\n Parameters:\n None\n """\n def __init__(self):\n super(Model, self).__init__()\n\n def forward(self, predictions, targets):\n return torch.mean(torch.clamp(1 - predictions * targets, min=0))\n\nbatch_size = 128\ninput_shape = (1,)\ndim = 1\n\ndef get_inputs():\n return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]\n\ndef get_init_inputs():\n return []\n```\n\nOptimize the architecture named Model with custom Triton kernels! Name your optimized output architecture ModelNew. Output the new code in codeblocks. Please generate real code, NOT pseudocode, make sure the code compiles and is fully functional. Just output the new model code, no other text, and NO testing code! \n\n',
"metrics": {
"has_stop_token_in_generation_at_1": 100.0,
"correctness_at_1": 100.0,
"pass_at_1": 100.0,
"kernel_compiles_at_1": 100.0,
"timeout_at_1": false,
"runtime_at_1": 0.0299,
"runtime_stats": {"mean": 0.0299, "std": 0.00231, "min": 0.0276, "max": 0.0438, "num_trials": 100},
"error": null,
"error_stdout": null,
"error_stderr": null,
},
"generation": 'import torch\nimport triton\nimport triton.language as tl\nfrom torch._inductor.runtime.triton_heuristics import grid\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch._inductor.runtime import triton_helpers\nimport torch.nn as nn\nassert_size_stride = torch._C._dynamo.guards.assert_size_stride\nempty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda\n\n\[email protected]\ndef triton_per_fused_clamp_mean_mul_rsub_sub_0(in_out_ptr0, in_ptr0,\n in_ptr1, xnumel, rnumel):\n XBLOCK: tl.constexpr = 1\n RBLOCK: tl.constexpr = 128\n xoffset = tl.program_id(0) * XBLOCK\n tl.full([1], xoffset, tl.int32)\n tl.full([RBLOCK], True, tl.int1)\n rindex = tl.arange(0, RBLOCK)[:]\n tl.full([RBLOCK], True, tl.int1)\n r0 = rindex\n tmp0 = tl.load(in_ptr0 + r0, None)\n tmp1 = tl.load(in_ptr1 + r0, None)\n tmp2 = tmp0 * tmp1\n tmp3 = 1.0\n tmp4 = tmp3 - tmp2\n tmp5 = 0.0\n tmp6 = triton_helpers.maximum(tmp4, tmp5)\n tmp7 = tl.broadcast_to(tmp6, [RBLOCK])\n tmp9 = triton_helpers.promote_to_tensor(tl.sum(tmp7, 0))\n tmp10 = 128.0\n tmp11 = tmp9 / tmp10\n tl.debug_barrier()\n tl.store(in_out_ptr0 + tl.full([1], 0, tl.int32), tmp11, None)\n\n\ndef call(args):\n arg0_1, arg1_1 = args\n args.clear()\n assert_size_stride(arg0_1, (128, 1), (1, 1))\n assert_size_stride(arg1_1, (128, 1), (1, 1))\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided_cuda((), (), torch.float32)\n buf1 = buf0\n del buf0\n get_raw_stream(0)\n triton_per_fused_clamp_mean_mul_rsub_sub_0[grid(1)](buf1, arg0_1,\n arg1_1, 1, 128, num_warps=2, num_stages=1)\n del arg0_1\n del arg1_1\n return buf1,\n\n\nclass ModelNew(nn.Module):\n """\n A model that computes Hinge Loss for binary classification tasks.\n\n Parameters:\n None\n """\n def __init__(self):\n super(ModelNew, self).__init__()\n\n def forward(self, input_0, input_1):\n arg0_1 = input_0\n arg1_1 = input_1\n output = call([arg0_1, arg1_1])\n return output[0]',
"source_file": "../KernelBench/KernelBench/level1/100_HingeLoss.py",
"reference": 'import torch\nimport torch.nn as nn\n\nclass Model(nn.Module):\n """\n A model that computes Hinge Loss for binary classification tasks.\n\n Parameters:\n None\n """\n def __init__(self):\n super(Model, self).__init__()\n\n def forward(self, predictions, targets):\n return torch.mean(torch.clamp(1 - predictions * targets, min=0))\n\nbatch_size = 128\ninput_shape = (1,)\ndim = 1\n\ndef get_inputs():\n return [torch.randn(batch_size, *input_shape), torch.randint(0, 2, (batch_size, 1)).float() * 2 - 1]\n\ndef get_init_inputs():\n return []',
}