Spaces:
Sleeping
Sleeping
import dataclasses | |
import itertools | |
from typing import List, Optional, Tuple | |
import nltk | |
import torch | |
from .downloader import load_trained_model | |
from ..parse_base import BaseParser, BaseInputExample | |
from ..ptb_unescape import ptb_unescape, guess_space_after | |
TOKENIZER_LOOKUP = { | |
"en": "english", | |
"de": "german", | |
"fr": "french", | |
"pl": "polish", | |
"sv": "swedish", | |
} | |
LANGUAGE_GUESS = { | |
"ar": ("X", "XP", "WHADVP", "WHNP", "WHPP"), | |
"zh": ("VSB", "VRD", "VPT", "VNV"), | |
"en": ("WHNP", "WHADJP", "SINV", "SQ"), | |
"de": ("AA", "AP", "CCP", "CH", "CNP", "VZ"), | |
"fr": ("P+", "P+D+", "PRO+", "PROREL+"), | |
"he": ("PREDP", "SYN_REL", "SYN_yyDOT"), | |
"pl": ("formaczas", "znakkonca"), | |
"sv": ("PSEUDO", "AVP", "XP"), | |
} | |
def guess_language(label_vocab): | |
"""Guess parser language based on its syntactic label inventory. | |
The parser training scripts are designed to accept arbitrary input tree | |
files with minimal language-specific behavior, but at inference time we may | |
need to know the language identity in order to invoke other pipeline | |
elements, such as tokenizers. | |
""" | |
for language, required_labels in LANGUAGE_GUESS.items(): | |
if all(label in label_vocab for label in required_labels): | |
return language | |
return None | |
class InputSentence(BaseInputExample): | |
"""Parser input for a single sentence. | |
At least one of `words` and `escaped_words` is required for each input | |
sentence. The remaining fields are optional: the parser will attempt to | |
derive the value for any missing fields using the fields that are provided. | |
`words` and `space_after` together form a reversible tokenization of the | |
input text: they represent, respectively, the Unicode text for each word and | |
an indicator for whether the word is followed by whitespace. These are used | |
as inputs by the parser. | |
`tags` is a list of part-of-speech tags, if available prior to running the | |
parser. The parser does not actually use these tags as input, but it will | |
pass them through to its output. If `tags` is None, the parser will perform | |
its own part of speech tagging (if the parser was not trained to also do | |
tagging, "UNK" part-of-speech tags will be used in the output instead). | |
`escaped_words` are the representations of each leaf to use in the output | |
tree. If `words` is provided, `escaped_words` will not be used by the neural | |
network portion of the parser, and will only be incorporated when | |
constructing the output tree. Therefore, `escaped_words` may be used to | |
accommodate any dataset-specific text encoding, such as transliteration. | |
Here is an example of the differences between these fields for English PTB: | |
(raw text): "Fly safely." | |
words: " Fly safely . " | |
space_after: False True False False False | |
tags: `` VB RB . '' | |
escaped_words: `` Fly safely . '' | |
""" | |
words: Optional[List[str]] = None | |
space_after: Optional[List[bool]] = None | |
tags: Optional[List[str]] = None | |
escaped_words: Optional[List[str]] = None | |
def tree(self): | |
return None | |
def leaves(self): | |
return self.escaped_words | |
def pos(self): | |
if self.tags is not None: | |
return list(zip(self.escaped_words, self.tags)) | |
else: | |
return [(word, "UNK") for word in self.escaped_words] | |
class Parser: | |
"""Berkeley Neural Parser (benepar), integrated with NLTK. | |
Use this class to apply the Berkeley Neural Parser to pre-tokenized datasets | |
and treebanks, or when integrating the parser into an NLP pipeline that | |
already performs tokenization, sentence splitting, and (optionally) | |
part-of-speech tagging. For parsing starting with raw text, it is strongly | |
encouraged that you use spaCy and benepar.BeneparComponent instead. | |
Sample usage: | |
>>> parser = benepar.Parser("benepar_en3") | |
>>> input_sentence = benepar.InputSentence( | |
words=['"', 'Fly', 'safely', '.', '"'], | |
space_after=[False, True, False, False, False], | |
tags=['``', 'VB', 'RB', '.', "''"], | |
escaped_words=['``', 'Fly', 'safely', '.', "''"], | |
) | |
>>> parser.parse(input_sentence) | |
Not all fields of benepar.InputSentence are required, but at least one of | |
`words` and `escaped_words` must not be None. The parser will attempt to | |
guess the value for missing fields. For example, | |
>>> input_sentence = benepar.InputSentence( | |
words=['"', 'Fly', 'safely', '.', '"'], | |
) | |
>>> parser.parse(input_sentence) | |
Although this class is primarily designed for use with data that has already | |
been tokenized, to help with interactive use and debugging it also accepts | |
simple text string inputs. However, using this class to parse from raw text | |
is STRONGLY DISCOURAGED for any application where parsing accuracy matters. | |
When parsing from raw text, use spaCy and benepar.BeneparComponent instead. | |
The reason is that parser models do not ship with a tokenizer or sentence | |
splitter, and some models may not include a part-of-speech tagger either. A | |
toolkit must be used to fill in these pipeline components, and spaCy | |
outperforms NLTK in all of these areas (sometimes by a large margin). | |
>>> parser.parse('"Fly safely."') # For debugging/interactive use only. | |
""" | |
def __init__(self, name, batch_size=64, language_code=None): | |
"""Load a trained parser model. | |
Args: | |
name (str): Model name, or path to pytorch saved model | |
batch_size (int): Maximum number of sentences to process per batch | |
language_code (str, optional): language code for the parser (e.g. | |
'en', 'he', 'zh', etc). Our official trained models will set | |
this automatically, so this argument is only needed if training | |
on new languages or treebanks. | |
""" | |
self._parser = load_trained_model(name) | |
if torch.cuda.is_available(): | |
self._parser.cuda() | |
if language_code is not None: | |
self._language_code = language_code | |
else: | |
self._language_code = guess_language(self._parser.config["label_vocab"]) | |
self._tokenizer_lang = TOKENIZER_LOOKUP.get(self._language_code, None) | |
self.batch_size = batch_size | |
def parse(self, sentence): | |
"""Parse a single sentence | |
Args: | |
sentence (InputSentence or List[str] or str): Sentence to parse. | |
If the input is of List[str], it is assumed to be a sequence of | |
words and will behave the same as only setting the `words` field | |
of InputSentence. If the input is of type str, the sentence will | |
be tokenized using the default NLTK tokenizer (not recommended: | |
if parsing from raw text, use spaCy and benepar.BeneparComponent | |
instead). | |
Returns: | |
nltk.Tree | |
""" | |
return list(self.parse_sents([sentence]))[0] | |
def parse_sents(self, sents): | |
"""Parse multiple sentences in batches. | |
Args: | |
sents (Iterable[InputSentence]): An iterable of sentences to be | |
parsed. `sents` may also be a string, in which case it will be | |
segmented into sentences using the default NLTK sentence | |
splitter (not recommended: if parsing from raw text, use spaCy | |
and benepar.BeneparComponent instead). Otherwise, each element | |
of `sents` will be treated as a sentence. The elements of | |
`sents` may also be List[str] or str: see Parser.parse() for | |
documentation regarding these cases. | |
Yields: | |
nltk.Tree objects, one per input sentence. | |
""" | |
if isinstance(sents, str): | |
if self._tokenizer_lang is None: | |
raise ValueError( | |
"No tokenizer available for this language. " | |
"Please split into individual sentences and tokens " | |
"before calling the parser." | |
) | |
sents = nltk.sent_tokenize(sents, self._tokenizer_lang) | |
end_sentinel = object() | |
for batch_sents in itertools.zip_longest( | |
*([iter(sents)] * self.batch_size), fillvalue=end_sentinel | |
): | |
batch_inputs = [] | |
for sent in batch_sents: | |
if sent is end_sentinel: | |
break | |
elif isinstance(sent, str): | |
if self._tokenizer_lang is None: | |
raise ValueError( | |
"No word tokenizer available for this language. " | |
"Please tokenize before calling the parser." | |
) | |
escaped_words = nltk.word_tokenize(sent, self._tokenizer_lang) | |
sent = InputSentence(escaped_words=escaped_words) | |
elif isinstance(sent, (list, tuple)): | |
sent = InputSentence(words=sent) | |
elif not isinstance(sent, InputSentence): | |
raise ValueError( | |
"Sentences must be one of: InputSentence, list, tuple, or str" | |
) | |
batch_inputs.append(self._with_missing_fields_filled(sent)) | |
for inp, output in zip( | |
batch_inputs, self._parser.parse(batch_inputs, return_compressed=True) | |
): | |
# If pos tags are provided as input, ignore any tags predicted | |
# by the parser. | |
if inp.tags is not None: | |
output = output.without_predicted_tags() | |
yield output.to_tree( | |
inp.pos(), | |
self._parser.decoder.label_from_index, | |
self._parser.tag_from_index, | |
) | |
def _with_missing_fields_filled(self, sent): | |
if not isinstance(sent, InputSentence): | |
raise ValueError("Input is not an instance of InputSentence") | |
if sent.words is None and sent.escaped_words is None: | |
raise ValueError("At least one of words or escaped_words is required") | |
elif sent.words is None: | |
sent = dataclasses.replace(sent, words=ptb_unescape(sent.escaped_words)) | |
elif sent.escaped_words is None: | |
escaped_words = [ | |
word.replace("(", "-LRB-") | |
.replace(")", "-RRB-") | |
.replace("{", "-LCB-") | |
.replace("}", "-RCB-") | |
.replace("[", "-LSB-") | |
.replace("]", "-RSB-") | |
for word in sent.words | |
] | |
sent = dataclasses.replace(sent, escaped_words=escaped_words) | |
else: | |
if len(sent.words) != len(sent.escaped_words): | |
raise ValueError( | |
f"Length of words ({len(sent.words)}) does not match " | |
f"escaped_words ({len(sent.escaped_words)})" | |
) | |
if sent.space_after is None: | |
if self._language_code == "zh": | |
space_after = [False for _ in sent.words] | |
elif self._language_code in ("ar", "he"): | |
space_after = [True for _ in sent.words] | |
else: | |
space_after = guess_space_after(sent.words) | |
sent = dataclasses.replace(sent, space_after=space_after) | |
elif len(sent.words) != len(sent.space_after): | |
raise ValueError( | |
f"Length of words ({len(sent.words)}) does not match " | |
f"space_after ({len(sent.space_after)})" | |
) | |
assert len(sent.words) == len(sent.escaped_words) == len(sent.space_after) | |
return sent | |