|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "ScfgRule.h" |
|
|
|
#include <algorithm> |
|
|
|
#include "Node.h" |
|
#include "Subgraph.h" |
|
#include "SyntaxNode.h" |
|
#include "SyntaxNodeCollection.h" |
|
|
|
namespace MosesTraining |
|
{ |
|
namespace Syntax |
|
{ |
|
namespace GHKM |
|
{ |
|
|
|
ScfgRule::ScfgRule(const Subgraph &fragment, |
|
const SyntaxNodeCollection *sourceNodeCollection) |
|
: m_graphFragment(fragment) |
|
, m_sourceLHS("X", NonTerminal) |
|
, m_targetLHS(fragment.GetRoot()->GetLabel(), NonTerminal) |
|
, m_pcfgScore(fragment.GetPcfgScore()) |
|
, m_hasSourceLabels(sourceNodeCollection) |
|
{ |
|
|
|
|
|
|
|
const std::set<const Node *> &leaves = fragment.GetLeaves(); |
|
|
|
std::vector<const Node *> sourceRHSNodes; |
|
sourceRHSNodes.reserve(leaves.size()); |
|
for (std::set<const Node *>::const_iterator p(leaves.begin()); |
|
p != leaves.end(); ++p) { |
|
const Node &leaf = **p; |
|
if (!leaf.GetSpan().empty()) { |
|
sourceRHSNodes.push_back(&leaf); |
|
} |
|
} |
|
|
|
std::sort(sourceRHSNodes.begin(), sourceRHSNodes.end(), PartitionOrderComp); |
|
|
|
|
|
|
|
std::map<const Node *, std::vector<int> > sourceOrder; |
|
|
|
m_sourceRHS.reserve(sourceRHSNodes.size()); |
|
m_numberOfNonTerminals = 0; |
|
int srcIndex = 0; |
|
for (std::vector<const Node *>::const_iterator p(sourceRHSNodes.begin()); |
|
p != sourceRHSNodes.end(); ++p, ++srcIndex) { |
|
const Node &sinkNode = **p; |
|
if (sinkNode.GetType() == TREE) { |
|
m_sourceRHS.push_back(Symbol("X", NonTerminal)); |
|
sourceOrder[&sinkNode].push_back(srcIndex); |
|
++m_numberOfNonTerminals; |
|
} else { |
|
assert(sinkNode.GetType() == SOURCE); |
|
m_sourceRHS.push_back(Symbol(sinkNode.GetLabel(), Terminal)); |
|
|
|
const std::vector<Node *> &parents(sinkNode.GetParents()); |
|
for (std::vector<Node *>::const_iterator q(parents.begin()); |
|
q != parents.end(); ++q) { |
|
if ((*q)->GetType() == TARGET) { |
|
sourceOrder[*q].push_back(srcIndex); |
|
} |
|
} |
|
} |
|
if (sourceNodeCollection) { |
|
|
|
PushSourceLabel(sourceNodeCollection,&sinkNode,"XRHS"); |
|
} |
|
} |
|
|
|
|
|
|
|
std::vector<const Node *> targetLeaves; |
|
fragment.GetTargetLeaves(targetLeaves); |
|
|
|
m_alignment.reserve(targetLeaves.size()); |
|
m_targetRHS.reserve(targetLeaves.size()); |
|
|
|
for (std::vector<const Node *>::const_iterator p(targetLeaves.begin()); |
|
p != targetLeaves.end(); ++p) { |
|
const Node &leaf = **p; |
|
if (leaf.GetSpan().empty()) { |
|
|
|
|
|
std::vector<std::string> targetWords(leaf.GetTargetWords()); |
|
for (std::vector<std::string>::const_iterator q(targetWords.begin()); |
|
q != targetWords.end(); ++q) { |
|
m_targetRHS.push_back(Symbol(*q, Terminal)); |
|
} |
|
} else if (leaf.GetType() == SOURCE) { |
|
|
|
} else { |
|
SymbolType type = (leaf.GetType() == TREE) ? NonTerminal : Terminal; |
|
m_targetRHS.push_back(Symbol(leaf.GetLabel(), type)); |
|
|
|
int tgtIndex = m_targetRHS.size()-1; |
|
std::map<const Node *, std::vector<int> >::iterator q(sourceOrder.find(&leaf)); |
|
assert(q != sourceOrder.end()); |
|
std::vector<int> &sourceNodes = q->second; |
|
for (std::vector<int>::iterator r(sourceNodes.begin()); |
|
r != sourceNodes.end(); ++r) { |
|
int srcIndex = *r; |
|
m_alignment.push_back(std::make_pair(srcIndex, tgtIndex)); |
|
} |
|
} |
|
} |
|
|
|
if (sourceNodeCollection) { |
|
|
|
PushSourceLabel(sourceNodeCollection,fragment.GetRoot(),"XLHS"); |
|
|
|
|
|
|
|
} |
|
} |
|
|
|
void ScfgRule::PushSourceLabel(const SyntaxNodeCollection *sourceNodeCollection, |
|
const Node *node, |
|
const std::string &nonMatchingLabel) |
|
{ |
|
ContiguousSpan span = Closure(node->GetSpan()); |
|
if (sourceNodeCollection->HasNode(span.first,span.second)) { |
|
std::vector<SyntaxNode*> sourceLabels = |
|
sourceNodeCollection->GetNodes(span.first,span.second); |
|
if (!sourceLabels.empty()) { |
|
|
|
m_sourceLabels.push_back(sourceLabels.back()->label); |
|
} |
|
} else { |
|
|
|
m_sourceLabels.push_back(nonMatchingLabel); |
|
} |
|
} |
|
|
|
|
|
void ScfgRule::UpdateSourceLabelCoocCounts(std::map< std::string, std::map<std::string,float>* > &coocCounts, float count) const |
|
{ |
|
std::map<int, int> sourceToTargetNTMap; |
|
std::map<int, int> targetToSourceNTMap; |
|
|
|
for (Alignment::const_iterator p(m_alignment.begin()); |
|
p != m_alignment.end(); ++p) { |
|
if ( m_sourceRHS[p->first].GetType() == NonTerminal ) { |
|
assert(m_targetRHS[p->second].GetType() == NonTerminal); |
|
sourceToTargetNTMap[p->first] = p->second; |
|
} |
|
} |
|
|
|
size_t sourceIndex = 0; |
|
size_t sourceNonTerminalIndex = 0; |
|
for (std::vector<Symbol>::const_iterator p=m_sourceRHS.begin(); |
|
p != m_sourceRHS.end(); ++p, ++sourceIndex) { |
|
if ( p->GetType() == NonTerminal ) { |
|
const std::string &sourceLabel = m_sourceLabels[sourceNonTerminalIndex]; |
|
int targetIndex = sourceToTargetNTMap[sourceIndex]; |
|
const std::string &targetLabel = m_targetRHS[targetIndex].GetValue(); |
|
++sourceNonTerminalIndex; |
|
|
|
std::map<std::string,float>* countMap = NULL; |
|
std::map< std::string, std::map<std::string,float>* >::iterator iter = coocCounts.find(sourceLabel); |
|
if ( iter == coocCounts.end() ) { |
|
std::map<std::string,float> *newCountMap = new std::map<std::string,float>(); |
|
std::pair< std::map< std::string, std::map<std::string,float>* >::iterator, bool > inserted = |
|
coocCounts.insert( std::pair< std::string, std::map<std::string,float>* >(sourceLabel, newCountMap) ); |
|
assert(inserted.second); |
|
countMap = (inserted.first)->second; |
|
} else { |
|
countMap = iter->second; |
|
} |
|
std::pair< std::map<std::string,float>::iterator, bool > inserted = |
|
countMap->insert( std::pair< std::string,float>(targetLabel, count) ); |
|
if ( !inserted.second ) { |
|
(inserted.first)->second += count; |
|
} |
|
} |
|
} |
|
} |
|
|
|
} |
|
} |
|
} |
|
|