from typing import List, Union import numpy as np import torch from transformer_deploy.benchmarks.utils import compare_outputs, to_numpy def check_accuracy( engine_name: str, pytorch_output: List[torch.Tensor], engine_output: List[Union[np.ndarray, torch.Tensor]], tolerance: float, ) -> None: """ Compare engine predictions with a reference. Assert that the difference is under a threshold. :param engine_name: string used in error message, if any :param pytorch_output: reference output used for the comparaison :param engine_output: output from the engine :param tolerance: if difference in outputs is above threshold, an error will be raised """ pytorch_output = to_numpy(pytorch_output) engine_output = to_numpy(engine_output) discrepency = compare_outputs(pytorch_output=pytorch_output, engine_output=engine_output) assert discrepency <= tolerance, ( f"{engine_name} discrepency is too high ({discrepency:.2f} >= {tolerance}):\n" f"Pythorch:\n{pytorch_output}\n" f"VS\n" f"Engine:\n{engine_output}\n" f"Diff:\n" f"{torch.asarray(pytorch_output) - torch.asarray(engine_output)}\n" "Tolerance can be increased with --atol parameter." )