File size: 11,887 Bytes
6ed21b9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
import dataclasses
import itertools
from typing import List, Optional, Tuple

import nltk
import torch

from .downloader import load_trained_model
from ..parse_base import BaseParser, BaseInputExample
from ..ptb_unescape import ptb_unescape, guess_space_after


TOKENIZER_LOOKUP = {
    "en": "english",
    "de": "german",
    "fr": "french",
    "pl": "polish",
    "sv": "swedish",
}

LANGUAGE_GUESS = {
    "ar": ("X", "XP", "WHADVP", "WHNP", "WHPP"),
    "zh": ("VSB", "VRD", "VPT", "VNV"),
    "en": ("WHNP", "WHADJP", "SINV", "SQ"),
    "de": ("AA", "AP", "CCP", "CH", "CNP", "VZ"),
    "fr": ("P+", "P+D+", "PRO+", "PROREL+"),
    "he": ("PREDP", "SYN_REL", "SYN_yyDOT"),
    "pl": ("formaczas", "znakkonca"),
    "sv": ("PSEUDO", "AVP", "XP"),
}


def guess_language(label_vocab):
    """Guess parser language based on its syntactic label inventory.

    The parser training scripts are designed to accept arbitrary input tree
    files with minimal language-specific behavior, but at inference time we may
    need to know the language identity in order to invoke other pipeline
    elements, such as tokenizers.
    """
    for language, required_labels in LANGUAGE_GUESS.items():
        if all(label in label_vocab for label in required_labels):
            return language
    return None


@dataclasses.dataclass
class InputSentence(BaseInputExample):
    """Parser input for a single sentence.

    At least one of `words` and `escaped_words` is required for each input
    sentence. The remaining fields are optional: the parser will attempt to
    derive the value for any missing fields using the fields that are provided.

    `words` and `space_after` together form a reversible tokenization of the
    input text: they represent, respectively, the Unicode text for each word and
    an indicator for whether the word is followed by whitespace. These are used
    as inputs by the parser.

    `tags` is a list of part-of-speech tags, if available prior to running the
    parser. The parser does not actually use these tags as input, but it will
    pass them through to its output. If `tags` is None, the parser will perform
    its own part of speech tagging (if the parser was not trained to also do
    tagging, "UNK" part-of-speech tags will be used in the output instead).

    `escaped_words` are the representations of each leaf to use in the output
    tree. If `words` is provided, `escaped_words` will not be used by the neural
    network portion of the parser, and will only be incorporated when
    constructing the output tree. Therefore, `escaped_words` may be used to
    accommodate any dataset-specific text encoding, such as transliteration.

    Here is an example of the differences between these fields for English PTB:
        (raw text):     "Fly safely."
        words:          "       Fly     safely  .       "
        space_after:    False   True    False   False   False
        tags:           ``      VB      RB      .       ''
        escaped_words:  ``      Fly     safely  .       ''
    """

    words: Optional[List[str]] = None
    space_after: Optional[List[bool]] = None
    tags: Optional[List[str]] = None
    escaped_words: Optional[List[str]] = None

    @property
    def tree(self):
        return None

    def leaves(self):
        return self.escaped_words

    def pos(self):
        if self.tags is not None:
            return list(zip(self.escaped_words, self.tags))
        else:
            return [(word, "UNK") for word in self.escaped_words]


class Parser:
    """Berkeley Neural Parser (benepar), integrated with NLTK.

    Use this class to apply the Berkeley Neural Parser to pre-tokenized datasets
    and treebanks, or when integrating the parser into an NLP pipeline that
    already performs tokenization, sentence splitting, and (optionally)
    part-of-speech tagging. For parsing starting with raw text, it is strongly
    encouraged that you use spaCy and benepar.BeneparComponent instead.

    Sample usage:
    >>> parser = benepar.Parser("benepar_en3")
    >>> input_sentence = benepar.InputSentence(
        words=['"', 'Fly', 'safely', '.', '"'],
        space_after=[False, True, False, False, False],
        tags=['``', 'VB', 'RB', '.', "''"],
        escaped_words=['``', 'Fly', 'safely', '.', "''"],
    )
    >>> parser.parse(input_sentence)

    Not all fields of benepar.InputSentence are required, but at least one of
    `words` and `escaped_words` must not be None. The parser will attempt to
    guess the value for missing fields. For example,
    >>> input_sentence = benepar.InputSentence(
        words=['"', 'Fly', 'safely', '.', '"'],
    )
    >>> parser.parse(input_sentence)

    Although this class is primarily designed for use with data that has already
    been tokenized, to help with interactive use and debugging it also accepts
    simple text string inputs. However, using this class to parse from raw text
    is STRONGLY DISCOURAGED for any application where parsing accuracy matters.
    When parsing from raw text, use spaCy and benepar.BeneparComponent instead.
    The reason is that parser models do not ship with a tokenizer or sentence
    splitter, and some models may not include a part-of-speech tagger either. A
    toolkit must be used to fill in these pipeline components, and spaCy
    outperforms NLTK in all of these areas (sometimes by a large margin).
    >>> parser.parse('"Fly safely."')  # For debugging/interactive use only.
    """

    def __init__(self, name, batch_size=64, language_code=None):
        """Load a trained parser model.

        Args:
            name (str): Model name, or path to pytorch saved model
            batch_size (int): Maximum number of sentences to process per batch
            language_code (str, optional): language code for the parser (e.g.
                'en', 'he', 'zh', etc). Our official trained models will set
                this automatically, so this argument is only needed if training
                on new languages or treebanks.
        """
        self._parser = load_trained_model(name)
        if torch.cuda.is_available():
            self._parser.cuda()
        if language_code is not None:
            self._language_code = language_code
        else:
            self._language_code = guess_language(self._parser.config["label_vocab"])
        self._tokenizer_lang = TOKENIZER_LOOKUP.get(self._language_code, None)

        self.batch_size = batch_size

    def parse(self, sentence):
        """Parse a single sentence

        Args:
            sentence (InputSentence or List[str] or str): Sentence to parse.
                If the input is of List[str], it is assumed to be a sequence of
                words and will behave the same as only setting the `words` field
                of InputSentence. If the input is of type str, the sentence will
                be tokenized using the default NLTK tokenizer (not recommended:
                if parsing from raw text, use spaCy and benepar.BeneparComponent
                instead).

        Returns:
            nltk.Tree
        """
        return list(self.parse_sents([sentence]))[0]

    def parse_sents(self, sents):
        """Parse multiple sentences in batches.

        Args:
            sents (Iterable[InputSentence]): An iterable of sentences to be
                parsed. `sents` may also be a string, in which case it will be
                segmented into sentences using the default NLTK sentence
                splitter (not recommended: if parsing from raw text, use spaCy
                and benepar.BeneparComponent instead). Otherwise, each element
                of `sents` will be treated as a sentence. The elements of
                `sents` may also be List[str] or str: see Parser.parse() for
                documentation regarding these cases.

        Yields:
            nltk.Tree objects, one per input sentence.
        """
        if isinstance(sents, str):
            if self._tokenizer_lang is None:
                raise ValueError(
                    "No tokenizer available for this language. "
                    "Please split into individual sentences and tokens "
                    "before calling the parser."
                )
            sents = nltk.sent_tokenize(sents, self._tokenizer_lang)

        end_sentinel = object()
        for batch_sents in itertools.zip_longest(
            *([iter(sents)] * self.batch_size), fillvalue=end_sentinel
        ):
            batch_inputs = []
            for sent in batch_sents:
                if sent is end_sentinel:
                    break
                elif isinstance(sent, str):
                    if self._tokenizer_lang is None:
                        raise ValueError(
                            "No word tokenizer available for this language. "
                            "Please tokenize before calling the parser."
                        )
                    escaped_words = nltk.word_tokenize(sent, self._tokenizer_lang)
                    sent = InputSentence(escaped_words=escaped_words)
                elif isinstance(sent, (list, tuple)):
                    sent = InputSentence(words=sent)
                elif not isinstance(sent, InputSentence):
                    raise ValueError(
                        "Sentences must be one of: InputSentence, list, tuple, or str"
                    )
                batch_inputs.append(self._with_missing_fields_filled(sent))

            for inp, output in zip(
                batch_inputs, self._parser.parse(batch_inputs, return_compressed=True)
            ):
                # If pos tags are provided as input, ignore any tags predicted
                # by the parser.
                if inp.tags is not None:
                    output = output.without_predicted_tags()
                yield output.to_tree(
                    inp.pos(),
                    self._parser.decoder.label_from_index,
                    self._parser.tag_from_index,
                )

    def _with_missing_fields_filled(self, sent):
        if not isinstance(sent, InputSentence):
            raise ValueError("Input is not an instance of InputSentence")
        if sent.words is None and sent.escaped_words is None:
            raise ValueError("At least one of words or escaped_words is required")
        elif sent.words is None:
            sent = dataclasses.replace(sent, words=ptb_unescape(sent.escaped_words))
        elif sent.escaped_words is None:
            escaped_words = [
                word.replace("(", "-LRB-")
                .replace(")", "-RRB-")
                .replace("{", "-LCB-")
                .replace("}", "-RCB-")
                .replace("[", "-LSB-")
                .replace("]", "-RSB-")
                for word in sent.words
            ]
            sent = dataclasses.replace(sent, escaped_words=escaped_words)
        else:
            if len(sent.words) != len(sent.escaped_words):
                raise ValueError(
                    f"Length of words ({len(sent.words)}) does not match "
                    f"escaped_words ({len(sent.escaped_words)})"
                )

        if sent.space_after is None:
            if self._language_code == "zh":
                space_after = [False for _ in sent.words]
            elif self._language_code in ("ar", "he"):
                space_after = [True for _ in sent.words]
            else:
                space_after = guess_space_after(sent.words)
            sent = dataclasses.replace(sent, space_after=space_after)
        elif len(sent.words) != len(sent.space_after):
            raise ValueError(
                f"Length of words ({len(sent.words)}) does not match "
                f"space_after ({len(sent.space_after)})"
            )

        assert len(sent.words) == len(sent.escaped_words) == len(sent.space_after)
        return sent