|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "LoaderCompact.h" |
|
|
|
#include "moses/AlignmentInfoCollection.h" |
|
#include "moses/InputFileStream.h" |
|
#include "moses/Util.h" |
|
#include "moses/Timer.h" |
|
#include "moses/Word.h" |
|
#include "Trie.h" |
|
|
|
#include <istream> |
|
#include <sstream> |
|
|
|
namespace Moses |
|
{ |
|
|
|
bool RuleTableLoaderCompact::Load(AllOptions const& opts, |
|
const std::vector<FactorType> &input, |
|
const std::vector<FactorType> &output, |
|
const std::string &inFile, |
|
size_t , |
|
RuleTableTrie &ruleTable) |
|
{ |
|
PrintUserTime("Start loading compact rule table"); |
|
|
|
InputFileStream inStream(inFile); |
|
LineReader reader(inStream); |
|
|
|
|
|
reader.ReadLine(); |
|
if (reader.m_line != "1") { |
|
std::cerr << "Unexpected compact rule table format: " << reader.m_line; |
|
return false; |
|
} |
|
|
|
|
|
std::vector<Word> vocab; |
|
LoadVocabularySection(reader, input, vocab); |
|
|
|
|
|
std::vector<Phrase> sourcePhrases; |
|
std::vector<size_t> sourceLhsIds; |
|
LoadPhraseSection(reader, vocab, sourcePhrases, sourceLhsIds); |
|
|
|
|
|
std::vector<Phrase> targetPhrases; |
|
std::vector<size_t> targetLhsIds; |
|
LoadPhraseSection(reader, vocab, targetPhrases, targetLhsIds); |
|
|
|
|
|
std::vector<const AlignmentInfo *> alignmentSets; |
|
LoadAlignmentSection(reader, alignmentSets, sourcePhrases); |
|
|
|
|
|
if (!LoadRuleSection(reader, vocab, sourcePhrases, targetPhrases, |
|
targetLhsIds, alignmentSets, |
|
ruleTable)) { |
|
return false; |
|
} |
|
|
|
|
|
SortAndPrune(ruleTable); |
|
|
|
return true; |
|
} |
|
|
|
void RuleTableLoaderCompact::LoadVocabularySection( |
|
LineReader &reader, |
|
const std::vector<FactorType> &factorTypes, |
|
std::vector<Word> &vocabulary) |
|
{ |
|
|
|
reader.ReadLine(); |
|
const size_t vocabSize = std::atoi(reader.m_line.c_str()); |
|
|
|
|
|
vocabulary.resize(vocabSize); |
|
for (size_t i = 0; i < vocabSize; ++i) { |
|
reader.ReadLine(); |
|
const size_t len = reader.m_line.size(); |
|
bool isNonTerm = (reader.m_line[0] == '[' && reader.m_line[len-1] == ']'); |
|
if (isNonTerm) { |
|
reader.m_line = reader.m_line.substr(1, len-2); |
|
} |
|
vocabulary[i].CreateFromString(Input, factorTypes, reader.m_line, isNonTerm); |
|
} |
|
} |
|
|
|
void RuleTableLoaderCompact::LoadPhraseSection( |
|
LineReader &reader, |
|
const std::vector<Word> &vocab, |
|
std::vector<Phrase> &rhsPhrases, |
|
std::vector<size_t> &lhsIds) |
|
{ |
|
|
|
reader.ReadLine(); |
|
const size_t phraseCount = std::atoi(reader.m_line.c_str()); |
|
|
|
|
|
rhsPhrases.resize(phraseCount, Phrase(0)); |
|
lhsIds.resize(phraseCount); |
|
std::vector<size_t> tokenPositions; |
|
for (size_t i = 0; i < phraseCount; ++i) { |
|
reader.ReadLine(); |
|
tokenPositions.clear(); |
|
FindTokens(tokenPositions, reader.m_line); |
|
const char *charLine = reader.m_line.c_str(); |
|
lhsIds[i] = std::atoi(charLine+tokenPositions[0]); |
|
for (size_t j = 1; j < tokenPositions.size(); ++j) { |
|
rhsPhrases[i].AddWord(vocab[std::atoi(charLine+tokenPositions[j])]); |
|
} |
|
} |
|
} |
|
|
|
void RuleTableLoaderCompact::LoadAlignmentSection( |
|
LineReader &reader, std::vector<const AlignmentInfo *> &alignmentSets, std::vector<Phrase> &sourcePhrases) |
|
{ |
|
|
|
reader.ReadLine(); |
|
const size_t alignmentSetCount = std::atoi(reader.m_line.c_str()); |
|
|
|
alignmentSets.resize(alignmentSetCount * 2); |
|
AlignmentInfo::CollType alignTerm, alignNonTerm; |
|
std::vector<std::string> tokens; |
|
std::vector<size_t> points; |
|
for (size_t i = 0; i < alignmentSetCount; ++i) { |
|
|
|
alignTerm.clear(); |
|
alignNonTerm.clear(); |
|
tokens.clear(); |
|
|
|
reader.ReadLine(); |
|
Tokenize(tokens, reader.m_line); |
|
std::vector<std::string>::const_iterator p; |
|
for (p = tokens.begin(); p != tokens.end(); ++p) { |
|
points.clear(); |
|
Tokenize<size_t>(points, *p, "-"); |
|
std::pair<size_t, size_t> alignmentPair(points[0], points[1]); |
|
|
|
if (sourcePhrases[i].GetWord(alignmentPair.first).IsNonTerminal()) { |
|
alignNonTerm.insert(alignmentPair); |
|
} else { |
|
alignTerm.insert(alignmentPair); |
|
} |
|
|
|
} |
|
alignmentSets[i*2] = AlignmentInfoCollection::Instance().Add(alignNonTerm); |
|
alignmentSets[i*2 + 1] = AlignmentInfoCollection::Instance().Add(alignTerm); |
|
} |
|
} |
|
|
|
bool RuleTableLoaderCompact::LoadRuleSection( |
|
LineReader &reader, |
|
const std::vector<Word> &vocab, |
|
const std::vector<Phrase> &sourcePhrases, |
|
const std::vector<Phrase> &targetPhrases, |
|
const std::vector<size_t> &targetLhsIds, |
|
const std::vector<const AlignmentInfo *> &alignmentSets, |
|
RuleTableTrie &ruleTable) |
|
{ |
|
|
|
reader.ReadLine(); |
|
const size_t ruleCount = std::atoi(reader.m_line.c_str()); |
|
|
|
|
|
const size_t numScoreComponents = ruleTable.GetNumScoreComponents(); |
|
std::vector<float> scoreVector(numScoreComponents); |
|
std::vector<size_t> tokenPositions; |
|
for (size_t i = 0; i < ruleCount; ++i) { |
|
reader.ReadLine(); |
|
|
|
tokenPositions.clear(); |
|
FindTokens(tokenPositions, reader.m_line); |
|
|
|
const char *charLine = reader.m_line.c_str(); |
|
|
|
|
|
|
|
const int sourcePhraseId = std::atoi(charLine+tokenPositions[0]); |
|
const int targetPhraseId = std::atoi(charLine+tokenPositions[1]); |
|
const int alignmentSetId = std::atoi(charLine+tokenPositions[2]); |
|
|
|
const Phrase &sourcePhrase = sourcePhrases[sourcePhraseId]; |
|
const Phrase &targetPhrasePhrase = targetPhrases[targetPhraseId]; |
|
const Word *targetLhs = new Word(vocab[targetLhsIds[targetPhraseId]]); |
|
Word sourceLHS("X"); |
|
const AlignmentInfo *alignNonTerm = alignmentSets[alignmentSetId]; |
|
|
|
|
|
for (size_t j = 0; j < numScoreComponents; ++j) { |
|
float score = std::atof(charLine+tokenPositions[3+j]); |
|
scoreVector[j] = FloorScore(TransformScore(score)); |
|
} |
|
if (reader.m_line[tokenPositions[3+numScoreComponents]] != ':') { |
|
std::cerr << "Size of scoreVector != number (" |
|
<< scoreVector.size() << "!=" << numScoreComponents |
|
<< ") of score components on line " << reader.m_lineNum; |
|
return false; |
|
} |
|
|
|
|
|
|
|
|
|
TargetPhrase *targetPhrase = new TargetPhrase(targetPhrasePhrase, &ruleTable); |
|
targetPhrase->SetAlignNonTerm(alignNonTerm); |
|
targetPhrase->SetTargetLHS(targetLhs); |
|
|
|
targetPhrase->EvaluateInIsolation(sourcePhrase, ruleTable.GetFeaturesToApply()); |
|
|
|
|
|
TargetPhraseCollection::shared_ptr coll; |
|
coll = GetOrCreateTargetPhraseCollection(ruleTable, sourcePhrase, |
|
*targetPhrase, &sourceLHS); |
|
coll->Add(targetPhrase); |
|
} |
|
|
|
return true; |
|
} |
|
|
|
} |
|
|