File size: 2,393 Bytes
07423df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
import logging
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import pandas as pd
from numpy.typing import NDArray
from scipy.special import softmax
from sklearn.metrics import log_loss, roc_auc_score
logger = logging.getLogger(__name__)
def accuracy_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
predicted_text = np.array([int(text) for text in results["predicted_text"]])
target_text = np.array([int(text) for text in results["target_text"]])
return (predicted_text == target_text).astype("float")
def auc_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
logits = np.array(results["logits"])
target_text = np.array([int(text) for text in results["target_text"]])
if cfg.dataset.num_classes > 1:
target_text = np.eye(cfg.dataset.num_classes)[target_text]
return roc_auc_score(target_text, logits, multi_class="ovr")
def logloss_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
logits = np.array(results["logits"])
target_text = np.array([int(text) for text in results["target_text"]])
if cfg.dataset.num_classes > 1:
target_text = np.eye(cfg.dataset.num_classes)[target_text]
logits = softmax(logits, axis=1)
return log_loss(target_text, logits, eps=1e-7)
class Metrics:
"""
Metrics factory. Returns:
- metric value
- should it be maximized or minimized
- Reduce function
Maximized or minimized is needed for early stopping (saving best checkpoint)
Reduce function to generate a single metric value, usually "mean" or "none"
"""
_metrics = {
"AUC": (auc_score, "max", "mean"),
"Accuracy": (accuracy_score, "max", "mean"),
"LogLoss": (logloss_score, "min", "mean"),
}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._metrics.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Metrics.
Args:
name: metrics name
Returns:
A class to build the Metrics
"""
return cls._metrics.get(name, cls._metrics["LogLoss"])
|