Spaces:
Sleeping
Sleeping
import nltk | |
import evaluate | |
from nltk import edit_distance as compute_edit_distance | |
from src.utils.common_utils import compute_exprate | |
class Metrics: | |
def __init__(self, processor): | |
self.processor = processor | |
self.bleu = evaluate.load("bleu") | |
self.wer = evaluate.load("wer") | |
self.exact_match = evaluate.load("exact_match") | |
def compute_metrics(self, pred): | |
labels_ids = pred.label_ids | |
pred_ids = pred.predictions | |
pred_str = self.processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) | |
labels_ids[labels_ids == -100] = self.processor.tokenizer.pad_token_id | |
label_str = self.processor.tokenizer.batch_decode(labels_ids, skip_special_tokens=True) | |
total_edit_distance, total_bleu, total_exact_match = 0, 0, 0 | |
for i in range(len(pred_str)): | |
# Compute edit distance score | |
edit_distance = compute_edit_distance( | |
pred_str[i], | |
label_str[i] | |
)/max(len(pred_str[i]),len(label_str[i])) | |
total_edit_distance = total_edit_distance + edit_distance | |
# Compute bleu score | |
try: | |
bleu = self.bleu.compute( | |
predictions=[pred_str[i]], | |
references=[label_str[i]], | |
max_order=4 # Maximum n-gram order to use when computing BLEU score | |
) | |
total_bleu += bleu['bleu'] | |
except ZeroDivisionError: | |
total_bleu+=0 | |
# Compute exact match score | |
exact_match = self.exact_match.compute( | |
predictions=[pred_str[i]], | |
references=[label_str[i]], | |
regexes_to_ignore=[' '] | |
) | |
total_exact_match += exact_match['exact_match'] | |
bleu = total_bleu / len(pred_str) | |
exact_match = total_exact_match / len(pred_str) | |
# Convert minimun edit distance score to maximun edit distance score | |
edit_distance = 1 - (total_edit_distance / len(pred_str)) | |
# Compute word error rate score | |
wer = self.wer.compute(predictions=pred_str, references=label_str) | |
# Compute expression rate score | |
exprate, error_1, error_2, error_3 = compute_exprate( | |
predictions=pred_str, | |
references=label_str | |
) | |
return { | |
"bleu": round(bleu*100, 2), | |
"maximun_edit_distance": round(edit_distance*100, 2), | |
"exact_match": round(exact_match*100, 2), | |
"wer": round(wer*100, 2), | |
"exprate": round(exprate*100, 2), | |
"exprate_error_1": round(error_1*100, 2), | |
"exprate_error_2": round(error_2*100, 2), | |
"exprate_error_3": round(error_3*100, 2), | |
} |