Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform | |
def test_scaling_transform(): | |
import time | |
logit = torch.randn(16, 601) | |
start = time.time() | |
output_1 = inverse_scalar_transform(logit, 300) | |
print('t1', time.time() - start) | |
handle = InverseScalarTransform(300) | |
start = time.time() | |
output_2 = handle(logit) | |
print('t2', time.time() - start) | |
assert output_1.shape == output_2.shape == (16, 1) | |
assert (output_1 == output_2).all() | |