|
|
|
#pragma once |
|
|
|
#include <iostream> |
|
#include <sstream> |
|
|
|
#include "moses/DecodeGraph.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/Syntax/BoundedPriorityContainer.h" |
|
#include "moses/Syntax/CubeQueue.h" |
|
#include "moses/Syntax/PHyperedge.h" |
|
#include "moses/Syntax/RuleTable.h" |
|
#include "moses/Syntax/RuleTableFF.h" |
|
#include "moses/Syntax/SHyperedgeBundle.h" |
|
#include "moses/Syntax/SVertex.h" |
|
#include "moses/Syntax/SVertexRecombinationEqualityPred.h" |
|
#include "moses/Syntax/SVertexRecombinationHasher.h" |
|
#include "moses/Syntax/SymbolEqualityPred.h" |
|
#include "moses/Syntax/SymbolHasher.h" |
|
|
|
#include "DerivationWriter.h" |
|
#include "OovHandler.h" |
|
#include "PChart.h" |
|
#include "RuleTrie.h" |
|
#include "SChart.h" |
|
|
|
namespace Moses |
|
{ |
|
namespace Syntax |
|
{ |
|
namespace S2T |
|
{ |
|
|
|
template<typename Parser> |
|
Manager<Parser>::Manager(ttasksptr const& ttask) |
|
: Syntax::Manager(ttask) |
|
, m_pchart(m_source.GetSize(), Parser::RequiresCompressedChart()) |
|
, m_schart(m_source.GetSize()) |
|
{ } |
|
|
|
template<typename Parser> |
|
void Manager<Parser>::InitializeCharts() |
|
{ |
|
|
|
for (std::size_t i = 0; i < m_source.GetSize(); ++i) { |
|
const Word &terminal = m_source.GetWord(i); |
|
|
|
|
|
PVertex tmp(Range(i,i), terminal); |
|
PVertex &pvertex = m_pchart.AddVertex(tmp); |
|
|
|
|
|
boost::shared_ptr<SVertex> v(new SVertex()); |
|
v->best = 0; |
|
v->pvertex = &pvertex; |
|
SChart::Cell &scell = m_schart.GetCell(i,i); |
|
SVertexStack stack(1, v); |
|
SChart::Cell::TMap::value_type x(terminal, stack); |
|
scell.terminalStacks.insert(x); |
|
} |
|
} |
|
|
|
template<typename Parser> |
|
void Manager<Parser>::InitializeParsers(PChart &pchart, |
|
std::size_t ruleLimit) |
|
{ |
|
const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
|
|
|
const std::vector<DecodeGraph*> &graphs = |
|
StaticData::Instance().GetDecodeGraphs(); |
|
|
|
UTIL_THROW_IF2(ffs.size() != graphs.size(), |
|
"number of RuleTables does not match number of decode graphs"); |
|
|
|
for (std::size_t i = 0; i < ffs.size(); ++i) { |
|
RuleTableFF *ff = ffs[i]; |
|
std::size_t maxChartSpan = graphs[i]->GetMaxChartSpan(); |
|
|
|
|
|
|
|
|
|
const RuleTable *table = ff->GetTable(); |
|
assert(table); |
|
RuleTable *nonConstTable = const_cast<RuleTable*>(table); |
|
boost::shared_ptr<Parser> parser; |
|
typename Parser::RuleTrie *trie = |
|
dynamic_cast<typename Parser::RuleTrie*>(nonConstTable); |
|
assert(trie); |
|
parser.reset(new Parser(pchart, *trie, maxChartSpan)); |
|
m_parsers.push_back(parser); |
|
} |
|
|
|
|
|
|
|
m_oovs.clear(); |
|
std::size_t maxOovWidth = 0; |
|
FindOovs(pchart, m_oovs, maxOovWidth); |
|
if (!m_oovs.empty()) { |
|
|
|
OovHandler<typename Parser::RuleTrie> oovHandler(*ffs[0]); |
|
m_oovRuleTrie = oovHandler.SynthesizeRuleTrie(m_oovs.begin(), m_oovs.end()); |
|
|
|
boost::shared_ptr<Parser> parser( |
|
new Parser(pchart, *m_oovRuleTrie, maxOovWidth)); |
|
m_parsers.push_back(parser); |
|
} |
|
} |
|
|
|
|
|
|
|
template<typename Parser> |
|
void Manager<Parser>::FindOovs(const PChart &pchart, boost::unordered_set<Word> &oovs, |
|
std::size_t maxOovWidth) |
|
{ |
|
|
|
std::vector<const RuleTrie *> tries; |
|
const std::vector<RuleTableFF*> &ffs = RuleTableFF::Instances(); |
|
for (std::size_t i = 0; i < ffs.size(); ++i) { |
|
const RuleTableFF *ff = ffs[i]; |
|
if (ff->GetTable()) { |
|
const RuleTrie *trie = dynamic_cast<const RuleTrie*>(ff->GetTable()); |
|
assert(trie); |
|
tries.push_back(trie); |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
oovs.clear(); |
|
maxOovWidth = 0; |
|
|
|
|
|
for (std::size_t i = 1; i < pchart.GetWidth()-1; ++i) { |
|
for (std::size_t j = i; j < pchart.GetWidth()-1; ++j) { |
|
std::size_t width = j-i+1; |
|
const PChart::Cell::TMap &map = pchart.GetCell(i,j).terminalVertices; |
|
for (PChart::Cell::TMap::const_iterator p = map.begin(); |
|
p != map.end(); ++p) { |
|
const Word &word = p->first; |
|
assert(!word.IsNonTerminal()); |
|
bool found = false; |
|
for (std::vector<const RuleTrie *>::const_iterator q = tries.begin(); |
|
q != tries.end(); ++q) { |
|
const RuleTrie *trie = *q; |
|
if (trie->HasPreterminalRule(word)) { |
|
found = true; |
|
break; |
|
} |
|
} |
|
if (!found) { |
|
oovs.insert(word); |
|
maxOovWidth = std::max(maxOovWidth, width); |
|
} |
|
} |
|
} |
|
} |
|
} |
|
|
|
template<typename Parser> |
|
void Manager<Parser>::Decode() |
|
{ |
|
|
|
const std::size_t popLimit = options()->cube.pop_limit; |
|
const std::size_t ruleLimit = options()->syntax.rule_limit; |
|
const std::size_t stackLimit = options()->search.stack_size; |
|
|
|
|
|
InitializeCharts(); |
|
|
|
|
|
InitializeParsers(m_pchart, ruleLimit); |
|
|
|
|
|
typename Parser::CallbackType callback(m_schart, ruleLimit); |
|
|
|
|
|
std::size_t size = m_source.GetSize(); |
|
for (int start = size-1; start >= 0; --start) { |
|
for (std::size_t width = 1; width <= size-start; ++width) { |
|
std::size_t end = start + width - 1; |
|
|
|
|
|
SChart::Cell &scell = m_schart.GetCell(start, end); |
|
|
|
Range range(start, end); |
|
|
|
|
|
|
|
|
|
callback.InitForRange(range); |
|
for (typename std::vector<boost::shared_ptr<Parser> >::iterator |
|
p = m_parsers.begin(); p != m_parsers.end(); ++p) { |
|
(*p)->EnumerateHyperedges(range, callback); |
|
} |
|
|
|
|
|
const BoundedPriorityContainer<SHyperedgeBundle> &bundles = |
|
callback.GetContainer(); |
|
|
|
|
|
|
|
CubeQueue cubeQueue(bundles.Begin(), bundles.End()); |
|
std::size_t count = 0; |
|
typedef boost::unordered_map<Word, std::vector<SHyperedge*>, |
|
SymbolHasher, SymbolEqualityPred > BufferMap; |
|
BufferMap buffers; |
|
while (count < popLimit && !cubeQueue.IsEmpty()) { |
|
SHyperedge *hyperedge = cubeQueue.Pop(); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const Word &lhs = hyperedge->label.translation->GetTargetLHS(); |
|
hyperedge->head->pvertex = &m_pchart.AddVertex(PVertex(range, lhs)); |
|
|
|
buffers[lhs].push_back(hyperedge); |
|
++count; |
|
} |
|
|
|
|
|
for (BufferMap::const_iterator p = buffers.begin(); p != buffers.end(); |
|
++p) { |
|
const Word &category = p->first; |
|
const std::vector<SHyperedge*> &buffer = p->second; |
|
std::pair<SChart::Cell::NMap::Iterator, bool> ret = |
|
scell.nonTerminalStacks.Insert(category, SVertexStack()); |
|
assert(ret.second); |
|
SVertexStack &stack = ret.first->second; |
|
RecombineAndSort(buffer, stack); |
|
} |
|
|
|
|
|
if (stackLimit > 0) { |
|
for (SChart::Cell::NMap::Iterator p = scell.nonTerminalStacks.Begin(); |
|
p != scell.nonTerminalStacks.End(); ++p) { |
|
SVertexStack &stack = p->second; |
|
if (stack.size() > stackLimit) { |
|
stack.resize(stackLimit); |
|
} |
|
} |
|
} |
|
|
|
|
|
|
|
|
|
|
|
} |
|
} |
|
} |
|
|
|
template<typename Parser> |
|
const SHyperedge *Manager<Parser>::GetBestSHyperedge() const |
|
{ |
|
const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1); |
|
const SChart::Cell::NMap &stacks = cell.nonTerminalStacks; |
|
if (stacks.Size() == 0) { |
|
return 0; |
|
} |
|
assert(stacks.Size() == 1); |
|
const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second; |
|
|
|
return stack[0]->best; |
|
} |
|
|
|
template<typename Parser> |
|
void Manager<Parser>::ExtractKBest( |
|
std::size_t k, |
|
std::vector<boost::shared_ptr<KBestExtractor::Derivation> > &kBestList, |
|
bool onlyDistinct) const |
|
{ |
|
kBestList.clear(); |
|
if (k == 0 || m_source.GetSize() == 0) { |
|
return; |
|
} |
|
|
|
|
|
const SChart::Cell &cell = m_schart.GetCell(0, m_source.GetSize()-1); |
|
const SChart::Cell::NMap &stacks = cell.nonTerminalStacks; |
|
if (stacks.Size() == 0) { |
|
return; |
|
} |
|
assert(stacks.Size() == 1); |
|
const std::vector<boost::shared_ptr<SVertex> > &stack = stacks.Begin()->second; |
|
|
|
|
|
KBestExtractor extractor; |
|
|
|
if (!onlyDistinct) { |
|
|
|
extractor.Extract(stack, k, kBestList); |
|
return; |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
const StaticData &staticData = StaticData::Instance(); |
|
const std::size_t nBestFactor = staticData.options()->nbest.factor; |
|
std::size_t numDerivations = (nBestFactor == 0) ? k*1000 : k*nBestFactor; |
|
|
|
|
|
KBestExtractor::KBestVec bigList; |
|
bigList.reserve(numDerivations); |
|
extractor.Extract(stack, numDerivations, bigList); |
|
|
|
|
|
std::set<Phrase> distinct; |
|
for (KBestExtractor::KBestVec::const_iterator p = bigList.begin(); |
|
kBestList.size() < k && p != bigList.end(); ++p) { |
|
boost::shared_ptr<KBestExtractor::Derivation> derivation = *p; |
|
Phrase translation = KBestExtractor::GetOutputPhrase(*derivation); |
|
if (distinct.insert(translation).second) { |
|
kBestList.push_back(derivation); |
|
} |
|
} |
|
} |
|
|
|
template<typename Parser> |
|
void Manager<Parser>::PrunePChart(const SChart::Cell &scell, |
|
PChart::Cell &pcell) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
template<typename Parser> |
|
void Manager<Parser>::RecombineAndSort(const std::vector<SHyperedge*> &buffer, |
|
SVertexStack &stack) |
|
{ |
|
|
|
|
|
|
|
|
|
|
|
typedef boost::unordered_map<SVertex *, SVertex *, |
|
SVertexRecombinationHasher, |
|
SVertexRecombinationEqualityPred> Map; |
|
Map map; |
|
for (std::vector<SHyperedge*>::const_iterator p = buffer.begin(); |
|
p != buffer.end(); ++p) { |
|
SHyperedge *h = *p; |
|
SVertex *v = h->head; |
|
assert(v->best == h); |
|
assert(v->recombined.empty()); |
|
std::pair<Map::iterator, bool> result = map.insert(Map::value_type(v, v)); |
|
if (result.second) { |
|
continue; |
|
} |
|
|
|
|
|
|
|
SVertex *storedVertex = result.first->second; |
|
if (h->label.futureScore > storedVertex->best->label.futureScore) { |
|
|
|
storedVertex->recombined.push_back(storedVertex->best); |
|
storedVertex->best = h; |
|
} else { |
|
storedVertex->recombined.push_back(h); |
|
} |
|
h->head->best = 0; |
|
delete h->head; |
|
h->head = storedVertex; |
|
} |
|
|
|
|
|
stack.clear(); |
|
stack.reserve(map.size()); |
|
for (Map::const_iterator p = map.begin(); p != map.end(); ++p) { |
|
stack.push_back(boost::shared_ptr<SVertex>(p->first)); |
|
} |
|
|
|
|
|
std::sort(stack.begin(), stack.end(), SVertexStackContentOrderer()); |
|
} |
|
|
|
template<typename Parser> |
|
void Manager<Parser>::OutputDetailedTranslationReport( |
|
OutputCollector *collector) const |
|
{ |
|
const SHyperedge *best = GetBestSHyperedge(); |
|
if (best == NULL || collector == NULL) { |
|
return; |
|
} |
|
long translationId = m_source.GetTranslationId(); |
|
std::ostringstream out; |
|
DerivationWriter::Write(*best, translationId, out); |
|
collector->Write(translationId, out.str()); |
|
} |
|
|
|
} |
|
} |
|
} |
|
|