from typing import Dict | |
import torch | |
from torchmetrics import functional as FM | |
def regression_metrics(preds: torch.Tensor, | |
target: torch.Tensor) -> Dict[str, torch.Tensor]: | |
""" | |
get_classification_metrics | |
Return some metrics evaluation the classification task | |
Parameters | |
---------- | |
preds : torch.Tensor | |
logits, probs | |
target : torch.Tensor | |
targets label | |
Returns | |
------- | |
Dict[str, torch.Tensor] | |
_description_ | |
""" | |
mse: torch.Tensor = FM.mean_squared_error(preds=preds, target=target) | |
mape: torch.Tensor = FM.mean_absolute_percentage_error(preds=preds, | |
target=target) | |
return dict(mse=mse, mape=mape) | |