File size: 6,033 Bytes
158b61b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
#include "InputTreeBuilder.h"

#include "moses/StaticData.h"

namespace Moses
{
namespace Syntax
{
namespace T2S
{

InputTreeBuilder::InputTreeBuilder(std::vector<FactorType> const& oFactors)
  : m_outputFactorOrder(oFactors)
{
}

void InputTreeBuilder::Build(const TreeInput &in,
                             const std::string &topLevelLabel,
                             InputTree &out)
{
  CreateNodes(in, topLevelLabel, out);
  ConnectNodes(out);
}

// Create the InputTree::Node objects but do not connect them.
void InputTreeBuilder::CreateNodes(const TreeInput &in,
                                   const std::string &topLevelLabel,
                                   InputTree &out)
{
  // Get the input sentence word count.  This includes the <s> and </s> symbols.
  const std::size_t numWords = in.GetSize();

  // Get the parse tree non-terminal nodes.  The parse tree covers the original
  // sentence only, not the <s> and </s> symbols, so at this point there is
  // no top-level node.
  std::vector<XMLParseOutput> xmlNodes = in.GetLabelledSpans();

  // Sort the XML nodes into post-order.  Prior to sorting they will be in the
  // order that TreeInput created them.  Usually that will be post-order, but
  // if, for example, the tree was binarized by relax-parse then it won't be.
  // In all cases, we assume that if two nodes cover the same span then the
  // first one is the lowest.
  SortXmlNodesIntoPostOrder(xmlNodes);

  // Copy the parse tree non-terminal nodes, but offset the ranges by 1 (to
  // allow for the <s> symbol at position 0).
  std::vector<XMLParseOutput> nonTerms;
  nonTerms.reserve(xmlNodes.size()+1);
  for (std::vector<XMLParseOutput>::const_iterator p = xmlNodes.begin();
       p != xmlNodes.end(); ++p) {
    std::size_t start = p->m_range.GetStartPos();
    std::size_t end = p->m_range.GetEndPos();
    nonTerms.push_back(XMLParseOutput(p->m_label, Range(start+1, end+1)));
  }
  // Add a top-level node that also covers <s> and </s>.
  nonTerms.push_back(XMLParseOutput(topLevelLabel, Range(0, numWords-1)));

  // Allocate space for the InputTree nodes.  In the case of out.nodes, this
  // step is essential because once created the PVertex objects must not be
  // moved around (through vector resizing) because InputTree keeps pointers
  // to them.
  out.nodes.reserve(numWords + nonTerms.size());
  out.nodesAtPos.resize(numWords);

  // Create the InputTree::Node objects.
  int prevStart = -1;
  int prevEnd = -1;
  for (std::vector<XMLParseOutput>::const_iterator p = nonTerms.begin();
       p != nonTerms.end(); ++p) {
    int start = static_cast<int>(p->m_range.GetStartPos());
    int end = static_cast<int>(p->m_range.GetEndPos());

    // Check if we've started ascending a new subtree.
    if (start != prevStart && end != prevEnd) {
      // Add a node for each terminal to the left of or below the first
      // nonTerm child of the subtree.
      for (int i = prevEnd+1; i <= end; ++i) {
        PVertex v(Range(i, i), in.GetWord(i));
        out.nodes.push_back(InputTree::Node(v));
        out.nodesAtPos[i].push_back(&out.nodes.back());
      }
    }
    // Add a node for the non-terminal.
    Word w(true);
    w.CreateFromString(Moses::Output, m_outputFactorOrder, p->m_label, true);
    PVertex v(Range(start, end), w);
    out.nodes.push_back(InputTree::Node(v));
    out.nodesAtPos[start].push_back(&out.nodes.back());

    prevStart = start;
    prevEnd = end;
  }
}

// Connect the nodes by filling in the node.children vectors.
void InputTreeBuilder::ConnectNodes(InputTree &out)
{
  // Create a vector that records the parent of each node (except the root).
  std::vector<InputTree::Node*> parents(out.nodes.size(), NULL);
  for (std::size_t i = 0; i < out.nodes.size()-1; ++i) {
    const InputTree::Node &node = out.nodes[i];
    std::size_t start = node.pvertex.span.GetStartPos();
    std::size_t end = node.pvertex.span.GetEndPos();
    // Find the next node (in post-order) that completely covers node's span.
    std::size_t j = i+1;
    while (true) {
      const InputTree::Node &succ = out.nodes[j];
      std::size_t succStart = succ.pvertex.span.GetStartPos();
      std::size_t succEnd = succ.pvertex.span.GetEndPos();
      if (succStart <= start && succEnd >= end) {
        break;
      }
      ++j;
    }
    parents[i] = &(out.nodes[j]);
  }

  // Add each node to its parent's list of children (except the root).
  for (std::size_t i = 0; i < out.nodes.size()-1; ++i) {
    InputTree::Node &child = out.nodes[i];
    InputTree::Node &parent = *(parents[i]);
    parent.children.push_back(&child);
  }
}

void InputTreeBuilder::SortXmlNodesIntoPostOrder(
  std::vector<XMLParseOutput> &nodes)
{
  // Sorting is based on both the value of a node and its original position,
  // so for each node construct a pair containing both pieces of information.
  std::vector<std::pair<XMLParseOutput *, int> > pairs;
  pairs.reserve(nodes.size());
  for (std::size_t i = 0; i < nodes.size(); ++i) {
    pairs.push_back(std::make_pair(&(nodes[i]), i));
  }

  // Sort the pairs.
  std::sort(pairs.begin(), pairs.end(), PostOrderComp);

  // Replace the original node sequence with the correctly sorted sequence.
  std::vector<XMLParseOutput> tmp;
  tmp.reserve(nodes.size());
  for (std::size_t i = 0; i < pairs.size(); ++i) {
    tmp.push_back(nodes[pairs[i].second]);
  }
  nodes.swap(tmp);
}

// Comparison function used by SortXmlNodesIntoPostOrder.
bool InputTreeBuilder::PostOrderComp(const std::pair<XMLParseOutput *, int> &x,
                                     const std::pair<XMLParseOutput *, int> &y)
{
  std::size_t xStart = x.first->m_range.GetStartPos();
  std::size_t xEnd = x.first->m_range.GetEndPos();
  std::size_t yStart = y.first->m_range.GetStartPos();
  std::size_t yEnd = y.first->m_range.GetEndPos();

  if (xEnd == yEnd) {
    if (xStart == yStart) {
      return x.second < y.second;
    } else {
      return xStart > yStart;
    }
  } else {
    return xEnd < yEnd;
  }
}

}  // T2S
}  // Syntax
}  // Moses