|
from typing import Dict, Any, Callable, List, Tuple, NamedTuple, Text, Optional |
|
import numpy as np |
|
from spacy.tokens.token import Token |
|
from spacy.tokens.span import Span |
|
|
|
from lattice import Product as L |
|
|
|
from heuristics import Heuristics |
|
|
|
Rel = Tuple[List[Token], "Entity"] |
|
Sup = List[Token] |
|
|
|
DEFAULT_HEURISTICS = Heuristics() |
|
|
|
|
|
def find_superlatives(tokens, heuristics) -> List[Sup]: |
|
"""Modify and return a list of superlative tokens.""" |
|
for heuristic in heuristics.superlatives: |
|
if any(tok.text in heuristic.keywords for tok in tokens): |
|
tokens.sort(key=lambda tok: tok.i) |
|
return [tokens] |
|
return [] |
|
|
|
def expand_chunks(doc, chunks): |
|
expanded = {} |
|
for key in chunks: |
|
chunk = chunks[key] |
|
start = chunk.start |
|
end = chunk.end |
|
for i in range(chunk.start-1, -1, -1): |
|
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): |
|
if not any(any(doc[i].is_ancestor(doc[j]) for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): |
|
start = i |
|
for i in range(chunk.end, len(doc)): |
|
if any(doc[j].is_ancestor(doc[i]) for j in range(chunk.start, chunk.end)): |
|
if not any(any(doc[i].is_ancestor(doc[j]) or i == j for j in range(chunks[key2].start, chunks[key2].end)) for key2 in chunks if key != key2): |
|
end = i+1 |
|
else: |
|
break |
|
expanded[key] = Span(doc=doc, start=start, end=end) |
|
return expanded |
|
|
|
class Entity(NamedTuple): |
|
"""Represents an entity with locative constraints extracted from the parse.""" |
|
|
|
head: Span |
|
relations: List[Rel] |
|
superlatives: List[Sup] |
|
|
|
@classmethod |
|
def extract(cls, head, chunks, heuristics: Optional[Heuristics] = None) -> "Entity": |
|
"""Extract entities from a spacy parse. |
|
|
|
Jointly recursive with `_get_rel_sups`.""" |
|
if heuristics is None: |
|
heuristics = DEFAULT_HEURISTICS |
|
|
|
if head.i not in chunks: |
|
|
|
children = list(head.children) |
|
if children and children[0].i in chunks: |
|
head = children[0] |
|
|
|
else: |
|
return None |
|
hchunk = chunks[head.i] |
|
rels, sups = cls._get_rel_sups(head, head, [], chunks, heuristics) |
|
return cls(hchunk, rels, sups) |
|
|
|
@classmethod |
|
def _get_rel_sups(cls, token, head, tokens, chunks, heuristics) -> Tuple[List[Rel], List[Sup]]: |
|
hchunk = chunks[head.i] |
|
is_keyword = any(token.text in h.keywords for h in heuristics.relations) |
|
is_keyword |= token.text in heuristics.null_keywords |
|
|
|
|
|
if token.i in chunks and chunks[token.i] is not hchunk and not is_keyword: |
|
tchunk = chunks[token.i] |
|
tokens.sort(key=lambda tok: tok.i) |
|
subhead = cls.extract(token, chunks, heuristics) |
|
return [(tokens, subhead)], [] |
|
|
|
|
|
n_children = len(list(token.children)) |
|
if n_children == 0: |
|
return [], find_superlatives(tokens + [token], heuristics) |
|
|
|
relations = [] |
|
superlatives = [] |
|
is_keyword |= any(token.text in h.keywords for h in heuristics.superlatives) |
|
for child in token.children: |
|
if token.i in chunks and child.i in chunks and chunks[token.i] is chunks[child.i]: |
|
if not any(child.text in h.keywords for h in heuristics.superlatives): |
|
if n_children == 1: |
|
|
|
sups = find_superlatives(tokens + [token], heuristics) |
|
superlatives.extend(sups) |
|
continue |
|
new_tokens = tokens + [token] if token.i not in chunks or is_keyword else tokens |
|
subrel, subsup = cls._get_rel_sups(child, head, new_tokens, chunks, heuristics) |
|
relations.extend(subrel) |
|
superlatives.extend(subsup) |
|
return relations, superlatives |
|
|
|
def expand(self, span: Span = None): |
|
tokens = [token for token in self.head] |
|
if span is None: |
|
span = [None] |
|
for target_token in span: |
|
include = False |
|
stack = [token for token in self.head] |
|
while len(stack) > 0: |
|
token = stack.pop() |
|
if token == target_token: |
|
token2 = target_token.head |
|
while token2.head != token2: |
|
tokens.append(token2) |
|
token2 = token2.head |
|
tokens.append(token2) |
|
stack = [] |
|
include = True |
|
if target_token is None or include: |
|
tokens.append(token) |
|
for child in token.children: |
|
stack.append(child) |
|
tokens = list(set(tokens)) |
|
tokens = sorted(tokens, key=lambda x: x.i) |
|
return ' '.join([token.text for token in tokens]) |
|
|
|
def __eq__(self, other: "Entity") -> bool: |
|
if self.text != other.text: |
|
return False |
|
if self.relations != other.relations: |
|
return False |
|
if self.superlatives != other.superlatives: |
|
return False |
|
return True |
|
|
|
@property |
|
def text(self) -> Text: |
|
"""Get the text predicate associated with this entity.""" |
|
return self.head.text |
|
|