import sys from typing import Callable, Optional, Sequence, TypeVar, Union import nltk import numpy as np from fuzzywuzzy import fuzz from rouge import Rouge # increase recursion depth to ensure ROUGE can be calculated for long sentences if sys.getrecursionlimit() < 10_000: sys.setrecursionlimit(10_000) def bleu(gold: list[str], pred: list[str]) -> float: """ Calculate BLEU score, using smoothing method 2 with auto reweighting, in the range of 0~100. :param gold: list of gold tokens :param pred: list of predicted tokens :return: BLEU score """ if len(pred) == 0 or len(gold) == 0: return 0.0 return 100.0 * nltk.translate.bleu_score.sentence_bleu( [gold], pred, smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2, auto_reweigh=True, ) def batch_bleu(golds: list[list[str]], preds: list[list[str]]) -> list[float]: """ Calculate BLEU score for a batch of sentences. :param golds: list of gold sentences :param preds: list of predicted sentences :return: list of BLEU scores """ if len(golds) != len(preds): raise ValueError('golds and preds must have the same length') return [bleu(gold, pred) for gold, pred in zip(golds, preds)] def corpus_bleu(golds: list[list[str]], preds: list[list[str]]) -> float: """ Calculate corpus-level BLEU score for a batch of sentences. :param golds: list of gold sentences :param preds: list of predicted sentences :return: corpus-level BLEU score """ if len(golds) != len(preds): raise ValueError('golds and preds must have the same length') return 100.0 * nltk.translate.bleu_score.corpus_bleu( [[gold] for gold in golds], preds, smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2, auto_reweigh=True, ) def edit_sim( gold: Union[str, list[str]], pred: Union[str, list[str]], sep: str = ' ' ) -> float: """ Calculate char-level edit similarity, in the range of 0~100. :param gold: gold sentence or list of gold tokens :param pred: predicted sentence or list of predicted tokens :param sep: separator between tokens :return: char-level edit similarity """ if len(pred) == 0 or len(gold) == 0: return 0.0 if isinstance(gold, list): gold = sep.join(gold) if isinstance(pred, list): pred = sep.join(pred) return fuzz.ratio(gold, pred) def batch_edit_sim( golds: list[Union[str, list[str]]], preds: list[Union[str, list[str]]], sep: str = ' ', ) -> list[float]: """ Calculate char-level edit similarity for a batch of sentences. :param golds: list of gold sentences :param preds: list of predicted sentences :param sep: separator between tokens :return: list of char-level edit similarity """ if len(golds) != len(preds): raise ValueError('golds and preds must have the same length') return [edit_sim(gold, pred, sep) for gold, pred in zip(golds, preds)] T = TypeVar('T') def exact_match(gold: T, pred: T) -> float: """ Calculate exact match accuracy, in the range of {0, 100}. :param gold: gold sentence or list of gold tokens :param pred: predicted sentence or list of predicted tokens :return: exact match accuracy """ if len(pred) == 0 or len(gold) == 0: return 0.0 return 100.0 if gold == pred else 0.0 def batch_exact_match(golds: list[T], preds: list[T]) -> list[float]: """ Calculate exact match accuracy for a batch of sentences. :param golds: list of gold sentences :param preds: list of predicted sentences :return: list of exact match accuracy """ if len(golds) != len(preds): raise ValueError('golds and preds must have the same length') return [exact_match(gold, pred) for gold, pred in zip(golds, preds)] def rouge_l( gold: Union[str, list[str]], pred: Union[str, list[str]], sep: str = ' ' ) -> dict[str, float]: """ Calculate ROUGE-L F1, precision, and recall scores, in the range of 0~100. :param gold: gold sentence or list of gold tokens :param pred: predicted sentence or list of predicted tokens :return: {"p": precision, "r": recall, "f": F1} """ if len(pred) == 0 or len(gold) == 0: return {'p': 0.0, 'r': 0.0, 'f': 0.0} if isinstance(gold, list): gold = sep.join(gold) if isinstance(pred, list): pred = sep.join(pred) try: rouge = Rouge() scores = rouge.get_scores(hyps=pred, refs=gold, avg=True) return {x: scores['rouge-l'][x] * 100.0 for x in ['p', 'r', 'f']} except ValueError: return {'p': 0.0, 'r': 0.0, 'f': 0.0} def batch_rouge_l( golds: list[Union[str, list[str]]], preds: list[Union[str, list[str]]], sep: str = ' ', ) -> dict[str, list[float]]: """ Calculate ROUGE-L F1, precision, and recall scores for a batch of sentences. :param golds: list of gold sentences :param preds: list of predicted sentences :param sep: separator between tokens :return: list of {"p": precision, "r": recall, "f": F1} """ if len(golds) != len(preds): raise ValueError('golds and preds must have the same length') scores = [rouge_l(gold, pred, sep) for gold, pred in zip(golds, preds)] return {x: [score[x] for score in scores] for x in ['p', 'r', 'f']} def accuracy( gold: list[str], pred: list[str], ignore: Optional[Sequence[str]] = None, ) -> float: """ Calculate token-level accuracy, in the range of 0~100. If gold and pred are not the same length, the longer one would be truncated. :param gold: list of gold tokens :param pred: list of predicted tokens :param ignore: list of (gold) tokens to ignore :return: accuracy """ if len(pred) == 0 or len(gold) == 0: return 0.0 if ignore is None: ignore = [] i = 0 total = 0 match = 0 while i < len(gold) and i < len(pred): if gold[i] in ignore: i += 1 continue total += 1 if gold[i] == pred[i]: match += 1 i += 1 if total == 0: return 0.0 return 100.0 * match / total def batch_accuracy( golds: list[list[str]], preds: list[list[str]], ignore: Optional[Sequence[str]] = None, ) -> list[float]: """ Calculate token-level accuracy for a batch of sentences. :param golds: list of gold sentences :param preds: list of predicted sentences :param ignore: list of (gold) tokens to ignore :return: list of accuracy """ if len(golds) != len(preds): raise ValueError('golds and preds must have the same length') return [accuracy(gold, pred, ignore) for gold, pred in zip(golds, preds)] def first_match_to_topk( first_match_list: list[int], k_values: list[int] ) -> dict[int, list[float]]: """ Calculate top-k accuracy with the first match ranks (1-indexed). :param first_match: first match ranks (1-indexed) :param k_values: k values to consider :return: a mapping from k to top-k accuracies (ranging from 0~100) """ return {k: [100.0 if x <= k else 0.0 for x in first_match_list] for k in k_values} def pass_at_k(n: int, c: int, k: int) -> float: """ Sample pass@k metric according to the Codex paper, but in the scale of 0~100. :param n: total number of samples :param c: number of correct samples :param k: k in pass@$k$ """ if n < k or (n - c) < k: # fallback to the (1 - (1-p)^k) formula return (1 - (1 - (c / n)) ** k) * 100 else: return (1.0 - np.prod(1.0 - k / np.arange(n - c + 1, n + 1)).item()) * 100 def self_bleu(samples: list[list[str]]) -> float: """ Calculate self-BLEU among the samples. :param samples: the chosen m samples :return: self-BLEU """ if len(samples) == 0: return 100.0 scores = [] for i in range(len(samples)): scores.append( 100.0 * nltk.translate.bleu_score.sentence_bleu( [samples[j] for j in range(len(samples)) if j != i], samples[i], smoothing_function=nltk.translate.bleu_score.SmoothingFunction().method2, auto_reweigh=True, ) ) return np.mean(scores).item() def self_edit_distance(samples: list[Union[str, list[str]]], sep=' ') -> float: """ Calculate self-edit-distance among the samples. :param samples: the chosen m samples :param sep: the separator between tokens :return: self-edit-distance """ if len(samples) == 0: return 0.0 scores = [] for i in range(len(samples)): sample_i = samples[i] if not isinstance(sample_i, str): sample_i = sep.join(sample_i) for j in range(len(samples)): if i == j: continue sample_j = samples[j] if not isinstance(sample_j, str): sample_j = sep.join(sample_j) scores.append(100 - fuzz.ratio(sample_i, sample_j)) return np.mean(scores).item() QUALITY_METRICS: dict[str, Callable[[list[str], list[str]], float]] = { 'bleu': bleu, 'xmatch': exact_match, 'edit-sim': edit_sim, 'rouge-f': lambda g, p: rouge_l(g, p)['f'], 'rouge-p': lambda g, p: rouge_l(g, p)['p'], 'rouge-r': lambda g, p: rouge_l(g, p)['r'], }