drbh
feat: mrope position id kernel and reference
e52d1ec
import time
import torch
import pytest
import get_position_ids # noqa: E402
from reference import DummyModel
# Each configuration includes:
# - name: A label for the test case.
# - input_ids: A list of token IDs (with vision start (151652) and vision end (151653) tokens embedded).
# - grid: A list of [t, h, w] values (one per vision segment).
#
# The cases below include:
# 1. one_segment: a single vision segment.
# 2. two_segments: two vision segments with extra text tokens afterward.
# 3. three_segments: three vision segments.
VISION_CONFIGS = [
{
"name": "one_segment",
"input_ids": (
[10] * 5 + # 5 text tokens before vision segment
[151652, 151653] + # vision tokens for segment 1
[20] * 5 # 5 extra text tokens after vision segment
),
"grid": [[2, 4, 6]] # one vision segment grid
},
{
"name": "two_segments",
"input_ids": (
[100] * 5 + # 5 text tokens for segment 1
[151652, 151653] + # vision tokens for segment 1
[101] * 5 + # 5 text tokens for segment 2
[151652, 151653] + # vision tokens for segment 2
[102] * 5 # 5 extra text tokens after last vision segment
),
"grid": [
[2, 4, 6], # vision segment 1 grid
[3, 4, 6] # vision segment 2 grid
],
},
{
"name": "three_segments",
"input_ids": (
[11] * 5 + # Segment 1: 5 text tokens
[151652, 151653] + # vision tokens for segment 1
[12] * 6 + # Segment 2: 6 text tokens
[151652, 151653] + # vision tokens for segment 2
[13] * 7 + # Segment 3: 7 text tokens
[151652, 151653] + # vision tokens for segment 3
[14] * 8 # 8 extra text tokens after the last vision segment
),
"grid": [
[2, 4, 6], # vision segment 1 grid
[3, 6, 6], # vision segment 2 grid
[4, 4, 8] # vision segment 3 grid
],
},
]
CUDA_DEVICES = ["cuda"] # List of CUDA devices; you can add more if needed.
SEEDS = [42] # Seeds for reproducibility.
DTYPES = [torch.int32] # In our test the tokens and grid are created with int32.
@pytest.mark.parametrize("vision_config",
VISION_CONFIGS,
ids=[cfg["name"] for cfg in VISION_CONFIGS])
@pytest.mark.parametrize("seed", SEEDS)
@pytest.mark.parametrize("device", CUDA_DEVICES)
@torch.inference_mode()
def test_get_position_ids(vision_config, seed, device):
torch.manual_seed(seed)
input_ids = torch.tensor(vision_config["input_ids"], dtype=torch.int32, device=device)
image_grid_thw = torch.tensor(vision_config["grid"], dtype=torch.int32, device=device)
# Create a DummyModel instance from the reference implementation.
dummy_model = DummyModel()
# reference implementation
torch.cuda.synchronize()
start_ref = time.perf_counter()
pos_ids_ref = dummy_model.get_position_ids(input_ids, image_grid_thw)
torch.cuda.synchronize()
end_ref = time.perf_counter()
ref_time = (end_ref - start_ref) * 1000 # ms
print(f"\nVision config {vision_config['name']} - Reference time: {ref_time:.2f} ms")
# Convert reference output to int32 for comparison (since its returned as a float tensor).
pos_ids_ref = pos_ids_ref.to(dtype=torch.int32)
# kernel implementation
torch.cuda.synchronize()
start_ext = time.perf_counter()
out = torch.empty(pos_ids_ref.shape, dtype=torch.int32, device=device)
get_position_ids.get_position_ids(out, input_ids, image_grid_thw)
torch.cuda.synchronize()
end_ext = time.perf_counter()
ext_time = (end_ext - start_ext) * 1000 # ms
print(f"Vision config {vision_config['name']} - Extension time: {ext_time:.2f} ms\n")
ext_out = out.clone()
# verify the results
torch.testing.assert_close(ext_out.cpu(), pos_ids_ref.cpu())