|
#include "GlobalLexicalModelUnlimited.h" |
|
#include <fstream> |
|
#include "moses/StaticData.h" |
|
#include "moses/InputFileStream.h" |
|
#include "moses/Hypothesis.h" |
|
#include "moses/TranslationTask.h" |
|
#include "util/string_piece_hash.hh" |
|
#include "util/string_stream.hh" |
|
|
|
using namespace std; |
|
|
|
namespace Moses |
|
{ |
|
GlobalLexicalModelUnlimited::GlobalLexicalModelUnlimited(const std::string &line) |
|
:StatelessFeatureFunction(0, line) |
|
{ |
|
UTIL_THROW(util::Exception, |
|
"GlobalLexicalModelUnlimited hasn't been refactored for new feature function framework yet"); |
|
|
|
const vector<string> modelSpec = Tokenize(line); |
|
|
|
for (size_t i = 0; i < modelSpec.size(); i++ ) { |
|
bool ignorePunctuation = true, biasFeature = false, restricted = false; |
|
size_t context = 0; |
|
string filenameSource, filenameTarget; |
|
vector< string > factors; |
|
vector< string > spec = Tokenize(modelSpec[i]," "); |
|
|
|
|
|
if (spec.size() > 0) { |
|
if (spec.size() != 2 && spec.size() != 3 && spec.size() != 4 && spec.size() != 6) { |
|
std::cerr << "Format of glm feature is <factor-src>-<factor-tgt> [ignore-punct] [use-bias] " |
|
<< "[context-type] [filename-src filename-tgt]"; |
|
|
|
} |
|
|
|
factors = Tokenize(spec[0],"-"); |
|
if (spec.size() >= 2) |
|
ignorePunctuation = Scan<size_t>(spec[1]); |
|
if (spec.size() >= 3) |
|
biasFeature = Scan<size_t>(spec[2]); |
|
if (spec.size() >= 4) |
|
context = Scan<size_t>(spec[3]); |
|
if (spec.size() == 6) { |
|
filenameSource = spec[4]; |
|
filenameTarget = spec[5]; |
|
restricted = true; |
|
} |
|
} else |
|
factors = Tokenize(modelSpec[i],"-"); |
|
|
|
if ( factors.size() != 2 ) { |
|
std::cerr << "Wrong factor definition for global lexical model unlimited: " << modelSpec[i]; |
|
|
|
} |
|
|
|
const vector<FactorType> inputFactors = Tokenize<FactorType>(factors[0],","); |
|
const vector<FactorType> outputFactors = Tokenize<FactorType>(factors[1],","); |
|
throw runtime_error("GlobalLexicalModelUnlimited should be reimplemented as a stateful feature"); |
|
GlobalLexicalModelUnlimited* glmu = NULL; |
|
|
|
if (restricted) { |
|
cerr << "loading word translation word lists from " << filenameSource << " and " << filenameTarget << endl; |
|
if (!glmu->Load(filenameSource, filenameTarget)) { |
|
std::cerr << "Unable to load word lists for word translation feature from files " |
|
<< filenameSource |
|
<< " and " |
|
<< filenameTarget; |
|
|
|
} |
|
} |
|
} |
|
} |
|
|
|
bool GlobalLexicalModelUnlimited::Load(const std::string &filePathSource, |
|
const std::string &filePathTarget) |
|
{ |
|
|
|
ifstream inFileSource(filePathSource.c_str()); |
|
if (!inFileSource) { |
|
cerr << "could not open file " << filePathSource << endl; |
|
return false; |
|
} |
|
|
|
std::string line; |
|
while (getline(inFileSource, line)) { |
|
m_vocabSource.insert(line); |
|
} |
|
|
|
inFileSource.close(); |
|
|
|
|
|
ifstream inFileTarget(filePathTarget.c_str()); |
|
if (!inFileTarget) { |
|
cerr << "could not open file " << filePathTarget << endl; |
|
return false; |
|
} |
|
|
|
while (getline(inFileTarget, line)) { |
|
m_vocabTarget.insert(line); |
|
} |
|
|
|
inFileTarget.close(); |
|
|
|
m_unrestricted = false; |
|
return true; |
|
} |
|
|
|
void GlobalLexicalModelUnlimited::InitializeForInput(ttasksptr const& ttask) |
|
{ |
|
UTIL_THROW_IF2(ttask->GetSource()->GetType() != SentenceInput, |
|
"GlobalLexicalModel works only with sentence input."); |
|
Sentence const* s = reinterpret_cast<Sentence const*>(ttask->GetSource().get()); |
|
m_local.reset(new ThreadLocalStorage); |
|
m_local->input = s; |
|
} |
|
|
|
void GlobalLexicalModelUnlimited::EvaluateWhenApplied(const Hypothesis& cur_hypo, ScoreComponentCollection* accumulator) const |
|
{ |
|
const Sentence& input = *(m_local->input); |
|
const TargetPhrase& targetPhrase = cur_hypo.GetCurrTargetPhrase(); |
|
|
|
for(size_t targetIndex = 0; targetIndex < targetPhrase.GetSize(); targetIndex++ ) { |
|
StringPiece targetString = targetPhrase.GetWord(targetIndex).GetString(0); |
|
|
|
if (m_ignorePunctuation) { |
|
|
|
char firstChar = targetString[0]; |
|
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); |
|
if(charIterator != m_punctuationHash.end()) |
|
continue; |
|
} |
|
|
|
if (m_biasFeature) { |
|
util::StringStream feature; |
|
feature << "glm_"; |
|
feature << targetString; |
|
feature << "~"; |
|
feature << "**BIAS**"; |
|
accumulator->SparsePlusEquals(feature.str(), 1); |
|
} |
|
|
|
boost::unordered_set<uint64_t> alreadyScored; |
|
for(size_t sourceIndex = 0; sourceIndex < input.GetSize(); sourceIndex++ ) { |
|
const StringPiece sourceString = input.GetWord(sourceIndex).GetString(0); |
|
|
|
|
|
if (m_ignorePunctuation) { |
|
|
|
char firstChar = sourceString[0]; |
|
CharHash::const_iterator charIterator = m_punctuationHash.find( firstChar ); |
|
if(charIterator != m_punctuationHash.end()) |
|
continue; |
|
} |
|
const uint64_t sourceHash = util::MurmurHashNative(sourceString.data(), sourceString.size()); |
|
|
|
if ( alreadyScored.find(sourceHash) == alreadyScored.end()) { |
|
bool sourceExists, targetExists; |
|
if (!m_unrestricted) { |
|
sourceExists = FindStringPiece(m_vocabSource, sourceString ) != m_vocabSource.end(); |
|
targetExists = FindStringPiece(m_vocabTarget, targetString) != m_vocabTarget.end(); |
|
} |
|
|
|
|
|
if (m_unrestricted || (sourceExists && targetExists)) { |
|
if (m_sourceContext) { |
|
if (sourceIndex == 0) { |
|
|
|
util::StringStream feature; |
|
feature << "glm_"; |
|
feature << targetString; |
|
feature << "~"; |
|
feature << "<s>,"; |
|
feature << sourceString; |
|
accumulator->SparsePlusEquals(feature.str(), 1); |
|
alreadyScored.insert(sourceHash); |
|
} |
|
|
|
|
|
for(int contextIndex = sourceIndex+1; contextIndex < input.GetSize(); contextIndex++ ) { |
|
StringPiece contextString = input.GetWord(contextIndex).GetString(0); |
|
bool contextExists; |
|
if (!m_unrestricted) |
|
contextExists = FindStringPiece(m_vocabSource, contextString ) != m_vocabSource.end(); |
|
|
|
if (m_unrestricted || contextExists) { |
|
util::StringStream feature; |
|
feature << "glm_"; |
|
feature << targetString; |
|
feature << "~"; |
|
feature << sourceString; |
|
feature << ","; |
|
feature << contextString; |
|
accumulator->SparsePlusEquals(feature.str(), 1); |
|
alreadyScored.insert(sourceHash); |
|
} |
|
} |
|
} else if (m_biphrase) { |
|
|
|
int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex; |
|
|
|
|
|
StringPiece targetContext; |
|
if (globalTargetIndex > 0) |
|
targetContext = cur_hypo.GetWord(globalTargetIndex-1).GetString(0); |
|
else |
|
targetContext = "<s>"; |
|
|
|
if (sourceIndex == 0) { |
|
StringPiece sourceTrigger = "<s>"; |
|
AddFeature(accumulator, sourceTrigger, sourceString, |
|
targetContext, targetString); |
|
} else |
|
for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { |
|
StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); |
|
bool sourceTriggerExists = false; |
|
if (!m_unrestricted) |
|
sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); |
|
|
|
if (m_unrestricted || sourceTriggerExists) |
|
AddFeature(accumulator, sourceTrigger, sourceString, |
|
targetContext, targetString); |
|
} |
|
|
|
|
|
StringPiece sourceContext; |
|
if (sourceIndex-1 >= 0) |
|
sourceContext = input.GetWord(sourceIndex-1).GetString(0); |
|
else |
|
sourceContext = "<s>"; |
|
|
|
if (globalTargetIndex == 0) { |
|
string targetTrigger = "<s>"; |
|
AddFeature(accumulator, sourceContext, sourceString, |
|
targetTrigger, targetString); |
|
} else |
|
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { |
|
StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); |
|
bool targetTriggerExists = false; |
|
if (!m_unrestricted) |
|
targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); |
|
|
|
if (m_unrestricted || targetTriggerExists) |
|
AddFeature(accumulator, sourceContext, sourceString, |
|
targetTrigger, targetString); |
|
} |
|
} else if (m_bitrigger) { |
|
|
|
int globalTargetIndex = cur_hypo.GetSize() - targetPhrase.GetSize() + targetIndex; |
|
|
|
if (sourceIndex == 0) { |
|
StringPiece sourceTrigger = "<s>"; |
|
bool sourceTriggerExists = true; |
|
|
|
if (globalTargetIndex == 0) { |
|
string targetTrigger = "<s>"; |
|
bool targetTriggerExists = true; |
|
|
|
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
|
AddFeature(accumulator, sourceTrigger, sourceString, |
|
targetTrigger, targetString); |
|
} else { |
|
|
|
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { |
|
StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); |
|
bool targetTriggerExists = false; |
|
if (!m_unrestricted) |
|
targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); |
|
|
|
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
|
AddFeature(accumulator, sourceTrigger, sourceString, |
|
targetTrigger, targetString); |
|
} |
|
} |
|
} |
|
|
|
else { |
|
|
|
for(int contextIndex = sourceIndex-1; contextIndex >= 0; contextIndex-- ) { |
|
StringPiece sourceTrigger = input.GetWord(contextIndex).GetString(0); |
|
bool sourceTriggerExists = false; |
|
if (!m_unrestricted) |
|
sourceTriggerExists = FindStringPiece(m_vocabSource, sourceTrigger ) != m_vocabSource.end(); |
|
|
|
if (globalTargetIndex == 0) { |
|
string targetTrigger = "<s>"; |
|
bool targetTriggerExists = true; |
|
|
|
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
|
AddFeature(accumulator, sourceTrigger, sourceString, |
|
targetTrigger, targetString); |
|
} else { |
|
|
|
for(int globalContextIndex = globalTargetIndex-1; globalContextIndex >= 0; globalContextIndex-- ) { |
|
StringPiece targetTrigger = cur_hypo.GetWord(globalContextIndex).GetString(0); |
|
bool targetTriggerExists = false; |
|
if (!m_unrestricted) |
|
targetTriggerExists = FindStringPiece(m_vocabTarget, targetTrigger ) != m_vocabTarget.end(); |
|
|
|
if (m_unrestricted || (sourceTriggerExists && targetTriggerExists)) |
|
AddFeature(accumulator, sourceTrigger, sourceString, |
|
targetTrigger, targetString); |
|
} |
|
} |
|
} |
|
} |
|
} else { |
|
util::StringStream feature; |
|
feature << "glm_"; |
|
feature << targetString; |
|
feature << "~"; |
|
feature << sourceString; |
|
accumulator->SparsePlusEquals(feature.str(), 1); |
|
alreadyScored.insert(sourceHash); |
|
|
|
} |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
void GlobalLexicalModelUnlimited::AddFeature(ScoreComponentCollection* accumulator, |
|
StringPiece sourceTrigger, StringPiece sourceWord, |
|
StringPiece targetTrigger, StringPiece targetWord) const |
|
{ |
|
util::StringStream feature; |
|
feature << "glm_"; |
|
feature << targetTrigger; |
|
feature << ","; |
|
feature << targetWord; |
|
feature << "~"; |
|
feature << sourceTrigger; |
|
feature << ","; |
|
feature << sourceWord; |
|
accumulator->SparsePlusEquals(feature.str(), 1); |
|
|
|
} |
|
|
|
} |
|
|