File size: 3,886 Bytes
7dd9869
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch
import spacy, nltk
from nltk.tree import Tree
import numpy as np

def collapse_unary_strip_pos(tree, strip_top=True):
    """Collapse unary chains and strip part of speech tags."""

    def strip_pos(tree):
        if len(tree) == 1 and isinstance(tree[0], str):
            return tree[0]
        else:
            return nltk.tree.Tree(tree.label(), [strip_pos(child) for child in tree])

    collapsed_tree = strip_pos(tree)
    collapsed_tree.collapse_unary(collapsePOS=True, joinChar="::")
    if collapsed_tree.label() in ("TOP", "ROOT", "S1", "VROOT"):
        if strip_top:
            if len(collapsed_tree) == 1:
                collapsed_tree = collapsed_tree[0]
            else:
                collapsed_tree.set_label("")
        elif len(collapsed_tree) == 1:
            collapsed_tree[0].set_label(
                collapsed_tree.label() + "::" + collapsed_tree[0].label())
            collapsed_tree = collapsed_tree[0]
    return collapsed_tree

def _get_labeled_spans(tree, spans_out, start):
    if isinstance(tree, str):
        return start + 1

    assert len(tree) > 1 or isinstance(
        tree[0], str
    ), "Must call collapse_unary_strip_pos first"
    end = start
    for child in tree:
        end = _get_labeled_spans(child, spans_out, end)
    # Spans are returned as closed intervals on both ends
    spans_out.append((start, end - 1, tree.label()))
    return end

def get_labeled_spans(tree):
    """Converts a tree into a list of labeled spans.
    Args:
        tree: an nltk.tree.Tree object
    Returns:
        A list of (span_start, span_end, span_label) tuples. The start and end
        indices indicate the first and last words of the span (a closed
        interval). Unary chains are collapsed, so e.g. a (S (VP ...)) will
        result in a single span labeled "S+VP".
    """
    tree = collapse_unary_strip_pos(tree)
    spans_out = []
    _get_labeled_spans(tree, spans_out, start=0)
    return spans_out

def padded_chart_from_spans(label_vocab, spans, ):
    num_words = 64
    chart = np.full((num_words, num_words), -100, dtype=int)
    # chart = np.tril(chart, -1)
    # Now all invalid entries are filled with -100, and valid entries with 0
    for start, end, label in spans:
        if label in label_vocab:
            chart[start, end] = label_vocab[label]
    return chart

def chart_from_tree(label_vocab, tree, verbose=False):
    spans = get_labeled_spans(tree)
    num_words = len(tree.leaves())
    chart = np.full((num_words, num_words), -100, dtype=int)
    chart = np.tril(chart, -1)
    # Now all invalid entries are filled with -100, and valid entries with 0
    # print(tree)
    for start, end, label in spans:
        # Previously unseen unary chains can occur in the dev/test sets.
        # For now, we ignore them and don't mark the corresponding chart
        # entry as a constituent.
        # print(start, end, label)
        if label in label_vocab:
            chart[start, end] = label_vocab[label]
    if not verbose:
        return chart
    else:
        return chart, spans

def pad_charts(charts, padding_value=-100):
    """
    Our input text format contains START and END, but the parse charts doesn't.
    NEED TO: update the charts, so that we include these two, and set their span label to 0.

    :param charts:
    :param padding_value:
    :return:
    """
    max_len = 64
    padded_charts = torch.full(
        (len(charts), max_len, max_len),
        padding_value,
    )
    padded_charts = np.tril(padded_charts, -1)
    # print(padded_charts[-2:], padded_charts.shape)
    # print(padded_charts[1])
    for i, chart in enumerate(charts):
        # print(chart, len(chart), len(chart[0]))
        chart_size = len(chart)
        padded_charts[i, 1:chart_size+1, 1:chart_size+1] = chart

    # print(padded_charts[-2:], padded_charts.shape)
    return padded_charts