Spaces:
Sleeping
Sleeping
File size: 2,317 Bytes
5caedb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
import logging
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import pandas as pd
from numpy.typing import NDArray
logger = logging.getLogger(__name__)
def mse_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
target = np.array(
[[float(t) for t in text.split(",")] for text in results["target_text"]]
)
predictions = np.array(results["predictions"])
if len(target) != len(predictions):
raise ValueError(
f"Length of target ({len(target)}) and predictions ({len(predictions)}) "
"should be the same."
)
if len(target) == 0:
raise ValueError("No data to calculate MSE score")
return ((target - predictions) ** 2).mean(axis=1).reshape(-1).astype("float")
def mae_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
target = np.array(
[[float(t) for t in text.split(",")] for text in results["target_text"]]
)
predictions = np.array(results["predictions"])
if len(target) != len(predictions):
raise ValueError(
f"Length of target ({len(target)}) and predictions ({len(predictions)}) "
"should be the same."
)
if len(target) == 0:
raise ValueError("No data to calculate MAE score")
return np.abs(target - predictions).mean(axis=1).reshape(-1).astype("float")
class Metrics:
"""
Metrics factory. Returns:
- metric value
- should it be maximized or minimized
- Reduce function
Maximized or minimized is needed for early stopping (saving best checkpoint)
Reduce function to generate a single metric value, usually "mean" or "none"
"""
_metrics = {
"MSE": (mse_score, "min", "mean"),
"MAE": (mae_score, "min", "mean"),
}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._metrics.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Metrics.
Args:
name: metrics name
Returns:
A class to build the Metrics
"""
return cls._metrics.get(name, cls._metrics["MSE"])
|