|
|
|
#pragma once |
|
|
|
#include "moses/DecodeGraph.h" |
|
#include "moses/ForestInput.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/Syntax/BoundedPriorityContainer.h" |
|
#include "moses/Syntax/CubeQueue.h" |
|
#include "moses/Syntax/PHyperedge.h" |
|
#include "moses/Syntax/RuleTable.h" |
|
#include "moses/Syntax/RuleTableFF.h" |
|
#include "moses/Syntax/SHyperedgeBundle.h" |
|
#include "moses/Syntax/SVertex.h" |
|
#include "moses/Syntax/SVertexRecombinationEqualityPred.h" |
|
#include "moses/Syntax/SVertexRecombinationHasher.h" |
|
#include "moses/Syntax/SymbolEqualityPred.h" |
|
#include "moses/Syntax/SymbolHasher.h" |
|
#include "moses/Syntax/T2S/InputTree.h" |
|
#include "moses/Syntax/T2S/InputTreeBuilder.h" |
|
#include "moses/Syntax/T2S/InputTreeToForest.h" |
|
#include "moses/TreeInput.h" |
|
|
|
#include "DerivationWriter.h" |
|
#include "GlueRuleSynthesizer.h" |
|
#include "HyperTree.h" |
|
#include "RuleMatcherCallback.h" |
|
#include "TopologicalSorter.h" |
|
|
|
namespace Moses |
|
{ |
|
namespace Syntax |
|
{ |
|
namespace F2S |
|
{ |
|
|
|
template<typename RuleMatcher> |
|
Manager<RuleMatcher>::Manager(ttasksptr const& ttask) |
|
: Syntax::Manager(ttask) |
|
{ |
|
if (const ForestInput *p = dynamic_cast<const ForestInput*>(&m_source)) { |
|
m_forest = p->GetForest(); |
|
m_rootVertex = p->GetRootVertex(); |
|
m_sentenceLength = p->GetSize(); |
|
} else if (const TreeInput *p = dynamic_cast<const TreeInput*>(&m_source)) { |
|
T2S::InputTreeBuilder builder(options()->output.factor_order); |
|
T2S::InputTree tmpTree; |
|
builder.Build(*p, "Q", tmpTree); |
|
boost::shared_ptr<Forest> forest = boost::make_shared<Forest>(); |
|
m_rootVertex = T2S::InputTreeToForest(tmpTree, *forest); |
|
m_forest = forest; |
|
m_sentenceLength = p->GetSize(); |
|
} else { |
|
UTIL_THROW2("ERROR: F2S::Manager requires input to be a tree or forest"); |
|
} |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::Decode() |
|
{ |
|
|
|
const std::size_t popLimit = options()->cube.pop_limit; |
|
const std::size_t ruleLimit = options()->syntax.rule_limit; |
|
const std::size_t stackLimit = options()->search.stack_size; |
|
|
|
|
|
InitializeStacks(); |
|
|
|
|
|
InitializeRuleMatchers(); |
|
|
|
|
|
RuleMatcherCallback callback(m_stackMap, ruleLimit); |
|
|
|
|
|
GlueRuleSynthesizer glueRuleSynthesizer(*options(), *m_glueRuleTrie); |
|
|
|
|
|
std::vector<const Forest::Vertex *> sortedVertices; |
|
TopologicalSorter sorter; |
|
sorter.Sort(*m_forest, sortedVertices); |
|
|
|
|
|
for (std::vector<const Forest::Vertex *>::const_iterator |
|
p = sortedVertices.begin(); p != sortedVertices.end(); ++p) { |
|
const Forest::Vertex &vertex = **p; |
|
|
|
|
|
if (vertex.incoming.empty()) { |
|
if (vertex.pvertex.span.GetStartPos() > 0 && |
|
vertex.pvertex.span.GetEndPos() < m_sentenceLength-1 && |
|
IsUnknownSourceWord(vertex.pvertex.symbol)) { |
|
m_oovs.insert(vertex.pvertex.symbol); |
|
} |
|
continue; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
callback.ClearContainer(); |
|
for (typename std::vector<boost::shared_ptr<RuleMatcher> >::iterator |
|
q = m_mainRuleMatchers.begin(); q != m_mainRuleMatchers.end(); ++q) { |
|
(*q)->EnumerateHyperedges(vertex, callback); |
|
} |
|
|
|
|
|
const BoundedPriorityContainer<SHyperedgeBundle> &bundles = |
|
callback.GetContainer(); |
|
|
|
|
|
|
|
if (bundles.Size() == 0) { |
|
for (std::vector<Forest::Hyperedge *>::const_iterator p = |
|
vertex.incoming.begin(); p != vertex.incoming.end(); ++p) { |
|
glueRuleSynthesizer.SynthesizeRule(**p); |
|
} |
|
m_glueRuleMatcher->EnumerateHyperedges(vertex, callback); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
CubeQueue cubeQueue(bundles.Begin(), bundles.End()); |
|
std::size_t count = 0; |
|
std::vector<SHyperedge*> buffer; |
|
while (count < popLimit && !cubeQueue.IsEmpty()) { |
|
SHyperedge *hyperedge = cubeQueue.Pop(); |
|
|
|
|
|
hyperedge->head->pvertex = &(vertex.pvertex); |
|
|
|
buffer.push_back(hyperedge); |
|
++count; |
|
} |
|
|
|
|
|
SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; |
|
RecombineAndSort(buffer, stack); |
|
|
|
|
|
if (stackLimit > 0 && stack.size() > stackLimit) { |
|
stack.resize(stackLimit); |
|
} |
|
} |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::InitializeRuleMatchers() |
|
{ |
|
const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
|
for (std::size_t i = 0; i < ffs.size(); ++i) { |
|
RuleTableFF *ff = ffs[i]; |
|
|
|
|
|
|
|
|
|
const RuleTable *table = ff->GetTable(); |
|
assert(table); |
|
RuleTable *nonConstTable = const_cast<RuleTable*>(table); |
|
HyperTree *trie = dynamic_cast<HyperTree*>(nonConstTable); |
|
assert(trie); |
|
boost::shared_ptr<RuleMatcher> p(new RuleMatcher(*trie)); |
|
m_mainRuleMatchers.push_back(p); |
|
} |
|
|
|
|
|
|
|
|
|
m_glueRuleTrie.reset(new HyperTree(ffs[0])); |
|
m_glueRuleMatcher = boost::shared_ptr<RuleMatcher>( |
|
new RuleMatcher(*m_glueRuleTrie)); |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::InitializeStacks() |
|
{ |
|
|
|
assert(!m_forest->vertices.empty()); |
|
|
|
for (std::vector<Forest::Vertex *>::const_iterator |
|
p = m_forest->vertices.begin(); p != m_forest->vertices.end(); ++p) { |
|
const Forest::Vertex &vertex = **p; |
|
|
|
|
|
SVertexStack &stack = m_stackMap[&(vertex.pvertex)]; |
|
|
|
|
|
if (vertex.incoming.empty()) { |
|
boost::shared_ptr<SVertex> v(new SVertex()); |
|
v->best = 0; |
|
v->pvertex = &(vertex.pvertex); |
|
stack.push_back(v); |
|
} |
|
} |
|
} |
|
|
|
template<typename RuleMatcher> |
|
bool Manager<RuleMatcher>::IsUnknownSourceWord(const Word &w) const |
|
{ |
|
const std::size_t factorId = w[0]->GetId(); |
|
const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
|
for (std::size_t i = 0; i < ffs.size(); ++i) { |
|
RuleTableFF *ff = ffs[i]; |
|
const boost::unordered_set<std::size_t> &sourceTerms = |
|
ff->GetSourceTerminalSet(); |
|
if (sourceTerms.find(factorId) != sourceTerms.end()) { |
|
return false; |
|
} |
|
} |
|
return true; |
|
} |
|
|
|
template<typename RuleMatcher> |
|
const SHyperedge *Manager<RuleMatcher>::GetBestSHyperedge() const |
|
{ |
|
PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); |
|
assert(p != m_stackMap.end()); |
|
const SVertexStack &stack = p->second; |
|
assert(!stack.empty()); |
|
return stack[0]->best; |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::ExtractKBest( |
|
std::size_t k, |
|
std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, |
|
bool onlyDistinct) const |
|
{ |
|
kBestList.clear(); |
|
if (k == 0 || m_source.GetSize() == 0) { |
|
return; |
|
} |
|
|
|
|
|
PVertexToStackMap::const_iterator p = m_stackMap.find(&m_rootVertex->pvertex); |
|
assert(p != m_stackMap.end()); |
|
const SVertexStack &stack = p->second; |
|
assert(!stack.empty()); |
|
|
|
KBestExtractor extractor; |
|
|
|
if (!onlyDistinct) { |
|
|
|
extractor.Extract(stack, k, kBestList); |
|
return; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
const StaticData &staticData = StaticData::Instance(); |
|
const std::size_t nBestFactor = staticData.options()->nbest.factor; |
|
std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; |
|
|
|
|
|
KBestExtractor::KBestVec bigList; |
|
bigList.reserve(numDerivations); |
|
extractor.Extract(stack, numDerivations, bigList); |
|
|
|
|
|
std::set<Phrase> distinct; |
|
for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); |
|
kBestList.size() < k && p != bigList.end(); ++p) { |
|
boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; |
|
Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); |
|
if (distinct.insert(translation).second) { |
|
kBestList.push_back(derivation); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::RecombineAndSort( |
|
const std::vector<SHyperedge*> &buffer, SVertexStack &stack) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
typedef boost::unordered_map<SVertex *, SVertex *, |
|
SVertexRecombinationHasher, |
|
SVertexRecombinationEqualityPred> Map; |
|
Map map; |
|
for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); |
|
p != buffer.end(); ++p) { |
|
SHyperedge *h = *p; |
|
SVertex *v = h->head; |
|
assert(v->best == h); |
|
assert(v->recombined.empty()); |
|
std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); |
|
if (result.second) { |
|
continue; |
|
} |
|
|
|
|
|
|
|
SVertex *storedVertex = result.first->second; |
|
if (h->label.futureScore > storedVertex->best->label.futureScore) { |
|
|
|
storedVertex->recombined.push_back(storedVertex->best); |
|
storedVertex->best = h; |
|
} else { |
|
storedVertex->recombined.push_back(h); |
|
} |
|
h->head->best = 0; |
|
delete h->head; |
|
h->head = storedVertex; |
|
} |
|
|
|
|
|
stack.clear(); |
|
stack.reserve(map.size()); |
|
for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { |
|
stack.push_back(boost::shared_ptr<SVertex>(p->first)); |
|
} |
|
|
|
|
|
std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); |
|
} |
|
|
|
template<typename RuleMatcher> |
|
void Manager<RuleMatcher>::OutputDetailedTranslationReport( |
|
OutputCollector *collector) const |
|
{ |
|
const SHyperedge *best = GetBestSHyperedge(); |
|
if (best == NULL || collector == NULL) { |
|
return; |
|
} |
|
long translationId = m_source.GetTranslationId(); |
|
std::ostringstream out; |
|
DerivationWriter::Write(*best, translationId, out); |
|
collector->Write(translationId, out.str()); |
|
} |
|
|
|
} |
|
} |
|
} |
|
|