File size: 484 Bytes
79aac9d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
"""Tests for the `residual_rms` kernel.
Run `pytest tests/kernels/test_residual_rms.py`.
"""
from typing import List
import pytest
import torch
from residual_rms._ops import ops
from residual_rms.residual_rms import residual_rms
@pytest.mark.parametrize("shape", [(2, 3, 4, 5), (2, 3, 4, 5, 6)])
def test_residual_rms(shape: List[int]) -> None:
x = torch.randn(shape)
out = torch.zeros_like(x)
residual_rms(out, x)
assert torch.allclose(out, ops.residual_rms(x)) |