nielklug's picture
init
6ed21b9
import dataclasses
import itertools
from typing import List, Optional, Tuple
import nltk
import torch
from .downloader import load_trained_model
from ..parse_base import BaseParser, BaseInputExample
from ..ptb_unescape import ptb_unescape, guess_space_after
TOKENIZER_LOOKUP = {
"en": "english",
"de": "german",
"fr": "french",
"pl": "polish",
"sv": "swedish",
}
LANGUAGE_GUESS = {
"ar": ("X", "XP", "WHADVP", "WHNP", "WHPP"),
"zh": ("VSB", "VRD", "VPT", "VNV"),
"en": ("WHNP", "WHADJP", "SINV", "SQ"),
"de": ("AA", "AP", "CCP", "CH", "CNP", "VZ"),
"fr": ("P+", "P+D+", "PRO+", "PROREL+"),
"he": ("PREDP", "SYN_REL", "SYN_yyDOT"),
"pl": ("formaczas", "znakkonca"),
"sv": ("PSEUDO", "AVP", "XP"),
}
def guess_language(label_vocab):
"""Guess parser language based on its syntactic label inventory.
The parser training scripts are designed to accept arbitrary input tree
files with minimal language-specific behavior, but at inference time we may
need to know the language identity in order to invoke other pipeline
elements, such as tokenizers.
"""
for language, required_labels in LANGUAGE_GUESS.items():
if all(label in label_vocab for label in required_labels):
return language
return None
@dataclasses.dataclass
class InputSentence(BaseInputExample):
"""Parser input for a single sentence.
At least one of `words` and `escaped_words` is required for each input
sentence. The remaining fields are optional: the parser will attempt to
derive the value for any missing fields using the fields that are provided.
`words` and `space_after` together form a reversible tokenization of the
input text: they represent, respectively, the Unicode text for each word and
an indicator for whether the word is followed by whitespace. These are used
as inputs by the parser.
`tags` is a list of part-of-speech tags, if available prior to running the
parser. The parser does not actually use these tags as input, but it will
pass them through to its output. If `tags` is None, the parser will perform
its own part of speech tagging (if the parser was not trained to also do
tagging, "UNK" part-of-speech tags will be used in the output instead).
`escaped_words` are the representations of each leaf to use in the output
tree. If `words` is provided, `escaped_words` will not be used by the neural
network portion of the parser, and will only be incorporated when
constructing the output tree. Therefore, `escaped_words` may be used to
accommodate any dataset-specific text encoding, such as transliteration.
Here is an example of the differences between these fields for English PTB:
(raw text): "Fly safely."
words: " Fly safely . "
space_after: False True False False False
tags: `` VB RB . ''
escaped_words: `` Fly safely . ''
"""
words: Optional[List[str]] = None
space_after: Optional[List[bool]] = None
tags: Optional[List[str]] = None
escaped_words: Optional[List[str]] = None
@property
def tree(self):
return None
def leaves(self):
return self.escaped_words
def pos(self):
if self.tags is not None:
return list(zip(self.escaped_words, self.tags))
else:
return [(word, "UNK") for word in self.escaped_words]
class Parser:
"""Berkeley Neural Parser (benepar), integrated with NLTK.
Use this class to apply the Berkeley Neural Parser to pre-tokenized datasets
and treebanks, or when integrating the parser into an NLP pipeline that
already performs tokenization, sentence splitting, and (optionally)
part-of-speech tagging. For parsing starting with raw text, it is strongly
encouraged that you use spaCy and benepar.BeneparComponent instead.
Sample usage:
>>> parser = benepar.Parser("benepar_en3")
>>> input_sentence = benepar.InputSentence(
words=['"', 'Fly', 'safely', '.', '"'],
space_after=[False, True, False, False, False],
tags=['``', 'VB', 'RB', '.', "''"],
escaped_words=['``', 'Fly', 'safely', '.', "''"],
)
>>> parser.parse(input_sentence)
Not all fields of benepar.InputSentence are required, but at least one of
`words` and `escaped_words` must not be None. The parser will attempt to
guess the value for missing fields. For example,
>>> input_sentence = benepar.InputSentence(
words=['"', 'Fly', 'safely', '.', '"'],
)
>>> parser.parse(input_sentence)
Although this class is primarily designed for use with data that has already
been tokenized, to help with interactive use and debugging it also accepts
simple text string inputs. However, using this class to parse from raw text
is STRONGLY DISCOURAGED for any application where parsing accuracy matters.
When parsing from raw text, use spaCy and benepar.BeneparComponent instead.
The reason is that parser models do not ship with a tokenizer or sentence
splitter, and some models may not include a part-of-speech tagger either. A
toolkit must be used to fill in these pipeline components, and spaCy
outperforms NLTK in all of these areas (sometimes by a large margin).
>>> parser.parse('"Fly safely."') # For debugging/interactive use only.
"""
def __init__(self, name, batch_size=64, language_code=None):
"""Load a trained parser model.
Args:
name (str): Model name, or path to pytorch saved model
batch_size (int): Maximum number of sentences to process per batch
language_code (str, optional): language code for the parser (e.g.
'en', 'he', 'zh', etc). Our official trained models will set
this automatically, so this argument is only needed if training
on new languages or treebanks.
"""
self._parser = load_trained_model(name)
if torch.cuda.is_available():
self._parser.cuda()
if language_code is not None:
self._language_code = language_code
else:
self._language_code = guess_language(self._parser.config["label_vocab"])
self._tokenizer_lang = TOKENIZER_LOOKUP.get(self._language_code, None)
self.batch_size = batch_size
def parse(self, sentence):
"""Parse a single sentence
Args:
sentence (InputSentence or List[str] or str): Sentence to parse.
If the input is of List[str], it is assumed to be a sequence of
words and will behave the same as only setting the `words` field
of InputSentence. If the input is of type str, the sentence will
be tokenized using the default NLTK tokenizer (not recommended:
if parsing from raw text, use spaCy and benepar.BeneparComponent
instead).
Returns:
nltk.Tree
"""
return list(self.parse_sents([sentence]))[0]
def parse_sents(self, sents):
"""Parse multiple sentences in batches.
Args:
sents (Iterable[InputSentence]): An iterable of sentences to be
parsed. `sents` may also be a string, in which case it will be
segmented into sentences using the default NLTK sentence
splitter (not recommended: if parsing from raw text, use spaCy
and benepar.BeneparComponent instead). Otherwise, each element
of `sents` will be treated as a sentence. The elements of
`sents` may also be List[str] or str: see Parser.parse() for
documentation regarding these cases.
Yields:
nltk.Tree objects, one per input sentence.
"""
if isinstance(sents, str):
if self._tokenizer_lang is None:
raise ValueError(
"No tokenizer available for this language. "
"Please split into individual sentences and tokens "
"before calling the parser."
)
sents = nltk.sent_tokenize(sents, self._tokenizer_lang)
end_sentinel = object()
for batch_sents in itertools.zip_longest(
*([iter(sents)] * self.batch_size), fillvalue=end_sentinel
):
batch_inputs = []
for sent in batch_sents:
if sent is end_sentinel:
break
elif isinstance(sent, str):
if self._tokenizer_lang is None:
raise ValueError(
"No word tokenizer available for this language. "
"Please tokenize before calling the parser."
)
escaped_words = nltk.word_tokenize(sent, self._tokenizer_lang)
sent = InputSentence(escaped_words=escaped_words)
elif isinstance(sent, (list, tuple)):
sent = InputSentence(words=sent)
elif not isinstance(sent, InputSentence):
raise ValueError(
"Sentences must be one of: InputSentence, list, tuple, or str"
)
batch_inputs.append(self._with_missing_fields_filled(sent))
for inp, output in zip(
batch_inputs, self._parser.parse(batch_inputs, return_compressed=True)
):
# If pos tags are provided as input, ignore any tags predicted
# by the parser.
if inp.tags is not None:
output = output.without_predicted_tags()
yield output.to_tree(
inp.pos(),
self._parser.decoder.label_from_index,
self._parser.tag_from_index,
)
def _with_missing_fields_filled(self, sent):
if not isinstance(sent, InputSentence):
raise ValueError("Input is not an instance of InputSentence")
if sent.words is None and sent.escaped_words is None:
raise ValueError("At least one of words or escaped_words is required")
elif sent.words is None:
sent = dataclasses.replace(sent, words=ptb_unescape(sent.escaped_words))
elif sent.escaped_words is None:
escaped_words = [
word.replace("(", "-LRB-")
.replace(")", "-RRB-")
.replace("{", "-LCB-")
.replace("}", "-RCB-")
.replace("[", "-LSB-")
.replace("]", "-RSB-")
for word in sent.words
]
sent = dataclasses.replace(sent, escaped_words=escaped_words)
else:
if len(sent.words) != len(sent.escaped_words):
raise ValueError(
f"Length of words ({len(sent.words)}) does not match "
f"escaped_words ({len(sent.escaped_words)})"
)
if sent.space_after is None:
if self._language_code == "zh":
space_after = [False for _ in sent.words]
elif self._language_code in ("ar", "he"):
space_after = [True for _ in sent.words]
else:
space_after = guess_space_after(sent.words)
sent = dataclasses.replace(sent, space_after=space_after)
elif len(sent.words) != len(sent.space_after):
raise ValueError(
f"Length of words ({len(sent.words)}) does not match "
f"space_after ({len(sent.space_after)})"
)
assert len(sent.words) == len(sent.escaped_words) == len(sent.space_after)
return sent