|
import time |
|
import torch |
|
import pytest |
|
import get_position_ids |
|
from reference import DummyModel |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
VISION_CONFIGS = [ |
|
{ |
|
"name": "one_segment", |
|
"input_ids": ( |
|
[10] * 5 + |
|
[151652, 151653] + |
|
[20] * 5 |
|
), |
|
"grid": [[2, 4, 6]] |
|
}, |
|
{ |
|
"name": "two_segments", |
|
"input_ids": ( |
|
[100] * 5 + |
|
[151652, 151653] + |
|
[101] * 5 + |
|
[151652, 151653] + |
|
[102] * 5 |
|
), |
|
"grid": [ |
|
[2, 4, 6], |
|
[3, 4, 6] |
|
], |
|
}, |
|
{ |
|
"name": "three_segments", |
|
"input_ids": ( |
|
[11] * 5 + |
|
[151652, 151653] + |
|
[12] * 6 + |
|
[151652, 151653] + |
|
[13] * 7 + |
|
[151652, 151653] + |
|
[14] * 8 |
|
), |
|
"grid": [ |
|
[2, 4, 6], |
|
[3, 6, 6], |
|
[4, 4, 8] |
|
], |
|
}, |
|
] |
|
|
|
CUDA_DEVICES = ["cuda"] |
|
SEEDS = [42] |
|
DTYPES = [torch.int32] |
|
|
|
|
|
@pytest.mark.parametrize("vision_config", |
|
VISION_CONFIGS, |
|
ids=[cfg["name"] for cfg in VISION_CONFIGS]) |
|
@pytest.mark.parametrize("seed", SEEDS) |
|
@pytest.mark.parametrize("device", CUDA_DEVICES) |
|
@torch.inference_mode() |
|
def test_get_position_ids(vision_config, seed, device): |
|
torch.manual_seed(seed) |
|
input_ids = torch.tensor(vision_config["input_ids"], dtype=torch.int32, device=device) |
|
image_grid_thw = torch.tensor(vision_config["grid"], dtype=torch.int32, device=device) |
|
|
|
|
|
dummy_model = DummyModel() |
|
|
|
|
|
torch.cuda.synchronize() |
|
start_ref = time.perf_counter() |
|
pos_ids_ref = dummy_model.get_position_ids(input_ids, image_grid_thw) |
|
torch.cuda.synchronize() |
|
end_ref = time.perf_counter() |
|
ref_time = (end_ref - start_ref) * 1000 |
|
print(f"\nVision config {vision_config['name']} - Reference time: {ref_time:.2f} ms") |
|
|
|
pos_ids_ref = pos_ids_ref.to(dtype=torch.int32) |
|
|
|
|
|
torch.cuda.synchronize() |
|
start_ext = time.perf_counter() |
|
out = torch.empty(pos_ids_ref.shape, dtype=torch.int32, device=device) |
|
get_position_ids.get_position_ids(out, input_ids, image_grid_thw) |
|
torch.cuda.synchronize() |
|
end_ext = time.perf_counter() |
|
ext_time = (end_ext - start_ext) * 1000 |
|
print(f"Vision config {vision_config['name']} - Extension time: {ext_time:.2f} ms\n") |
|
ext_out = out.clone() |
|
|
|
|
|
torch.testing.assert_close(ext_out.cpu(), pos_ids_ref.cpu()) |
|
|