File size: 484 Bytes
79aac9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
"""Tests for the `residual_rms` kernel.

Run `pytest tests/kernels/test_residual_rms.py`.
"""

from typing import List

import pytest
import torch

from residual_rms._ops import ops
from residual_rms.residual_rms import residual_rms


@pytest.mark.parametrize("shape", [(2, 3, 4, 5), (2, 3, 4, 5, 6)])
def test_residual_rms(shape: List[int]) -> None:
    x = torch.randn(shape)
    out = torch.zeros_like(x)
    residual_rms(out, x)
    assert torch.allclose(out, ops.residual_rms(x))