File size: 5,575 Bytes
92e0882 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional
import numpy as np
from spacy.tokens.token import Token
from spacy.tokens.span import Span
from lattice import Product as L
from heuristics import Heuristics
Rel = Tuple[List[Token], "Entity"]
Sup = List[Token]
DEFAULT_HEURISTICS = Heuristics()
def find_superlatives(tokens, heuristics) -> List[Sup]:
"""Modify and return a list of superlative tokens."""
for heuristic in heuristics.superlatives:
if any(tok.text in heuristic.keywords for tok in tokens):
tokens.sort(key=lambda tok: tok.i)
return [tokens]
return []
def expand_chunks(doc, chunks):
expanded = {}
for key in chunks:
chunk = chunks[key]
start = chunk.start
end = chunk.end
for i in range(chunk.start-1, -1, -1):
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
start = i
for i in range(chunk.end, len(doc)):
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)):
if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2):
end = i+1
else:
break
expanded[key] = Span(doc=doc, start=start, end=end)
return expanded
class Entity(NamedTuple):
"""Represents an entity with locative constraints extracted from the parse."""
head: Span
relations: List[Rel]
superlatives: List[Sup]
@classmethod
def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity":
"""Extract entities from a spacy parse.
Jointly recursive with `_get_rel_sups`."""
if heuristics is None:
heuristics = DEFAULT_HEURISTICS
if head.i not in chunks:
# Handles predicative cases.
children = list(head.children)
if children and children[0].i in chunks:
head = children[0]
# TODO: Also extract predicative relations.
else:
return None
hchunk = chunks[head.i]
rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics)
return cls(hchunk, rels, sups)
@classmethod
def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]:
hchunk = chunks[head.i]
is_keyword = any(token.text in h.keywords for h in heuristics.relations)
is_keyword |= token.text in heuristics.null_keywords
# Found another entity head.
if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword:
tchunk = chunks[token.i]
tokens.sort(key=lambda tok: tok.i)
subhead = cls.extract(token, chunks, heuristics)
return [(tokens, subhead)], []
# End of a chain of modifiers.
n_children = len(list(token.children))
if n_children == 0:
return [], find_superlatives(tokens + [token], heuristics)
relations = []
superlatives = []
is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives)
for child in token.children:
if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]:
if not any(child.text in h.keywords for h in heuristics.superlatives):
if n_children == 1:
# Catches "the goat on the left"
sups = find_superlatives(tokens + [token], heuristics)
superlatives.extend(sups)
continue
new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens
subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics)
relations.extend(subrel)
superlatives.extend(subsup)
return relations, superlatives
def expand(self, span: Span = None):
tokens = [token for token in self.head]
if span is None:
span = [None]
for target_token in span:
include = False
stack = [token for token in self.head]
while len(stack) > 0:
token = stack.pop()
if token == target_token:
token2 = target_token.head
while token2.head != token2:
tokens.append(token2)
token2 = token2.head
tokens.append(token2)
stack = []
include = True
if target_token is None or include:
tokens.append(token)
for child in token.children:
stack.append(child)
tokens = list(set(tokens))
tokens = sorted(tokens, key=lambda x: x.i)
return ' '.join([token.text for token in tokens])
def __eq__(self, other: "Entity") -> bool:
if self.text != other.text:
return False
if self.relations != other.relations:
return False
if self.superlatives != other.superlatives:
return False
return True
@property
def text(self) -> Text:
"""Get the text predicate associated with this entity."""
return self.head.text
|