File size: 12,871 Bytes
6ed21b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import dataclasses
from typing import List, Optional, Tuple

import nltk
from nltk.corpus.reader.bracket_parse import BracketParseCorpusReader
import torch

from .benepar import ptb_unescape
from .benepar.parse_base import BaseInputExample
from .transliterate import TRANSLITERATIONS



@dataclasses.dataclass
class ParsingExample(BaseInputExample):
    """A single parse tree and sentence."""

    words: List[str]
    space_after: List[bool]
    tree: Optional[nltk.Tree] = None
    _pos: Optional[List[Tuple[str, str]]] = None
    word_tag_pairs: Optional[List[Tuple[str, str]]] = None

    def leaves(self):
        if self.tree is not None:
            return self.tree.leaves()
        elif self._pos is not None:
            return [word for word, tag in self._pos]
        else:
            return None

    def pos(self):
        if self.tree is not None:
            return self.tree.pos()
        else:
            return self._pos

    def without_gold_annotations(self):
        return dataclasses.replace(self, tree=None, _pos=self.pos())


class Treebank(torch.utils.data.Dataset):
    def __init__(self, examples):
        self.examples = examples

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, index):
        return self.examples[index]

    @property
    def trees(self):
        return [x.tree for x in self.examples]

    @property
    def sents(self):
        return [x.words for x in self.examples]

    @property
    def tagged_sents(self):
        return [x.pos() for x in self.examples]

    def filter_by_length(self, max_len):
        return Treebank([x for x in self.examples if len(x.leaves()) <= max_len])

    def without_gold_annotations(self):
        return Treebank([x.without_gold_annotations() for x in self.examples])


def read_text(text_path):
    sents = []
    sent = []
    end_of_multiword = 0
    multiword_combined = ""
    multiword_separate = []
    multiword_sp_after = False
    with open(text_path) as f:
        for line in f:
            if not line.strip() or line.startswith("#"):
                if sent:
                    sents.append(([w for w, sp in sent], [sp for w, sp in sent]))
                    sent = []
                    assert end_of_multiword == 0
                continue
            fields = line.split("\t", 2)
            num_or_range = fields[0]
            w = fields[1]

            if "-" in num_or_range:
                end_of_multiword = int(num_or_range.split("-")[1])
                multiword_combined = w
                multiword_separate = []
                multiword_sp_after = "SpaceAfter=No" not in fields[-1]
                continue
            elif int(num_or_range) <= end_of_multiword:
                multiword_separate.append(w)
                if int(num_or_range) == end_of_multiword:
                    import tokenizations
                    _, separate_to_combined = tokenizations.get_alignments(
                        multiword_combined, multiword_separate
                    )
                    have_up_to = 0
                    for i, char_idxs in enumerate(separate_to_combined):
                        if i == len(multiword_separate) - 1:
                            word = multiword_combined[have_up_to:]
                            sent.append((word, multiword_sp_after))
                        elif char_idxs:
                            word = multiword_combined[have_up_to : max(char_idxs) + 1]
                            sent.append((word, False))
                            have_up_to = max(char_idxs) + 1
                        else:
                            sent.append(("", False))
                    assert int(num_or_range) == len(sent)
                    end_of_multiword = 0
                    multiword_combined = ""
                    multiword_separate = []
                    multiword_sp_after = False
                continue
            else:
                assert int(num_or_range) == len(sent) + 1
                sp = "SpaceAfter=No" not in fields[-1]
                sent.append((w, sp))
    return sents


def load_trees(const_path, text_path=None, text_processing="default"):
    """Load a treebank.

    The standard tree format presents an abstracted view of the raw text, with the
    assumption that a tokenizer and other early stages of the NLP pipeline have already
    been run. These can include formatting changes like escaping certain characters
    (e.g. -LRB-) or transliteration (see e.g. the Arabic and Hebrew SPMRL datasets).
    Tokens are not always delimited by whitespace, and the raw whitespace in the source
    text is thrown away in the PTB tree format. Moreover, in some treebanks the leaves
    of the trees are lemmas/stems rather than word forms.

    All of this is a mismatch for pre-trained transformer models, which typically do
    their own tokenization starting with raw unicode strings. A mismatch compared to
    pre-training often doesn't affect performance if you just want to report F1 scores
    within the same treebank, but it raises some questions when it comes to releasing a
    parser for general use: (1) Must the parser be integrated with a tokenizer that
    matches the treebank convention? In fact, many modern NLP libraries like spaCy train
    on dependency data that doesn't necessarily use the same tokenization convention as
    constituency treebanks. (2) Can the parser's pre-trained model be merged with other
    pre-trained system components (via methods like multi-task learning or adapters), or
    must it remain its own system because of tokenization mismatches?

    This tree-loading function aims to build a path towards parsing from raw text by
    using the `text_path` argument to specify an auxiliary file that can be used to
    recover the original unicode string for the text. Parser layers above the
    pre-trained model may still use gold tokenization during training, but this will
    possibly help make the parser more robust to tokenization mismatches.

    On the other hand, some benchmarks involve evaluating with gold tokenization, and
    naively switching to using raw text degrades performance substantially. This can
    hopefully be addressed by making the parser layers on top of the pre-trained
    transformers handle tokenization more intelligently, but this is still a work in
    progress and the option remains to use the data from the tree files with minimal
    processing controlled by the `text_processing` argument to clean up some escaping or
    transliteration.

    Args:
        const_path: Path to the file with one tree per line.
        text_path: (optional) Path to a file that provides the correct spelling for all
            tokens (without any escaping, transliteration, or other mangling) and
            information about whether there is whitespace after each token. Files in the
            CoNLL-U format (https://universaldependencies.org/format.html) are accepted,
            but the parser also accepts similarly-formatted files with just three fields
            (ID, FORM, MISC) instead of the usual ten. Text is recovered from the FORM
            field and any "SpaceAfter=No" annotations in the MISC field.
        text_processing: Text processing to use if no text_path is specified:
            - 'default': undo PTB-style escape sequences and attempt to guess whitespace
                surrounding punctuation
            - 'arabic': guess that all tokens are separated by spaces
            - 'arabic-translit': undo Buckwalter transliteration and guess that all
                tokens are separated by spaces
            - 'chinese': keep all tokens unchanged (i.e. do not attempt to find any
                escape sequences), and assume no whitespace between tokens
            - 'hebrew': guess that all tokens are separated by spaces
            - 'hebrew-translit': undo transliteration (see Sima'an et al. 2002) and
                guess that all tokens are separated by spaces

    Returns:
        A list of ParsingExample objects, which have the following attributes:
            - `tree` is an instance of nltk.Tree
            - `words` is a list of strings
            - `space_after` is a list of booleans
    """
    reader = BracketParseCorpusReader("", [const_path])
    trees = reader.parsed_sents()

    if text_path is not None:
        sents = read_text(text_path)
    elif text_processing in ("arabic-translit", "hebrew-translit"):
        translit = transliterate.TRANSLITERATIONS[
            text_processing.replace("-translit", "")
        ]
        sents = []
        for tree in trees:
            words = [translit(word) for word in tree.leaves()]
            sp_after = [True for _ in words]
            sents.append((words, sp_after))
    elif text_processing in ("arabic", "hebrew"):
        sents = []
        for tree in trees:
            words = tree.leaves()
            sp_after = [True for _ in words]
            sents.append((words, sp_after))
    elif text_processing == "chinese":
        sents = []
        for tree in trees:
            words = tree.leaves()
            sp_after = [False for _ in words]
            sents.append((words, sp_after))
    elif text_processing == "default":
        sents = []
        for tree in trees:
            words = ptb_unescape.ptb_unescape(tree.leaves())
            sp_after = ptb_unescape.guess_space_after(tree.leaves())
            sents.append((words, sp_after))
    else:
        raise ValueError(f"Bad value for text_processing: {text_processing}")

    assert len(trees) == len(sents)
    treebank = Treebank(
        [
            ParsingExample(tree=tree, words=words, space_after=space_after)
            for tree, (words, space_after) in zip(trees, sents)
        ]
    )
    for example in treebank:
        assert len(example.words) == len(example.leaves()), (
            "Constituency tree has a different number of tokens than the CONLL-U or "
            "other file used to specify reversible tokenization."
        )
    return treebank

def load_text(const_path):
    """Load the test data from conll file and store each sample in a ParseExample class.

    Example of input file in conll format:
    Zum     APPRART.Dat.Sg.Neut
    ersten  ADJA.Sup.Dat.Sg.Neut
    sollen  VMFIN.1.Pl.Pres.*
    wir     PPER.1.Nom.Pl.*

    Lieber  ADJA.Pos.Nom.Sg.Masc
    mensch  NN.Nom.Sg.Masc
    gedenck VVIMP.2.Sg.Imp
    und     KON

    Args:
        const_path: Path to the conll file with one (token, tag) pair each line and each
        sentence separated by a new line '\n'.

    Returns:
        A list of ParsingExample objects, which have the following attributes:
            - `words` is a list of strings
            - `_pos` is a list of (token, tag) tuples
            - `space_after` is a list of booleans
    """
    with open(const_path, 'r', encoding='utf-8') as f:
        examples = []
        tags, tag_pairs, word_tag_pairs = [], [], []
        for line in f.readlines():
            if line == '\n':
                sp_after = ptb_unescape.guess_space_after(tags)
                examples.append(
                    ParsingExample(words=tags, space_after=sp_after, _pos=tag_pairs, 
                    word_tag_pairs=word_tag_pairs))
                tags, tag_pairs, word_tag_pairs = [], [], []
            else:
                word, tag = line.strip().split()
                tags.append(tag)
                word_tag_pairs.append((word, tag))
                tag_pairs.append((tag, tag.split('.')[0]))

    treebank = Treebank(examples)

    return treebank

def load_single_text(words, tags):
    """Load the test data from conll file and store each sample in a ParseExample class.

    Example of input file in conll format:
    Zum     APPRART.Dat.Sg.Neut
    ersten  ADJA.Sup.Dat.Sg.Neut
    sollen  VMFIN.1.Pl.Pres.*
    wir     PPER.1.Nom.Pl.*

    Lieber  ADJA.Pos.Nom.Sg.Masc
    mensch  NN.Nom.Sg.Masc
    gedenck VVIMP.2.Sg.Imp
    und     KON

    Args:
        const_path: Path to the conll file with one (token, tag) pair each line and each
        sentence separated by a new line '\n'.

    Returns:
        A list of ParsingExample objects, which have the following attributes:
            - `words` is a list of strings
            - `_pos` is a list of (token, tag) tuples
            - `space_after` is a list of booleans
    """

    examples = []
    tag_pairs, word_tag_pairs = [], []
    sp_after = ptb_unescape.guess_space_after(tags)
    for word, tag in zip(words, tags):
        word_tag_pairs.append((word, tag))
        tag_pairs.append((tag, tag.split('.')[0]))
        
    examples.append(
            ParsingExample(words=tags, space_after=sp_after, _pos=tag_pairs, 
            word_tag_pairs=word_tag_pairs))

            

    treebank = Treebank(examples)

    return treebank