|
#pragma once |
|
|
|
#include <memory> |
|
#include <vector> |
|
|
|
#include "moses/ChartParser.h" |
|
#include "moses/ChartTranslationOptionList.h" |
|
#include "moses/InputType.h" |
|
#include "moses/NonTerminal.h" |
|
#include "moses/StaticData.h" |
|
#include "moses/Syntax/S2T/Parsers/Parser.h" |
|
#include "moses/Syntax/S2T/PChart.h" |
|
|
|
#include "TailLatticeSearcher.h" |
|
|
|
namespace Moses |
|
{ |
|
namespace Syntax |
|
{ |
|
namespace S2T |
|
{ |
|
|
|
template<typename Callback> |
|
Scope3Parser<Callback>::Scope3Parser(PChart &chart, const RuleTrie &trie, |
|
std::size_t maxChartSpan) |
|
: Parser<Callback>(chart) |
|
, m_ruleTable(trie) |
|
, m_maxChartSpan(maxChartSpan) |
|
, m_latticeBuilder(chart) |
|
{ |
|
Init(); |
|
} |
|
|
|
template<typename Callback> |
|
Scope3Parser<Callback>::~Scope3Parser() |
|
{ |
|
delete m_patRoot; |
|
} |
|
|
|
template<typename Callback> |
|
void Scope3Parser<Callback>:: |
|
EnumerateHyperedges(const Range &range, Callback &callback) |
|
{ |
|
const std::size_t start = range.GetStartPos(); |
|
const std::size_t end = range.GetEndPos(); |
|
|
|
const std::vector<const PatternApplicationTrie *> &patNodes = |
|
m_patSpans[start][end-start+1]; |
|
|
|
for (std::vector<const PatternApplicationTrie *>::const_iterator |
|
p = patNodes.begin(); p != patNodes.end(); ++p) { |
|
const PatternApplicationTrie *patNode = *p; |
|
|
|
|
|
patNode->ReadOffPatternApplicationKey(m_patKey); |
|
|
|
|
|
m_symbolRangeCalculator.Calc(m_patKey, start, end, m_symbolRanges); |
|
|
|
|
|
|
|
m_latticeBuilder.Build(m_patKey, m_symbolRanges, m_lattice, |
|
m_quickCheckTable); |
|
|
|
|
|
|
|
const RuleTrie::Node::LabelMap &labelMap = patNode->m_node->GetLabelMap(); |
|
|
|
|
|
|
|
TailLatticeSearcher<Callback> searcher(m_lattice, m_patKey, m_symbolRanges); |
|
RuleTrie::Node::LabelMap::const_iterator q = labelMap.begin(); |
|
for (; q != labelMap.end(); ++q) { |
|
const std::vector<int> &labelSeq = q->first; |
|
TargetPhraseCollection::shared_ptr tpc = q->second; |
|
|
|
|
|
|
|
bool failCheck = false; |
|
std::size_t nonTermIndex = 0; |
|
for (std::size_t i = 0; i < m_patKey.size(); ++i) { |
|
if (m_patKey[i]->IsTerminalNode()) { |
|
continue; |
|
} |
|
if (!m_quickCheckTable[nonTermIndex][labelSeq[nonTermIndex]]) { |
|
failCheck = true; |
|
break; |
|
} |
|
++nonTermIndex; |
|
} |
|
if (failCheck) { |
|
continue; |
|
} |
|
searcher.Search(labelSeq, tpc, callback); |
|
} |
|
} |
|
} |
|
|
|
template<typename Callback> |
|
void Scope3Parser<Callback>::Init() |
|
{ |
|
|
|
SentenceMap sentMap; |
|
FillSentenceMap(sentMap); |
|
|
|
|
|
const RuleTrie::Node &root = m_ruleTable.GetRootNode(); |
|
m_patRoot = new PatternApplicationTrie(-1, -1, root, 0, 0); |
|
m_patRoot->Extend(root, -1, sentMap, false); |
|
|
|
|
|
InitRuleApplicationVector(); |
|
RecordPatternApplicationSpans(*m_patRoot); |
|
} |
|
|
|
|
|
template<typename Callback> |
|
void Scope3Parser<Callback>::InitRuleApplicationVector() |
|
{ |
|
std::size_t length = Base::m_chart.GetWidth(); |
|
m_patSpans.resize(length); |
|
for (std::size_t start = 0; start < length; ++start) { |
|
std::size_t maxSpan = length-start; |
|
m_patSpans[start].resize(maxSpan+1); |
|
} |
|
} |
|
|
|
template<typename Callback> |
|
void Scope3Parser<Callback>::FillSentenceMap(SentenceMap &sentMap) |
|
{ |
|
typedef PChart::Cell Cell; |
|
|
|
const std::size_t width = Base::m_chart.GetWidth(); |
|
for (std::size_t i = 0; i < width; ++i) { |
|
for (std::size_t j = i; j < width; ++j) { |
|
const Cell::TMap &map = Base::m_chart.GetCell(i, j).terminalVertices; |
|
for (Cell::TMap::const_iterator p = map.begin(); p != map.end(); ++p) { |
|
const Word &terminal = p->first; |
|
const PVertex &v = p->second; |
|
sentMap[terminal].push_back(&v); |
|
} |
|
} |
|
} |
|
} |
|
|
|
template<typename Callback> |
|
void Scope3Parser<Callback>::RecordPatternApplicationSpans( |
|
const PatternApplicationTrie &patNode) |
|
{ |
|
if (patNode.m_node->HasRules()) { |
|
int s1 = -1; |
|
int s2 = -1; |
|
int e1 = -1; |
|
int e2 = -1; |
|
patNode.DetermineStartRange(Base::m_chart.GetWidth(), s1, s2); |
|
patNode.DetermineEndRange(Base::m_chart.GetWidth(), e1, e2); |
|
|
|
int minSpan = patNode.Depth(); |
|
|
|
|
|
for (int i = s1; i <= s2; ++i) { |
|
for (int j = std::max(e1, i+minSpan-1); j <= e2; ++j) { |
|
std::size_t span = j-i+1; |
|
assert(span >= 1); |
|
if (span < std::size_t(minSpan)) { |
|
continue; |
|
} |
|
if (m_maxChartSpan && span > m_maxChartSpan) { |
|
break; |
|
} |
|
m_patSpans[i][span].push_back(&patNode); |
|
} |
|
} |
|
} |
|
|
|
for (std::vector<PatternApplicationTrie*>::const_iterator p = |
|
patNode.m_children.begin(); p != patNode.m_children.end(); ++p) { |
|
RecordPatternApplicationSpans(**p); |
|
} |
|
} |
|
|
|
} |
|
} |
|
} |
|
|