Spaces:
Sleeping
Sleeping
""" | |
Tests for BLEU translation evaluation metric | |
""" | |
import io | |
import unittest | |
from nltk.data import find | |
from nltk.translate.bleu_score import ( | |
SmoothingFunction, | |
brevity_penalty, | |
closest_ref_length, | |
corpus_bleu, | |
modified_precision, | |
sentence_bleu, | |
) | |
class TestBLEU(unittest.TestCase): | |
def test_modified_precision(self): | |
""" | |
Examples from the original BLEU paper | |
https://www.aclweb.org/anthology/P02-1040.pdf | |
""" | |
# Example 1: the "the*" example. | |
# Reference sentences. | |
ref1 = "the cat is on the mat".split() | |
ref2 = "there is a cat on the mat".split() | |
# Hypothesis sentence(s). | |
hyp1 = "the the the the the the the".split() | |
references = [ref1, ref2] | |
# Testing modified unigram precision. | |
hyp1_unigram_precision = float(modified_precision(references, hyp1, n=1)) | |
assert round(hyp1_unigram_precision, 4) == 0.2857 | |
# With assertAlmostEqual at 4 place precision. | |
self.assertAlmostEqual(hyp1_unigram_precision, 0.28571428, places=4) | |
# Testing modified bigram precision. | |
assert float(modified_precision(references, hyp1, n=2)) == 0.0 | |
# Example 2: the "of the" example. | |
# Reference sentences | |
ref1 = str( | |
"It is a guide to action that ensures that the military " | |
"will forever heed Party commands" | |
).split() | |
ref2 = str( | |
"It is the guiding principle which guarantees the military " | |
"forces always being under the command of the Party" | |
).split() | |
ref3 = str( | |
"It is the practical guide for the army always to heed " | |
"the directions of the party" | |
).split() | |
# Hypothesis sentence(s). | |
hyp1 = "of the".split() | |
references = [ref1, ref2, ref3] | |
# Testing modified unigram precision. | |
assert float(modified_precision(references, hyp1, n=1)) == 1.0 | |
# Testing modified bigram precision. | |
assert float(modified_precision(references, hyp1, n=2)) == 1.0 | |
# Example 3: Proper MT outputs. | |
hyp1 = str( | |
"It is a guide to action which ensures that the military " | |
"always obeys the commands of the party" | |
).split() | |
hyp2 = str( | |
"It is to insure the troops forever hearing the activity " | |
"guidebook that party direct" | |
).split() | |
references = [ref1, ref2, ref3] | |
# Unigram precision. | |
hyp1_unigram_precision = float(modified_precision(references, hyp1, n=1)) | |
hyp2_unigram_precision = float(modified_precision(references, hyp2, n=1)) | |
# Test unigram precision with assertAlmostEqual at 4 place precision. | |
self.assertAlmostEqual(hyp1_unigram_precision, 0.94444444, places=4) | |
self.assertAlmostEqual(hyp2_unigram_precision, 0.57142857, places=4) | |
# Test unigram precision with rounding. | |
assert round(hyp1_unigram_precision, 4) == 0.9444 | |
assert round(hyp2_unigram_precision, 4) == 0.5714 | |
# Bigram precision | |
hyp1_bigram_precision = float(modified_precision(references, hyp1, n=2)) | |
hyp2_bigram_precision = float(modified_precision(references, hyp2, n=2)) | |
# Test bigram precision with assertAlmostEqual at 4 place precision. | |
self.assertAlmostEqual(hyp1_bigram_precision, 0.58823529, places=4) | |
self.assertAlmostEqual(hyp2_bigram_precision, 0.07692307, places=4) | |
# Test bigram precision with rounding. | |
assert round(hyp1_bigram_precision, 4) == 0.5882 | |
assert round(hyp2_bigram_precision, 4) == 0.0769 | |
def test_brevity_penalty(self): | |
# Test case from brevity_penalty_closest function in mteval-v13a.pl. | |
# Same test cases as in the doctest in nltk.translate.bleu_score.py | |
references = [["a"] * 11, ["a"] * 8] | |
hypothesis = ["a"] * 7 | |
hyp_len = len(hypothesis) | |
closest_ref_len = closest_ref_length(references, hyp_len) | |
self.assertAlmostEqual( | |
brevity_penalty(closest_ref_len, hyp_len), 0.8669, places=4 | |
) | |
references = [["a"] * 11, ["a"] * 8, ["a"] * 6, ["a"] * 7] | |
hypothesis = ["a"] * 7 | |
hyp_len = len(hypothesis) | |
closest_ref_len = closest_ref_length(references, hyp_len) | |
assert brevity_penalty(closest_ref_len, hyp_len) == 1.0 | |
def test_zero_matches(self): | |
# Test case where there's 0 matches | |
references = ["The candidate has no alignment to any of the references".split()] | |
hypothesis = "John loves Mary".split() | |
# Test BLEU to nth order of n-grams, where n is len(hypothesis). | |
for n in range(1, len(hypothesis)): | |
weights = (1.0 / n,) * n # Uniform weights. | |
assert sentence_bleu(references, hypothesis, weights) == 0 | |
def test_full_matches(self): | |
# Test case where there's 100% matches | |
references = ["John loves Mary".split()] | |
hypothesis = "John loves Mary".split() | |
# Test BLEU to nth order of n-grams, where n is len(hypothesis). | |
for n in range(1, len(hypothesis)): | |
weights = (1.0 / n,) * n # Uniform weights. | |
assert sentence_bleu(references, hypothesis, weights) == 1.0 | |
def test_partial_matches_hypothesis_longer_than_reference(self): | |
references = ["John loves Mary".split()] | |
hypothesis = "John loves Mary who loves Mike".split() | |
# Since no 4-grams matches were found the result should be zero | |
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0 | |
self.assertAlmostEqual(sentence_bleu(references, hypothesis), 0.0, places=4) | |
# Checks that the warning has been raised because len(reference) < 4. | |
try: | |
self.assertWarns(UserWarning, sentence_bleu, references, hypothesis) | |
except AttributeError: | |
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2. | |
# @unittest.skip("Skipping fringe cases for BLEU.") | |
class TestBLEUFringeCases(unittest.TestCase): | |
def test_case_where_n_is_bigger_than_hypothesis_length(self): | |
# Test BLEU to nth order of n-grams, where n > len(hypothesis). | |
references = ["John loves Mary ?".split()] | |
hypothesis = "John loves Mary".split() | |
n = len(hypothesis) + 1 # | |
weights = (1.0 / n,) * n # Uniform weights. | |
# Since no n-grams matches were found the result should be zero | |
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0 | |
self.assertAlmostEqual( | |
sentence_bleu(references, hypothesis, weights), 0.0, places=4 | |
) | |
# Checks that the warning has been raised because len(hypothesis) < 4. | |
try: | |
self.assertWarns(UserWarning, sentence_bleu, references, hypothesis) | |
except AttributeError: | |
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2. | |
# Test case where n > len(hypothesis) but so is n > len(reference), and | |
# it's a special case where reference == hypothesis. | |
references = ["John loves Mary".split()] | |
hypothesis = "John loves Mary".split() | |
# Since no 4-grams matches were found the result should be zero | |
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0 | |
self.assertAlmostEqual( | |
sentence_bleu(references, hypothesis, weights), 0.0, places=4 | |
) | |
def test_empty_hypothesis(self): | |
# Test case where there's hypothesis is empty. | |
references = ["The candidate has no alignment to any of the references".split()] | |
hypothesis = [] | |
assert sentence_bleu(references, hypothesis) == 0 | |
def test_length_one_hypothesis(self): | |
# Test case where there's hypothesis is of length 1 in Smoothing method 4. | |
references = ["The candidate has no alignment to any of the references".split()] | |
hypothesis = ["Foo"] | |
method4 = SmoothingFunction().method4 | |
try: | |
sentence_bleu(references, hypothesis, smoothing_function=method4) | |
except ValueError: | |
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2. | |
def test_empty_references(self): | |
# Test case where there's reference is empty. | |
references = [[]] | |
hypothesis = "John loves Mary".split() | |
assert sentence_bleu(references, hypothesis) == 0 | |
def test_empty_references_and_hypothesis(self): | |
# Test case where both references and hypothesis is empty. | |
references = [[]] | |
hypothesis = [] | |
assert sentence_bleu(references, hypothesis) == 0 | |
def test_reference_or_hypothesis_shorter_than_fourgrams(self): | |
# Test case where the length of reference or hypothesis | |
# is shorter than 4. | |
references = ["let it go".split()] | |
hypothesis = "let go it".split() | |
# Checks that the value the hypothesis and reference returns is 0.0 | |
# exp(w_1 * 1 * w_2 * 1 * w_3 * 1 * w_4 * -inf) = 0 | |
self.assertAlmostEqual(sentence_bleu(references, hypothesis), 0.0, places=4) | |
# Checks that the warning has been raised. | |
try: | |
self.assertWarns(UserWarning, sentence_bleu, references, hypothesis) | |
except AttributeError: | |
pass # unittest.TestCase.assertWarns is only supported in Python >= 3.2. | |
class TestBLEUvsMteval13a(unittest.TestCase): | |
def test_corpus_bleu(self): | |
ref_file = find("models/wmt15_eval/ref.ru") | |
hyp_file = find("models/wmt15_eval/google.ru") | |
mteval_output_file = find("models/wmt15_eval/mteval-13a.output") | |
# Reads the BLEU scores from the `mteval-13a.output` file. | |
# The order of the list corresponds to the order of the ngrams. | |
with open(mteval_output_file) as mteval_fin: | |
# The numbers are located in the last 2nd line of the file. | |
# The first and 2nd item in the list are the score and system names. | |
mteval_bleu_scores = map(float, mteval_fin.readlines()[-2].split()[1:-1]) | |
with open(ref_file, encoding="utf8") as ref_fin: | |
with open(hyp_file, encoding="utf8") as hyp_fin: | |
# Whitespace tokenize the file. | |
# Note: split() automatically strip(). | |
hypothesis = list(map(lambda x: x.split(), hyp_fin)) | |
# Note that the corpus_bleu input is list of list of references. | |
references = list(map(lambda x: [x.split()], ref_fin)) | |
# Without smoothing. | |
for i, mteval_bleu in zip(range(1, 10), mteval_bleu_scores): | |
nltk_bleu = corpus_bleu( | |
references, hypothesis, weights=(1.0 / i,) * i | |
) | |
# Check that the BLEU scores difference is less than 0.005 . | |
# Note: This is an approximate comparison; as much as | |
# +/- 0.01 BLEU might be "statistically significant", | |
# the actual translation quality might not be. | |
assert abs(mteval_bleu - nltk_bleu) < 0.005 | |
# With the same smoothing method used in mteval-v13a.pl | |
chencherry = SmoothingFunction() | |
for i, mteval_bleu in zip(range(1, 10), mteval_bleu_scores): | |
nltk_bleu = corpus_bleu( | |
references, | |
hypothesis, | |
weights=(1.0 / i,) * i, | |
smoothing_function=chencherry.method3, | |
) | |
assert abs(mteval_bleu - nltk_bleu) < 0.005 | |
class TestBLEUWithBadSentence(unittest.TestCase): | |
def test_corpus_bleu_with_bad_sentence(self): | |
hyp = "Teo S yb , oe uNb , R , T t , , t Tue Ar saln S , , 5istsi l , 5oe R ulO sae oR R" | |
ref = str( | |
"Their tasks include changing a pump on the faulty stokehold ." | |
"Likewise , two species that are very similar in morphology " | |
"were distinguished using genetics ." | |
) | |
references = [[ref.split()]] | |
hypotheses = [hyp.split()] | |
try: # Check that the warning is raised since no. of 2-grams < 0. | |
with self.assertWarns(UserWarning): | |
# Verify that the BLEU output is undesired since no. of 2-grams < 0. | |
self.assertAlmostEqual( | |
corpus_bleu(references, hypotheses), 0.0, places=4 | |
) | |
except AttributeError: # unittest.TestCase.assertWarns is only supported in Python >= 3.2. | |
self.assertAlmostEqual(corpus_bleu(references, hypotheses), 0.0, places=4) | |
class TestBLEUWithMultipleWeights(unittest.TestCase): | |
def test_corpus_bleu_with_multiple_weights(self): | |
hyp1 = [ | |
"It", | |
"is", | |
"a", | |
"guide", | |
"to", | |
"action", | |
"which", | |
"ensures", | |
"that", | |
"the", | |
"military", | |
"always", | |
"obeys", | |
"the", | |
"commands", | |
"of", | |
"the", | |
"party", | |
] | |
ref1a = [ | |
"It", | |
"is", | |
"a", | |
"guide", | |
"to", | |
"action", | |
"that", | |
"ensures", | |
"that", | |
"the", | |
"military", | |
"will", | |
"forever", | |
"heed", | |
"Party", | |
"commands", | |
] | |
ref1b = [ | |
"It", | |
"is", | |
"the", | |
"guiding", | |
"principle", | |
"which", | |
"guarantees", | |
"the", | |
"military", | |
"forces", | |
"always", | |
"being", | |
"under", | |
"the", | |
"command", | |
"of", | |
"the", | |
"Party", | |
] | |
ref1c = [ | |
"It", | |
"is", | |
"the", | |
"practical", | |
"guide", | |
"for", | |
"the", | |
"army", | |
"always", | |
"to", | |
"heed", | |
"the", | |
"directions", | |
"of", | |
"the", | |
"party", | |
] | |
hyp2 = [ | |
"he", | |
"read", | |
"the", | |
"book", | |
"because", | |
"he", | |
"was", | |
"interested", | |
"in", | |
"world", | |
"history", | |
] | |
ref2a = [ | |
"he", | |
"was", | |
"interested", | |
"in", | |
"world", | |
"history", | |
"because", | |
"he", | |
"read", | |
"the", | |
"book", | |
] | |
weight_1 = (1, 0, 0, 0) | |
weight_2 = (0.25, 0.25, 0.25, 0.25) | |
weight_3 = (0, 0, 0, 0, 1) | |
bleu_scores = corpus_bleu( | |
list_of_references=[[ref1a, ref1b, ref1c], [ref2a]], | |
hypotheses=[hyp1, hyp2], | |
weights=[weight_1, weight_2, weight_3], | |
) | |
assert bleu_scores[0] == corpus_bleu( | |
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_1 | |
) | |
assert bleu_scores[1] == corpus_bleu( | |
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_2 | |
) | |
assert bleu_scores[2] == corpus_bleu( | |
[[ref1a, ref1b, ref1c], [ref2a]], [hyp1, hyp2], weight_3 | |
) | |