#include "BleuScorer.h" |
#define BOOST_TEST_MODULE MertBleuScorer |
#include <boost/test/unit_test.hpp> |
#include <cmath> |
#include "Ngram.h" |
#include "Vocabulary.h" |
#include "Util.h" |
using namespace MosesTuning; |
namespace |
{ |
NgramCounts* g_counts = NULL; |
NgramCounts* GetNgramCounts() |
{ |
assert(g_counts); |
return g_counts; |
} |
void SetNgramCounts(NgramCounts* counts) |
{ |
g_counts = counts; |
} |
struct Unigram { |
Unigram(const std::string& a) { |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(a)); |
} |
NgramCounts::Key instance; |
}; |
struct Bigram { |
Bigram(const std::string& a, const std::string& b) { |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(a)); |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(b)); |
} |
NgramCounts::Key instance; |
}; |
struct Trigram { |
Trigram(const std::string& a, const std::string& b, const std::string& c) { |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(a)); |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(b)); |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(c)); |
} |
NgramCounts::Key instance; |
}; |
struct Fourgram { |
Fourgram(const std::string& a, const std::string& b, |
const std::string& c, const std::string& d) { |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(a)); |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(b)); |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(c)); |
instance.push_back(mert::VocabularyFactory::GetVocabulary()->Encode(d)); |
} |
NgramCounts::Key instance; |
}; |
bool CheckUnigram(const std::string& str) |
{ |
Unigram unigram(str); |
NgramCounts::Value v; |
return GetNgramCounts()->Lookup(unigram.instance, &v); |
} |
bool CheckBigram(const std::string& a, const std::string& b) |
{ |
Bigram bigram(a, b); |
NgramCounts::Value v; |
return GetNgramCounts()->Lookup(bigram.instance, &v); |
} |
bool CheckTrigram(const std::string& a, const std::string& b, |
const std::string& c) |
{ |
Trigram trigram(a, b, c); |
NgramCounts::Value v; |
return GetNgramCounts()->Lookup(trigram.instance, &v); |
} |
bool CheckFourgram(const std::string& a, const std::string& b, |
const std::string& c, const std::string& d) |
{ |
Fourgram fourgram(a, b, c, d); |
NgramCounts::Value v; |
return GetNgramCounts()->Lookup(fourgram.instance, &v); |
} |
void SetUpReferences(BleuScorer& scorer) |
{ |
{ |
std::stringstream ref1; |
ref1 << "israeli officials are responsible for airport security" << std::endl; |
BOOST_CHECK(scorer.OpenReferenceStream(&ref1, 0)); |
} |
{ |
std::stringstream ref2; |
ref2 << "israel is in charge of the security at this airport" << std::endl; |
BOOST_CHECK(scorer.OpenReferenceStream(&ref2, 1)); |
} |
{ |
std::stringstream ref3; |
ref3 << "the security work for this airport is the responsibility of the israel government" |
<< std::endl; |
BOOST_CHECK(scorer.OpenReferenceStream(&ref3, 2)); |
} |
{ |
std::stringstream ref4; |
ref4 << "israli side was in charge of the security of this airport" << std::endl; |
BOOST_CHECK(scorer.OpenReferenceStream(&ref4, 3)); |
} |
} |
} |
BOOST_AUTO_TEST_CASE(bleu_reference_type) |
{ |
BleuScorer scorer; |
BOOST_CHECK_EQUAL(scorer.GetReferenceLengthType(), BleuScorer::CLOSEST); |
scorer.SetReferenceLengthType(BleuScorer::AVERAGE); |
BOOST_CHECK_EQUAL(scorer.GetReferenceLengthType(), BleuScorer::AVERAGE); |
scorer.SetReferenceLengthType(BleuScorer::SHORTEST); |
BOOST_CHECK_EQUAL(scorer.GetReferenceLengthType(), BleuScorer::SHORTEST); |
} |
BOOST_AUTO_TEST_CASE(bleu_reference_type_with_config) |
{ |
{ |
BleuScorer scorer("reflen:average"); |
BOOST_CHECK_EQUAL(scorer.GetReferenceLengthType(), BleuScorer::AVERAGE); |
} |
{ |
BleuScorer scorer("reflen:shortest"); |
BOOST_CHECK_EQUAL(scorer.GetReferenceLengthType(), BleuScorer::SHORTEST); |
} |
} |
BOOST_AUTO_TEST_CASE(bleu_count_ngrams) |
{ |
BleuScorer scorer; |
std::string line = "I saw a girl with a telescope ."; |
NgramCounts counts; |
BOOST_REQUIRE(scorer.CountNgrams(line, counts, kBleuNgramOrder) == 8); |
BOOST_CHECK_EQUAL((std::size_t)25, counts.size()); |
mert::Vocabulary* vocab = scorer.GetVocab(); |
BOOST_CHECK_EQUAL((std::size_t)7, vocab->size()); |
std::vector<std::string> res; |
Tokenize(line.c_str(), ' ', &res); |
std::vector<int> ids(res.size()); |
for (std::size_t i = 0; i < res.size(); ++i) { |
BOOST_CHECK(vocab->Lookup(res[i], &ids[i])); |
} |
SetNgramCounts(&counts); |
for (std::size_t i = 0; i < res.size(); ++i) { |
BOOST_CHECK(CheckUnigram(res[i])); |
} |
BOOST_CHECK(CheckBigram("I", "saw")); |
BOOST_CHECK(CheckBigram("saw", "a")); |
BOOST_CHECK(CheckBigram("a", "girl")); |
BOOST_CHECK(CheckBigram("girl", "with")); |
BOOST_CHECK(CheckBigram("with", "a")); |
BOOST_CHECK(CheckBigram("a", "telescope")); |
BOOST_CHECK(CheckBigram("telescope", ".")); |
BOOST_CHECK(CheckTrigram("I", "saw", "a")); |
BOOST_CHECK(CheckTrigram("saw", "a", "girl")); |
BOOST_CHECK(CheckTrigram("a", "girl", "with")); |
BOOST_CHECK(CheckTrigram("girl", "with", "a")); |
BOOST_CHECK(CheckTrigram("with", "a", "telescope")); |
BOOST_CHECK(CheckTrigram("a", "telescope", ".")); |
BOOST_CHECK(CheckFourgram("I", "saw", "a", "girl")); |
BOOST_CHECK(CheckFourgram("saw", "a", "girl", "with")); |
BOOST_CHECK(CheckFourgram("a", "girl", "with", "a")); |
BOOST_CHECK(CheckFourgram("girl", "with", "a", "telescope")); |
BOOST_CHECK(CheckFourgram("with", "a", "telescope", ".")); |
} |
BOOST_AUTO_TEST_CASE(bleu_clipped_counts) |
{ |
BleuScorer scorer; |
SetUpReferences(scorer); |
std::string line("israeli officials responsibility of airport safety"); |
ScoreStats entry; |
scorer.prepareStats(0, line, entry); |
BOOST_CHECK_EQUAL(entry.size(), (std::size_t)(2 * kBleuNgramOrder + 1)); |
BOOST_CHECK_EQUAL(entry.get(0), 5); |
BOOST_CHECK_EQUAL(entry.get(2), 2); |
BOOST_CHECK_EQUAL(entry.get(4), 0); |
BOOST_CHECK_EQUAL(entry.get(6), 0); |
BOOST_CHECK_EQUAL(entry.get(1), 6); |
BOOST_CHECK_EQUAL(entry.get(3), 5); |
BOOST_CHECK_EQUAL(entry.get(5), 4); |
BOOST_CHECK_EQUAL(entry.get(7), 3); |
} |
BOOST_AUTO_TEST_CASE(calculate_actual_score) |
{ |
BOOST_REQUIRE(4 == kBleuNgramOrder); |
std::vector<ScoreStatsType> stats(2 * kBleuNgramOrder + 1); |
BleuScorer scorer; |
stats[0] = 6; |
stats[1] = 6; |
stats[2] = 4; |
stats[3] = 5; |
stats[4] = 2; |
stats[5] = 4; |
stats[6] = 1; |
stats[7] = 3; |
stats[8] = 7; |
BOOST_CHECK_CLOSE(0.5115f, scorer.calculateScore(stats), 0.01); |
} |
BOOST_AUTO_TEST_CASE(sentence_level_bleu) |
{ |
BOOST_REQUIRE(4 == kBleuNgramOrder); |
std::vector<float> stats(2 * kBleuNgramOrder + 1); |
stats[0] = 6.0; |
stats[1] = 6.0; |
stats[2] = 4.0; |
stats[3] = 5.0; |
stats[4] = 2.0; |
stats[5] = 4.0; |
stats[6] = 1.0; |
stats[7] = 3.0; |
stats[8] = 7.0; |
BOOST_CHECK_CLOSE(0.5985f, smoothedSentenceBleu(stats), 0.01); |
BOOST_CHECK_CLOSE(0.5624f, smoothedSentenceBleu(stats, 0.5), 0.01 ); |
BOOST_CHECK_CLOSE(0.5067f, smoothedSentenceBleu(stats, 1.0, true), 0.01); |
} |