File size: 1,269 Bytes
e0c2d04 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
from typing import List, Union
import numpy as np
import torch
from transformer_deploy.benchmarks.utils import compare_outputs, to_numpy
def check_accuracy(
engine_name: str,
pytorch_output: List[torch.Tensor],
engine_output: List[Union[np.ndarray, torch.Tensor]],
tolerance: float,
) -> None:
"""
Compare engine predictions with a reference.
Assert that the difference is under a threshold.
:param engine_name: string used in error message, if any
:param pytorch_output: reference output used for the comparaison
:param engine_output: output from the engine
:param tolerance: if difference in outputs is above threshold, an error will be raised
"""
pytorch_output = to_numpy(pytorch_output)
engine_output = to_numpy(engine_output)
discrepency = compare_outputs(pytorch_output=pytorch_output, engine_output=engine_output)
assert discrepency <= tolerance, (
f"{engine_name} discrepency is too high ({discrepency:.2f} >= {tolerance}):\n"
f"Pythorch:\n{pytorch_output}\n"
f"VS\n"
f"Engine:\n{engine_output}\n"
f"Diff:\n"
f"{torch.asarray(pytorch_output) - torch.asarray(engine_output)}\n"
"Tolerance can be increased with --atol parameter."
)
|