Spaces:
Runtime error
Runtime error
import torch | |
def compute_accuracy(pad_outputs, pad_targets, ignore_label): | |
"""Calculate accuracy. | |
Args: | |
pad_outputs (LongTensor): Prediction tensors (B, Lmax). | |
pad_targets (LongTensor): Target label tensors (B, Lmax). | |
ignore_label (int): Ignore label id. | |
Returns: | |
float: Accuracy value (0.0 - 1.0). | |
""" | |
mask = pad_targets != ignore_label | |
numerator = torch.sum( | |
pad_outputs.masked_select(mask) == pad_targets.masked_select(mask) | |
) | |
denominator = torch.sum(mask) | |
return numerator.float() / denominator.float() #(FIX:MZY):return torch.Tensor type |