|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include <boost/lexical_cast.hpp> |
|
#include <boost/unordered_set.hpp> |
|
|
|
#include "util/exception.hh" |
|
#include "util/tokenize_piece.hh" |
|
#include "moses/TranslationModel/PhraseDictionaryInterpolated.h" |
|
|
|
using namespace std; |
|
|
|
namespace Moses |
|
{ |
|
|
|
PhraseDictionaryInterpolated::PhraseDictionaryInterpolated |
|
(size_t numScoreComponent,size_t numInputScores,const PhraseDictionaryFeature* feature): |
|
PhraseDictionary(numScoreComponent,feature), |
|
m_targetPhrases(NULL), |
|
m_languageModels(NULL) {} |
|
|
|
bool PhraseDictionaryInterpolated::Load( |
|
const std::vector<FactorType> &input |
|
, const std::vector<FactorType> &output |
|
, const std::vector<std::string>& config |
|
, const std::vector<float> &weightT |
|
, size_t tableLimit |
|
, const LMList &languageModels |
|
, float weightWP) |
|
{ |
|
|
|
m_languageModels = &languageModels; |
|
m_weightT = weightT; |
|
m_tableLimit = tableLimit; |
|
m_weightWP = weightWP; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UTIL_THROW_IF(config.size() < 7, util::Exception, "Missing fields from phrase table configuration: expected at least 7"); |
|
UTIL_THROW_IF(config[4] != "naive", util::Exception, "Unsupported combination mode: '" << config[4] << "'"); |
|
|
|
|
|
for (size_t i = 5; i < config.size()-1; ++i) { |
|
m_dictionaries.push_back(DictionaryHandle(new PhraseDictionaryTreeAdaptor( |
|
GetFeature()->GetNumScoreComponents(), |
|
GetFeature()->GetNumInputScores(), |
|
GetFeature()))); |
|
bool ret = m_dictionaries.back()->Load( |
|
input, |
|
output, |
|
config[i], |
|
weightT, |
|
0, |
|
languageModels, |
|
weightWP); |
|
if (!ret) return ret; |
|
} |
|
|
|
|
|
for (util::TokenIter<util::SingleCharacter, false> featureWeights(config.back(), util::SingleCharacter(';')); featureWeights; ++featureWeights) { |
|
m_weights.push_back(vector<float>()); |
|
float sum = 0; |
|
for (util::TokenIter<util::SingleCharacter, false> tableWeights(*featureWeights, util::SingleCharacter(',')); tableWeights; ++tableWeights) { |
|
const float weight = boost::lexical_cast<float>(*tableWeights); |
|
m_weights.back().push_back(weight); |
|
sum += weight; |
|
} |
|
UTIL_THROW_IF(m_weights.back().size() != m_dictionaries.size(), util::Exception, |
|
"Number of weights (" << m_weights.back().size() << |
|
") does not match number of dictionaries to combine (" << m_dictionaries.size() << ")"); |
|
UTIL_THROW_IF(abs(sum - 1) > 0.01, util::Exception, "Weights not normalised"); |
|
|
|
} |
|
|
|
|
|
|
|
UTIL_THROW_IF(m_weights.size() != 1 && m_weights.size() != GetFeature()->GetNumScoreComponents()-1, util::Exception, "Unexpected number of weight sets"); |
|
|
|
if (m_weights.size() == 1) { |
|
while(m_weights.size() < GetFeature()->GetNumScoreComponents()-1) { |
|
m_weights.push_back(m_weights[0]); |
|
} |
|
} |
|
|
|
return true; |
|
} |
|
|
|
void PhraseDictionaryInterpolated::InitializeForInput(ttasksptr const& ttask) |
|
{ |
|
for (size_t i = 0; i < m_dictionaries.size(); ++i) { |
|
m_dictionaries[i]->InitializeForInput(ttask); |
|
} |
|
} |
|
|
|
typedef |
|
boost::unordered_set<TargetPhrase*,PhrasePtrHasher,PhrasePtrComparator> PhraseSet; |
|
|
|
|
|
TargetPhraseCollection::shared_ptr |
|
PhraseDictionaryInterpolated::GetTargetPhraseCollection(const Phrase& src) const |
|
{ |
|
|
|
delete m_targetPhrases; |
|
m_targetPhrases = new TargetPhraseCollection(); |
|
PhraseSet allPhrases; |
|
vector<PhraseSet> phrasesByTable(m_dictionaries.size()); |
|
for (size_t i = 0; i < m_dictionaries.size(); ++i) { |
|
TargetPhraseCollection::shared_ptr phrases = m_dictionaries[i]->GetTargetPhraseCollection(src); |
|
if (phrases) { |
|
for (TargetPhraseCollection::const_iterator j = phrases->begin(); |
|
j != phrases->end(); ++j) { |
|
allPhrases.insert(*j); |
|
phrasesByTable[i].insert(*j); |
|
} |
|
} |
|
} |
|
ScoreComponentCollection sparseVector; |
|
for (PhraseSet::const_iterator i = allPhrases.begin(); i != allPhrases.end(); ++i) { |
|
TargetPhrase* combinedPhrase = new TargetPhrase((Phrase)**i); |
|
|
|
|
|
combinedPhrase->SetSourcePhrase((*i)->GetSourcePhrase()); |
|
combinedPhrase->SetAlignTerm(&((*i)->GetAlignTerm())); |
|
combinedPhrase->SetAlignNonTerm(&((*i)->GetAlignTerm())); |
|
Scores combinedScores(GetFeature()->GetNumScoreComponents()); |
|
for (size_t j = 0; j < phrasesByTable.size(); ++j) { |
|
PhraseSet::const_iterator tablePhrase = phrasesByTable[j].find(combinedPhrase); |
|
if (tablePhrase != phrasesByTable[j].end()) { |
|
Scores tableScores = (*tablePhrase)->GetScoreBreakdown() |
|
.GetScoresForProducer(GetFeature()); |
|
|
|
for (size_t k = 0; k < tableScores.size()-1; ++k) { |
|
|
|
combinedScores[k] += m_weights[k][j] * exp(tableScores[k]); |
|
|
|
} |
|
|
|
} |
|
} |
|
|
|
|
|
for (size_t k = 0; k < combinedScores.size()-1; ++k) { |
|
|
|
combinedScores[k] = log(combinedScores[k]); |
|
|
|
} |
|
|
|
combinedScores.back() = 1; |
|
combinedPhrase->SetScore( |
|
GetFeature(), |
|
combinedScores, |
|
sparseVector, |
|
m_weightT, |
|
m_weightWP, |
|
*m_languageModels); |
|
|
|
m_targetPhrases->Add(combinedPhrase); |
|
} |
|
|
|
m_targetPhrases->Prune(true,m_tableLimit); |
|
|
|
|
|
return m_targetPhrases; |
|
} |
|
|
|
} |
|
|