sumen / src /utils /metrics.py
hoang-quoc-trung's picture
Upload 8 files
3d52ce7 verified
import nltk
import evaluate
from nltk import edit_distance as compute_edit_distance
from src.utils.common_utils import compute_exprate
class Metrics:
def __init__(self, processor):
self.processor = processor
self.bleu = evaluate.load("bleu")
self.wer = evaluate.load("wer")
self.exact_match = evaluate.load("exact_match")
def compute_metrics(self, pred):
labels_ids = pred.label_ids
pred_ids = pred.predictions
pred_str = self.processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)
labels_ids[labels_ids == -100] = self.processor.tokenizer.pad_token_id
label_str = self.processor.tokenizer.batch_decode(labels_ids, skip_special_tokens=True)
total_edit_distance, total_bleu, total_exact_match = 0, 0, 0
for i in range(len(pred_str)):
# Compute edit distance score
edit_distance = compute_edit_distance(
pred_str[i],
label_str[i]
)/max(len(pred_str[i]),len(label_str[i]))
total_edit_distance = total_edit_distance + edit_distance
# Compute bleu score
try:
bleu = self.bleu.compute(
predictions=[pred_str[i]],
references=[label_str[i]],
max_order=4 # Maximum n-gram order to use when computing BLEU score
)
total_bleu += bleu['bleu']
except ZeroDivisionError:
total_bleu+=0
# Compute exact match score
exact_match = self.exact_match.compute(
predictions=[pred_str[i]],
references=[label_str[i]],
regexes_to_ignore=[' ']
)
total_exact_match += exact_match['exact_match']
bleu = total_bleu / len(pred_str)
exact_match = total_exact_match / len(pred_str)
# Convert minimun edit distance score to maximun edit distance score
edit_distance = 1 - (total_edit_distance / len(pred_str))
# Compute word error rate score
wer = self.wer.compute(predictions=pred_str, references=label_str)
# Compute expression rate score
exprate, error_1, error_2, error_3 = compute_exprate(
predictions=pred_str,
references=label_str
)
return {
"bleu": round(bleu*100, 2),
"maximun_edit_distance": round(edit_distance*100, 2),
"exact_match": round(exact_match*100, 2),
"wer": round(wer*100, 2),
"exprate": round(exprate*100, 2),
"exprate_error_1": round(error_1*100, 2),
"exprate_error_2": round(error_2*100, 2),
"exprate_error_3": round(error_3*100, 2),
}