File size: 1,263 Bytes
4c41a36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from typing import Dict

import torch
from torchmetrics import functional as FM


def classification_metrics(
        preds: torch.Tensor,
        target: torch.Tensor,
        num_classes: int,
        average: str = 'macro',
        task: str = 'multiclass') -> Dict[str, torch.Tensor]:
    """
    get_classification_metrics
    Return some metrics evaluation the classification task

    Parameters
    ----------
    preds : torch.Tensor
        logits, probs
    target : torch.Tensor
        targets label

    Returns
    -------
    Dict[str, torch.Tensor]
        _description_
    """
    f1 = FM.f1_score(preds=preds,
                     target=target,
                     num_classes=num_classes,
                     task=task,
                     average=average)
    recall = FM.recall(preds=preds,
                       target=target,
                       num_classes=num_classes,
                       task=task,
                       average=average)
    precision = FM.precision(preds=preds,
                             target=target,
                             num_classes=num_classes,
                             task=task,
                             average=average)
    return dict(f1=f1, precision=precision, recall=recall)