File size: 641 Bytes
35c1cfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch

def compute_accuracy(pad_outputs, pad_targets, ignore_label):
    """Calculate accuracy.



    Args:

        pad_outputs (LongTensor): Prediction tensors (B, Lmax).

        pad_targets (LongTensor): Target label tensors (B, Lmax).

        ignore_label (int): Ignore label id.



    Returns:

        float: Accuracy value (0.0 - 1.0).



    """
    mask = pad_targets != ignore_label
    numerator = torch.sum(
        pad_outputs.masked_select(mask) == pad_targets.masked_select(mask)
    )
    denominator = torch.sum(mask)
    return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type